From 25112e95469db5717473e01d2d161962b6dcb9dd Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Sat, 11 Nov 2023 16:33:13 -0800 Subject: [PATCH] bigint: Calculate 1*R mod m without multiplication by 1*RR. Save two private-modulus Montgomery multiplications per RSA exponentiation at the cost of approximately two modulus-wide XORs. The new new `oneR()` is extracted from the Montgomery RR setup. Remove the use of `One` in `elem_exp_consttime`. --- src/arithmetic/bigint.rs | 65 ++++++-------------------------- src/arithmetic/bigint/modulus.rs | 30 +++++++++++++++ src/limb.rs | 11 ++++++ src/rsa/keypair.rs | 2 +- 4 files changed, 54 insertions(+), 54 deletions(-) diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index ca065fe07..c71e1c9a2 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -166,17 +166,7 @@ where // r *= 2. fn elem_double(r: &mut Elem, m: &Modulus) { - prefixed_extern! { - fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t); - } - unsafe { - LIMBS_shl_mod( - r.limbs.as_mut_ptr(), - r.limbs.as_ptr(), - m.limbs().as_ptr(), - m.limbs().len(), - ); - } + limb::limbs_double_mod(&mut r.limbs, m.limbs()) } // TODO: This is currently unused, but we intend to eventually use this to @@ -289,28 +279,13 @@ impl One { let m_bits = m.len_bits().as_usize_bits(); let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS; - // base = 2**r - m. + // base = 2**r (mod m) == R (mod m). let mut base = m.zero(); - limb::limbs_negative_odd(&mut base.limbs, m.limbs()); + m.oneR(&mut base.limbs); - // Correct base to 2**(lg m) (mod m). - let lg_m = m.len_bits().as_usize_bits(); - let leading_zero_bits_in_m = r - lg_m; - if leading_zero_bits_in_m != 0 { - debug_assert!(leading_zero_bits_in_m < LIMB_BITS); - // `limbs_negative_odd` flipped all the leading zero bits to ones. - // Flip them back. - *base.limbs.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m; - } - - // Double `base` so that base == R == 2**r (mod m). For normal moduli - // that have the high bit of the highest limb set, this requires one - // doubling. Unusual moduli require more doublings but we are less - // concerned about the performance of those. - // - // Then double `base` again so that base == 2*R (mod m), i.e. `2` in - // Montgomery form (`elem_exp_vartime()` requires the base to be in - // Montgomery form). Then compute + // Double `base` so that base == 2*R (mod m), i.e. `2` in Montgomery + // form (`elem_exp_vartime()` requires the base to be in Montgomery + // form). Then compute // RR = R**2 == base**r == R**r == (2**r)**r (mod m). // // Take advantage of the fact that `elem_double` is faster than @@ -321,7 +296,7 @@ impl One { const LG_BASE: usize = 2; // Doubling vs. squaring trade-off. debug_assert_eq!(LG_BASE.count_ones(), 1); // Must be 2**n for n >= 0. - let doublings = leading_zero_bits_in_m + LG_BASE; + let doublings = LG_BASE; // `m_bits >= LG_BASE` (for the currently chosen value of `LG_BASE`) // since we require the modulus to have at least `MODULUS_MIN_LIMBS` // limbs. `r >= m_bits` as seen above. So `r >= LG_BASE` and thus @@ -407,11 +382,8 @@ pub(crate) fn elem_exp_vartime( pub fn elem_exp_consttime( base: Elem, exponent: &PrivateExponent, - m: &OwnedModulusWithOne, + m: &Modulus, ) -> Result, error::Unspecified> { - let oneRR = m.oneRR(); - let m = &m.modulus(); - use crate::{bssl, limb::Window}; const WINDOW_BITS: usize = 5; @@ -459,13 +431,7 @@ pub fn elem_exp_consttime( } // table[0] = base**0 (i.e. 1). - { - let acc = entry_mut(&mut table, 0, num_limbs); - // `table` was initialized to zero and hasn't changed. - debug_assert!(acc.iter().all(|&value| value == 0)); - acc[0] = 1; - limbs_mont_mul(acc, &oneRR.0.limbs, m.limbs(), m.n0(), m.cpu_features()); - } + m.oneR(entry_mut(&mut table, 0, num_limbs)); entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs); for i in 2..TABLE_ENTRIES { @@ -502,13 +468,10 @@ pub fn elem_exp_consttime( pub fn elem_exp_consttime( base: Elem, exponent: &PrivateExponent, - m: &OwnedModulusWithOne, + m: &Modulus, ) -> Result, error::Unspecified> { use crate::limb::LIMB_BYTES; - let oneRR = m.oneRR(); - let m = &m.modulus(); - // Pretty much all the math here requires CPU feature detection to have // been done. `cpu_features` isn't threaded through all the internal // functions, so just make it clear that it has been done at this point. @@ -659,11 +622,7 @@ pub fn elem_exp_consttime( // All entries in `table` will be Montgomery encoded. // acc = table[0] = base**0 (i.e. 1). - // `acc` was initialized to zero and hasn't changed. Change it to 1 and then Montgomery - // encode it. - debug_assert!(acc.iter().all(|&value| value == 0)); - acc[0] = 1; - limbs_mont_mul(acc, &oneRR.0.limbs, m_cached, n0, cpu_features); + m.oneR(acc); scatter(table, acc, 0, num_limbs); // acc = base**1 (i.e. base). @@ -834,7 +793,7 @@ mod tests { .expect("valid exponent") }; let base = into_encoded(base, &m_); - let actual_result = elem_exp_consttime(base, &e, &m_).unwrap(); + let actual_result = elem_exp_consttime(base, &e, &m).unwrap(); assert_elem_eq(&actual_result, &expected_result); Ok(()) diff --git a/src/arithmetic/bigint/modulus.rs b/src/arithmetic/bigint/modulus.rs index 4eb8e8b91..807cb70c0 100644 --- a/src/arithmetic/bigint/modulus.rs +++ b/src/arithmetic/bigint/modulus.rs @@ -225,6 +225,36 @@ pub struct Modulus<'a, M> { } impl Modulus<'_, M> { + pub(super) fn oneR(&self, out: &mut [Limb]) { + assert_eq!(self.limbs.len(), out.len()); + + let r = self.limbs.len() * LIMB_BITS; + + // out = 2**r - m where m = self. + limb::limbs_negative_odd(out, self.limbs); + + let lg_m = self.len_bits().as_usize_bits(); + let leading_zero_bits_in_m = r - lg_m; + + // When m's length is a multiple of LIMB_BITS, which is the case we + // most want to optimize for, then we already have + // out == 2**r - m == 2**r (mod m). + if leading_zero_bits_in_m != 0 { + debug_assert!(leading_zero_bits_in_m < LIMB_BITS); + // Correct out to 2**(lg m) (mod m). `limbs_negative_odd` flipped + // all the leading zero bits to ones. Flip them back. + *out.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m; + + // Now we have out == 2**(lg m) (mod m). Keep doubling until we get + // to 2**r (mod m). + for _ in 0..leading_zero_bits_in_m { + limb::limbs_double_mod(out, self.limbs) + } + } + + // Now out == 2**r (mod m) == 1*R. + } + // TODO: XXX Avoid duplication with `Modulus`. pub(super) fn zero(&self) -> Elem { Elem { diff --git a/src/limb.rs b/src/limb.rs index 8dd53099e..ee139d626 100644 --- a/src/limb.rs +++ b/src/limb.rs @@ -350,6 +350,17 @@ pub(crate) fn limbs_add_assign_mod(a: &mut [Limb], b: &[Limb], m: &[Limb]) { unsafe { LIMBS_add_mod(a.as_mut_ptr(), a.as_ptr(), b.as_ptr(), m.as_ptr(), m.len()) } } +// r *= 2 (mod m). +pub(crate) fn limbs_double_mod(r: &mut [Limb], m: &[Limb]) { + assert_eq!(r.len(), m.len()); + prefixed_extern! { + fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t); + } + unsafe { + LIMBS_shl_mod(r.as_mut_ptr(), r.as_ptr(), m.as_ptr(), m.len()); + } +} + // *r = -a, assuming a is odd. pub(crate) fn limbs_negative_odd(r: &mut [Limb], a: &[Limb]) { debug_assert_eq!(r.len(), a.len()); diff --git a/src/rsa/keypair.rs b/src/rsa/keypair.rs index 819c6678e..f9a9e16eb 100644 --- a/src/rsa/keypair.rs +++ b/src/rsa/keypair.rs @@ -468,7 +468,7 @@ fn elem_exp_consttime( // in the Smooth CRT-RSA paper. let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m); let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m); - bigint::elem_exp_consttime(c_mod_m, &p.exponent, &p.modulus) + bigint::elem_exp_consttime(c_mod_m, &p.exponent, m) } // Type-level representations of the different moduli used in RSA signing, in