From 1ded8a2977a84777f339ea7e7c788b11d75a6cce Mon Sep 17 00:00:00 2001 From: Jon Bjarni Bjarnason Date: Tue, 12 Apr 2022 22:17:38 +0000 Subject: [PATCH] py/objtype: Convert result of user __contains__ method to bool. Per https://docs.python.org/3/reference/expressions.html#membership-test-operations For user-defined classes which define the contains() method, x in y returns True if y.contains(x) returns a true value, and False otherwise. Fixes issue #7884. --- py/objtype.c | 1 + tests/basics/class_contains.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/py/objtype.c b/py/objtype.c index 0977a67ced..e320adc8bb 100644 --- a/py/objtype.c +++ b/py/objtype.c @@ -549,6 +549,7 @@ retry:; } else if (dest[0] != MP_OBJ_NULL) { dest[2] = rhs_in; res = mp_call_method_n_kw(1, 0, dest); + res = op == MP_BINARY_OP_CONTAINS ? mp_obj_new_bool(mp_obj_is_true(res)) : res; } else { // If this was an inplace method, fallback to normal method // https://docs.python.org/3/reference/datamodel.html#object.__iadd__ : diff --git a/tests/basics/class_contains.py b/tests/basics/class_contains.py index b6dd3661cd..5fdb1db4c0 100644 --- a/tests/basics/class_contains.py +++ b/tests/basics/class_contains.py @@ -21,3 +21,27 @@ b = B([1, 2]) print(1 in b) print(2 in b) print(3 in b) + + +class C: + def __contains__(self, arg): + return arg + + +print(C().__contains__(0)) +print(C().__contains__(1)) +print(C().__contains__('')) +print(C().__contains__('foo')) +print(C().__contains__(None)) + +print(0 in C()) +print(1 in C()) +print('' in C()) +print('foo' in C()) +print(None in C()) + +print(0 not in C()) +print(1 not in C()) +print('' not in C()) +print('foo' not in C()) +print(None not in C())