bigint: Use a better Montgomery RR doubling-vs-squaring trade-off.
Clarify how the math works, and use a slightly better trade-off of doubling vs squaring. On 64-bit targets RSA verification is now less than 10% faster. On 32-bit targets its over 20% faster. I expect that we can improve the performance further by optimizing the doubling implementation. Also the new implementation avoids allocating/cloning any temporary `Elem`s, unlike the previous implementation.
This commit is contained in:
parent
90dd9218cd
commit
4f825b77eb
@ -48,7 +48,6 @@ use crate::{
|
||||
bits::BitLength,
|
||||
c, cpu, error,
|
||||
limb::{self, Limb, LimbMask, LIMB_BITS},
|
||||
polyfill::u64_from_usize,
|
||||
};
|
||||
use alloc::vec;
|
||||
use core::{marker::PhantomData, num::NonZeroU64};
|
||||
@ -276,46 +275,72 @@ impl<M> One<M, RR> {
|
||||
// is correct because R**2 will still be a multiple of the latter as
|
||||
// `N0::LIMBS_USED` is either one or two.
|
||||
fn newRR(m: &Modulus<M>) -> Self {
|
||||
let m_bits = m.len_bits().as_usize_bits();
|
||||
let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
|
||||
// The number of limbs in the numbers involved.
|
||||
let w = m.limbs().len();
|
||||
|
||||
// base = 2**r (mod m) == R (mod m).
|
||||
let mut base = m.zero();
|
||||
m.oneR(&mut base.limbs);
|
||||
// The length of the numbers involved, in bits. R = 2**r.
|
||||
let r = w * LIMB_BITS;
|
||||
|
||||
// Double `base` so that base == 2*R (mod m), i.e. `2` in Montgomery
|
||||
// form (`elem_exp_vartime()` requires the base to be in Montgomery
|
||||
// form). Then compute
|
||||
// RR = R**2 == base**r == R**r == (2**r)**r (mod m).
|
||||
let mut acc: Elem<M, R> = m.zero();
|
||||
m.oneR(&mut acc.limbs);
|
||||
|
||||
// 2**t * R can be calculated by t doublings starting with R.
|
||||
//
|
||||
// Take advantage of the fact that `elem_double` is faster than
|
||||
// `elem_squared` by replacing some of the early squarings with
|
||||
// doublings.
|
||||
// TODO: Benchmark doubling vs. squaring performance to determine the
|
||||
// optimal value of `LG_BASE`.
|
||||
const LG_BASE: usize = 2; // Doubling vs. squaring trade-off.
|
||||
debug_assert_eq!(LG_BASE.count_ones(), 1); // Must be 2**n for n >= 0.
|
||||
|
||||
let doublings = LG_BASE;
|
||||
// `m_bits >= LG_BASE` (for the currently chosen value of `LG_BASE`)
|
||||
// since we require the modulus to have at least `MODULUS_MIN_LIMBS`
|
||||
// limbs. `r >= m_bits` as seen above. So `r >= LG_BASE` and thus
|
||||
// `r / LG_BASE` is non-zero.
|
||||
// Choose a t that divides r and where t doublings are cheaper than 1 squaring.
|
||||
//
|
||||
// The maximum value of `r` is determined by
|
||||
// `MODULUS_MAX_LIMBS * LIMB_BITS`. Further `r` is a multiple of
|
||||
// `LIMB_BITS` so the maximum Hamming Weight is bounded by
|
||||
// `MODULUS_MAX_LIMBS`. For the common case of {2048, 4096, 8192}-bit
|
||||
// moduli the Hamming weight is 1. For the other common case of 3072
|
||||
// the Hamming weight is 2.
|
||||
let exponent = NonZeroU64::new(u64_from_usize(r / LG_BASE)).unwrap();
|
||||
for _ in 0..doublings {
|
||||
elem_double(&mut base, m)
|
||||
// We could choose other values of t than w. But if t < d then the exponentiation that
|
||||
// follows would require multiplications. Normally d is 1 (i.e. the modulus length is a
|
||||
// power of two: RSA 1024, 2048, 4097, 8192) or 3 (RSA 1536, 3072).
|
||||
//
|
||||
// XXX(perf): Currently t = w / 2 is slightly faster. TODO(perf): Optimize `elem_double`
|
||||
// and re-run benchmarks to rebalance this.
|
||||
let t = w;
|
||||
let z = w.trailing_zeros();
|
||||
let d = w >> z;
|
||||
debug_assert_eq!(w, d * (1 << z));
|
||||
debug_assert!(d <= t);
|
||||
debug_assert!(t < r);
|
||||
for _ in 0..t {
|
||||
elem_double(&mut acc, m);
|
||||
}
|
||||
|
||||
// Because t | r:
|
||||
//
|
||||
// MontExp(2**t * R, r / t)
|
||||
// = (2**t)**(r / t) * R (mod m) by definition of MontExp.
|
||||
// = (2**t)**(1/t * r) * R (mod m)
|
||||
// = (2**(t * 1/t))**r * R (mod m)
|
||||
// = (2**1)**r * R (mod m)
|
||||
// = 2**r * R (mod m)
|
||||
// = R * R (mod m)
|
||||
// = RR
|
||||
//
|
||||
// Like BoringSSL, use t = w (`m.limbs.len()`) which ensures that the exponent is a power
|
||||
// of two. Consequently, there will be no multiplications in the Montgomery exponentiation;
|
||||
// there will only be lg(r / t) squarings.
|
||||
//
|
||||
// lg(r / t)
|
||||
// = lg((w * 2**b) / t)
|
||||
// = lg((t * 2**b) / t)
|
||||
// = lg(2**b)
|
||||
// = b
|
||||
// TODO(MSRV:1.67): const B: u32 = LIMB_BITS.ilog2();
|
||||
const B: u32 = if cfg!(target_pointer_width = "64") {
|
||||
6
|
||||
} else if cfg!(target_pointer_width = "32") {
|
||||
5
|
||||
} else {
|
||||
panic!("unsupported target_pointer_width")
|
||||
};
|
||||
#[allow(clippy::assertions_on_constants)]
|
||||
const _LIMB_BITS_IS_2_POW_B: () = assert!(LIMB_BITS == 1 << B);
|
||||
debug_assert_eq!(r, t * (1 << B));
|
||||
for _ in 0..B {
|
||||
acc = elem_squared(acc, m);
|
||||
}
|
||||
let RR = elem_exp_vartime(base, exponent, m);
|
||||
|
||||
Self(Elem {
|
||||
limbs: RR.limbs,
|
||||
limbs: acc.limbs,
|
||||
encoding: PhantomData, // PhantomData<RR>
|
||||
})
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user