diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index 331254c46..f01a380b4 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -56,18 +56,6 @@ mod boxed_limbs; mod modulus; mod private_exponent; -/// A modulus *s* that is smaller than another modulus *l* so every element of -/// ℤ/sℤ is also an element of ℤ/lℤ. -/// -/// # Safety -/// -/// Some logic may assume that the invariant holds when accessing limbs within -/// a value, e.g. by assuming the larger modulus has at least as many limbs. -/// TODO: Any such logic should be encapsulated here, or this trait should be -/// made non-`unsafe`. (In retrospect, this shouldn't have been made an `unsafe` -/// trait preemptively.) -pub unsafe trait SmallerModulus {} - pub trait PublicModulus {} /// Elements of ℤ/mℤ for some modulus *m*. @@ -224,13 +212,17 @@ where } } -pub fn elem_widen>( +pub fn elem_widen( a: Elem, m: &Modulus, -) -> Elem { + smaller_modulus_bits: BitLength, +) -> Result, error::Unspecified> { + if smaller_modulus_bits >= m.len_bits() { + return Err(error::Unspecified); + } let mut r = m.zero(); r.limbs[..a.limbs.len()].copy_from_slice(&a.limbs); - r + Ok(r) } // TODO: Document why this works for all Montgomery factors. diff --git a/src/rsa/keypair.rs b/src/rsa/keypair.rs index bf9920e51..71af675ba 100644 --- a/src/rsa/keypair.rs +++ b/src/rsa/keypair.rs @@ -497,11 +497,9 @@ fn elem_exp_consttime( #[derive(Copy, Clone)] enum P {} -unsafe impl bigint::SmallerModulus for P {} #[derive(Copy, Clone)] enum Q {} -unsafe impl bigint::SmallerModulus for Q {} impl KeyPair { /// Computes the signature of `msg` and writes it into `signature`. @@ -591,11 +589,12 @@ impl KeyPair { // necessary because `h < p` and `p * q == n` implies `h * q < n`. // Modular arithmetic is used simply to avoid implementing // non-modular arithmetic. - let h = bigint::elem_widen(h, n); + let p_bits = self.p.modulus.len_bits(); + let h = bigint::elem_widen(h, n, p_bits)?; let q_mod_n = self.q.modulus.to_elem(n)?; let q_mod_n = bigint::elem_mul(n_one, q_mod_n, n); let q_times_h = bigint::elem_mul(&q_mod_n, h, n); - let m_2 = bigint::elem_widen(m_2, n); + let m_2 = bigint::elem_widen(m_2, n, q_bits)?; let m = bigint::elem_add(m_2, q_times_h, n); // Step 2.b.v isn't needed since there are only two primes.