bigint: Use a better Montgomery RR doubling-vs-squaring trade-off.

Clarify how the math works, and use a slightly better trade-off of
doubling vs squaring. On 64-bit targets RSA verification is now
less than 10% faster. On 32-bit targets its over 20% faster. I
expect that we can improve the performance further by optimizing
the doubling implementation.

Also the new implementation avoids allocating/cloning any temporary
`Elem`s, unlike the previous implementation.
This commit is contained in:
Brian Smith 2023-11-15 19:55:33 -08:00
parent 90dd9218cd
commit 4f825b77eb

View File

@ -48,7 +48,6 @@ use crate::{
bits::BitLength,
c, cpu, error,
limb::{self, Limb, LimbMask, LIMB_BITS},
polyfill::u64_from_usize,
};
use alloc::vec;
use core::{marker::PhantomData, num::NonZeroU64};
@ -276,46 +275,72 @@ impl<M> One<M, RR> {
// is correct because R**2 will still be a multiple of the latter as
// `N0::LIMBS_USED` is either one or two.
fn newRR(m: &Modulus<M>) -> Self {
let m_bits = m.len_bits().as_usize_bits();
let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
// The number of limbs in the numbers involved.
let w = m.limbs().len();
// base = 2**r (mod m) == R (mod m).
let mut base = m.zero();
m.oneR(&mut base.limbs);
// The length of the numbers involved, in bits. R = 2**r.
let r = w * LIMB_BITS;
// Double `base` so that base == 2*R (mod m), i.e. `2` in Montgomery
// form (`elem_exp_vartime()` requires the base to be in Montgomery
// form). Then compute
// RR = R**2 == base**r == R**r == (2**r)**r (mod m).
let mut acc: Elem<M, R> = m.zero();
m.oneR(&mut acc.limbs);
// 2**t * R can be calculated by t doublings starting with R.
//
// Take advantage of the fact that `elem_double` is faster than
// `elem_squared` by replacing some of the early squarings with
// doublings.
// TODO: Benchmark doubling vs. squaring performance to determine the
// optimal value of `LG_BASE`.
const LG_BASE: usize = 2; // Doubling vs. squaring trade-off.
debug_assert_eq!(LG_BASE.count_ones(), 1); // Must be 2**n for n >= 0.
let doublings = LG_BASE;
// `m_bits >= LG_BASE` (for the currently chosen value of `LG_BASE`)
// since we require the modulus to have at least `MODULUS_MIN_LIMBS`
// limbs. `r >= m_bits` as seen above. So `r >= LG_BASE` and thus
// `r / LG_BASE` is non-zero.
// Choose a t that divides r and where t doublings are cheaper than 1 squaring.
//
// The maximum value of `r` is determined by
// `MODULUS_MAX_LIMBS * LIMB_BITS`. Further `r` is a multiple of
// `LIMB_BITS` so the maximum Hamming Weight is bounded by
// `MODULUS_MAX_LIMBS`. For the common case of {2048, 4096, 8192}-bit
// moduli the Hamming weight is 1. For the other common case of 3072
// the Hamming weight is 2.
let exponent = NonZeroU64::new(u64_from_usize(r / LG_BASE)).unwrap();
for _ in 0..doublings {
elem_double(&mut base, m)
// We could choose other values of t than w. But if t < d then the exponentiation that
// follows would require multiplications. Normally d is 1 (i.e. the modulus length is a
// power of two: RSA 1024, 2048, 4097, 8192) or 3 (RSA 1536, 3072).
//
// XXX(perf): Currently t = w / 2 is slightly faster. TODO(perf): Optimize `elem_double`
// and re-run benchmarks to rebalance this.
let t = w;
let z = w.trailing_zeros();
let d = w >> z;
debug_assert_eq!(w, d * (1 << z));
debug_assert!(d <= t);
debug_assert!(t < r);
for _ in 0..t {
elem_double(&mut acc, m);
}
// Because t | r:
//
// MontExp(2**t * R, r / t)
// = (2**t)**(r / t) * R (mod m) by definition of MontExp.
// = (2**t)**(1/t * r) * R (mod m)
// = (2**(t * 1/t))**r * R (mod m)
// = (2**1)**r * R (mod m)
// = 2**r * R (mod m)
// = R * R (mod m)
// = RR
//
// Like BoringSSL, use t = w (`m.limbs.len()`) which ensures that the exponent is a power
// of two. Consequently, there will be no multiplications in the Montgomery exponentiation;
// there will only be lg(r / t) squarings.
//
// lg(r / t)
// = lg((w * 2**b) / t)
// = lg((t * 2**b) / t)
// = lg(2**b)
// = b
// TODO(MSRV:1.67): const B: u32 = LIMB_BITS.ilog2();
const B: u32 = if cfg!(target_pointer_width = "64") {
6
} else if cfg!(target_pointer_width = "32") {
5
} else {
panic!("unsupported target_pointer_width")
};
#[allow(clippy::assertions_on_constants)]
const _LIMB_BITS_IS_2_POW_B: () = assert!(LIMB_BITS == 1 << B);
debug_assert_eq!(r, t * (1 << B));
for _ in 0..B {
acc = elem_squared(acc, m);
}
let RR = elem_exp_vartime(base, exponent, m);
Self(Elem {
limbs: RR.limbs,
limbs: acc.limbs,
encoding: PhantomData, // PhantomData<RR>
})
}