diff --git a/py/obj.c b/py/obj.c index 55754f9be2..6aa0abf0d3 100644 --- a/py/obj.c +++ b/py/obj.c @@ -189,7 +189,7 @@ bool mp_obj_is_callable(mp_obj_t o_in) { return mp_obj_instance_is_callable(o_in); } -// This function implements the '==' operator (and so the inverse of '!='). +// This function implements the '==' and '!=' operators. // // From the Python language reference: // (https://docs.python.org/3/reference/expressions.html#not-in) @@ -202,67 +202,89 @@ bool mp_obj_is_callable(mp_obj_t o_in) { // Furthermore, from the v3.4.2 code for object.c: "Practical amendments: If rich // comparison returns NotImplemented, == and != are decided by comparing the object // pointer." -bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) { - // Float (and complex) NaN is never equal to anything, not even itself, - // so we must have a special check here to cover those cases. - if (o1 == o2 - #if MICROPY_PY_BUILTINS_FLOAT - && !mp_obj_is_float(o1) - #endif - #if MICROPY_PY_BUILTINS_COMPLEX - && !mp_obj_is_type(o1, &mp_type_complex) - #endif - ) { - return true; - } - if (o1 == mp_const_none || o2 == mp_const_none) { - return false; - } +mp_obj_t mp_obj_equal_not_equal(mp_binary_op_t op, mp_obj_t o1, mp_obj_t o2) { + mp_obj_t local_true = (op == MP_BINARY_OP_NOT_EQUAL) ? mp_const_false : mp_const_true; + mp_obj_t local_false = (op == MP_BINARY_OP_NOT_EQUAL) ? mp_const_true : mp_const_false; + int pass_number = 0; - // fast path for small ints - if (mp_obj_is_small_int(o1)) { - if (mp_obj_is_small_int(o2)) { - // both SMALL_INT, and not equal if we get here - return false; - } else { - mp_obj_t temp = o2; o2 = o1; o1 = temp; - // o2 is now the SMALL_INT, o1 is not - // fall through to generic op - } + // Shortcut for very common cases + if (o1 == o2 && + (mp_obj_is_small_int(o1) || !(mp_obj_get_type(o1)->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST))) { + return local_true; } // fast path for strings if (mp_obj_is_str(o1)) { if (mp_obj_is_str(o2)) { // both strings, use special function - return mp_obj_str_equal(o1, o2); - } else { - // a string is never equal to anything else - goto str_cmp_err; - } - } else if (mp_obj_is_str(o2)) { - // o1 is not a string (else caught above), so the objects are not equal - str_cmp_err: + return mp_obj_str_equal(o1, o2) ? local_true : local_false; #if MICROPY_PY_STR_BYTES_CMP_WARN - if (mp_obj_is_type(o1, &mp_type_bytes) || mp_obj_is_type(o2, &mp_type_bytes)) { + } else if (mp_obj_is_type(o2, &mp_type_bytes)) { + str_bytes_cmp: mp_warning(MP_WARN_CAT(BytesWarning), "Comparison between bytes and str"); - } + return local_false; #endif - return false; + } else { + goto skip_one_pass; + } + #if MICROPY_PY_STR_BYTES_CMP_WARN + } else if (mp_obj_is_str(o2) && mp_obj_is_type(o1, &mp_type_bytes)) { + // o1 is not a string (else caught above), so the objects are not equal + goto str_bytes_cmp; + #endif + } + + // fast path for small ints + if (mp_obj_is_small_int(o1)) { + if (mp_obj_is_small_int(o2)) { + // both SMALL_INT, and not equal if we get here + return local_false; + } else { + goto skip_one_pass; + } } // generic type, call binary_op(MP_BINARY_OP_EQUAL) - const mp_obj_type_t *type = mp_obj_get_type(o1); - if (type->binary_op != NULL) { - mp_obj_t r = type->binary_op(MP_BINARY_OP_EQUAL, o1, o2); - if (r != MP_OBJ_NULL) { - return r == mp_const_true ? true : false; + while (pass_number < 2) { + const mp_obj_type_t *type = mp_obj_get_type(o1); + // If a full equality test is not needed and the other object is a different + // type then we don't need to bother trying the comparison. + if (type->binary_op != NULL && + ((type->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST) || mp_obj_get_type(o2) == type)) { + // CPython is asymmetric: it will try __eq__ if there's no __ne__ but not the + // other way around. If the class doesn't need a full test we can skip __ne__. + if (op == MP_BINARY_OP_NOT_EQUAL && (type->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST)) { + mp_obj_t r = type->binary_op(MP_BINARY_OP_NOT_EQUAL, o1, o2); + if (r != MP_OBJ_NULL) { + return r; + } + } + + // Try calling __eq__. + mp_obj_t r = type->binary_op(MP_BINARY_OP_EQUAL, o1, o2); + if (r != MP_OBJ_NULL) { + if (op == MP_BINARY_OP_EQUAL) { + return r; + } else { + return mp_obj_is_true(r) ? local_true : local_false; + } + } } + + skip_one_pass: + // Try the other way around if none of the above worked + ++pass_number; + mp_obj_t temp = o1; + o1 = o2; + o2 = temp; } - // equality not implemented, and objects are not the same object, so - // they are defined as not equal - return false; + // equality not implemented, so fall back to pointer conparison + return (o1 == o2) ? local_true : local_false; +} + +bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) { + return mp_obj_is_true(mp_obj_equal_not_equal(MP_BINARY_OP_EQUAL, o1, o2)); } mp_int_t mp_obj_get_int(mp_const_obj_t arg) { diff --git a/py/obj.h b/py/obj.h index 2bc72b5867..2e91b6b150 100644 --- a/py/obj.h +++ b/py/obj.h @@ -445,8 +445,14 @@ typedef mp_obj_t (*mp_fun_var_t)(size_t n, const mp_obj_t *); typedef mp_obj_t (*mp_fun_kw_t)(size_t n, const mp_obj_t *, mp_map_t *); // Flags for type behaviour (mp_obj_type_t.flags) +// If MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST is clear then all the following hold: +// (a) the type only implements the __eq__ operator and not the __ne__ operator; +// (b) __eq__ returns a boolean result (False or True); +// (c) __eq__ is reflexive (A==A is True); +// (d) the type can't be equal to an instance of any different class that also clears this flag. #define MP_TYPE_FLAG_IS_SUBCLASSED (0x0001) #define MP_TYPE_FLAG_HAS_SPECIAL_ACCESSORS (0x0002) +#define MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST (0x0004) typedef enum { PRINT_STR = 0, @@ -729,6 +735,7 @@ void mp_obj_print_exception(const mp_print_t *print, mp_obj_t exc); bool mp_obj_is_true(mp_obj_t arg); bool mp_obj_is_callable(mp_obj_t o_in); +mp_obj_t mp_obj_equal_not_equal(mp_binary_op_t op, mp_obj_t o1, mp_obj_t o2); bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2); static inline bool mp_obj_is_integer(mp_const_obj_t o) { return mp_obj_is_int(o) || mp_obj_is_bool(o); } // returns true if o is bool, small int or long int diff --git a/py/objarray.c b/py/objarray.c index c19617d4e1..51f924ba4e 100644 --- a/py/objarray.c +++ b/py/objarray.c @@ -558,6 +558,7 @@ const mp_obj_type_t mp_type_array = { const mp_obj_type_t mp_type_bytearray = { { &mp_type_type }, .name = MP_QSTR_bytearray, + .flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST, .print = array_print, .make_new = bytearray_make_new, .getiter = array_iterator_new, diff --git a/py/objcomplex.c b/py/objcomplex.c index bf6fb51dc5..0c87f544fb 100644 --- a/py/objcomplex.c +++ b/py/objcomplex.c @@ -148,6 +148,7 @@ STATIC void complex_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) { const mp_obj_type_t mp_type_complex = { { &mp_type_type }, .name = MP_QSTR_complex, + .flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST, .print = complex_print, .make_new = complex_make_new, .unary_op = complex_unary_op, diff --git a/py/objfloat.c b/py/objfloat.c index 3da549bb25..6181d5f578 100644 --- a/py/objfloat.c +++ b/py/objfloat.c @@ -186,6 +186,7 @@ STATIC mp_obj_t float_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs const mp_obj_type_t mp_type_float = { { &mp_type_type }, .name = MP_QSTR_float, + .flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST, .print = float_print, .make_new = float_make_new, .unary_op = float_unary_op, diff --git a/py/objset.c b/py/objset.c index 6f5bcc032a..a18b7e051e 100644 --- a/py/objset.c +++ b/py/objset.c @@ -564,6 +564,7 @@ STATIC MP_DEFINE_CONST_DICT(frozenset_locals_dict, frozenset_locals_dict_table); const mp_obj_type_t mp_type_frozenset = { { &mp_type_type }, .name = MP_QSTR_frozenset, + .flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST, .print = set_print, .make_new = set_make_new, .unary_op = set_unary_op, diff --git a/py/objtype.c b/py/objtype.c index c574dcdfe2..ae0fe6caef 100644 --- a/py/objtype.c +++ b/py/objtype.c @@ -467,7 +467,7 @@ const byte mp_binary_op_method_name[MP_BINARY_OP_NUM_RUNTIME] = { [MP_BINARY_OP_EQUAL] = MP_QSTR___eq__, [MP_BINARY_OP_LESS_EQUAL] = MP_QSTR___le__, [MP_BINARY_OP_MORE_EQUAL] = MP_QSTR___ge__, - // MP_BINARY_OP_NOT_EQUAL, // a != b calls a == b and inverts result + [MP_BINARY_OP_NOT_EQUAL] = MP_QSTR___ne__, [MP_BINARY_OP_CONTAINS] = MP_QSTR___contains__, // If an inplace method is not found a normal method will be used as a fallback @@ -1100,7 +1100,7 @@ mp_obj_t mp_obj_new_type(qstr name, mp_obj_t bases_tuple, mp_obj_t locals_dict) // TODO might need to make a copy of locals_dict; at least that's how CPython does it // Basic validation of base classes - uint16_t base_flags = 0; + uint16_t base_flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST; size_t bases_len; mp_obj_t *bases_items; mp_obj_tuple_get(bases_tuple, &bases_len, &bases_items); diff --git a/py/runtime.c b/py/runtime.c index db044cf7c4..4a718c1e23 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -323,19 +323,8 @@ mp_obj_t mp_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) { // deal with == and != for all types if (op == MP_BINARY_OP_EQUAL || op == MP_BINARY_OP_NOT_EQUAL) { - if (mp_obj_equal(lhs, rhs)) { - if (op == MP_BINARY_OP_EQUAL) { - return mp_const_true; - } else { - return mp_const_false; - } - } else { - if (op == MP_BINARY_OP_EQUAL) { - return mp_const_false; - } else { - return mp_const_true; - } - } + // mp_obj_equal_not_equal supports a bunch of shortcuts + return mp_obj_equal_not_equal(op, lhs, rhs); } // deal with exception_match for all types diff --git a/tests/basics/special_comparisons.py b/tests/basics/special_comparisons.py new file mode 100644 index 0000000000..2ebd7e224e --- /dev/null +++ b/tests/basics/special_comparisons.py @@ -0,0 +1,33 @@ +class A: + def __eq__(self, other): + print("A __eq__ called") + return True + +class B: + def __ne__(self, other): + print("B __ne__ called") + return True + +class C: + def __eq__(self, other): + print("C __eq__ called") + return False + +class D: + def __ne__(self, other): + print("D __ne__ called") + return False + +a = A() +b = B() +c = C() +d = D() + +def test(s): + print(s) + print(eval(s)) + +for x in 'abcd': + for y in 'abcd': + test('{} == {}'.format(x,y)) + test('{} != {}'.format(x,y)) diff --git a/tests/basics/special_comparisons2.py b/tests/basics/special_comparisons2.py new file mode 100644 index 0000000000..c29dc72ce6 --- /dev/null +++ b/tests/basics/special_comparisons2.py @@ -0,0 +1,31 @@ +class E: + def __repr__(self): + return "E" + + def __eq__(self, other): + print('E eq', other) + return 123 + +class F: + def __repr__(self): + return "F" + + def __ne__(self, other): + print('F ne', other) + return -456 + +print(E() != F()) +print(F() != E()) + +tests = (None, 0, 1, 'a') + +for val in tests: + print('==== testing', val) + print(E() == val) + print(val == E()) + print(E() != val) + print(val != E()) + print(F() == val) + print(val == F()) + print(F() != val) + print(val != F())