From 5c603bd0fd53e7dbd709883f289d1e580c86e33d Mon Sep 17 00:00:00 2001 From: Paul Sokolovsky Date: Thu, 7 Sep 2017 00:10:10 +0300 Subject: [PATCH] py/objlist: Properly implement comparison with incompatible types. Should raise TypeError, unless it's (in)equality comparison. --- py/objlist.c | 26 ++++++++++++-------------- tests/basics/list_compare.py | 10 ++++++++++ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/py/objlist.c b/py/objlist.c index 6ac33e80e6..d70867dedc 100644 --- a/py/objlist.c +++ b/py/objlist.c @@ -87,18 +87,6 @@ STATIC mp_obj_t list_make_new(const mp_obj_type_t *type_in, size_t n_args, size_ } } -// Don't pass MP_BINARY_OP_NOT_EQUAL here -STATIC bool list_cmp_helper(mp_uint_t op, mp_obj_t self_in, mp_obj_t another_in) { - mp_check_self(MP_OBJ_IS_TYPE(self_in, &mp_type_list)); - if (!MP_OBJ_IS_TYPE(another_in, &mp_type_list)) { - return false; - } - mp_obj_list_t *self = MP_OBJ_TO_PTR(self_in); - mp_obj_list_t *another = MP_OBJ_TO_PTR(another_in); - - return mp_seq_cmp_objs(op, self->items, self->len, another->items, another->len); -} - STATIC mp_obj_t list_unary_op(mp_unary_op_t op, mp_obj_t self_in) { mp_obj_list_t *self = MP_OBJ_TO_PTR(self_in); switch (op) { @@ -146,8 +134,18 @@ STATIC mp_obj_t list_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) { case MP_BINARY_OP_LESS: case MP_BINARY_OP_LESS_EQUAL: case MP_BINARY_OP_MORE: - case MP_BINARY_OP_MORE_EQUAL: - return mp_obj_new_bool(list_cmp_helper(op, lhs, rhs)); + case MP_BINARY_OP_MORE_EQUAL: { + if (!MP_OBJ_IS_TYPE(rhs, &mp_type_list)) { + if (op == MP_BINARY_OP_EQUAL) { + return mp_const_false; + } + return MP_OBJ_NULL; // op not supported + } + + mp_obj_list_t *another = MP_OBJ_TO_PTR(rhs); + bool res = mp_seq_cmp_objs(op, o->items, o->len, another->items, another->len); + return mp_obj_new_bool(res); + } default: return MP_OBJ_NULL; // op not supported diff --git a/tests/basics/list_compare.py b/tests/basics/list_compare.py index eea8814247..fd656c7f15 100644 --- a/tests/basics/list_compare.py +++ b/tests/basics/list_compare.py @@ -48,3 +48,13 @@ print([1] <= [1, 0]) print([1] <= [1, -1]) print([1, 0] <= [1]) print([1, -1] <= [1]) + + +print([] == {}) +print([] != {}) +print([1] == (1,)) + +try: + print([] < {}) +except TypeError: + print("TypeError")