From 438c88dd2fbf8250883882188cca513ce534271f Mon Sep 17 00:00:00 2001 From: Damien George Date: Sat, 22 Feb 2014 19:25:23 +0000 Subject: [PATCH] Add arbitrary precision integer support. Some functionality is still missing (eg and, or, bit shift), and some things are buggy (eg subtract). --- py/binary.c | 1 - py/mpconfig.h | 1 + py/mpz.c | 1009 ++++++++++++++++++++++++++++++++++++++++++ py/mpz.h | 65 +++ py/obj.h | 3 + py/objint.c | 1 + py/objint.h | 6 +- py/objint_longlong.c | 1 + py/objint_mpz.c | 181 ++++++++ py/py.mk | 2 + stm/mpconfigport.h | 1 + unix/mpconfigport.h | 2 +- 12 files changed, 1267 insertions(+), 6 deletions(-) create mode 100644 py/mpz.c create mode 100644 py/mpz.h create mode 100644 py/objint_mpz.c diff --git a/py/binary.c b/py/binary.c index b28eb6426c..20b38b8b14 100644 --- a/py/binary.c +++ b/py/binary.c @@ -6,7 +6,6 @@ #include "mpconfig.h" #include "qstr.h" #include "obj.h" -#include "objint.h" #include "binary.h" // Helpers to work with binary-encoded data diff --git a/py/mpconfig.h b/py/mpconfig.h index 00e2439e4d..34c83d3245 100644 --- a/py/mpconfig.h +++ b/py/mpconfig.h @@ -68,6 +68,7 @@ // Long int implementation #define MICROPY_LONGINT_IMPL_NONE (0) #define MICROPY_LONGINT_IMPL_LONGLONG (1) +#define MICROPY_LONGINT_IMPL_MPZ (2) #ifndef MICROPY_LONGINT_IMPL #define MICROPY_LONGINT_IMPL (MICROPY_LONGINT_IMPL_NONE) diff --git a/py/mpz.c b/py/mpz.c new file mode 100644 index 0000000000..ab049ba890 --- /dev/null +++ b/py/mpz.c @@ -0,0 +1,1009 @@ +#include +#include +#include +#include +#include + +#include "misc.h" +#include "mpconfig.h" +#include "mpz.h" + +#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ + +#define DIG_SIZE (15) +#define DIG_MASK ((1 << DIG_SIZE) - 1) + +/* + definition of normalise: + ? +*/ + +/* compares i with j + returns sign(i - j) + assumes i, j are normalised +*/ +int mpn_cmp(const mpz_dig_t *idig, uint ilen, const mpz_dig_t *jdig, uint jlen) { + if (ilen < jlen) { return -1; } + if (ilen > jlen) { return 1; } + + for (idig += ilen, jdig += ilen; ilen > 0; --ilen) { + int cmp = *(--idig) - *(--jdig); + if (cmp < 0) { return -1; } + if (cmp > 0) { return 1; } + } + + return 0; +} + +/* computes i = j >> n + returns number of digits in i + assumes enough memory in i; assumes normalised j + can have i, j pointing to same memory +*/ +uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) { + uint n_whole = n / DIG_SIZE; + uint n_part = n % DIG_SIZE; + + if (n_whole >= jlen) { + return 0; + } + + jdig += n_whole; + jlen -= n_whole; + + for (uint i = jlen; i > 0; --i, ++idig, ++jdig) { + mpz_dbl_dig_t d = *jdig; + if (i > 1) + d |= jdig[1] << DIG_SIZE; + d >>= n_part; + *idig = d & DIG_MASK; + } + + if (idig[-1] == 0) { + --jlen; + } + + return jlen; +} + +/* computes i = j + k + returns number of digits in i + assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen + can have i, j, k pointing to same memory +*/ +uint mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) { + mpz_dig_t *oidig = idig; + mpz_dbl_dig_t carry = 0; + + jlen -= klen; + + for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) { + carry += *jdig + *kdig; + *idig = carry & DIG_MASK; + carry >>= DIG_SIZE; + } + + for (; jlen > 0; --jlen, ++idig, ++jdig) { + carry += *jdig; + *idig = carry & DIG_MASK; + carry >>= DIG_SIZE; + } + + if (carry != 0) { + *idig++ = carry; + } + + return idig - oidig; +} + +/* computes i = j - k + returns number of digits in i + assumes enough memory in i; assumes normalised j, k; assumes j >= k + can have i, j, k pointing to same memory +*/ +uint mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) { + mpz_dig_t *oidig = idig; + mpz_dbl_dig_signed_t borrow = 0; + + jlen -= klen; + + for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) { + borrow += *jdig - *kdig; + *idig = borrow & DIG_MASK; + borrow >>= DIG_SIZE; + } + + for (; jlen > 0; --jlen, ++idig, ++kdig) { + borrow += *jdig; + *idig = borrow & DIG_MASK; + borrow >>= DIG_SIZE; + } + + for (--idig; idig >= oidig && *idig == 0; --idig) { + } + + return idig + 1 - oidig; +} + +/* computes i = i * d1 + d2 + returns number of digits in i + assumes enough memory in i; assumes normalised i; assumes dmul != 0 +*/ +uint mpn_mul_dig_add_dig(mpz_dig_t *idig, uint ilen, mpz_dig_t dmul, mpz_dig_t dadd) { + mpz_dig_t *oidig = idig; + mpz_dbl_dig_t carry = dadd; + + for (; ilen > 0; --ilen, ++idig) { + carry += *idig * dmul; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2 + *idig = carry & DIG_MASK; + carry >>= DIG_SIZE; + } + + if (carry != 0) { + *idig++ = carry; + } + + return idig - oidig; +} + +/* computes i = j * k + returns number of digits in i + assumes enough memory in i; assumes i is zeroed; assumes normalised j, k + can have j, k point to same memory +*/ +uint mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, mpz_dig_t *kdig, uint klen) { + mpz_dig_t *oidig = idig; + uint ilen = 0; + + for (; klen > 0; --klen, ++idig, ++kdig) { + mpz_dig_t *id = idig; + mpz_dbl_dig_t carry = 0; + + uint jl = jlen; + for (mpz_dig_t *jd = jdig; jl > 0; --jl, ++jd, ++id) { + carry += *id + *jd * *kdig; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2 + *id = carry & DIG_MASK; + carry >>= DIG_SIZE; + } + + if (carry != 0) { + *id++ = carry; + } + + ilen = id - oidig; + } + + return ilen; +} + +/* natural_div - quo * den + new_num = old_num (ie num is replaced with rem) + assumes den != 0 + assumes num_dig has enough memory to be extended by 1 digit + assumes quo_dig has enough memory (as many digits as num) + assumes quo_dig is filled with zeros + modifies den_dig memory, but restors it to original state at end +*/ + +void mpn_div(mpz_dig_t *num_dig, machine_uint_t *num_len, mpz_dig_t *den_dig, machine_uint_t den_len, mpz_dig_t *quo_dig, machine_uint_t *quo_len) { + mpz_dig_t *orig_num_dig = num_dig; + mpz_dig_t *orig_quo_dig = quo_dig; + mpz_dig_t norm_shift = 0; + mpz_dbl_dig_t lead_den_digit; + + // handle simple cases + { + int cmp = mpn_cmp(num_dig, *num_len, den_dig, den_len); + if (cmp == 0) { + *num_len = 0; + quo_dig[0] = 1; + *quo_len = 1; + return; + } else if (cmp < 0) { + // numerator remains the same + *quo_len = 0; + return; + } + } + + // count number of leading zeros in leading digit of denominator + { + mpz_dig_t d = den_dig[den_len - 1]; + while ((d & (1 << (DIG_SIZE - 1))) == 0) { + d <<= 1; + ++norm_shift; + } + } + + // normalise denomenator (leading bit of leading digit is 1) + for (mpz_dig_t *den = den_dig, carry = 0; den < den_dig + den_len; ++den) { + mpz_dig_t d = *den; + *den = ((d << norm_shift) | carry) & DIG_MASK; + carry = d >> (DIG_SIZE - norm_shift); + } + + // now need to shift numerator by same amount as denominator + // first, increase length of numerator in case we need more room to shift + num_dig[*num_len] = 0; + ++(*num_len); + for (mpz_dig_t *num = num_dig, carry = 0; num < num_dig + *num_len; ++num) { + mpz_dig_t n = *num; + *num = ((n << norm_shift) | carry) & DIG_MASK; + carry = n >> (DIG_SIZE - norm_shift); + } + + // cache the leading digit of the denominator + lead_den_digit = den_dig[den_len - 1]; + + // point num_dig to last digit in numerator + num_dig += *num_len - 1; + + // calculate number of digits in quotient + *quo_len = *num_len - den_len; + + // point to last digit to store for quotient + quo_dig += *quo_len - 1; + + // keep going while we have enough digits to divide + while (*num_len > den_len) { + mpz_dbl_dig_t quo = (*num_dig << DIG_SIZE) | num_dig[-1]; + + // get approximate quotient + quo /= lead_den_digit; + + // multiply quo by den and subtract from num get remainder + { + mpz_dbl_dig_signed_t borrow = 0; + + for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { + borrow += *n - quo * *d; // will overflow if DIG_SIZE >= 16 + *n = borrow & DIG_MASK; + borrow >>= DIG_SIZE; + } + borrow += *num_dig; // will overflow if DIG_SIZE >= 16 + *num_dig = borrow & DIG_MASK; + borrow >>= DIG_SIZE; + + // adjust quotient if it is too big + for (; borrow != 0; --quo) { + mpz_dbl_dig_t carry = 0; + for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { + carry += *n + *d; + *n = carry & DIG_MASK; + carry >>= DIG_SIZE; + } + carry += *num_dig; + *num_dig = carry & DIG_MASK; + carry >>= DIG_SIZE; + + borrow += carry; + } + } + + // store this digit of the quotient + *quo_dig = quo & DIG_MASK; + --quo_dig; + + // move down to next digit of numerator + --num_dig; + --(*num_len); + } + + // unnormalise denomenator + for (mpz_dig_t *den = den_dig + den_len - 1, carry = 0; den >= den_dig; --den) { + mpz_dig_t d = *den; + *den = ((d >> norm_shift) | carry) & DIG_MASK; + carry = d << (DIG_SIZE - norm_shift); + } + + // unnormalise numerator (remainder now) + for (mpz_dig_t *num = orig_num_dig + *num_len - 1, carry = 0; num >= orig_num_dig; --num) { + mpz_dig_t n = *num; + *num = ((n >> norm_shift) | carry) & DIG_MASK; + carry = n << (DIG_SIZE - norm_shift); + } + + // strip trailing zeros + + while (*quo_len > 0 && orig_quo_dig[*quo_len - 1] == 0) { + --(*quo_len); + } + + while (*num_len > 0 && orig_num_dig[*num_len - 1] == 0) { + --(*num_len); + } +} + +#define MIN_ALLOC (4) +#define ALIGN_ALLOC (2) +#define NUM_DIG_FOR_INT (sizeof(int) * 8 / DIG_SIZE + 1) + +static const uint log_base2_floor[] = { + 0, + 0, 1, 1, 2, + 2, 2, 2, 3, + 3, 3, 3, 3, + 3, 3, 3, 4, + 4, 4, 4, 4, + 4, 4, 4, 4, + 4, 4, 4, 4, + 4, 4, 4, 5 +}; + +bool mpz_int_is_sml_int(int i) { + return -(1 << DIG_SIZE) < i && i < (1 << DIG_SIZE); +} + +void mpz_init_zero(mpz_t *z) { + z->alloc = 0; + z->neg = 0; + z->len = 0; + z->dig = NULL; +} + +void mpz_init_from_int(mpz_t *z, machine_int_t val) { + mpz_init_zero(z); + mpz_set_from_int(z, val); +} + +void mpz_deinit(mpz_t *z) { + if (z != NULL) { + m_del(mpz_dig_t, z->dig, z->alloc); + } +} + +mpz_t *mpz_zero(void) { + mpz_t *z = m_new_obj(mpz_t); + mpz_init_zero(z); + return z; +} + +mpz_t *mpz_from_int(machine_int_t val) { + mpz_t *z = mpz_zero(); + mpz_set_from_int(z, val); + return z; +} + +mpz_t *mpz_from_str(const char *str, uint len, bool neg, uint base) { + mpz_t *z = mpz_zero(); + mpz_set_from_str(z, str, len, neg, base); + return z; +} + +void mpz_free(mpz_t *z) { + if (z != NULL) { + m_del(mpz_dig_t, z->dig, z->alloc); + m_del_obj(mpz_t, z); + } +} + +STATIC void mpz_need_dig(mpz_t *z, uint need) { + uint alloc; + if (need < MIN_ALLOC) { + alloc = MIN_ALLOC; + } else { + alloc = (need + ALIGN_ALLOC) & (~(ALIGN_ALLOC - 1)); + } + + if (z->dig == NULL || z->alloc < alloc) { + z->dig = m_renew(mpz_dig_t, z->dig, z->alloc, alloc); + z->alloc = alloc; + } +} + +mpz_t *mpz_clone(const mpz_t *src) { + mpz_t *z = m_new_obj(mpz_t); + z->alloc = src->alloc; + z->neg = src->neg; + z->len = src->len; + if (src->dig == NULL) { + z->dig = NULL; + } else { + z->dig = m_new(mpz_dig_t, z->alloc); + memcpy(z->dig, src->dig, src->alloc * sizeof(mpz_dig_t)); + } + return z; +} + +void mpz_set(mpz_t *dest, const mpz_t *src) { + mpz_need_dig(dest, src->len); + dest->neg = src->neg; + dest->len = src->len; + memcpy(dest->dig, src->dig, src->len * sizeof(mpz_dig_t)); +} + +void mpz_set_from_int(mpz_t *z, machine_int_t val) { + mpz_need_dig(z, NUM_DIG_FOR_INT); + + if (val < 0) { + z->neg = 1; + val = -val; + } else { + z->neg = 0; + } + + z->len = 0; + while (val > 0) { + z->dig[z->len++] = val & DIG_MASK; + val >>= DIG_SIZE; + } +} + +// returns number of bytes from str that were processed +uint mpz_set_from_str(mpz_t *z, const char *str, uint len, bool neg, uint base) { + assert(base < 36); + + const char *cur = str; + const char *top = str + len; + + mpz_need_dig(z, len * 8 / DIG_SIZE + 1); + + if (neg) { + z->neg = 1; + } else { + z->neg = 0; + } + + z->len = 0; + for (; cur < top; ++cur) { // XXX UTF8 next char + //uint v = char_to_numeric(cur#); // XXX UTF8 get char + uint v = *cur; + if ('0' <= v && v <= '9') { + v -= '0'; + } else if ('A' <= v && v <= 'Z') { + v -= 'A' - 10; + } else if ('a' <= v && v <= 'z') { + v -= 'a' - 10; + } else { + break; + } + if (v >= base) { + break; + } + z->len = mpn_mul_dig_add_dig(z->dig, z->len, base, v); + } + + return cur - str; +} + +bool mpz_is_zero(const mpz_t *z) { + return z->len == 0; +} + +bool mpz_is_pos(const mpz_t *z) { + return z->len > 0 && z->neg == 0; +} + +bool mpz_is_neg(const mpz_t *z) { + return z->len > 0 && z->neg != 0; +} + +bool mpz_is_odd(const mpz_t *z) { + return z->len > 0 && (z->dig[0] & 1) != 0; +} + +bool mpz_is_even(const mpz_t *z) { + return z->len == 0 || (z->dig[0] & 1) == 0; +} + +int mpz_cmp(const mpz_t *z1, const mpz_t *z2) { + int cmp = z2->neg - z1->neg; + if (cmp != 0) { + return cmp; + } + cmp = mpn_cmp(z1->dig, z1->len, z2->dig, z2->len); + if (z1->neg != 0) { + cmp = -cmp; + } + return cmp; +} + +int mpz_cmp_sml_int(const mpz_t *z, int sml_int) { + int cmp; + if (z->neg == 0) { + if (sml_int < 0) return 1; + if (sml_int == 0) { + if (z->len == 0) return 0; + return 1; + } + if (z->len == 0) return -1; + assert(sml_int < (1 << DIG_SIZE)); + if (z->len != 1) return 1; + cmp = z->dig[0] - sml_int; + } else { + if (sml_int > 0) return -1; + if (sml_int == 0) { + if (z->len == 0) return 0; + return -1; + } + if (z->len == 0) return 1; + assert(sml_int > -(1 << DIG_SIZE)); + if (z->len != 1) return -1; + cmp = -z->dig[0] - sml_int; + } + if (cmp < 0) return -1; + if (cmp > 0) return 1; + return 0; +} + +/* not finished +mpz_t *mpz_shl(mpz_t *dest, const mpz_t *lhs, int rhs) +{ + if (dest != lhs) + dest = mpz_set(dest, lhs); + + if (dest.len == 0 || rhs == 0) + return dest; + + if (rhs < 0) + return mpz_shr(dest, dest, -rhs); + + printf("mpz_shl: not implemented\n"); + + return dest; +} + +mpz_t *mpz_shr(mpz_t *dest, const mpz_t *lhs, int rhs) +{ + if (dest != lhs) + dest = mpz_set(dest, lhs); + + if (dest.len == 0 || rhs == 0) + return dest; + + if (rhs < 0) + return mpz_shl(dest, dest, -rhs); + + dest.len = mpn_shr(dest.len, dest.dig, rhs); + dest.dig[dest.len .. dest->alloc] = 0; + + return dest; +} +*/ + + +#if 0 +these functions are unused + +/* returns abs(z) +*/ +mpz_t *mpz_abs(const mpz_t *z) { + mpz_t *z2 = mpz_clone(z); + z2->neg = 0; + return z2; +} + +/* returns -z +*/ +mpz_t *mpz_neg(const mpz_t *z) { + mpz_t *z2 = mpz_clone(z); + z2->neg = 1 - z2->neg; + return z2; +} + +/* returns lhs + rhs + can have lhs, rhs the same +*/ +mpz_t *mpz_add(const mpz_t *lhs, const mpz_t *rhs) { + mpz_t *z = mpz_zero(); + mpz_add_inpl(z, lhs, rhs); + return z; +} + +/* returns lhs - rhs + can have lhs, rhs the same +*/ +mpz_t *mpz_sub(const mpz_t *lhs, const mpz_t *rhs) { + mpz_t *z = mpz_zero(); + mpz_sub_inpl(z, lhs, rhs); + return z; +} + +/* returns lhs * rhs + can have lhs, rhs the same +*/ +mpz_t *mpz_mul(const mpz_t *lhs, const mpz_t *rhs) { + mpz_t *z = mpz_zero(); + mpz_mul_inpl(z, lhs, rhs); + return z; +} + +/* returns lhs ** rhs + can have lhs, rhs the same +*/ +mpz_t *mpz_pow(const mpz_t *lhs, const mpz_t *rhs) { + mpz_t *z = mpz_zero(); + mpz_pow_inpl(z, lhs, rhs); + return z; +} +#endif + +/* computes dest = abs(z) + can have dest, z the same +*/ +void mpz_abs_inpl(mpz_t *dest, const mpz_t *z) { + if (dest != z) { + mpz_set(dest, z); + } + dest->neg = 0; +} + +/* computes dest = -z + can have dest, z the same +*/ +void mpz_neg_inpl(mpz_t *dest, const mpz_t *z) { + if (dest != z) { + mpz_set(dest, z); + } + dest->neg = 1 - dest->neg; +} + +/* computes dest = lhs + rhs + can have dest, lhs, rhs the same +*/ +void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) { + if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) { + const mpz_t *temp = lhs; + lhs = rhs; + rhs = temp; + } + + if (lhs->neg == rhs->neg) { + mpz_need_dig(dest, lhs->len + 1); + dest->len = mpn_add(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + } else { + mpz_need_dig(dest, lhs->len); + dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + } + + dest->neg = lhs->neg; +} + +/* computes dest = lhs - rhs + can have dest, lhs, rhs the same +*/ +void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) { + bool neg = false; + + if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) { + const mpz_t *temp = lhs; + lhs = rhs; + rhs = temp; + neg = true; + } + + if (lhs->neg != rhs->neg) { + mpz_need_dig(dest, lhs->len + 1); + dest->len = mpn_add(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + } else { + mpz_need_dig(dest, lhs->len); + dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + } + + if (neg) { + dest->neg = 1 - lhs->neg; + } else { + dest->neg = lhs->neg; + } +} + +/* computes dest = lhs * rhs + can have dest, lhs, rhs the same +*/ +void mpz_mul_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) +{ + if (lhs->len == 0 || rhs->len == 0) { + return mpz_set_from_int(dest, 0); + } + + mpz_t *temp = NULL; + if (lhs == dest) { + lhs = temp = mpz_clone(lhs); + if (rhs == dest) { + rhs = lhs; + } + } else if (rhs == dest) { + rhs = temp = mpz_clone(rhs); + } + + mpz_need_dig(dest, lhs->len + rhs->len); // min mem l+r-1, max mem l+r + memset(dest->dig, 0, dest->alloc * sizeof(mpz_dig_t)); + dest->len = mpn_mul(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + + if (lhs->neg == rhs->neg) { + dest->neg = 0; + } else { + dest->neg = 1; + } + + mpz_free(temp); +} + +/* computes dest = lhs ** rhs + can have dest, lhs, rhs the same +*/ +void mpz_pow_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) { + if (lhs->len == 0 || rhs->neg != 0) { + return mpz_set_from_int(dest, 0); + } + + if (rhs->len == 0) { + return mpz_set_from_int(dest, 1); + } + + mpz_t *x = mpz_clone(lhs); + mpz_t *n = mpz_clone(rhs); + + mpz_set_from_int(dest, 1); + + while (n->len > 0) { + if (mpz_is_odd(n)) { + mpz_mul_inpl(dest, dest, x); + } + mpz_mul_inpl(x, x, x); + n->len = mpn_shr(n->dig, n->dig, n->len, 1); + } + + mpz_free(x); + mpz_free(n); +} + +/* computes gcd(z1, z2) + based on Knuth's modified gcd algorithm (I think?) + gcd(z1, z2) >= 0 + gcd(0, 0) = 0 + gcd(z, 0) = abs(z) +*/ +mpz_t *mpz_gcd(const mpz_t *z1, const mpz_t *z2) { + if (z1->len == 0) { + mpz_t *a = mpz_clone(z2); + a->neg = 0; + return a; + } else if (z2->len == 0) { + mpz_t *a = mpz_clone(z1); + a->neg = 0; + return a; + } + + mpz_t *a = mpz_clone(z1); + mpz_t *b = mpz_clone(z2); + mpz_t c; mpz_init_zero(&c); + a->neg = 0; + b->neg = 0; + + for (;;) { + if (mpz_cmp(a, b) < 0) { + if (a->len == 0) { + mpz_free(a); + mpz_deinit(&c); + return b; + } + mpz_t *t = a; a = b; b = t; + } + if (!(b->len >= 2 || (b->len == 1 && b->dig[0] > 1))) { // compute b > 0; could be mpz_cmp_small_int(b, 1) > 0 + break; + } + mpz_set(&c, b); + do { + mpz_add_inpl(&c, &c, &c); + } while (mpz_cmp(&c, a) <= 0); + c.len = mpn_shr(c.dig, c.dig, c.len, 1); + mpz_sub_inpl(a, a, &c); + } + + mpz_deinit(&c); + + if (b->len == 1 && b->dig[0] == 1) { // compute b == 1; could be mpz_cmp_small_int(b, 1) == 0 + mpz_free(a); + return b; + } else { + mpz_free(b); + return a; + } +} + +/* computes lcm(z1, z2) + = abs(z1) / gcd(z1, z2) * abs(z2) + lcm(z1, z1) >= 0 + lcm(0, 0) = 0 + lcm(z, 0) = 0 +*/ +mpz_t *mpz_lcm(const mpz_t *z1, const mpz_t *z2) +{ + if (z1->len == 0 || z2->len == 0) + return mpz_zero(); + + mpz_t *gcd = mpz_gcd(z1, z2); + mpz_t *quo = mpz_zero(); + mpz_t *rem = mpz_zero(); + mpz_divmod_inpl(quo, rem, z1, gcd); + mpz_mul_inpl(rem, quo, z2); + mpz_free(gcd); + mpz_free(quo); + rem->neg = 0; + return rem; +} + +/* computes new integers in quo and rem such that: + quo * rhs + rem = lhs + 0 <= rem < rhs + can have lhs, rhs the same +*/ +void mpz_divmod(const mpz_t *lhs, const mpz_t *rhs, mpz_t **quo, mpz_t **rem) { + *quo = mpz_zero(); + *rem = mpz_zero(); + mpz_divmod_inpl(*quo, *rem, lhs, rhs); +} + +/* computes new integers in quo and rem such that: + quo * rhs + rem = lhs + 0 <= rem < rhs + can have lhs, rhs the same +*/ +void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const mpz_t *rhs) { + if (rhs->len == 0) { + mpz_set_from_int(dest_quo, 0); + mpz_set_from_int(dest_rem, 0); + return; + } + + mpz_need_dig(dest_quo, lhs->len + 1); // +1 necessary? + memset(dest_quo->dig, 0, (lhs->len + 1) * sizeof(mpz_dig_t)); + dest_quo->len = 0; + mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary? + mpz_set(dest_rem, lhs); + //rhs->dig[rhs->len] = 0; + mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len); + + if (lhs->neg != rhs->neg) { + dest_quo->neg = 1; + } +} + +#if 0 +these functions are unused + +/* computes floor(lhs / rhs) + can have lhs, rhs the same +*/ +mpz_t *mpz_div(const mpz_t *lhs, const mpz_t *rhs) { + mpz_t *quo = mpz_zero(); + mpz_t rem; mpz_init_zero(&rem); + mpz_divmod_inpl(quo, &rem, lhs, rhs); + mpz_deinit(&rem); + return quo; +} + +/* computes lhs % rhs ( >= 0) + can have lhs, rhs the same +*/ +mpz_t *mpz_mod(const mpz_t *lhs, const mpz_t *rhs) { + mpz_t quo; mpz_init_zero(&quo); + mpz_t *rem = mpz_zero(); + mpz_divmod_inpl(&quo, rem, lhs, rhs); + mpz_deinit(&quo); + return rem; +} +#endif + +int mpz_as_int(const mpz_t *i) { + int val = 0; + mpz_dig_t *d = i->dig + i->len; + + while (--d >= i->dig) + { + int oldval = val; + val = (val << DIG_SIZE) | *d; + if (val < oldval) + { + if (i->neg == 0) { + return 0x7fffffff; + } else { + return 0x80000000; + } + } + } + + if (i->neg != 0) { + val = -val; + } + + return val; +} + +machine_float_t mpz_as_float(const mpz_t *i) { + machine_float_t val = 0; + mpz_dig_t *d = i->dig + i->len; + + while (--d >= i->dig) { + val = val * (1 << DIG_SIZE) + *d; + } + + if (i->neg != 0) { + val = -val; + } + + return val; +} + +uint mpz_as_str_size(const mpz_t *i, uint base) { + if (base < 2 || base > 32) { + return 0; + } + + return i->len * DIG_SIZE / log_base2_floor[base] + 2 + 1; // +1 for null byte termination +} + +char *mpz_as_str(const mpz_t *i, uint base) { + char *s = m_new(char, mpz_as_str_size(i, base)); + mpz_as_str_inpl(i, base, s); + return s; +} + +// assumes enough space as calculated by mpz_as_str_size +// returns length of string, not including null byte +uint mpz_as_str_inpl(const mpz_t *i, uint base, char *str) { + if (str == NULL || base < 2 || base > 32) { + str[0] = 0; + return 0; + } + + uint ilen = i->len; + + if (ilen == 0) { + str[0] = '0'; + str[1] = 0; + return 1; + } + + // make a copy of mpz digits + mpz_dig_t *dig = m_new(mpz_dig_t, ilen); + memcpy(dig, i->dig, ilen * sizeof(mpz_dig_t)); + + // convert + char *s = str; + bool done; + do { + mpz_dig_t *d = dig + ilen; + mpz_dbl_dig_t a = 0; + + // compute next remainder + while (--d >= dig) { + a = (a << DIG_SIZE) | *d; + *d = a / base; + a %= base; + } + + // convert to character + a += '0'; + if (a > '9') { + a += 'a' - '9' - 1; + } + *s++ = a; + + // check if number is zero + done = true; + for (d = dig; d < dig + ilen; ++d) { + if (*d != 0) { + done = false; + break; + } + } + } while (!done); + + if (i->neg != 0) { + *s++ = '-'; + } + + // reverse string + for (char *u = str, *v = s - 1; u < v; ++u, --v) { + char temp = *u; + *u = *v; + *v = temp; + } + + s[0] = 0; // null termination + + return s - str; +} + +#endif // MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ diff --git a/py/mpz.h b/py/mpz.h new file mode 100644 index 0000000000..eabad62831 --- /dev/null +++ b/py/mpz.h @@ -0,0 +1,65 @@ +typedef uint16_t mpz_dig_t; +typedef uint32_t mpz_dbl_dig_t; +typedef int32_t mpz_dbl_dig_signed_t; + +typedef struct _mpz_t { + struct { + machine_uint_t neg : 1; + machine_uint_t alloc : 31; + }; + machine_uint_t len; + mpz_dig_t *dig; +} mpz_t; + +bool mpz_int_is_sml_int(int i); + +void mpz_init_zero(mpz_t *z); +void mpz_init_from_int(mpz_t *z, machine_int_t val); +void mpz_deinit(mpz_t *z); + +mpz_t *mpz_zero(); +mpz_t *mpz_from_int(machine_int_t i); +mpz_t *mpz_from_str(const char *str, uint len, bool neg, uint base); +void mpz_free(mpz_t *z); + +mpz_t *mpz_clone(const mpz_t *src); + +void mpz_set(mpz_t *dest, const mpz_t *src); +void mpz_set_from_int(mpz_t *z, machine_int_t src); +uint mpz_set_from_str(mpz_t *z, const char *str, uint len, bool neg, uint base); + +bool mpz_is_zero(const mpz_t *z); +bool mpz_is_pos(const mpz_t *z); +bool mpz_is_neg(const mpz_t *z); +bool mpz_is_odd(const mpz_t *z); +bool mpz_is_even(const mpz_t *z); + +int mpz_cmp(const mpz_t *lhs, const mpz_t *rhs); +int mpz_cmp_sml_int(const mpz_t *lhs, int sml_int); + +mpz_t *mpz_abs(const mpz_t *z); +mpz_t *mpz_neg(const mpz_t *z); +mpz_t *mpz_add(const mpz_t *lhs, const mpz_t *rhs); +mpz_t *mpz_sub(const mpz_t *lhs, const mpz_t *rhs); +mpz_t *mpz_mul(const mpz_t *lhs, const mpz_t *rhs); +mpz_t *mpz_pow(const mpz_t *lhs, const mpz_t *rhs); + +void mpz_abs_inpl(mpz_t *dest, const mpz_t *z); +void mpz_neg_inpl(mpz_t *dest, const mpz_t *z); +void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); +void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); +void mpz_mul_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); +void mpz_pow_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); + +mpz_t *mpz_gcd(const mpz_t *z1, const mpz_t *z2); +mpz_t *mpz_lcm(const mpz_t *z1, const mpz_t *z2); +void mpz_divmod(const mpz_t *lhs, const mpz_t *rhs, mpz_t **quo, mpz_t **rem); +void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const mpz_t *rhs); +mpz_t *mpz_div(const mpz_t *lhs, const mpz_t *rhs); +mpz_t *mpz_mod(const mpz_t *lhs, const mpz_t *rhs); + +int mpz_as_int(const mpz_t *z); +machine_float_t mpz_as_float(const mpz_t *z); +uint mpz_as_str_size(const mpz_t *z, uint base); +char *mpz_as_str(const mpz_t *z, uint base); +uint mpz_as_str_inpl(const mpz_t *z, uint base, char *str); diff --git a/py/obj.h b/py/obj.h index c2b127c328..88202bbcc0 100644 --- a/py/obj.h +++ b/py/obj.h @@ -225,6 +225,9 @@ mp_obj_t mp_obj_new_cell(mp_obj_t obj); mp_obj_t mp_obj_new_int(machine_int_t value); mp_obj_t mp_obj_new_int_from_uint(machine_uint_t value); mp_obj_t mp_obj_new_int_from_long_str(const char *s); +#if MICROPY_LONGINT_IMPL != MICROPY_LONGINT_IMPL_NONE +mp_obj_t mp_obj_new_int_from_ll(long long val); +#endif mp_obj_t mp_obj_new_str(const byte* data, uint len, bool make_qstr_if_not_already); mp_obj_t mp_obj_new_bytes(const byte* data, uint len); #if MICROPY_ENABLE_FLOAT diff --git a/py/objint.c b/py/objint.c index 0caaab649b..490b4340bb 100644 --- a/py/objint.c +++ b/py/objint.c @@ -9,6 +9,7 @@ #include "qstr.h" #include "obj.h" #include "parsenum.h" +#include "mpz.h" #include "objint.h" // This dispatcher function is expected to be independent of the implementation diff --git a/py/objint.h b/py/objint.h index 00f9e51d4a..53ee49e7db 100644 --- a/py/objint.h +++ b/py/objint.h @@ -2,13 +2,11 @@ typedef struct _mp_obj_int_t { mp_obj_base_t base; #if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_LONGLONG mp_longint_impl_t val; +#elif MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ + mpz_t mpz; #endif } mp_obj_int_t; void int_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj_t self_in, mp_print_kind_t kind); mp_obj_t int_unary_op(int op, mp_obj_t o_in); mp_obj_t int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in); - -#if MICROPY_LONGINT_IMPL != MICROPY_LONGINT_IMPL_NONE -mp_obj_t mp_obj_new_int_from_ll(long long val); -#endif diff --git a/py/objint_longlong.c b/py/objint_longlong.c index d07f72a555..eca2951be6 100644 --- a/py/objint_longlong.c +++ b/py/objint_longlong.c @@ -8,6 +8,7 @@ #include "mpconfig.h" #include "qstr.h" #include "obj.h" +#include "mpz.h" #include "objint.h" #include "runtime0.h" diff --git a/py/objint_mpz.c b/py/objint_mpz.c new file mode 100644 index 0000000000..a0889da9e4 --- /dev/null +++ b/py/objint_mpz.c @@ -0,0 +1,181 @@ +#include +#include +#include +#include + +#include "nlr.h" +#include "misc.h" +#include "mpconfig.h" +#include "qstr.h" +#include "obj.h" +#include "mpz.h" +#include "objint.h" +#include "runtime0.h" + +#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ + +STATIC mp_obj_int_t *mp_obj_int_new_mpz(void) { + mp_obj_int_t *o = m_new_obj(mp_obj_int_t); + o->base.type = &int_type; + mpz_init_zero(&o->mpz); + return o; +} + +void int_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj_t self_in, mp_print_kind_t kind) { + if (MP_OBJ_IS_SMALL_INT(self_in)) { + print(env, INT_FMT, MP_OBJ_SMALL_INT_VALUE(self_in)); + } else { + // TODO would rather not allocate memory to print... + mp_obj_int_t *self = self_in; + char *str = mpz_as_str(&self->mpz, 10); + print(env, "%s", str); + m_free(str, 0); + } +} + +mp_obj_t int_unary_op(int op, mp_obj_t o_in) { + mp_obj_int_t *o = o_in; + switch (op) { + case RT_UNARY_OP_BOOL: return MP_BOOL(!mpz_is_zero(&o->mpz)); + case RT_UNARY_OP_POSITIVE: return o_in; + case RT_UNARY_OP_NEGATIVE: { mp_obj_int_t *o2 = mp_obj_int_new_mpz(); mpz_neg_inpl(&o2->mpz, &o->mpz); return o2; } + //case RT_UNARY_OP_INVERT: ~ not implemented for mpz + default: return NULL; // op not supported + } +} + +mp_obj_t int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { + mpz_t *zlhs = &((mp_obj_int_t*)lhs_in)->mpz; + mpz_t *zrhs; + + if (MP_OBJ_IS_SMALL_INT(rhs_in)) { + zrhs = mpz_from_int(MP_OBJ_SMALL_INT_VALUE(rhs_in)); + } else if (MP_OBJ_IS_TYPE(rhs_in, &int_type)) { + zrhs = &((mp_obj_int_t*)rhs_in)->mpz; + } else { + return MP_OBJ_NULL; + } + + if (op == RT_BINARY_OP_TRUE_DIVIDE || op == RT_BINARY_OP_INPLACE_TRUE_DIVIDE) { + machine_float_t flhs = mpz_as_float(zlhs); + machine_float_t frhs = mpz_as_float(zrhs); + return mp_obj_new_float(flhs / frhs); + + } else if (op <= RT_BINARY_OP_POWER) { + mp_obj_int_t *res = mp_obj_int_new_mpz(); + + switch (op) { + case RT_BINARY_OP_ADD: + case RT_BINARY_OP_INPLACE_ADD: + mpz_add_inpl(&res->mpz, zlhs, zrhs); + break; + case RT_BINARY_OP_SUBTRACT: + case RT_BINARY_OP_INPLACE_SUBTRACT: + mpz_sub_inpl(&res->mpz, zlhs, zrhs); + break; + case RT_BINARY_OP_MULTIPLY: + case RT_BINARY_OP_INPLACE_MULTIPLY: + mpz_mul_inpl(&res->mpz, zlhs, zrhs); + break; + case RT_BINARY_OP_FLOOR_DIVIDE: + case RT_BINARY_OP_INPLACE_FLOOR_DIVIDE: { + mpz_t rem; mpz_init_zero(&rem); + mpz_divmod_inpl(&res->mpz, &rem, zlhs, zrhs); + mpz_deinit(&rem); + break; + } + + //case RT_BINARY_OP_MODULO: + //case RT_BINARY_OP_INPLACE_MODULO: + + //case RT_BINARY_OP_AND: + //case RT_BINARY_OP_INPLACE_AND: + //case RT_BINARY_OP_OR: + //case RT_BINARY_OP_INPLACE_OR: + //case RT_BINARY_OP_XOR: + //case RT_BINARY_OP_INPLACE_XOR: + + //case RT_BINARY_OP_LSHIFT: + //case RT_BINARY_OP_INPLACE_LSHIFT: + //case RT_BINARY_OP_RSHIFT: + //case RT_BINARY_OP_INPLACE_RSHIFT: + + case RT_BINARY_OP_POWER: + case RT_BINARY_OP_INPLACE_POWER: + mpz_pow_inpl(&res->mpz, zlhs, zrhs); + break; + + default: + return MP_OBJ_NULL; + } + + return res; + + } else { + int cmp = mpz_cmp(zlhs, zrhs); + switch (op) { + case RT_BINARY_OP_LESS: + return MP_BOOL(cmp < 0); + case RT_BINARY_OP_MORE: + return MP_BOOL(cmp > 0); + case RT_BINARY_OP_LESS_EQUAL: + return MP_BOOL(cmp <= 0); + case RT_BINARY_OP_MORE_EQUAL: + return MP_BOOL(cmp >= 0); + case RT_BINARY_OP_EQUAL: + return MP_BOOL(cmp == 0); + case RT_BINARY_OP_NOT_EQUAL: + return MP_BOOL(cmp != 0); + + default: + return MP_OBJ_NULL; + } + } +} + +mp_obj_t mp_obj_new_int(machine_int_t value) { + if (MP_OBJ_FITS_SMALL_INT(value)) { + return MP_OBJ_NEW_SMALL_INT(value); + } + return mp_obj_new_int_from_ll(value); +} + +mp_obj_t mp_obj_new_int_from_ll(long long val) { + mp_obj_int_t *o = mp_obj_int_new_mpz(); + mpz_set_from_int(&o->mpz, val); + return o; +} + +mp_obj_t mp_obj_new_int_from_uint(machine_uint_t value) { + // SMALL_INT accepts only signed numbers, of one bit less size + // than word size, which totals 2 bits less for unsigned numbers. + if ((value & (WORD_MSBIT_HIGH | (WORD_MSBIT_HIGH >> 1))) == 0) { + return MP_OBJ_NEW_SMALL_INT(value); + } + return mp_obj_new_int_from_ll(value); +} + +mp_obj_t mp_obj_new_int_from_long_str(const char *str) { + mp_obj_int_t *o = mp_obj_int_new_mpz(); + uint len = strlen(str); + uint n = mpz_set_from_str(&o->mpz, str, len, false, 10); + if (n != len) { + nlr_jump(mp_obj_new_exception_msg(&mp_type_SyntaxError, "invalid syntax for number")); + } + return o; +} + +machine_int_t mp_obj_int_get(mp_obj_t self_in) { + if (MP_OBJ_IS_SMALL_INT(self_in)) { + return MP_OBJ_SMALL_INT_VALUE(self_in); + } + mp_obj_int_t *self = self_in; + return mpz_as_int(&self->mpz); +} + +machine_int_t mp_obj_int_get_checked(mp_obj_t self_in) { + // TODO: Check overflow + return mp_obj_int_get(self_in); +} + +#endif diff --git a/py/py.mk b/py/py.mk index 199c3aadc8..0285bc05fa 100644 --- a/py/py.mk +++ b/py/py.mk @@ -15,6 +15,7 @@ PY_O_BASENAME = \ qstr.o \ vstr.o \ unicode.o \ + mpz.o \ lexer.o \ lexerstr.o \ lexerunix.o \ @@ -51,6 +52,7 @@ PY_O_BASENAME = \ objgetitemiter.o \ objint.o \ objint_longlong.o \ + objint_mpz.o \ objlist.o \ objmap.o \ objmodule.o \ diff --git a/stm/mpconfigport.h b/stm/mpconfigport.h index c5614f4bb8..33ebae8fbc 100644 --- a/stm/mpconfigport.h +++ b/stm/mpconfigport.h @@ -7,6 +7,7 @@ #define MICROPY_ENABLE_GC (1) #define MICROPY_ENABLE_REPL_HELPERS (1) #define MICROPY_ENABLE_FLOAT (1) +#define MICROPY_LONGINT_IMPL (MICROPY_LONGINT_IMPL_MPZ) #define MICROPY_PATH_MAX (128) // type definitions for the specific machine diff --git a/unix/mpconfigport.h b/unix/mpconfigport.h index 5b2503f4da..456ec02d5b 100644 --- a/unix/mpconfigport.h +++ b/unix/mpconfigport.h @@ -14,7 +14,7 @@ #define MICROPY_ENABLE_LEXER_UNIX (1) #define MICROPY_ENABLE_SOURCE_LINE (1) #define MICROPY_ENABLE_FLOAT (1) -#define MICROPY_LONGINT_IMPL (MICROPY_LONGINT_IMPL_LONGLONG) +#define MICROPY_LONGINT_IMPL (MICROPY_LONGINT_IMPL_MPZ) #define MICROPY_PATH_MAX (PATH_MAX) // type definitions for the specific machine