diff --git a/py/objstr.c b/py/objstr.c index c7e4ebf53b..a51be74747 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -743,11 +743,29 @@ static mp_obj_t str_finder(size_t n_args, const mp_obj_t *args, int direction, b const mp_obj_type_t *self_type = mp_obj_get_type(args[0]); check_is_str_or_bytes(args[0]); - // check argument type - str_check_arg_type(self_type, args[1]); - GET_STR_DATA_LEN(args[0], haystack, haystack_len); - GET_STR_DATA_LEN(args[1], needle, needle_len); + + mp_int_t val; + byte needle_data; + const byte *needle; + size_t needle_len; + if (self_type != &mp_type_str && mp_obj_get_int_maybe(args[1], &val)) { + // Allow {bytes/bytearray}.{find,index}(int). + #if MICROPY_FULL_CHECKS + if (val < 0 || val > 255) { + mp_raise_ValueError(MP_ERROR_TEXT("bytes value out of range")); + } + #endif + needle_data = val; + needle = &needle_data; + needle_len = 1; + } else { + // check argument type + str_check_arg_type(self_type, args[1]); + GET_STR_DATA_LEN(args[1], needle_tmp, needle_len_tmp); + needle = needle_tmp; + needle_len = needle_len_tmp; + } const byte *start = haystack; const byte *end = haystack + haystack_len; diff --git a/tests/basics/bytearray_byte_operations.py b/tests/basics/bytearray_byte_operations.py index 48b08ab261..794fe8cafa 100644 --- a/tests/basics/bytearray_byte_operations.py +++ b/tests/basics/bytearray_byte_operations.py @@ -2,6 +2,7 @@ print(bytearray(b"hello world").find(b"ll")) print(bytearray(b"hello\x00world").rfind(b"l")) +print(bytearray(b"hello world").find(ord(b"l"))) print(bytearray(b"abc efg ").strip(b"g a")) print(bytearray(b" spacious ").lstrip()) @@ -14,10 +15,18 @@ print(bytearray(b"abcabc").rsplit(b"bc")) print(bytearray(b"asdfasdf").replace(b"a", b"b")) -print("00\x0000".index("0", 0)) -print("00\x0000".index("0", 3)) -print("00\x0000".rindex("0", 0)) -print("00\x0000".rindex("0", 3)) +print(b"00\x0000".index(b"0", 0)) +print(b"00\x0000".index(b"0", 3)) +print(b"00\x0000".index(ord("0"), 0)) +print(b"00\x0000".index(ord("0"), 3)) +print(b"00\x0000".index(0, 0)) +try: + print(b"00\x0000".index(0, 3)) +except ValueError: + print("ValueError") +print(b"00\x0000".index(b"0", 3)) +print(b"00\x0000".rindex(b"0", 0)) +print(b"00\x0000".rindex(b"0", 3)) print(bytearray(b"foobar").endswith(b"bar")) print(bytearray(b"1foo").startswith(b"foo", 1)) diff --git a/tests/basics/bytes_find.py b/tests/basics/bytes_find.py index 75ef9796cd..4a2707cc53 100644 --- a/tests/basics/bytes_find.py +++ b/tests/basics/bytes_find.py @@ -21,6 +21,34 @@ print(b"0000".find(b'-1', 3)) print(b"0000".find(b'1', 3)) print(b"0000".find(b'1', 4)) print(b"0000".find(b'1', 5)) +print(b"0000".find(ord(b'0'))) +print(b"0000".find(ord(b'0'), 0)) +print(b"0000".find(ord(b'0'), 1)) +print(b"0000".find(ord(b'0'), 2)) +print(b"0000".find(ord(b'0'), 3)) +print(b"0000".find(ord(b'0'), 4)) +print(b"0000".find(ord(b'0'), 5)) +print(b"0000".find(ord(b'x'), 3)) +print(b"0000".find(ord(b'1'), 3)) +print(b"0000".find(ord(b'1'), 4)) +print(b"0000".find(ord(b'1'), 5)) # Non-ascii values (make sure not treated as unicode-like) print(b"\x80abc".find(b"a", 1)) + +# Int-like conversion. +print(b"00\x0000".find(b'0', True)) + +# Out of bounds int. +try: + print(b"0000".find(b'0', -1)) +except ValueError: + print("ValueError") +try: + print(b"0000".find(b'0', 256)) +except ValueError: + print("ValueError") +try: + print(b"0000".find(b'0', 91273611)) +except ValueError: + print("ValueError") diff --git a/tests/basics/bytes_index.py b/tests/basics/bytes_index.py new file mode 100644 index 0000000000..08ed46e4a9 --- /dev/null +++ b/tests/basics/bytes_index.py @@ -0,0 +1,99 @@ +print(b"hello world".index(b"ll")) +print(b"hello world".index(b"ll", None)) +print(b"hello world".index(b"ll", 1)) +print(b"hello world".index(b"ll", 1, None)) +print(b"hello world".index(b"ll", None, None)) +print(b"hello world".index(b"ll", 1, -1)) +try: + print(b"hello world".index(b"ll", 1, 1)) +except ValueError: + print("ValueError") +try: + print(b"hello world".index(b"ll", 1, 2)) +except ValueError: + print("ValueError") +try: + print(b"hello world".index(b"ll", 1, 3)) +except ValueError: + print("ValueError") +print(b"hello world".index(b"ll", 1, 4)) +print(b"hello world".index(b"ll", 1, 5)) +print(b"hello world".index(b"ll", -100)) +print(b"0000".index(b'0')) +print(b"0000".index(b'0', 0)) +print(b"0000".index(b'0', 1)) +print(b"0000".index(b'0', 2)) +print(b"0000".index(b'0', 3)) +try: + print(b"0000".index(b'0', 4)) +except ValueError: + print("ValueError") +try: + print(b"0000".index(b'0', 5)) +except ValueError: + print("ValueError") +try: + print(b"0000".index(b'-1', 3)) +except ValueError: + print("ValueError") +try: + print(b"0000".index(b'1', 3)) +except ValueError: + print("ValueError") +try: + print(b"0000".index(b'1', 4)) +except ValueError: + print("ValueError") +try: + print(b"0000".index(b'1', 5)) +except ValueError: + print("ValueError") +print(b"0000".index(ord(b'0'))) +print(b"0000".index(ord(b'0'), 0)) +print(b"0000".index(ord(b'0'), 1)) +print(b"0000".index(ord(b'0'), 2)) +print(b"0000".index(ord(b'0'), 3)) +try: + print(b"0000".index(ord(b'0'), 4)) +except ValueError: + print("ValueError") +try: + print(b"0000".index(ord(b'0'), 5)) +except ValueError: + print("ValueError") +try: + print(b"0000".index(ord(b'x'), 3)) +except ValueError: + print("ValueError") +try: + print(b"0000".index(ord(b'1'), 3)) +except ValueError: + print("ValueError") +try: + print(b"0000".index(ord(b'1'), 4)) +except ValueError: + print("ValueError") +try: + print(b"0000".index(ord(b'1'), 5)) +except ValueError: + print("ValueError") + +# Non-ascii values (make sure not treated as unicode-like) +print(b"\x80abc".index(b"a", 1)) + +# Int-like conversion. +print(b"00\x0000".index(b'0', True)) + +# Out of bounds int. +try: + print(b"0000".index(b'0', -1)) +except ValueError: + print("ValueError") +try: + print(b"0000".index(b'0', 256)) +except ValueError: + print("ValueError") +try: + print(b"0000".index(b'0', 91273611)) +except ValueError: + print("ValueError") diff --git a/tests/basics/string_find.py b/tests/basics/string_find.py index f9fcad3e57..f37064ad29 100644 --- a/tests/basics/string_find.py +++ b/tests/basics/string_find.py @@ -24,6 +24,7 @@ print("0000".find('1', 5)) print("aaaaaaaaaaa".find("bbb", 9, 2)) try: + # Only works on bytes/bytearray. 'abc'.find(1) except TypeError: print('TypeError') diff --git a/tests/basics/string_index.py b/tests/basics/string_index.py index 31f6900e6c..328f2dff54 100644 --- a/tests/basics/string_index.py +++ b/tests/basics/string_index.py @@ -76,3 +76,9 @@ except ValueError: print("Raised ValueError") else: print("Did not raise ValueError") + +try: + # Only works on bytes/bytearray. + 'abc'.index(1) +except TypeError: + print('TypeError')