diff --git a/userspace/Cargo.lock b/userspace/Cargo.lock index 181fea9c..1d9157dd 100644 --- a/userspace/Cargo.lock +++ b/userspace/Cargo.lock @@ -2592,10 +2592,14 @@ dependencies = [ name = "rsh" version = "0.1.0" dependencies = [ + "aead", "aes", + "aes-gcm", "bytemuck", + "chacha20poly1305", "clap", "cross", + "crypto-common", "ed25519-dalek", "libterm", "log", diff --git a/userspace/Cargo.toml b/userspace/Cargo.toml index a43abc00..33407406 100644 --- a/userspace/Cargo.toml +++ b/userspace/Cargo.toml @@ -62,6 +62,7 @@ chacha20poly1305 = { version = "0.10.1", default-features = false, features = [" aes-gcm = { version = "0.10.1", default-features = false, features = ["alloc"] } aead = { version = "0.5.2", default-features = false, features = ["alloc"] } sha2 = { version = "0.10.9" } +crypto-common = "0.1.6" webpki-roots = "1.0.1" raqote = { version = "0.8.3", default-features = false } diff --git a/userspace/tools/rsh/Cargo.toml b/userspace/tools/rsh/Cargo.toml index 56bec9c6..cba60d4f 100644 --- a/userspace/tools/rsh/Cargo.toml +++ b/userspace/tools/rsh/Cargo.toml @@ -19,9 +19,12 @@ x25519-dalek.workspace = true ed25519-dalek = { workspace = true, features = ["rand_core", "pem"] } sha2.workspace = true log.workspace = true - -rand = { git = "https://git.alnyan.me/yggdrasil/rand.git", branch = "alnyan/yggdrasil-rng_core-0.6.4" } -aes = { version = "0.8.4" } +aes.workspace = true +aes-gcm.workspace = true +chacha20poly1305.workspace = true +crypto-common.workspace = true +aead.workspace = true +rand.workspace = true [lints] workspace = true diff --git a/userspace/tools/rsh/src/crypt/client.rs b/userspace/tools/rsh/src/crypt/client.rs index d8139bf3..47503cee 100644 --- a/userspace/tools/rsh/src/crypt/client.rs +++ b/userspace/tools/rsh/src/crypt/client.rs @@ -1,87 +1,34 @@ use std::{ io::{self, Read, Write}, net::{SocketAddr, TcpStream}, - os::fd::{AsRawFd, RawFd}, - time::Duration, }; +use rand::RngCore; use x25519_dalek::{EphemeralSecret, PublicKey}; use crate::{ - crypt::{self, signature::VerificationMethod}, + crypt::{ + self, + signature::VerificationMethod, + stream::ClientSocket, + symmetric::{data::Iv, SymmetricCipher}, + V1_CLIENT_RANDOM_LEN, + }, proto::{Decode, DecodeError, Decoder, Encode, EncodeError, Encoder}, }; use super::{ config::ClientConfig, signature::{self, SignatureMethod}, - symmetric::{self, SymmetricCipher}, - ClientNegotiationMessage, ServerNegotiationMessage, + symmetric, ClientNegotiationMessage, ServerNegotiationMessage, }; -pub const MESSAGE_SIZE_MAX: usize = 4096; -const FRAMING_BUFFER_SIZE: usize = 8192; - -pub struct ClientSocket { - pub(crate) stream: TcpStream, - pub(crate) remote: SocketAddr, - pub(crate) buffer: [u8; MESSAGE_SIZE_MAX], - pub(crate) recv_buf: FramingBuffer, - - pub(crate) signer: SignatureMethod, - pub(crate) verifier: VerificationMethod, - pub(crate) symmetric: SymmetricCipher, -} - -pub enum Message { - Data(T), - Incomplete, - Closed, -} - -pub(crate) struct FramingBuffer { - buffer: [u8; FRAMING_BUFFER_SIZE], - len: usize, -} - -impl FramingBuffer { - pub fn new() -> Self { - Self { - buffer: [0; FRAMING_BUFFER_SIZE], - len: 0, - } - } - - pub fn get_mut(&mut self) -> &mut [u8] { - &mut self.buffer[self.len..] - } - - pub fn advance(&mut self, len: usize) { - self.len += len; - assert!(self.len <= self.buffer.len()); - } - - pub fn pop(&mut self, buffer: &mut [u8]) -> Option { - if self.len < size_of::() { - return None; - } - let len = u16::from_le_bytes([self.buffer[0], self.buffer[1]]) as usize; - if self.len < size_of::() + len { - return None; - } - // TODO check dst len - buffer[..len].copy_from_slice(&self.buffer[size_of::()..len + size_of::()]); - self.buffer.copy_within(len + size_of::()..self.len, 0); - self.len -= len + size_of::(); - Some(len) - } -} - -struct Negotiation { +pub(crate) struct Negotiation { stream: TcpStream, remote: SocketAddr, buffer: [u8; 512], config: ClientConfig, + client_random: [u8; V1_CLIENT_RANDOM_LEN], } #[derive(Debug, thiserror::Error)] @@ -110,87 +57,6 @@ pub enum Error { ServerKeyRejected, } -impl ClientSocket { - pub fn connect(address: SocketAddr, config: ClientConfig) -> Result { - let stream = TcpStream::connect(address)?; - stream.set_read_timeout(Some(Duration::from_secs(1)))?; - stream.set_write_timeout(Some(Duration::from_secs(1)))?; - Negotiation::new(stream, address, config).perform() - } - - pub fn remote_address(&self) -> SocketAddr { - self.remote - } - - pub fn write_all(&mut self, message: &E) -> Result<(), Error> { - let mut buf = [0; MESSAGE_SIZE_MAX - 256]; - let mut encoder = Encoder::new(&mut buf); - message.encode(&mut encoder)?; - - // Insert signature - let (payload, rest) = encoder.split_mut(); - let payload_len = payload.len(); - let signature_len = self.signer.sign(payload, rest)?; - - let len = self.symmetric.encrypt( - &buf[..payload_len + signature_len], - &mut self.buffer[size_of::()..], - )?; - let len_bytes: u16 = len.try_into().unwrap(); - self.buffer[..size_of::()].copy_from_slice(&len_bytes.to_le_bytes()); - self.stream - .write_all(&self.buffer[..len + size_of::()])?; - Ok(()) - } - - pub fn poll_read<'de, D: Decode<'de>>( - &mut self, - buffer: &'de mut [u8], - ) -> Result, Error> { - if self.poll()? == 0 { - return Ok(Message::Closed); - } - match self.read(buffer)? { - Some(message) => Ok(Message::Data(message)), - None => Ok(Message::Incomplete), - } - } - - pub fn poll(&mut self) -> Result { - let dst = self.recv_buf.get_mut(); - if dst.is_empty() { - todo!() - } - let len = self.stream.read(dst)?; - self.recv_buf.advance(len); - Ok(len) - } - - pub fn read<'de, D: Decode<'de>>(&mut self, buffer: &'de mut [u8]) -> Result, Error> { - if let Some(len) = self.recv_buf.pop(&mut self.buffer) { - let data_len = self.symmetric.decrypt(&self.buffer[..len], buffer)?; - - let mut decoder = Decoder::new(&buffer[..data_len]); - let message = D::decode(&mut decoder)?; - - // Verify signature - let (payload, signature) = decoder.split(); - self.verifier.verify(payload, signature)?; - - Ok(Some(message)) - } else { - // Buffer doesn't contain a full message yet - Ok(None) - } - } -} - -impl AsRawFd for ClientSocket { - fn as_raw_fd(&self) -> RawFd { - self.stream.as_raw_fd() - } -} - impl Negotiation { pub fn new(stream: TcpStream, remote: SocketAddr, config: ClientConfig) -> Self { Self { @@ -198,9 +64,15 @@ impl Negotiation { remote, buffer: [0; 512], config, + client_random: [0; V1_CLIENT_RANDOM_LEN], } } + fn fill_client_random(&mut self) { + let mut rng = rand::rngs::OsRng; + rng.fill_bytes(&mut self.client_random); + } + fn hello(&mut self, recv_buf: &mut [u8]) -> Result<(SignatureMethod, u8, u8), Error> { log::info!("Send ClientHello v1"); self.send(None, &ClientNegotiationMessage::Hello { protocol: 1 })?; @@ -212,13 +84,13 @@ impl Negotiation { }; log::info!("Server ciphersuites:"); - for &cipher in hello.symmetric_ciphersuites { - if let Some(name) = crypt::ciphersuite_name(cipher) { - log::info!(" * {name:?} ({cipher:#x})"); - } else { - log::info!(" * {cipher:#x}"); - } - } + // for &cipher in hello.symmetric_ciphersuites { + // if let Some(name) = crypt::ciphersuite_name(cipher) { + // log::info!(" * {name:?} ({cipher:#x})"); + // } else { + // log::info!(" * {cipher:#x}"); + // } + // } log::info!("Server signature algorithms:"); for &sig in hello.sig_algos { if let Some(name) = crypt::sig_algo_name(sig) { @@ -248,6 +120,9 @@ impl Negotiation { ciphersuite: u8, kex_algorithm: u8, ) -> Result { + // Fill the client random buffer + self.fill_client_random(); + let sig_algorithm = signer.algorithm(); let key_data = signer.verifying_key_bytes(); @@ -256,8 +131,11 @@ impl Negotiation { crypt::signature::fingerprint_sha256(sig_algorithm_name, &key_data); log::info!("Offer {offered_fingerprint}"); - let ciphersuite_name = crypt::ciphersuite_name(ciphersuite).unwrap_or("???"); + let ciphersuite_name = symmetric::ciphersuite(ciphersuite) + .map(|t| t.1) + .unwrap_or("???"); log::info!("With ciphersuite {ciphersuite_name:?} ({ciphersuite:#x})"); + let client_random = self.client_random; self.send( None, @@ -265,6 +143,7 @@ impl Negotiation { kex_algorithm, sig_algorithm, ciphersuite, + client_random: &client_random, key_data: &key_data, }, )?; @@ -315,12 +194,24 @@ impl Negotiation { if server_public_key.len() != 32 { todo!() } + + let (ciphersuite_algo, ciphersuite_name) = symmetric::ciphersuite(ciphersuite).ok_or( + Error::Symmetric(symmetric::Error::InvalidCiphersuite(ciphersuite)), + )?; + let key_size = ciphersuite_algo.key_size(); + + log::info!("Using ciphersuite: {ciphersuite_name}"); + let mut server_key = [0; 32]; server_key.copy_from_slice(server_public_key); let server_key = PublicKey::from(server_key); let shared = ephemeral.diffie_hellman(&server_key); - let cipher = SymmetricCipher::new(ciphersuite, shared.as_bytes())?; + let cipher = SymmetricCipher::from_ciphersuite( + ciphersuite_algo, + &shared.as_bytes()[..key_size], + Iv::from(self.client_random), + )?; Ok(cipher) } @@ -348,16 +239,13 @@ impl Negotiation { self.finish(&mut recv_buf, &mut signer, &mut verifier)?; log::info!("Established"); - - Ok(ClientSocket { - stream: self.stream, - remote: self.remote, - buffer: [0; MESSAGE_SIZE_MAX], - recv_buf: FramingBuffer::new(), + Ok(ClientSocket::from_parts( + self.stream, + self.remote, signer, verifier, symmetric, - }) + )) } fn recv<'de>( diff --git a/userspace/tools/rsh/src/crypt/config.rs b/userspace/tools/rsh/src/crypt/config.rs index 7d4dc18f..63b3096d 100644 --- a/userspace/tools/rsh/src/crypt/config.rs +++ b/userspace/tools/rsh/src/crypt/config.rs @@ -1,6 +1,9 @@ use std::{collections::HashSet, path::PathBuf}; -use crate::crypt::{V1_CIPHER_AES_256_CBC, V1_CIPHER_AES_256_ECB, V1_KEX_X25519_DALEK}; +use crate::crypt::{ + symmetric::{V1_CIPHER_AES_128_GCM, V1_CIPHER_AES_256_GCM, V1_CIPHER_CHACHA20POLY1305}, + V1_KEX_X25519_DALEK, +}; use super::{ sig_algo_name, @@ -22,7 +25,11 @@ fn default_select_kex_algorithm(offer: &[u8]) -> Option { } fn default_select_ciphersuite(offer: &[u8]) -> Option { - const ACCEPTED: &[u8] = &[V1_CIPHER_AES_256_CBC, V1_CIPHER_AES_256_ECB]; + const ACCEPTED: &[u8] = &[ + V1_CIPHER_CHACHA20POLY1305, + V1_CIPHER_AES_256_GCM, + V1_CIPHER_AES_128_GCM, + ]; for accepted in ACCEPTED { if offer.contains(accepted) { @@ -41,7 +48,11 @@ fn default_offer_kex_algorithms() -> &'static [u8] { } fn default_offer_ciphersuites() -> &'static [u8] { - &[V1_CIPHER_AES_256_CBC, V1_CIPHER_AES_256_ECB] + &[ + V1_CIPHER_CHACHA20POLY1305, + V1_CIPHER_AES_256_GCM, + V1_CIPHER_AES_128_GCM, + ] } pub struct ClientConfig { diff --git a/userspace/tools/rsh/src/crypt/mod.rs b/userspace/tools/rsh/src/crypt/mod.rs index 6ec1b4a8..2ea74f30 100644 --- a/userspace/tools/rsh/src/crypt/mod.rs +++ b/userspace/tools/rsh/src/crypt/mod.rs @@ -4,18 +4,17 @@ pub mod client; pub mod config; pub mod server; pub mod signature; +pub mod stream; pub mod symmetric; pub mod util; -pub const V1_CIPHER_NULL: u8 = 0x00; -pub const V1_CIPHER_AES_256_ECB: u8 = 0x10; -pub const V1_CIPHER_AES_256_CBC: u8 = 0x11; - // v1 supports only one DH algo pub const V1_KEX_X25519_DALEK: u8 = 0x10; pub const V1_SIG_ED25519: u8 = 0x10; +pub const V1_CLIENT_RANDOM_LEN: usize = 16; + #[derive(Debug)] pub enum ServerReject { NoReason, @@ -44,6 +43,7 @@ pub enum ClientNegotiationMessage<'a> { kex_algorithm: u8, sig_algorithm: u8, ciphersuite: u8, + client_random: &'a [u8; V1_CLIENT_RANDOM_LEN], key_data: &'a [u8], }, DHPublicKey(&'a [u8]), @@ -99,6 +99,7 @@ impl Encode for ClientNegotiationMessage<'_> { kex_algorithm, sig_algorithm, ciphersuite, + client_random, key_data, } => { buffer.write(&[ @@ -107,6 +108,7 @@ impl Encode for ClientNegotiationMessage<'_> { sig_algorithm, ciphersuite, ])?; + buffer.write(client_random)?; buffer.write_variable_bytes(key_data) } Self::DHPublicKey(key_data) => { @@ -130,11 +132,13 @@ impl<'de> Decode<'de> for ClientNegotiationMessage<'de> { let kex_algorithm = buffer.read_u8()?; let sig_algorithm = buffer.read_u8()?; let ciphersuite = buffer.read_u8()?; + let client_random = buffer.read_bytes(V1_CLIENT_RANDOM_LEN)?.try_into().unwrap(); let key_data = buffer.read_variable_bytes()?; Ok(Self::StartKex { kex_algorithm, sig_algorithm, ciphersuite, + client_random, key_data, }) } @@ -203,15 +207,6 @@ impl<'de> Decode<'de> for ServerNegotiationMessage<'de> { } } -pub fn ciphersuite_name(cipher: u8) -> Option<&'static str> { - match cipher { - V1_CIPHER_NULL => Some("null"), - V1_CIPHER_AES_256_ECB => Some("aes-256-ecb"), - V1_CIPHER_AES_256_CBC => Some("aes-256-cbc"), - _ => None, - } -} - pub fn sig_algo_name(sig: u8) -> Option<&'static str> { match sig { V1_SIG_ED25519 => Some("ed25519"), diff --git a/userspace/tools/rsh/src/crypt/server.rs b/userspace/tools/rsh/src/crypt/server.rs index 78d82d52..9d69ec3b 100644 --- a/userspace/tools/rsh/src/crypt/server.rs +++ b/userspace/tools/rsh/src/crypt/server.rs @@ -10,14 +10,13 @@ use x25519_dalek::{EphemeralSecret, PublicKey}; use crate::{ crypt::{ - client::MESSAGE_SIZE_MAX, sig_algo_name, signature::fingerprint_sha256, ServerHello, - ServerNegotiationMessage, + sig_algo_name, signature::fingerprint_sha256, stream::ClientSocket, symmetric::data::Iv, + ServerHello, ServerNegotiationMessage, V1_CLIENT_RANDOM_LEN, }, proto::{Decode, DecodeError, Decoder, Encode, EncodeError, Encoder}, }; use super::{ - client::{ClientSocket, FramingBuffer}, config::ServerConfig, signature::{SignatureMethod, VerificationMethod}, symmetric::{self, SymmetricCipher}, @@ -72,7 +71,7 @@ enum NegotiationOutcome { enum NegotiationState { None, Hello, - StartKex(u8, u8), + StartKex(u8, u8, [u8; V1_CLIENT_RANDOM_LEN]), DHExchange, } @@ -124,15 +123,13 @@ impl ServerSocket { Ok(NegotiationOutcome::Accepted) => { self.poll.remove(&fd)?; let client = self.pending.remove(&fd).unwrap(); - Ok(Some(ClientSocket { - stream: client.stream, - signer: client.signer.unwrap(), - recv_buf: FramingBuffer::new(), - remote: address, - verifier: client.verifier.unwrap(), - symmetric: client.symmetric.unwrap(), - buffer: [0; MESSAGE_SIZE_MAX], - })) + Ok(Some(ClientSocket::from_parts( + client.stream, + address, + client.signer.unwrap(), + client.verifier.unwrap(), + client.symmetric.unwrap(), + ))) } Err(error) => { log::error!("{address}: {error}"); @@ -177,14 +174,19 @@ impl PendingClient { kex_algorithm, sig_algorithm, ciphersuite, + client_random, key_data, }, NegotiationState::Hello, ) => { log::debug!("{address}: StartKex {{ ... }}"); + log::debug!("{address}: client random: {client_random:02x?}"); + + // TODO entropy test on client random let sig_algorithm_name = sig_algo_name(sig_algorithm).unwrap_or("???"); let their_fingerprint = fingerprint_sha256(sig_algorithm_name, key_data); + let client_random = *client_random; log::info!("{address}: their fingerprint {their_fingerprint}"); let verifier = match config @@ -210,17 +212,22 @@ impl PendingClient { })?; self.signer = Some(signer); self.verifier = Some(verifier); - self.state = NegotiationState::StartKex(ciphersuite, kex_algorithm); + self.state = NegotiationState::StartKex(ciphersuite, kex_algorithm, client_random); Ok(NegotiationOutcome::Pending) } ( ClientNegotiationMessage::DHPublicKey(key_data), - NegotiationState::StartKex(ciphersuite, _kex_algorithm), + NegotiationState::StartKex(ciphersuite, _kex_algorithm, client_random), ) => { - if key_data.len() != 32 { - todo!() - } log::debug!("{address}: DHPublicKey {{ ... }}"); + let (ciphersuite_algo, ciphersuite_name) = symmetric::ciphersuite(ciphersuite) + .ok_or(Error::Symmetric(symmetric::Error::InvalidCiphersuite( + ciphersuite, + )))?; + let key_size = ciphersuite_algo.key_size(); + + log::debug!("{address}: selected ciphersuite {ciphersuite_name}"); + let mut rng = rand::thread_rng(); let mut their_public = [0; 32]; their_public.copy_from_slice(key_data); @@ -228,7 +235,14 @@ impl PendingClient { let ephemeral = EphemeralSecret::random_from_rng(&mut rng); let public = PublicKey::from(&ephemeral); let shared = ephemeral.diffie_hellman(&their_public); - let symmetric = SymmetricCipher::new(ciphersuite, shared.as_bytes())?; + + let iv = Iv::from(client_random); + let symmetric = SymmetricCipher::from_ciphersuite( + ciphersuite_algo, + &shared.as_bytes()[..key_size], + iv, + )?; + self.send(&ServerNegotiationMessage::DHPublicKey(public.as_bytes()))?; self.state = NegotiationState::DHExchange; self.symmetric = Some(symmetric); diff --git a/userspace/tools/rsh/src/crypt/stream.rs b/userspace/tools/rsh/src/crypt/stream.rs new file mode 100644 index 00000000..2055d0e8 --- /dev/null +++ b/userspace/tools/rsh/src/crypt/stream.rs @@ -0,0 +1,206 @@ +use std::{ + io::{self, Read, Write}, + net::{SocketAddr, TcpStream}, + os::fd::{AsRawFd, RawFd}, + time::Duration, +}; + +use crate::{ + crypt::{ + client::{self, Negotiation}, + config::ClientConfig, + signature::{SignatureMethod, VerificationMethod}, + symmetric::{ + self, + data::{InboundEncryptedMessage, OutboundPlainMessage}, + SymmetricCipher, + }, + }, + proto::{Decode, DecodeError, Decoder, Encode, EncodeError, Encoder}, +}; + +pub const MESSAGE_SIZE_MAX: usize = 4096; +const FRAMING_BUFFER_SIZE: usize = 8192; +const HEADER_SIZE: usize = size_of::() + size_of::(); + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("I/O: {0}")] + Io(#[from] io::Error), + #[error("Encode error: {0}")] + Encode(#[from] EncodeError), + #[error("Decode error: {0}")] + Decode(#[from] DecodeError), + #[error("Cipher error: {0}")] + Symmetric(#[from] symmetric::Error), + #[error("Receive buffer too small")] + BufferTooSmall, +} + +pub struct ClientSocket { + pub(crate) stream: TcpStream, + pub(crate) remote: SocketAddr, + pub(crate) buffer: [u8; MESSAGE_SIZE_MAX], + pub(crate) recv_buf: FramingBuffer, + + // TODO: only use those for signature verification during negotiation? + #[allow(unused)] + pub(crate) signer: SignatureMethod, + #[allow(unused)] + pub(crate) verifier: VerificationMethod, + pub(crate) symmetric: SymmetricCipher, + + pub(crate) tx_seq: u64, +} + +pub enum Message { + Data(T), + Incomplete, + Closed, +} + +pub(crate) struct FramingBuffer { + buffer: [u8; FRAMING_BUFFER_SIZE], + len: usize, +} + +impl FramingBuffer { + pub fn new() -> Self { + Self { + buffer: [0; FRAMING_BUFFER_SIZE], + len: 0, + } + } + + pub fn get_mut(&mut self) -> &mut [u8] { + &mut self.buffer[self.len..] + } + + pub fn advance(&mut self, len: usize) { + self.len += len; + assert!(self.len <= self.buffer.len()); + } + + pub fn pop(&mut self, buffer: &mut [u8]) -> Option<(usize, u64)> { + if self.len < HEADER_SIZE { + return None; + } + let len = u16::from_be_bytes([self.buffer[0], self.buffer[1]]) as usize; + if self.len < HEADER_SIZE + len { + return None; + } + let mut seq = [0; 8]; + seq.copy_from_slice(&self.buffer[2..HEADER_SIZE]); + let seq = u64::from_be_bytes(seq); + // TODO check dst len + buffer[..len].copy_from_slice(&self.buffer[HEADER_SIZE..len + HEADER_SIZE]); + self.buffer.copy_within(len + HEADER_SIZE..self.len, 0); + self.len -= len + HEADER_SIZE; + Some((len, seq)) + } +} + +impl ClientSocket { + pub(crate) fn from_parts( + stream: TcpStream, + remote: SocketAddr, + signer: SignatureMethod, + verifier: VerificationMethod, + symmetric: SymmetricCipher, + ) -> Self { + ClientSocket { + stream, + remote, + signer, + verifier, + symmetric, + buffer: [0; MESSAGE_SIZE_MAX], + recv_buf: FramingBuffer::new(), + tx_seq: 0x1234567887654321, + } + } + + pub fn connect(address: SocketAddr, config: ClientConfig) -> Result { + let stream = TcpStream::connect(address)?; + stream.set_read_timeout(Some(Duration::from_secs(1)))?; + stream.set_write_timeout(Some(Duration::from_secs(1)))?; + Negotiation::new(stream, address, config).perform() + } + + pub fn remote_address(&self) -> SocketAddr { + self.remote + } + + pub fn write_all(&mut self, message: &E) -> Result<(), Error> { + let mut buf = [0; MESSAGE_SIZE_MAX - 256]; + + let mut encoder = Encoder::new(&mut buf[HEADER_SIZE..]); + message.encode(&mut encoder)?; + let (payload, _) = encoder.split_mut(); + + let seq = self.tx_seq; + self.tx_seq = self.tx_seq.wrapping_add(payload.len() as u64); + + let outbound_plain = OutboundPlainMessage::new(payload); + let outbound_encrypted = self.symmetric.encrypter.encrypt(seq, outbound_plain)?; + + let total_len = outbound_encrypted.payload.len(); + let len: u16 = total_len.try_into().unwrap(); + + buf[0..2].copy_from_slice(&len.to_be_bytes()); + buf[2..HEADER_SIZE].copy_from_slice(&seq.to_be_bytes()); + buf[HEADER_SIZE..HEADER_SIZE + outbound_encrypted.payload.len()] + .copy_from_slice(&outbound_encrypted.payload[..]); + + self.stream.write_all(&buf[..total_len + HEADER_SIZE])?; + + Ok(()) + } + + pub fn poll_read<'de, D: Decode<'de>>( + &mut self, + buffer: &'de mut [u8], + ) -> Result, Error> { + if self.poll()? == 0 { + return Ok(Message::Closed); + } + match self.read(buffer)? { + Some(message) => Ok(Message::Data(message)), + None => Ok(Message::Incomplete), + } + } + + pub fn poll(&mut self) -> Result { + let dst = self.recv_buf.get_mut(); + if dst.is_empty() { + todo!() + } + let len = self.stream.read(dst)?; + self.recv_buf.advance(len); + Ok(len) + } + + pub fn read<'de, D: Decode<'de>>(&mut self, buffer: &'de mut [u8]) -> Result, Error> { + if let Some((len, seq)) = self.recv_buf.pop(&mut self.buffer) { + if len > buffer.len() { + return Err(Error::BufferTooSmall); + } + buffer[..len].copy_from_slice(&self.buffer[..len]); + let inbound_encrypted = InboundEncryptedMessage::from(&mut buffer[..len]); + let inbound_plain = self.symmetric.decrypter.decrypt(seq, inbound_encrypted)?; + let mut decoder = Decoder::new(inbound_plain.payload); + let message = D::decode(&mut decoder)?; + + Ok(Some(message)) + } else { + // Buffer doesn't contain a full message yet + Ok(None) + } + } +} + +impl AsRawFd for ClientSocket { + fn as_raw_fd(&self) -> RawFd { + self.stream.as_raw_fd() + } +} diff --git a/userspace/tools/rsh/src/crypt/symmetric.rs b/userspace/tools/rsh/src/crypt/symmetric.rs deleted file mode 100644 index 73a7b4ab..00000000 --- a/userspace/tools/rsh/src/crypt/symmetric.rs +++ /dev/null @@ -1,410 +0,0 @@ -use aes::{ - cipher::{BlockDecrypt, BlockEncrypt, KeyInit}, - Aes256, Block, -}; - -use crate::crypt::{util, V1_CIPHER_NULL}; - -use super::{V1_CIPHER_AES_256_CBC, V1_CIPHER_AES_256_ECB}; - -pub trait AesBlockMode { - fn encryption(&mut self, aes: &Aes256, block: Block) -> Block; - fn decryption(&mut self, aes: &Aes256, block: Block) -> Block; -} - -pub struct Aes256BlockCipher { - aes: Aes256, - mode: M, -} - -pub struct CipherModeEcb; -pub struct CipherModeCbc { - iv_encrypt: Block, - iv_decrypt: Block, -} - -pub enum SymmetricCipher { - Aes256Ecb(Aes256BlockCipher), - Aes256Cbc(Aes256BlockCipher), - Null, -} - -pub struct Pkcs7Padder<'src> { - block: Block, - src: &'src [u8], - extra: bool, -} - -pub struct Pkcs7Unpadder<'dst> { - pos: usize, - dst: &'dst mut [u8], - block_count: usize, - index: usize, -} - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("Invalid symmetric encryption key")] - InvalidKey, - #[error("Malformed ciphertext")] - InvalidCiphertext, - #[error("Payload too large: buffer size {0}, payload size {1}")] - MessageTooLarge(usize, usize), -} - -impl<'src> Pkcs7Padder<'src> { - pub fn new(src: &'src [u8]) -> Self { - Self { - block: Block::from([0; 16]), - extra: false, - src, - } - } -} - -impl<'dst> Pkcs7Unpadder<'dst> { - pub fn new(dst: &'dst mut [u8], block_count: usize) -> Self { - Self { - pos: 0, - dst, - block_count, - index: 0, - } - } - - fn write_dst(&mut self, data: &[u8]) -> Result<(), Error> { - if self.pos + data.len() > self.dst.len() { - Err(Error::MessageTooLarge( - self.dst.len(), - self.pos + data.len(), - )) - } else { - self.dst[self.pos..self.pos + data.len()].copy_from_slice(data); - self.pos += data.len(); - Ok(()) - } - } - - pub fn push(&mut self, block: Block) -> Result<(), Error> { - // Cases: 1 padded block - // 1 full block, 1 16b - // 2 full blocks, 1 16b - // 1 full, 1 padded - if self.index >= self.block_count { - return Ok(()); - } - match self.block_count { - 0 => return Ok(()), - _ if self.index < self.block_count - 1 => { - self.write_dst(&block)?; - } - // Last block - _ => { - if *block != [16; 16] { - let pad = block[15] as usize; - if pad > 15 { - return Err(Error::InvalidCiphertext); - } - let len = 16 - pad; - self.write_dst(&block[..len])?; - } - } - } - self.index += 1; - Ok(()) - } - - pub fn finish(self) -> &'dst mut [u8] { - &mut self.dst[..self.pos] - } -} - -impl Iterator for Pkcs7Padder<'_> { - type Item = Block; - - fn next(&mut self) -> Option { - if self.src.is_empty() { - if self.extra { - self.block.fill(16); - self.extra = false; - Some(self.block) - } else { - None - } - } else { - let len = core::cmp::min(self.src.len(), 16); - // Need an extra block to indicate that the last block was full - self.extra = len == 16; - - self.block[..len].copy_from_slice(&self.src[..len]); - self.block[len..].fill(16 - len as u8); - self.src = &self.src[len..]; - - Some(self.block) - } - } -} - -impl Aes256BlockCipher { - pub fn new(key: &[u8], mode: M) -> Result { - let aes = Aes256::new_from_slice(key).map_err(|_| Error::InvalidKey)?; - Ok(Self { aes, mode }) - } - - pub fn encrypt(&mut self, src: &[u8], dst: &mut [u8]) -> Result { - let full_len = (src.len() + 16) & !15; - if dst.len() < full_len { - return Err(Error::MessageTooLarge(dst.len(), full_len)); - } - let mut dst_pos = 0; - for block in Pkcs7Padder::new(src) { - let block = self.mode.encryption(&self.aes, block); - dst[dst_pos..dst_pos + 16].copy_from_slice(&block[..]); - dst_pos += 16; - } - debug_assert_eq!(full_len, dst_pos); - Ok(full_len) - } - - pub fn decrypt(&mut self, src: &[u8], dst: &mut [u8]) -> Result { - if src.len() % 16 != 0 { - return Err(Error::InvalidCiphertext); - } - let mut unpadder = Pkcs7Unpadder::new(dst, src.len() / 16); - for i in 0..src.len() / 16 { - let block = Block::clone_from_slice(&src[i * 16..i * 16 + 16]); - let block = self.mode.decryption(&self.aes, block); - unpadder.push(block)?; - } - let len = unpadder.finish().len(); - Ok(len) - } -} - -impl AesBlockMode for CipherModeEcb { - fn encryption(&mut self, aes: &Aes256, mut block: Block) -> Block { - aes.encrypt_block(&mut block); - block - } - - fn decryption(&mut self, aes: &Aes256, mut block: Block) -> Block { - aes.decrypt_block(&mut block); - block - } -} - -impl AesBlockMode for CipherModeCbc { - fn encryption(&mut self, aes: &Aes256, block: Block) -> Block { - let mut block = util::xor16b(block, self.iv_encrypt); - aes.encrypt_block(&mut block); - self.iv_encrypt = block; - block - } - - fn decryption(&mut self, aes: &Aes256, ciphertext: Block) -> Block { - let mut block = ciphertext; - aes.decrypt_block(&mut block); - let block = util::xor16b(block, self.iv_decrypt); - self.iv_decrypt = ciphertext; - block - } -} - -impl SymmetricCipher { - pub fn new(suite: u8, shared_key: &[u8]) -> Result { - match suite { - V1_CIPHER_NULL => Ok(Self::Null), - V1_CIPHER_AES_256_ECB => { - Aes256BlockCipher::new(shared_key, CipherModeEcb).map(Self::Aes256Ecb) - } - V1_CIPHER_AES_256_CBC => Aes256BlockCipher::new( - shared_key, - CipherModeCbc { - iv_encrypt: [0; 16].into(), - iv_decrypt: [0; 16].into(), - }, - ) - .map(Self::Aes256Cbc), - _ => unreachable!(), - } - } - - pub fn encrypt(&mut self, src: &[u8], dst: &mut [u8]) -> Result { - match self { - Self::Aes256Ecb(cipher) => cipher.encrypt(src, dst), - Self::Aes256Cbc(cipher) => cipher.encrypt(src, dst), - Self::Null => { - if src.len() > dst.len() { - return Err(Error::MessageTooLarge(dst.len(), src.len())); - } - - dst[..src.len()].copy_from_slice(src); - Ok(src.len()) - } - } - } - - pub fn decrypt(&mut self, src: &[u8], dst: &mut [u8]) -> Result { - match self { - Self::Aes256Ecb(cipher) => cipher.decrypt(src, dst), - Self::Aes256Cbc(cipher) => cipher.decrypt(src, dst), - Self::Null => { - if src.len() > dst.len() { - return Err(Error::MessageTooLarge(dst.len(), src.len())); - } - - dst[..src.len()].copy_from_slice(src); - Ok(src.len()) - } - } - } -} - -#[cfg(test)] -mod tests { - use aes::Block; - - use super::{Aes256BlockCipher, CipherModeCbc, Pkcs7Padder, Pkcs7Unpadder}; - - fn pad_unpad<'d>(text: &[u8], buffer: &'d mut [u8]) -> &'d [u8] { - let blocks = Pkcs7Padder::new(text).collect::>(); - let mut unpad = Pkcs7Unpadder::new(buffer, blocks.len()); - for block in blocks { - unpad.push(block).unwrap(); - } - unpad.finish() - } - - #[test] - fn test_pkcs7_pad() { - // 19 bytes - let src = b"This is a test text"; - let blocks = Pkcs7Padder::new(src).collect::>(); - assert_eq!(blocks.len(), 2); - assert_eq!(&*blocks[0], b"This is a test t"); - assert_eq!(&blocks[1][..3], b"ext"); - assert_eq!(&blocks[1][3..], &[13; 13]); - - // 32 bytes - let src = b"1234567812345678ABCDEFGHIJKLMNOP"; - let blocks = Pkcs7Padder::new(src).collect::>(); - assert_eq!(blocks.len(), 3); - assert_eq!(&*blocks[0], b"1234567812345678"); - assert_eq!(&*blocks[1], b"ABCDEFGHIJKLMNOP"); - assert_eq!(&*blocks[2], &[16; 16]); - } - - #[test] - fn test_pkcs7_unpad() { - // 19 bytes - let mut src = [0; 32]; - let mut dst = [0; 32]; - src[..19].copy_from_slice(b"This is a test text"); - src[19..].fill(13); - let mut unpad = Pkcs7Unpadder::new(&mut dst, 2); - for i in 0..2 { - let block = Block::from_slice(&src[i * 16..i * 16 + 16]); - unpad.push(*block).unwrap(); - } - assert_eq!(unpad.finish(), b"This is a test text"); - - // 32 bytes - let mut src = [0; 48]; - let mut dst = [0; 32]; - src[..32].copy_from_slice(b"1234567812345678ABCDEFGHIJKLMNOP"); - src[32..].fill(16); - let mut unpad = Pkcs7Unpadder::new(&mut dst, 3); - for i in 0..3 { - let block = Block::from_slice(&src[i * 16..i * 16 + 16]); - unpad.push(*block).unwrap(); - } - assert_eq!(unpad.finish(), &src[..32]); - } - - #[test] - fn test_pkcs7_pad_reversible() { - let text = "1234567890ABCDEF"; - let mut buffer = [0; 256]; - - for i in 0..text.len() { - let text = &text.as_bytes()[..i]; - let output = pad_unpad(text, &mut buffer); - assert_eq!(output, text); - } - } - - #[test] - fn test_aes256cbc_encrypt_decrypt() { - let messages = [&b"Hello"[..], b"This is a test text", b"Another message!"]; - let key = b"1234ABCD1234ABCD1234ABCD1234ABCD"; - let mut encrypted = [0; 256]; - let mut decrypted = [0; 256]; - let mut enc_cipher = Aes256BlockCipher::new( - key, - CipherModeCbc { - iv_encrypt: [0; 16].into(), - iv_decrypt: [0; 16].into(), - }, - ) - .unwrap(); - let mut dec_cipher = Aes256BlockCipher::new( - key, - CipherModeCbc { - iv_encrypt: [0; 16].into(), - iv_decrypt: [0; 16].into(), - }, - ) - .unwrap(); - - for text in messages { - let len = enc_cipher.encrypt(text, &mut encrypted).unwrap(); - let len = dec_cipher - .decrypt(&encrypted[..len], &mut decrypted) - .unwrap(); - assert_eq!(&decrypted[..len], text); - } - } - - #[test] - fn test_aes256cbc_large_message() { - let data = include_bytes!("../../tests/test-message.dat"); - let key = b"1234ABCD1234ABCD1234ABCD1234ABCD"; - let mut encrypt_buffer = [0; 512]; - let mut decrypt_buffer = [0; 512]; - let mut enc_cipher = Aes256BlockCipher::new( - key, - CipherModeCbc { - iv_encrypt: [0; 16].into(), - iv_decrypt: [0; 16].into(), - }, - ) - .unwrap(); - let mut dec_cipher = Aes256BlockCipher::new( - key, - CipherModeCbc { - iv_encrypt: [0; 16].into(), - iv_decrypt: [0; 16].into(), - }, - ) - .unwrap(); - - let mut position = 0; - while position < data.len() { - let count = (data.len() - position).min(400); - let enc_len = enc_cipher - .encrypt(&data[position..position + count], &mut encrypt_buffer) - .unwrap(); - let dec_len = dec_cipher - .decrypt(&encrypt_buffer[..enc_len], &mut decrypt_buffer) - .unwrap(); - - assert_eq!(dec_len, count); - assert_eq!( - &decrypt_buffer[..dec_len], - &data[position..position + count] - ); - - position += count; - } - } -} diff --git a/userspace/tools/rsh/src/crypt/symmetric/aes_gcm.rs b/userspace/tools/rsh/src/crypt/symmetric/aes_gcm.rs new file mode 100644 index 00000000..672bc5b7 --- /dev/null +++ b/userspace/tools/rsh/src/crypt/symmetric/aes_gcm.rs @@ -0,0 +1,122 @@ +use std::marker::PhantomData; + +use aead::{KeyInit, KeySizeUser}; +use aes::cipher::{BlockCipher, BlockEncrypt}; +use aes_gcm::{aead::AeadMutInPlace, AesGcm}; +use crypto_common::{typenum, BlockSizeUser}; + +use crate::crypt::symmetric::{ + data::{ + DecryptionBuffer, EncryptionBuffer, InboundEncryptedMessage, InboundPlainMessage, Iv, + Nonce, NonceSize, OutboundEncryptedMessage, OutboundPlainMessage, PrefixedPayload, + }, + make_aad, CiphersuiteAlgorithm, Decrypter, Encrypter, Error, +}; + +pub struct AesGcmAlgorithm(PhantomData); +pub struct AesGcmCipher(AesGcm, Iv); + +pub const AES_GCM_OVERHEAD: usize = 16; + +pub static AES128GCM: AesGcmAlgorithm = AesGcmAlgorithm(PhantomData); +pub static AES256GCM: AesGcmAlgorithm = AesGcmAlgorithm(PhantomData); + +impl CiphersuiteAlgorithm for AesGcmAlgorithm +where + Aes: BlockCipher + BlockEncrypt + BlockSizeUser + KeyInit + 'static, + AesGcm: KeySizeUser + KeyInit, +{ + fn encrypter(&self, key: &[u8], iv: &Iv) -> Result, Error> { + Ok(Box::new(AesGcmCipher::( + AesGcm::new_from_slice(key).map_err(|_| Error::InvalidKeySize)?, + *iv, + ))) + } + + fn decrypter(&self, key: &[u8], iv: &Iv) -> Result, Error> { + Ok(Box::new(AesGcmCipher::( + AesGcm::new_from_slice(key).map_err(|_| Error::InvalidKeySize)?, + *iv, + ))) + } + + fn key_size(&self) -> usize { + AesGcm::::key_size() + } +} + +impl Encrypter for AesGcmCipher +where + Aes: BlockCipher + BlockEncrypt + BlockSizeUser, +{ + fn encrypt( + &mut self, + seq: u64, + msg: OutboundPlainMessage<'_>, + ) -> Result { + let total_len = msg.payload.len() + AES_GCM_OVERHEAD; + let mut payload = PrefixedPayload::with_capacity(total_len); + + payload.extend_from_slice(msg.payload); + let nonce = aes_gcm::Nonce::from(Nonce::new(&self.1, seq).0); + let aad = make_aad(total_len, seq); + + self.0 + .encrypt_in_place(&nonce, &aad, &mut EncryptionBuffer(&mut payload)) + .map_err(|_| Error::EncryptError) + .map(|_| OutboundEncryptedMessage::new(payload)) + } +} + +impl Decrypter for AesGcmCipher +where + Aes: BlockCipher + BlockEncrypt + BlockSizeUser, +{ + fn decrypt<'a>( + &mut self, + seq: u64, + mut msg: InboundEncryptedMessage<'a>, + ) -> Result, Error> { + let payload = &msg.payload; + let nonce = aes_gcm::Nonce::from(Nonce::new(&self.1, seq).0); + let aad = make_aad(payload.len(), seq); + let payload = &mut msg.payload; + + self.0 + .decrypt_in_place(&nonce, &aad, &mut DecryptionBuffer(payload)) + .map_err(|_| Error::DecryptError)?; + + Ok(msg.into_plain_message()) + } +} + +#[cfg(test)] +mod tests { + use crate::crypt::symmetric::{ + aes_gcm::{AES128GCM, AES256GCM}, + data::{InboundEncryptedMessage, Iv, OutboundPlainMessage}, + CiphersuiteAlgorithm, + }; + + #[test] + fn test_aes_gcm_encrypt_decrypt() { + const ALGORITHMS: &[(&dyn CiphersuiteAlgorithm, &[u8])] = &[ + (&AES128GCM, b"1234123412341234"), + (&AES256GCM, b"12341234123412345678567856785678"), + ]; + let plaintext = b"This is a message"; + let iv = Iv::from(*b"1111111111111111"); + + for (aes, key) in ALGORITHMS { + let mut encrypter = aes.encrypter(key, &iv).unwrap(); + let mut decrypter = aes.decrypter(key, &iv).unwrap(); + + let out_plain = OutboundPlainMessage::new(plaintext); + let mut out_encrypted = encrypter.encrypt(1234, out_plain).unwrap(); + let in_encrypted = InboundEncryptedMessage::from(&mut out_encrypted.payload[..]); + let in_plain = decrypter.decrypt(1234, in_encrypted).unwrap(); + + assert_eq!(in_plain.payload, plaintext); + } + } +} diff --git a/userspace/tools/rsh/src/crypt/symmetric/chacha20poly1305.rs b/userspace/tools/rsh/src/crypt/symmetric/chacha20poly1305.rs new file mode 100644 index 00000000..62937812 --- /dev/null +++ b/userspace/tools/rsh/src/crypt/symmetric/chacha20poly1305.rs @@ -0,0 +1,102 @@ +use aead::{AeadMutInPlace, KeyInit, KeySizeUser}; +use chacha20poly1305::ChaCha20Poly1305; + +use crate::crypt::symmetric::{ + data::{ + DecryptionBuffer, EncryptionBuffer, InboundEncryptedMessage, InboundPlainMessage, Iv, + Nonce, OutboundEncryptedMessage, OutboundPlainMessage, PrefixedPayload, + }, + make_aad, CiphersuiteAlgorithm, Decrypter, Encrypter, Error, +}; + +pub struct Chacha20Poly1305Algorithm; +pub struct Chacha20Poly1305Cipher(ChaCha20Poly1305, Iv); + +pub const CHACHA20POLY1305_OVERHEAD: usize = 16; + +pub static CHACHA20POLY1305: Chacha20Poly1305Algorithm = Chacha20Poly1305Algorithm; + +impl CiphersuiteAlgorithm for Chacha20Poly1305Algorithm { + fn encrypter(&self, key: &[u8], iv: &Iv) -> Result, Error> { + Ok(Box::new(Chacha20Poly1305Cipher( + ChaCha20Poly1305::new_from_slice(key).map_err(|_| Error::InvalidKeySize)?, + *iv, + ))) + } + + fn decrypter(&self, key: &[u8], iv: &Iv) -> Result, Error> { + Ok(Box::new(Chacha20Poly1305Cipher( + ChaCha20Poly1305::new_from_slice(key).map_err(|_| Error::InvalidKeySize)?, + *iv, + ))) + } + + fn key_size(&self) -> usize { + ChaCha20Poly1305::key_size() + } +} + +impl Encrypter for Chacha20Poly1305Cipher { + fn encrypt( + &mut self, + seq: u64, + msg: OutboundPlainMessage<'_>, + ) -> Result { + let total_len = msg.payload.len() + CHACHA20POLY1305_OVERHEAD; + let mut payload = PrefixedPayload::with_capacity(total_len); + + payload.extend_from_slice(msg.payload); + let nonce = chacha20poly1305::Nonce::from(Nonce::new(&self.1, seq).0); + let aad = make_aad(total_len, seq); + + self.0 + .encrypt_in_place(&nonce, &aad, &mut EncryptionBuffer(&mut payload)) + .map_err(|_| Error::EncryptError) + .map(|_| OutboundEncryptedMessage::new(payload)) + } +} + +impl Decrypter for Chacha20Poly1305Cipher { + fn decrypt<'a>( + &mut self, + seq: u64, + mut msg: InboundEncryptedMessage<'a>, + ) -> Result, Error> { + let payload = &msg.payload; + let nonce = chacha20poly1305::Nonce::from(Nonce::new(&self.1, seq).0); + let aad = make_aad(payload.len(), seq); + let payload = &mut msg.payload; + + self.0 + .decrypt_in_place(&nonce, &aad, &mut DecryptionBuffer(payload)) + .map_err(|_| Error::DecryptError)?; + + Ok(msg.into_plain_message()) + } +} + +#[cfg(test)] +mod tests { + use crate::crypt::symmetric::{ + chacha20poly1305::CHACHA20POLY1305, + data::{InboundEncryptedMessage, Iv, OutboundPlainMessage}, + CiphersuiteAlgorithm, + }; + + #[test] + fn test_encrypt_decrypt() { + let plaintext = b"This is a message"; + let key = b"11112222333344441111222233334444"; + let iv = Iv::from(*b"1111111111111111"); + + let mut encrypter = CHACHA20POLY1305.encrypter(key, &iv).unwrap(); + let mut decrypter = CHACHA20POLY1305.decrypter(key, &iv).unwrap(); + + let out_plain = OutboundPlainMessage::new(plaintext); + let mut out_encrypted = encrypter.encrypt(1234, out_plain).unwrap(); + let in_encrypted = InboundEncryptedMessage::from(&mut out_encrypted.payload[..]); + let in_plain = decrypter.decrypt(1234, in_encrypted).unwrap(); + + assert_eq!(in_plain.payload, plaintext); + } +} diff --git a/userspace/tools/rsh/src/crypt/symmetric/data.rs b/userspace/tools/rsh/src/crypt/symmetric/data.rs new file mode 100644 index 00000000..51e875ac --- /dev/null +++ b/userspace/tools/rsh/src/crypt/symmetric/data.rs @@ -0,0 +1,206 @@ +use std::ops::{Deref, DerefMut}; + +use crypto_common::typenum; + +use crate::crypt::V1_CLIENT_RANDOM_LEN; + +pub type NonceSize = typenum::U12; +pub const NONCE_SIZE: usize = 12; + +pub struct Nonce(pub [u8; NONCE_SIZE]); +#[derive(Clone, Copy)] +pub struct Iv(pub [u8; V1_CLIENT_RANDOM_LEN]); + +pub struct BorrowedPayload<'a>(&'a mut [u8]); +pub struct PrefixedPayload(Vec); + +pub struct EncryptionBuffer<'a>(pub(crate) &'a mut PrefixedPayload); +pub struct DecryptionBuffer<'a, 'p>(pub(crate) &'a mut BorrowedPayload<'p>); + +pub struct InboundPlainMessage<'a> { + pub payload: &'a [u8], +} + +pub struct InboundEncryptedMessage<'a> { + pub payload: BorrowedPayload<'a>, +} + +pub struct OutboundPlainMessage<'a> { + pub payload: &'a [u8], +} + +pub struct OutboundEncryptedMessage { + pub payload: PrefixedPayload, +} + +impl From<[u8; V1_CLIENT_RANDOM_LEN]> for Iv { + fn from(value: [u8; V1_CLIENT_RANDOM_LEN]) -> Self { + Self(value) + } +} + +impl From<[u8; NONCE_SIZE]> for Nonce { + fn from(value: [u8; NONCE_SIZE]) -> Self { + Self(value) + } +} + +impl Nonce { + pub fn new(iv: &Iv, seq: u64) -> Self { + let mut data = [0; NONCE_SIZE]; + data[0..8].copy_from_slice(&seq.to_be_bytes()); + data.iter_mut().zip(iv.0.iter()).for_each(|(data, iv)| { + *data ^= *iv; + }); + Self(data) + } +} + +impl<'a> InboundEncryptedMessage<'a> { + pub fn into_plain_message(self) -> InboundPlainMessage<'a> { + InboundPlainMessage { + payload: self.payload.into_inner(), + } + } +} + +impl<'a> From<&'a mut [u8]> for InboundEncryptedMessage<'a> { + fn from(value: &'a mut [u8]) -> Self { + Self { + payload: BorrowedPayload(value), + } + } +} + +impl<'a> OutboundPlainMessage<'a> { + pub fn new(payload: &'a [u8]) -> Self { + Self { payload } + } +} + +impl OutboundEncryptedMessage { + pub fn new(payload: PrefixedPayload) -> Self { + Self { payload } + } +} + +impl<'a> BorrowedPayload<'a> { + pub fn new(payload: &'a mut [u8]) -> Self { + Self(payload) + } + + pub fn into_inner(self) -> &'a mut [u8] { + self.0 + } + + pub fn truncate(&mut self, len: usize) { + if len >= self.len() { + return; + } + + self.0 = core::mem::take(&mut self.0).split_at_mut(len).0; + } +} + +impl Deref for BorrowedPayload<'_> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl DerefMut for BorrowedPayload<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0 + } +} + +impl PrefixedPayload { + pub fn with_capacity(capacity: usize) -> Self { + Self(Vec::with_capacity(capacity)) + } + + pub fn extend_from_slice(&mut self, slice: &[u8]) { + self.0.extend_from_slice(slice); + } + + pub fn truncate(&mut self, len: usize) { + self.0.truncate(len); + } +} + +impl Deref for PrefixedPayload { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for PrefixedPayload { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl AsRef<[u8]> for EncryptionBuffer<'_> { + fn as_ref(&self) -> &[u8] { + &self.0[..] + } +} + +impl AsMut<[u8]> for EncryptionBuffer<'_> { + fn as_mut(&mut self) -> &mut [u8] { + &mut self.0[..] + } +} + +impl aead::Buffer for EncryptionBuffer<'_> { + fn len(&self) -> usize { + self.0.len() + } + + fn is_empty(&self) -> bool { + self.0.is_empty() + } + + fn truncate(&mut self, len: usize) { + self.0.truncate(len); + } + + fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> { + self.0.extend_from_slice(other); + Ok(()) + } +} + +impl AsRef<[u8]> for DecryptionBuffer<'_, '_> { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + +impl AsMut<[u8]> for DecryptionBuffer<'_, '_> { + fn as_mut(&mut self) -> &mut [u8] { + self.0.as_mut() + } +} + +impl aead::Buffer for DecryptionBuffer<'_, '_> { + fn len(&self) -> usize { + self.0.len() + } + + fn is_empty(&self) -> bool { + self.0.is_empty() + } + + fn truncate(&mut self, len: usize) { + self.0.truncate(len); + } + + fn extend_from_slice(&mut self, _other: &[u8]) -> aead::Result<()> { + todo!("Not used in decryption") + } +} diff --git a/userspace/tools/rsh/src/crypt/symmetric/mod.rs b/userspace/tools/rsh/src/crypt/symmetric/mod.rs new file mode 100644 index 00000000..a98b3d79 --- /dev/null +++ b/userspace/tools/rsh/src/crypt/symmetric/mod.rs @@ -0,0 +1,92 @@ +use crate::crypt::symmetric::data::{ + InboundEncryptedMessage, InboundPlainMessage, Iv, OutboundEncryptedMessage, + OutboundPlainMessage, +}; + +pub mod aes_gcm; +pub mod chacha20poly1305; +pub mod data; + +pub const V1_CIPHER_AES_128_GCM: u8 = 0x10; +pub const V1_CIPHER_AES_256_GCM: u8 = 0x11; +pub const V1_CIPHER_CHACHA20POLY1305: u8 = 0x12; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Invalid symmetric ciphersuite: {0:#02x}")] + InvalidCiphersuite(u8), + #[error("Invalid key size")] + InvalidKeySize, + #[error("Decryption error")] + DecryptError, + #[error("Encryption error")] + EncryptError, +} + +pub struct SymmetricCipher { + pub encrypter: Box, + pub decrypter: Box, +} + +pub trait CiphersuiteAlgorithm { + fn encrypter(&self, key: &[u8], iv: &Iv) -> Result, Error>; + fn decrypter(&self, key: &[u8], iv: &Iv) -> Result, Error>; + fn key_size(&self) -> usize; +} + +pub trait Encrypter { + fn encrypt( + &mut self, + seq: u64, + msg: OutboundPlainMessage<'_>, + ) -> Result; +} + +pub trait Decrypter { + fn decrypt<'a>( + &mut self, + seq: u64, + msg: InboundEncryptedMessage<'a>, + ) -> Result, Error>; +} + +pub fn make_aad(total_len: usize, seq: u64) -> [u8; 4] { + [ + total_len as u8, + (total_len >> 8) as u8, + seq as u8, + (seq >> 8) as u8, + ] +} + +pub fn ciphersuite(id: u8) -> Option<(&'static dyn CiphersuiteAlgorithm, &'static str)> { + match id { + V1_CIPHER_AES_128_GCM => Some((&aes_gcm::AES128GCM, "aes-128-gcm")), + V1_CIPHER_AES_256_GCM => Some((&aes_gcm::AES256GCM, "aes-256-gcm")), + V1_CIPHER_CHACHA20POLY1305 => { + Some((&chacha20poly1305::CHACHA20POLY1305, "chacha20poly1305")) + } + _ => None, + } +} + +impl SymmetricCipher { + pub fn from_ciphersuite( + ciphersuite: &dyn CiphersuiteAlgorithm, + key: &[u8], + iv: Iv, + ) -> Result { + let encrypter = ciphersuite.encrypter(key, &iv)?; + let decrypter = ciphersuite.decrypter(key, &iv)?; + + Ok(Self { + encrypter, + decrypter, + }) + } + + pub fn from_ciphersuite_id(id: u8, key: &[u8], iv: Iv) -> Result { + let ciphersuite = ciphersuite(id).ok_or(Error::InvalidCiphersuite(id))?.0; + Self::from_ciphersuite(ciphersuite, key, iv) + } +} diff --git a/userspace/tools/rsh/src/main.rs b/userspace/tools/rsh/src/main.rs index 4f01e274..67d2995f 100644 --- a/userspace/tools/rsh/src/main.rs +++ b/userspace/tools/rsh/src/main.rs @@ -11,9 +11,10 @@ use clap::Parser; use cross::io::Poll; use rsh::{ crypt::{ - client::{self, ClientSocket, Message}, + client, config::{ClientConfig, SimpleClientKeyStore}, signature::{SignEd25519, SignatureMethod}, + stream::{self, ClientSocket, Message}, }, proto::{ServerMessage, StreamIndex}, }; @@ -39,8 +40,10 @@ pub enum Error { Disconnected(String), #[error("Timed out")] Timeout, - #[error("Socket error: {0}")] - Socket(#[from] client::Error), + #[error("Client error: {0}")] + Client(#[from] client::Error), + #[error("Stream error: {0}")] + Stream(#[from] stream::Error), #[error("Aborted by user")] Abort, } diff --git a/userspace/tools/rsh/src/server.rs b/userspace/tools/rsh/src/server.rs index 5d74337d..eafa5fa2 100644 --- a/userspace/tools/rsh/src/server.rs +++ b/userspace/tools/rsh/src/server.rs @@ -13,9 +13,10 @@ use cross::io::{PidFd, Pipe, Poll}; use crate::{ crypt::{ - client::{self, ClientSocket}, + client, config::ServerConfig, server::{self, ServerSocket}, + stream::{self, ClientSocket}, }, proto::{ClientMessage, ServerMessage, StreamIndex, TerminalInfo}, }; @@ -28,8 +29,10 @@ pub enum Error { Io(#[from] io::Error), #[error("Socket error: {0}")] Socket(#[from] server::Error), - #[error("Client socket error: {0}")] + #[error("Client error: {0}")] Client(#[from] client::Error), + #[error("Client stream error: {0}")] + Stream(#[from] stream::Error), } pub trait Session: Sized {