diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 8631d653a..6ea6eb2fe 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -20,7 +20,7 @@ pub mod bigint; pub mod montgomery; mod n0; -#[cfg(feature = "alloc")] +#[cfg(all(test, feature = "alloc"))] mod nonnegative; #[allow(dead_code)] diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index f01a380b4..c1bcc16fd 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -42,7 +42,6 @@ pub(crate) use self::{ private_exponent::PrivateExponent, }; use super::n0::N0; -pub(crate) use super::nonnegative::Nonnegative; use crate::{ arithmetic::montgomery::*, bits::BitLength, @@ -703,21 +702,6 @@ pub fn elem_verify_equal_consttime( } } -// TODO: Move these methods from `Nonnegative` to `Modulus`. -impl Nonnegative { - pub fn verify_less_than_modulus(&self, m: &Modulus) -> Result<(), error::Unspecified> { - if self.limbs().len() > m.limbs().len() { - return Err(error::Unspecified); - } - if self.limbs().len() == m.limbs().len() { - if limb::limbs_less_than_limbs_consttime(self.limbs(), m.limbs()) != LimbMask::True { - return Err(error::Unspecified); - } - } - Ok(()) - } -} - /// r *= a fn limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features) { debug_assert_eq!(r.len(), m.len()); @@ -789,7 +773,7 @@ prefixed_extern! { #[cfg(test)] mod tests { - use super::*; + use super::{super::nonnegative::Nonnegative, *}; use crate::test; // Type-level representation of an arbitrary modulus. diff --git a/src/arithmetic/bigint/modulus.rs b/src/arithmetic/bigint/modulus.rs index 25e5f9ea8..b0750639a 100644 --- a/src/arithmetic/bigint/modulus.rs +++ b/src/arithmetic/bigint/modulus.rs @@ -146,13 +146,18 @@ impl OwnedModulus { }) } - pub fn to_elem(&self, l: &Modulus) -> Result, error::Unspecified> { + pub fn verify_less_than(&self, l: &Modulus) -> Result<(), error::Unspecified> { if self.len_bits() > l.len_bits() || (self.limbs.len() == l.limbs().len() && limb::limbs_less_than_limbs_consttime(&self.limbs, l.limbs()) != LimbMask::True) { return Err(error::Unspecified); } + Ok(()) + } + + pub fn to_elem(&self, l: &Modulus) -> Result, error::Unspecified> { + self.verify_less_than(l)?; let mut limbs = BoxedLimbs::zero(l.limbs.len()); limbs[..self.limbs.len()].copy_from_slice(&self.limbs); Ok(Elem { diff --git a/src/arithmetic/nonnegative.rs b/src/arithmetic/nonnegative.rs index 87ad7c970..0a3b0a594 100644 --- a/src/arithmetic/nonnegative.rs +++ b/src/arithmetic/nonnegative.rs @@ -14,7 +14,7 @@ use crate::{ bits, error, - limb::{self, Limb, LimbMask, LIMB_BYTES}, + limb::{self, Limb, LIMB_BYTES}, }; use alloc::{vec, vec::Vec}; @@ -37,10 +37,6 @@ impl Nonnegative { Ok((Self { limbs }, r_bits)) } - #[inline] - pub fn is_odd(&self) -> bool { - limb::limbs_are_even_constant_time(&self.limbs) != LimbMask::True - } #[inline] pub fn limbs(&self) -> &[Limb] { &self.limbs diff --git a/src/rsa/keypair.rs b/src/rsa/keypair.rs index 71af675ba..05fb0c911 100644 --- a/src/rsa/keypair.rs +++ b/src/rsa/keypair.rs @@ -338,18 +338,15 @@ impl KeyPair { // First, validate `2**half_n_bits < d`. Since 2**half_n_bits has a bit // length of half_n_bits + 1, this check gives us 2**half_n_bits <= d, // and knowing d is odd makes the inequality strict. - let (d, d_bits) = bigint::Nonnegative::from_be_bytes_with_bit_length(d) - .map_err(|_| error::KeyRejected::invalid_encoding())?; - if !(n_bits.half_rounded_up() < d_bits) { + let d = bigint::OwnedModulus::::from_be_bytes(d, cpu_features) + .map_err(|_| error::KeyRejected::invalid_component())?; + if !(n_bits.half_rounded_up() < d.len_bits()) { return Err(KeyRejected::inconsistent_components()); } // XXX: This check should be `d < LCM(p - 1, q - 1)`, but we don't have // a good way of calculating LCM, so it is omitted, as explained above. - d.verify_less_than_modulus(n) + d.verify_less_than(n) .map_err(|error::Unspecified| KeyRejected::inconsistent_components())?; - if !d.is_odd() { - return Err(KeyRejected::invalid_component()); - } // Step 6.b is omitted as explained above. @@ -501,6 +498,8 @@ enum P {} #[derive(Copy, Clone)] enum Q {} +enum D {} + impl KeyPair { /// Computes the signature of `msg` and writes it into `signature`. ///