diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index e177806fa..ca065fe07 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -289,10 +289,19 @@ impl One { let m_bits = m.len_bits().as_usize_bits(); let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS; - // base = 2**(lg m - 1). - let bit = m_bits - 1; + // base = 2**r - m. let mut base = m.zero(); - base.limbs[bit / LIMB_BITS] = 1 << (bit % LIMB_BITS); + limb::limbs_negative_odd(&mut base.limbs, m.limbs()); + + // Correct base to 2**(lg m) (mod m). + let lg_m = m.len_bits().as_usize_bits(); + let leading_zero_bits_in_m = r - lg_m; + if leading_zero_bits_in_m != 0 { + debug_assert!(leading_zero_bits_in_m < LIMB_BITS); + // `limbs_negative_odd` flipped all the leading zero bits to ones. + // Flip them back. + *base.limbs.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m; + } // Double `base` so that base == R == 2**r (mod m). For normal moduli // that have the high bit of the highest limb set, this requires one @@ -312,7 +321,7 @@ impl One { const LG_BASE: usize = 2; // Doubling vs. squaring trade-off. debug_assert_eq!(LG_BASE.count_ones(), 1); // Must be 2**n for n >= 0. - let doublings = r - bit + LG_BASE; + let doublings = leading_zero_bits_in_m + LG_BASE; // `m_bits >= LG_BASE` (for the currently chosen value of `LG_BASE`) // since we require the modulus to have at least `MODULUS_MIN_LIMBS` // limbs. `r >= m_bits` as seen above. So `r >= LG_BASE` and thus diff --git a/src/limb.rs b/src/limb.rs index 582510112..8dd53099e 100644 --- a/src/limb.rs +++ b/src/limb.rs @@ -350,6 +350,19 @@ pub(crate) fn limbs_add_assign_mod(a: &mut [Limb], b: &[Limb], m: &[Limb]) { unsafe { LIMBS_add_mod(a.as_mut_ptr(), a.as_ptr(), b.as_ptr(), m.as_ptr(), m.len()) } } +// *r = -a, assuming a is odd. +pub(crate) fn limbs_negative_odd(r: &mut [Limb], a: &[Limb]) { + debug_assert_eq!(r.len(), a.len()); + // Two's complement step 1: flip all the bits. + // The compiler should optimize this to vectorized (a ^ !0). + r.iter_mut().zip(a.iter()).for_each(|(r, &a)| { + *r = !a; + }); + // Two's complement step 2: Add one. Since `a` is odd, `r` is even. Thus we + // can use a bitwise or for addition. + r[0] |= 1; +} + prefixed_extern! { fn LIMBS_are_zero(a: *const Limb, num_limbs: c::size_t) -> LimbMask; fn LIMBS_less_than(a: *const Limb, b: *const Limb, num_limbs: c::size_t) -> LimbMask;