py/objarray: Prohibit comparison of mismatching types.

Array equality is defined as each element being equal but to keep
code size down MicroPython implements a binary comparison.  This
can only be used correctly for elements with the same binary layout
though so turn it into an NotImplementedError when comparing types
for which the binary comparison yielded incorrect results: types
with different sizes, and floating point numbers because nan != nan.
pull/7250/head
stijn 2021-05-12 17:02:06 +02:00 zatwierdzone przez Damien George
rodzic 6affcb0104
commit 57365d8557
4 zmienionych plików z 54 dodań i 1 usunięć

Wyświetl plik

@ -258,6 +258,16 @@ STATIC mp_obj_t array_unary_op(mp_unary_op_t op, mp_obj_t o_in) {
}
}
STATIC int typecode_for_comparison(int typecode) {
if (typecode == BYTEARRAY_TYPECODE) {
typecode = 'B';
}
if (typecode <= 'Z') {
typecode += 32; // to lowercase
}
return typecode;
}
STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
mp_obj_array_t *lhs = MP_OBJ_TO_PTR(lhs_in);
switch (op) {
@ -319,7 +329,20 @@ STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs
if (!mp_get_buffer(rhs_in, &rhs_bufinfo, MP_BUFFER_READ)) {
return mp_const_false;
}
return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len));
// mp_seq_cmp_bytes is used so only compatible representations can be correctly compared.
// The type doesn't matter: array/bytearray/str/bytes all have the same buffer layout, so
// just check if the typecodes are compatible; for testing equality the types should have the
// same code except for signedness, and not be floating point because nan never equals nan.
// Note that typecode_for_comparison always returns lowercase letters to save code size.
// No need for (& TYPECODE_MASK) here: xxx_get_buffer already takes care of that.
const int lhs_code = typecode_for_comparison(lhs_bufinfo.typecode);
const int rhs_code = typecode_for_comparison(rhs_bufinfo.typecode);
if (lhs_code == rhs_code && lhs_code != 'f' && lhs_code != 'd') {
return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len));
}
// mp_obj_equal_not_equal treats returning MP_OBJ_NULL as 'fall back to pointer comparison'
// for MP_BINARY_OP_EQUAL but that is incompatible with CPython.
mp_raise_NotImplementedError(NULL);
}
default:

Wyświetl plik

@ -41,14 +41,23 @@ except ValueError:
# equality (CPython requires both sides are array)
print(bytes(array.array('b', [0x61, 0x62, 0x63])) == b'abc')
print(array.array('b', [0x61, 0x62, 0x63]) == b'abc')
print(array.array('B', [0x61, 0x62, 0x63]) == b'abc')
print(array.array('b', [0x61, 0x62, 0x63]) != b'abc')
print(array.array('b', [0x61, 0x62, 0x63]) == b'xyz')
print(array.array('b', [0x61, 0x62, 0x63]) != b'xyz')
print(b'abc' == array.array('b', [0x61, 0x62, 0x63]))
print(b'abc' == array.array('B', [0x61, 0x62, 0x63]))
print(b'abc' != array.array('b', [0x61, 0x62, 0x63]))
print(b'xyz' == array.array('b', [0x61, 0x62, 0x63]))
print(b'xyz' != array.array('b', [0x61, 0x62, 0x63]))
compatible_typecodes = []
for t in ["b", "h", "i", "l", "q"]:
compatible_typecodes.append((t, t))
compatible_typecodes.append((t, t.upper()))
for a, b in compatible_typecodes:
print(array.array(a, [1, 2]) == array.array(b, [1, 2]))
class X(array.array):
pass

Wyświetl plik

@ -17,3 +17,15 @@ print(a[0])
a = array.array('P')
a.append(1)
print(a[0])
# comparison between mismatching binary layouts is not implemented
typecodes = ["b", "h", "i", "l", "q", "P", "O", "S", "f", "d"]
for a in typecodes:
for b in typecodes:
if a == b and a not in ["f", "d"]:
continue
try:
array.array(a) == array.array(b)
print('FAIL')
except NotImplementedError:
pass

Wyświetl plik

@ -0,0 +1,9 @@
"""
categories: Modules,array
description: Comparison between different typecodes not supported
cause: Code size
workaround: Compare individual elements
"""
import array
array.array("b", [1, 2]) == array.array("i", [1, 2])