From 4b57330465b98df30ef8a10e19a0e197b5797550 Mon Sep 17 00:00:00 2001 From: Damien George Date: Fri, 12 May 2023 23:17:20 +1000 Subject: [PATCH] py/objstr: Return unsupported binop instead of raising TypeError. So that user types can implement reverse operators and have them work with str on the left-hand-side, eg `"a" + UserType()`. Signed-off-by: Damien George --- py/objstr.c | 11 ++++++++++- tests/basics/class_reverse_op.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/py/objstr.c b/py/objstr.c index 4d9dca04af..e6c5ee71cf 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -403,7 +403,16 @@ mp_obj_t mp_obj_str_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_i } else { // LHS is str and RHS has an incompatible type // (except if operation is EQUAL, but that's handled by mp_obj_equal) - bad_implicit_conversion(rhs_in); + + // CONTAINS must fail with a bad-implicit-conversion exception, because + // otherwise mp_binary_op() will fallback to `list(lhs).__contains__(rhs)`. + if (op == MP_BINARY_OP_CONTAINS) { + bad_implicit_conversion(rhs_in); + } + + // All other operations are not supported, and may be handled by another + // type, eg for reverse operations. + return MP_OBJ_NULL; } switch (op) { diff --git a/tests/basics/class_reverse_op.py b/tests/basics/class_reverse_op.py index b0dae5f8a3..11aba6aada 100644 --- a/tests/basics/class_reverse_op.py +++ b/tests/basics/class_reverse_op.py @@ -1,5 +1,7 @@ -class A: +# Test reverse operators. +# Test user type with integers. +class A: def __init__(self, v): self.v = v @@ -14,5 +16,33 @@ class A: def __repr__(self): return "A({})".format(self.v) + print(A(3) + 1) print(2 + A(5)) + + +# Test user type with strings. +class B: + def __init__(self, v): + self.v = v + + def __repr__(self): + return "B({})".format(self.v) + + def __ror__(self, o): + return B(o + "|" + self.v) + + def __radd__(self, o): + return B(o + "+" + self.v) + + def __rmul__(self, o): + return B(o + "*" + self.v) + + def __rtruediv__(self, o): + return B(o + "/" + self.v) + + +print("a" | B("b")) +print("a" + B("b")) +print("a" * B("b")) +print("a" / B("b"))