From b1147381c9c18280b30d286ea222e00ed47f4fae Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Wed, 11 Oct 2023 11:04:32 -0700 Subject: [PATCH] Generalize `array_flatten` into an `ArrayFlatten` trait. --- src/aead/aes_gcm.rs | 10 ++++++---- src/aead/chacha20_poly1305.rs | 4 ++-- src/aead/gcm/gcm_nohw.rs | 4 ++-- src/polyfill.rs | 2 +- src/polyfill/array_flatten.rs | 24 +++++++++++++++++------- 5 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/aead/aes_gcm.rs b/src/aead/aes_gcm.rs index 9976d5d07..0a2b19405 100644 --- a/src/aead/aes_gcm.rs +++ b/src/aead/aes_gcm.rs @@ -19,7 +19,7 @@ use super::{ }; use crate::{ aead, cpu, error, - polyfill::{self, array_flatten}, + polyfill::{self, ArrayFlatten}, }; use core::ops::RangeFrom; @@ -245,9 +245,11 @@ fn finish( // Authenticate the final block containing the input lengths. let aad_bits = polyfill::u64_from_usize(aad_len) << 3; let ciphertext_bits = polyfill::u64_from_usize(in_out_len) << 3; - gcm_ctx.update_block(Block::from(&array_flatten( - [aad_bits, ciphertext_bits].map(u64::to_be_bytes), - ))); + gcm_ctx.update_block(Block::from( + &[aad_bits, ciphertext_bits] + .map(u64::to_be_bytes) + .array_flatten(), + )); // Finalize the tag and return it. gcm_ctx.pre_finish(|pre_tag| { diff --git a/src/aead/chacha20_poly1305.rs b/src/aead/chacha20_poly1305.rs index 4e3b7b031..10c707bf3 100644 --- a/src/aead/chacha20_poly1305.rs +++ b/src/aead/chacha20_poly1305.rs @@ -18,7 +18,7 @@ use super::{ }; use crate::{ aead, cpu, error, - polyfill::{self, array_flatten}, + polyfill::{self, ArrayFlatten}, }; use core::ops::RangeFrom; @@ -213,7 +213,7 @@ fn finish(mut auth: poly1305::Context, aad_len: usize, in_out_len: usize) -> Tag let block: [[u8; 8]; 2] = [aad_len, in_out_len] .map(polyfill::u64_from_usize) .map(u64::to_le_bytes); - auth.update(&array_flatten(block)); + auth.update(&block.array_flatten()); auth.finish() } diff --git a/src/aead/gcm/gcm_nohw.rs b/src/aead/gcm/gcm_nohw.rs index e3bdd9160..48414fef4 100644 --- a/src/aead/gcm/gcm_nohw.rs +++ b/src/aead/gcm/gcm_nohw.rs @@ -23,7 +23,7 @@ // Unlike the BearSSL notes, we use u128 in the 64-bit implementation. use super::{Block, Xi, BLOCK_LEN}; -use crate::polyfill::{array_flatten, ChunksFixed}; +use crate::polyfill::{ArrayFlatten, ChunksFixed}; #[cfg(target_pointer_width = "64")] fn gcm_mul64_nohw(a: u64, b: u64) -> (u64, u64) { @@ -242,5 +242,5 @@ fn with_swapped_xi(Xi(xi): &mut Xi, f: impl FnOnce(&mut [u64; 2])) { let mut swapped: [u64; 2] = [unswapped[1], unswapped[0]]; f(&mut swapped); let reswapped = [swapped[1], swapped[0]]; - *xi = Block::from(&array_flatten(reswapped.map(u64::to_be_bytes))) + *xi = Block::from(&reswapped.map(u64::to_be_bytes).array_flatten()) } diff --git a/src/polyfill.rs b/src/polyfill.rs index c957b9430..1521f9fe4 100644 --- a/src/polyfill.rs +++ b/src/polyfill.rs @@ -39,7 +39,7 @@ mod test; mod unwrap_const; pub use self::{ - array_flat_map::ArrayFlatMap, array_flatten::array_flatten, chunks_fixed::*, + array_flat_map::ArrayFlatMap, array_flatten::ArrayFlatten, chunks_fixed::*, unwrap_const::unwrap_const, }; diff --git a/src/polyfill/array_flatten.rs b/src/polyfill/array_flatten.rs index a99612b83..b99db2a0c 100644 --- a/src/polyfill/array_flatten.rs +++ b/src/polyfill/array_flatten.rs @@ -12,11 +12,21 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -/// Returns the flattened form of `a` -#[inline(always)] -pub fn array_flatten(a: [[T; 8]; 2]) -> [T; 16] { - let [[a0, a1, a2, a3, a4, a5, a6, a7], [b0, b1, b2, b3, b4, b5, b6, b7]] = a; - [ - a0, a1, a2, a3, a4, a5, a6, a7, b0, b1, b2, b3, b4, b5, b6, b7, - ] +pub trait ArrayFlatten { + type Output; + + /// Returns the flattened form of `a` + fn array_flatten(self) -> Self::Output; +} + +impl ArrayFlatten for [[T; 8]; 2] { + type Output = [T; 16]; + + #[inline(always)] + fn array_flatten(self) -> Self::Output { + let [[a0, a1, a2, a3, a4, a5, a6, a7], [b0, b1, b2, b3, b4, b5, b6, b7]] = self; + [ + a0, a1, a2, a3, a4, a5, a6, a7, b0, b1, b2, b3, b4, b5, b6, b7, + ] + } }