diff --git a/src/arithmetic/bigint/modulus.rs b/src/arithmetic/bigint/modulus.rs index 3b88c9fb4..25e5f9ea8 100644 --- a/src/arithmetic/bigint/modulus.rs +++ b/src/arithmetic/bigint/modulus.rs @@ -12,7 +12,7 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use super::{super::n0::N0, BoxedLimbs, Elem, PublicModulus, SmallerModulus, Unencoded}; +use super::{super::n0::N0, BoxedLimbs, Elem, PublicModulus, Unencoded}; use crate::{ bits::BitLength, cpu, error, @@ -146,16 +146,19 @@ impl OwnedModulus { }) } - pub fn to_elem(&self, l: &Modulus) -> Elem - where - M: SmallerModulus, - { + pub fn to_elem(&self, l: &Modulus) -> Result, error::Unspecified> { + if self.len_bits() > l.len_bits() + || (self.limbs.len() == l.limbs().len() + && limb::limbs_less_than_limbs_consttime(&self.limbs, l.limbs()) != LimbMask::True) + { + return Err(error::Unspecified); + } let mut limbs = BoxedLimbs::zero(l.limbs.len()); limbs[..self.limbs.len()].copy_from_slice(&self.limbs); - Elem { + Ok(Elem { limbs, encoding: PhantomData, - } + }) } pub fn modulus(&self) -> Modulus { Modulus { diff --git a/src/rsa/keypair.rs b/src/rsa/keypair.rs index 657a7bc93..bf9920e51 100644 --- a/src/rsa/keypair.rs +++ b/src/rsa/keypair.rs @@ -317,8 +317,14 @@ impl KeyPair { // 0 < q < p < n. We check that q and p are close to sqrt(n) and then // assume that these preconditions are enough to let us assume that // checking p * q == 0 (mod n) is equivalent to checking p * q == n. - let q_mod_n = q.modulus.to_elem(n); - let p_mod_n = p.modulus.to_elem(n); + let q_mod_n = q + .modulus + .to_elem(n) + .map_err(|error::Unspecified| KeyRejected::inconsistent_components())?; + let p_mod_n = p + .modulus + .to_elem(n) + .map_err(|error::Unspecified| KeyRejected::inconsistent_components())?; let p_mod_n = bigint::elem_mul(n_one, p_mod_n, n); let pq_mod_n = bigint::elem_mul(&q_mod_n, p_mod_n, n); if !pq_mod_n.is_zero() { @@ -586,7 +592,7 @@ impl KeyPair { // Modular arithmetic is used simply to avoid implementing // non-modular arithmetic. let h = bigint::elem_widen(h, n); - let q_mod_n = self.q.modulus.to_elem(n); + let q_mod_n = self.q.modulus.to_elem(n)?; let q_mod_n = bigint::elem_mul(n_one, q_mod_n, n); let q_times_h = bigint::elem_mul(&q_mod_n, h, n); let m_2 = bigint::elem_widen(m_2, n);