diff --git a/py/mpz.c b/py/mpz.c index 51d0c96890..e0d249c214 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -531,60 +531,37 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di quo /= lead_den_digit; // Multiply quo by den and subtract from num to get remainder. - // We have different code here to handle different compile-time - // configurations of mpz: - // - // 1. DIG_SIZE is stricly less than half the number of bits - // available in mpz_dbl_dig_t. In this case we can use a - // slightly more optimal (in time and space) routine that - // uses the extra bits in mpz_dbl_dig_signed_t to store a - // sign bit. - // - // 2. DIG_SIZE is exactly half the number of bits available in - // mpz_dbl_dig_t. In this (common) case we need to be careful - // not to overflow the borrow variable. And the shifting of - // borrow needs some special logic (it's a shift right with - // round up). - // + // Must be careful with overflow of the borrow variable. Both + // borrow and low_digs are signed values and need signed right-shift, + // but x is unsigned and may take a full-range value. const mpz_dig_t *d = den_dig; mpz_dbl_dig_t d_norm = 0; - mpz_dbl_dig_t borrow = 0; + mpz_dbl_dig_signed_t borrow = 0; for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { + // Get the next digit in (den). d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); + // Multiply the next digit in (quo * den). mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); - #if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2 - borrow += (mpz_dbl_dig_t)*n - x; // will overflow if DIG_SIZE >= MPZ_DBL_DIG_SIZE/2 - *n = borrow & DIG_MASK; - borrow = (mpz_dbl_dig_signed_t)borrow >> DIG_SIZE; - #else // DIG_SIZE == MPZ_DBL_DIG_SIZE / 2 - if (x >= *n || *n - x <= borrow) { - borrow += x - (mpz_dbl_dig_t)*n; - *n = (-borrow) & DIG_MASK; - borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up - } else { - *n = ((mpz_dbl_dig_t)*n - x - borrow) & DIG_MASK; - borrow = 0; - } - #endif + // Compute the low DIG_MASK bits of the next digit in (num - quo * den) + mpz_dbl_dig_signed_t low_digs = (borrow & DIG_MASK) + *n - (x & DIG_MASK); + // Store the digit result for (num). + *n = low_digs & DIG_MASK; + // Compute the borrow, shifted right before summing to avoid overflow. + borrow = (borrow >> DIG_SIZE) - (x >> DIG_SIZE) + (low_digs >> DIG_SIZE); } - #if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2 - // Borrow was negative in the above for-loop, make it positive for next if-block. - borrow = -borrow; - #endif - // At this point we have either: // // 1. quo was the correct value and the most-sig-digit of num is exactly - // cancelled by borrow (borrow == *num_dig). In this case there is + // cancelled by borrow (borrow + *num_dig == 0). In this case there is // nothing more to do. // // 2. quo was too large, we subtracted too many den from num, and the - // most-sig-digit of num is 1 less than borrow (borrow == *num_dig + 1). + // most-sig-digit of num is less than needed (borrow + *num_dig < 0). // In this case we must reduce quo and add back den to num until the // carry from this operation cancels out the borrow. // - borrow -= *num_dig; + borrow += *num_dig; for (; borrow != 0; --quo) { d = den_dig; d_norm = 0; @@ -595,7 +572,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di *n = carry & DIG_MASK; carry >>= DIG_SIZE; } - borrow -= carry; + borrow += carry; } // store this digit of the quotient diff --git a/tests/basics/int_big_div.py b/tests/basics/int_big_div.py index 642f051d41..29fd405970 100644 --- a/tests/basics/int_big_div.py +++ b/tests/basics/int_big_div.py @@ -8,3 +8,7 @@ x = 0x8000000000000000 print((x + 1) // x) x = 0x86c60128feff5330 print((x + 1) // x) + +# these check edge cases where borrow overflows +print((2 ** 48 - 1) ** 2 // (2 ** 48 - 1)) +print((2 ** 256 - 2 ** 32) ** 2 // (2 ** 256 - 2 ** 32))