From 9ec1caf42e7733b5141b7aecf1b6e58834a94bf7 Mon Sep 17 00:00:00 2001 From: Damien George Date: Mon, 10 Feb 2020 21:41:22 +1100 Subject: [PATCH] py: Expand type equality flags to 3 separate ones, fix bool/namedtuple. Both bool and namedtuple will check against other types for equality; int, float and complex for bool, and tuple for namedtuple. So to make them work after the recent commit 3aab54bf434e7f025a91ea05052f1bac439fad8c they would need MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST set. But that makes all bool and namedtuple equality checks less efficient because mp_obj_equal_not_equal() could no longer short-cut x==x, and would need to try __ne__. To improve this, this commit splits the MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST flags into 3 separate flags to give types more fine-grained control over how their equality behaves. These new flags are then used to fix bool and namedtuple equality. Fixes issue #5615 and #5620. --- py/obj.c | 6 +++--- py/obj.h | 15 +++++++++------ py/objarray.c | 2 +- py/objbool.c | 1 + py/objcomplex.c | 2 +- py/objfloat.c | 2 +- py/objnamedtuple.c | 1 + py/objset.c | 2 +- py/objtype.c | 3 ++- 9 files changed, 20 insertions(+), 14 deletions(-) diff --git a/py/obj.c b/py/obj.c index 6aa0abf0d3..382815b3f0 100644 --- a/py/obj.c +++ b/py/obj.c @@ -209,7 +209,7 @@ mp_obj_t mp_obj_equal_not_equal(mp_binary_op_t op, mp_obj_t o1, mp_obj_t o2) { // Shortcut for very common cases if (o1 == o2 && - (mp_obj_is_small_int(o1) || !(mp_obj_get_type(o1)->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST))) { + (mp_obj_is_small_int(o1) || !(mp_obj_get_type(o1)->flags & MP_TYPE_FLAG_EQ_NOT_REFLEXIVE))) { return local_true; } @@ -250,10 +250,10 @@ mp_obj_t mp_obj_equal_not_equal(mp_binary_op_t op, mp_obj_t o1, mp_obj_t o2) { // If a full equality test is not needed and the other object is a different // type then we don't need to bother trying the comparison. if (type->binary_op != NULL && - ((type->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST) || mp_obj_get_type(o2) == type)) { + ((type->flags & MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE) || mp_obj_get_type(o2) == type)) { // CPython is asymmetric: it will try __eq__ if there's no __ne__ but not the // other way around. If the class doesn't need a full test we can skip __ne__. - if (op == MP_BINARY_OP_NOT_EQUAL && (type->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST)) { + if (op == MP_BINARY_OP_NOT_EQUAL && (type->flags & MP_TYPE_FLAG_EQ_HAS_NEQ_TEST)) { mp_obj_t r = type->binary_op(MP_BINARY_OP_NOT_EQUAL, o1, o2); if (r != MP_OBJ_NULL) { return r; diff --git a/py/obj.h b/py/obj.h index 2e91b6b150..e3eb85e138 100644 --- a/py/obj.h +++ b/py/obj.h @@ -445,14 +445,17 @@ typedef mp_obj_t (*mp_fun_var_t)(size_t n, const mp_obj_t *); typedef mp_obj_t (*mp_fun_kw_t)(size_t n, const mp_obj_t *, mp_map_t *); // Flags for type behaviour (mp_obj_type_t.flags) -// If MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST is clear then all the following hold: -// (a) the type only implements the __eq__ operator and not the __ne__ operator; -// (b) __eq__ returns a boolean result (False or True); -// (c) __eq__ is reflexive (A==A is True); -// (d) the type can't be equal to an instance of any different class that also clears this flag. +// If MP_TYPE_FLAG_EQ_NOT_REFLEXIVE is clear then __eq__ is reflexive (A==A returns True). +// If MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE is clear then the type can't be equal to an +// instance of any different class that also clears this flag. If this flag is set +// then the type may check for equality against a different type. +// If MP_TYPE_FLAG_EQ_HAS_NEQ_TEST is clear then the type only implements the __eq__ +// operator and not the __ne__ operator. If it's set then __ne__ may be implemented. #define MP_TYPE_FLAG_IS_SUBCLASSED (0x0001) #define MP_TYPE_FLAG_HAS_SPECIAL_ACCESSORS (0x0002) -#define MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST (0x0004) +#define MP_TYPE_FLAG_EQ_NOT_REFLEXIVE (0x0040) +#define MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE (0x0080) +#define MP_TYPE_FLAG_EQ_HAS_NEQ_TEST (0x0010) typedef enum { PRINT_STR = 0, diff --git a/py/objarray.c b/py/objarray.c index 51f924ba4e..a6e6fca1c9 100644 --- a/py/objarray.c +++ b/py/objarray.c @@ -557,8 +557,8 @@ const mp_obj_type_t mp_type_array = { #if MICROPY_PY_BUILTINS_BYTEARRAY const mp_obj_type_t mp_type_bytearray = { { &mp_type_type }, + .flags = MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE, .name = MP_QSTR_bytearray, - .flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST, .print = array_print, .make_new = bytearray_make_new, .getiter = array_iterator_new, diff --git a/py/objbool.c b/py/objbool.c index 4c046ac8f2..1183a7ec92 100644 --- a/py/objbool.c +++ b/py/objbool.c @@ -86,6 +86,7 @@ STATIC mp_obj_t bool_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_ const mp_obj_type_t mp_type_bool = { { &mp_type_type }, + .flags = MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE, // can match all numeric types .name = MP_QSTR_bool, .print = bool_print, .make_new = bool_make_new, diff --git a/py/objcomplex.c b/py/objcomplex.c index 0c87f544fb..2031c1c03a 100644 --- a/py/objcomplex.c +++ b/py/objcomplex.c @@ -147,8 +147,8 @@ STATIC void complex_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) { const mp_obj_type_t mp_type_complex = { { &mp_type_type }, + .flags = MP_TYPE_FLAG_EQ_NOT_REFLEXIVE | MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE, .name = MP_QSTR_complex, - .flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST, .print = complex_print, .make_new = complex_make_new, .unary_op = complex_unary_op, diff --git a/py/objfloat.c b/py/objfloat.c index 6181d5f578..347f692a48 100644 --- a/py/objfloat.c +++ b/py/objfloat.c @@ -185,8 +185,8 @@ STATIC mp_obj_t float_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs const mp_obj_type_t mp_type_float = { { &mp_type_type }, + .flags = MP_TYPE_FLAG_EQ_NOT_REFLEXIVE | MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE, .name = MP_QSTR_float, - .flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST, .print = float_print, .make_new = float_make_new, .unary_op = float_unary_op, diff --git a/py/objnamedtuple.c b/py/objnamedtuple.c index 54d8493c37..3e0f661397 100644 --- a/py/objnamedtuple.c +++ b/py/objnamedtuple.c @@ -155,6 +155,7 @@ mp_obj_namedtuple_type_t *mp_obj_new_namedtuple_base(size_t n_fields, mp_obj_t * STATIC mp_obj_t mp_obj_new_namedtuple_type(qstr name, size_t n_fields, mp_obj_t *fields) { mp_obj_namedtuple_type_t *o = mp_obj_new_namedtuple_base(n_fields, fields); o->base.base.type = &mp_type_type; + o->base.flags = MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE; // can match tuple o->base.name = name; o->base.print = namedtuple_print; o->base.make_new = namedtuple_make_new; diff --git a/py/objset.c b/py/objset.c index a18b7e051e..48d6caf4a8 100644 --- a/py/objset.c +++ b/py/objset.c @@ -563,8 +563,8 @@ STATIC MP_DEFINE_CONST_DICT(frozenset_locals_dict, frozenset_locals_dict_table); const mp_obj_type_t mp_type_frozenset = { { &mp_type_type }, + .flags = MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE, .name = MP_QSTR_frozenset, - .flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST, .print = set_print, .make_new = set_make_new, .unary_op = set_unary_op, diff --git a/py/objtype.c b/py/objtype.c index ae0fe6caef..57cfeeccbc 100644 --- a/py/objtype.c +++ b/py/objtype.c @@ -1100,7 +1100,8 @@ mp_obj_t mp_obj_new_type(qstr name, mp_obj_t bases_tuple, mp_obj_t locals_dict) // TODO might need to make a copy of locals_dict; at least that's how CPython does it // Basic validation of base classes - uint16_t base_flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST; + uint16_t base_flags = MP_TYPE_FLAG_EQ_NOT_REFLEXIVE + | MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE | MP_TYPE_FLAG_EQ_HAS_NEQ_TEST; size_t bases_len; mp_obj_t *bases_items; mp_obj_tuple_get(bases_tuple, &bases_len, &bases_items);