rsh: implement dumb kex+aes256

This commit is contained in:
Mark Poliakov 2024-11-02 01:00:42 +02:00
parent bcf1e74a04
commit 99c1dd51ae
10 changed files with 1004 additions and 185 deletions

58
userspace/Cargo.lock generated
View File

@ -16,6 +16,17 @@ dependencies = [
name = "abi-lib"
version = "0.1.0"
[[package]]
name = "aes"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "anstyle"
version = "1.0.9"
@ -102,6 +113,16 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cipher"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"inout",
]
[[package]]
name = "clap"
version = "4.5.20"
@ -522,6 +543,15 @@ dependencies = [
"yggdrasil-rt",
]
[[package]]
name = "inout"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5"
dependencies = [
"generic-array",
]
[[package]]
name = "itoa"
version = "1.0.11"
@ -1011,11 +1041,14 @@ dependencies = [
name = "rsh"
version = "0.1.0"
dependencies = [
"aes",
"bytemuck",
"clap",
"cross",
"libterm",
"rand 0.8.5",
"thiserror",
"x25519-dalek",
]
[[package]]
@ -1633,6 +1666,17 @@ dependencies = [
"memchr",
]
[[package]]
name = "x25519-dalek"
version = "2.0.1"
source = "git+https://git.alnyan.me/yggdrasil/curve25519-dalek.git?branch=alnyan%2Fyggdrasil#5f4dbb09259347077d3a5024ba60c77efde93a3a"
dependencies = [
"curve25519-dalek",
"rand_core 0.6.4 (git+https://git.alnyan.me/yggdrasil/rand.git?branch=alnyan%2Fyggdrasil-rng_core-0.6.4)",
"serde",
"zeroize",
]
[[package]]
name = "yasync"
version = "0.1.0"
@ -1710,3 +1754,17 @@ name = "zeroize"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
dependencies = [
"zeroize_derive",
]
[[package]]
name = "zeroize_derive"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.86",
]

View File

@ -8,8 +8,12 @@ name = "rshd"
path = "src/rshd/main.rs"
[dependencies]
clap.workspace = true
libterm.workspace = true
thiserror.workspace = true
cross.workspace = true
clap.workspace = true
thiserror.workspace = true
bytemuck.workspace = true
x25519-dalek.workspace = true
rand = { git = "https://git.alnyan.me/yggdrasil/rand.git", branch = "alnyan/yggdrasil-rng_core-0.6.4" }
aes = { version = "0.8.4" }

View File

@ -0,0 +1,206 @@
use std::{
net::SocketAddr,
os::fd::{AsRawFd, RawFd}, time::Instant,
};
use aes::{cipher::KeyInit, Aes256};
use rand::prelude::Distribution;
use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::{
crypt::{
decrypt_blocked, encrypt_blocked, ClientNegotiationMessage, ServerNegotiationMessage,
V1_CIPHER_AES_256_CBC, V1_KEX_X25519_DALEK,
},
socket::{MessageSocket, PacketSocket, SocketWrapper},
Error,
};
use super::{ClientMessageProxy, ServerMessageProxy};
enum ClientState {
PreNegotioation,
Connected(Aes256),
}
pub struct ClientEncryptedSocket<S: PacketSocket> {
transport: SocketWrapper<S, ServerMessageProxy, ClientMessageProxy>,
state: ClientState,
last_ping: Instant
}
impl<S: PacketSocket> ClientEncryptedSocket<S> {
pub fn new(transport: S) -> Self {
Self {
transport: SocketWrapper::new(transport),
state: ClientState::PreNegotioation,
last_ping: Instant::now(),
}
}
pub fn last_ping(&self) -> Instant {
self.last_ping
}
pub fn try_connect_blocking(&mut self) -> Result<(), Error> {
let ClientState::PreNegotioation = self.state else {
return Err(Error::InvalidState);
};
let mut buf = [0; 256];
// Send Hello
self.transport
.send(&mut buf, &ClientNegotiationMessage::Hello { protocol: 1 })?;
// Wait for Server Hello
let (message, _) = self.transport.recv_from(&mut buf)?;
let _hello = match message {
ServerNegotiationMessage::Hello(hello) => hello,
ServerNegotiationMessage::Reject(_reason) => {
return Err(Error::Disconnected);
},
_ => {
return Err(Error::Disconnected);
},
};
// TODO select kex, ciphersuite
let ciphersuite = V1_CIPHER_AES_256_CBC;
let kex_algo = V1_KEX_X25519_DALEK;
// Initiate key exchange
self.transport.send(
&mut buf,
&ClientNegotiationMessage::StartKex {
kex_algo,
ciphersuite,
},
)?;
// Wait for server to accept
let (message, _) = self.transport.recv_from(&mut buf)?;
match message {
ServerNegotiationMessage::StartKex => (),
ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected),
_ => return Err(Error::Disconnected),
};
// Generate an ephemeral key
let mut rng = rand::thread_rng();
let secret = EphemeralSecret::random_from_rng(&mut rng);
let public = PublicKey::from(&secret).to_bytes();
// Send it to the server
self.transport.send(
&mut buf,
&ClientNegotiationMessage::PublicKey(true, &public),
)?;
// Wait for server to respond with its own key
let (message, _) = self.transport.recv_from(&mut buf)?;
let mut remote = [0; 32];
match message {
ServerNegotiationMessage::PublicKey(true, key) if key.len() == 32 => {
remote.copy_from_slice(key)
}
ServerNegotiationMessage::PublicKey(_, _key) => return Err(Error::Disconnected),
ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected),
_ => return Err(Error::Disconnected),
};
let remote = PublicKey::from(remote);
let shared = secret.diffie_hellman(&remote);
// Negotiation done
self.transport
.send(&mut buf, &ClientNegotiationMessage::Agreed)?;
// Wait for server's Agreed
let (message, _) = self.transport.recv_from(&mut buf)?;
match message {
ServerNegotiationMessage::Agreed => (),
ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected),
_ => return Err(Error::Disconnected),
}
let aes = Aes256::new_from_slice(shared.as_bytes()).unwrap();
self.state = ClientState::Connected(aes);
Ok(())
}
}
impl<S: PacketSocket> PacketSocket for ClientEncryptedSocket<S> {
type Error = Error;
fn recv_from(&mut self, output: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
match &self.state {
ClientState::Connected(aes) => {
assert!(output.len() >= 16);
loop {
let mut buffer = [0; 256];
let (message, peer) = self.transport.recv_from(&mut buffer)?;
match message {
ServerNegotiationMessage::CipherText(data) => {
let len = decrypt_blocked(aes, data, output)?;
break Ok((len, peer))
}
ServerNegotiationMessage::Ping => {
self.transport.send(&mut buffer, &ClientNegotiationMessage::Pong)?;
self.last_ping = Instant::now();
// TODO this is a workaround
break Err(Error::Ping);
}
ServerNegotiationMessage::Kick(_reason) => {
break Err(Error::NotConnected)
},
// TODO send disconnect and leave
_ => break Err(Error::NotConnected),
}
}
}
_ => Err(Error::NotConnected),
}
}
fn send_to(&mut self, data: &[u8], _addr: &SocketAddr) -> Result<usize, Self::Error> {
self.send(data)
}
fn send(&mut self, data: &[u8]) -> Result<usize, Self::Error> {
let mut buffer = [0; 256];
let mut send_buffer = [0; 256];
match &self.state {
ClientState::Connected(aes) => {
let len = encrypt_blocked(aes, data, &mut buffer)?;
assert_eq!(len % 16, 0);
self.transport.send(
&mut send_buffer,
&ClientNegotiationMessage::CipherText(&buffer[..len]),
)?;
Ok(data.len())
}
_ => Err(Error::NotConnected),
}
}
}
impl<S: PacketSocket> Drop for ClientEncryptedSocket<S> {
fn drop(&mut self) {
match self.state {
ClientState::PreNegotioation => (),
_ => {
let mut buffer = [0; 32];
self.transport
.send(&mut buffer, &ClientNegotiationMessage::Disconnect(0))
.ok();
}
}
}
}
impl<S: PacketSocket> AsRawFd for ClientEncryptedSocket<S> {
fn as_raw_fd(&self) -> RawFd {
self.transport.as_raw_fd()
}
}

View File

@ -0,0 +1,262 @@
use aes::{
cipher::{BlockDecrypt, BlockEncrypt},
Aes256, Block,
};
use crate::{
proto::{Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy},
Error,
};
pub mod client;
pub mod server;
pub use client::ClientEncryptedSocket;
pub use server::ServerEncryptedSocket;
pub const V1_CIPHER_AES_256_CBC: u8 = 0x10;
pub const V1_CIPHER_AES_256_CTR: u8 = 0x11;
pub const V1_CIPHER_AES_256_GCM: u8 = 0x12;
pub const V1_CIPHER_CHACHA20_POLY1305: u8 = 0x15;
// v1 supports only one DH algo
pub const V1_KEX_X25519_DALEK: u8 = 0x10;
pub enum ServerReject {
NoReason,
UnsupportedKexAlgo,
UnsupportedCiphersuite,
InvalidParameter,
}
// v1 protocol
pub struct ServerHello<'a> {
// TODO server will send its public fingerprint and advertise
// supported asymmetric key formats
pub symmetric_ciphersuites: &'a [u8],
pub kex_algos: &'a [u8],
}
pub enum ClientNegotiationMessage<'a> {
Hello { protocol: u8 },
StartKex { kex_algo: u8, ciphersuite: u8 },
PublicKey(bool, &'a [u8]),
Agreed,
Pong,
CipherText(&'a [u8]),
Disconnect(u8),
}
pub enum ServerNegotiationMessage<'a> {
Hello(ServerHello<'a>),
Reject(ServerReject),
StartKex,
PublicKey(bool, &'a [u8]),
Agreed,
Kick(u8),
Ping,
CipherText(&'a [u8]),
}
pub struct ClientMessageProxy;
pub struct ServerMessageProxy;
impl MessageProxy for ClientMessageProxy {
type Type<'de> = ClientNegotiationMessage<'de>;
}
impl MessageProxy for ServerMessageProxy {
type Type<'de> = ServerNegotiationMessage<'de>;
}
impl ClientNegotiationMessage<'_> {
const TAG_HELLO: u8 = 0x90;
const TAG_START_KEX: u8 = 0x91;
const TAG_PUBLIC_KEY: u8 = 0x92;
const TAG_AGREED: u8 = 0x93;
const TAG_CIPHERTEXT: u8 = 0xA0;
const TAG_DISCONNECT: u8 = 0xA1;
const TAG_PONG: u8 = 0xA2;
}
impl Encode for ClientNegotiationMessage<'_> {
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
match self {
&Self::Hello { protocol } => buffer.write(&[Self::TAG_HELLO, protocol]),
&Self::StartKex {
kex_algo,
ciphersuite,
} => buffer.write(&[Self::TAG_START_KEX, kex_algo, ciphersuite]),
Self::PublicKey(end, data) => {
buffer.write(&[Self::TAG_PUBLIC_KEY, *end as u8])?;
buffer.write_variable_bytes(data)
}
Self::Agreed => buffer.write(&[Self::TAG_AGREED]),
Self::CipherText(data) => {
buffer.write(&[Self::TAG_CIPHERTEXT])?;
buffer.write_variable_bytes(data)
}
&Self::Disconnect(reason) => buffer.write(&[Self::TAG_DISCONNECT, reason]),
Self::Pong => buffer.write(&[Self::TAG_PONG]),
}
}
}
impl<'de> Decode<'de> for ClientNegotiationMessage<'de> {
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
let tag = buffer.read_u8()?;
match tag {
Self::TAG_HELLO => {
let protocol = buffer.read_u8()?;
Ok(Self::Hello { protocol })
}
Self::TAG_START_KEX => {
let kex_algo = buffer.read_u8()?;
let ciphersuite = buffer.read_u8()?;
Ok(Self::StartKex {
kex_algo,
ciphersuite,
})
}
Self::TAG_PUBLIC_KEY => {
let end = buffer.read_u8()? != 0;
let data = buffer.read_variable_bytes()?;
Ok(Self::PublicKey(end, data))
}
Self::TAG_CIPHERTEXT => buffer.read_variable_bytes().map(Self::CipherText),
Self::TAG_AGREED => Ok(Self::Agreed),
Self::TAG_PONG => Ok(Self::Pong),
Self::TAG_DISCONNECT => buffer.read_u8().map(Self::Disconnect),
_ => Err(DecodeError::InvalidMessage),
}
}
}
impl ServerNegotiationMessage<'_> {
const TAG_HELLO: u8 = 0xB0;
const TAG_REJECT: u8 = 0xB1;
const TAG_START_KEX: u8 = 0xB2;
const TAG_PUBLIC_KEY: u8 = 0xB3;
const TAG_AGREED: u8 = 0xB4;
const TAG_CIPHERTEXT: u8 = 0xC0;
const TAG_PING: u8 = 0xC1;
const TAG_KICK: u8 = 0xC2;
}
impl Encode for ServerHello<'_> {
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
buffer.write_variable_bytes(self.kex_algos)?;
buffer.write_variable_bytes(self.symmetric_ciphersuites)?;
Ok(())
}
}
impl<'de> Decode<'de> for ServerHello<'de> {
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
let kex_algos = buffer.read_variable_bytes()?;
let symmetric_ciphersuites = buffer.read_variable_bytes()?;
Ok(Self {
kex_algos,
symmetric_ciphersuites,
})
}
}
impl Encode for ServerNegotiationMessage<'_> {
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
match self {
Self::Hello(hello) => {
buffer.write(&[Self::TAG_HELLO])?;
hello.encode(buffer)
}
Self::Reject(_reason) => todo!(),
Self::StartKex => buffer.write(&[Self::TAG_START_KEX]),
Self::PublicKey(end, data) => {
buffer.write(&[Self::TAG_PUBLIC_KEY, *end as u8])?;
buffer.write_variable_bytes(data)
}
Self::Agreed => buffer.write(&[Self::TAG_AGREED]),
Self::CipherText(data) => {
buffer.write(&[Self::TAG_CIPHERTEXT])?;
buffer.write_variable_bytes(data)
}
&Self::Kick(reason) => buffer.write(&[Self::TAG_KICK, reason]),
&Self::Ping => buffer.write(&[Self::TAG_PING]),
}
}
}
impl<'de> Decode<'de> for ServerNegotiationMessage<'de> {
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
let tag = buffer.read_u8()?;
match tag {
Self::TAG_HELLO => ServerHello::decode(buffer).map(Self::Hello),
Self::TAG_REJECT => todo!(),
Self::TAG_START_KEX => Ok(Self::StartKex),
Self::TAG_PUBLIC_KEY => {
let end = buffer.read_u8()? != 0;
let data = buffer.read_variable_bytes()?;
Ok(Self::PublicKey(end, data))
}
Self::TAG_AGREED => Ok(Self::Agreed),
Self::TAG_CIPHERTEXT => buffer.read_variable_bytes().map(Self::CipherText),
Self::TAG_KICK => buffer.read_u8().map(Self::Kick),
Self::TAG_PING => Ok(Self::Ping),
_ => Err(DecodeError::InvalidMessage),
}
}
}
fn decrypt_blocked(aes: &Aes256, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> {
if src.len() % 16 != 0 || dst.len() < src.len() {
todo!();
}
let mut pos = 0;
let mut out = 0;
let mut block = Block::from([0; 16]);
while pos != src.len() {
block.copy_from_slice(&src[pos..pos + 16]);
aes.decrypt_block(&mut block);
let len = block[0] as usize;
if len > 15 {
todo!();
}
dst[out..out + len].copy_from_slice(&block[1..1 + len]);
out += len;
pos += 16;
}
Ok(out)
}
fn encrypt_blocked(aes: &Aes256, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> {
if src.len() >= (dst.len() / 16) * 15 {
todo!();
}
let mut pos = 0;
let mut out = 0;
let mut block = Block::from([0; 16]);
while pos != src.len() {
let len = core::cmp::min(src.len() - pos, 15);
block[0] = len as u8;
block[1..1 + len].copy_from_slice(&src[pos..pos + len]);
// TODO pad with random
if len != 15 {
block[1 + len..].fill(0);
}
aes.encrypt_block(&mut block);
dst[out..out + 16].copy_from_slice(&block);
pos += len;
out += 16;
}
Ok(out)
}

View File

@ -0,0 +1,202 @@
use std::{
collections::HashMap,
mem,
net::SocketAddr,
os::fd::{AsRawFd, RawFd},
};
use aes::{cipher::KeyInit, Aes256};
use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::{
crypt::{decrypt_blocked, ServerHello, ServerNegotiationMessage},
socket::{
MessageSocket, MultiplexedSocket, MultiplexedSocketEvent, PacketSocket, SocketWrapper,
},
Error,
};
use super::{encrypt_blocked, ClientMessageProxy, ClientNegotiationMessage, ServerMessageProxy};
pub enum ServerPeerTransport {
PreNegotiation,
Negotiation(EphemeralSecret),
Connected(Aes256, usize),
}
pub struct ServerEncryptedSocket<S: PacketSocket> {
transport: SocketWrapper<S, ClientMessageProxy, ServerMessageProxy>,
peers: HashMap<SocketAddr, ServerPeerTransport>,
buffer: [u8; 256],
}
impl<S: PacketSocket> ServerEncryptedSocket<S> {
pub fn new(transport: S) -> Self {
Self {
transport: SocketWrapper::new(transport),
peers: HashMap::new(),
buffer: [0; 256],
}
}
pub fn remove_client(&mut self, remote: &SocketAddr) {
let mut buf = [0; 32];
if self.peers.remove(remote).is_some() {
self.transport
.send_to(remote, &mut buf, &ServerNegotiationMessage::Kick(0))
.ok();
}
}
pub fn ping_clients(&mut self, limit: usize) -> Vec<SocketAddr> {
let mut send_buf = [0; 32];
let mut removed = vec![];
for (remote, state) in self.peers.iter_mut() {
match state {
ServerPeerTransport::Connected(_, missed) => {
self.transport.send_to(remote, &mut send_buf, &ServerNegotiationMessage::Ping).ok();
if *missed >= limit {
removed.push(*remote);
}
*missed += 1;
}
_ => (),
}
}
for entry in removed.iter() {
self.peers.remove(entry);
}
removed
}
}
impl<S: PacketSocket> MultiplexedSocket for ServerEncryptedSocket<S> {
fn send_to(&mut self, remote: &SocketAddr, data: &[u8]) -> Result<(), Error> {
let mut buf = [0; 256];
if let Some(ServerPeerTransport::Connected(aes, _)) = self.peers.get(remote) {
let len = encrypt_blocked(aes, data, &mut self.buffer)?;
assert_eq!(len % 16, 0);
self.transport.send_to(
remote,
&mut buf,
&ServerNegotiationMessage::CipherText(&self.buffer[..len]),
)
} else {
Err(Error::NotConnected)
}
}
fn recv_from<'a>(&mut self, buffer: &'a mut [u8]) -> Result<MultiplexedSocketEvent<'a>, Error> {
let (message, remote) = self.transport.recv_from(&mut self.buffer)?;
let mut buf = [0; 256];
if let Some(mut state) = self.peers.get_mut(&remote) {
match (message, &mut state) {
// TODO check kex params
(
ClientNegotiationMessage::StartKex { .. },
ServerPeerTransport::PreNegotiation,
) => {
let mut rng = rand::thread_rng();
let secret = EphemeralSecret::random_from_rng(&mut rng);
*state = ServerPeerTransport::Negotiation(secret);
self.transport.send_to(
&remote,
&mut buf,
&ServerNegotiationMessage::StartKex,
)?;
Ok(MultiplexedSocketEvent::None(remote))
}
(
ClientNegotiationMessage::PublicKey(true, data),
ServerPeerTransport::Negotiation(secret),
) if data.len() == 32 => {
let public = PublicKey::from(&*secret);
let mut remote_key = [0; 32];
remote_key.copy_from_slice(data);
let remote_key = PublicKey::from(remote_key);
state.negotiate(&remote_key);
// Send public key to the client
self.transport.send_to(
&remote,
&mut buf,
&ServerNegotiationMessage::PublicKey(true, public.as_bytes()),
)?;
Ok(MultiplexedSocketEvent::None(remote))
}
(ClientNegotiationMessage::Agreed, ServerPeerTransport::Connected(_, _)) => {
self.transport
.send_to(&remote, &mut buf, &ServerNegotiationMessage::Agreed)?;
Ok(MultiplexedSocketEvent::None(remote))
}
(
ClientNegotiationMessage::CipherText(data),
ServerPeerTransport::Connected(aes, _),
) => {
let len = decrypt_blocked(aes, data, buffer)?;
Ok(MultiplexedSocketEvent::ClientData(remote, &buffer[..len]))
}
(
ClientNegotiationMessage::Pong,
ServerPeerTransport::Connected(_, missed)
) => {
*missed = 0;
Ok(MultiplexedSocketEvent::None(remote))
}
(ClientNegotiationMessage::Disconnect(_reason), _) => {
eprintln!("Peer disconnected: {remote}");
self.peers.remove(&remote);
Ok(MultiplexedSocketEvent::ClientDisconnected(remote))
}
// Misbehavior
_ => {
self.peers.remove(&remote);
Ok(MultiplexedSocketEvent::ClientDisconnected(remote))
}
}
} else {
let ClientNegotiationMessage::Hello { protocol: 1 } = message else {
eprintln!("Unhandled client message");
return Ok(MultiplexedSocketEvent::None(remote));
};
eprintln!("Client Hello");
self.peers
.insert(remote, ServerPeerTransport::PreNegotiation);
// Reply with hello
let hello = ServerHello {
kex_algos: &[],
symmetric_ciphersuites: &[],
};
self.transport
.send_to(&remote, &mut buf, &ServerNegotiationMessage::Hello(hello))?;
Ok(MultiplexedSocketEvent::None(remote))
}
}
}
impl<S: PacketSocket> AsRawFd for ServerEncryptedSocket<S> {
fn as_raw_fd(&self) -> RawFd {
self.transport.as_raw_fd()
}
}
impl ServerPeerTransport {
fn negotiate(&mut self, public: &PublicKey) {
let Self::Negotiation(secret) = mem::replace(self, ServerPeerTransport::PreNegotiation)
else {
panic!();
};
let shared = secret.diffie_hellman(public);
let aes = Aes256::new_from_slice(shared.as_bytes()).unwrap();
*self = ServerPeerTransport::Connected(aes, 0);
}
}

View File

@ -2,20 +2,15 @@
#![feature(generic_const_exprs)]
#![allow(incomplete_features)]
use std::{
io,
marker::PhantomData,
net::{SocketAddr, UdpSocket},
ops::{Deref, DerefMut},
os::fd::{AsRawFd, RawFd},
};
use std::io;
use proto::{
ClientMessageProxy, Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy,
ServerMessageProxy,
};
use proto::{DecodeError, EncodeError};
pub mod proto;
pub mod socket;
pub mod crypt;
pub use socket::{ClientSocket, ServerSocket};
#[derive(Debug, thiserror::Error)]
pub enum Error {
@ -27,106 +22,12 @@ pub enum Error {
Decode(#[from] DecodeError),
#[error("Encode error: {0}")]
Encode(#[from] EncodeError),
}
pub struct SocketWrapper<Rx: MessageProxy, Tx: MessageProxy> {
socket: UdpSocket,
_pd: PhantomData<(Rx, Tx)>,
}
pub struct ClientSocket(SocketWrapper<ServerMessageProxy, ClientMessageProxy>);
pub struct ServerSocket(SocketWrapper<ClientMessageProxy, ServerMessageProxy>);
impl<Rx: MessageProxy, Tx: MessageProxy> SocketWrapper<Rx, Tx> {
pub fn new(socket: UdpSocket) -> Self {
Self {
socket,
_pd: PhantomData,
}
}
pub fn send(&self, buffer: &mut [u8], message: &Tx::Type<'_>) -> Result<(), Error> {
let mut enc = Encoder::new(buffer);
message.encode(&mut enc)?;
let message = enc.get();
let amount = self.socket.send(message)?;
if amount == message.len() {
Ok(())
} else {
Err(Error::Truncated)
}
}
pub fn send_to(
&mut self,
remote: &SocketAddr,
buffer: &mut [u8],
message: &Tx::Type<'_>,
) -> Result<(), Error> {
let mut enc = Encoder::new(buffer);
message.encode(&mut enc)?;
let message = enc.get();
let amount = self.socket.send_to(message, remote)?;
if amount == message.len() {
Ok(())
} else {
Err(Error::Truncated)
}
}
pub fn recv_from<'de>(
&mut self,
buffer: &'de mut [u8],
) -> Result<(Rx::Type<'de>, SocketAddr), Error> {
let (len, remote) = self.socket.recv_from(buffer)?;
let mut dec = Decoder::new(&buffer[..len]);
let message = Rx::Type::<'de>::decode(&mut dec)?;
Ok((message, remote))
}
}
impl<Rx: MessageProxy, Tx: MessageProxy> AsRawFd for SocketWrapper<Rx, Tx> {
fn as_raw_fd(&self) -> RawFd {
self.socket.as_raw_fd()
}
}
impl ClientSocket {
pub fn new(socket: UdpSocket) -> Self {
Self(SocketWrapper::new(socket))
}
}
impl Deref for ClientSocket {
type Target = SocketWrapper<ServerMessageProxy, ClientMessageProxy>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for ClientSocket {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl ServerSocket {
pub fn new(socket: UdpSocket) -> Self {
Self(SocketWrapper::new(socket))
}
}
impl Deref for ServerSocket {
type Target = SocketWrapper<ClientMessageProxy, ServerMessageProxy>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for ServerSocket {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
#[error("Peer not connected")]
NotConnected,
#[error("Ping")]
Ping,
#[error("Invalid socket state")]
InvalidState,
#[error("Disconnected by remote peer")]
Disconnected,
}

View File

@ -12,8 +12,7 @@ use clap::Parser;
use cross::io::Poll;
use libterm::{RawMode, RawTerminal};
use rsh::{
proto::{ClientMessage, ServerMessage, TerminalInfo},
ClientSocket,
crypt::ClientEncryptedSocket, proto::{ClientMessage, ServerMessage, TerminalInfo}, socket::{MessageSocket, PacketSocket}, ClientSocket
};
pub const PING_TIMEOUT: Duration = Duration::from_secs(3);
@ -39,11 +38,10 @@ struct Args {
pub struct Client {
poll: Poll,
socket: ClientSocket,
socket: ClientSocket<ClientEncryptedSocket<UdpSocket>>,
stdin: Stdin,
stdout: Stdout,
need_bye: bool,
last_ping: Instant,
_raw: RawMode,
}
@ -51,7 +49,6 @@ pub enum Event<'b> {
Stdin(&'b [u8]),
Data(&'b [u8]),
Disconnected(&'b str),
Ping,
}
impl Client {
@ -69,6 +66,8 @@ impl Client {
let socket = UdpSocket::bind(local)?;
socket.connect(remote)?;
let mut socket = ClientEncryptedSocket::new(socket);
socket.try_connect_blocking()?;
let mut socket = ClientSocket::new(socket);
poll.add(&*socket)?;
@ -82,7 +81,6 @@ impl Client {
let _raw = unsafe { RawMode::enter(&stdin)? };
Ok(Self {
last_ping: Instant::now(),
poll,
socket,
stdout,
@ -102,7 +100,11 @@ impl Client {
match event {
fd if fd == self.socket.as_raw_fd() => {
let (message, _) = self.socket.recv_from(buffer)?;
let message = match self.socket.recv_from(buffer) {
Ok((message, _)) => message,
Err(rsh::Error::Ping) => return Ok(None),
Err(error) => return Err(error.into())
};
match message {
ServerMessage::Bye(reason) => {
// No need for a bye
@ -112,9 +114,6 @@ impl Client {
ServerMessage::Output(data) => {
break Ok(Some(Event::Data(data)));
}
ServerMessage::Ping => {
break Ok(Some(Event::Ping));
}
// Ignore this one
ServerMessage::Hello => break Ok(None),
}
@ -148,10 +147,6 @@ impl Client {
Event::Stdin(data) => {
self.socket.send(&mut send_buf, &ClientMessage::Input(data))?;
}
Event::Ping => {
self.last_ping = Instant::now();
self.socket.send(&mut send_buf, &ClientMessage::Pong)?;
}
Event::Disconnected(reason) => {
break Ok(reason.into());
}
@ -161,16 +156,16 @@ impl Client {
fn check_timeout(&mut self) -> Result<(), Error> {
let now = Instant::now();
if now - self.last_ping >= PING_TIMEOUT {
if now - self.socket.as_inner_mut().last_ping() >= PING_TIMEOUT {
Err(Error::Timeout)
} else {
Ok(())
}
}
fn handshake(
fn handshake<S: PacketSocket>(
poll: &mut Poll,
socket: &mut ClientSocket,
socket: &mut ClientSocket<S>,
terminal: TerminalInfo,
attempts: usize,
) -> Result<(), Error> {
@ -187,9 +182,9 @@ impl Client {
Self::try_handshake(poll, socket, terminal, timeout)
}
fn try_handshake(
fn try_handshake<S: PacketSocket>(
poll: &mut Poll,
socket: &mut ClientSocket,
socket: &mut ClientSocket<S>,
terminal: TerminalInfo,
timeout: Duration,
) -> Result<(), Error> {

View File

@ -11,7 +11,6 @@ pub struct ServerMessageProxy;
pub enum ClientMessage<'a> {
Hello(TerminalInfo),
Bye(&'a str),
Pong,
Input(&'a [u8]),
}
@ -19,7 +18,6 @@ pub enum ClientMessage<'a> {
pub enum ServerMessage<'a> {
Hello,
Bye(&'a str),
Ping,
Output(&'a [u8]),
}
@ -68,6 +66,10 @@ impl<'a> Encoder<'a> {
Self { buffer, pos: 0 }
}
pub fn reset(&mut self) {
self.pos = 0;
}
pub fn write_variable_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
let len: u32 = bytes
.len()
@ -145,7 +147,6 @@ impl MessageProxy for ServerMessageProxy {
impl ClientMessage<'_> {
const TAG_HELLO: u8 = 0x80;
const TAG_BYE: u8 = 0x81;
const TAG_PONG: u8 = 0x82;
const TAG_INPUT: u8 = 0x90;
}
@ -176,9 +177,6 @@ impl<'a> Encode for ClientMessage<'a> {
buffer.write(&[Self::TAG_BYE])?;
buffer.write_str(reason)
}
Self::Pong => {
buffer.write(&[Self::TAG_PONG])
}
Self::Input(data) => {
buffer.write(&[Self::TAG_INPUT])?;
buffer.write_variable_bytes(data)
@ -197,9 +195,6 @@ impl<'de> Decode<'de> for ClientMessage<'de> {
Self::TAG_BYE => {
buffer.read_str().map(Self::Bye)
}
Self::TAG_PONG => {
Ok(Self::Pong)
}
Self::TAG_INPUT => {
buffer.read_variable_bytes().map(Self::Input)
}
@ -211,7 +206,6 @@ impl<'de> Decode<'de> for ClientMessage<'de> {
impl ServerMessage<'_> {
const TAG_HELLO: u8 = 0x10;
const TAG_BYE: u8 = 0x11;
const TAG_PING: u8 = 0x12;
const TAG_OUTPUT: u8 = 0x20;
}
@ -223,8 +217,6 @@ impl<'a> Encode for ServerMessage<'a> {
buffer.write(&[Self::TAG_BYE])?;
buffer.write_str(reason)
}
// TODO sequence number
Self::Ping => buffer.write(&[Self::TAG_PING]),
Self::Output(data) => {
buffer.write(&[Self::TAG_OUTPUT])?;
buffer.write_variable_bytes(data)
@ -241,7 +233,6 @@ impl<'de> Decode<'de> for ServerMessage<'de> {
Self::TAG_BYE => {
buffer.read_str().map(Self::Bye)
}
Self::TAG_PING => Ok(Self::Ping),
Self::TAG_OUTPUT => {
buffer.read_variable_bytes().map(Self::Output)
},

View File

@ -13,8 +13,7 @@ use std::{
use cross::io::{Poll, TimerFd};
use rsh::{
proto::{ClientMessage, ServerMessage, TerminalInfo},
Error, ServerSocket,
crypt::ServerEncryptedSocket, proto::{ClientMessage, Decode, Decoder, Encoder, ServerMessage, TerminalInfo}, socket::{MessageSocket, MultiplexedSocket, MultiplexedSocketEvent}, Error, ServerSocket
};
pub const PING_INTERVAL: Duration = Duration::from_millis(500);
@ -29,7 +28,7 @@ pub struct Session {
pub struct Server {
poll: Poll,
timer: TimerFd,
socket: ServerSocket,
socket: ServerEncryptedSocket<UdpSocket>,
addr_to_session: HashMap<SocketAddr, RawFd>,
pty_to_session: HashMap<RawFd, Session>,
@ -98,8 +97,8 @@ impl Server {
pub fn new(listen_addr: SocketAddr) -> Result<Self, Error> {
let mut poll = Poll::new()?;
let timer = TimerFd::new()?;
let socket = UdpSocket::bind(listen_addr).map(ServerSocket::new)?;
poll.add(&*socket)?;
let socket = UdpSocket::bind(listen_addr).map(ServerEncryptedSocket::new)?;
poll.add(&socket)?;
poll.add(&timer)?;
Ok(Self {
poll,
@ -115,13 +114,27 @@ impl Server {
match fd {
fd if fd == self.socket.as_raw_fd() => {
let (message, remote) = match self.socket.recv_from(buffer) {
Ok((message, remote)) => (message, remote),
Err(error @ (Error::Decode(_) | Error::Truncated)) => {
eprintln!("Receive error: {error}");
let event = self.socket.recv_from(buffer)?;
let (message, remote) = match event {
MultiplexedSocketEvent::ClientDisconnected(remote) => {
self.remove_session_by_remote(remote).ok();
return Ok(None)
},
Err(error) => return Err(error),
}
MultiplexedSocketEvent::None(_) => return Ok(None),
MultiplexedSocketEvent::ClientData(peer, data) => {
let mut decoder = Decoder::new(data);
let message = ClientMessage::decode(&mut decoder);
(message, peer)
}
};
let message = match message {
Ok(message) => message,
Err(error) => {
eprintln!("Decode error: {error}");
return Ok(None);
}
};
let event = match message {
@ -136,11 +149,6 @@ impl Server {
{
Event::SessionInput(*fd, remote, data)
}
ClientMessage::Pong if let Some(fd) = self.addr_to_session.get(&remote) => {
let session = self.pty_to_session.get_mut(fd).unwrap();
session.timeouts = 0;
return Ok(None);
}
_ => return Ok(None),
};
@ -176,7 +184,7 @@ impl Server {
eprintln!("PTY write error: {error}");
self.remove_session_by_fd(fd)?;
self.socket
.send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
.send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
.ok();
}
}
@ -190,13 +198,13 @@ impl Server {
Ok(session) => {
self.register_session(remote, session)?;
self.socket
.send_to(&remote, &mut send_buf, &ServerMessage::Hello)
.send_message_to(&remote, &mut send_buf, &ServerMessage::Hello)
.ok();
}
Err(err) => {
eprintln!("PTY open error: {err}");
self.socket
.send_to(
.send_message_to(
&remote,
&mut send_buf,
&ServerMessage::Bye("PTY open error"),
@ -208,48 +216,37 @@ impl Server {
Event::Pty(fd, remote, event) => match event {
PtyEvent::Data(data) => {
self.socket
.send_to(&remote, &mut send_buf, &ServerMessage::Output(data))
.send_message_to(&remote, &mut send_buf, &ServerMessage::Output(data))
.ok();
},
PtyEvent::Err(error) => {
eprintln!("PTY read error: {error}");
self.remove_session_by_fd(fd)?;
self.socket
.send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
.send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
.ok();
},
PtyEvent::Closed => {
println!("End of PTY for {remote}");
self.remove_session_by_fd(fd)?;
self.socket
.send_to(&remote, &mut send_buf, &ServerMessage::Bye(""))
.send_message_to(&remote, &mut send_buf, &ServerMessage::Bye(""))
.ok();
},
},
Event::Tick => {
// Restart the timer
self.update_client_timeouts(&mut send_buf)?;
self.update_client_timeouts()?;
self.timer.start(PING_INTERVAL)?;
}
}
}
}
fn update_client_timeouts(&mut self, send_buf: &mut [u8]) -> Result<(), Error> {
let mut removed = vec![];
for (remote, fd) in self.addr_to_session.iter() {
let session = self.pty_to_session.get_mut(&fd).unwrap();
self.socket.send_to(remote, send_buf, &ServerMessage::Ping).ok();
session.timeouts += 1;
if session.timeouts >= 10 {
removed.push((*remote, *fd));
}
}
for (remote, fd) in removed {
eprintln!("Client {remote} timed out");
self.remove_session_by_fd(fd)?;
self.socket.send_to(&remote, send_buf, &ServerMessage::Bye("Timed out".into())).ok();
fn update_client_timeouts(&mut self) -> Result<(), Error> {
let removed = self.socket.ping_clients(8);
for entry in removed {
self.remove_session_by_remote(entry).ok();
}
Ok(())
}
@ -268,6 +265,7 @@ impl Server {
// NOTE: this will block the whole server while the process finishes.
session.shell.wait().ok();
self.addr_to_session.remove(&session.remote).unwrap();
self.socket.remove_client(&session.remote);
self.poll.remove(&fd)?;
Ok(Some(session))
} else {

202
userspace/rsh/src/socket.rs Normal file
View File

@ -0,0 +1,202 @@
use std::{
io,
marker::PhantomData,
net::{SocketAddr, UdpSocket},
ops::{Deref, DerefMut},
os::fd::{AsRawFd, RawFd},
};
use crate::{
proto::{
ClientMessageProxy, Decode, Decoder, Encode, Encoder, MessageProxy, ServerMessageProxy,
},
Error,
};
pub struct SocketWrapper<S: PacketSocket, Rx: MessageProxy, Tx: MessageProxy> {
socket: S,
_pd: PhantomData<(Rx, Tx)>,
}
pub trait MessageSocket<Rx: MessageProxy, Tx: MessageProxy>: AsRawFd {
fn send(&mut self, buffer: &mut [u8], data: &Tx::Type<'_>) -> Result<(), Error>;
fn send_to(
&mut self,
addr: &SocketAddr,
buffer: &mut [u8],
data: &Tx::Type<'_>,
) -> Result<(), Error>;
fn recv_from<'de>(
&mut self,
buffer: &'de mut [u8],
) -> Result<(Rx::Type<'de>, SocketAddr), Error>;
}
pub trait PacketSocket: AsRawFd {
type Error: Into<Error>;
fn send(&mut self, data: &[u8]) -> Result<usize, Self::Error>;
fn send_to(&mut self, data: &[u8], addr: &SocketAddr) -> Result<usize, Self::Error>;
fn recv_from(&mut self, data: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error>;
}
pub enum MultiplexedSocketEvent<'a> {
ClientData(SocketAddr, &'a [u8]),
ClientDisconnected(SocketAddr),
None(SocketAddr),
}
pub trait MultiplexedSocket: AsRawFd {
fn recv_from<'a>(&mut self, buffer: &'a mut [u8]) -> Result<MultiplexedSocketEvent<'a>, Error>;
fn send_to(&mut self, remote: &SocketAddr, data: &[u8]) -> Result<(), Error>;
fn send_message_to<Tx: Encode>(
&mut self,
remote: &SocketAddr,
buffer: &mut [u8],
data: &Tx,
) -> Result<(), Error> {
let mut encoder = Encoder::new(buffer);
data.encode(&mut encoder)?;
self.send_to(remote, encoder.get())
}
}
impl PacketSocket for UdpSocket {
type Error = io::Error;
fn send(&mut self, data: &[u8]) -> Result<usize, Self::Error> {
UdpSocket::send(self, data)
}
fn send_to(&mut self, data: &[u8], addr: &SocketAddr) -> Result<usize, Self::Error> {
UdpSocket::send_to(self, data, addr)
}
fn recv_from(&mut self, data: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
UdpSocket::recv_from(self, data)
}
}
pub struct ClientSocket<S: PacketSocket>(SocketWrapper<S, ServerMessageProxy, ClientMessageProxy>);
pub struct ServerSocket<S: MultiplexedSocket>(S);
impl<S: PacketSocket, Rx: MessageProxy, Tx: MessageProxy> MessageSocket<Rx, Tx>
for SocketWrapper<S, Rx, Tx>
{
fn send(&mut self, buffer: &mut [u8], data: &Tx::Type<'_>) -> Result<(), Error> {
let mut enc = Encoder::new(buffer);
data.encode(&mut enc)?;
let message = enc.get();
let amount = self.socket.send(message).map_err(S::Error::into)?;
if amount == message.len() {
Ok(())
} else {
Err(Error::Truncated)
}
}
fn send_to(
&mut self,
addr: &SocketAddr,
buffer: &mut [u8],
data: &Tx::Type<'_>,
) -> Result<(), Error> {
let mut enc = Encoder::new(buffer);
data.encode(&mut enc)?;
let message = enc.get();
let amount = self.socket.send_to(message, addr).map_err(S::Error::into)?;
if amount == message.len() {
Ok(())
} else {
Err(Error::Truncated)
}
}
fn recv_from<'de>(
&mut self,
buffer: &'de mut [u8],
) -> Result<(Rx::Type<'de>, SocketAddr), Error> {
let (len, remote) = self.socket.recv_from(buffer).map_err(S::Error::into)?;
let mut dec = Decoder::new(&buffer[..len]);
let message = Rx::Type::<'de>::decode(&mut dec)?;
Ok((message, remote))
}
}
impl<S: PacketSocket, Rx: MessageProxy, Tx: MessageProxy> SocketWrapper<S, Rx, Tx> {
pub fn new(socket: S) -> Self {
Self {
socket,
_pd: PhantomData,
}
}
pub fn into_inner(self) -> S {
self.socket
}
pub fn as_inner_mut(&mut self) -> &mut S {
&mut self.socket
}
}
impl<S: PacketSocket, Rx: MessageProxy, Tx: MessageProxy> AsRawFd for SocketWrapper<S, Rx, Tx> {
fn as_raw_fd(&self) -> RawFd {
self.socket.as_raw_fd()
}
}
impl<S: PacketSocket> ClientSocket<S> {
pub fn new(socket: S) -> Self {
Self(SocketWrapper::new(socket))
}
}
impl<S: PacketSocket> Deref for ClientSocket<S> {
type Target = SocketWrapper<S, ServerMessageProxy, ClientMessageProxy>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S: PacketSocket> DerefMut for ClientSocket<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<S: MultiplexedSocket> ServerSocket<S> {
pub fn new(socket: S) -> Self {
Self(socket)
}
}
impl<S: MultiplexedSocket> Deref for ServerSocket<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S: MultiplexedSocket> DerefMut for ServerSocket<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
// impl<S: PacketSocket> Deref for ServerSocket<S> {
// type Target = SocketWrapper<S, ClientMessageProxy, ServerMessageProxy>;
//
// fn deref(&self) -> &Self::Target {
// &self.0
// }
// }
//
// impl<S: PacketSocket> DerefMut for ServerSocket<S> {
// fn deref_mut(&mut self) -> &mut Self::Target {
// &mut self.0
// }
// }