diff --git a/userspace/Cargo.lock b/userspace/Cargo.lock index 124b93ea..664a4951 100644 --- a/userspace/Cargo.lock +++ b/userspace/Cargo.lock @@ -27,12 +27,64 @@ dependencies = [ "cpufeatures", ] +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "anstream" +version = "0.6.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23a1e53f0f5d86382dafe1cf314783b2044280f406e7e1506368220ad11b1338" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + [[package]] name = "anstyle" version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +dependencies = [ + "anstyle", + "windows-sys 0.59.0", +] + [[package]] name = "arrayvec" version = "0.7.6" @@ -170,6 +222,12 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + [[package]] name = "colors" version = "0.1.0" @@ -337,6 +395,29 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4445909572dbd556c457c849c4ca58623d84b27c8fff1e74b0b4227d8b90d17b" +[[package]] +name = "env_filter" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -505,6 +586,12 @@ dependencies = [ "libm", ] +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "iced-x86" version = "1.21.0" @@ -552,6 +639,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itoa" version = "1.0.11" @@ -1018,6 +1111,35 @@ dependencies = [ "bitflags 2.6.0", ] +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + [[package]] name = "rsa" version = "0.9.6" @@ -1045,7 +1167,9 @@ dependencies = [ "bytemuck", "clap", "cross", + "env_logger", "libterm", + "log", "rand 0.8.5", "thiserror", "x25519-dalek", @@ -1475,6 +1599,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "version_check" version = "0.9.5" diff --git a/userspace/rsh/Cargo.toml b/userspace/rsh/Cargo.toml index ff754904..762cddd4 100644 --- a/userspace/rsh/Cargo.toml +++ b/userspace/rsh/Cargo.toml @@ -17,3 +17,5 @@ bytemuck.workspace = true x25519-dalek.workspace = true rand = { git = "https://git.alnyan.me/yggdrasil/rand.git", branch = "alnyan/yggdrasil-rng_core-0.6.4" } aes = { version = "0.8.4" } +log = "0.4.22" +env_logger = "0.11.5" diff --git a/userspace/rsh/src/crypt/client.rs b/userspace/rsh/src/crypt/client.rs index 9f1dd91b..a3e2130d 100644 --- a/userspace/rsh/src/crypt/client.rs +++ b/userspace/rsh/src/crypt/client.rs @@ -1,40 +1,75 @@ use std::{ net::SocketAddr, - os::fd::{AsRawFd, RawFd}, time::Instant, + os::fd::{AsRawFd, RawFd}, + time::{Duration, Instant}, }; -use aes::{cipher::KeyInit, Aes256}; -use rand::prelude::Distribution; -use x25519_dalek::{EphemeralSecret, PublicKey}; +use cross::io::Poll; +use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret}; use crate::{ - crypt::{ - decrypt_blocked, encrypt_blocked, ClientNegotiationMessage, ServerNegotiationMessage, - V1_CIPHER_AES_256_CBC, V1_KEX_X25519_DALEK, - }, + crypt::{ClientNegotiationMessage, ServerNegotiationMessage, V1_CIPHER_AES_256_CBC, V1_KEX_X25519_DALEK}, socket::{MessageSocket, PacketSocket, SocketWrapper}, Error, }; -use super::{ClientMessageProxy, ServerMessageProxy}; +use super::{ + symmetric::SymmetricCipher, ClientMessageProxy, ServerMessageProxy, V1_CIPHER_AES_256_ECB, +}; + +pub struct ClientConfig { + pub select_ciphersuite: fn(&[u8]) -> Option<u8>, +} enum ClientState { PreNegotioation, - Connected(Aes256), + StartKexSent(u8, u8), + ClientKeySent(u8, EphemeralSecret), + ServerKeyReceived(u8, SharedSecret), + Connected(SymmetricCipher), } pub struct ClientEncryptedSocket<S: PacketSocket> { transport: SocketWrapper<S, ServerMessageProxy, ClientMessageProxy>, - state: ClientState, - last_ping: Instant + state: Option<ClientState>, + last_ping: Instant, + config: ClientConfig, +} + +fn select_ciphersuite_default(offered: &[u8]) -> Option<u8> { + // List of default accepted ciphers, ordered descending by their priority + const ACCEPTED: &[u8] = &[ + V1_CIPHER_AES_256_CBC, + V1_CIPHER_AES_256_ECB, + ]; + + for cipher in ACCEPTED { + if offered.contains(cipher) { + return Some(*cipher); + } + } + None +} + +impl Default for ClientConfig { + fn default() -> Self { + Self { + select_ciphersuite: select_ciphersuite_default + } + } } impl<S: PacketSocket> ClientEncryptedSocket<S> { pub fn new(transport: S) -> Self { + Self::new_with_config(transport, Default::default()) + } + + pub fn new_with_config(transport: S, config: ClientConfig) -> Self { Self { transport: SocketWrapper::new(transport), - state: ClientState::PreNegotioation, + state: None, last_ping: Instant::now(), + config } } @@ -42,122 +77,160 @@ impl<S: PacketSocket> ClientEncryptedSocket<S> { self.last_ping } - pub fn try_connect_blocking(&mut self) -> Result<(), Error> { - let ClientState::PreNegotioation = self.state else { + pub fn try_connect_blocking( + &mut self, + poll: &mut Poll, + timeout: Duration, + ) -> Result<(), Error> { + if self.state.is_some() { return Err(Error::InvalidState); }; - let mut buf = [0; 256]; - // Send Hello - self.transport - .send(&mut buf, &ClientNegotiationMessage::Hello { protocol: 1 })?; + self.state = Some(ClientState::PreNegotioation); + self.do_handshake_sequence(poll, timeout) + } - // Wait for Server Hello - let (message, _) = self.transport.recv_from(&mut buf)?; - let _hello = match message { - ServerNegotiationMessage::Hello(hello) => hello, - ServerNegotiationMessage::Reject(_reason) => { - return Err(Error::Disconnected); - }, - _ => { - return Err(Error::Disconnected); - }, - }; + fn do_handshake_sequence(&mut self, poll: &mut Poll, timeout: Duration) -> Result<(), Error> { + let mut recv_buf = [0; 256]; - // TODO select kex, ciphersuite - let ciphersuite = V1_CIPHER_AES_256_CBC; - let kex_algo = V1_KEX_X25519_DALEK; - - // Initiate key exchange self.transport.send( - &mut buf, - &ClientNegotiationMessage::StartKex { - kex_algo, - ciphersuite, - }, + &mut recv_buf, + &ClientNegotiationMessage::Hello { protocol: 1 }, )?; - // Wait for server to accept - let (message, _) = self.transport.recv_from(&mut buf)?; - match message { - ServerNegotiationMessage::StartKex => (), - ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected), - _ => return Err(Error::Disconnected), - }; + loop { + let Some(fd) = poll.wait(Some(timeout))? else { + return Err(Error::Timeout); + }; + assert_eq!(fd, self.transport.as_raw_fd()); - // Generate an ephemeral key - let mut rng = rand::thread_rng(); - let secret = EphemeralSecret::random_from_rng(&mut rng); - let public = PublicKey::from(&secret).to_bytes(); - - // Send it to the server - self.transport.send( - &mut buf, - &ClientNegotiationMessage::PublicKey(true, &public), - )?; - - // Wait for server to respond with its own key - let (message, _) = self.transport.recv_from(&mut buf)?; - let mut remote = [0; 32]; - match message { - ServerNegotiationMessage::PublicKey(true, key) if key.len() == 32 => { - remote.copy_from_slice(key) + let (message, _) = self.transport.recv_from(&mut recv_buf)?; + if self.update_handshake_sequence(&message)? { + break; } - ServerNegotiationMessage::PublicKey(_, _key) => return Err(Error::Disconnected), - ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected), - _ => return Err(Error::Disconnected), - }; - let remote = PublicKey::from(remote); - - let shared = secret.diffie_hellman(&remote); - - // Negotiation done - self.transport - .send(&mut buf, &ClientNegotiationMessage::Agreed)?; - - // Wait for server's Agreed - let (message, _) = self.transport.recv_from(&mut buf)?; - match message { - ServerNegotiationMessage::Agreed => (), - ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected), - _ => return Err(Error::Disconnected), } - let aes = Aes256::new_from_slice(shared.as_bytes()).unwrap(); - - self.state = ClientState::Connected(aes); + assert!(matches!(self.state, Some(ClientState::Connected(_)))); Ok(()) } + + fn update_handshake_sequence( + &mut self, + message: &ServerNegotiationMessage, + ) -> Result<bool, Error> { + let mut send_buf = [0; 256]; + match (message, self.state.take()) { + // PreNegotioation -> StartKexSent + (ServerNegotiationMessage::Hello(info), Some(ClientState::PreNegotioation)) => { + // TODO kex algorithm selection + let kex_algo = V1_KEX_X25519_DALEK; + let ciphersuite = match (self.config.select_ciphersuite)(info.symmetric_ciphersuites) { + // Picked a ciphersuite + Some(ciphersuite) => ciphersuite, + // Server didn't offer anything acceptable + None => { + return Err(Error::UnacceptableCiphersuites); + } + }; + self.transport.send( + &mut send_buf, + &ClientNegotiationMessage::StartKex { + kex_algo, + ciphersuite, + }, + )?; + self.state = Some(ClientState::StartKexSent(kex_algo, ciphersuite)); + + Ok(false) + } + // StartKexSent -> ClientKeySent + ( + ServerNegotiationMessage::StartKex, + Some(ClientState::StartKexSent(_kex, ciphersuite)), + ) => { + let mut rng = rand::thread_rng(); + let secret = EphemeralSecret::random_from_rng(&mut rng); + let public = PublicKey::from(&secret); + + assert!(public.as_bytes().len() < 128); + self.transport.send( + &mut send_buf, + &ClientNegotiationMessage::PublicKey(true, public.as_bytes()), + )?; + + self.state = Some(ClientState::ClientKeySent(ciphersuite, secret)); + Ok(false) + } + // ClientKeySent -> ServerKeyReceived + ( + ServerNegotiationMessage::PublicKey(true, key_data), + Some(ClientState::ClientKeySent(ciphersuite, secret)), + ) if key_data.len() == 32 => { + let mut public = [0; 32]; + public.copy_from_slice(key_data); + let public = PublicKey::from(public); + let shared = secret.diffie_hellman(&public); + + self.transport + .send(&mut send_buf, &ClientNegotiationMessage::Agreed)?; + self.state = Some(ClientState::ServerKeyReceived(ciphersuite, shared)); + Ok(false) + } + // ServerKeyReceived -> Connected (or fail) + ( + ServerNegotiationMessage::Agreed, + Some(ClientState::ServerKeyReceived(ciphersuite, shared)), + ) => match SymmetricCipher::new(ciphersuite, shared.as_bytes()) { + Ok(cipher) => { + self.state = Some(ClientState::Connected(cipher)); + Ok(true) + } + Err(error) => { + self.transport + .send(&mut send_buf, &ClientNegotiationMessage::Disconnect(0)) + .ok(); + self.state = None; + Err(error) + } + }, + // *** -> None + (ServerNegotiationMessage::Kick(_reason), _) => { + // No need to send any disconnects, server already forgot us + self.state = None; + Err(Error::Disconnected) + } + (_, _) => { + self.transport + .send(&mut send_buf, &ClientNegotiationMessage::Disconnect(0)) + .ok(); + self.state = None; + Err(Error::Disconnected) + } + } + } } impl<S: PacketSocket> PacketSocket for ClientEncryptedSocket<S> { type Error = Error; fn recv_from(&mut self, output: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> { - match &self.state { - ClientState::Connected(aes) => { - assert!(output.len() >= 16); - loop { - let mut buffer = [0; 256]; - let (message, peer) = self.transport.recv_from(&mut buffer)?; - match message { - ServerNegotiationMessage::CipherText(data) => { - let len = decrypt_blocked(aes, data, output)?; - break Ok((len, peer)) - } - ServerNegotiationMessage::Ping => { - self.transport.send(&mut buffer, &ClientNegotiationMessage::Pong)?; - self.last_ping = Instant::now(); - // TODO this is a workaround - break Err(Error::Ping); - } - ServerNegotiationMessage::Kick(_reason) => { - break Err(Error::NotConnected) - }, - // TODO send disconnect and leave - _ => break Err(Error::NotConnected), - } - } + let Some(ClientState::Connected(cipher)) = &mut self.state else { + return Err(Error::NotConnected); + }; + + let mut buffer = [0; 256]; + let (message, peer) = self.transport.recv_from(&mut buffer)?; + match message { + ServerNegotiationMessage::CipherText(ciphertext) => { + let len = cipher.decrypt(ciphertext, output)?; + Ok((len, peer)) + } + ServerNegotiationMessage::Ping => { + self.transport + .send(&mut buffer, &ClientNegotiationMessage::Pong)?; + self.last_ping = Instant::now(); + // TODO this is a workaround + Err(Error::Ping) } _ => Err(Error::NotConnected), } @@ -168,33 +241,28 @@ impl<S: PacketSocket> PacketSocket for ClientEncryptedSocket<S> { } fn send(&mut self, data: &[u8]) -> Result<usize, Self::Error> { + let Some(ClientState::Connected(cipher)) = &mut self.state else { + return Err(Error::NotConnected); + }; + let mut buffer = [0; 256]; let mut send_buffer = [0; 256]; - match &self.state { - ClientState::Connected(aes) => { - let len = encrypt_blocked(aes, data, &mut buffer)?; - assert_eq!(len % 16, 0); - self.transport.send( - &mut send_buffer, - &ClientNegotiationMessage::CipherText(&buffer[..len]), - )?; - Ok(data.len()) - } - _ => Err(Error::NotConnected), - } + let len = cipher.encrypt(data, &mut buffer)?; + self.transport.send( + &mut send_buffer, + &ClientNegotiationMessage::CipherText(&buffer[..len]), + )?; + Ok(data.len()) } } impl<S: PacketSocket> Drop for ClientEncryptedSocket<S> { fn drop(&mut self) { - match self.state { - ClientState::PreNegotioation => (), - _ => { - let mut buffer = [0; 32]; - self.transport - .send(&mut buffer, &ClientNegotiationMessage::Disconnect(0)) - .ok(); - } + if self.state.is_some() { + let mut buffer = [0; 32]; + self.transport + .send(&mut buffer, &ClientNegotiationMessage::Disconnect(0)) + .ok(); } } } diff --git a/userspace/rsh/src/crypt/mod.rs b/userspace/rsh/src/crypt/mod.rs index b124af0d..3f19a796 100644 --- a/userspace/rsh/src/crypt/mod.rs +++ b/userspace/rsh/src/crypt/mod.rs @@ -1,23 +1,15 @@ -use aes::{ - cipher::{BlockDecrypt, BlockEncrypt}, - Aes256, Block, -}; - -use crate::{ - proto::{Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy}, - Error, -}; +use crate::proto::{Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy}; pub mod client; pub mod server; +pub mod symmetric; +pub mod util; pub use client::ClientEncryptedSocket; pub use server::ServerEncryptedSocket; -pub const V1_CIPHER_AES_256_CBC: u8 = 0x10; -pub const V1_CIPHER_AES_256_CTR: u8 = 0x11; -pub const V1_CIPHER_AES_256_GCM: u8 = 0x12; -pub const V1_CIPHER_CHACHA20_POLY1305: u8 = 0x15; +pub const V1_CIPHER_AES_256_ECB: u8 = 0x10; +pub const V1_CIPHER_AES_256_CBC: u8 = 0x11; // v1 supports only one DH algo pub const V1_KEX_X25519_DALEK: u8 = 0x10; @@ -209,54 +201,10 @@ impl<'de> Decode<'de> for ServerNegotiationMessage<'de> { } } -fn decrypt_blocked(aes: &Aes256, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> { - if src.len() % 16 != 0 || dst.len() < src.len() { - todo!(); +pub fn ciphersuite_name(cipher: u8) -> Option<&'static str> { + match cipher { + V1_CIPHER_AES_256_ECB => Some("aes-256-ecb"), + V1_CIPHER_AES_256_CBC => Some("aes-256-cbc"), + _ => None } - - let mut pos = 0; - let mut out = 0; - let mut block = Block::from([0; 16]); - while pos != src.len() { - block.copy_from_slice(&src[pos..pos + 16]); - aes.decrypt_block(&mut block); - - let len = block[0] as usize; - if len > 15 { - todo!(); - } - - dst[out..out + len].copy_from_slice(&block[1..1 + len]); - - out += len; - pos += 16; - } - - Ok(out) -} - -fn encrypt_blocked(aes: &Aes256, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> { - if src.len() >= (dst.len() / 16) * 15 { - todo!(); - } - - let mut pos = 0; - let mut out = 0; - let mut block = Block::from([0; 16]); - while pos != src.len() { - let len = core::cmp::min(src.len() - pos, 15); - block[0] = len as u8; - block[1..1 + len].copy_from_slice(&src[pos..pos + len]); - // TODO pad with random - if len != 15 { - block[1 + len..].fill(0); - } - aes.encrypt_block(&mut block); - - dst[out..out + 16].copy_from_slice(&block); - - pos += len; - out += 16; - } - Ok(out) } diff --git a/userspace/rsh/src/crypt/server.rs b/userspace/rsh/src/crypt/server.rs index 032cffcb..24626590 100644 --- a/userspace/rsh/src/crypt/server.rs +++ b/userspace/rsh/src/crypt/server.rs @@ -5,37 +5,71 @@ use std::{ os::fd::{AsRawFd, RawFd}, }; -use aes::{cipher::KeyInit, Aes256}; use x25519_dalek::{EphemeralSecret, PublicKey}; use crate::{ - crypt::{decrypt_blocked, ServerHello, ServerNegotiationMessage}, + crypt::{ + ciphersuite_name, ServerHello, ServerNegotiationMessage, V1_CIPHER_AES_256_CBC, + V1_CIPHER_AES_256_ECB, + }, socket::{ MessageSocket, MultiplexedSocket, MultiplexedSocketEvent, PacketSocket, SocketWrapper, }, Error, }; -use super::{encrypt_blocked, ClientMessageProxy, ClientNegotiationMessage, ServerMessageProxy}; +use super::{ + symmetric::SymmetricCipher, ClientMessageProxy, ClientNegotiationMessage, ServerMessageProxy, +}; + +pub struct ServerConfig { + pub accept_ciphersuite: fn(u8) -> bool, + pub offer_ciphersuites: fn() -> &'static [u8], +} pub enum ServerPeerTransport { PreNegotiation, - Negotiation(EphemeralSecret), - Connected(Aes256, usize), + Negotiation(EphemeralSecret, u8), + Connected(SymmetricCipher, usize), } pub struct ServerEncryptedSocket<S: PacketSocket> { transport: SocketWrapper<S, ClientMessageProxy, ServerMessageProxy>, peers: HashMap<SocketAddr, ServerPeerTransport>, buffer: [u8; 256], + config: ServerConfig, +} + +const DEFAULT_CIPHERSUITES: &[u8] = &[V1_CIPHER_AES_256_CBC, V1_CIPHER_AES_256_ECB]; + +fn accept_ciphersuite_default(ciphersuite: u8) -> bool { + DEFAULT_CIPHERSUITES.contains(&ciphersuite) +} + +fn offer_ciphersuites_default() -> &'static [u8] { + DEFAULT_CIPHERSUITES +} + +impl Default for ServerConfig { + fn default() -> Self { + Self { + accept_ciphersuite: accept_ciphersuite_default, + offer_ciphersuites: offer_ciphersuites_default, + } + } } impl<S: PacketSocket> ServerEncryptedSocket<S> { pub fn new(transport: S) -> Self { + Self::new_with_config(transport, Default::default()) + } + + pub fn new_with_config(transport: S, config: ServerConfig) -> Self { Self { transport: SocketWrapper::new(transport), peers: HashMap::new(), buffer: [0; 256], + config, } } @@ -54,7 +88,9 @@ impl<S: PacketSocket> ServerEncryptedSocket<S> { for (remote, state) in self.peers.iter_mut() { match state { ServerPeerTransport::Connected(_, missed) => { - self.transport.send_to(remote, &mut send_buf, &ServerNegotiationMessage::Ping).ok(); + self.transport + .send_to(remote, &mut send_buf, &ServerNegotiationMessage::Ping) + .ok(); if *missed >= limit { removed.push(*remote); @@ -77,9 +113,8 @@ impl<S: PacketSocket> ServerEncryptedSocket<S> { impl<S: PacketSocket> MultiplexedSocket for ServerEncryptedSocket<S> { fn send_to(&mut self, remote: &SocketAddr, data: &[u8]) -> Result<(), Error> { let mut buf = [0; 256]; - if let Some(ServerPeerTransport::Connected(aes, _)) = self.peers.get(remote) { - let len = encrypt_blocked(aes, data, &mut self.buffer)?; - assert_eq!(len % 16, 0); + if let Some(ServerPeerTransport::Connected(cipher, _)) = self.peers.get_mut(remote) { + let len = cipher.encrypt(data, &mut self.buffer)?; self.transport.send_to( remote, &mut buf, @@ -99,12 +134,21 @@ impl<S: PacketSocket> MultiplexedSocket for ServerEncryptedSocket<S> { match (message, &mut state) { // TODO check kex params ( - ClientNegotiationMessage::StartKex { .. }, + ClientNegotiationMessage::StartKex { ciphersuite, .. }, ServerPeerTransport::PreNegotiation, ) => { + let name = ciphersuite_name(ciphersuite); + if !(self.config.accept_ciphersuite)(ciphersuite) { + log::warn!("Kicking {remote}: cannot accept offered ciphersuite: {name:?}",); + self.remove_client(&remote); + return Ok(MultiplexedSocketEvent::None(remote)); + } + + log::debug!("{remote}: negotiated a ciphersuite: {name:?}"); + let mut rng = rand::thread_rng(); let secret = EphemeralSecret::random_from_rng(&mut rng); - *state = ServerPeerTransport::Negotiation(secret); + *state = ServerPeerTransport::Negotiation(secret, ciphersuite); self.transport.send_to( &remote, &mut buf, @@ -114,66 +158,77 @@ impl<S: PacketSocket> MultiplexedSocket for ServerEncryptedSocket<S> { } ( ClientNegotiationMessage::PublicKey(true, data), - ServerPeerTransport::Negotiation(secret), + ServerPeerTransport::Negotiation(secret, _), ) if data.len() == 32 => { let public = PublicKey::from(&*secret); let mut remote_key = [0; 32]; remote_key.copy_from_slice(data); let remote_key = PublicKey::from(remote_key); - state.negotiate(&remote_key); + match state.negotiate(&remote_key) { + Ok(()) => { + // Send public key to the client + self.transport.send_to( + &remote, + &mut buf, + &ServerNegotiationMessage::PublicKey(true, public.as_bytes()), + )?; - // Send public key to the client - self.transport.send_to( - &remote, - &mut buf, - &ServerNegotiationMessage::PublicKey(true, public.as_bytes()), - )?; - - Ok(MultiplexedSocketEvent::None(remote)) + Ok(MultiplexedSocketEvent::None(remote)) + } + Err(error) => { + log::warn!( + "Kicking {remote}: couldn't setup requested ciphersuite: {error}" + ); + self.remove_client(&remote); + Ok(MultiplexedSocketEvent::None(remote)) + } + } } (ClientNegotiationMessage::Agreed, ServerPeerTransport::Connected(_, _)) => { + log::debug!("{remote}: negotiated"); self.transport .send_to(&remote, &mut buf, &ServerNegotiationMessage::Agreed)?; Ok(MultiplexedSocketEvent::None(remote)) } ( ClientNegotiationMessage::CipherText(data), - ServerPeerTransport::Connected(aes, _), + ServerPeerTransport::Connected(cipher, _), ) => { - let len = decrypt_blocked(aes, data, buffer)?; + let len = cipher.decrypt(data, buffer)?; Ok(MultiplexedSocketEvent::ClientData(remote, &buffer[..len])) } - ( - ClientNegotiationMessage::Pong, - ServerPeerTransport::Connected(_, missed) - ) => { + (ClientNegotiationMessage::Pong, ServerPeerTransport::Connected(_, missed)) => { *missed = 0; Ok(MultiplexedSocketEvent::None(remote)) } (ClientNegotiationMessage::Disconnect(_reason), _) => { - eprintln!("Peer disconnected: {remote}"); + log::debug!("{remote}: disconnected"); self.peers.remove(&remote); Ok(MultiplexedSocketEvent::ClientDisconnected(remote)) } // Misbehavior _ => { + log::warn!("Kicking {remote}: unexpected message"); self.peers.remove(&remote); Ok(MultiplexedSocketEvent::ClientDisconnected(remote)) } } } else { let ClientNegotiationMessage::Hello { protocol: 1 } = message else { - eprintln!("Unhandled client message"); return Ok(MultiplexedSocketEvent::None(remote)); }; - eprintln!("Client Hello"); self.peers .insert(remote, ServerPeerTransport::PreNegotiation); // Reply with hello + let symmetric_ciphersuites = (self.config.offer_ciphersuites)(); + log::debug!("{remote}: offering ciphersuites:"); + for suite in symmetric_ciphersuites { + log::debug!("* {:?}", ciphersuite_name(*suite)); + } let hello = ServerHello { kex_algos: &[], - symmetric_ciphersuites: &[], + symmetric_ciphersuites, }; self.transport .send_to(&remote, &mut buf, &ServerNegotiationMessage::Hello(hello))?; @@ -190,13 +245,16 @@ impl<S: PacketSocket> AsRawFd for ServerEncryptedSocket<S> { } impl ServerPeerTransport { - fn negotiate(&mut self, public: &PublicKey) { - let Self::Negotiation(secret) = mem::replace(self, ServerPeerTransport::PreNegotiation) + fn negotiate(&mut self, public: &PublicKey) -> Result<(), Error> { + let Self::Negotiation(secret, symmetric) = + mem::replace(self, ServerPeerTransport::PreNegotiation) else { panic!(); }; let shared = secret.diffie_hellman(public); - let aes = Aes256::new_from_slice(shared.as_bytes()).unwrap(); - *self = ServerPeerTransport::Connected(aes, 0); + let cipher = SymmetricCipher::new(symmetric, shared.as_bytes())?; + // let aes = Aes256::new_from_slice(shared.as_bytes()).unwrap(); + *self = ServerPeerTransport::Connected(cipher, 0); + Ok(()) } } diff --git a/userspace/rsh/src/crypt/symmetric.rs b/userspace/rsh/src/crypt/symmetric.rs new file mode 100644 index 00000000..c599795d --- /dev/null +++ b/userspace/rsh/src/crypt/symmetric.rs @@ -0,0 +1,321 @@ +use aes::{ + cipher::{BlockDecrypt, BlockEncrypt, KeyInit}, + Aes256, Block, +}; + +use crate::{crypt::util, Error}; + +use super::{V1_CIPHER_AES_256_CBC, V1_CIPHER_AES_256_ECB}; + +pub trait AesBlockMode { + fn encryption(&mut self, aes: &Aes256, block: Block) -> Block; + fn decryption(&mut self, aes: &Aes256, block: Block) -> Block; +} + +pub struct Aes256BlockCipher<M: AesBlockMode> { + aes: Aes256, + mode: M, +} + +pub struct CipherModeEcb; +pub struct CipherModeCbc { + iv: Block, +} + +pub enum SymmetricCipher { + Aes256Ecb(Aes256BlockCipher<CipherModeEcb>), + Aes256Cbc(Aes256BlockCipher<CipherModeCbc>), +} + +pub struct Pkcs7Padder<'src> { + block: Block, + src: &'src [u8], + extra: bool, +} + +pub struct Pkcs7Unpadder<'dst> { + pos: usize, + dst: &'dst mut [u8], + block_count: usize, + index: usize, +} + +impl<'src> Pkcs7Padder<'src> { + pub fn new(src: &'src [u8]) -> Self { + Self { + block: Block::from([0; 16]), + extra: false, + src, + } + } +} + +impl<'dst> Pkcs7Unpadder<'dst> { + pub fn new(dst: &'dst mut [u8], block_count: usize) -> Self { + Self { + pos: 0, + dst, + block_count, + index: 0, + } + } + + fn write_dst(&mut self, data: &[u8]) -> Result<(), Error> { + if self.pos + data.len() > self.dst.len() { + Err(Error::MessageTooLarge( + self.dst.len(), + self.pos + data.len(), + )) + } else { + self.dst[self.pos..self.pos + data.len()].copy_from_slice(data); + self.pos += data.len(); + Ok(()) + } + } + + pub fn push(&mut self, block: Block) -> Result<(), Error> { + // Cases: 1 padded block + // 1 full block, 1 16b + // 2 full blocks, 1 16b + // 1 full, 1 padded + if self.index >= self.block_count { + return Ok(()); + } + match self.block_count { + 0 => return Ok(()), + _ if self.index < self.block_count - 1 => { + self.write_dst(&block)?; + } + // Last block + _ => { + if &*block != &[16; 16] { + let pad = block[15] as usize; + if pad > 15 { + return Err(Error::InvalidCiphertext); + } + let len = 16 - pad; + self.write_dst(&block[..len])?; + } + } + } + self.index += 1; + Ok(()) + } + + pub fn finish(self) -> &'dst mut [u8] { + &mut self.dst[..self.pos] + } +} + +impl Iterator for Pkcs7Padder<'_> { + type Item = Block; + + fn next(&mut self) -> Option<Self::Item> { + if self.src.is_empty() { + if self.extra { + self.block.fill(16); + self.extra = false; + Some(self.block) + } else { + None + } + } else { + let len = core::cmp::min(self.src.len(), 16); + // Need an extra block to indicate that the last block was full + self.extra = len == 16; + + self.block[..len].copy_from_slice(&self.src[..len]); + self.block[len..].fill(16 - len as u8); + self.src = &self.src[len..]; + + Some(self.block) + } + } +} + +impl<M: AesBlockMode> Aes256BlockCipher<M> { + pub fn new(key: &[u8], mode: M) -> Result<Self, Error> { + let aes = Aes256::new_from_slice(key).map_err(|_| Error::InvalidKey)?; + Ok(Self { aes, mode }) + } + + pub fn encrypt(&mut self, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> { + let full_len = (src.len() + 16) & !15; + if dst.len() < full_len { + return Err(Error::MessageTooLarge(dst.len(), full_len)); + } + let mut dst_pos = 0; + for block in Pkcs7Padder::new(src) { + let block = self.mode.encryption(&self.aes, block); + dst[dst_pos..dst_pos + 16].copy_from_slice(&block[..]); + dst_pos += 16; + } + debug_assert_eq!(full_len, dst_pos); + Ok(full_len) + } + + pub fn decrypt(&mut self, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> { + if src.len() % 16 != 0 { + return Err(Error::InvalidCiphertext); + } + let mut unpadder = Pkcs7Unpadder::new(dst, src.len() / 16); + for i in 0..src.len() / 16 { + let block = Block::clone_from_slice(&src[i * 16..i * 16 + 16]); + let block = self.mode.decryption(&self.aes, block); + unpadder.push(block)?; + } + let len = unpadder.finish().len(); + Ok(len) + } +} + +impl AesBlockMode for CipherModeEcb { + fn encryption(&mut self, aes: &Aes256, mut block: Block) -> Block { + aes.encrypt_block(&mut block); + block + } + + fn decryption(&mut self, aes: &Aes256, mut block: Block) -> Block { + aes.decrypt_block(&mut block); + block + } +} + +impl AesBlockMode for CipherModeCbc { + fn encryption(&mut self, aes: &Aes256, block: Block) -> Block { + let mut block = util::xor16b(block, self.iv); + aes.encrypt_block(&mut block); + self.iv = block; + block + } + + fn decryption(&mut self, aes: &Aes256, ciphertext: Block) -> Block { + let mut block = ciphertext; + aes.decrypt_block(&mut block); + let block = util::xor16b(block, self.iv); + self.iv = ciphertext; + block + } +} + +impl SymmetricCipher { + pub fn new(suite: u8, shared_key: &[u8]) -> Result<Self, Error> { + match suite { + V1_CIPHER_AES_256_ECB => { + Aes256BlockCipher::new(shared_key, CipherModeEcb).map(Self::Aes256Ecb) + } + V1_CIPHER_AES_256_CBC => { + Aes256BlockCipher::new(shared_key, CipherModeCbc { iv: [0; 16].into() }) + .map(Self::Aes256Cbc) + } + _ => unreachable!(), + } + } + + pub fn encrypt(&mut self, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> { + match self { + Self::Aes256Ecb(cipher) => cipher.encrypt(src, dst), + Self::Aes256Cbc(cipher) => cipher.encrypt(src, dst), + } + } + + pub fn decrypt(&mut self, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> { + match self { + Self::Aes256Ecb(cipher) => cipher.decrypt(src, dst), + Self::Aes256Cbc(cipher) => cipher.decrypt(src, dst), + } + } +} + +#[cfg(test)] +mod tests { + use aes::Block; + + use super::{Aes256BlockCipher, CipherModeCbc, Pkcs7Padder, Pkcs7Unpadder}; + + fn pad_unpad<'d>(text: &[u8], buffer: &'d mut [u8]) -> &'d [u8] { + let blocks = Pkcs7Padder::new(text).collect::<Vec<_>>(); + let mut unpad = Pkcs7Unpadder::new(buffer, blocks.len()); + for block in blocks { + unpad.push(block).unwrap(); + } + unpad.finish() + } + + #[test] + fn test_pkcs7_pad() { + // 19 bytes + let src = b"This is a test text"; + let blocks = Pkcs7Padder::new(src).collect::<Vec<_>>(); + assert_eq!(blocks.len(), 2); + assert_eq!(&*blocks[0], b"This is a test t"); + assert_eq!(&blocks[1][..3], b"ext"); + assert_eq!(&blocks[1][3..], &[13; 13]); + + // 32 bytes + let src = b"1234567812345678ABCDEFGHIJKLMNOP"; + let blocks = Pkcs7Padder::new(src).collect::<Vec<_>>(); + assert_eq!(blocks.len(), 3); + assert_eq!(&*blocks[0], b"1234567812345678"); + assert_eq!(&*blocks[1], b"ABCDEFGHIJKLMNOP"); + assert_eq!(&*blocks[2], &[16; 16]); + } + + #[test] + fn test_pkcs7_unpad() { + // 19 bytes + let mut src = [0; 32]; + let mut dst = [0; 32]; + src[..19].copy_from_slice(b"This is a test text"); + src[19..].fill(13); + let mut unpad = Pkcs7Unpadder::new(&mut dst, 2); + for i in 0..2 { + let block = Block::from_slice(&src[i * 16..i * 16 + 16]); + unpad.push(*block).unwrap(); + } + assert_eq!(unpad.finish(), b"This is a test text"); + + // 32 bytes + let mut src = [0; 48]; + let mut dst = [0; 32]; + src[..32].copy_from_slice(b"1234567812345678ABCDEFGHIJKLMNOP"); + src[32..].fill(16); + let mut unpad = Pkcs7Unpadder::new(&mut dst, 3); + for i in 0..3 { + let block = Block::from_slice(&src[i * 16..i * 16 + 16]); + unpad.push(*block).unwrap(); + } + assert_eq!(unpad.finish(), &src[..32]); + } + + #[test] + fn test_pkcs7_pad_reversible() { + let texts = [&b"Hello"[..], &[16; 16], &[32; 16], &[1; 16]]; + let mut buffer = [0; 256]; + + for text in texts { + let output = pad_unpad(text, &mut buffer); + assert_eq!(text, output); + } + } + + #[test] + fn test_aes256cbc_encrypt_decrypt() { + let messages = [&b"Hello"[..], b"This is a test text", b"Another message!"]; + let key = b"1234ABCD1234ABCD1234ABCD1234ABCD"; + let mut encrypted = [0; 256]; + let mut decrypted = [0; 256]; + let mut enc_cipher = + Aes256BlockCipher::new(key, CipherModeCbc { iv: [0; 16].into() }).unwrap(); + let mut dec_cipher = + Aes256BlockCipher::new(key, CipherModeCbc { iv: [0; 16].into() }).unwrap(); + + for text in messages { + let len = enc_cipher.encrypt(text, &mut encrypted).unwrap(); + let len = dec_cipher + .decrypt(&encrypted[..len], &mut decrypted) + .unwrap(); + assert_eq!(&decrypted[..len], text); + } + } +} diff --git a/userspace/rsh/src/crypt/util.rs b/userspace/rsh/src/crypt/util.rs new file mode 100644 index 00000000..6a5c871d --- /dev/null +++ b/userspace/rsh/src/crypt/util.rs @@ -0,0 +1,10 @@ +use core::simd; + +use aes::Block; + +pub fn xor16b(x: Block, y: Block) -> Block { + let x = simd::u8x16::from_array(x.into()); + let y = simd::u8x16::from_array(y.into()); + let z: [u8; 16] = (x ^ y).into(); + Block::from(z) +} diff --git a/userspace/rsh/src/lib.rs b/userspace/rsh/src/lib.rs index 1b98dc70..57e4d174 100644 --- a/userspace/rsh/src/lib.rs +++ b/userspace/rsh/src/lib.rs @@ -1,5 +1,5 @@ #![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))] -#![feature(generic_const_exprs)] +#![feature(generic_const_exprs, portable_simd)] #![allow(incomplete_features)] use std::io; @@ -30,4 +30,14 @@ pub enum Error { InvalidState, #[error("Disconnected by remote peer")] Disconnected, + #[error("Message too large: buffer size {0}, message size {1}")] + MessageTooLarge(usize, usize), + #[error("Malformed ciphertext")] + InvalidCiphertext, + #[error("Malformed encryption key")] + InvalidKey, + #[error("Communication timed out")] + Timeout, + #[error("Cannot accept any of the offered ciphersuites")] + UnacceptableCiphersuites, } diff --git a/userspace/rsh/src/main.rs b/userspace/rsh/src/main.rs index 14bed354..60eba4f3 100644 --- a/userspace/rsh/src/main.rs +++ b/userspace/rsh/src/main.rs @@ -67,10 +67,12 @@ impl Client { socket.connect(remote)?; let mut socket = ClientEncryptedSocket::new(socket); - socket.try_connect_blocking()?; - let mut socket = ClientSocket::new(socket); - poll.add(&*socket)?; + poll.add(&socket)?; + + socket.try_connect_blocking(&mut poll, Duration::from_secs(1))?; + + let mut socket = ClientSocket::new(socket); // Do the handshake thing Self::handshake(&mut poll, &mut socket, info, 5)?; @@ -233,6 +235,7 @@ fn run(args: Args) -> Result<(), Error> { } fn main() -> ExitCode { + env_logger::init(); let args = Args::parse(); if let Err(error) = run(args) { diff --git a/userspace/rsh/src/rshd/main.rs b/userspace/rsh/src/rshd/main.rs index d16d886c..90467e4a 100644 --- a/userspace/rsh/src/rshd/main.rs +++ b/userspace/rsh/src/rshd/main.rs @@ -13,7 +13,10 @@ use std::{ use cross::io::{Poll, TimerFd}; use rsh::{ - crypt::ServerEncryptedSocket, proto::{ClientMessage, Decode, Decoder, Encoder, ServerMessage, TerminalInfo}, socket::{MessageSocket, MultiplexedSocket, MultiplexedSocketEvent}, Error, ServerSocket + crypt::ServerEncryptedSocket, + proto::{ClientMessage, Decode, Decoder, ServerMessage, TerminalInfo}, + socket::{MultiplexedSocket, MultiplexedSocketEvent}, + Error, }; pub const PING_INTERVAL: Duration = Duration::from_millis(500); @@ -22,7 +25,6 @@ pub struct Session { pty_master: File, remote: SocketAddr, shell: Child, - timeouts: usize, } pub struct Server { @@ -83,7 +85,6 @@ impl Session { pty_master, shell, remote, - timeouts: 0, }) } #[cfg(unix)] @@ -109,7 +110,11 @@ impl Server { }) } - pub fn poll<'b>(&mut self, buffer: &'b mut [u8], pty_max: usize) -> Result<Option<Event<'b>>, Error> { + pub fn poll<'b>( + &mut self, + buffer: &'b mut [u8], + pty_max: usize, + ) -> Result<Option<Event<'b>>, Error> { let fd = self.poll.wait(None)?.unwrap(); match fd { @@ -119,7 +124,7 @@ impl Server { let (message, remote) = match event { MultiplexedSocketEvent::ClientDisconnected(remote) => { self.remove_session_by_remote(remote).ok(); - return Ok(None) + return Ok(None); } MultiplexedSocketEvent::None(_) => return Ok(None), MultiplexedSocketEvent::ClientData(peer, data) => { @@ -184,7 +189,11 @@ impl Server { eprintln!("PTY write error: {error}"); self.remove_session_by_fd(fd)?; self.socket - .send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error")) + .send_message_to( + &remote, + &mut send_buf, + &ServerMessage::Bye("PTY error"), + ) .ok(); } } @@ -218,21 +227,25 @@ impl Server { self.socket .send_message_to(&remote, &mut send_buf, &ServerMessage::Output(data)) .ok(); - }, + } PtyEvent::Err(error) => { eprintln!("PTY read error: {error}"); self.remove_session_by_fd(fd)?; self.socket - .send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error")) + .send_message_to( + &remote, + &mut send_buf, + &ServerMessage::Bye("PTY error"), + ) .ok(); - }, + } PtyEvent::Closed => { println!("End of PTY for {remote}"); self.remove_session_by_fd(fd)?; self.socket .send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("")) .ok(); - }, + } }, Event::Tick => { // Restart the timer @@ -287,6 +300,10 @@ fn run() -> Result<(), Error> { } fn main() -> ExitCode { + env_logger::Builder::new() + .filter_level(log::LevelFilter::Debug) + .format_timestamp(None) + .init(); if let Err(error) = run() { eprintln!("Finished with error: {error}"); ExitCode::FAILURE