From 2ac1364688cd3ee313661e82a336663551986fc8 Mon Sep 17 00:00:00 2001 From: Damien George Date: Tue, 3 Oct 2017 17:56:27 +1100 Subject: [PATCH] py/objset: Check that RHS of a binary op is a set/frozenset. CPython docs explicitly state that the RHS of a set/frozenset binary op must be a set to prevent user errors. It also preserves commutativity of the ops, eg: "abc" & set() is a TypeError, and so should be set() & "abc". This change actually decreases unix (x64) code by 160 bytes; it increases stm32 by 4 bytes and esp8266 by 28 bytes (but previous patch already introduced a much large saving). --- py/objset.c | 4 ++++ tests/basics/set_binop.py | 12 ++++++++++++ tests/misc/non_compliant.py | 12 ------------ tests/misc/non_compliant.py.exp | 2 -- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/py/objset.c b/py/objset.c index 6dede887c9..80ed263340 100644 --- a/py/objset.c +++ b/py/objset.c @@ -463,6 +463,10 @@ STATIC mp_obj_t set_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) { #else bool update = true; #endif + if (op != MP_BINARY_OP_IN && !is_set_or_frozenset(rhs)) { + // For all ops except containment the RHS must be a set/frozenset + return MP_OBJ_NULL; + } switch (op) { case MP_BINARY_OP_OR: return set_union(lhs, rhs); diff --git a/tests/basics/set_binop.py b/tests/basics/set_binop.py index 7848920b6a..bc76533b1f 100644 --- a/tests/basics/set_binop.py +++ b/tests/basics/set_binop.py @@ -47,6 +47,18 @@ s1 = s2 = set('abc') s1 -= set('ad') print(s1 is s2, len(s1)) +# RHS must be a set +try: + print(set('12') >= '1') +except TypeError: + print('TypeError') + +# RHS must be a set +try: + print(set('12') <= '123') +except TypeError: + print('TypeError') + # unsupported operator try: set('abc') * 2 diff --git a/tests/misc/non_compliant.py b/tests/misc/non_compliant.py index b4c90e9fcf..152633c3b7 100644 --- a/tests/misc/non_compliant.py +++ b/tests/misc/non_compliant.py @@ -39,18 +39,6 @@ try: except NotImplementedError: print('NotImplementedError') -# should raise type error -try: - print(set('12') >= '1') -except TypeError: - print('TypeError') - -# should raise type error -try: - print(set('12') <= '123') -except TypeError: - print('TypeError') - # uPy raises TypeError, shold be ValueError try: '%c' % b'\x01\x02' diff --git a/tests/misc/non_compliant.py.exp b/tests/misc/non_compliant.py.exp index ba5590acc0..9c157fd5bd 100644 --- a/tests/misc/non_compliant.py.exp +++ b/tests/misc/non_compliant.py.exp @@ -3,8 +3,6 @@ AttributeError TypeError NotImplementedError NotImplementedError -True -True TypeError, ValueError NotImplementedError NotImplementedError