bigint: Calculate 1*R mod m without multiplication by 1*RR.

Save two private-modulus Montgomery multiplications per RSA exponentiation
at the cost of approximately two modulus-wide XORs.

The new new `oneR()` is extracted from the Montgomery RR setup.

Remove the use of `One<RR>` in `elem_exp_consttime`.
This commit is contained in:
Brian Smith 2023-11-11 16:33:13 -08:00
parent 81e17e4b10
commit 25112e9546
4 changed files with 54 additions and 54 deletions

View File

@ -166,17 +166,7 @@ where
// r *= 2. // r *= 2.
fn elem_double<M, AF>(r: &mut Elem<M, AF>, m: &Modulus<M>) { fn elem_double<M, AF>(r: &mut Elem<M, AF>, m: &Modulus<M>) {
prefixed_extern! { limb::limbs_double_mod(&mut r.limbs, m.limbs())
fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
}
unsafe {
LIMBS_shl_mod(
r.limbs.as_mut_ptr(),
r.limbs.as_ptr(),
m.limbs().as_ptr(),
m.limbs().len(),
);
}
} }
// TODO: This is currently unused, but we intend to eventually use this to // TODO: This is currently unused, but we intend to eventually use this to
@ -289,28 +279,13 @@ impl<M> One<M, RR> {
let m_bits = m.len_bits().as_usize_bits(); let m_bits = m.len_bits().as_usize_bits();
let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS; let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
// base = 2**r - m. // base = 2**r (mod m) == R (mod m).
let mut base = m.zero(); let mut base = m.zero();
limb::limbs_negative_odd(&mut base.limbs, m.limbs()); m.oneR(&mut base.limbs);
// Correct base to 2**(lg m) (mod m). // Double `base` so that base == 2*R (mod m), i.e. `2` in Montgomery
let lg_m = m.len_bits().as_usize_bits(); // form (`elem_exp_vartime()` requires the base to be in Montgomery
let leading_zero_bits_in_m = r - lg_m; // form). Then compute
if leading_zero_bits_in_m != 0 {
debug_assert!(leading_zero_bits_in_m < LIMB_BITS);
// `limbs_negative_odd` flipped all the leading zero bits to ones.
// Flip them back.
*base.limbs.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m;
}
// Double `base` so that base == R == 2**r (mod m). For normal moduli
// that have the high bit of the highest limb set, this requires one
// doubling. Unusual moduli require more doublings but we are less
// concerned about the performance of those.
//
// Then double `base` again so that base == 2*R (mod m), i.e. `2` in
// Montgomery form (`elem_exp_vartime()` requires the base to be in
// Montgomery form). Then compute
// RR = R**2 == base**r == R**r == (2**r)**r (mod m). // RR = R**2 == base**r == R**r == (2**r)**r (mod m).
// //
// Take advantage of the fact that `elem_double` is faster than // Take advantage of the fact that `elem_double` is faster than
@ -321,7 +296,7 @@ impl<M> One<M, RR> {
const LG_BASE: usize = 2; // Doubling vs. squaring trade-off. const LG_BASE: usize = 2; // Doubling vs. squaring trade-off.
debug_assert_eq!(LG_BASE.count_ones(), 1); // Must be 2**n for n >= 0. debug_assert_eq!(LG_BASE.count_ones(), 1); // Must be 2**n for n >= 0.
let doublings = leading_zero_bits_in_m + LG_BASE; let doublings = LG_BASE;
// `m_bits >= LG_BASE` (for the currently chosen value of `LG_BASE`) // `m_bits >= LG_BASE` (for the currently chosen value of `LG_BASE`)
// since we require the modulus to have at least `MODULUS_MIN_LIMBS` // since we require the modulus to have at least `MODULUS_MIN_LIMBS`
// limbs. `r >= m_bits` as seen above. So `r >= LG_BASE` and thus // limbs. `r >= m_bits` as seen above. So `r >= LG_BASE` and thus
@ -407,11 +382,8 @@ pub(crate) fn elem_exp_vartime<M>(
pub fn elem_exp_consttime<M>( pub fn elem_exp_consttime<M>(
base: Elem<M, R>, base: Elem<M, R>,
exponent: &PrivateExponent, exponent: &PrivateExponent,
m: &OwnedModulusWithOne<M>, m: &Modulus<M>,
) -> Result<Elem<M, Unencoded>, error::Unspecified> { ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
let oneRR = m.oneRR();
let m = &m.modulus();
use crate::{bssl, limb::Window}; use crate::{bssl, limb::Window};
const WINDOW_BITS: usize = 5; const WINDOW_BITS: usize = 5;
@ -459,13 +431,7 @@ pub fn elem_exp_consttime<M>(
} }
// table[0] = base**0 (i.e. 1). // table[0] = base**0 (i.e. 1).
{ m.oneR(entry_mut(&mut table, 0, num_limbs));
let acc = entry_mut(&mut table, 0, num_limbs);
// `table` was initialized to zero and hasn't changed.
debug_assert!(acc.iter().all(|&value| value == 0));
acc[0] = 1;
limbs_mont_mul(acc, &oneRR.0.limbs, m.limbs(), m.n0(), m.cpu_features());
}
entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs); entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
for i in 2..TABLE_ENTRIES { for i in 2..TABLE_ENTRIES {
@ -502,13 +468,10 @@ pub fn elem_exp_consttime<M>(
pub fn elem_exp_consttime<M>( pub fn elem_exp_consttime<M>(
base: Elem<M, R>, base: Elem<M, R>,
exponent: &PrivateExponent, exponent: &PrivateExponent,
m: &OwnedModulusWithOne<M>, m: &Modulus<M>,
) -> Result<Elem<M, Unencoded>, error::Unspecified> { ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
use crate::limb::LIMB_BYTES; use crate::limb::LIMB_BYTES;
let oneRR = m.oneRR();
let m = &m.modulus();
// Pretty much all the math here requires CPU feature detection to have // Pretty much all the math here requires CPU feature detection to have
// been done. `cpu_features` isn't threaded through all the internal // been done. `cpu_features` isn't threaded through all the internal
// functions, so just make it clear that it has been done at this point. // functions, so just make it clear that it has been done at this point.
@ -659,11 +622,7 @@ pub fn elem_exp_consttime<M>(
// All entries in `table` will be Montgomery encoded. // All entries in `table` will be Montgomery encoded.
// acc = table[0] = base**0 (i.e. 1). // acc = table[0] = base**0 (i.e. 1).
// `acc` was initialized to zero and hasn't changed. Change it to 1 and then Montgomery m.oneR(acc);
// encode it.
debug_assert!(acc.iter().all(|&value| value == 0));
acc[0] = 1;
limbs_mont_mul(acc, &oneRR.0.limbs, m_cached, n0, cpu_features);
scatter(table, acc, 0, num_limbs); scatter(table, acc, 0, num_limbs);
// acc = base**1 (i.e. base). // acc = base**1 (i.e. base).
@ -834,7 +793,7 @@ mod tests {
.expect("valid exponent") .expect("valid exponent")
}; };
let base = into_encoded(base, &m_); let base = into_encoded(base, &m_);
let actual_result = elem_exp_consttime(base, &e, &m_).unwrap(); let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
assert_elem_eq(&actual_result, &expected_result); assert_elem_eq(&actual_result, &expected_result);
Ok(()) Ok(())

View File

@ -225,6 +225,36 @@ pub struct Modulus<'a, M> {
} }
impl<M> Modulus<'_, M> { impl<M> Modulus<'_, M> {
pub(super) fn oneR(&self, out: &mut [Limb]) {
assert_eq!(self.limbs.len(), out.len());
let r = self.limbs.len() * LIMB_BITS;
// out = 2**r - m where m = self.
limb::limbs_negative_odd(out, self.limbs);
let lg_m = self.len_bits().as_usize_bits();
let leading_zero_bits_in_m = r - lg_m;
// When m's length is a multiple of LIMB_BITS, which is the case we
// most want to optimize for, then we already have
// out == 2**r - m == 2**r (mod m).
if leading_zero_bits_in_m != 0 {
debug_assert!(leading_zero_bits_in_m < LIMB_BITS);
// Correct out to 2**(lg m) (mod m). `limbs_negative_odd` flipped
// all the leading zero bits to ones. Flip them back.
*out.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m;
// Now we have out == 2**(lg m) (mod m). Keep doubling until we get
// to 2**r (mod m).
for _ in 0..leading_zero_bits_in_m {
limb::limbs_double_mod(out, self.limbs)
}
}
// Now out == 2**r (mod m) == 1*R.
}
// TODO: XXX Avoid duplication with `Modulus`. // TODO: XXX Avoid duplication with `Modulus`.
pub(super) fn zero<E>(&self) -> Elem<M, E> { pub(super) fn zero<E>(&self) -> Elem<M, E> {
Elem { Elem {

View File

@ -350,6 +350,17 @@ pub(crate) fn limbs_add_assign_mod(a: &mut [Limb], b: &[Limb], m: &[Limb]) {
unsafe { LIMBS_add_mod(a.as_mut_ptr(), a.as_ptr(), b.as_ptr(), m.as_ptr(), m.len()) } unsafe { LIMBS_add_mod(a.as_mut_ptr(), a.as_ptr(), b.as_ptr(), m.as_ptr(), m.len()) }
} }
// r *= 2 (mod m).
pub(crate) fn limbs_double_mod(r: &mut [Limb], m: &[Limb]) {
assert_eq!(r.len(), m.len());
prefixed_extern! {
fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
}
unsafe {
LIMBS_shl_mod(r.as_mut_ptr(), r.as_ptr(), m.as_ptr(), m.len());
}
}
// *r = -a, assuming a is odd. // *r = -a, assuming a is odd.
pub(crate) fn limbs_negative_odd(r: &mut [Limb], a: &[Limb]) { pub(crate) fn limbs_negative_odd(r: &mut [Limb], a: &[Limb]) {
debug_assert_eq!(r.len(), a.len()); debug_assert_eq!(r.len(), a.len());

View File

@ -468,7 +468,7 @@ fn elem_exp_consttime<M>(
// in the Smooth CRT-RSA paper. // in the Smooth CRT-RSA paper.
let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m); let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m);
let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m); let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m);
bigint::elem_exp_consttime(c_mod_m, &p.exponent, &p.modulus) bigint::elem_exp_consttime(c_mod_m, &p.exponent, m)
} }
// Type-level representations of the different moduli used in RSA signing, in // Type-level representations of the different moduli used in RSA signing, in