rsh: implement dumb kex+aes256
This commit is contained in:
parent
bcf1e74a04
commit
99c1dd51ae
58
userspace/Cargo.lock
generated
58
userspace/Cargo.lock
generated
@ -16,6 +16,17 @@ dependencies = [
|
||||
name = "abi-lib"
|
||||
version = "0.1.0"
|
||||
|
||||
[[package]]
|
||||
name = "aes"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cipher",
|
||||
"cpufeatures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle"
|
||||
version = "1.0.9"
|
||||
@ -102,6 +113,16 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "cipher"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
|
||||
dependencies = [
|
||||
"crypto-common",
|
||||
"inout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.20"
|
||||
@ -522,6 +543,15 @@ dependencies = [
|
||||
"yggdrasil-rt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inout"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.11"
|
||||
@ -1011,11 +1041,14 @@ dependencies = [
|
||||
name = "rsh"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"aes",
|
||||
"bytemuck",
|
||||
"clap",
|
||||
"cross",
|
||||
"libterm",
|
||||
"rand 0.8.5",
|
||||
"thiserror",
|
||||
"x25519-dalek",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1633,6 +1666,17 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x25519-dalek"
|
||||
version = "2.0.1"
|
||||
source = "git+https://git.alnyan.me/yggdrasil/curve25519-dalek.git?branch=alnyan%2Fyggdrasil#5f4dbb09259347077d3a5024ba60c77efde93a3a"
|
||||
dependencies = [
|
||||
"curve25519-dalek",
|
||||
"rand_core 0.6.4 (git+https://git.alnyan.me/yggdrasil/rand.git?branch=alnyan%2Fyggdrasil-rng_core-0.6.4)",
|
||||
"serde",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yasync"
|
||||
version = "0.1.0"
|
||||
@ -1710,3 +1754,17 @@ name = "zeroize"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
|
||||
dependencies = [
|
||||
"zeroize_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeroize_derive"
|
||||
version = "1.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.86",
|
||||
]
|
||||
|
@ -8,8 +8,12 @@ name = "rshd"
|
||||
path = "src/rshd/main.rs"
|
||||
|
||||
[dependencies]
|
||||
clap.workspace = true
|
||||
libterm.workspace = true
|
||||
thiserror.workspace = true
|
||||
cross.workspace = true
|
||||
|
||||
clap.workspace = true
|
||||
thiserror.workspace = true
|
||||
bytemuck.workspace = true
|
||||
x25519-dalek.workspace = true
|
||||
rand = { git = "https://git.alnyan.me/yggdrasil/rand.git", branch = "alnyan/yggdrasil-rng_core-0.6.4" }
|
||||
aes = { version = "0.8.4" }
|
||||
|
206
userspace/rsh/src/crypt/client.rs
Normal file
206
userspace/rsh/src/crypt/client.rs
Normal file
@ -0,0 +1,206 @@
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
os::fd::{AsRawFd, RawFd}, time::Instant,
|
||||
};
|
||||
|
||||
use aes::{cipher::KeyInit, Aes256};
|
||||
use rand::prelude::Distribution;
|
||||
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||
|
||||
use crate::{
|
||||
crypt::{
|
||||
decrypt_blocked, encrypt_blocked, ClientNegotiationMessage, ServerNegotiationMessage,
|
||||
V1_CIPHER_AES_256_CBC, V1_KEX_X25519_DALEK,
|
||||
},
|
||||
socket::{MessageSocket, PacketSocket, SocketWrapper},
|
||||
Error,
|
||||
};
|
||||
|
||||
use super::{ClientMessageProxy, ServerMessageProxy};
|
||||
|
||||
enum ClientState {
|
||||
PreNegotioation,
|
||||
Connected(Aes256),
|
||||
}
|
||||
|
||||
pub struct ClientEncryptedSocket<S: PacketSocket> {
|
||||
transport: SocketWrapper<S, ServerMessageProxy, ClientMessageProxy>,
|
||||
state: ClientState,
|
||||
last_ping: Instant
|
||||
}
|
||||
|
||||
impl<S: PacketSocket> ClientEncryptedSocket<S> {
|
||||
pub fn new(transport: S) -> Self {
|
||||
Self {
|
||||
transport: SocketWrapper::new(transport),
|
||||
state: ClientState::PreNegotioation,
|
||||
last_ping: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn last_ping(&self) -> Instant {
|
||||
self.last_ping
|
||||
}
|
||||
|
||||
pub fn try_connect_blocking(&mut self) -> Result<(), Error> {
|
||||
let ClientState::PreNegotioation = self.state else {
|
||||
return Err(Error::InvalidState);
|
||||
};
|
||||
let mut buf = [0; 256];
|
||||
|
||||
// Send Hello
|
||||
self.transport
|
||||
.send(&mut buf, &ClientNegotiationMessage::Hello { protocol: 1 })?;
|
||||
|
||||
// Wait for Server Hello
|
||||
let (message, _) = self.transport.recv_from(&mut buf)?;
|
||||
let _hello = match message {
|
||||
ServerNegotiationMessage::Hello(hello) => hello,
|
||||
ServerNegotiationMessage::Reject(_reason) => {
|
||||
return Err(Error::Disconnected);
|
||||
},
|
||||
_ => {
|
||||
return Err(Error::Disconnected);
|
||||
},
|
||||
};
|
||||
|
||||
// TODO select kex, ciphersuite
|
||||
let ciphersuite = V1_CIPHER_AES_256_CBC;
|
||||
let kex_algo = V1_KEX_X25519_DALEK;
|
||||
|
||||
// Initiate key exchange
|
||||
self.transport.send(
|
||||
&mut buf,
|
||||
&ClientNegotiationMessage::StartKex {
|
||||
kex_algo,
|
||||
ciphersuite,
|
||||
},
|
||||
)?;
|
||||
|
||||
// Wait for server to accept
|
||||
let (message, _) = self.transport.recv_from(&mut buf)?;
|
||||
match message {
|
||||
ServerNegotiationMessage::StartKex => (),
|
||||
ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected),
|
||||
_ => return Err(Error::Disconnected),
|
||||
};
|
||||
|
||||
// Generate an ephemeral key
|
||||
let mut rng = rand::thread_rng();
|
||||
let secret = EphemeralSecret::random_from_rng(&mut rng);
|
||||
let public = PublicKey::from(&secret).to_bytes();
|
||||
|
||||
// Send it to the server
|
||||
self.transport.send(
|
||||
&mut buf,
|
||||
&ClientNegotiationMessage::PublicKey(true, &public),
|
||||
)?;
|
||||
|
||||
// Wait for server to respond with its own key
|
||||
let (message, _) = self.transport.recv_from(&mut buf)?;
|
||||
let mut remote = [0; 32];
|
||||
match message {
|
||||
ServerNegotiationMessage::PublicKey(true, key) if key.len() == 32 => {
|
||||
remote.copy_from_slice(key)
|
||||
}
|
||||
ServerNegotiationMessage::PublicKey(_, _key) => return Err(Error::Disconnected),
|
||||
ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected),
|
||||
_ => return Err(Error::Disconnected),
|
||||
};
|
||||
let remote = PublicKey::from(remote);
|
||||
|
||||
let shared = secret.diffie_hellman(&remote);
|
||||
|
||||
// Negotiation done
|
||||
self.transport
|
||||
.send(&mut buf, &ClientNegotiationMessage::Agreed)?;
|
||||
|
||||
// Wait for server's Agreed
|
||||
let (message, _) = self.transport.recv_from(&mut buf)?;
|
||||
match message {
|
||||
ServerNegotiationMessage::Agreed => (),
|
||||
ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected),
|
||||
_ => return Err(Error::Disconnected),
|
||||
}
|
||||
|
||||
let aes = Aes256::new_from_slice(shared.as_bytes()).unwrap();
|
||||
|
||||
self.state = ClientState::Connected(aes);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PacketSocket> PacketSocket for ClientEncryptedSocket<S> {
|
||||
type Error = Error;
|
||||
|
||||
fn recv_from(&mut self, output: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
|
||||
match &self.state {
|
||||
ClientState::Connected(aes) => {
|
||||
assert!(output.len() >= 16);
|
||||
loop {
|
||||
let mut buffer = [0; 256];
|
||||
let (message, peer) = self.transport.recv_from(&mut buffer)?;
|
||||
match message {
|
||||
ServerNegotiationMessage::CipherText(data) => {
|
||||
let len = decrypt_blocked(aes, data, output)?;
|
||||
break Ok((len, peer))
|
||||
}
|
||||
ServerNegotiationMessage::Ping => {
|
||||
self.transport.send(&mut buffer, &ClientNegotiationMessage::Pong)?;
|
||||
self.last_ping = Instant::now();
|
||||
// TODO this is a workaround
|
||||
break Err(Error::Ping);
|
||||
}
|
||||
ServerNegotiationMessage::Kick(_reason) => {
|
||||
break Err(Error::NotConnected)
|
||||
},
|
||||
// TODO send disconnect and leave
|
||||
_ => break Err(Error::NotConnected),
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => Err(Error::NotConnected),
|
||||
}
|
||||
}
|
||||
|
||||
fn send_to(&mut self, data: &[u8], _addr: &SocketAddr) -> Result<usize, Self::Error> {
|
||||
self.send(data)
|
||||
}
|
||||
|
||||
fn send(&mut self, data: &[u8]) -> Result<usize, Self::Error> {
|
||||
let mut buffer = [0; 256];
|
||||
let mut send_buffer = [0; 256];
|
||||
match &self.state {
|
||||
ClientState::Connected(aes) => {
|
||||
let len = encrypt_blocked(aes, data, &mut buffer)?;
|
||||
assert_eq!(len % 16, 0);
|
||||
self.transport.send(
|
||||
&mut send_buffer,
|
||||
&ClientNegotiationMessage::CipherText(&buffer[..len]),
|
||||
)?;
|
||||
Ok(data.len())
|
||||
}
|
||||
_ => Err(Error::NotConnected),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PacketSocket> Drop for ClientEncryptedSocket<S> {
|
||||
fn drop(&mut self) {
|
||||
match self.state {
|
||||
ClientState::PreNegotioation => (),
|
||||
_ => {
|
||||
let mut buffer = [0; 32];
|
||||
self.transport
|
||||
.send(&mut buffer, &ClientNegotiationMessage::Disconnect(0))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PacketSocket> AsRawFd for ClientEncryptedSocket<S> {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.transport.as_raw_fd()
|
||||
}
|
||||
}
|
262
userspace/rsh/src/crypt/mod.rs
Normal file
262
userspace/rsh/src/crypt/mod.rs
Normal file
@ -0,0 +1,262 @@
|
||||
use aes::{
|
||||
cipher::{BlockDecrypt, BlockEncrypt},
|
||||
Aes256, Block,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
proto::{Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy},
|
||||
Error,
|
||||
};
|
||||
|
||||
pub mod client;
|
||||
pub mod server;
|
||||
|
||||
pub use client::ClientEncryptedSocket;
|
||||
pub use server::ServerEncryptedSocket;
|
||||
|
||||
pub const V1_CIPHER_AES_256_CBC: u8 = 0x10;
|
||||
pub const V1_CIPHER_AES_256_CTR: u8 = 0x11;
|
||||
pub const V1_CIPHER_AES_256_GCM: u8 = 0x12;
|
||||
pub const V1_CIPHER_CHACHA20_POLY1305: u8 = 0x15;
|
||||
|
||||
// v1 supports only one DH algo
|
||||
pub const V1_KEX_X25519_DALEK: u8 = 0x10;
|
||||
|
||||
pub enum ServerReject {
|
||||
NoReason,
|
||||
UnsupportedKexAlgo,
|
||||
UnsupportedCiphersuite,
|
||||
InvalidParameter,
|
||||
}
|
||||
|
||||
// v1 protocol
|
||||
pub struct ServerHello<'a> {
|
||||
// TODO server will send its public fingerprint and advertise
|
||||
// supported asymmetric key formats
|
||||
pub symmetric_ciphersuites: &'a [u8],
|
||||
pub kex_algos: &'a [u8],
|
||||
}
|
||||
|
||||
pub enum ClientNegotiationMessage<'a> {
|
||||
Hello { protocol: u8 },
|
||||
StartKex { kex_algo: u8, ciphersuite: u8 },
|
||||
PublicKey(bool, &'a [u8]),
|
||||
Agreed,
|
||||
Pong,
|
||||
CipherText(&'a [u8]),
|
||||
Disconnect(u8),
|
||||
}
|
||||
|
||||
pub enum ServerNegotiationMessage<'a> {
|
||||
Hello(ServerHello<'a>),
|
||||
Reject(ServerReject),
|
||||
StartKex,
|
||||
PublicKey(bool, &'a [u8]),
|
||||
Agreed,
|
||||
Kick(u8),
|
||||
Ping,
|
||||
CipherText(&'a [u8]),
|
||||
}
|
||||
|
||||
pub struct ClientMessageProxy;
|
||||
pub struct ServerMessageProxy;
|
||||
|
||||
impl MessageProxy for ClientMessageProxy {
|
||||
type Type<'de> = ClientNegotiationMessage<'de>;
|
||||
}
|
||||
impl MessageProxy for ServerMessageProxy {
|
||||
type Type<'de> = ServerNegotiationMessage<'de>;
|
||||
}
|
||||
|
||||
impl ClientNegotiationMessage<'_> {
|
||||
const TAG_HELLO: u8 = 0x90;
|
||||
const TAG_START_KEX: u8 = 0x91;
|
||||
const TAG_PUBLIC_KEY: u8 = 0x92;
|
||||
const TAG_AGREED: u8 = 0x93;
|
||||
|
||||
const TAG_CIPHERTEXT: u8 = 0xA0;
|
||||
const TAG_DISCONNECT: u8 = 0xA1;
|
||||
const TAG_PONG: u8 = 0xA2;
|
||||
}
|
||||
|
||||
impl Encode for ClientNegotiationMessage<'_> {
|
||||
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
|
||||
match self {
|
||||
&Self::Hello { protocol } => buffer.write(&[Self::TAG_HELLO, protocol]),
|
||||
&Self::StartKex {
|
||||
kex_algo,
|
||||
ciphersuite,
|
||||
} => buffer.write(&[Self::TAG_START_KEX, kex_algo, ciphersuite]),
|
||||
Self::PublicKey(end, data) => {
|
||||
buffer.write(&[Self::TAG_PUBLIC_KEY, *end as u8])?;
|
||||
buffer.write_variable_bytes(data)
|
||||
}
|
||||
Self::Agreed => buffer.write(&[Self::TAG_AGREED]),
|
||||
Self::CipherText(data) => {
|
||||
buffer.write(&[Self::TAG_CIPHERTEXT])?;
|
||||
buffer.write_variable_bytes(data)
|
||||
}
|
||||
&Self::Disconnect(reason) => buffer.write(&[Self::TAG_DISCONNECT, reason]),
|
||||
Self::Pong => buffer.write(&[Self::TAG_PONG]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de> for ClientNegotiationMessage<'de> {
|
||||
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
|
||||
let tag = buffer.read_u8()?;
|
||||
match tag {
|
||||
Self::TAG_HELLO => {
|
||||
let protocol = buffer.read_u8()?;
|
||||
Ok(Self::Hello { protocol })
|
||||
}
|
||||
Self::TAG_START_KEX => {
|
||||
let kex_algo = buffer.read_u8()?;
|
||||
let ciphersuite = buffer.read_u8()?;
|
||||
Ok(Self::StartKex {
|
||||
kex_algo,
|
||||
ciphersuite,
|
||||
})
|
||||
}
|
||||
Self::TAG_PUBLIC_KEY => {
|
||||
let end = buffer.read_u8()? != 0;
|
||||
let data = buffer.read_variable_bytes()?;
|
||||
Ok(Self::PublicKey(end, data))
|
||||
}
|
||||
Self::TAG_CIPHERTEXT => buffer.read_variable_bytes().map(Self::CipherText),
|
||||
Self::TAG_AGREED => Ok(Self::Agreed),
|
||||
Self::TAG_PONG => Ok(Self::Pong),
|
||||
Self::TAG_DISCONNECT => buffer.read_u8().map(Self::Disconnect),
|
||||
_ => Err(DecodeError::InvalidMessage),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerNegotiationMessage<'_> {
|
||||
const TAG_HELLO: u8 = 0xB0;
|
||||
const TAG_REJECT: u8 = 0xB1;
|
||||
const TAG_START_KEX: u8 = 0xB2;
|
||||
const TAG_PUBLIC_KEY: u8 = 0xB3;
|
||||
const TAG_AGREED: u8 = 0xB4;
|
||||
|
||||
const TAG_CIPHERTEXT: u8 = 0xC0;
|
||||
const TAG_PING: u8 = 0xC1;
|
||||
const TAG_KICK: u8 = 0xC2;
|
||||
}
|
||||
|
||||
impl Encode for ServerHello<'_> {
|
||||
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
|
||||
buffer.write_variable_bytes(self.kex_algos)?;
|
||||
buffer.write_variable_bytes(self.symmetric_ciphersuites)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de> for ServerHello<'de> {
|
||||
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
|
||||
let kex_algos = buffer.read_variable_bytes()?;
|
||||
let symmetric_ciphersuites = buffer.read_variable_bytes()?;
|
||||
Ok(Self {
|
||||
kex_algos,
|
||||
symmetric_ciphersuites,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode for ServerNegotiationMessage<'_> {
|
||||
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
|
||||
match self {
|
||||
Self::Hello(hello) => {
|
||||
buffer.write(&[Self::TAG_HELLO])?;
|
||||
hello.encode(buffer)
|
||||
}
|
||||
Self::Reject(_reason) => todo!(),
|
||||
Self::StartKex => buffer.write(&[Self::TAG_START_KEX]),
|
||||
Self::PublicKey(end, data) => {
|
||||
buffer.write(&[Self::TAG_PUBLIC_KEY, *end as u8])?;
|
||||
buffer.write_variable_bytes(data)
|
||||
}
|
||||
Self::Agreed => buffer.write(&[Self::TAG_AGREED]),
|
||||
Self::CipherText(data) => {
|
||||
buffer.write(&[Self::TAG_CIPHERTEXT])?;
|
||||
buffer.write_variable_bytes(data)
|
||||
}
|
||||
&Self::Kick(reason) => buffer.write(&[Self::TAG_KICK, reason]),
|
||||
&Self::Ping => buffer.write(&[Self::TAG_PING]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de> for ServerNegotiationMessage<'de> {
|
||||
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
|
||||
let tag = buffer.read_u8()?;
|
||||
match tag {
|
||||
Self::TAG_HELLO => ServerHello::decode(buffer).map(Self::Hello),
|
||||
Self::TAG_REJECT => todo!(),
|
||||
Self::TAG_START_KEX => Ok(Self::StartKex),
|
||||
Self::TAG_PUBLIC_KEY => {
|
||||
let end = buffer.read_u8()? != 0;
|
||||
let data = buffer.read_variable_bytes()?;
|
||||
Ok(Self::PublicKey(end, data))
|
||||
}
|
||||
Self::TAG_AGREED => Ok(Self::Agreed),
|
||||
Self::TAG_CIPHERTEXT => buffer.read_variable_bytes().map(Self::CipherText),
|
||||
Self::TAG_KICK => buffer.read_u8().map(Self::Kick),
|
||||
Self::TAG_PING => Ok(Self::Ping),
|
||||
|
||||
_ => Err(DecodeError::InvalidMessage),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn decrypt_blocked(aes: &Aes256, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> {
|
||||
if src.len() % 16 != 0 || dst.len() < src.len() {
|
||||
todo!();
|
||||
}
|
||||
|
||||
let mut pos = 0;
|
||||
let mut out = 0;
|
||||
let mut block = Block::from([0; 16]);
|
||||
while pos != src.len() {
|
||||
block.copy_from_slice(&src[pos..pos + 16]);
|
||||
aes.decrypt_block(&mut block);
|
||||
|
||||
let len = block[0] as usize;
|
||||
if len > 15 {
|
||||
todo!();
|
||||
}
|
||||
|
||||
dst[out..out + len].copy_from_slice(&block[1..1 + len]);
|
||||
|
||||
out += len;
|
||||
pos += 16;
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn encrypt_blocked(aes: &Aes256, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> {
|
||||
if src.len() >= (dst.len() / 16) * 15 {
|
||||
todo!();
|
||||
}
|
||||
|
||||
let mut pos = 0;
|
||||
let mut out = 0;
|
||||
let mut block = Block::from([0; 16]);
|
||||
while pos != src.len() {
|
||||
let len = core::cmp::min(src.len() - pos, 15);
|
||||
block[0] = len as u8;
|
||||
block[1..1 + len].copy_from_slice(&src[pos..pos + len]);
|
||||
// TODO pad with random
|
||||
if len != 15 {
|
||||
block[1 + len..].fill(0);
|
||||
}
|
||||
aes.encrypt_block(&mut block);
|
||||
|
||||
dst[out..out + 16].copy_from_slice(&block);
|
||||
|
||||
pos += len;
|
||||
out += 16;
|
||||
}
|
||||
Ok(out)
|
||||
}
|
202
userspace/rsh/src/crypt/server.rs
Normal file
202
userspace/rsh/src/crypt/server.rs
Normal file
@ -0,0 +1,202 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
mem,
|
||||
net::SocketAddr,
|
||||
os::fd::{AsRawFd, RawFd},
|
||||
};
|
||||
|
||||
use aes::{cipher::KeyInit, Aes256};
|
||||
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||
|
||||
use crate::{
|
||||
crypt::{decrypt_blocked, ServerHello, ServerNegotiationMessage},
|
||||
socket::{
|
||||
MessageSocket, MultiplexedSocket, MultiplexedSocketEvent, PacketSocket, SocketWrapper,
|
||||
},
|
||||
Error,
|
||||
};
|
||||
|
||||
use super::{encrypt_blocked, ClientMessageProxy, ClientNegotiationMessage, ServerMessageProxy};
|
||||
|
||||
pub enum ServerPeerTransport {
|
||||
PreNegotiation,
|
||||
Negotiation(EphemeralSecret),
|
||||
Connected(Aes256, usize),
|
||||
}
|
||||
|
||||
pub struct ServerEncryptedSocket<S: PacketSocket> {
|
||||
transport: SocketWrapper<S, ClientMessageProxy, ServerMessageProxy>,
|
||||
peers: HashMap<SocketAddr, ServerPeerTransport>,
|
||||
buffer: [u8; 256],
|
||||
}
|
||||
|
||||
impl<S: PacketSocket> ServerEncryptedSocket<S> {
|
||||
pub fn new(transport: S) -> Self {
|
||||
Self {
|
||||
transport: SocketWrapper::new(transport),
|
||||
peers: HashMap::new(),
|
||||
buffer: [0; 256],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_client(&mut self, remote: &SocketAddr) {
|
||||
let mut buf = [0; 32];
|
||||
if self.peers.remove(remote).is_some() {
|
||||
self.transport
|
||||
.send_to(remote, &mut buf, &ServerNegotiationMessage::Kick(0))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ping_clients(&mut self, limit: usize) -> Vec<SocketAddr> {
|
||||
let mut send_buf = [0; 32];
|
||||
let mut removed = vec![];
|
||||
for (remote, state) in self.peers.iter_mut() {
|
||||
match state {
|
||||
ServerPeerTransport::Connected(_, missed) => {
|
||||
self.transport.send_to(remote, &mut send_buf, &ServerNegotiationMessage::Ping).ok();
|
||||
|
||||
if *missed >= limit {
|
||||
removed.push(*remote);
|
||||
}
|
||||
|
||||
*missed += 1;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
for entry in removed.iter() {
|
||||
self.peers.remove(entry);
|
||||
}
|
||||
|
||||
removed
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PacketSocket> MultiplexedSocket for ServerEncryptedSocket<S> {
|
||||
fn send_to(&mut self, remote: &SocketAddr, data: &[u8]) -> Result<(), Error> {
|
||||
let mut buf = [0; 256];
|
||||
if let Some(ServerPeerTransport::Connected(aes, _)) = self.peers.get(remote) {
|
||||
let len = encrypt_blocked(aes, data, &mut self.buffer)?;
|
||||
assert_eq!(len % 16, 0);
|
||||
self.transport.send_to(
|
||||
remote,
|
||||
&mut buf,
|
||||
&ServerNegotiationMessage::CipherText(&self.buffer[..len]),
|
||||
)
|
||||
} else {
|
||||
Err(Error::NotConnected)
|
||||
}
|
||||
}
|
||||
|
||||
fn recv_from<'a>(&mut self, buffer: &'a mut [u8]) -> Result<MultiplexedSocketEvent<'a>, Error> {
|
||||
let (message, remote) = self.transport.recv_from(&mut self.buffer)?;
|
||||
|
||||
let mut buf = [0; 256];
|
||||
|
||||
if let Some(mut state) = self.peers.get_mut(&remote) {
|
||||
match (message, &mut state) {
|
||||
// TODO check kex params
|
||||
(
|
||||
ClientNegotiationMessage::StartKex { .. },
|
||||
ServerPeerTransport::PreNegotiation,
|
||||
) => {
|
||||
let mut rng = rand::thread_rng();
|
||||
let secret = EphemeralSecret::random_from_rng(&mut rng);
|
||||
*state = ServerPeerTransport::Negotiation(secret);
|
||||
self.transport.send_to(
|
||||
&remote,
|
||||
&mut buf,
|
||||
&ServerNegotiationMessage::StartKex,
|
||||
)?;
|
||||
Ok(MultiplexedSocketEvent::None(remote))
|
||||
}
|
||||
(
|
||||
ClientNegotiationMessage::PublicKey(true, data),
|
||||
ServerPeerTransport::Negotiation(secret),
|
||||
) if data.len() == 32 => {
|
||||
let public = PublicKey::from(&*secret);
|
||||
let mut remote_key = [0; 32];
|
||||
remote_key.copy_from_slice(data);
|
||||
let remote_key = PublicKey::from(remote_key);
|
||||
state.negotiate(&remote_key);
|
||||
|
||||
// Send public key to the client
|
||||
self.transport.send_to(
|
||||
&remote,
|
||||
&mut buf,
|
||||
&ServerNegotiationMessage::PublicKey(true, public.as_bytes()),
|
||||
)?;
|
||||
|
||||
Ok(MultiplexedSocketEvent::None(remote))
|
||||
}
|
||||
(ClientNegotiationMessage::Agreed, ServerPeerTransport::Connected(_, _)) => {
|
||||
self.transport
|
||||
.send_to(&remote, &mut buf, &ServerNegotiationMessage::Agreed)?;
|
||||
Ok(MultiplexedSocketEvent::None(remote))
|
||||
}
|
||||
(
|
||||
ClientNegotiationMessage::CipherText(data),
|
||||
ServerPeerTransport::Connected(aes, _),
|
||||
) => {
|
||||
let len = decrypt_blocked(aes, data, buffer)?;
|
||||
Ok(MultiplexedSocketEvent::ClientData(remote, &buffer[..len]))
|
||||
}
|
||||
(
|
||||
ClientNegotiationMessage::Pong,
|
||||
ServerPeerTransport::Connected(_, missed)
|
||||
) => {
|
||||
*missed = 0;
|
||||
Ok(MultiplexedSocketEvent::None(remote))
|
||||
}
|
||||
(ClientNegotiationMessage::Disconnect(_reason), _) => {
|
||||
eprintln!("Peer disconnected: {remote}");
|
||||
self.peers.remove(&remote);
|
||||
Ok(MultiplexedSocketEvent::ClientDisconnected(remote))
|
||||
}
|
||||
// Misbehavior
|
||||
_ => {
|
||||
self.peers.remove(&remote);
|
||||
Ok(MultiplexedSocketEvent::ClientDisconnected(remote))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let ClientNegotiationMessage::Hello { protocol: 1 } = message else {
|
||||
eprintln!("Unhandled client message");
|
||||
return Ok(MultiplexedSocketEvent::None(remote));
|
||||
};
|
||||
eprintln!("Client Hello");
|
||||
self.peers
|
||||
.insert(remote, ServerPeerTransport::PreNegotiation);
|
||||
|
||||
// Reply with hello
|
||||
let hello = ServerHello {
|
||||
kex_algos: &[],
|
||||
symmetric_ciphersuites: &[],
|
||||
};
|
||||
self.transport
|
||||
.send_to(&remote, &mut buf, &ServerNegotiationMessage::Hello(hello))?;
|
||||
|
||||
Ok(MultiplexedSocketEvent::None(remote))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PacketSocket> AsRawFd for ServerEncryptedSocket<S> {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.transport.as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerPeerTransport {
|
||||
fn negotiate(&mut self, public: &PublicKey) {
|
||||
let Self::Negotiation(secret) = mem::replace(self, ServerPeerTransport::PreNegotiation)
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
let shared = secret.diffie_hellman(public);
|
||||
let aes = Aes256::new_from_slice(shared.as_bytes()).unwrap();
|
||||
*self = ServerPeerTransport::Connected(aes, 0);
|
||||
}
|
||||
}
|
@ -2,20 +2,15 @@
|
||||
#![feature(generic_const_exprs)]
|
||||
#![allow(incomplete_features)]
|
||||
|
||||
use std::{
|
||||
io,
|
||||
marker::PhantomData,
|
||||
net::{SocketAddr, UdpSocket},
|
||||
ops::{Deref, DerefMut},
|
||||
os::fd::{AsRawFd, RawFd},
|
||||
};
|
||||
use std::io;
|
||||
|
||||
use proto::{
|
||||
ClientMessageProxy, Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy,
|
||||
ServerMessageProxy,
|
||||
};
|
||||
use proto::{DecodeError, EncodeError};
|
||||
|
||||
pub mod proto;
|
||||
pub mod socket;
|
||||
pub mod crypt;
|
||||
|
||||
pub use socket::{ClientSocket, ServerSocket};
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
@ -27,106 +22,12 @@ pub enum Error {
|
||||
Decode(#[from] DecodeError),
|
||||
#[error("Encode error: {0}")]
|
||||
Encode(#[from] EncodeError),
|
||||
}
|
||||
|
||||
pub struct SocketWrapper<Rx: MessageProxy, Tx: MessageProxy> {
|
||||
socket: UdpSocket,
|
||||
_pd: PhantomData<(Rx, Tx)>,
|
||||
}
|
||||
|
||||
pub struct ClientSocket(SocketWrapper<ServerMessageProxy, ClientMessageProxy>);
|
||||
pub struct ServerSocket(SocketWrapper<ClientMessageProxy, ServerMessageProxy>);
|
||||
|
||||
impl<Rx: MessageProxy, Tx: MessageProxy> SocketWrapper<Rx, Tx> {
|
||||
pub fn new(socket: UdpSocket) -> Self {
|
||||
Self {
|
||||
socket,
|
||||
_pd: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send(&self, buffer: &mut [u8], message: &Tx::Type<'_>) -> Result<(), Error> {
|
||||
let mut enc = Encoder::new(buffer);
|
||||
message.encode(&mut enc)?;
|
||||
let message = enc.get();
|
||||
let amount = self.socket.send(message)?;
|
||||
if amount == message.len() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::Truncated)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_to(
|
||||
&mut self,
|
||||
remote: &SocketAddr,
|
||||
buffer: &mut [u8],
|
||||
message: &Tx::Type<'_>,
|
||||
) -> Result<(), Error> {
|
||||
let mut enc = Encoder::new(buffer);
|
||||
message.encode(&mut enc)?;
|
||||
let message = enc.get();
|
||||
let amount = self.socket.send_to(message, remote)?;
|
||||
if amount == message.len() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::Truncated)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn recv_from<'de>(
|
||||
&mut self,
|
||||
buffer: &'de mut [u8],
|
||||
) -> Result<(Rx::Type<'de>, SocketAddr), Error> {
|
||||
let (len, remote) = self.socket.recv_from(buffer)?;
|
||||
let mut dec = Decoder::new(&buffer[..len]);
|
||||
let message = Rx::Type::<'de>::decode(&mut dec)?;
|
||||
Ok((message, remote))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rx: MessageProxy, Tx: MessageProxy> AsRawFd for SocketWrapper<Rx, Tx> {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.socket.as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientSocket {
|
||||
pub fn new(socket: UdpSocket) -> Self {
|
||||
Self(SocketWrapper::new(socket))
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ClientSocket {
|
||||
type Target = SocketWrapper<ServerMessageProxy, ClientMessageProxy>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for ClientSocket {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerSocket {
|
||||
pub fn new(socket: UdpSocket) -> Self {
|
||||
Self(SocketWrapper::new(socket))
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ServerSocket {
|
||||
type Target = SocketWrapper<ClientMessageProxy, ServerMessageProxy>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for ServerSocket {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
#[error("Peer not connected")]
|
||||
NotConnected,
|
||||
#[error("Ping")]
|
||||
Ping,
|
||||
#[error("Invalid socket state")]
|
||||
InvalidState,
|
||||
#[error("Disconnected by remote peer")]
|
||||
Disconnected,
|
||||
}
|
||||
|
@ -12,8 +12,7 @@ use clap::Parser;
|
||||
use cross::io::Poll;
|
||||
use libterm::{RawMode, RawTerminal};
|
||||
use rsh::{
|
||||
proto::{ClientMessage, ServerMessage, TerminalInfo},
|
||||
ClientSocket,
|
||||
crypt::ClientEncryptedSocket, proto::{ClientMessage, ServerMessage, TerminalInfo}, socket::{MessageSocket, PacketSocket}, ClientSocket
|
||||
};
|
||||
|
||||
pub const PING_TIMEOUT: Duration = Duration::from_secs(3);
|
||||
@ -39,11 +38,10 @@ struct Args {
|
||||
|
||||
pub struct Client {
|
||||
poll: Poll,
|
||||
socket: ClientSocket,
|
||||
socket: ClientSocket<ClientEncryptedSocket<UdpSocket>>,
|
||||
stdin: Stdin,
|
||||
stdout: Stdout,
|
||||
need_bye: bool,
|
||||
last_ping: Instant,
|
||||
_raw: RawMode,
|
||||
}
|
||||
|
||||
@ -51,7 +49,6 @@ pub enum Event<'b> {
|
||||
Stdin(&'b [u8]),
|
||||
Data(&'b [u8]),
|
||||
Disconnected(&'b str),
|
||||
Ping,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
@ -69,6 +66,8 @@ impl Client {
|
||||
let socket = UdpSocket::bind(local)?;
|
||||
socket.connect(remote)?;
|
||||
|
||||
let mut socket = ClientEncryptedSocket::new(socket);
|
||||
socket.try_connect_blocking()?;
|
||||
let mut socket = ClientSocket::new(socket);
|
||||
|
||||
poll.add(&*socket)?;
|
||||
@ -82,7 +81,6 @@ impl Client {
|
||||
let _raw = unsafe { RawMode::enter(&stdin)? };
|
||||
|
||||
Ok(Self {
|
||||
last_ping: Instant::now(),
|
||||
poll,
|
||||
socket,
|
||||
stdout,
|
||||
@ -102,7 +100,11 @@ impl Client {
|
||||
|
||||
match event {
|
||||
fd if fd == self.socket.as_raw_fd() => {
|
||||
let (message, _) = self.socket.recv_from(buffer)?;
|
||||
let message = match self.socket.recv_from(buffer) {
|
||||
Ok((message, _)) => message,
|
||||
Err(rsh::Error::Ping) => return Ok(None),
|
||||
Err(error) => return Err(error.into())
|
||||
};
|
||||
match message {
|
||||
ServerMessage::Bye(reason) => {
|
||||
// No need for a bye
|
||||
@ -112,9 +114,6 @@ impl Client {
|
||||
ServerMessage::Output(data) => {
|
||||
break Ok(Some(Event::Data(data)));
|
||||
}
|
||||
ServerMessage::Ping => {
|
||||
break Ok(Some(Event::Ping));
|
||||
}
|
||||
// Ignore this one
|
||||
ServerMessage::Hello => break Ok(None),
|
||||
}
|
||||
@ -148,10 +147,6 @@ impl Client {
|
||||
Event::Stdin(data) => {
|
||||
self.socket.send(&mut send_buf, &ClientMessage::Input(data))?;
|
||||
}
|
||||
Event::Ping => {
|
||||
self.last_ping = Instant::now();
|
||||
self.socket.send(&mut send_buf, &ClientMessage::Pong)?;
|
||||
}
|
||||
Event::Disconnected(reason) => {
|
||||
break Ok(reason.into());
|
||||
}
|
||||
@ -161,16 +156,16 @@ impl Client {
|
||||
|
||||
fn check_timeout(&mut self) -> Result<(), Error> {
|
||||
let now = Instant::now();
|
||||
if now - self.last_ping >= PING_TIMEOUT {
|
||||
if now - self.socket.as_inner_mut().last_ping() >= PING_TIMEOUT {
|
||||
Err(Error::Timeout)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn handshake(
|
||||
fn handshake<S: PacketSocket>(
|
||||
poll: &mut Poll,
|
||||
socket: &mut ClientSocket,
|
||||
socket: &mut ClientSocket<S>,
|
||||
terminal: TerminalInfo,
|
||||
attempts: usize,
|
||||
) -> Result<(), Error> {
|
||||
@ -187,9 +182,9 @@ impl Client {
|
||||
Self::try_handshake(poll, socket, terminal, timeout)
|
||||
}
|
||||
|
||||
fn try_handshake(
|
||||
fn try_handshake<S: PacketSocket>(
|
||||
poll: &mut Poll,
|
||||
socket: &mut ClientSocket,
|
||||
socket: &mut ClientSocket<S>,
|
||||
terminal: TerminalInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<(), Error> {
|
||||
|
@ -11,7 +11,6 @@ pub struct ServerMessageProxy;
|
||||
pub enum ClientMessage<'a> {
|
||||
Hello(TerminalInfo),
|
||||
Bye(&'a str),
|
||||
Pong,
|
||||
Input(&'a [u8]),
|
||||
}
|
||||
|
||||
@ -19,7 +18,6 @@ pub enum ClientMessage<'a> {
|
||||
pub enum ServerMessage<'a> {
|
||||
Hello,
|
||||
Bye(&'a str),
|
||||
Ping,
|
||||
Output(&'a [u8]),
|
||||
}
|
||||
|
||||
@ -68,6 +66,10 @@ impl<'a> Encoder<'a> {
|
||||
Self { buffer, pos: 0 }
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.pos = 0;
|
||||
}
|
||||
|
||||
pub fn write_variable_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
|
||||
let len: u32 = bytes
|
||||
.len()
|
||||
@ -145,7 +147,6 @@ impl MessageProxy for ServerMessageProxy {
|
||||
impl ClientMessage<'_> {
|
||||
const TAG_HELLO: u8 = 0x80;
|
||||
const TAG_BYE: u8 = 0x81;
|
||||
const TAG_PONG: u8 = 0x82;
|
||||
const TAG_INPUT: u8 = 0x90;
|
||||
}
|
||||
|
||||
@ -176,9 +177,6 @@ impl<'a> Encode for ClientMessage<'a> {
|
||||
buffer.write(&[Self::TAG_BYE])?;
|
||||
buffer.write_str(reason)
|
||||
}
|
||||
Self::Pong => {
|
||||
buffer.write(&[Self::TAG_PONG])
|
||||
}
|
||||
Self::Input(data) => {
|
||||
buffer.write(&[Self::TAG_INPUT])?;
|
||||
buffer.write_variable_bytes(data)
|
||||
@ -197,9 +195,6 @@ impl<'de> Decode<'de> for ClientMessage<'de> {
|
||||
Self::TAG_BYE => {
|
||||
buffer.read_str().map(Self::Bye)
|
||||
}
|
||||
Self::TAG_PONG => {
|
||||
Ok(Self::Pong)
|
||||
}
|
||||
Self::TAG_INPUT => {
|
||||
buffer.read_variable_bytes().map(Self::Input)
|
||||
}
|
||||
@ -211,7 +206,6 @@ impl<'de> Decode<'de> for ClientMessage<'de> {
|
||||
impl ServerMessage<'_> {
|
||||
const TAG_HELLO: u8 = 0x10;
|
||||
const TAG_BYE: u8 = 0x11;
|
||||
const TAG_PING: u8 = 0x12;
|
||||
const TAG_OUTPUT: u8 = 0x20;
|
||||
}
|
||||
|
||||
@ -223,8 +217,6 @@ impl<'a> Encode for ServerMessage<'a> {
|
||||
buffer.write(&[Self::TAG_BYE])?;
|
||||
buffer.write_str(reason)
|
||||
}
|
||||
// TODO sequence number
|
||||
Self::Ping => buffer.write(&[Self::TAG_PING]),
|
||||
Self::Output(data) => {
|
||||
buffer.write(&[Self::TAG_OUTPUT])?;
|
||||
buffer.write_variable_bytes(data)
|
||||
@ -241,7 +233,6 @@ impl<'de> Decode<'de> for ServerMessage<'de> {
|
||||
Self::TAG_BYE => {
|
||||
buffer.read_str().map(Self::Bye)
|
||||
}
|
||||
Self::TAG_PING => Ok(Self::Ping),
|
||||
Self::TAG_OUTPUT => {
|
||||
buffer.read_variable_bytes().map(Self::Output)
|
||||
},
|
||||
|
@ -13,8 +13,7 @@ use std::{
|
||||
|
||||
use cross::io::{Poll, TimerFd};
|
||||
use rsh::{
|
||||
proto::{ClientMessage, ServerMessage, TerminalInfo},
|
||||
Error, ServerSocket,
|
||||
crypt::ServerEncryptedSocket, proto::{ClientMessage, Decode, Decoder, Encoder, ServerMessage, TerminalInfo}, socket::{MessageSocket, MultiplexedSocket, MultiplexedSocketEvent}, Error, ServerSocket
|
||||
};
|
||||
|
||||
pub const PING_INTERVAL: Duration = Duration::from_millis(500);
|
||||
@ -29,7 +28,7 @@ pub struct Session {
|
||||
pub struct Server {
|
||||
poll: Poll,
|
||||
timer: TimerFd,
|
||||
socket: ServerSocket,
|
||||
socket: ServerEncryptedSocket<UdpSocket>,
|
||||
|
||||
addr_to_session: HashMap<SocketAddr, RawFd>,
|
||||
pty_to_session: HashMap<RawFd, Session>,
|
||||
@ -98,8 +97,8 @@ impl Server {
|
||||
pub fn new(listen_addr: SocketAddr) -> Result<Self, Error> {
|
||||
let mut poll = Poll::new()?;
|
||||
let timer = TimerFd::new()?;
|
||||
let socket = UdpSocket::bind(listen_addr).map(ServerSocket::new)?;
|
||||
poll.add(&*socket)?;
|
||||
let socket = UdpSocket::bind(listen_addr).map(ServerEncryptedSocket::new)?;
|
||||
poll.add(&socket)?;
|
||||
poll.add(&timer)?;
|
||||
Ok(Self {
|
||||
poll,
|
||||
@ -115,13 +114,27 @@ impl Server {
|
||||
|
||||
match fd {
|
||||
fd if fd == self.socket.as_raw_fd() => {
|
||||
let (message, remote) = match self.socket.recv_from(buffer) {
|
||||
Ok((message, remote)) => (message, remote),
|
||||
Err(error @ (Error::Decode(_) | Error::Truncated)) => {
|
||||
eprintln!("Receive error: {error}");
|
||||
let event = self.socket.recv_from(buffer)?;
|
||||
|
||||
let (message, remote) = match event {
|
||||
MultiplexedSocketEvent::ClientDisconnected(remote) => {
|
||||
self.remove_session_by_remote(remote).ok();
|
||||
return Ok(None)
|
||||
},
|
||||
Err(error) => return Err(error),
|
||||
}
|
||||
MultiplexedSocketEvent::None(_) => return Ok(None),
|
||||
MultiplexedSocketEvent::ClientData(peer, data) => {
|
||||
let mut decoder = Decoder::new(data);
|
||||
let message = ClientMessage::decode(&mut decoder);
|
||||
(message, peer)
|
||||
}
|
||||
};
|
||||
|
||||
let message = match message {
|
||||
Ok(message) => message,
|
||||
Err(error) => {
|
||||
eprintln!("Decode error: {error}");
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let event = match message {
|
||||
@ -136,11 +149,6 @@ impl Server {
|
||||
{
|
||||
Event::SessionInput(*fd, remote, data)
|
||||
}
|
||||
ClientMessage::Pong if let Some(fd) = self.addr_to_session.get(&remote) => {
|
||||
let session = self.pty_to_session.get_mut(fd).unwrap();
|
||||
session.timeouts = 0;
|
||||
return Ok(None);
|
||||
}
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
@ -176,7 +184,7 @@ impl Server {
|
||||
eprintln!("PTY write error: {error}");
|
||||
self.remove_session_by_fd(fd)?;
|
||||
self.socket
|
||||
.send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
|
||||
.send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
@ -190,13 +198,13 @@ impl Server {
|
||||
Ok(session) => {
|
||||
self.register_session(remote, session)?;
|
||||
self.socket
|
||||
.send_to(&remote, &mut send_buf, &ServerMessage::Hello)
|
||||
.send_message_to(&remote, &mut send_buf, &ServerMessage::Hello)
|
||||
.ok();
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("PTY open error: {err}");
|
||||
self.socket
|
||||
.send_to(
|
||||
.send_message_to(
|
||||
&remote,
|
||||
&mut send_buf,
|
||||
&ServerMessage::Bye("PTY open error"),
|
||||
@ -208,48 +216,37 @@ impl Server {
|
||||
Event::Pty(fd, remote, event) => match event {
|
||||
PtyEvent::Data(data) => {
|
||||
self.socket
|
||||
.send_to(&remote, &mut send_buf, &ServerMessage::Output(data))
|
||||
.send_message_to(&remote, &mut send_buf, &ServerMessage::Output(data))
|
||||
.ok();
|
||||
},
|
||||
PtyEvent::Err(error) => {
|
||||
eprintln!("PTY read error: {error}");
|
||||
self.remove_session_by_fd(fd)?;
|
||||
self.socket
|
||||
.send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
|
||||
.send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
|
||||
.ok();
|
||||
},
|
||||
PtyEvent::Closed => {
|
||||
println!("End of PTY for {remote}");
|
||||
self.remove_session_by_fd(fd)?;
|
||||
self.socket
|
||||
.send_to(&remote, &mut send_buf, &ServerMessage::Bye(""))
|
||||
.send_message_to(&remote, &mut send_buf, &ServerMessage::Bye(""))
|
||||
.ok();
|
||||
},
|
||||
},
|
||||
Event::Tick => {
|
||||
// Restart the timer
|
||||
self.update_client_timeouts(&mut send_buf)?;
|
||||
self.update_client_timeouts()?;
|
||||
self.timer.start(PING_INTERVAL)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_client_timeouts(&mut self, send_buf: &mut [u8]) -> Result<(), Error> {
|
||||
let mut removed = vec![];
|
||||
for (remote, fd) in self.addr_to_session.iter() {
|
||||
let session = self.pty_to_session.get_mut(&fd).unwrap();
|
||||
self.socket.send_to(remote, send_buf, &ServerMessage::Ping).ok();
|
||||
session.timeouts += 1;
|
||||
if session.timeouts >= 10 {
|
||||
removed.push((*remote, *fd));
|
||||
}
|
||||
}
|
||||
|
||||
for (remote, fd) in removed {
|
||||
eprintln!("Client {remote} timed out");
|
||||
self.remove_session_by_fd(fd)?;
|
||||
self.socket.send_to(&remote, send_buf, &ServerMessage::Bye("Timed out".into())).ok();
|
||||
fn update_client_timeouts(&mut self) -> Result<(), Error> {
|
||||
let removed = self.socket.ping_clients(8);
|
||||
for entry in removed {
|
||||
self.remove_session_by_remote(entry).ok();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -268,6 +265,7 @@ impl Server {
|
||||
// NOTE: this will block the whole server while the process finishes.
|
||||
session.shell.wait().ok();
|
||||
self.addr_to_session.remove(&session.remote).unwrap();
|
||||
self.socket.remove_client(&session.remote);
|
||||
self.poll.remove(&fd)?;
|
||||
Ok(Some(session))
|
||||
} else {
|
||||
|
202
userspace/rsh/src/socket.rs
Normal file
202
userspace/rsh/src/socket.rs
Normal file
@ -0,0 +1,202 @@
|
||||
use std::{
|
||||
io,
|
||||
marker::PhantomData,
|
||||
net::{SocketAddr, UdpSocket},
|
||||
ops::{Deref, DerefMut},
|
||||
os::fd::{AsRawFd, RawFd},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
proto::{
|
||||
ClientMessageProxy, Decode, Decoder, Encode, Encoder, MessageProxy, ServerMessageProxy,
|
||||
},
|
||||
Error,
|
||||
};
|
||||
|
||||
pub struct SocketWrapper<S: PacketSocket, Rx: MessageProxy, Tx: MessageProxy> {
|
||||
socket: S,
|
||||
_pd: PhantomData<(Rx, Tx)>,
|
||||
}
|
||||
|
||||
pub trait MessageSocket<Rx: MessageProxy, Tx: MessageProxy>: AsRawFd {
|
||||
fn send(&mut self, buffer: &mut [u8], data: &Tx::Type<'_>) -> Result<(), Error>;
|
||||
fn send_to(
|
||||
&mut self,
|
||||
addr: &SocketAddr,
|
||||
buffer: &mut [u8],
|
||||
data: &Tx::Type<'_>,
|
||||
) -> Result<(), Error>;
|
||||
fn recv_from<'de>(
|
||||
&mut self,
|
||||
buffer: &'de mut [u8],
|
||||
) -> Result<(Rx::Type<'de>, SocketAddr), Error>;
|
||||
}
|
||||
|
||||
pub trait PacketSocket: AsRawFd {
|
||||
type Error: Into<Error>;
|
||||
|
||||
fn send(&mut self, data: &[u8]) -> Result<usize, Self::Error>;
|
||||
fn send_to(&mut self, data: &[u8], addr: &SocketAddr) -> Result<usize, Self::Error>;
|
||||
fn recv_from(&mut self, data: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error>;
|
||||
}
|
||||
|
||||
pub enum MultiplexedSocketEvent<'a> {
|
||||
ClientData(SocketAddr, &'a [u8]),
|
||||
ClientDisconnected(SocketAddr),
|
||||
None(SocketAddr),
|
||||
}
|
||||
|
||||
pub trait MultiplexedSocket: AsRawFd {
|
||||
fn recv_from<'a>(&mut self, buffer: &'a mut [u8]) -> Result<MultiplexedSocketEvent<'a>, Error>;
|
||||
fn send_to(&mut self, remote: &SocketAddr, data: &[u8]) -> Result<(), Error>;
|
||||
|
||||
fn send_message_to<Tx: Encode>(
|
||||
&mut self,
|
||||
remote: &SocketAddr,
|
||||
buffer: &mut [u8],
|
||||
data: &Tx,
|
||||
) -> Result<(), Error> {
|
||||
let mut encoder = Encoder::new(buffer);
|
||||
data.encode(&mut encoder)?;
|
||||
self.send_to(remote, encoder.get())
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketSocket for UdpSocket {
|
||||
type Error = io::Error;
|
||||
|
||||
fn send(&mut self, data: &[u8]) -> Result<usize, Self::Error> {
|
||||
UdpSocket::send(self, data)
|
||||
}
|
||||
|
||||
fn send_to(&mut self, data: &[u8], addr: &SocketAddr) -> Result<usize, Self::Error> {
|
||||
UdpSocket::send_to(self, data, addr)
|
||||
}
|
||||
|
||||
fn recv_from(&mut self, data: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
|
||||
UdpSocket::recv_from(self, data)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ClientSocket<S: PacketSocket>(SocketWrapper<S, ServerMessageProxy, ClientMessageProxy>);
|
||||
pub struct ServerSocket<S: MultiplexedSocket>(S);
|
||||
|
||||
impl<S: PacketSocket, Rx: MessageProxy, Tx: MessageProxy> MessageSocket<Rx, Tx>
|
||||
for SocketWrapper<S, Rx, Tx>
|
||||
{
|
||||
fn send(&mut self, buffer: &mut [u8], data: &Tx::Type<'_>) -> Result<(), Error> {
|
||||
let mut enc = Encoder::new(buffer);
|
||||
data.encode(&mut enc)?;
|
||||
let message = enc.get();
|
||||
let amount = self.socket.send(message).map_err(S::Error::into)?;
|
||||
if amount == message.len() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::Truncated)
|
||||
}
|
||||
}
|
||||
|
||||
fn send_to(
|
||||
&mut self,
|
||||
addr: &SocketAddr,
|
||||
buffer: &mut [u8],
|
||||
data: &Tx::Type<'_>,
|
||||
) -> Result<(), Error> {
|
||||
let mut enc = Encoder::new(buffer);
|
||||
data.encode(&mut enc)?;
|
||||
let message = enc.get();
|
||||
let amount = self.socket.send_to(message, addr).map_err(S::Error::into)?;
|
||||
if amount == message.len() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::Truncated)
|
||||
}
|
||||
}
|
||||
|
||||
fn recv_from<'de>(
|
||||
&mut self,
|
||||
buffer: &'de mut [u8],
|
||||
) -> Result<(Rx::Type<'de>, SocketAddr), Error> {
|
||||
let (len, remote) = self.socket.recv_from(buffer).map_err(S::Error::into)?;
|
||||
let mut dec = Decoder::new(&buffer[..len]);
|
||||
let message = Rx::Type::<'de>::decode(&mut dec)?;
|
||||
Ok((message, remote))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PacketSocket, Rx: MessageProxy, Tx: MessageProxy> SocketWrapper<S, Rx, Tx> {
|
||||
pub fn new(socket: S) -> Self {
|
||||
Self {
|
||||
socket,
|
||||
_pd: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> S {
|
||||
self.socket
|
||||
}
|
||||
|
||||
pub fn as_inner_mut(&mut self) -> &mut S {
|
||||
&mut self.socket
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PacketSocket, Rx: MessageProxy, Tx: MessageProxy> AsRawFd for SocketWrapper<S, Rx, Tx> {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.socket.as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PacketSocket> ClientSocket<S> {
|
||||
pub fn new(socket: S) -> Self {
|
||||
Self(SocketWrapper::new(socket))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PacketSocket> Deref for ClientSocket<S> {
|
||||
type Target = SocketWrapper<S, ServerMessageProxy, ClientMessageProxy>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PacketSocket> DerefMut for ClientSocket<S> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: MultiplexedSocket> ServerSocket<S> {
|
||||
pub fn new(socket: S) -> Self {
|
||||
Self(socket)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: MultiplexedSocket> Deref for ServerSocket<S> {
|
||||
type Target = S;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: MultiplexedSocket> DerefMut for ServerSocket<S> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
// impl<S: PacketSocket> Deref for ServerSocket<S> {
|
||||
// type Target = SocketWrapper<S, ClientMessageProxy, ServerMessageProxy>;
|
||||
//
|
||||
// fn deref(&self) -> &Self::Target {
|
||||
// &self.0
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// impl<S: PacketSocket> DerefMut for ServerSocket<S> {
|
||||
// fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
// &mut self.0
|
||||
// }
|
||||
// }
|
Loading…
x
Reference in New Issue
Block a user