From ea7031faff9efe6803d8a8f67ad2e3b4a6d390e3 Mon Sep 17 00:00:00 2001 From: Damien George Date: Fri, 12 May 2023 23:16:37 +1000 Subject: [PATCH] py/runtime: If inplace binop fails then try corresponding normal binop. The code that handles inplace-operator to normal-binary-operator fallback is moved in this commit from py/objtype.c to py/runtime.c, making it apply to all types, not just user classes. Signed-off-by: Damien George --- py/objtype.c | 13 +------------ py/runtime.c | 9 +++++++++ tests/basics/class_reverse_op.py | 5 +++++ tests/basics/list_mult.py | 10 ++++++++++ tests/basics/string_mult.py | 10 ++++++++++ 5 files changed, 35 insertions(+), 12 deletions(-) diff --git a/py/objtype.c b/py/objtype.c index 88d1c8ebe0..04bdf5acd7 100644 --- a/py/objtype.c +++ b/py/objtype.c @@ -534,7 +534,6 @@ STATIC mp_obj_t instance_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t // Note: For ducktyping, CPython does not look in the instance members or use // __getattr__ or __getattribute__. It only looks in the class dictionary. mp_obj_instance_t *lhs = MP_OBJ_TO_PTR(lhs_in); -retry:; qstr op_name = mp_binary_op_method_name[op]; /* Still try to lookup native slot if (op_name == 0) { @@ -559,22 +558,12 @@ retry:; res = mp_call_method_n_kw(1, 0, dest); res = op == MP_BINARY_OP_CONTAINS ? mp_obj_new_bool(mp_obj_is_true(res)) : res; } else { - // If this was an inplace method, fallback to normal method - // https://docs.python.org/3/reference/datamodel.html#object.__iadd__ : - // "If a specific method is not defined, the augmented assignment - // falls back to the normal methods." - if (op >= MP_BINARY_OP_INPLACE_OR && op <= MP_BINARY_OP_INPLACE_POWER) { - op -= MP_BINARY_OP_INPLACE_OR - MP_BINARY_OP_OR; - goto retry; - } return MP_OBJ_NULL; // op not supported } #if MICROPY_PY_BUILTINS_NOTIMPLEMENTED // NotImplemented means "try other fallbacks (like calling __rop__ - // instead of __op__) and if nothing works, raise TypeError". As - // MicroPython doesn't implement any fallbacks, signal to raise - // TypeError right away. + // instead of __op__) and if nothing works, raise TypeError". if (res == mp_const_notimplemented) { return MP_OBJ_NULL; // op not supported } diff --git a/py/runtime.c b/py/runtime.c index 47e094763c..3434d9cc48 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -625,6 +625,15 @@ generic_binary_op: } } + // If this was an inplace method, fallback to the corresponding normal method. + // https://docs.python.org/3/reference/datamodel.html#object.__iadd__ : + // "If a specific method is not defined, the augmented assignment falls back + // to the normal methods." + if (op >= MP_BINARY_OP_INPLACE_OR && op <= MP_BINARY_OP_INPLACE_POWER) { + op += MP_BINARY_OP_OR - MP_BINARY_OP_INPLACE_OR; + goto generic_binary_op; + } + #if MICROPY_PY_REVERSE_SPECIAL_METHODS if (op >= MP_BINARY_OP_OR && op <= MP_BINARY_OP_POWER) { mp_obj_t t = rhs; diff --git a/tests/basics/class_reverse_op.py b/tests/basics/class_reverse_op.py index 11aba6aada..915b3a9ef3 100644 --- a/tests/basics/class_reverse_op.py +++ b/tests/basics/class_reverse_op.py @@ -46,3 +46,8 @@ print("a" | B("b")) print("a" + B("b")) print("a" * B("b")) print("a" / B("b")) + +x = "a"; x |= B("b"); print(x) +x = "a"; x += B("b"); print(x) +x = "a"; x *= B("b"); print(x) +x = "a"; x /= B("b"); print(x) diff --git a/tests/basics/list_mult.py b/tests/basics/list_mult.py index 548f88534e..125c548eec 100644 --- a/tests/basics/list_mult.py +++ b/tests/basics/list_mult.py @@ -11,6 +11,16 @@ a = [1, 2, 3] c = a * 3 print(a, c) +# check inplace multiplication +a = [4, 5, 6] +a *= 3 +print(a) + +# check reverse inplace multiplication +a = 3 +a *= [7, 8, 9] +print(a) + # unsupported type on RHS try: [] * None diff --git a/tests/basics/string_mult.py b/tests/basics/string_mult.py index c0713c1d3a..5a7d822947 100644 --- a/tests/basics/string_mult.py +++ b/tests/basics/string_mult.py @@ -10,3 +10,13 @@ for i in (-4, -2, 0, 2, 4): a = '123' c = a * 3 print(a, c) + +# check inplace multiplication +a = '456' +a *= 3 +print(a) + +# check reverse inplace multiplication +a = 3 +a *= '789' +print(a)