bigint: NFC: Take oneRR out of OwnedModulus.

`PublicModulus` and `PrivatePrime` are basically duplicates of
`OwnedModulusWithOne`. In the future we would like to create an
`OwnedModulus` that doesn't need 1RR to be calculated. Also in the
future we'd like to be able to "take" 1RR from a public modulus.
This change is a step towards those ends.
This commit is contained in:
Brian Smith 2023-11-22 15:59:21 -08:00
parent 986fe1f5ff
commit 6de27244ff
5 changed files with 64 additions and 70 deletions

View File

@ -38,7 +38,7 @@
use self::boxed_limbs::BoxedLimbs; use self::boxed_limbs::BoxedLimbs;
pub(crate) use self::{ pub(crate) use self::{
modulus::{Modulus, OwnedModulusWithOne, MODULUS_MAX_LIMBS}, modulus::{Modulus, OwnedModulus, MODULUS_MAX_LIMBS},
private_exponent::PrivateExponent, private_exponent::PrivateExponent,
}; };
use super::n0::N0; use super::n0::N0;
@ -274,7 +274,7 @@ impl<M> One<M, RR> {
// values, using `LIMB_BITS` here, rather than `N0::LIMBS_USED * LIMB_BITS`, // values, using `LIMB_BITS` here, rather than `N0::LIMBS_USED * LIMB_BITS`,
// is correct because R**2 will still be a multiple of the latter as // is correct because R**2 will still be a multiple of the latter as
// `N0::LIMBS_USED` is either one or two. // `N0::LIMBS_USED` is either one or two.
fn newRR(m: &Modulus<M>) -> Self { pub(crate) fn newRR(m: &Modulus<M>) -> Self {
// The number of limbs in the numbers involved. // The number of limbs in the numbers involved.
let w = m.limbs().len(); let w = m.limbs().len();
@ -808,8 +808,8 @@ mod tests {
|section, test_case| { |section, test_case| {
assert_eq!(section, ""); assert_eq!(section, "");
let m_ = consume_modulus::<M>(test_case, "M", cpu_features); let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m_.modulus(); let m = m.modulus();
let expected_result = consume_elem(test_case, "ModExp", &m); let expected_result = consume_elem(test_case, "ModExp", &m);
let base = consume_elem(test_case, "A", &m); let base = consume_elem(test_case, "A", &m);
let e = { let e = {
@ -817,7 +817,7 @@ mod tests {
PrivateExponent::from_be_bytes_for_test_only(untrusted::Input::from(&bytes), &m) PrivateExponent::from_be_bytes_for_test_only(untrusted::Input::from(&bytes), &m)
.expect("valid exponent") .expect("valid exponent")
}; };
let base = into_encoded(base, &m_); let base = into_encoded(base, &m);
let actual_result = elem_exp_consttime(base, &e, &m).unwrap(); let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
assert_elem_eq(&actual_result, &expected_result); assert_elem_eq(&actual_result, &expected_result);
@ -838,14 +838,14 @@ mod tests {
|section, test_case| { |section, test_case| {
assert_eq!(section, ""); assert_eq!(section, "");
let m_ = consume_modulus::<M>(test_case, "M", cpu_features); let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m_.modulus(); let m = m.modulus();
let expected_result = consume_elem(test_case, "ModMul", &m); let expected_result = consume_elem(test_case, "ModMul", &m);
let a = consume_elem(test_case, "A", &m); let a = consume_elem(test_case, "A", &m);
let b = consume_elem(test_case, "B", &m); let b = consume_elem(test_case, "B", &m);
let b = into_encoded(b, &m_); let b = into_encoded(b, &m);
let a = into_encoded(a, &m_); let a = into_encoded(a, &m);
let actual_result = elem_mul(&a, b, &m); let actual_result = elem_mul(&a, b, &m);
let actual_result = actual_result.into_unencoded(&m); let actual_result = actual_result.into_unencoded(&m);
assert_elem_eq(&actual_result, &expected_result); assert_elem_eq(&actual_result, &expected_result);
@ -863,12 +863,12 @@ mod tests {
|section, test_case| { |section, test_case| {
assert_eq!(section, ""); assert_eq!(section, "");
let m_ = consume_modulus::<M>(test_case, "M", cpu_features); let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m_.modulus(); let m = m.modulus();
let expected_result = consume_elem(test_case, "ModSquare", &m); let expected_result = consume_elem(test_case, "ModSquare", &m);
let a = consume_elem(test_case, "A", &m); let a = consume_elem(test_case, "A", &m);
let a = into_encoded(a, &m_); let a = into_encoded(a, &m);
let actual_result = elem_squared(a, &m); let actual_result = elem_squared(a, &m);
let actual_result = actual_result.into_unencoded(&m); let actual_result = actual_result.into_unencoded(&m);
assert_elem_eq(&actual_result, &expected_result); assert_elem_eq(&actual_result, &expected_result);
@ -896,7 +896,7 @@ mod tests {
let other_modulus_len_bits = m_.len_bits(); let other_modulus_len_bits = m_.len_bits();
let actual_result = elem_reduced(&a, &m, other_modulus_len_bits); let actual_result = elem_reduced(&a, &m, other_modulus_len_bits);
let oneRR = m_.oneRR(); let oneRR = One::newRR(&m);
let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m); let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m);
assert_elem_eq(&actual_result, &expected_result); assert_elem_eq(&actual_result, &expected_result);
@ -930,7 +930,7 @@ mod tests {
#[test] #[test]
fn test_modulus_debug() { fn test_modulus_debug() {
let modulus = OwnedModulusWithOne::<M>::from_be_bytes( let modulus = OwnedModulus::<M>::from_be_bytes(
untrusted::Input::from(&[0xff; LIMB_BYTES * MODULUS_MIN_LIMBS]), untrusted::Input::from(&[0xff; LIMB_BYTES * MODULUS_MIN_LIMBS]),
cpu::features(), cpu::features(),
) )
@ -965,9 +965,9 @@ mod tests {
test_case: &mut test::TestCase, test_case: &mut test::TestCase,
name: &str, name: &str,
cpu_features: cpu::Features, cpu_features: cpu::Features,
) -> OwnedModulusWithOne<M> { ) -> OwnedModulus<M> {
let value = test_case.consume_bytes(name); let value = test_case.consume_bytes(name);
OwnedModulusWithOne::from_be_bytes(untrusted::Input::from(&value), cpu_features).unwrap() OwnedModulus::from_be_bytes(untrusted::Input::from(&value), cpu_features).unwrap()
} }
fn consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative { fn consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative {
@ -983,7 +983,8 @@ mod tests {
} }
} }
fn into_encoded<M>(a: Elem<M, Unencoded>, m: &OwnedModulusWithOne<M>) -> Elem<M, R> { fn into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R> {
elem_mul(m.oneRR().as_ref(), a, &m.modulus()) let oneRR = One::newRR(m);
elem_mul(oneRR.as_ref(), a, m)
} }
} }

View File

@ -12,10 +12,7 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
use super::{ use super::{super::n0::N0, BoxedLimbs, Elem, PublicModulus, SmallerModulus, Unencoded};
super::{montgomery::RR, n0::N0},
BoxedLimbs, Elem, One, PublicModulus, SmallerModulus, Unencoded,
};
use crate::{ use crate::{
bits::BitLength, bits::BitLength,
cpu, error, cpu, error,
@ -37,7 +34,7 @@ pub const MODULUS_MAX_LIMBS: usize = super::super::BIGINT_MODULUS_MAX_LIMBS;
/// for efficient Montgomery multiplication modulo *m*. The value must be odd /// for efficient Montgomery multiplication modulo *m*. The value must be odd
/// and larger than 2. The larger-than-1 requirement is imposed, at least, by /// and larger than 2. The larger-than-1 requirement is imposed, at least, by
/// the modular inversion code. /// the modular inversion code.
pub struct OwnedModulusWithOne<M> { pub struct OwnedModulus<M> {
limbs: BoxedLimbs<M>, // Also `value >= 3`. limbs: BoxedLimbs<M>, // Also `value >= 3`.
// n0 * N == -1 (mod r). // n0 * N == -1 (mod r).
@ -77,26 +74,23 @@ pub struct OwnedModulusWithOne<M> {
// calculations instead of double-precision `u64` calculations. // calculations instead of double-precision `u64` calculations.
n0: N0, n0: N0,
oneRR: One<M, RR>,
len_bits: BitLength, len_bits: BitLength,
cpu_features: cpu::Features, cpu_features: cpu::Features,
} }
impl<M: PublicModulus> Clone for OwnedModulusWithOne<M> { impl<M: PublicModulus> Clone for OwnedModulus<M> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
limbs: self.limbs.clone(), limbs: self.limbs.clone(),
n0: self.n0, n0: self.n0,
oneRR: self.oneRR.clone(),
len_bits: self.len_bits, len_bits: self.len_bits,
cpu_features: self.cpu_features, cpu_features: self.cpu_features,
} }
} }
} }
impl<M: PublicModulus> core::fmt::Debug for OwnedModulusWithOne<M> { impl<M: PublicModulus> core::fmt::Debug for OwnedModulus<M> {
fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error> { fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error> {
fmt.debug_struct("Modulus") fmt.debug_struct("Modulus")
// TODO: Print modulus value. // TODO: Print modulus value.
@ -104,7 +98,7 @@ impl<M: PublicModulus> core::fmt::Debug for OwnedModulusWithOne<M> {
} }
} }
impl<M> OwnedModulusWithOne<M> { impl<M> OwnedModulus<M> {
pub(crate) fn from_be_bytes( pub(crate) fn from_be_bytes(
input: untrusted::Input, input: untrusted::Input,
cpu_features: cpu::Features, cpu_features: cpu::Features,
@ -151,31 +145,15 @@ impl<M> OwnedModulusWithOne<M> {
}; };
let len_bits = limb::limbs_minimal_bits(&n); let len_bits = limb::limbs_minimal_bits(&n);
let oneRR = {
let partial = Modulus {
limbs: &n,
n0,
len_bits,
m: PhantomData,
cpu_features,
};
One::newRR(&partial)
};
Ok(Self { Ok(Self {
limbs: n, limbs: n,
n0, n0,
oneRR,
len_bits, len_bits,
cpu_features, cpu_features,
}) })
} }
pub fn oneRR(&self) -> &One<M, RR> {
&self.oneRR
}
pub fn to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded> pub fn to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded>
where where
M: SmallerModulus<L>, M: SmallerModulus<L>,
@ -202,7 +180,7 @@ impl<M> OwnedModulusWithOne<M> {
} }
} }
impl<M: PublicModulus> OwnedModulusWithOne<M> { impl<M: PublicModulus> OwnedModulus<M> {
pub fn be_bytes(&self) -> LeadingZerosStripped<impl ExactSizeIterator<Item = u8> + Clone + '_> { pub fn be_bytes(&self) -> LeadingZerosStripped<impl ExactSizeIterator<Item = u8> + Clone + '_> {
LeadingZerosStripped::new(limb::unstripped_be_bytes(&self.limbs)) LeadingZerosStripped::new(limb::unstripped_be_bytes(&self.limbs))
} }

View File

@ -18,7 +18,10 @@ use super::{
/// RSA PKCS#1 1.5 signatures. /// RSA PKCS#1 1.5 signatures.
use crate::{ use crate::{
arithmetic::{bigint, montgomery::R}, arithmetic::{
bigint,
montgomery::{R, RR},
},
bits::BitLength, bits::BitLength,
cpu, digest, cpu, digest,
error::{self, KeyRejected}, error::{self, KeyRejected},
@ -281,8 +284,8 @@ impl KeyPair {
cpu_features, cpu_features,
)?; )?;
let n_one = public_key.inner().n().value().oneRR(); let n_one = public_key.inner().n().oneRR();
let n = &public_key.inner().n().value().modulus(); let n = &public_key.inner().n().value();
// 6.4.1.4.3 says to skip 6.4.1.2.1 Step 2. // 6.4.1.4.3 says to skip 6.4.1.2.1 Step 2.
@ -316,7 +319,7 @@ impl KeyPair {
// checking p * q == 0 (mod n) is equivalent to checking p * q == n. // checking p * q == 0 (mod n) is equivalent to checking p * q == n.
let q_mod_n = q.modulus.to_elem(n); let q_mod_n = q.modulus.to_elem(n);
let p_mod_n = p.modulus.to_elem(n); let p_mod_n = p.modulus.to_elem(n);
let p_mod_n = bigint::elem_mul(n_one.as_ref(), p_mod_n, n); let p_mod_n = bigint::elem_mul(n_one, p_mod_n, n);
let pq_mod_n = bigint::elem_mul(&q_mod_n, p_mod_n, n); let pq_mod_n = bigint::elem_mul(&q_mod_n, p_mod_n, n);
if !pq_mod_n.is_zero() { if !pq_mod_n.is_zero() {
return Err(KeyRejected::inconsistent_components()); return Err(KeyRejected::inconsistent_components());
@ -357,9 +360,9 @@ impl KeyPair {
// with an even modulus. // with an even modulus.
// Step 7.f. // Step 7.f.
let qInv = bigint::elem_mul(p.modulus.oneRR().as_ref(), qInv, pm); let qInv = bigint::elem_mul(p.oneRR.as_ref(), qInv, pm);
let q_mod_p = bigint::elem_reduced(&q_mod_n, pm, q.modulus.len_bits()); let q_mod_p = bigint::elem_reduced(&q_mod_n, pm, q.modulus.len_bits());
let q_mod_p = bigint::elem_mul(p.modulus.oneRR().as_ref(), q_mod_p, pm); let q_mod_p = bigint::elem_mul(p.oneRR.as_ref(), q_mod_p, pm);
bigint::verify_inverses_consttime(&qInv, q_mod_p, pm) bigint::verify_inverses_consttime(&qInv, q_mod_p, pm)
.map_err(|error::Unspecified| KeyRejected::inconsistent_components())?; .map_err(|error::Unspecified| KeyRejected::inconsistent_components())?;
@ -397,7 +400,8 @@ impl signature::KeyPair for KeyPair {
} }
struct PrivatePrime<M> { struct PrivatePrime<M> {
modulus: bigint::OwnedModulusWithOne<M>, modulus: bigint::OwnedModulus<M>,
oneRR: bigint::One<M, RR>,
exponent: bigint::PrivateExponent, exponent: bigint::PrivateExponent,
} }
@ -410,7 +414,7 @@ impl<M> PrivatePrime<M> {
n_bits: BitLength, n_bits: BitLength,
cpu_features: cpu::Features, cpu_features: cpu::Features,
) -> Result<Self, KeyRejected> { ) -> Result<Self, KeyRejected> {
let p = bigint::OwnedModulusWithOne::from_be_bytes(p, cpu_features)?; let p = bigint::OwnedModulus::from_be_bytes(p, cpu_features)?;
// 5.c / 5.g: // 5.c / 5.g:
// //
@ -445,8 +449,11 @@ impl<M> PrivatePrime<M> {
return Err(error::KeyRejected::private_modulus_len_not_multiple_of_512_bits()); return Err(error::KeyRejected::private_modulus_len_not_multiple_of_512_bits());
} }
let oneRR = bigint::One::newRR(&p.modulus());
Ok(Self { Ok(Self {
modulus: p, modulus: p,
oneRR,
exponent: dP, exponent: dP,
}) })
} }
@ -461,8 +468,8 @@ fn elem_exp_consttime<M>(
let c_mod_m = bigint::elem_reduced(c, m, other_prime_len_bits); let c_mod_m = bigint::elem_reduced(c, m, other_prime_len_bits);
// We could precompute `oneRRR = elem_squared(&p.oneRR`) as mentioned // We could precompute `oneRRR = elem_squared(&p.oneRR`) as mentioned
// in the Smooth CRT-RSA paper. // in the Smooth CRT-RSA paper.
let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m); let c_mod_m = bigint::elem_mul(p.oneRR.as_ref(), c_mod_m, m);
let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m); let c_mod_m = bigint::elem_mul(p.oneRR.as_ref(), c_mod_m, m);
bigint::elem_exp_consttime(c_mod_m, &p.exponent, m) bigint::elem_exp_consttime(c_mod_m, &p.exponent, m)
} }
@ -537,9 +544,8 @@ impl KeyPair {
// RFC 8017 Section 5.1.2: RSADP, using the Chinese Remainder Theorem // RFC 8017 Section 5.1.2: RSADP, using the Chinese Remainder Theorem
// with Garner's algorithm. // with Garner's algorithm.
let n = self.public.inner().n().value(); let n = &self.public.inner().n().value();
let n_one = n.oneRR(); let n_one = self.public.inner().n().oneRR();
let n = &n.modulus();
// Step 1. The value zero is also rejected. // Step 1. The value zero is also rejected.
let base = bigint::Elem::from_be_bytes_padded(untrusted::Input::from(base), n)?; let base = bigint::Elem::from_be_bytes_padded(untrusted::Input::from(base), n)?;
@ -568,7 +574,7 @@ impl KeyPair {
// non-modular arithmetic. // non-modular arithmetic.
let h = bigint::elem_widen(h, n); let h = bigint::elem_widen(h, n);
let q_mod_n = self.q.modulus.to_elem(n); let q_mod_n = self.q.modulus.to_elem(n);
let q_mod_n = bigint::elem_mul(n_one.as_ref(), q_mod_n, n); let q_mod_n = bigint::elem_mul(n_one, q_mod_n, n);
let q_times_h = bigint::elem_mul(&q_mod_n, h, n); let q_times_h = bigint::elem_mul(&q_mod_n, h, n);
let m_2 = bigint::elem_widen(m_2, n); let m_2 = bigint::elem_widen(m_2, n);
let m = bigint::elem_add(m_2, q_times_h, n); let m = bigint::elem_add(m_2, q_times_h, n);

View File

@ -145,7 +145,7 @@ impl Inner {
base: untrusted::Input, base: untrusted::Input,
out_buffer: &'out mut [u8; PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN], out_buffer: &'out mut [u8; PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN],
) -> Result<&'out [u8], error::Unspecified> { ) -> Result<&'out [u8], error::Unspecified> {
let n = &self.n.value().modulus(); let n = &self.n.value();
// The encoded value of the base must be the same length as the modulus, // The encoded value of the base must be the same length as the modulus,
// in bytes. // in bytes.
@ -177,10 +177,9 @@ impl Inner {
// The exponent was already checked to be odd. // The exponent was already checked to be odd.
debug_assert_ne!(exponent_without_low_bit, self.e.value()); debug_assert_ne!(exponent_without_low_bit, self.e.value());
let n_ = self.n.value(); let n = &self.n.value();
let n = &n_.modulus();
let base_r = bigint::elem_mul(n_.oneRR().as_ref(), base.clone(), n); let base_r = bigint::elem_mul(self.n.oneRR(), base.clone(), n);
// During RSA public key operations the exponent is almost always either // During RSA public key operations the exponent is almost always either
// 65537 (0b10000000000000001) or 3 (0b11), both of which have a Hamming // 65537 (0b10000000000000001) or 3 (0b11), both of which have a Hamming

View File

@ -1,10 +1,15 @@
use crate::{arithmetic::bigint, bits, cpu, error, rsa::N}; use crate::{
arithmetic::{bigint, montgomery::RR},
bits, cpu, error,
rsa::N,
};
use core::ops::RangeInclusive; use core::ops::RangeInclusive;
/// The modulus (n) of an RSA public key. /// The modulus (n) of an RSA public key.
#[derive(Clone)] #[derive(Clone)]
pub struct PublicModulus { pub struct PublicModulus {
value: bigint::OwnedModulusWithOne<N>, value: bigint::OwnedModulus<N>,
oneRR: bigint::One<N, RR>,
} }
/* /*
@ -32,7 +37,7 @@ impl PublicModulus {
const MIN_BITS: bits::BitLength = bits::BitLength::from_usize_bits(1024); const MIN_BITS: bits::BitLength = bits::BitLength::from_usize_bits(1024);
// Step 3 / Step c for `n` (out of order). // Step 3 / Step c for `n` (out of order).
let value = bigint::OwnedModulusWithOne::from_be_bytes(n, cpu_features)?; let value = bigint::OwnedModulus::from_be_bytes(n, cpu_features)?;
let bits = value.len_bits(); let bits = value.len_bits();
// Step 1 / Step a. XXX: SP800-56Br1 and SP800-89 require the length of // Step 1 / Step a. XXX: SP800-56Br1 and SP800-89 require the length of
@ -47,8 +52,9 @@ impl PublicModulus {
if bits > max_bits { if bits > max_bits {
return Err(error::KeyRejected::too_large()); return Err(error::KeyRejected::too_large());
} }
let oneRR = bigint::One::newRR(&value.modulus());
Ok(Self { value }) Ok(Self { value, oneRR })
} }
/// The big-endian encoding of the modulus. /// The big-endian encoding of the modulus.
@ -63,7 +69,11 @@ impl PublicModulus {
self.value.len_bits() self.value.len_bits()
} }
pub(super) fn value(&self) -> &bigint::OwnedModulusWithOne<N> { pub(super) fn value(&self) -> bigint::Modulus<N> {
&self.value self.value.modulus()
}
pub(super) fn oneRR(&self) -> &bigint::Elem<N, RR> {
self.oneRR.as_ref()
} }
} }