diff --git a/py/objstr.c b/py/objstr.c index ea4f5ead24..be1f00e686 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -156,6 +156,43 @@ static bool chr_in_str(const char* const str, const size_t str_len, const char c return false; } +static mp_obj_t str_find(int n_args, const mp_obj_t *args) { + assert(2 <= n_args && n_args <= 4); + assert(MP_OBJ_IS_TYPE(args[0], &str_type)); + if (!MP_OBJ_IS_TYPE(args[1], &str_type)) { + nlr_jump(mp_obj_new_exception_msg_1_arg( + MP_QSTR_TypeError, + "Can't convert '%s' object to str implicitly", + mp_obj_get_type_str(args[1]))); + } + + const char* haystack = qstr_str(((mp_obj_str_t*)args[0])->qstr); + const char* needle = qstr_str(((mp_obj_str_t*)args[1])->qstr); + + ssize_t haystack_len = strlen(haystack); + ssize_t needle_len = strlen(needle); + + size_t start = 0; + size_t end = haystack_len; + /* TODO use a non-exception-throwing mp_get_index */ + if (n_args >= 3 && args[2] != mp_const_none) { + start = mp_get_index(&str_type, haystack_len, args[2]); + } + if (n_args >= 4 && args[3] != mp_const_none) { + end = mp_get_index(&str_type, haystack_len, args[3]); + } + + char *p = strstr(haystack + start, needle); + ssize_t pos = -1; + if (p) { + pos = p - haystack; + if (pos + needle_len > end) { + pos = -1; + } + } + return MP_OBJ_NEW_SMALL_INT(pos); +} + mp_obj_t str_strip(int n_args, const mp_obj_t *args) { assert(1 <= n_args && n_args <= 2); assert(MP_OBJ_IS_TYPE(args[0], &str_type)); @@ -239,11 +276,13 @@ mp_obj_t str_format(int n_args, const mp_obj_t *args) { return mp_obj_new_str(qstr_from_str_take(vstr->buf, vstr->alloc)); } +static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(str_find_obj, 2, 4, str_find); static MP_DEFINE_CONST_FUN_OBJ_2(str_join_obj, str_join); static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(str_strip_obj, 1, 2, str_strip); static MP_DEFINE_CONST_FUN_OBJ_VAR(str_format_obj, 1, str_format); static const mp_method_t str_type_methods[] = { + { "find", &str_find_obj }, { "join", &str_join_obj }, { "strip", &str_strip_obj }, { "format", &str_format_obj }, diff --git a/tests/basics/tests/string_find.py b/tests/basics/tests/string_find.py new file mode 100644 index 0000000000..90063228f8 --- /dev/null +++ b/tests/basics/tests/string_find.py @@ -0,0 +1,11 @@ +print("hello world".find("ll")) +print("hello world".find("ll", None)) +print("hello world".find("ll", 1)) +print("hello world".find("ll", 1, None)) +print("hello world".find("ll", None, None)) +print("hello world".find("ll", 1, -1)) +print("hello world".find("ll", 1, 1)) +print("hello world".find("ll", 1, 2)) +print("hello world".find("ll", 1, 3)) +print("hello world".find("ll", 1, 4)) +print("hello world".find("ll", 1, 5))