diff --git a/userspace/Cargo.lock b/userspace/Cargo.lock index eb66cc74..124b93ea 100644 --- a/userspace/Cargo.lock +++ b/userspace/Cargo.lock @@ -16,6 +16,17 @@ dependencies = [ name = "abi-lib" version = "0.1.0" +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "anstyle" version = "1.0.9" @@ -102,6 +113,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "4.5.20" @@ -522,6 +543,15 @@ dependencies = [ "yggdrasil-rt", ] +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1011,11 +1041,14 @@ dependencies = [ name = "rsh" version = "0.1.0" dependencies = [ + "aes", "bytemuck", "clap", "cross", "libterm", + "rand 0.8.5", "thiserror", + "x25519-dalek", ] [[package]] @@ -1633,6 +1666,17 @@ dependencies = [ "memchr", ] +[[package]] +name = "x25519-dalek" +version = "2.0.1" +source = "git+https://git.alnyan.me/yggdrasil/curve25519-dalek.git?branch=alnyan%2Fyggdrasil#5f4dbb09259347077d3a5024ba60c77efde93a3a" +dependencies = [ + "curve25519-dalek", + "rand_core 0.6.4 (git+https://git.alnyan.me/yggdrasil/rand.git?branch=alnyan%2Fyggdrasil-rng_core-0.6.4)", + "serde", + "zeroize", +] + [[package]] name = "yasync" version = "0.1.0" @@ -1710,3 +1754,17 @@ name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.86", +] diff --git a/userspace/rsh/Cargo.toml b/userspace/rsh/Cargo.toml index 9b72a83c..ff754904 100644 --- a/userspace/rsh/Cargo.toml +++ b/userspace/rsh/Cargo.toml @@ -8,8 +8,12 @@ name = "rshd" path = "src/rshd/main.rs" [dependencies] -clap.workspace = true libterm.workspace = true -thiserror.workspace = true cross.workspace = true + +clap.workspace = true +thiserror.workspace = true bytemuck.workspace = true +x25519-dalek.workspace = true +rand = { git = "https://git.alnyan.me/yggdrasil/rand.git", branch = "alnyan/yggdrasil-rng_core-0.6.4" } +aes = { version = "0.8.4" } diff --git a/userspace/rsh/src/crypt/client.rs b/userspace/rsh/src/crypt/client.rs new file mode 100644 index 00000000..9f1dd91b --- /dev/null +++ b/userspace/rsh/src/crypt/client.rs @@ -0,0 +1,206 @@ +use std::{ + net::SocketAddr, + os::fd::{AsRawFd, RawFd}, time::Instant, +}; + +use aes::{cipher::KeyInit, Aes256}; +use rand::prelude::Distribution; +use x25519_dalek::{EphemeralSecret, PublicKey}; + +use crate::{ + crypt::{ + decrypt_blocked, encrypt_blocked, ClientNegotiationMessage, ServerNegotiationMessage, + V1_CIPHER_AES_256_CBC, V1_KEX_X25519_DALEK, + }, + socket::{MessageSocket, PacketSocket, SocketWrapper}, + Error, +}; + +use super::{ClientMessageProxy, ServerMessageProxy}; + +enum ClientState { + PreNegotioation, + Connected(Aes256), +} + +pub struct ClientEncryptedSocket { + transport: SocketWrapper, + state: ClientState, + last_ping: Instant +} + +impl ClientEncryptedSocket { + pub fn new(transport: S) -> Self { + Self { + transport: SocketWrapper::new(transport), + state: ClientState::PreNegotioation, + last_ping: Instant::now(), + } + } + + pub fn last_ping(&self) -> Instant { + self.last_ping + } + + pub fn try_connect_blocking(&mut self) -> Result<(), Error> { + let ClientState::PreNegotioation = self.state else { + return Err(Error::InvalidState); + }; + let mut buf = [0; 256]; + + // Send Hello + self.transport + .send(&mut buf, &ClientNegotiationMessage::Hello { protocol: 1 })?; + + // Wait for Server Hello + let (message, _) = self.transport.recv_from(&mut buf)?; + let _hello = match message { + ServerNegotiationMessage::Hello(hello) => hello, + ServerNegotiationMessage::Reject(_reason) => { + return Err(Error::Disconnected); + }, + _ => { + return Err(Error::Disconnected); + }, + }; + + // TODO select kex, ciphersuite + let ciphersuite = V1_CIPHER_AES_256_CBC; + let kex_algo = V1_KEX_X25519_DALEK; + + // Initiate key exchange + self.transport.send( + &mut buf, + &ClientNegotiationMessage::StartKex { + kex_algo, + ciphersuite, + }, + )?; + + // Wait for server to accept + let (message, _) = self.transport.recv_from(&mut buf)?; + match message { + ServerNegotiationMessage::StartKex => (), + ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected), + _ => return Err(Error::Disconnected), + }; + + // Generate an ephemeral key + let mut rng = rand::thread_rng(); + let secret = EphemeralSecret::random_from_rng(&mut rng); + let public = PublicKey::from(&secret).to_bytes(); + + // Send it to the server + self.transport.send( + &mut buf, + &ClientNegotiationMessage::PublicKey(true, &public), + )?; + + // Wait for server to respond with its own key + let (message, _) = self.transport.recv_from(&mut buf)?; + let mut remote = [0; 32]; + match message { + ServerNegotiationMessage::PublicKey(true, key) if key.len() == 32 => { + remote.copy_from_slice(key) + } + ServerNegotiationMessage::PublicKey(_, _key) => return Err(Error::Disconnected), + ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected), + _ => return Err(Error::Disconnected), + }; + let remote = PublicKey::from(remote); + + let shared = secret.diffie_hellman(&remote); + + // Negotiation done + self.transport + .send(&mut buf, &ClientNegotiationMessage::Agreed)?; + + // Wait for server's Agreed + let (message, _) = self.transport.recv_from(&mut buf)?; + match message { + ServerNegotiationMessage::Agreed => (), + ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected), + _ => return Err(Error::Disconnected), + } + + let aes = Aes256::new_from_slice(shared.as_bytes()).unwrap(); + + self.state = ClientState::Connected(aes); + Ok(()) + } +} + +impl PacketSocket for ClientEncryptedSocket { + type Error = Error; + + fn recv_from(&mut self, output: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> { + match &self.state { + ClientState::Connected(aes) => { + assert!(output.len() >= 16); + loop { + let mut buffer = [0; 256]; + let (message, peer) = self.transport.recv_from(&mut buffer)?; + match message { + ServerNegotiationMessage::CipherText(data) => { + let len = decrypt_blocked(aes, data, output)?; + break Ok((len, peer)) + } + ServerNegotiationMessage::Ping => { + self.transport.send(&mut buffer, &ClientNegotiationMessage::Pong)?; + self.last_ping = Instant::now(); + // TODO this is a workaround + break Err(Error::Ping); + } + ServerNegotiationMessage::Kick(_reason) => { + break Err(Error::NotConnected) + }, + // TODO send disconnect and leave + _ => break Err(Error::NotConnected), + } + } + } + _ => Err(Error::NotConnected), + } + } + + fn send_to(&mut self, data: &[u8], _addr: &SocketAddr) -> Result { + self.send(data) + } + + fn send(&mut self, data: &[u8]) -> Result { + let mut buffer = [0; 256]; + let mut send_buffer = [0; 256]; + match &self.state { + ClientState::Connected(aes) => { + let len = encrypt_blocked(aes, data, &mut buffer)?; + assert_eq!(len % 16, 0); + self.transport.send( + &mut send_buffer, + &ClientNegotiationMessage::CipherText(&buffer[..len]), + )?; + Ok(data.len()) + } + _ => Err(Error::NotConnected), + } + } +} + +impl Drop for ClientEncryptedSocket { + fn drop(&mut self) { + match self.state { + ClientState::PreNegotioation => (), + _ => { + let mut buffer = [0; 32]; + self.transport + .send(&mut buffer, &ClientNegotiationMessage::Disconnect(0)) + .ok(); + } + } + } +} + +impl AsRawFd for ClientEncryptedSocket { + fn as_raw_fd(&self) -> RawFd { + self.transport.as_raw_fd() + } +} diff --git a/userspace/rsh/src/crypt/mod.rs b/userspace/rsh/src/crypt/mod.rs new file mode 100644 index 00000000..b124af0d --- /dev/null +++ b/userspace/rsh/src/crypt/mod.rs @@ -0,0 +1,262 @@ +use aes::{ + cipher::{BlockDecrypt, BlockEncrypt}, + Aes256, Block, +}; + +use crate::{ + proto::{Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy}, + Error, +}; + +pub mod client; +pub mod server; + +pub use client::ClientEncryptedSocket; +pub use server::ServerEncryptedSocket; + +pub const V1_CIPHER_AES_256_CBC: u8 = 0x10; +pub const V1_CIPHER_AES_256_CTR: u8 = 0x11; +pub const V1_CIPHER_AES_256_GCM: u8 = 0x12; +pub const V1_CIPHER_CHACHA20_POLY1305: u8 = 0x15; + +// v1 supports only one DH algo +pub const V1_KEX_X25519_DALEK: u8 = 0x10; + +pub enum ServerReject { + NoReason, + UnsupportedKexAlgo, + UnsupportedCiphersuite, + InvalidParameter, +} + +// v1 protocol +pub struct ServerHello<'a> { + // TODO server will send its public fingerprint and advertise + // supported asymmetric key formats + pub symmetric_ciphersuites: &'a [u8], + pub kex_algos: &'a [u8], +} + +pub enum ClientNegotiationMessage<'a> { + Hello { protocol: u8 }, + StartKex { kex_algo: u8, ciphersuite: u8 }, + PublicKey(bool, &'a [u8]), + Agreed, + Pong, + CipherText(&'a [u8]), + Disconnect(u8), +} + +pub enum ServerNegotiationMessage<'a> { + Hello(ServerHello<'a>), + Reject(ServerReject), + StartKex, + PublicKey(bool, &'a [u8]), + Agreed, + Kick(u8), + Ping, + CipherText(&'a [u8]), +} + +pub struct ClientMessageProxy; +pub struct ServerMessageProxy; + +impl MessageProxy for ClientMessageProxy { + type Type<'de> = ClientNegotiationMessage<'de>; +} +impl MessageProxy for ServerMessageProxy { + type Type<'de> = ServerNegotiationMessage<'de>; +} + +impl ClientNegotiationMessage<'_> { + const TAG_HELLO: u8 = 0x90; + const TAG_START_KEX: u8 = 0x91; + const TAG_PUBLIC_KEY: u8 = 0x92; + const TAG_AGREED: u8 = 0x93; + + const TAG_CIPHERTEXT: u8 = 0xA0; + const TAG_DISCONNECT: u8 = 0xA1; + const TAG_PONG: u8 = 0xA2; +} + +impl Encode for ClientNegotiationMessage<'_> { + fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> { + match self { + &Self::Hello { protocol } => buffer.write(&[Self::TAG_HELLO, protocol]), + &Self::StartKex { + kex_algo, + ciphersuite, + } => buffer.write(&[Self::TAG_START_KEX, kex_algo, ciphersuite]), + Self::PublicKey(end, data) => { + buffer.write(&[Self::TAG_PUBLIC_KEY, *end as u8])?; + buffer.write_variable_bytes(data) + } + Self::Agreed => buffer.write(&[Self::TAG_AGREED]), + Self::CipherText(data) => { + buffer.write(&[Self::TAG_CIPHERTEXT])?; + buffer.write_variable_bytes(data) + } + &Self::Disconnect(reason) => buffer.write(&[Self::TAG_DISCONNECT, reason]), + Self::Pong => buffer.write(&[Self::TAG_PONG]), + } + } +} + +impl<'de> Decode<'de> for ClientNegotiationMessage<'de> { + fn decode(buffer: &mut Decoder<'de>) -> Result { + let tag = buffer.read_u8()?; + match tag { + Self::TAG_HELLO => { + let protocol = buffer.read_u8()?; + Ok(Self::Hello { protocol }) + } + Self::TAG_START_KEX => { + let kex_algo = buffer.read_u8()?; + let ciphersuite = buffer.read_u8()?; + Ok(Self::StartKex { + kex_algo, + ciphersuite, + }) + } + Self::TAG_PUBLIC_KEY => { + let end = buffer.read_u8()? != 0; + let data = buffer.read_variable_bytes()?; + Ok(Self::PublicKey(end, data)) + } + Self::TAG_CIPHERTEXT => buffer.read_variable_bytes().map(Self::CipherText), + Self::TAG_AGREED => Ok(Self::Agreed), + Self::TAG_PONG => Ok(Self::Pong), + Self::TAG_DISCONNECT => buffer.read_u8().map(Self::Disconnect), + _ => Err(DecodeError::InvalidMessage), + } + } +} + +impl ServerNegotiationMessage<'_> { + const TAG_HELLO: u8 = 0xB0; + const TAG_REJECT: u8 = 0xB1; + const TAG_START_KEX: u8 = 0xB2; + const TAG_PUBLIC_KEY: u8 = 0xB3; + const TAG_AGREED: u8 = 0xB4; + + const TAG_CIPHERTEXT: u8 = 0xC0; + const TAG_PING: u8 = 0xC1; + const TAG_KICK: u8 = 0xC2; +} + +impl Encode for ServerHello<'_> { + fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> { + buffer.write_variable_bytes(self.kex_algos)?; + buffer.write_variable_bytes(self.symmetric_ciphersuites)?; + Ok(()) + } +} + +impl<'de> Decode<'de> for ServerHello<'de> { + fn decode(buffer: &mut Decoder<'de>) -> Result { + let kex_algos = buffer.read_variable_bytes()?; + let symmetric_ciphersuites = buffer.read_variable_bytes()?; + Ok(Self { + kex_algos, + symmetric_ciphersuites, + }) + } +} + +impl Encode for ServerNegotiationMessage<'_> { + fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> { + match self { + Self::Hello(hello) => { + buffer.write(&[Self::TAG_HELLO])?; + hello.encode(buffer) + } + Self::Reject(_reason) => todo!(), + Self::StartKex => buffer.write(&[Self::TAG_START_KEX]), + Self::PublicKey(end, data) => { + buffer.write(&[Self::TAG_PUBLIC_KEY, *end as u8])?; + buffer.write_variable_bytes(data) + } + Self::Agreed => buffer.write(&[Self::TAG_AGREED]), + Self::CipherText(data) => { + buffer.write(&[Self::TAG_CIPHERTEXT])?; + buffer.write_variable_bytes(data) + } + &Self::Kick(reason) => buffer.write(&[Self::TAG_KICK, reason]), + &Self::Ping => buffer.write(&[Self::TAG_PING]), + } + } +} + +impl<'de> Decode<'de> for ServerNegotiationMessage<'de> { + fn decode(buffer: &mut Decoder<'de>) -> Result { + let tag = buffer.read_u8()?; + match tag { + Self::TAG_HELLO => ServerHello::decode(buffer).map(Self::Hello), + Self::TAG_REJECT => todo!(), + Self::TAG_START_KEX => Ok(Self::StartKex), + Self::TAG_PUBLIC_KEY => { + let end = buffer.read_u8()? != 0; + let data = buffer.read_variable_bytes()?; + Ok(Self::PublicKey(end, data)) + } + Self::TAG_AGREED => Ok(Self::Agreed), + Self::TAG_CIPHERTEXT => buffer.read_variable_bytes().map(Self::CipherText), + Self::TAG_KICK => buffer.read_u8().map(Self::Kick), + Self::TAG_PING => Ok(Self::Ping), + + _ => Err(DecodeError::InvalidMessage), + } + } +} + +fn decrypt_blocked(aes: &Aes256, src: &[u8], dst: &mut [u8]) -> Result { + if src.len() % 16 != 0 || dst.len() < src.len() { + todo!(); + } + + let mut pos = 0; + let mut out = 0; + let mut block = Block::from([0; 16]); + while pos != src.len() { + block.copy_from_slice(&src[pos..pos + 16]); + aes.decrypt_block(&mut block); + + let len = block[0] as usize; + if len > 15 { + todo!(); + } + + dst[out..out + len].copy_from_slice(&block[1..1 + len]); + + out += len; + pos += 16; + } + + Ok(out) +} + +fn encrypt_blocked(aes: &Aes256, src: &[u8], dst: &mut [u8]) -> Result { + if src.len() >= (dst.len() / 16) * 15 { + todo!(); + } + + let mut pos = 0; + let mut out = 0; + let mut block = Block::from([0; 16]); + while pos != src.len() { + let len = core::cmp::min(src.len() - pos, 15); + block[0] = len as u8; + block[1..1 + len].copy_from_slice(&src[pos..pos + len]); + // TODO pad with random + if len != 15 { + block[1 + len..].fill(0); + } + aes.encrypt_block(&mut block); + + dst[out..out + 16].copy_from_slice(&block); + + pos += len; + out += 16; + } + Ok(out) +} diff --git a/userspace/rsh/src/crypt/server.rs b/userspace/rsh/src/crypt/server.rs new file mode 100644 index 00000000..032cffcb --- /dev/null +++ b/userspace/rsh/src/crypt/server.rs @@ -0,0 +1,202 @@ +use std::{ + collections::HashMap, + mem, + net::SocketAddr, + os::fd::{AsRawFd, RawFd}, +}; + +use aes::{cipher::KeyInit, Aes256}; +use x25519_dalek::{EphemeralSecret, PublicKey}; + +use crate::{ + crypt::{decrypt_blocked, ServerHello, ServerNegotiationMessage}, + socket::{ + MessageSocket, MultiplexedSocket, MultiplexedSocketEvent, PacketSocket, SocketWrapper, + }, + Error, +}; + +use super::{encrypt_blocked, ClientMessageProxy, ClientNegotiationMessage, ServerMessageProxy}; + +pub enum ServerPeerTransport { + PreNegotiation, + Negotiation(EphemeralSecret), + Connected(Aes256, usize), +} + +pub struct ServerEncryptedSocket { + transport: SocketWrapper, + peers: HashMap, + buffer: [u8; 256], +} + +impl ServerEncryptedSocket { + pub fn new(transport: S) -> Self { + Self { + transport: SocketWrapper::new(transport), + peers: HashMap::new(), + buffer: [0; 256], + } + } + + pub fn remove_client(&mut self, remote: &SocketAddr) { + let mut buf = [0; 32]; + if self.peers.remove(remote).is_some() { + self.transport + .send_to(remote, &mut buf, &ServerNegotiationMessage::Kick(0)) + .ok(); + } + } + + pub fn ping_clients(&mut self, limit: usize) -> Vec { + let mut send_buf = [0; 32]; + let mut removed = vec![]; + for (remote, state) in self.peers.iter_mut() { + match state { + ServerPeerTransport::Connected(_, missed) => { + self.transport.send_to(remote, &mut send_buf, &ServerNegotiationMessage::Ping).ok(); + + if *missed >= limit { + removed.push(*remote); + } + + *missed += 1; + } + _ => (), + } + } + + for entry in removed.iter() { + self.peers.remove(entry); + } + + removed + } +} + +impl MultiplexedSocket for ServerEncryptedSocket { + fn send_to(&mut self, remote: &SocketAddr, data: &[u8]) -> Result<(), Error> { + let mut buf = [0; 256]; + if let Some(ServerPeerTransport::Connected(aes, _)) = self.peers.get(remote) { + let len = encrypt_blocked(aes, data, &mut self.buffer)?; + assert_eq!(len % 16, 0); + self.transport.send_to( + remote, + &mut buf, + &ServerNegotiationMessage::CipherText(&self.buffer[..len]), + ) + } else { + Err(Error::NotConnected) + } + } + + fn recv_from<'a>(&mut self, buffer: &'a mut [u8]) -> Result, Error> { + let (message, remote) = self.transport.recv_from(&mut self.buffer)?; + + let mut buf = [0; 256]; + + if let Some(mut state) = self.peers.get_mut(&remote) { + match (message, &mut state) { + // TODO check kex params + ( + ClientNegotiationMessage::StartKex { .. }, + ServerPeerTransport::PreNegotiation, + ) => { + let mut rng = rand::thread_rng(); + let secret = EphemeralSecret::random_from_rng(&mut rng); + *state = ServerPeerTransport::Negotiation(secret); + self.transport.send_to( + &remote, + &mut buf, + &ServerNegotiationMessage::StartKex, + )?; + Ok(MultiplexedSocketEvent::None(remote)) + } + ( + ClientNegotiationMessage::PublicKey(true, data), + ServerPeerTransport::Negotiation(secret), + ) if data.len() == 32 => { + let public = PublicKey::from(&*secret); + let mut remote_key = [0; 32]; + remote_key.copy_from_slice(data); + let remote_key = PublicKey::from(remote_key); + state.negotiate(&remote_key); + + // Send public key to the client + self.transport.send_to( + &remote, + &mut buf, + &ServerNegotiationMessage::PublicKey(true, public.as_bytes()), + )?; + + Ok(MultiplexedSocketEvent::None(remote)) + } + (ClientNegotiationMessage::Agreed, ServerPeerTransport::Connected(_, _)) => { + self.transport + .send_to(&remote, &mut buf, &ServerNegotiationMessage::Agreed)?; + Ok(MultiplexedSocketEvent::None(remote)) + } + ( + ClientNegotiationMessage::CipherText(data), + ServerPeerTransport::Connected(aes, _), + ) => { + let len = decrypt_blocked(aes, data, buffer)?; + Ok(MultiplexedSocketEvent::ClientData(remote, &buffer[..len])) + } + ( + ClientNegotiationMessage::Pong, + ServerPeerTransport::Connected(_, missed) + ) => { + *missed = 0; + Ok(MultiplexedSocketEvent::None(remote)) + } + (ClientNegotiationMessage::Disconnect(_reason), _) => { + eprintln!("Peer disconnected: {remote}"); + self.peers.remove(&remote); + Ok(MultiplexedSocketEvent::ClientDisconnected(remote)) + } + // Misbehavior + _ => { + self.peers.remove(&remote); + Ok(MultiplexedSocketEvent::ClientDisconnected(remote)) + } + } + } else { + let ClientNegotiationMessage::Hello { protocol: 1 } = message else { + eprintln!("Unhandled client message"); + return Ok(MultiplexedSocketEvent::None(remote)); + }; + eprintln!("Client Hello"); + self.peers + .insert(remote, ServerPeerTransport::PreNegotiation); + + // Reply with hello + let hello = ServerHello { + kex_algos: &[], + symmetric_ciphersuites: &[], + }; + self.transport + .send_to(&remote, &mut buf, &ServerNegotiationMessage::Hello(hello))?; + + Ok(MultiplexedSocketEvent::None(remote)) + } + } +} + +impl AsRawFd for ServerEncryptedSocket { + fn as_raw_fd(&self) -> RawFd { + self.transport.as_raw_fd() + } +} + +impl ServerPeerTransport { + fn negotiate(&mut self, public: &PublicKey) { + let Self::Negotiation(secret) = mem::replace(self, ServerPeerTransport::PreNegotiation) + else { + panic!(); + }; + let shared = secret.diffie_hellman(public); + let aes = Aes256::new_from_slice(shared.as_bytes()).unwrap(); + *self = ServerPeerTransport::Connected(aes, 0); + } +} diff --git a/userspace/rsh/src/lib.rs b/userspace/rsh/src/lib.rs index b71d8630..1b98dc70 100644 --- a/userspace/rsh/src/lib.rs +++ b/userspace/rsh/src/lib.rs @@ -2,20 +2,15 @@ #![feature(generic_const_exprs)] #![allow(incomplete_features)] -use std::{ - io, - marker::PhantomData, - net::{SocketAddr, UdpSocket}, - ops::{Deref, DerefMut}, - os::fd::{AsRawFd, RawFd}, -}; +use std::io; -use proto::{ - ClientMessageProxy, Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy, - ServerMessageProxy, -}; +use proto::{DecodeError, EncodeError}; pub mod proto; +pub mod socket; +pub mod crypt; + +pub use socket::{ClientSocket, ServerSocket}; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -27,106 +22,12 @@ pub enum Error { Decode(#[from] DecodeError), #[error("Encode error: {0}")] Encode(#[from] EncodeError), -} - -pub struct SocketWrapper { - socket: UdpSocket, - _pd: PhantomData<(Rx, Tx)>, -} - -pub struct ClientSocket(SocketWrapper); -pub struct ServerSocket(SocketWrapper); - -impl SocketWrapper { - pub fn new(socket: UdpSocket) -> Self { - Self { - socket, - _pd: PhantomData, - } - } - - pub fn send(&self, buffer: &mut [u8], message: &Tx::Type<'_>) -> Result<(), Error> { - let mut enc = Encoder::new(buffer); - message.encode(&mut enc)?; - let message = enc.get(); - let amount = self.socket.send(message)?; - if amount == message.len() { - Ok(()) - } else { - Err(Error::Truncated) - } - } - - pub fn send_to( - &mut self, - remote: &SocketAddr, - buffer: &mut [u8], - message: &Tx::Type<'_>, - ) -> Result<(), Error> { - let mut enc = Encoder::new(buffer); - message.encode(&mut enc)?; - let message = enc.get(); - let amount = self.socket.send_to(message, remote)?; - if amount == message.len() { - Ok(()) - } else { - Err(Error::Truncated) - } - } - - pub fn recv_from<'de>( - &mut self, - buffer: &'de mut [u8], - ) -> Result<(Rx::Type<'de>, SocketAddr), Error> { - let (len, remote) = self.socket.recv_from(buffer)?; - let mut dec = Decoder::new(&buffer[..len]); - let message = Rx::Type::<'de>::decode(&mut dec)?; - Ok((message, remote)) - } -} - -impl AsRawFd for SocketWrapper { - fn as_raw_fd(&self) -> RawFd { - self.socket.as_raw_fd() - } -} - -impl ClientSocket { - pub fn new(socket: UdpSocket) -> Self { - Self(SocketWrapper::new(socket)) - } -} - -impl Deref for ClientSocket { - type Target = SocketWrapper; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for ClientSocket { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl ServerSocket { - pub fn new(socket: UdpSocket) -> Self { - Self(SocketWrapper::new(socket)) - } -} - -impl Deref for ServerSocket { - type Target = SocketWrapper; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for ServerSocket { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } + #[error("Peer not connected")] + NotConnected, + #[error("Ping")] + Ping, + #[error("Invalid socket state")] + InvalidState, + #[error("Disconnected by remote peer")] + Disconnected, } diff --git a/userspace/rsh/src/main.rs b/userspace/rsh/src/main.rs index 89f87e78..14bed354 100644 --- a/userspace/rsh/src/main.rs +++ b/userspace/rsh/src/main.rs @@ -12,8 +12,7 @@ use clap::Parser; use cross::io::Poll; use libterm::{RawMode, RawTerminal}; use rsh::{ - proto::{ClientMessage, ServerMessage, TerminalInfo}, - ClientSocket, + crypt::ClientEncryptedSocket, proto::{ClientMessage, ServerMessage, TerminalInfo}, socket::{MessageSocket, PacketSocket}, ClientSocket }; pub const PING_TIMEOUT: Duration = Duration::from_secs(3); @@ -39,11 +38,10 @@ struct Args { pub struct Client { poll: Poll, - socket: ClientSocket, + socket: ClientSocket>, stdin: Stdin, stdout: Stdout, need_bye: bool, - last_ping: Instant, _raw: RawMode, } @@ -51,7 +49,6 @@ pub enum Event<'b> { Stdin(&'b [u8]), Data(&'b [u8]), Disconnected(&'b str), - Ping, } impl Client { @@ -69,6 +66,8 @@ impl Client { let socket = UdpSocket::bind(local)?; socket.connect(remote)?; + let mut socket = ClientEncryptedSocket::new(socket); + socket.try_connect_blocking()?; let mut socket = ClientSocket::new(socket); poll.add(&*socket)?; @@ -82,7 +81,6 @@ impl Client { let _raw = unsafe { RawMode::enter(&stdin)? }; Ok(Self { - last_ping: Instant::now(), poll, socket, stdout, @@ -102,7 +100,11 @@ impl Client { match event { fd if fd == self.socket.as_raw_fd() => { - let (message, _) = self.socket.recv_from(buffer)?; + let message = match self.socket.recv_from(buffer) { + Ok((message, _)) => message, + Err(rsh::Error::Ping) => return Ok(None), + Err(error) => return Err(error.into()) + }; match message { ServerMessage::Bye(reason) => { // No need for a bye @@ -112,9 +114,6 @@ impl Client { ServerMessage::Output(data) => { break Ok(Some(Event::Data(data))); } - ServerMessage::Ping => { - break Ok(Some(Event::Ping)); - } // Ignore this one ServerMessage::Hello => break Ok(None), } @@ -148,10 +147,6 @@ impl Client { Event::Stdin(data) => { self.socket.send(&mut send_buf, &ClientMessage::Input(data))?; } - Event::Ping => { - self.last_ping = Instant::now(); - self.socket.send(&mut send_buf, &ClientMessage::Pong)?; - } Event::Disconnected(reason) => { break Ok(reason.into()); } @@ -161,16 +156,16 @@ impl Client { fn check_timeout(&mut self) -> Result<(), Error> { let now = Instant::now(); - if now - self.last_ping >= PING_TIMEOUT { + if now - self.socket.as_inner_mut().last_ping() >= PING_TIMEOUT { Err(Error::Timeout) } else { Ok(()) } } - fn handshake( + fn handshake( poll: &mut Poll, - socket: &mut ClientSocket, + socket: &mut ClientSocket, terminal: TerminalInfo, attempts: usize, ) -> Result<(), Error> { @@ -187,9 +182,9 @@ impl Client { Self::try_handshake(poll, socket, terminal, timeout) } - fn try_handshake( + fn try_handshake( poll: &mut Poll, - socket: &mut ClientSocket, + socket: &mut ClientSocket, terminal: TerminalInfo, timeout: Duration, ) -> Result<(), Error> { diff --git a/userspace/rsh/src/proto.rs b/userspace/rsh/src/proto.rs index b1dcc899..09cdd47e 100644 --- a/userspace/rsh/src/proto.rs +++ b/userspace/rsh/src/proto.rs @@ -11,7 +11,6 @@ pub struct ServerMessageProxy; pub enum ClientMessage<'a> { Hello(TerminalInfo), Bye(&'a str), - Pong, Input(&'a [u8]), } @@ -19,7 +18,6 @@ pub enum ClientMessage<'a> { pub enum ServerMessage<'a> { Hello, Bye(&'a str), - Ping, Output(&'a [u8]), } @@ -68,6 +66,10 @@ impl<'a> Encoder<'a> { Self { buffer, pos: 0 } } + pub fn reset(&mut self) { + self.pos = 0; + } + pub fn write_variable_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> { let len: u32 = bytes .len() @@ -145,7 +147,6 @@ impl MessageProxy for ServerMessageProxy { impl ClientMessage<'_> { const TAG_HELLO: u8 = 0x80; const TAG_BYE: u8 = 0x81; - const TAG_PONG: u8 = 0x82; const TAG_INPUT: u8 = 0x90; } @@ -176,9 +177,6 @@ impl<'a> Encode for ClientMessage<'a> { buffer.write(&[Self::TAG_BYE])?; buffer.write_str(reason) } - Self::Pong => { - buffer.write(&[Self::TAG_PONG]) - } Self::Input(data) => { buffer.write(&[Self::TAG_INPUT])?; buffer.write_variable_bytes(data) @@ -197,9 +195,6 @@ impl<'de> Decode<'de> for ClientMessage<'de> { Self::TAG_BYE => { buffer.read_str().map(Self::Bye) } - Self::TAG_PONG => { - Ok(Self::Pong) - } Self::TAG_INPUT => { buffer.read_variable_bytes().map(Self::Input) } @@ -211,7 +206,6 @@ impl<'de> Decode<'de> for ClientMessage<'de> { impl ServerMessage<'_> { const TAG_HELLO: u8 = 0x10; const TAG_BYE: u8 = 0x11; - const TAG_PING: u8 = 0x12; const TAG_OUTPUT: u8 = 0x20; } @@ -223,8 +217,6 @@ impl<'a> Encode for ServerMessage<'a> { buffer.write(&[Self::TAG_BYE])?; buffer.write_str(reason) } - // TODO sequence number - Self::Ping => buffer.write(&[Self::TAG_PING]), Self::Output(data) => { buffer.write(&[Self::TAG_OUTPUT])?; buffer.write_variable_bytes(data) @@ -241,7 +233,6 @@ impl<'de> Decode<'de> for ServerMessage<'de> { Self::TAG_BYE => { buffer.read_str().map(Self::Bye) } - Self::TAG_PING => Ok(Self::Ping), Self::TAG_OUTPUT => { buffer.read_variable_bytes().map(Self::Output) }, diff --git a/userspace/rsh/src/rshd/main.rs b/userspace/rsh/src/rshd/main.rs index d77d6648..d16d886c 100644 --- a/userspace/rsh/src/rshd/main.rs +++ b/userspace/rsh/src/rshd/main.rs @@ -13,8 +13,7 @@ use std::{ use cross::io::{Poll, TimerFd}; use rsh::{ - proto::{ClientMessage, ServerMessage, TerminalInfo}, - Error, ServerSocket, + crypt::ServerEncryptedSocket, proto::{ClientMessage, Decode, Decoder, Encoder, ServerMessage, TerminalInfo}, socket::{MessageSocket, MultiplexedSocket, MultiplexedSocketEvent}, Error, ServerSocket }; pub const PING_INTERVAL: Duration = Duration::from_millis(500); @@ -29,7 +28,7 @@ pub struct Session { pub struct Server { poll: Poll, timer: TimerFd, - socket: ServerSocket, + socket: ServerEncryptedSocket, addr_to_session: HashMap, pty_to_session: HashMap, @@ -98,8 +97,8 @@ impl Server { pub fn new(listen_addr: SocketAddr) -> Result { let mut poll = Poll::new()?; let timer = TimerFd::new()?; - let socket = UdpSocket::bind(listen_addr).map(ServerSocket::new)?; - poll.add(&*socket)?; + let socket = UdpSocket::bind(listen_addr).map(ServerEncryptedSocket::new)?; + poll.add(&socket)?; poll.add(&timer)?; Ok(Self { poll, @@ -115,13 +114,27 @@ impl Server { match fd { fd if fd == self.socket.as_raw_fd() => { - let (message, remote) = match self.socket.recv_from(buffer) { - Ok((message, remote)) => (message, remote), - Err(error @ (Error::Decode(_) | Error::Truncated)) => { - eprintln!("Receive error: {error}"); + let event = self.socket.recv_from(buffer)?; + + let (message, remote) = match event { + MultiplexedSocketEvent::ClientDisconnected(remote) => { + self.remove_session_by_remote(remote).ok(); return Ok(None) - }, - Err(error) => return Err(error), + } + MultiplexedSocketEvent::None(_) => return Ok(None), + MultiplexedSocketEvent::ClientData(peer, data) => { + let mut decoder = Decoder::new(data); + let message = ClientMessage::decode(&mut decoder); + (message, peer) + } + }; + + let message = match message { + Ok(message) => message, + Err(error) => { + eprintln!("Decode error: {error}"); + return Ok(None); + } }; let event = match message { @@ -136,11 +149,6 @@ impl Server { { Event::SessionInput(*fd, remote, data) } - ClientMessage::Pong if let Some(fd) = self.addr_to_session.get(&remote) => { - let session = self.pty_to_session.get_mut(fd).unwrap(); - session.timeouts = 0; - return Ok(None); - } _ => return Ok(None), }; @@ -176,7 +184,7 @@ impl Server { eprintln!("PTY write error: {error}"); self.remove_session_by_fd(fd)?; self.socket - .send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error")) + .send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error")) .ok(); } } @@ -190,13 +198,13 @@ impl Server { Ok(session) => { self.register_session(remote, session)?; self.socket - .send_to(&remote, &mut send_buf, &ServerMessage::Hello) + .send_message_to(&remote, &mut send_buf, &ServerMessage::Hello) .ok(); } Err(err) => { eprintln!("PTY open error: {err}"); self.socket - .send_to( + .send_message_to( &remote, &mut send_buf, &ServerMessage::Bye("PTY open error"), @@ -208,48 +216,37 @@ impl Server { Event::Pty(fd, remote, event) => match event { PtyEvent::Data(data) => { self.socket - .send_to(&remote, &mut send_buf, &ServerMessage::Output(data)) + .send_message_to(&remote, &mut send_buf, &ServerMessage::Output(data)) .ok(); }, PtyEvent::Err(error) => { eprintln!("PTY read error: {error}"); self.remove_session_by_fd(fd)?; self.socket - .send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error")) + .send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error")) .ok(); }, PtyEvent::Closed => { println!("End of PTY for {remote}"); self.remove_session_by_fd(fd)?; self.socket - .send_to(&remote, &mut send_buf, &ServerMessage::Bye("")) + .send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("")) .ok(); }, }, Event::Tick => { // Restart the timer - self.update_client_timeouts(&mut send_buf)?; + self.update_client_timeouts()?; self.timer.start(PING_INTERVAL)?; } } } } - fn update_client_timeouts(&mut self, send_buf: &mut [u8]) -> Result<(), Error> { - let mut removed = vec![]; - for (remote, fd) in self.addr_to_session.iter() { - let session = self.pty_to_session.get_mut(&fd).unwrap(); - self.socket.send_to(remote, send_buf, &ServerMessage::Ping).ok(); - session.timeouts += 1; - if session.timeouts >= 10 { - removed.push((*remote, *fd)); - } - } - - for (remote, fd) in removed { - eprintln!("Client {remote} timed out"); - self.remove_session_by_fd(fd)?; - self.socket.send_to(&remote, send_buf, &ServerMessage::Bye("Timed out".into())).ok(); + fn update_client_timeouts(&mut self) -> Result<(), Error> { + let removed = self.socket.ping_clients(8); + for entry in removed { + self.remove_session_by_remote(entry).ok(); } Ok(()) } @@ -268,6 +265,7 @@ impl Server { // NOTE: this will block the whole server while the process finishes. session.shell.wait().ok(); self.addr_to_session.remove(&session.remote).unwrap(); + self.socket.remove_client(&session.remote); self.poll.remove(&fd)?; Ok(Some(session)) } else { diff --git a/userspace/rsh/src/socket.rs b/userspace/rsh/src/socket.rs new file mode 100644 index 00000000..ddeb62bf --- /dev/null +++ b/userspace/rsh/src/socket.rs @@ -0,0 +1,202 @@ +use std::{ + io, + marker::PhantomData, + net::{SocketAddr, UdpSocket}, + ops::{Deref, DerefMut}, + os::fd::{AsRawFd, RawFd}, +}; + +use crate::{ + proto::{ + ClientMessageProxy, Decode, Decoder, Encode, Encoder, MessageProxy, ServerMessageProxy, + }, + Error, +}; + +pub struct SocketWrapper { + socket: S, + _pd: PhantomData<(Rx, Tx)>, +} + +pub trait MessageSocket: AsRawFd { + fn send(&mut self, buffer: &mut [u8], data: &Tx::Type<'_>) -> Result<(), Error>; + fn send_to( + &mut self, + addr: &SocketAddr, + buffer: &mut [u8], + data: &Tx::Type<'_>, + ) -> Result<(), Error>; + fn recv_from<'de>( + &mut self, + buffer: &'de mut [u8], + ) -> Result<(Rx::Type<'de>, SocketAddr), Error>; +} + +pub trait PacketSocket: AsRawFd { + type Error: Into; + + fn send(&mut self, data: &[u8]) -> Result; + fn send_to(&mut self, data: &[u8], addr: &SocketAddr) -> Result; + fn recv_from(&mut self, data: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error>; +} + +pub enum MultiplexedSocketEvent<'a> { + ClientData(SocketAddr, &'a [u8]), + ClientDisconnected(SocketAddr), + None(SocketAddr), +} + +pub trait MultiplexedSocket: AsRawFd { + fn recv_from<'a>(&mut self, buffer: &'a mut [u8]) -> Result, Error>; + fn send_to(&mut self, remote: &SocketAddr, data: &[u8]) -> Result<(), Error>; + + fn send_message_to( + &mut self, + remote: &SocketAddr, + buffer: &mut [u8], + data: &Tx, + ) -> Result<(), Error> { + let mut encoder = Encoder::new(buffer); + data.encode(&mut encoder)?; + self.send_to(remote, encoder.get()) + } +} + +impl PacketSocket for UdpSocket { + type Error = io::Error; + + fn send(&mut self, data: &[u8]) -> Result { + UdpSocket::send(self, data) + } + + fn send_to(&mut self, data: &[u8], addr: &SocketAddr) -> Result { + UdpSocket::send_to(self, data, addr) + } + + fn recv_from(&mut self, data: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> { + UdpSocket::recv_from(self, data) + } +} + +pub struct ClientSocket(SocketWrapper); +pub struct ServerSocket(S); + +impl MessageSocket + for SocketWrapper +{ + fn send(&mut self, buffer: &mut [u8], data: &Tx::Type<'_>) -> Result<(), Error> { + let mut enc = Encoder::new(buffer); + data.encode(&mut enc)?; + let message = enc.get(); + let amount = self.socket.send(message).map_err(S::Error::into)?; + if amount == message.len() { + Ok(()) + } else { + Err(Error::Truncated) + } + } + + fn send_to( + &mut self, + addr: &SocketAddr, + buffer: &mut [u8], + data: &Tx::Type<'_>, + ) -> Result<(), Error> { + let mut enc = Encoder::new(buffer); + data.encode(&mut enc)?; + let message = enc.get(); + let amount = self.socket.send_to(message, addr).map_err(S::Error::into)?; + if amount == message.len() { + Ok(()) + } else { + Err(Error::Truncated) + } + } + + fn recv_from<'de>( + &mut self, + buffer: &'de mut [u8], + ) -> Result<(Rx::Type<'de>, SocketAddr), Error> { + let (len, remote) = self.socket.recv_from(buffer).map_err(S::Error::into)?; + let mut dec = Decoder::new(&buffer[..len]); + let message = Rx::Type::<'de>::decode(&mut dec)?; + Ok((message, remote)) + } +} + +impl SocketWrapper { + pub fn new(socket: S) -> Self { + Self { + socket, + _pd: PhantomData, + } + } + + pub fn into_inner(self) -> S { + self.socket + } + + pub fn as_inner_mut(&mut self) -> &mut S { + &mut self.socket + } +} + +impl AsRawFd for SocketWrapper { + fn as_raw_fd(&self) -> RawFd { + self.socket.as_raw_fd() + } +} + +impl ClientSocket { + pub fn new(socket: S) -> Self { + Self(SocketWrapper::new(socket)) + } +} + +impl Deref for ClientSocket { + type Target = SocketWrapper; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for ClientSocket { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl ServerSocket { + pub fn new(socket: S) -> Self { + Self(socket) + } +} + +impl Deref for ServerSocket { + type Target = S; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for ServerSocket { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +// impl Deref for ServerSocket { +// type Target = SocketWrapper; +// +// fn deref(&self) -> &Self::Target { +// &self.0 +// } +// } +// +// impl DerefMut for ServerSocket { +// fn deref_mut(&mut self) -> &mut Self::Target { +// &mut self.0 +// } +// }