diff --git a/py/objarray.c b/py/objarray.c index 2f1f68d81a..5bd9454716 100644 --- a/py/objarray.c +++ b/py/objarray.c @@ -258,12 +258,13 @@ STATIC mp_obj_t array_unary_op(mp_unary_op_t op, mp_obj_t o_in) { } } -STATIC int typecode_for_comparison(int typecode) { +STATIC int typecode_for_comparison(int typecode, bool *is_unsigned) { if (typecode == BYTEARRAY_TYPECODE) { typecode = 'B'; } if (typecode <= 'Z') { typecode += 32; // to lowercase + *is_unsigned = true; } return typecode; } @@ -322,7 +323,11 @@ STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs return mp_const_false; } - case MP_BINARY_OP_EQUAL: { + case MP_BINARY_OP_EQUAL: + case MP_BINARY_OP_LESS: + case MP_BINARY_OP_LESS_EQUAL: + case MP_BINARY_OP_MORE: + case MP_BINARY_OP_MORE_EQUAL: { mp_buffer_info_t lhs_bufinfo; mp_buffer_info_t rhs_bufinfo; array_get_buffer(lhs_in, &lhs_bufinfo, MP_BUFFER_READ); @@ -333,11 +338,13 @@ STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs // The type doesn't matter: array/bytearray/str/bytes all have the same buffer layout, so // just check if the typecodes are compatible; for testing equality the types should have the // same code except for signedness, and not be floating point because nan never equals nan. + // For > and < the types should be the same and unsigned. // Note that typecode_for_comparison always returns lowercase letters to save code size. // No need for (& TYPECODE_MASK) here: xxx_get_buffer already takes care of that. - const int lhs_code = typecode_for_comparison(lhs_bufinfo.typecode); - const int rhs_code = typecode_for_comparison(rhs_bufinfo.typecode); - if (lhs_code == rhs_code && lhs_code != 'f' && lhs_code != 'd') { + bool is_unsigned = false; + const int lhs_code = typecode_for_comparison(lhs_bufinfo.typecode, &is_unsigned); + const int rhs_code = typecode_for_comparison(rhs_bufinfo.typecode, &is_unsigned); + if (lhs_code == rhs_code && lhs_code != 'f' && lhs_code != 'd' && (op == MP_BINARY_OP_EQUAL || is_unsigned)) { return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len)); } // mp_obj_equal_not_equal treats returning MP_OBJ_NULL as 'fall back to pointer comparison' diff --git a/tests/basics/array1.py b/tests/basics/array1.py index 15789e2c99..f21ad4bd75 100644 --- a/tests/basics/array1.py +++ b/tests/basics/array1.py @@ -66,3 +66,24 @@ print(X('b', [0x61, 0x62, 0x63]) == b'abc') print(X('b', [0x61, 0x62, 0x63]) != b'abc') print(X('b', [0x61, 0x62, 0x63]) == array.array('b', [0x61, 0x62, 0x63])) print(X('b', [0x61, 0x62, 0x63]) != array.array('b', [0x61, 0x62, 0x63])) + +# other comparisons +for typecode in ["B", "H", "I", "L", "Q"]: + a = array.array(typecode, [1, 1]) + print(a < a) + print(a <= a) + print(a > a) + print(a >= a) + + al = array.array(typecode, [1, 0]) + ab = array.array(typecode, [1, 2]) + + print(a < al) + print(a <= al) + print(a > al) + print(a >= al) + + print(a < ab) + print(a <= ab) + print(a > ab) + print(a >= ab) diff --git a/tests/basics/bytearray1.py b/tests/basics/bytearray1.py index b598500264..d12292e879 100644 --- a/tests/basics/bytearray1.py +++ b/tests/basics/bytearray1.py @@ -27,6 +27,26 @@ print(bytearray([1]) == b"1") print(b"1" == bytearray([1])) print(bytearray() == bytearray()) +b1 = bytearray([1, 2, 3]) +b2 = bytearray([1, 2, 3]) +b3 = bytearray([1, 3]) +print(b1 == b2) +print(b2 != b3) +print(b1 <= b2) +print(b1 <= b3) +print(b1 < b3) +print(b1 >= b2) +print(b3 >= b2) +print(b3 > b2) +print(b1 != b2) +print(b2 == b3) +print(b1 > b2) +print(b1 > b3) +print(b1 >= b3) +print(b1 < b2) +print(b3 < b2) +print(b3 <= b2) + # comparison with other type should return False print(bytearray() == 1)