From 4fe3e493b1a62381db15b724f77d565ff2666120 Mon Sep 17 00:00:00 2001 From: Damien George Date: Mon, 25 Jul 2022 15:23:48 +1000 Subject: [PATCH] py/obj: Make mp_obj_get_complex_maybe call mp_obj_get_float_maybe first. This commit simplifies mp_obj_get_complex_maybe() by first calling mp_obj_get_float_maybe() to handle the cases corresponding to floats. Only if that fails does it attempt to extra a full complex number. This reduces code size and also means that mp_obj_get_complex_maybe() now supports user-defined classes defining __float__; in particular this allows user-defined classes to be used as arguments to cmath-module function. Furthermore, complex_make_new() can now be simplified to directly call mp_obj_get_complex(), instead of mp_obj_get_complex_maybe() followed by mp_obj_get_float(). This also improves error messages from complex with an invalid argument, it now raises "can't convert to complex" rather than "can't convert to float". Signed-off-by: Damien George --- py/obj.c | 17 +---------------- py/objcomplex.c | 7 ++----- tests/float/cmath_dunder.py | 21 +++++++++++++++++++++ tests/float/complex_dunder.py | 6 ++++++ tests/float/math_dunder.py | 15 +++++++++++++++ 5 files changed, 45 insertions(+), 21 deletions(-) create mode 100644 tests/float/cmath_dunder.py create mode 100644 tests/float/math_dunder.py diff --git a/py/obj.c b/py/obj.c index 5a05ea58c5..b461fe50aa 100644 --- a/py/obj.c +++ b/py/obj.c @@ -383,22 +383,7 @@ mp_float_t mp_obj_get_float(mp_obj_t arg) { #if MICROPY_PY_BUILTINS_COMPLEX bool mp_obj_get_complex_maybe(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) { - if (arg == mp_const_false) { - *real = 0; - *imag = 0; - } else if (arg == mp_const_true) { - *real = 1; - *imag = 0; - } else if (mp_obj_is_small_int(arg)) { - *real = (mp_float_t)MP_OBJ_SMALL_INT_VALUE(arg); - *imag = 0; - #if MICROPY_LONGINT_IMPL != MICROPY_LONGINT_IMPL_NONE - } else if (mp_obj_is_exact_type(arg, &mp_type_int)) { - *real = mp_obj_int_as_float_impl(arg); - *imag = 0; - #endif - } else if (mp_obj_is_float(arg)) { - *real = mp_obj_float_get(arg); + if (mp_obj_get_float_maybe(arg, real)) { *imag = 0; } else if (mp_obj_is_type(arg, &mp_type_complex)) { mp_obj_complex_get(arg, real, imag); diff --git a/py/objcomplex.c b/py/objcomplex.c index 3c4cb66140..4aa598a0bc 100644 --- a/py/objcomplex.c +++ b/py/objcomplex.c @@ -89,11 +89,8 @@ STATIC mp_obj_t complex_make_new(const mp_obj_type_t *type_in, size_t n_args, si return args[0]; } else { mp_float_t real, imag; - if (mp_obj_get_complex_maybe(args[0], &real, &imag)) { - return mp_obj_new_complex(real, imag); - } - // something else, try to cast it to a complex - return mp_obj_new_complex(mp_obj_get_float(args[0]), 0); + mp_obj_get_complex(args[0], &real, &imag); + return mp_obj_new_complex(real, imag); } case 2: diff --git a/tests/float/cmath_dunder.py b/tests/float/cmath_dunder.py new file mode 100644 index 0000000000..3526341510 --- /dev/null +++ b/tests/float/cmath_dunder.py @@ -0,0 +1,21 @@ +# test that cmath functions support user classes with __float__ and __complex__ + +try: + import cmath +except ImportError: + print("SKIP") + raise SystemExit + + +class TestFloat: + def __float__(self): + return 1.0 + + +class TestComplex: + def __complex__(self): + return 1j + 10 + + +for clas in TestFloat, TestComplex: + print("%.5g" % cmath.phase(clas())) diff --git a/tests/float/complex_dunder.py b/tests/float/complex_dunder.py index 128dc69293..975d829b47 100644 --- a/tests/float/complex_dunder.py +++ b/tests/float/complex_dunder.py @@ -1,6 +1,11 @@ # test __complex__ function support +class TestFloat: + def __float__(self): + return 1.0 + + class TestComplex: def __complex__(self): return 1j + 10 @@ -20,6 +25,7 @@ class Test: pass +print(complex(TestFloat())) print(complex(TestComplex())) try: diff --git a/tests/float/math_dunder.py b/tests/float/math_dunder.py new file mode 100644 index 0000000000..33ea7f7c1c --- /dev/null +++ b/tests/float/math_dunder.py @@ -0,0 +1,15 @@ +# test that math functions support user classes with __float__ + +try: + import math +except ImportError: + print("SKIP") + raise SystemExit + + +class TestFloat: + def __float__(self): + return 1.0 + + +print("%.5g" % math.exp(TestFloat()))