From 2c8dab7ab4ec0884c6428afc613d9dcc322d8c6d Mon Sep 17 00:00:00 2001 From: Jim Mussared Date: Mon, 7 Nov 2022 12:55:31 +1100 Subject: [PATCH] py/objarray: Detect bytearray(str) without an encoding. This prevents a very subtle bug caused by writing e.g. `bytearray('\xfd')` which gives you `(0xc3, 0xbd)`. This work was funded through GitHub Sponsors. Signed-off-by: Jim Mussared --- py/objarray.c | 8 ++++++++ py/objstr.c | 4 ++++ tests/basics/bytearray_construct.py | 5 +++++ tests/micropython/viper_addr.py | 2 +- 4 files changed, 18 insertions(+), 1 deletion(-) diff --git a/py/objarray.c b/py/objarray.c index 42fc0749d2..c660705389 100644 --- a/py/objarray.c +++ b/py/objarray.c @@ -192,6 +192,14 @@ STATIC mp_obj_t bytearray_make_new(const mp_obj_type_t *type_in, size_t n_args, return MP_OBJ_FROM_PTR(o); } else { // 1 arg: construct the bytearray from that + if (mp_obj_is_str(args[0]) && n_args == 1) { + #if MICROPY_ERROR_REPORTING <= MICROPY_ERROR_REPORTING_TERSE + // Match bytes_make_new. + mp_raise_TypeError(MP_ERROR_TEXT("wrong number of arguments")); + #else + mp_raise_TypeError(MP_ERROR_TEXT("string argument without an encoding")); + #endif + } return array_construct(BYTEARRAY_TYPECODE, args[0]); } } diff --git a/py/objstr.c b/py/objstr.c index 8c639e7354..bd3e16e7f2 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -233,7 +233,11 @@ STATIC mp_obj_t bytes_make_new(const mp_obj_type_t *type_in, size_t n_args, size if (mp_obj_is_str(args[0])) { if (n_args < 2 || n_args > 3) { + #if MICROPY_ERROR_REPORTING <= MICROPY_ERROR_REPORTING_TERSE goto wrong_args; + #else + mp_raise_TypeError(MP_ERROR_TEXT("string argument without an encoding")); + #endif } GET_STR_DATA_LEN(args[0], str_data, str_len); GET_STR_HASH(args[0], str_hash); diff --git a/tests/basics/bytearray_construct.py b/tests/basics/bytearray_construct.py index 75fdc41178..eb4d4e641f 100644 --- a/tests/basics/bytearray_construct.py +++ b/tests/basics/bytearray_construct.py @@ -5,3 +5,8 @@ print(bytearray('1234', 'utf-8')) print(bytearray('12345', 'utf-8', 'strict')) print(bytearray((1, 2))) print(bytearray([1, 2])) + +try: + print(bytearray('1234')) +except TypeError: + print("TypeError") diff --git a/tests/micropython/viper_addr.py b/tests/micropython/viper_addr.py index 84bc6c002e..8e79fadb2a 100644 --- a/tests/micropython/viper_addr.py +++ b/tests/micropython/viper_addr.py @@ -21,7 +21,7 @@ def memsum(src: ptr8, n: int) -> int: # create array and get its address -ar = bytearray("0000") +ar = bytearray(b"0000") addr = get_addr(ar) print(type(ar)) print(type(addr))