bigint: Calculate 1*R mod m without multiplication by 1*RR.
Save two private-modulus Montgomery multiplications per RSA exponentiation at the cost of approximately two modulus-wide XORs. The new new `oneR()` is extracted from the Montgomery RR setup. Remove the use of `One<RR>` in `elem_exp_consttime`.
This commit is contained in:
parent
81e17e4b10
commit
25112e9546
@ -166,17 +166,7 @@ where
|
|||||||
|
|
||||||
// r *= 2.
|
// r *= 2.
|
||||||
fn elem_double<M, AF>(r: &mut Elem<M, AF>, m: &Modulus<M>) {
|
fn elem_double<M, AF>(r: &mut Elem<M, AF>, m: &Modulus<M>) {
|
||||||
prefixed_extern! {
|
limb::limbs_double_mod(&mut r.limbs, m.limbs())
|
||||||
fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
|
|
||||||
}
|
|
||||||
unsafe {
|
|
||||||
LIMBS_shl_mod(
|
|
||||||
r.limbs.as_mut_ptr(),
|
|
||||||
r.limbs.as_ptr(),
|
|
||||||
m.limbs().as_ptr(),
|
|
||||||
m.limbs().len(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: This is currently unused, but we intend to eventually use this to
|
// TODO: This is currently unused, but we intend to eventually use this to
|
||||||
@ -289,28 +279,13 @@ impl<M> One<M, RR> {
|
|||||||
let m_bits = m.len_bits().as_usize_bits();
|
let m_bits = m.len_bits().as_usize_bits();
|
||||||
let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
|
let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
|
||||||
|
|
||||||
// base = 2**r - m.
|
// base = 2**r (mod m) == R (mod m).
|
||||||
let mut base = m.zero();
|
let mut base = m.zero();
|
||||||
limb::limbs_negative_odd(&mut base.limbs, m.limbs());
|
m.oneR(&mut base.limbs);
|
||||||
|
|
||||||
// Correct base to 2**(lg m) (mod m).
|
// Double `base` so that base == 2*R (mod m), i.e. `2` in Montgomery
|
||||||
let lg_m = m.len_bits().as_usize_bits();
|
// form (`elem_exp_vartime()` requires the base to be in Montgomery
|
||||||
let leading_zero_bits_in_m = r - lg_m;
|
// form). Then compute
|
||||||
if leading_zero_bits_in_m != 0 {
|
|
||||||
debug_assert!(leading_zero_bits_in_m < LIMB_BITS);
|
|
||||||
// `limbs_negative_odd` flipped all the leading zero bits to ones.
|
|
||||||
// Flip them back.
|
|
||||||
*base.limbs.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Double `base` so that base == R == 2**r (mod m). For normal moduli
|
|
||||||
// that have the high bit of the highest limb set, this requires one
|
|
||||||
// doubling. Unusual moduli require more doublings but we are less
|
|
||||||
// concerned about the performance of those.
|
|
||||||
//
|
|
||||||
// Then double `base` again so that base == 2*R (mod m), i.e. `2` in
|
|
||||||
// Montgomery form (`elem_exp_vartime()` requires the base to be in
|
|
||||||
// Montgomery form). Then compute
|
|
||||||
// RR = R**2 == base**r == R**r == (2**r)**r (mod m).
|
// RR = R**2 == base**r == R**r == (2**r)**r (mod m).
|
||||||
//
|
//
|
||||||
// Take advantage of the fact that `elem_double` is faster than
|
// Take advantage of the fact that `elem_double` is faster than
|
||||||
@ -321,7 +296,7 @@ impl<M> One<M, RR> {
|
|||||||
const LG_BASE: usize = 2; // Doubling vs. squaring trade-off.
|
const LG_BASE: usize = 2; // Doubling vs. squaring trade-off.
|
||||||
debug_assert_eq!(LG_BASE.count_ones(), 1); // Must be 2**n for n >= 0.
|
debug_assert_eq!(LG_BASE.count_ones(), 1); // Must be 2**n for n >= 0.
|
||||||
|
|
||||||
let doublings = leading_zero_bits_in_m + LG_BASE;
|
let doublings = LG_BASE;
|
||||||
// `m_bits >= LG_BASE` (for the currently chosen value of `LG_BASE`)
|
// `m_bits >= LG_BASE` (for the currently chosen value of `LG_BASE`)
|
||||||
// since we require the modulus to have at least `MODULUS_MIN_LIMBS`
|
// since we require the modulus to have at least `MODULUS_MIN_LIMBS`
|
||||||
// limbs. `r >= m_bits` as seen above. So `r >= LG_BASE` and thus
|
// limbs. `r >= m_bits` as seen above. So `r >= LG_BASE` and thus
|
||||||
@ -407,11 +382,8 @@ pub(crate) fn elem_exp_vartime<M>(
|
|||||||
pub fn elem_exp_consttime<M>(
|
pub fn elem_exp_consttime<M>(
|
||||||
base: Elem<M, R>,
|
base: Elem<M, R>,
|
||||||
exponent: &PrivateExponent,
|
exponent: &PrivateExponent,
|
||||||
m: &OwnedModulusWithOne<M>,
|
m: &Modulus<M>,
|
||||||
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
|
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
|
||||||
let oneRR = m.oneRR();
|
|
||||||
let m = &m.modulus();
|
|
||||||
|
|
||||||
use crate::{bssl, limb::Window};
|
use crate::{bssl, limb::Window};
|
||||||
|
|
||||||
const WINDOW_BITS: usize = 5;
|
const WINDOW_BITS: usize = 5;
|
||||||
@ -459,13 +431,7 @@ pub fn elem_exp_consttime<M>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// table[0] = base**0 (i.e. 1).
|
// table[0] = base**0 (i.e. 1).
|
||||||
{
|
m.oneR(entry_mut(&mut table, 0, num_limbs));
|
||||||
let acc = entry_mut(&mut table, 0, num_limbs);
|
|
||||||
// `table` was initialized to zero and hasn't changed.
|
|
||||||
debug_assert!(acc.iter().all(|&value| value == 0));
|
|
||||||
acc[0] = 1;
|
|
||||||
limbs_mont_mul(acc, &oneRR.0.limbs, m.limbs(), m.n0(), m.cpu_features());
|
|
||||||
}
|
|
||||||
|
|
||||||
entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
|
entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
|
||||||
for i in 2..TABLE_ENTRIES {
|
for i in 2..TABLE_ENTRIES {
|
||||||
@ -502,13 +468,10 @@ pub fn elem_exp_consttime<M>(
|
|||||||
pub fn elem_exp_consttime<M>(
|
pub fn elem_exp_consttime<M>(
|
||||||
base: Elem<M, R>,
|
base: Elem<M, R>,
|
||||||
exponent: &PrivateExponent,
|
exponent: &PrivateExponent,
|
||||||
m: &OwnedModulusWithOne<M>,
|
m: &Modulus<M>,
|
||||||
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
|
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
|
||||||
use crate::limb::LIMB_BYTES;
|
use crate::limb::LIMB_BYTES;
|
||||||
|
|
||||||
let oneRR = m.oneRR();
|
|
||||||
let m = &m.modulus();
|
|
||||||
|
|
||||||
// Pretty much all the math here requires CPU feature detection to have
|
// Pretty much all the math here requires CPU feature detection to have
|
||||||
// been done. `cpu_features` isn't threaded through all the internal
|
// been done. `cpu_features` isn't threaded through all the internal
|
||||||
// functions, so just make it clear that it has been done at this point.
|
// functions, so just make it clear that it has been done at this point.
|
||||||
@ -659,11 +622,7 @@ pub fn elem_exp_consttime<M>(
|
|||||||
// All entries in `table` will be Montgomery encoded.
|
// All entries in `table` will be Montgomery encoded.
|
||||||
|
|
||||||
// acc = table[0] = base**0 (i.e. 1).
|
// acc = table[0] = base**0 (i.e. 1).
|
||||||
// `acc` was initialized to zero and hasn't changed. Change it to 1 and then Montgomery
|
m.oneR(acc);
|
||||||
// encode it.
|
|
||||||
debug_assert!(acc.iter().all(|&value| value == 0));
|
|
||||||
acc[0] = 1;
|
|
||||||
limbs_mont_mul(acc, &oneRR.0.limbs, m_cached, n0, cpu_features);
|
|
||||||
scatter(table, acc, 0, num_limbs);
|
scatter(table, acc, 0, num_limbs);
|
||||||
|
|
||||||
// acc = base**1 (i.e. base).
|
// acc = base**1 (i.e. base).
|
||||||
@ -834,7 +793,7 @@ mod tests {
|
|||||||
.expect("valid exponent")
|
.expect("valid exponent")
|
||||||
};
|
};
|
||||||
let base = into_encoded(base, &m_);
|
let base = into_encoded(base, &m_);
|
||||||
let actual_result = elem_exp_consttime(base, &e, &m_).unwrap();
|
let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
|
||||||
assert_elem_eq(&actual_result, &expected_result);
|
assert_elem_eq(&actual_result, &expected_result);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -225,6 +225,36 @@ pub struct Modulus<'a, M> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<M> Modulus<'_, M> {
|
impl<M> Modulus<'_, M> {
|
||||||
|
pub(super) fn oneR(&self, out: &mut [Limb]) {
|
||||||
|
assert_eq!(self.limbs.len(), out.len());
|
||||||
|
|
||||||
|
let r = self.limbs.len() * LIMB_BITS;
|
||||||
|
|
||||||
|
// out = 2**r - m where m = self.
|
||||||
|
limb::limbs_negative_odd(out, self.limbs);
|
||||||
|
|
||||||
|
let lg_m = self.len_bits().as_usize_bits();
|
||||||
|
let leading_zero_bits_in_m = r - lg_m;
|
||||||
|
|
||||||
|
// When m's length is a multiple of LIMB_BITS, which is the case we
|
||||||
|
// most want to optimize for, then we already have
|
||||||
|
// out == 2**r - m == 2**r (mod m).
|
||||||
|
if leading_zero_bits_in_m != 0 {
|
||||||
|
debug_assert!(leading_zero_bits_in_m < LIMB_BITS);
|
||||||
|
// Correct out to 2**(lg m) (mod m). `limbs_negative_odd` flipped
|
||||||
|
// all the leading zero bits to ones. Flip them back.
|
||||||
|
*out.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m;
|
||||||
|
|
||||||
|
// Now we have out == 2**(lg m) (mod m). Keep doubling until we get
|
||||||
|
// to 2**r (mod m).
|
||||||
|
for _ in 0..leading_zero_bits_in_m {
|
||||||
|
limb::limbs_double_mod(out, self.limbs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now out == 2**r (mod m) == 1*R.
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: XXX Avoid duplication with `Modulus`.
|
// TODO: XXX Avoid duplication with `Modulus`.
|
||||||
pub(super) fn zero<E>(&self) -> Elem<M, E> {
|
pub(super) fn zero<E>(&self) -> Elem<M, E> {
|
||||||
Elem {
|
Elem {
|
||||||
|
11
src/limb.rs
11
src/limb.rs
@ -350,6 +350,17 @@ pub(crate) fn limbs_add_assign_mod(a: &mut [Limb], b: &[Limb], m: &[Limb]) {
|
|||||||
unsafe { LIMBS_add_mod(a.as_mut_ptr(), a.as_ptr(), b.as_ptr(), m.as_ptr(), m.len()) }
|
unsafe { LIMBS_add_mod(a.as_mut_ptr(), a.as_ptr(), b.as_ptr(), m.as_ptr(), m.len()) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// r *= 2 (mod m).
|
||||||
|
pub(crate) fn limbs_double_mod(r: &mut [Limb], m: &[Limb]) {
|
||||||
|
assert_eq!(r.len(), m.len());
|
||||||
|
prefixed_extern! {
|
||||||
|
fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
LIMBS_shl_mod(r.as_mut_ptr(), r.as_ptr(), m.as_ptr(), m.len());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// *r = -a, assuming a is odd.
|
// *r = -a, assuming a is odd.
|
||||||
pub(crate) fn limbs_negative_odd(r: &mut [Limb], a: &[Limb]) {
|
pub(crate) fn limbs_negative_odd(r: &mut [Limb], a: &[Limb]) {
|
||||||
debug_assert_eq!(r.len(), a.len());
|
debug_assert_eq!(r.len(), a.len());
|
||||||
|
@ -468,7 +468,7 @@ fn elem_exp_consttime<M>(
|
|||||||
// in the Smooth CRT-RSA paper.
|
// in the Smooth CRT-RSA paper.
|
||||||
let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m);
|
let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m);
|
||||||
let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m);
|
let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m);
|
||||||
bigint::elem_exp_consttime(c_mod_m, &p.exponent, &p.modulus)
|
bigint::elem_exp_consttime(c_mod_m, &p.exponent, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Type-level representations of the different moduli used in RSA signing, in
|
// Type-level representations of the different moduli used in RSA signing, in
|
||||||
|
Loading…
x
Reference in New Issue
Block a user