From 57365d855734142deb030ebcd00c10efcedf554b Mon Sep 17 00:00:00 2001 From: stijn Date: Wed, 12 May 2021 17:02:06 +0200 Subject: [PATCH] py/objarray: Prohibit comparison of mismatching types. Array equality is defined as each element being equal but to keep code size down MicroPython implements a binary comparison. This can only be used correctly for elements with the same binary layout though so turn it into an NotImplementedError when comparing types for which the binary comparison yielded incorrect results: types with different sizes, and floating point numbers because nan != nan. --- py/objarray.c | 25 +++++++++++++++++++++++- tests/basics/array1.py | 9 +++++++++ tests/basics/array_micropython.py | 12 ++++++++++++ tests/cpydiff/module_array_comparison.py | 9 +++++++++ 4 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 tests/cpydiff/module_array_comparison.py diff --git a/py/objarray.c b/py/objarray.c index c27366d720..2f1f68d81a 100644 --- a/py/objarray.c +++ b/py/objarray.c @@ -258,6 +258,16 @@ STATIC mp_obj_t array_unary_op(mp_unary_op_t op, mp_obj_t o_in) { } } +STATIC int typecode_for_comparison(int typecode) { + if (typecode == BYTEARRAY_TYPECODE) { + typecode = 'B'; + } + if (typecode <= 'Z') { + typecode += 32; // to lowercase + } + return typecode; +} + STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) { mp_obj_array_t *lhs = MP_OBJ_TO_PTR(lhs_in); switch (op) { @@ -319,7 +329,20 @@ STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs if (!mp_get_buffer(rhs_in, &rhs_bufinfo, MP_BUFFER_READ)) { return mp_const_false; } - return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len)); + // mp_seq_cmp_bytes is used so only compatible representations can be correctly compared. + // The type doesn't matter: array/bytearray/str/bytes all have the same buffer layout, so + // just check if the typecodes are compatible; for testing equality the types should have the + // same code except for signedness, and not be floating point because nan never equals nan. + // Note that typecode_for_comparison always returns lowercase letters to save code size. + // No need for (& TYPECODE_MASK) here: xxx_get_buffer already takes care of that. + const int lhs_code = typecode_for_comparison(lhs_bufinfo.typecode); + const int rhs_code = typecode_for_comparison(rhs_bufinfo.typecode); + if (lhs_code == rhs_code && lhs_code != 'f' && lhs_code != 'd') { + return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len)); + } + // mp_obj_equal_not_equal treats returning MP_OBJ_NULL as 'fall back to pointer comparison' + // for MP_BINARY_OP_EQUAL but that is incompatible with CPython. + mp_raise_NotImplementedError(NULL); } default: diff --git a/tests/basics/array1.py b/tests/basics/array1.py index 5b3f475786..15789e2c99 100644 --- a/tests/basics/array1.py +++ b/tests/basics/array1.py @@ -41,14 +41,23 @@ except ValueError: # equality (CPython requires both sides are array) print(bytes(array.array('b', [0x61, 0x62, 0x63])) == b'abc') print(array.array('b', [0x61, 0x62, 0x63]) == b'abc') +print(array.array('B', [0x61, 0x62, 0x63]) == b'abc') print(array.array('b', [0x61, 0x62, 0x63]) != b'abc') print(array.array('b', [0x61, 0x62, 0x63]) == b'xyz') print(array.array('b', [0x61, 0x62, 0x63]) != b'xyz') print(b'abc' == array.array('b', [0x61, 0x62, 0x63])) +print(b'abc' == array.array('B', [0x61, 0x62, 0x63])) print(b'abc' != array.array('b', [0x61, 0x62, 0x63])) print(b'xyz' == array.array('b', [0x61, 0x62, 0x63])) print(b'xyz' != array.array('b', [0x61, 0x62, 0x63])) +compatible_typecodes = [] +for t in ["b", "h", "i", "l", "q"]: + compatible_typecodes.append((t, t)) + compatible_typecodes.append((t, t.upper())) +for a, b in compatible_typecodes: + print(array.array(a, [1, 2]) == array.array(b, [1, 2])) + class X(array.array): pass diff --git a/tests/basics/array_micropython.py b/tests/basics/array_micropython.py index 6b3dc7a93b..44dc1d83d8 100644 --- a/tests/basics/array_micropython.py +++ b/tests/basics/array_micropython.py @@ -17,3 +17,15 @@ print(a[0]) a = array.array('P') a.append(1) print(a[0]) + +# comparison between mismatching binary layouts is not implemented +typecodes = ["b", "h", "i", "l", "q", "P", "O", "S", "f", "d"] +for a in typecodes: + for b in typecodes: + if a == b and a not in ["f", "d"]: + continue + try: + array.array(a) == array.array(b) + print('FAIL') + except NotImplementedError: + pass diff --git a/tests/cpydiff/module_array_comparison.py b/tests/cpydiff/module_array_comparison.py new file mode 100644 index 0000000000..a442af3f5b --- /dev/null +++ b/tests/cpydiff/module_array_comparison.py @@ -0,0 +1,9 @@ +""" +categories: Modules,array +description: Comparison between different typecodes not supported +cause: Code size +workaround: Compare individual elements +""" +import array + +array.array("b", [1, 2]) == array.array("i", [1, 2])