From 4f825b77eb78656dcc41e66406cef76c17369d59 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Wed, 15 Nov 2023 19:55:33 -0800 Subject: [PATCH] bigint: Use a better Montgomery RR doubling-vs-squaring trade-off. Clarify how the math works, and use a slightly better trade-off of doubling vs squaring. On 64-bit targets RSA verification is now less than 10% faster. On 32-bit targets its over 20% faster. I expect that we can improve the performance further by optimizing the doubling implementation. Also the new implementation avoids allocating/cloning any temporary `Elem`s, unlike the previous implementation. --- src/arithmetic/bigint.rs | 93 +++++++++++++++++++++++++--------------- 1 file changed, 59 insertions(+), 34 deletions(-) diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index c71e1c9a2..319738e93 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -48,7 +48,6 @@ use crate::{ bits::BitLength, c, cpu, error, limb::{self, Limb, LimbMask, LIMB_BITS}, - polyfill::u64_from_usize, }; use alloc::vec; use core::{marker::PhantomData, num::NonZeroU64}; @@ -276,46 +275,72 @@ impl One { // is correct because R**2 will still be a multiple of the latter as // `N0::LIMBS_USED` is either one or two. fn newRR(m: &Modulus) -> Self { - let m_bits = m.len_bits().as_usize_bits(); - let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS; + // The number of limbs in the numbers involved. + let w = m.limbs().len(); - // base = 2**r (mod m) == R (mod m). - let mut base = m.zero(); - m.oneR(&mut base.limbs); + // The length of the numbers involved, in bits. R = 2**r. + let r = w * LIMB_BITS; - // Double `base` so that base == 2*R (mod m), i.e. `2` in Montgomery - // form (`elem_exp_vartime()` requires the base to be in Montgomery - // form). Then compute - // RR = R**2 == base**r == R**r == (2**r)**r (mod m). + let mut acc: Elem = m.zero(); + m.oneR(&mut acc.limbs); + + // 2**t * R can be calculated by t doublings starting with R. // - // Take advantage of the fact that `elem_double` is faster than - // `elem_squared` by replacing some of the early squarings with - // doublings. - // TODO: Benchmark doubling vs. squaring performance to determine the - // optimal value of `LG_BASE`. - const LG_BASE: usize = 2; // Doubling vs. squaring trade-off. - debug_assert_eq!(LG_BASE.count_ones(), 1); // Must be 2**n for n >= 0. - - let doublings = LG_BASE; - // `m_bits >= LG_BASE` (for the currently chosen value of `LG_BASE`) - // since we require the modulus to have at least `MODULUS_MIN_LIMBS` - // limbs. `r >= m_bits` as seen above. So `r >= LG_BASE` and thus - // `r / LG_BASE` is non-zero. + // Choose a t that divides r and where t doublings are cheaper than 1 squaring. // - // The maximum value of `r` is determined by - // `MODULUS_MAX_LIMBS * LIMB_BITS`. Further `r` is a multiple of - // `LIMB_BITS` so the maximum Hamming Weight is bounded by - // `MODULUS_MAX_LIMBS`. For the common case of {2048, 4096, 8192}-bit - // moduli the Hamming weight is 1. For the other common case of 3072 - // the Hamming weight is 2. - let exponent = NonZeroU64::new(u64_from_usize(r / LG_BASE)).unwrap(); - for _ in 0..doublings { - elem_double(&mut base, m) + // We could choose other values of t than w. But if t < d then the exponentiation that + // follows would require multiplications. Normally d is 1 (i.e. the modulus length is a + // power of two: RSA 1024, 2048, 4097, 8192) or 3 (RSA 1536, 3072). + // + // XXX(perf): Currently t = w / 2 is slightly faster. TODO(perf): Optimize `elem_double` + // and re-run benchmarks to rebalance this. + let t = w; + let z = w.trailing_zeros(); + let d = w >> z; + debug_assert_eq!(w, d * (1 << z)); + debug_assert!(d <= t); + debug_assert!(t < r); + for _ in 0..t { + elem_double(&mut acc, m); + } + + // Because t | r: + // + // MontExp(2**t * R, r / t) + // = (2**t)**(r / t) * R (mod m) by definition of MontExp. + // = (2**t)**(1/t * r) * R (mod m) + // = (2**(t * 1/t))**r * R (mod m) + // = (2**1)**r * R (mod m) + // = 2**r * R (mod m) + // = R * R (mod m) + // = RR + // + // Like BoringSSL, use t = w (`m.limbs.len()`) which ensures that the exponent is a power + // of two. Consequently, there will be no multiplications in the Montgomery exponentiation; + // there will only be lg(r / t) squarings. + // + // lg(r / t) + // = lg((w * 2**b) / t) + // = lg((t * 2**b) / t) + // = lg(2**b) + // = b + // TODO(MSRV:1.67): const B: u32 = LIMB_BITS.ilog2(); + const B: u32 = if cfg!(target_pointer_width = "64") { + 6 + } else if cfg!(target_pointer_width = "32") { + 5 + } else { + panic!("unsupported target_pointer_width") + }; + #[allow(clippy::assertions_on_constants)] + const _LIMB_BITS_IS_2_POW_B: () = assert!(LIMB_BITS == 1 << B); + debug_assert_eq!(r, t * (1 << B)); + for _ in 0..B { + acc = elem_squared(acc, m); } - let RR = elem_exp_vartime(base, exponent, m); Self(Elem { - limbs: RR.limbs, + limbs: acc.limbs, encoding: PhantomData, // PhantomData }) }