From 8c656754aa2892cbd36968bfaab1ff7033edeb0f Mon Sep 17 00:00:00 2001 From: Christopher Swenson Date: Mon, 27 Aug 2018 10:32:21 +1000 Subject: [PATCH] py/modmath: Add math.factorial, optimised and non-opt implementations. This commit adds the math.factorial function in two variants: - squared difference, which is faster than the naive version, relatively compact, and non-recursive; - a mildly optimised recursive version, faster than the above one. There are some more optimisations that could be done, but they tend to take more code, and more storage space. The recursive version seems like a sensible compromise. The new function is disabled by default, and uses the non-optimised version by default if it is enabled. The options are MICROPY_PY_MATH_FACTORIAL and MICROPY_OPT_MATH_FACTORIAL. --- ports/unix/mpconfigport_coverage.h | 2 + py/modmath.c | 69 +++++++++++++++++++++++++++- py/mpconfig.h | 11 +++++ tests/float/math_factorial_intbig.py | 14 ++++++ 4 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 tests/float/math_factorial_intbig.py diff --git a/ports/unix/mpconfigport_coverage.h b/ports/unix/mpconfigport_coverage.h index a4225e930c..504259cff0 100644 --- a/ports/unix/mpconfigport_coverage.h +++ b/ports/unix/mpconfigport_coverage.h @@ -32,6 +32,7 @@ #include +#define MICROPY_OPT_MATH_FACTORIAL (1) #define MICROPY_FLOAT_HIGH_QUALITY_HASH (1) #define MICROPY_ENABLE_SCHEDULER (1) #define MICROPY_READER_VFS (1) @@ -41,6 +42,7 @@ #define MICROPY_PY_BUILTINS_HELP (1) #define MICROPY_PY_BUILTINS_HELP_MODULES (1) #define MICROPY_PY_SYS_GETSIZEOF (1) +#define MICROPY_PY_MATH_FACTORIAL (1) #define MICROPY_PY_URANDOM_EXTRA_FUNCS (1) #define MICROPY_PY_IO_BUFFEREDWRITER (1) #define MICROPY_PY_IO_RESOURCE_STREAM (1) diff --git a/py/modmath.c b/py/modmath.c index 6072c780a5..d106f240c8 100644 --- a/py/modmath.c +++ b/py/modmath.c @@ -169,7 +169,7 @@ MATH_FUN_1(gamma, tgamma) // lgamma(x): return the natural logarithm of the gamma function of x MATH_FUN_1(lgamma, lgamma) #endif -//TODO: factorial, fsum +//TODO: fsum // Function that takes a variable number of arguments @@ -232,6 +232,70 @@ STATIC mp_obj_t mp_math_degrees(mp_obj_t x_obj) { } STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_degrees_obj, mp_math_degrees); +#if MICROPY_PY_MATH_FACTORIAL + +#if MICROPY_OPT_MATH_FACTORIAL + +// factorial(x): slightly efficient recursive implementation +STATIC mp_obj_t mp_math_factorial_inner(mp_uint_t start, mp_uint_t end) { + if (start == end) { + return mp_obj_new_int(start); + } else if (end - start == 1) { + return mp_binary_op(MP_BINARY_OP_MULTIPLY, MP_OBJ_NEW_SMALL_INT(start), MP_OBJ_NEW_SMALL_INT(end)); + } else if (end - start == 2) { + mp_obj_t left = MP_OBJ_NEW_SMALL_INT(start); + mp_obj_t middle = MP_OBJ_NEW_SMALL_INT(start + 1); + mp_obj_t right = MP_OBJ_NEW_SMALL_INT(end); + mp_obj_t tmp = mp_binary_op(MP_BINARY_OP_MULTIPLY, left, middle); + return mp_binary_op(MP_BINARY_OP_MULTIPLY, tmp, right); + } else { + mp_uint_t middle = start + ((end - start) >> 1); + mp_obj_t left = mp_math_factorial_inner(start, middle); + mp_obj_t right = mp_math_factorial_inner(middle + 1, end); + return mp_binary_op(MP_BINARY_OP_MULTIPLY, left, right); + } +} +STATIC mp_obj_t mp_math_factorial(mp_obj_t x_obj) { + mp_int_t max = mp_obj_get_int(x_obj); + if (max < 0) { + mp_raise_msg(&mp_type_ValueError, "negative factorial"); + } else if (max == 0) { + return MP_OBJ_NEW_SMALL_INT(1); + } + return mp_math_factorial_inner(1, max); +} + +#else + +// factorial(x): squared difference implementation +// based on http://www.luschny.de/math/factorial/index.html +STATIC mp_obj_t mp_math_factorial(mp_obj_t x_obj) { + mp_int_t max = mp_obj_get_int(x_obj); + if (max < 0) { + mp_raise_msg(&mp_type_ValueError, "negative factorial"); + } else if (max <= 1) { + return MP_OBJ_NEW_SMALL_INT(1); + } + mp_int_t h = max >> 1; + mp_int_t q = h * h; + mp_int_t r = q << 1; + if (max & 1) { + r *= max; + } + mp_obj_t prod = MP_OBJ_NEW_SMALL_INT(r); + for (mp_int_t num = 1; num < max - 2; num += 2) { + q -= num; + prod = mp_binary_op(MP_BINARY_OP_MULTIPLY, prod, MP_OBJ_NEW_SMALL_INT(q)); + } + return prod; +} + +#endif + +STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_factorial_obj, mp_math_factorial); + +#endif + STATIC const mp_rom_map_elem_t mp_module_math_globals_table[] = { { MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_math) }, { MP_ROM_QSTR(MP_QSTR_e), mp_const_float_e }, @@ -274,6 +338,9 @@ STATIC const mp_rom_map_elem_t mp_module_math_globals_table[] = { { MP_ROM_QSTR(MP_QSTR_trunc), MP_ROM_PTR(&mp_math_trunc_obj) }, { MP_ROM_QSTR(MP_QSTR_radians), MP_ROM_PTR(&mp_math_radians_obj) }, { MP_ROM_QSTR(MP_QSTR_degrees), MP_ROM_PTR(&mp_math_degrees_obj) }, + #if MICROPY_PY_MATH_FACTORIAL + { MP_ROM_QSTR(MP_QSTR_factorial), MP_ROM_PTR(&mp_math_factorial_obj) }, + #endif #if MICROPY_PY_MATH_SPECIAL_FUNCTIONS { MP_ROM_QSTR(MP_QSTR_erf), MP_ROM_PTR(&mp_math_erf_obj) }, { MP_ROM_QSTR(MP_QSTR_erfc), MP_ROM_PTR(&mp_math_erfc_obj) }, diff --git a/py/mpconfig.h b/py/mpconfig.h index 8f14114057..cd2f2acdf5 100644 --- a/py/mpconfig.h +++ b/py/mpconfig.h @@ -407,6 +407,12 @@ #define MICROPY_OPT_MPZ_BITWISE (0) #endif + +// Whether math.factorial is large, fast and recursive (1) or small and slow (0). +#ifndef MICROPY_OPT_MATH_FACTORIAL +#define MICROPY_OPT_MATH_FACTORIAL (0) +#endif + /*****************************************************************************/ /* Python internal features */ @@ -988,6 +994,11 @@ typedef double mp_float_t; #define MICROPY_PY_MATH_SPECIAL_FUNCTIONS (0) #endif +// Whether to provide math.factorial function +#ifndef MICROPY_PY_MATH_FACTORIAL +#define MICROPY_PY_MATH_FACTORIAL (0) +#endif + // Whether to provide "cmath" module #ifndef MICROPY_PY_CMATH #define MICROPY_PY_CMATH (0) diff --git a/tests/float/math_factorial_intbig.py b/tests/float/math_factorial_intbig.py new file mode 100644 index 0000000000..19d853df2a --- /dev/null +++ b/tests/float/math_factorial_intbig.py @@ -0,0 +1,14 @@ +try: + import math + math.factorial +except (ImportError, AttributeError): + print('SKIP') + raise SystemExit + + +for fun in (math.factorial,): + for x in range(-1, 30): + try: + print('%d' % fun(x)) + except ValueError as e: + print('ValueError')