Made Scalar::bits return an iterator rather than an array (#451)

Addresses issue #448 that Scalar::bits may leave unzeroed bits on the stack
This commit is contained in:
Michael Rosenberg 2022-12-08 16:37:42 -05:00 committed by GitHub
parent 42e93d7faf
commit 0b72bb5dc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 15 deletions

View File

@ -335,17 +335,23 @@ impl<'a, 'b> Mul<&'b Scalar> for &'a MontgomeryPoint {
W: FieldElement::one(),
};
let bits: [i8; 256] = scalar.bits();
for i in (0..255).rev() {
let choice: u8 = (bits[i + 1] ^ bits[i]) as u8;
// Go through the bits from most to least significant, using a sliding window of 2
let mut bits = scalar.bits_le().rev();
let mut prev_bit = bits.next().unwrap();
for cur_bit in bits {
let choice: u8 = (prev_bit ^ cur_bit) as u8;
debug_assert!(choice == 0 || choice == 1);
ProjectivePoint::conditional_swap(&mut x0, &mut x1, choice.into());
differential_add_and_double(&mut x0, &mut x1, &affine_u);
prev_bit = cur_bit;
}
ProjectivePoint::conditional_swap(&mut x0, &mut x1, Choice::from(bits[0] as u8));
// The final value of prev_bit above is scalar.bits()[0], i.e., the LSB of scalar
ProjectivePoint::conditional_swap(&mut x0, &mut x1, Choice::from(prev_bit as u8));
// Don't leave the bit in the stack
prev_bit.zeroize();
x0.as_affine()
}

View File

@ -847,16 +847,14 @@ impl Scalar {
ret
}
/// Get the bits of the scalar.
pub(crate) fn bits(&self) -> [i8; 256] {
let mut bits = [0i8; 256];
#[allow(clippy::needless_range_loop)]
for i in 0..256 {
// As i runs from 0..256, the bottom 3 bits index the bit,
// while the upper bits index the byte.
bits[i] = ((self.bytes[i >> 3] >> (i & 7)) & 1u8) as i8;
}
bits
/// Get the bits of the scalar, in little-endian order
pub(crate) fn bits_le(&self) -> impl DoubleEndedIterator<Item = bool> + '_ {
(0..256).map(|i| {
// As i runs from 0..256, the bottom 3 bits index the bit, while the upper bits index
// the byte. Since self.bytes is little-endian at the byte level, this iterator is
// little-endian on the bit level
((self.bytes[i >> 3] >> (i & 7)) & 1u8) == 1
})
}
/// Compute a width-\\(w\\) "Non-Adjacent Form" of this scalar.