rsh: implement dumb kex+aes256
This commit is contained in:
parent
bcf1e74a04
commit
99c1dd51ae
58
userspace/Cargo.lock
generated
58
userspace/Cargo.lock
generated
@ -16,6 +16,17 @@ dependencies = [
|
|||||||
name = "abi-lib"
|
name = "abi-lib"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aes"
|
||||||
|
version = "0.8.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"cipher",
|
||||||
|
"cpufeatures",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anstyle"
|
name = "anstyle"
|
||||||
version = "1.0.9"
|
version = "1.0.9"
|
||||||
@ -102,6 +113,16 @@ version = "1.0.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cipher"
|
||||||
|
version = "0.4.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
|
||||||
|
dependencies = [
|
||||||
|
"crypto-common",
|
||||||
|
"inout",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap"
|
name = "clap"
|
||||||
version = "4.5.20"
|
version = "4.5.20"
|
||||||
@ -522,6 +543,15 @@ dependencies = [
|
|||||||
"yggdrasil-rt",
|
"yggdrasil-rt",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "inout"
|
||||||
|
version = "0.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5"
|
||||||
|
dependencies = [
|
||||||
|
"generic-array",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itoa"
|
name = "itoa"
|
||||||
version = "1.0.11"
|
version = "1.0.11"
|
||||||
@ -1011,11 +1041,14 @@ dependencies = [
|
|||||||
name = "rsh"
|
name = "rsh"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"aes",
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"clap",
|
"clap",
|
||||||
"cross",
|
"cross",
|
||||||
"libterm",
|
"libterm",
|
||||||
|
"rand 0.8.5",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
|
"x25519-dalek",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1633,6 +1666,17 @@ dependencies = [
|
|||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "x25519-dalek"
|
||||||
|
version = "2.0.1"
|
||||||
|
source = "git+https://git.alnyan.me/yggdrasil/curve25519-dalek.git?branch=alnyan%2Fyggdrasil#5f4dbb09259347077d3a5024ba60c77efde93a3a"
|
||||||
|
dependencies = [
|
||||||
|
"curve25519-dalek",
|
||||||
|
"rand_core 0.6.4 (git+https://git.alnyan.me/yggdrasil/rand.git?branch=alnyan%2Fyggdrasil-rng_core-0.6.4)",
|
||||||
|
"serde",
|
||||||
|
"zeroize",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "yasync"
|
name = "yasync"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@ -1710,3 +1754,17 @@ name = "zeroize"
|
|||||||
version = "1.8.1"
|
version = "1.8.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
|
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
|
||||||
|
dependencies = [
|
||||||
|
"zeroize_derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zeroize_derive"
|
||||||
|
version = "1.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.86",
|
||||||
|
]
|
||||||
|
@ -8,8 +8,12 @@ name = "rshd"
|
|||||||
path = "src/rshd/main.rs"
|
path = "src/rshd/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
clap.workspace = true
|
|
||||||
libterm.workspace = true
|
libterm.workspace = true
|
||||||
thiserror.workspace = true
|
|
||||||
cross.workspace = true
|
cross.workspace = true
|
||||||
|
|
||||||
|
clap.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
bytemuck.workspace = true
|
bytemuck.workspace = true
|
||||||
|
x25519-dalek.workspace = true
|
||||||
|
rand = { git = "https://git.alnyan.me/yggdrasil/rand.git", branch = "alnyan/yggdrasil-rng_core-0.6.4" }
|
||||||
|
aes = { version = "0.8.4" }
|
||||||
|
206
userspace/rsh/src/crypt/client.rs
Normal file
206
userspace/rsh/src/crypt/client.rs
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
use std::{
|
||||||
|
net::SocketAddr,
|
||||||
|
os::fd::{AsRawFd, RawFd}, time::Instant,
|
||||||
|
};
|
||||||
|
|
||||||
|
use aes::{cipher::KeyInit, Aes256};
|
||||||
|
use rand::prelude::Distribution;
|
||||||
|
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
crypt::{
|
||||||
|
decrypt_blocked, encrypt_blocked, ClientNegotiationMessage, ServerNegotiationMessage,
|
||||||
|
V1_CIPHER_AES_256_CBC, V1_KEX_X25519_DALEK,
|
||||||
|
},
|
||||||
|
socket::{MessageSocket, PacketSocket, SocketWrapper},
|
||||||
|
Error,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{ClientMessageProxy, ServerMessageProxy};
|
||||||
|
|
||||||
|
enum ClientState {
|
||||||
|
PreNegotioation,
|
||||||
|
Connected(Aes256),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ClientEncryptedSocket<S: PacketSocket> {
|
||||||
|
transport: SocketWrapper<S, ServerMessageProxy, ClientMessageProxy>,
|
||||||
|
state: ClientState,
|
||||||
|
last_ping: Instant
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PacketSocket> ClientEncryptedSocket<S> {
|
||||||
|
pub fn new(transport: S) -> Self {
|
||||||
|
Self {
|
||||||
|
transport: SocketWrapper::new(transport),
|
||||||
|
state: ClientState::PreNegotioation,
|
||||||
|
last_ping: Instant::now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn last_ping(&self) -> Instant {
|
||||||
|
self.last_ping
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn try_connect_blocking(&mut self) -> Result<(), Error> {
|
||||||
|
let ClientState::PreNegotioation = self.state else {
|
||||||
|
return Err(Error::InvalidState);
|
||||||
|
};
|
||||||
|
let mut buf = [0; 256];
|
||||||
|
|
||||||
|
// Send Hello
|
||||||
|
self.transport
|
||||||
|
.send(&mut buf, &ClientNegotiationMessage::Hello { protocol: 1 })?;
|
||||||
|
|
||||||
|
// Wait for Server Hello
|
||||||
|
let (message, _) = self.transport.recv_from(&mut buf)?;
|
||||||
|
let _hello = match message {
|
||||||
|
ServerNegotiationMessage::Hello(hello) => hello,
|
||||||
|
ServerNegotiationMessage::Reject(_reason) => {
|
||||||
|
return Err(Error::Disconnected);
|
||||||
|
},
|
||||||
|
_ => {
|
||||||
|
return Err(Error::Disconnected);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO select kex, ciphersuite
|
||||||
|
let ciphersuite = V1_CIPHER_AES_256_CBC;
|
||||||
|
let kex_algo = V1_KEX_X25519_DALEK;
|
||||||
|
|
||||||
|
// Initiate key exchange
|
||||||
|
self.transport.send(
|
||||||
|
&mut buf,
|
||||||
|
&ClientNegotiationMessage::StartKex {
|
||||||
|
kex_algo,
|
||||||
|
ciphersuite,
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Wait for server to accept
|
||||||
|
let (message, _) = self.transport.recv_from(&mut buf)?;
|
||||||
|
match message {
|
||||||
|
ServerNegotiationMessage::StartKex => (),
|
||||||
|
ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected),
|
||||||
|
_ => return Err(Error::Disconnected),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generate an ephemeral key
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let secret = EphemeralSecret::random_from_rng(&mut rng);
|
||||||
|
let public = PublicKey::from(&secret).to_bytes();
|
||||||
|
|
||||||
|
// Send it to the server
|
||||||
|
self.transport.send(
|
||||||
|
&mut buf,
|
||||||
|
&ClientNegotiationMessage::PublicKey(true, &public),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Wait for server to respond with its own key
|
||||||
|
let (message, _) = self.transport.recv_from(&mut buf)?;
|
||||||
|
let mut remote = [0; 32];
|
||||||
|
match message {
|
||||||
|
ServerNegotiationMessage::PublicKey(true, key) if key.len() == 32 => {
|
||||||
|
remote.copy_from_slice(key)
|
||||||
|
}
|
||||||
|
ServerNegotiationMessage::PublicKey(_, _key) => return Err(Error::Disconnected),
|
||||||
|
ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected),
|
||||||
|
_ => return Err(Error::Disconnected),
|
||||||
|
};
|
||||||
|
let remote = PublicKey::from(remote);
|
||||||
|
|
||||||
|
let shared = secret.diffie_hellman(&remote);
|
||||||
|
|
||||||
|
// Negotiation done
|
||||||
|
self.transport
|
||||||
|
.send(&mut buf, &ClientNegotiationMessage::Agreed)?;
|
||||||
|
|
||||||
|
// Wait for server's Agreed
|
||||||
|
let (message, _) = self.transport.recv_from(&mut buf)?;
|
||||||
|
match message {
|
||||||
|
ServerNegotiationMessage::Agreed => (),
|
||||||
|
ServerNegotiationMessage::Reject(_reason) => return Err(Error::Disconnected),
|
||||||
|
_ => return Err(Error::Disconnected),
|
||||||
|
}
|
||||||
|
|
||||||
|
let aes = Aes256::new_from_slice(shared.as_bytes()).unwrap();
|
||||||
|
|
||||||
|
self.state = ClientState::Connected(aes);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PacketSocket> PacketSocket for ClientEncryptedSocket<S> {
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
fn recv_from(&mut self, output: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
|
||||||
|
match &self.state {
|
||||||
|
ClientState::Connected(aes) => {
|
||||||
|
assert!(output.len() >= 16);
|
||||||
|
loop {
|
||||||
|
let mut buffer = [0; 256];
|
||||||
|
let (message, peer) = self.transport.recv_from(&mut buffer)?;
|
||||||
|
match message {
|
||||||
|
ServerNegotiationMessage::CipherText(data) => {
|
||||||
|
let len = decrypt_blocked(aes, data, output)?;
|
||||||
|
break Ok((len, peer))
|
||||||
|
}
|
||||||
|
ServerNegotiationMessage::Ping => {
|
||||||
|
self.transport.send(&mut buffer, &ClientNegotiationMessage::Pong)?;
|
||||||
|
self.last_ping = Instant::now();
|
||||||
|
// TODO this is a workaround
|
||||||
|
break Err(Error::Ping);
|
||||||
|
}
|
||||||
|
ServerNegotiationMessage::Kick(_reason) => {
|
||||||
|
break Err(Error::NotConnected)
|
||||||
|
},
|
||||||
|
// TODO send disconnect and leave
|
||||||
|
_ => break Err(Error::NotConnected),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(Error::NotConnected),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send_to(&mut self, data: &[u8], _addr: &SocketAddr) -> Result<usize, Self::Error> {
|
||||||
|
self.send(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send(&mut self, data: &[u8]) -> Result<usize, Self::Error> {
|
||||||
|
let mut buffer = [0; 256];
|
||||||
|
let mut send_buffer = [0; 256];
|
||||||
|
match &self.state {
|
||||||
|
ClientState::Connected(aes) => {
|
||||||
|
let len = encrypt_blocked(aes, data, &mut buffer)?;
|
||||||
|
assert_eq!(len % 16, 0);
|
||||||
|
self.transport.send(
|
||||||
|
&mut send_buffer,
|
||||||
|
&ClientNegotiationMessage::CipherText(&buffer[..len]),
|
||||||
|
)?;
|
||||||
|
Ok(data.len())
|
||||||
|
}
|
||||||
|
_ => Err(Error::NotConnected),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PacketSocket> Drop for ClientEncryptedSocket<S> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
match self.state {
|
||||||
|
ClientState::PreNegotioation => (),
|
||||||
|
_ => {
|
||||||
|
let mut buffer = [0; 32];
|
||||||
|
self.transport
|
||||||
|
.send(&mut buffer, &ClientNegotiationMessage::Disconnect(0))
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PacketSocket> AsRawFd for ClientEncryptedSocket<S> {
|
||||||
|
fn as_raw_fd(&self) -> RawFd {
|
||||||
|
self.transport.as_raw_fd()
|
||||||
|
}
|
||||||
|
}
|
262
userspace/rsh/src/crypt/mod.rs
Normal file
262
userspace/rsh/src/crypt/mod.rs
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
use aes::{
|
||||||
|
cipher::{BlockDecrypt, BlockEncrypt},
|
||||||
|
Aes256, Block,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
proto::{Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy},
|
||||||
|
Error,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub mod client;
|
||||||
|
pub mod server;
|
||||||
|
|
||||||
|
pub use client::ClientEncryptedSocket;
|
||||||
|
pub use server::ServerEncryptedSocket;
|
||||||
|
|
||||||
|
pub const V1_CIPHER_AES_256_CBC: u8 = 0x10;
|
||||||
|
pub const V1_CIPHER_AES_256_CTR: u8 = 0x11;
|
||||||
|
pub const V1_CIPHER_AES_256_GCM: u8 = 0x12;
|
||||||
|
pub const V1_CIPHER_CHACHA20_POLY1305: u8 = 0x15;
|
||||||
|
|
||||||
|
// v1 supports only one DH algo
|
||||||
|
pub const V1_KEX_X25519_DALEK: u8 = 0x10;
|
||||||
|
|
||||||
|
pub enum ServerReject {
|
||||||
|
NoReason,
|
||||||
|
UnsupportedKexAlgo,
|
||||||
|
UnsupportedCiphersuite,
|
||||||
|
InvalidParameter,
|
||||||
|
}
|
||||||
|
|
||||||
|
// v1 protocol
|
||||||
|
pub struct ServerHello<'a> {
|
||||||
|
// TODO server will send its public fingerprint and advertise
|
||||||
|
// supported asymmetric key formats
|
||||||
|
pub symmetric_ciphersuites: &'a [u8],
|
||||||
|
pub kex_algos: &'a [u8],
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum ClientNegotiationMessage<'a> {
|
||||||
|
Hello { protocol: u8 },
|
||||||
|
StartKex { kex_algo: u8, ciphersuite: u8 },
|
||||||
|
PublicKey(bool, &'a [u8]),
|
||||||
|
Agreed,
|
||||||
|
Pong,
|
||||||
|
CipherText(&'a [u8]),
|
||||||
|
Disconnect(u8),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum ServerNegotiationMessage<'a> {
|
||||||
|
Hello(ServerHello<'a>),
|
||||||
|
Reject(ServerReject),
|
||||||
|
StartKex,
|
||||||
|
PublicKey(bool, &'a [u8]),
|
||||||
|
Agreed,
|
||||||
|
Kick(u8),
|
||||||
|
Ping,
|
||||||
|
CipherText(&'a [u8]),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ClientMessageProxy;
|
||||||
|
pub struct ServerMessageProxy;
|
||||||
|
|
||||||
|
impl MessageProxy for ClientMessageProxy {
|
||||||
|
type Type<'de> = ClientNegotiationMessage<'de>;
|
||||||
|
}
|
||||||
|
impl MessageProxy for ServerMessageProxy {
|
||||||
|
type Type<'de> = ServerNegotiationMessage<'de>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientNegotiationMessage<'_> {
|
||||||
|
const TAG_HELLO: u8 = 0x90;
|
||||||
|
const TAG_START_KEX: u8 = 0x91;
|
||||||
|
const TAG_PUBLIC_KEY: u8 = 0x92;
|
||||||
|
const TAG_AGREED: u8 = 0x93;
|
||||||
|
|
||||||
|
const TAG_CIPHERTEXT: u8 = 0xA0;
|
||||||
|
const TAG_DISCONNECT: u8 = 0xA1;
|
||||||
|
const TAG_PONG: u8 = 0xA2;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encode for ClientNegotiationMessage<'_> {
|
||||||
|
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
|
||||||
|
match self {
|
||||||
|
&Self::Hello { protocol } => buffer.write(&[Self::TAG_HELLO, protocol]),
|
||||||
|
&Self::StartKex {
|
||||||
|
kex_algo,
|
||||||
|
ciphersuite,
|
||||||
|
} => buffer.write(&[Self::TAG_START_KEX, kex_algo, ciphersuite]),
|
||||||
|
Self::PublicKey(end, data) => {
|
||||||
|
buffer.write(&[Self::TAG_PUBLIC_KEY, *end as u8])?;
|
||||||
|
buffer.write_variable_bytes(data)
|
||||||
|
}
|
||||||
|
Self::Agreed => buffer.write(&[Self::TAG_AGREED]),
|
||||||
|
Self::CipherText(data) => {
|
||||||
|
buffer.write(&[Self::TAG_CIPHERTEXT])?;
|
||||||
|
buffer.write_variable_bytes(data)
|
||||||
|
}
|
||||||
|
&Self::Disconnect(reason) => buffer.write(&[Self::TAG_DISCONNECT, reason]),
|
||||||
|
Self::Pong => buffer.write(&[Self::TAG_PONG]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Decode<'de> for ClientNegotiationMessage<'de> {
|
||||||
|
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
|
||||||
|
let tag = buffer.read_u8()?;
|
||||||
|
match tag {
|
||||||
|
Self::TAG_HELLO => {
|
||||||
|
let protocol = buffer.read_u8()?;
|
||||||
|
Ok(Self::Hello { protocol })
|
||||||
|
}
|
||||||
|
Self::TAG_START_KEX => {
|
||||||
|
let kex_algo = buffer.read_u8()?;
|
||||||
|
let ciphersuite = buffer.read_u8()?;
|
||||||
|
Ok(Self::StartKex {
|
||||||
|
kex_algo,
|
||||||
|
ciphersuite,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Self::TAG_PUBLIC_KEY => {
|
||||||
|
let end = buffer.read_u8()? != 0;
|
||||||
|
let data = buffer.read_variable_bytes()?;
|
||||||
|
Ok(Self::PublicKey(end, data))
|
||||||
|
}
|
||||||
|
Self::TAG_CIPHERTEXT => buffer.read_variable_bytes().map(Self::CipherText),
|
||||||
|
Self::TAG_AGREED => Ok(Self::Agreed),
|
||||||
|
Self::TAG_PONG => Ok(Self::Pong),
|
||||||
|
Self::TAG_DISCONNECT => buffer.read_u8().map(Self::Disconnect),
|
||||||
|
_ => Err(DecodeError::InvalidMessage),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerNegotiationMessage<'_> {
|
||||||
|
const TAG_HELLO: u8 = 0xB0;
|
||||||
|
const TAG_REJECT: u8 = 0xB1;
|
||||||
|
const TAG_START_KEX: u8 = 0xB2;
|
||||||
|
const TAG_PUBLIC_KEY: u8 = 0xB3;
|
||||||
|
const TAG_AGREED: u8 = 0xB4;
|
||||||
|
|
||||||
|
const TAG_CIPHERTEXT: u8 = 0xC0;
|
||||||
|
const TAG_PING: u8 = 0xC1;
|
||||||
|
const TAG_KICK: u8 = 0xC2;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encode for ServerHello<'_> {
|
||||||
|
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
|
||||||
|
buffer.write_variable_bytes(self.kex_algos)?;
|
||||||
|
buffer.write_variable_bytes(self.symmetric_ciphersuites)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Decode<'de> for ServerHello<'de> {
|
||||||
|
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
|
||||||
|
let kex_algos = buffer.read_variable_bytes()?;
|
||||||
|
let symmetric_ciphersuites = buffer.read_variable_bytes()?;
|
||||||
|
Ok(Self {
|
||||||
|
kex_algos,
|
||||||
|
symmetric_ciphersuites,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encode for ServerNegotiationMessage<'_> {
|
||||||
|
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
|
||||||
|
match self {
|
||||||
|
Self::Hello(hello) => {
|
||||||
|
buffer.write(&[Self::TAG_HELLO])?;
|
||||||
|
hello.encode(buffer)
|
||||||
|
}
|
||||||
|
Self::Reject(_reason) => todo!(),
|
||||||
|
Self::StartKex => buffer.write(&[Self::TAG_START_KEX]),
|
||||||
|
Self::PublicKey(end, data) => {
|
||||||
|
buffer.write(&[Self::TAG_PUBLIC_KEY, *end as u8])?;
|
||||||
|
buffer.write_variable_bytes(data)
|
||||||
|
}
|
||||||
|
Self::Agreed => buffer.write(&[Self::TAG_AGREED]),
|
||||||
|
Self::CipherText(data) => {
|
||||||
|
buffer.write(&[Self::TAG_CIPHERTEXT])?;
|
||||||
|
buffer.write_variable_bytes(data)
|
||||||
|
}
|
||||||
|
&Self::Kick(reason) => buffer.write(&[Self::TAG_KICK, reason]),
|
||||||
|
&Self::Ping => buffer.write(&[Self::TAG_PING]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Decode<'de> for ServerNegotiationMessage<'de> {
|
||||||
|
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
|
||||||
|
let tag = buffer.read_u8()?;
|
||||||
|
match tag {
|
||||||
|
Self::TAG_HELLO => ServerHello::decode(buffer).map(Self::Hello),
|
||||||
|
Self::TAG_REJECT => todo!(),
|
||||||
|
Self::TAG_START_KEX => Ok(Self::StartKex),
|
||||||
|
Self::TAG_PUBLIC_KEY => {
|
||||||
|
let end = buffer.read_u8()? != 0;
|
||||||
|
let data = buffer.read_variable_bytes()?;
|
||||||
|
Ok(Self::PublicKey(end, data))
|
||||||
|
}
|
||||||
|
Self::TAG_AGREED => Ok(Self::Agreed),
|
||||||
|
Self::TAG_CIPHERTEXT => buffer.read_variable_bytes().map(Self::CipherText),
|
||||||
|
Self::TAG_KICK => buffer.read_u8().map(Self::Kick),
|
||||||
|
Self::TAG_PING => Ok(Self::Ping),
|
||||||
|
|
||||||
|
_ => Err(DecodeError::InvalidMessage),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decrypt_blocked(aes: &Aes256, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> {
|
||||||
|
if src.len() % 16 != 0 || dst.len() < src.len() {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut pos = 0;
|
||||||
|
let mut out = 0;
|
||||||
|
let mut block = Block::from([0; 16]);
|
||||||
|
while pos != src.len() {
|
||||||
|
block.copy_from_slice(&src[pos..pos + 16]);
|
||||||
|
aes.decrypt_block(&mut block);
|
||||||
|
|
||||||
|
let len = block[0] as usize;
|
||||||
|
if len > 15 {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[out..out + len].copy_from_slice(&block[1..1 + len]);
|
||||||
|
|
||||||
|
out += len;
|
||||||
|
pos += 16;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encrypt_blocked(aes: &Aes256, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> {
|
||||||
|
if src.len() >= (dst.len() / 16) * 15 {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut pos = 0;
|
||||||
|
let mut out = 0;
|
||||||
|
let mut block = Block::from([0; 16]);
|
||||||
|
while pos != src.len() {
|
||||||
|
let len = core::cmp::min(src.len() - pos, 15);
|
||||||
|
block[0] = len as u8;
|
||||||
|
block[1..1 + len].copy_from_slice(&src[pos..pos + len]);
|
||||||
|
// TODO pad with random
|
||||||
|
if len != 15 {
|
||||||
|
block[1 + len..].fill(0);
|
||||||
|
}
|
||||||
|
aes.encrypt_block(&mut block);
|
||||||
|
|
||||||
|
dst[out..out + 16].copy_from_slice(&block);
|
||||||
|
|
||||||
|
pos += len;
|
||||||
|
out += 16;
|
||||||
|
}
|
||||||
|
Ok(out)
|
||||||
|
}
|
202
userspace/rsh/src/crypt/server.rs
Normal file
202
userspace/rsh/src/crypt/server.rs
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
mem,
|
||||||
|
net::SocketAddr,
|
||||||
|
os::fd::{AsRawFd, RawFd},
|
||||||
|
};
|
||||||
|
|
||||||
|
use aes::{cipher::KeyInit, Aes256};
|
||||||
|
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
crypt::{decrypt_blocked, ServerHello, ServerNegotiationMessage},
|
||||||
|
socket::{
|
||||||
|
MessageSocket, MultiplexedSocket, MultiplexedSocketEvent, PacketSocket, SocketWrapper,
|
||||||
|
},
|
||||||
|
Error,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{encrypt_blocked, ClientMessageProxy, ClientNegotiationMessage, ServerMessageProxy};
|
||||||
|
|
||||||
|
pub enum ServerPeerTransport {
|
||||||
|
PreNegotiation,
|
||||||
|
Negotiation(EphemeralSecret),
|
||||||
|
Connected(Aes256, usize),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ServerEncryptedSocket<S: PacketSocket> {
|
||||||
|
transport: SocketWrapper<S, ClientMessageProxy, ServerMessageProxy>,
|
||||||
|
peers: HashMap<SocketAddr, ServerPeerTransport>,
|
||||||
|
buffer: [u8; 256],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PacketSocket> ServerEncryptedSocket<S> {
|
||||||
|
pub fn new(transport: S) -> Self {
|
||||||
|
Self {
|
||||||
|
transport: SocketWrapper::new(transport),
|
||||||
|
peers: HashMap::new(),
|
||||||
|
buffer: [0; 256],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remove_client(&mut self, remote: &SocketAddr) {
|
||||||
|
let mut buf = [0; 32];
|
||||||
|
if self.peers.remove(remote).is_some() {
|
||||||
|
self.transport
|
||||||
|
.send_to(remote, &mut buf, &ServerNegotiationMessage::Kick(0))
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ping_clients(&mut self, limit: usize) -> Vec<SocketAddr> {
|
||||||
|
let mut send_buf = [0; 32];
|
||||||
|
let mut removed = vec![];
|
||||||
|
for (remote, state) in self.peers.iter_mut() {
|
||||||
|
match state {
|
||||||
|
ServerPeerTransport::Connected(_, missed) => {
|
||||||
|
self.transport.send_to(remote, &mut send_buf, &ServerNegotiationMessage::Ping).ok();
|
||||||
|
|
||||||
|
if *missed >= limit {
|
||||||
|
removed.push(*remote);
|
||||||
|
}
|
||||||
|
|
||||||
|
*missed += 1;
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for entry in removed.iter() {
|
||||||
|
self.peers.remove(entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
removed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PacketSocket> MultiplexedSocket for ServerEncryptedSocket<S> {
|
||||||
|
fn send_to(&mut self, remote: &SocketAddr, data: &[u8]) -> Result<(), Error> {
|
||||||
|
let mut buf = [0; 256];
|
||||||
|
if let Some(ServerPeerTransport::Connected(aes, _)) = self.peers.get(remote) {
|
||||||
|
let len = encrypt_blocked(aes, data, &mut self.buffer)?;
|
||||||
|
assert_eq!(len % 16, 0);
|
||||||
|
self.transport.send_to(
|
||||||
|
remote,
|
||||||
|
&mut buf,
|
||||||
|
&ServerNegotiationMessage::CipherText(&self.buffer[..len]),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
Err(Error::NotConnected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn recv_from<'a>(&mut self, buffer: &'a mut [u8]) -> Result<MultiplexedSocketEvent<'a>, Error> {
|
||||||
|
let (message, remote) = self.transport.recv_from(&mut self.buffer)?;
|
||||||
|
|
||||||
|
let mut buf = [0; 256];
|
||||||
|
|
||||||
|
if let Some(mut state) = self.peers.get_mut(&remote) {
|
||||||
|
match (message, &mut state) {
|
||||||
|
// TODO check kex params
|
||||||
|
(
|
||||||
|
ClientNegotiationMessage::StartKex { .. },
|
||||||
|
ServerPeerTransport::PreNegotiation,
|
||||||
|
) => {
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let secret = EphemeralSecret::random_from_rng(&mut rng);
|
||||||
|
*state = ServerPeerTransport::Negotiation(secret);
|
||||||
|
self.transport.send_to(
|
||||||
|
&remote,
|
||||||
|
&mut buf,
|
||||||
|
&ServerNegotiationMessage::StartKex,
|
||||||
|
)?;
|
||||||
|
Ok(MultiplexedSocketEvent::None(remote))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ClientNegotiationMessage::PublicKey(true, data),
|
||||||
|
ServerPeerTransport::Negotiation(secret),
|
||||||
|
) if data.len() == 32 => {
|
||||||
|
let public = PublicKey::from(&*secret);
|
||||||
|
let mut remote_key = [0; 32];
|
||||||
|
remote_key.copy_from_slice(data);
|
||||||
|
let remote_key = PublicKey::from(remote_key);
|
||||||
|
state.negotiate(&remote_key);
|
||||||
|
|
||||||
|
// Send public key to the client
|
||||||
|
self.transport.send_to(
|
||||||
|
&remote,
|
||||||
|
&mut buf,
|
||||||
|
&ServerNegotiationMessage::PublicKey(true, public.as_bytes()),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(MultiplexedSocketEvent::None(remote))
|
||||||
|
}
|
||||||
|
(ClientNegotiationMessage::Agreed, ServerPeerTransport::Connected(_, _)) => {
|
||||||
|
self.transport
|
||||||
|
.send_to(&remote, &mut buf, &ServerNegotiationMessage::Agreed)?;
|
||||||
|
Ok(MultiplexedSocketEvent::None(remote))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ClientNegotiationMessage::CipherText(data),
|
||||||
|
ServerPeerTransport::Connected(aes, _),
|
||||||
|
) => {
|
||||||
|
let len = decrypt_blocked(aes, data, buffer)?;
|
||||||
|
Ok(MultiplexedSocketEvent::ClientData(remote, &buffer[..len]))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ClientNegotiationMessage::Pong,
|
||||||
|
ServerPeerTransport::Connected(_, missed)
|
||||||
|
) => {
|
||||||
|
*missed = 0;
|
||||||
|
Ok(MultiplexedSocketEvent::None(remote))
|
||||||
|
}
|
||||||
|
(ClientNegotiationMessage::Disconnect(_reason), _) => {
|
||||||
|
eprintln!("Peer disconnected: {remote}");
|
||||||
|
self.peers.remove(&remote);
|
||||||
|
Ok(MultiplexedSocketEvent::ClientDisconnected(remote))
|
||||||
|
}
|
||||||
|
// Misbehavior
|
||||||
|
_ => {
|
||||||
|
self.peers.remove(&remote);
|
||||||
|
Ok(MultiplexedSocketEvent::ClientDisconnected(remote))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let ClientNegotiationMessage::Hello { protocol: 1 } = message else {
|
||||||
|
eprintln!("Unhandled client message");
|
||||||
|
return Ok(MultiplexedSocketEvent::None(remote));
|
||||||
|
};
|
||||||
|
eprintln!("Client Hello");
|
||||||
|
self.peers
|
||||||
|
.insert(remote, ServerPeerTransport::PreNegotiation);
|
||||||
|
|
||||||
|
// Reply with hello
|
||||||
|
let hello = ServerHello {
|
||||||
|
kex_algos: &[],
|
||||||
|
symmetric_ciphersuites: &[],
|
||||||
|
};
|
||||||
|
self.transport
|
||||||
|
.send_to(&remote, &mut buf, &ServerNegotiationMessage::Hello(hello))?;
|
||||||
|
|
||||||
|
Ok(MultiplexedSocketEvent::None(remote))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PacketSocket> AsRawFd for ServerEncryptedSocket<S> {
|
||||||
|
fn as_raw_fd(&self) -> RawFd {
|
||||||
|
self.transport.as_raw_fd()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerPeerTransport {
|
||||||
|
fn negotiate(&mut self, public: &PublicKey) {
|
||||||
|
let Self::Negotiation(secret) = mem::replace(self, ServerPeerTransport::PreNegotiation)
|
||||||
|
else {
|
||||||
|
panic!();
|
||||||
|
};
|
||||||
|
let shared = secret.diffie_hellman(public);
|
||||||
|
let aes = Aes256::new_from_slice(shared.as_bytes()).unwrap();
|
||||||
|
*self = ServerPeerTransport::Connected(aes, 0);
|
||||||
|
}
|
||||||
|
}
|
@ -2,20 +2,15 @@
|
|||||||
#![feature(generic_const_exprs)]
|
#![feature(generic_const_exprs)]
|
||||||
#![allow(incomplete_features)]
|
#![allow(incomplete_features)]
|
||||||
|
|
||||||
use std::{
|
use std::io;
|
||||||
io,
|
|
||||||
marker::PhantomData,
|
|
||||||
net::{SocketAddr, UdpSocket},
|
|
||||||
ops::{Deref, DerefMut},
|
|
||||||
os::fd::{AsRawFd, RawFd},
|
|
||||||
};
|
|
||||||
|
|
||||||
use proto::{
|
use proto::{DecodeError, EncodeError};
|
||||||
ClientMessageProxy, Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy,
|
|
||||||
ServerMessageProxy,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub mod proto;
|
pub mod proto;
|
||||||
|
pub mod socket;
|
||||||
|
pub mod crypt;
|
||||||
|
|
||||||
|
pub use socket::{ClientSocket, ServerSocket};
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
@ -27,106 +22,12 @@ pub enum Error {
|
|||||||
Decode(#[from] DecodeError),
|
Decode(#[from] DecodeError),
|
||||||
#[error("Encode error: {0}")]
|
#[error("Encode error: {0}")]
|
||||||
Encode(#[from] EncodeError),
|
Encode(#[from] EncodeError),
|
||||||
}
|
#[error("Peer not connected")]
|
||||||
|
NotConnected,
|
||||||
pub struct SocketWrapper<Rx: MessageProxy, Tx: MessageProxy> {
|
#[error("Ping")]
|
||||||
socket: UdpSocket,
|
Ping,
|
||||||
_pd: PhantomData<(Rx, Tx)>,
|
#[error("Invalid socket state")]
|
||||||
}
|
InvalidState,
|
||||||
|
#[error("Disconnected by remote peer")]
|
||||||
pub struct ClientSocket(SocketWrapper<ServerMessageProxy, ClientMessageProxy>);
|
Disconnected,
|
||||||
pub struct ServerSocket(SocketWrapper<ClientMessageProxy, ServerMessageProxy>);
|
|
||||||
|
|
||||||
impl<Rx: MessageProxy, Tx: MessageProxy> SocketWrapper<Rx, Tx> {
|
|
||||||
pub fn new(socket: UdpSocket) -> Self {
|
|
||||||
Self {
|
|
||||||
socket,
|
|
||||||
_pd: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send(&self, buffer: &mut [u8], message: &Tx::Type<'_>) -> Result<(), Error> {
|
|
||||||
let mut enc = Encoder::new(buffer);
|
|
||||||
message.encode(&mut enc)?;
|
|
||||||
let message = enc.get();
|
|
||||||
let amount = self.socket.send(message)?;
|
|
||||||
if amount == message.len() {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(Error::Truncated)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_to(
|
|
||||||
&mut self,
|
|
||||||
remote: &SocketAddr,
|
|
||||||
buffer: &mut [u8],
|
|
||||||
message: &Tx::Type<'_>,
|
|
||||||
) -> Result<(), Error> {
|
|
||||||
let mut enc = Encoder::new(buffer);
|
|
||||||
message.encode(&mut enc)?;
|
|
||||||
let message = enc.get();
|
|
||||||
let amount = self.socket.send_to(message, remote)?;
|
|
||||||
if amount == message.len() {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(Error::Truncated)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn recv_from<'de>(
|
|
||||||
&mut self,
|
|
||||||
buffer: &'de mut [u8],
|
|
||||||
) -> Result<(Rx::Type<'de>, SocketAddr), Error> {
|
|
||||||
let (len, remote) = self.socket.recv_from(buffer)?;
|
|
||||||
let mut dec = Decoder::new(&buffer[..len]);
|
|
||||||
let message = Rx::Type::<'de>::decode(&mut dec)?;
|
|
||||||
Ok((message, remote))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Rx: MessageProxy, Tx: MessageProxy> AsRawFd for SocketWrapper<Rx, Tx> {
|
|
||||||
fn as_raw_fd(&self) -> RawFd {
|
|
||||||
self.socket.as_raw_fd()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ClientSocket {
|
|
||||||
pub fn new(socket: UdpSocket) -> Self {
|
|
||||||
Self(SocketWrapper::new(socket))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Deref for ClientSocket {
|
|
||||||
type Target = SocketWrapper<ServerMessageProxy, ClientMessageProxy>;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DerefMut for ClientSocket {
|
|
||||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
|
||||||
&mut self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ServerSocket {
|
|
||||||
pub fn new(socket: UdpSocket) -> Self {
|
|
||||||
Self(SocketWrapper::new(socket))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Deref for ServerSocket {
|
|
||||||
type Target = SocketWrapper<ClientMessageProxy, ServerMessageProxy>;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DerefMut for ServerSocket {
|
|
||||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
|
||||||
&mut self.0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -12,8 +12,7 @@ use clap::Parser;
|
|||||||
use cross::io::Poll;
|
use cross::io::Poll;
|
||||||
use libterm::{RawMode, RawTerminal};
|
use libterm::{RawMode, RawTerminal};
|
||||||
use rsh::{
|
use rsh::{
|
||||||
proto::{ClientMessage, ServerMessage, TerminalInfo},
|
crypt::ClientEncryptedSocket, proto::{ClientMessage, ServerMessage, TerminalInfo}, socket::{MessageSocket, PacketSocket}, ClientSocket
|
||||||
ClientSocket,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const PING_TIMEOUT: Duration = Duration::from_secs(3);
|
pub const PING_TIMEOUT: Duration = Duration::from_secs(3);
|
||||||
@ -39,11 +38,10 @@ struct Args {
|
|||||||
|
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
poll: Poll,
|
poll: Poll,
|
||||||
socket: ClientSocket,
|
socket: ClientSocket<ClientEncryptedSocket<UdpSocket>>,
|
||||||
stdin: Stdin,
|
stdin: Stdin,
|
||||||
stdout: Stdout,
|
stdout: Stdout,
|
||||||
need_bye: bool,
|
need_bye: bool,
|
||||||
last_ping: Instant,
|
|
||||||
_raw: RawMode,
|
_raw: RawMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,7 +49,6 @@ pub enum Event<'b> {
|
|||||||
Stdin(&'b [u8]),
|
Stdin(&'b [u8]),
|
||||||
Data(&'b [u8]),
|
Data(&'b [u8]),
|
||||||
Disconnected(&'b str),
|
Disconnected(&'b str),
|
||||||
Ping,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Client {
|
impl Client {
|
||||||
@ -69,6 +66,8 @@ impl Client {
|
|||||||
let socket = UdpSocket::bind(local)?;
|
let socket = UdpSocket::bind(local)?;
|
||||||
socket.connect(remote)?;
|
socket.connect(remote)?;
|
||||||
|
|
||||||
|
let mut socket = ClientEncryptedSocket::new(socket);
|
||||||
|
socket.try_connect_blocking()?;
|
||||||
let mut socket = ClientSocket::new(socket);
|
let mut socket = ClientSocket::new(socket);
|
||||||
|
|
||||||
poll.add(&*socket)?;
|
poll.add(&*socket)?;
|
||||||
@ -82,7 +81,6 @@ impl Client {
|
|||||||
let _raw = unsafe { RawMode::enter(&stdin)? };
|
let _raw = unsafe { RawMode::enter(&stdin)? };
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
last_ping: Instant::now(),
|
|
||||||
poll,
|
poll,
|
||||||
socket,
|
socket,
|
||||||
stdout,
|
stdout,
|
||||||
@ -102,7 +100,11 @@ impl Client {
|
|||||||
|
|
||||||
match event {
|
match event {
|
||||||
fd if fd == self.socket.as_raw_fd() => {
|
fd if fd == self.socket.as_raw_fd() => {
|
||||||
let (message, _) = self.socket.recv_from(buffer)?;
|
let message = match self.socket.recv_from(buffer) {
|
||||||
|
Ok((message, _)) => message,
|
||||||
|
Err(rsh::Error::Ping) => return Ok(None),
|
||||||
|
Err(error) => return Err(error.into())
|
||||||
|
};
|
||||||
match message {
|
match message {
|
||||||
ServerMessage::Bye(reason) => {
|
ServerMessage::Bye(reason) => {
|
||||||
// No need for a bye
|
// No need for a bye
|
||||||
@ -112,9 +114,6 @@ impl Client {
|
|||||||
ServerMessage::Output(data) => {
|
ServerMessage::Output(data) => {
|
||||||
break Ok(Some(Event::Data(data)));
|
break Ok(Some(Event::Data(data)));
|
||||||
}
|
}
|
||||||
ServerMessage::Ping => {
|
|
||||||
break Ok(Some(Event::Ping));
|
|
||||||
}
|
|
||||||
// Ignore this one
|
// Ignore this one
|
||||||
ServerMessage::Hello => break Ok(None),
|
ServerMessage::Hello => break Ok(None),
|
||||||
}
|
}
|
||||||
@ -148,10 +147,6 @@ impl Client {
|
|||||||
Event::Stdin(data) => {
|
Event::Stdin(data) => {
|
||||||
self.socket.send(&mut send_buf, &ClientMessage::Input(data))?;
|
self.socket.send(&mut send_buf, &ClientMessage::Input(data))?;
|
||||||
}
|
}
|
||||||
Event::Ping => {
|
|
||||||
self.last_ping = Instant::now();
|
|
||||||
self.socket.send(&mut send_buf, &ClientMessage::Pong)?;
|
|
||||||
}
|
|
||||||
Event::Disconnected(reason) => {
|
Event::Disconnected(reason) => {
|
||||||
break Ok(reason.into());
|
break Ok(reason.into());
|
||||||
}
|
}
|
||||||
@ -161,16 +156,16 @@ impl Client {
|
|||||||
|
|
||||||
fn check_timeout(&mut self) -> Result<(), Error> {
|
fn check_timeout(&mut self) -> Result<(), Error> {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
if now - self.last_ping >= PING_TIMEOUT {
|
if now - self.socket.as_inner_mut().last_ping() >= PING_TIMEOUT {
|
||||||
Err(Error::Timeout)
|
Err(Error::Timeout)
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handshake(
|
fn handshake<S: PacketSocket>(
|
||||||
poll: &mut Poll,
|
poll: &mut Poll,
|
||||||
socket: &mut ClientSocket,
|
socket: &mut ClientSocket<S>,
|
||||||
terminal: TerminalInfo,
|
terminal: TerminalInfo,
|
||||||
attempts: usize,
|
attempts: usize,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
@ -187,9 +182,9 @@ impl Client {
|
|||||||
Self::try_handshake(poll, socket, terminal, timeout)
|
Self::try_handshake(poll, socket, terminal, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_handshake(
|
fn try_handshake<S: PacketSocket>(
|
||||||
poll: &mut Poll,
|
poll: &mut Poll,
|
||||||
socket: &mut ClientSocket,
|
socket: &mut ClientSocket<S>,
|
||||||
terminal: TerminalInfo,
|
terminal: TerminalInfo,
|
||||||
timeout: Duration,
|
timeout: Duration,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
|
@ -11,7 +11,6 @@ pub struct ServerMessageProxy;
|
|||||||
pub enum ClientMessage<'a> {
|
pub enum ClientMessage<'a> {
|
||||||
Hello(TerminalInfo),
|
Hello(TerminalInfo),
|
||||||
Bye(&'a str),
|
Bye(&'a str),
|
||||||
Pong,
|
|
||||||
Input(&'a [u8]),
|
Input(&'a [u8]),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -19,7 +18,6 @@ pub enum ClientMessage<'a> {
|
|||||||
pub enum ServerMessage<'a> {
|
pub enum ServerMessage<'a> {
|
||||||
Hello,
|
Hello,
|
||||||
Bye(&'a str),
|
Bye(&'a str),
|
||||||
Ping,
|
|
||||||
Output(&'a [u8]),
|
Output(&'a [u8]),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,6 +66,10 @@ impl<'a> Encoder<'a> {
|
|||||||
Self { buffer, pos: 0 }
|
Self { buffer, pos: 0 }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.pos = 0;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn write_variable_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
|
pub fn write_variable_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
|
||||||
let len: u32 = bytes
|
let len: u32 = bytes
|
||||||
.len()
|
.len()
|
||||||
@ -145,7 +147,6 @@ impl MessageProxy for ServerMessageProxy {
|
|||||||
impl ClientMessage<'_> {
|
impl ClientMessage<'_> {
|
||||||
const TAG_HELLO: u8 = 0x80;
|
const TAG_HELLO: u8 = 0x80;
|
||||||
const TAG_BYE: u8 = 0x81;
|
const TAG_BYE: u8 = 0x81;
|
||||||
const TAG_PONG: u8 = 0x82;
|
|
||||||
const TAG_INPUT: u8 = 0x90;
|
const TAG_INPUT: u8 = 0x90;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,9 +177,6 @@ impl<'a> Encode for ClientMessage<'a> {
|
|||||||
buffer.write(&[Self::TAG_BYE])?;
|
buffer.write(&[Self::TAG_BYE])?;
|
||||||
buffer.write_str(reason)
|
buffer.write_str(reason)
|
||||||
}
|
}
|
||||||
Self::Pong => {
|
|
||||||
buffer.write(&[Self::TAG_PONG])
|
|
||||||
}
|
|
||||||
Self::Input(data) => {
|
Self::Input(data) => {
|
||||||
buffer.write(&[Self::TAG_INPUT])?;
|
buffer.write(&[Self::TAG_INPUT])?;
|
||||||
buffer.write_variable_bytes(data)
|
buffer.write_variable_bytes(data)
|
||||||
@ -197,9 +195,6 @@ impl<'de> Decode<'de> for ClientMessage<'de> {
|
|||||||
Self::TAG_BYE => {
|
Self::TAG_BYE => {
|
||||||
buffer.read_str().map(Self::Bye)
|
buffer.read_str().map(Self::Bye)
|
||||||
}
|
}
|
||||||
Self::TAG_PONG => {
|
|
||||||
Ok(Self::Pong)
|
|
||||||
}
|
|
||||||
Self::TAG_INPUT => {
|
Self::TAG_INPUT => {
|
||||||
buffer.read_variable_bytes().map(Self::Input)
|
buffer.read_variable_bytes().map(Self::Input)
|
||||||
}
|
}
|
||||||
@ -211,7 +206,6 @@ impl<'de> Decode<'de> for ClientMessage<'de> {
|
|||||||
impl ServerMessage<'_> {
|
impl ServerMessage<'_> {
|
||||||
const TAG_HELLO: u8 = 0x10;
|
const TAG_HELLO: u8 = 0x10;
|
||||||
const TAG_BYE: u8 = 0x11;
|
const TAG_BYE: u8 = 0x11;
|
||||||
const TAG_PING: u8 = 0x12;
|
|
||||||
const TAG_OUTPUT: u8 = 0x20;
|
const TAG_OUTPUT: u8 = 0x20;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -223,8 +217,6 @@ impl<'a> Encode for ServerMessage<'a> {
|
|||||||
buffer.write(&[Self::TAG_BYE])?;
|
buffer.write(&[Self::TAG_BYE])?;
|
||||||
buffer.write_str(reason)
|
buffer.write_str(reason)
|
||||||
}
|
}
|
||||||
// TODO sequence number
|
|
||||||
Self::Ping => buffer.write(&[Self::TAG_PING]),
|
|
||||||
Self::Output(data) => {
|
Self::Output(data) => {
|
||||||
buffer.write(&[Self::TAG_OUTPUT])?;
|
buffer.write(&[Self::TAG_OUTPUT])?;
|
||||||
buffer.write_variable_bytes(data)
|
buffer.write_variable_bytes(data)
|
||||||
@ -241,7 +233,6 @@ impl<'de> Decode<'de> for ServerMessage<'de> {
|
|||||||
Self::TAG_BYE => {
|
Self::TAG_BYE => {
|
||||||
buffer.read_str().map(Self::Bye)
|
buffer.read_str().map(Self::Bye)
|
||||||
}
|
}
|
||||||
Self::TAG_PING => Ok(Self::Ping),
|
|
||||||
Self::TAG_OUTPUT => {
|
Self::TAG_OUTPUT => {
|
||||||
buffer.read_variable_bytes().map(Self::Output)
|
buffer.read_variable_bytes().map(Self::Output)
|
||||||
},
|
},
|
||||||
|
@ -13,8 +13,7 @@ use std::{
|
|||||||
|
|
||||||
use cross::io::{Poll, TimerFd};
|
use cross::io::{Poll, TimerFd};
|
||||||
use rsh::{
|
use rsh::{
|
||||||
proto::{ClientMessage, ServerMessage, TerminalInfo},
|
crypt::ServerEncryptedSocket, proto::{ClientMessage, Decode, Decoder, Encoder, ServerMessage, TerminalInfo}, socket::{MessageSocket, MultiplexedSocket, MultiplexedSocketEvent}, Error, ServerSocket
|
||||||
Error, ServerSocket,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const PING_INTERVAL: Duration = Duration::from_millis(500);
|
pub const PING_INTERVAL: Duration = Duration::from_millis(500);
|
||||||
@ -29,7 +28,7 @@ pub struct Session {
|
|||||||
pub struct Server {
|
pub struct Server {
|
||||||
poll: Poll,
|
poll: Poll,
|
||||||
timer: TimerFd,
|
timer: TimerFd,
|
||||||
socket: ServerSocket,
|
socket: ServerEncryptedSocket<UdpSocket>,
|
||||||
|
|
||||||
addr_to_session: HashMap<SocketAddr, RawFd>,
|
addr_to_session: HashMap<SocketAddr, RawFd>,
|
||||||
pty_to_session: HashMap<RawFd, Session>,
|
pty_to_session: HashMap<RawFd, Session>,
|
||||||
@ -98,8 +97,8 @@ impl Server {
|
|||||||
pub fn new(listen_addr: SocketAddr) -> Result<Self, Error> {
|
pub fn new(listen_addr: SocketAddr) -> Result<Self, Error> {
|
||||||
let mut poll = Poll::new()?;
|
let mut poll = Poll::new()?;
|
||||||
let timer = TimerFd::new()?;
|
let timer = TimerFd::new()?;
|
||||||
let socket = UdpSocket::bind(listen_addr).map(ServerSocket::new)?;
|
let socket = UdpSocket::bind(listen_addr).map(ServerEncryptedSocket::new)?;
|
||||||
poll.add(&*socket)?;
|
poll.add(&socket)?;
|
||||||
poll.add(&timer)?;
|
poll.add(&timer)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
poll,
|
poll,
|
||||||
@ -115,13 +114,27 @@ impl Server {
|
|||||||
|
|
||||||
match fd {
|
match fd {
|
||||||
fd if fd == self.socket.as_raw_fd() => {
|
fd if fd == self.socket.as_raw_fd() => {
|
||||||
let (message, remote) = match self.socket.recv_from(buffer) {
|
let event = self.socket.recv_from(buffer)?;
|
||||||
Ok((message, remote)) => (message, remote),
|
|
||||||
Err(error @ (Error::Decode(_) | Error::Truncated)) => {
|
let (message, remote) = match event {
|
||||||
eprintln!("Receive error: {error}");
|
MultiplexedSocketEvent::ClientDisconnected(remote) => {
|
||||||
|
self.remove_session_by_remote(remote).ok();
|
||||||
return Ok(None)
|
return Ok(None)
|
||||||
},
|
}
|
||||||
Err(error) => return Err(error),
|
MultiplexedSocketEvent::None(_) => return Ok(None),
|
||||||
|
MultiplexedSocketEvent::ClientData(peer, data) => {
|
||||||
|
let mut decoder = Decoder::new(data);
|
||||||
|
let message = ClientMessage::decode(&mut decoder);
|
||||||
|
(message, peer)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let message = match message {
|
||||||
|
Ok(message) => message,
|
||||||
|
Err(error) => {
|
||||||
|
eprintln!("Decode error: {error}");
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let event = match message {
|
let event = match message {
|
||||||
@ -136,11 +149,6 @@ impl Server {
|
|||||||
{
|
{
|
||||||
Event::SessionInput(*fd, remote, data)
|
Event::SessionInput(*fd, remote, data)
|
||||||
}
|
}
|
||||||
ClientMessage::Pong if let Some(fd) = self.addr_to_session.get(&remote) => {
|
|
||||||
let session = self.pty_to_session.get_mut(fd).unwrap();
|
|
||||||
session.timeouts = 0;
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
_ => return Ok(None),
|
_ => return Ok(None),
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -176,7 +184,7 @@ impl Server {
|
|||||||
eprintln!("PTY write error: {error}");
|
eprintln!("PTY write error: {error}");
|
||||||
self.remove_session_by_fd(fd)?;
|
self.remove_session_by_fd(fd)?;
|
||||||
self.socket
|
self.socket
|
||||||
.send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
|
.send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -190,13 +198,13 @@ impl Server {
|
|||||||
Ok(session) => {
|
Ok(session) => {
|
||||||
self.register_session(remote, session)?;
|
self.register_session(remote, session)?;
|
||||||
self.socket
|
self.socket
|
||||||
.send_to(&remote, &mut send_buf, &ServerMessage::Hello)
|
.send_message_to(&remote, &mut send_buf, &ServerMessage::Hello)
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
eprintln!("PTY open error: {err}");
|
eprintln!("PTY open error: {err}");
|
||||||
self.socket
|
self.socket
|
||||||
.send_to(
|
.send_message_to(
|
||||||
&remote,
|
&remote,
|
||||||
&mut send_buf,
|
&mut send_buf,
|
||||||
&ServerMessage::Bye("PTY open error"),
|
&ServerMessage::Bye("PTY open error"),
|
||||||
@ -208,48 +216,37 @@ impl Server {
|
|||||||
Event::Pty(fd, remote, event) => match event {
|
Event::Pty(fd, remote, event) => match event {
|
||||||
PtyEvent::Data(data) => {
|
PtyEvent::Data(data) => {
|
||||||
self.socket
|
self.socket
|
||||||
.send_to(&remote, &mut send_buf, &ServerMessage::Output(data))
|
.send_message_to(&remote, &mut send_buf, &ServerMessage::Output(data))
|
||||||
.ok();
|
.ok();
|
||||||
},
|
},
|
||||||
PtyEvent::Err(error) => {
|
PtyEvent::Err(error) => {
|
||||||
eprintln!("PTY read error: {error}");
|
eprintln!("PTY read error: {error}");
|
||||||
self.remove_session_by_fd(fd)?;
|
self.remove_session_by_fd(fd)?;
|
||||||
self.socket
|
self.socket
|
||||||
.send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
|
.send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
|
||||||
.ok();
|
.ok();
|
||||||
},
|
},
|
||||||
PtyEvent::Closed => {
|
PtyEvent::Closed => {
|
||||||
println!("End of PTY for {remote}");
|
println!("End of PTY for {remote}");
|
||||||
self.remove_session_by_fd(fd)?;
|
self.remove_session_by_fd(fd)?;
|
||||||
self.socket
|
self.socket
|
||||||
.send_to(&remote, &mut send_buf, &ServerMessage::Bye(""))
|
.send_message_to(&remote, &mut send_buf, &ServerMessage::Bye(""))
|
||||||
.ok();
|
.ok();
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Event::Tick => {
|
Event::Tick => {
|
||||||
// Restart the timer
|
// Restart the timer
|
||||||
self.update_client_timeouts(&mut send_buf)?;
|
self.update_client_timeouts()?;
|
||||||
self.timer.start(PING_INTERVAL)?;
|
self.timer.start(PING_INTERVAL)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_client_timeouts(&mut self, send_buf: &mut [u8]) -> Result<(), Error> {
|
fn update_client_timeouts(&mut self) -> Result<(), Error> {
|
||||||
let mut removed = vec![];
|
let removed = self.socket.ping_clients(8);
|
||||||
for (remote, fd) in self.addr_to_session.iter() {
|
for entry in removed {
|
||||||
let session = self.pty_to_session.get_mut(&fd).unwrap();
|
self.remove_session_by_remote(entry).ok();
|
||||||
self.socket.send_to(remote, send_buf, &ServerMessage::Ping).ok();
|
|
||||||
session.timeouts += 1;
|
|
||||||
if session.timeouts >= 10 {
|
|
||||||
removed.push((*remote, *fd));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (remote, fd) in removed {
|
|
||||||
eprintln!("Client {remote} timed out");
|
|
||||||
self.remove_session_by_fd(fd)?;
|
|
||||||
self.socket.send_to(&remote, send_buf, &ServerMessage::Bye("Timed out".into())).ok();
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -268,6 +265,7 @@ impl Server {
|
|||||||
// NOTE: this will block the whole server while the process finishes.
|
// NOTE: this will block the whole server while the process finishes.
|
||||||
session.shell.wait().ok();
|
session.shell.wait().ok();
|
||||||
self.addr_to_session.remove(&session.remote).unwrap();
|
self.addr_to_session.remove(&session.remote).unwrap();
|
||||||
|
self.socket.remove_client(&session.remote);
|
||||||
self.poll.remove(&fd)?;
|
self.poll.remove(&fd)?;
|
||||||
Ok(Some(session))
|
Ok(Some(session))
|
||||||
} else {
|
} else {
|
||||||
|
202
userspace/rsh/src/socket.rs
Normal file
202
userspace/rsh/src/socket.rs
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
use std::{
|
||||||
|
io,
|
||||||
|
marker::PhantomData,
|
||||||
|
net::{SocketAddr, UdpSocket},
|
||||||
|
ops::{Deref, DerefMut},
|
||||||
|
os::fd::{AsRawFd, RawFd},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
proto::{
|
||||||
|
ClientMessageProxy, Decode, Decoder, Encode, Encoder, MessageProxy, ServerMessageProxy,
|
||||||
|
},
|
||||||
|
Error,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct SocketWrapper<S: PacketSocket, Rx: MessageProxy, Tx: MessageProxy> {
|
||||||
|
socket: S,
|
||||||
|
_pd: PhantomData<(Rx, Tx)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait MessageSocket<Rx: MessageProxy, Tx: MessageProxy>: AsRawFd {
|
||||||
|
fn send(&mut self, buffer: &mut [u8], data: &Tx::Type<'_>) -> Result<(), Error>;
|
||||||
|
fn send_to(
|
||||||
|
&mut self,
|
||||||
|
addr: &SocketAddr,
|
||||||
|
buffer: &mut [u8],
|
||||||
|
data: &Tx::Type<'_>,
|
||||||
|
) -> Result<(), Error>;
|
||||||
|
fn recv_from<'de>(
|
||||||
|
&mut self,
|
||||||
|
buffer: &'de mut [u8],
|
||||||
|
) -> Result<(Rx::Type<'de>, SocketAddr), Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait PacketSocket: AsRawFd {
|
||||||
|
type Error: Into<Error>;
|
||||||
|
|
||||||
|
fn send(&mut self, data: &[u8]) -> Result<usize, Self::Error>;
|
||||||
|
fn send_to(&mut self, data: &[u8], addr: &SocketAddr) -> Result<usize, Self::Error>;
|
||||||
|
fn recv_from(&mut self, data: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum MultiplexedSocketEvent<'a> {
|
||||||
|
ClientData(SocketAddr, &'a [u8]),
|
||||||
|
ClientDisconnected(SocketAddr),
|
||||||
|
None(SocketAddr),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait MultiplexedSocket: AsRawFd {
|
||||||
|
fn recv_from<'a>(&mut self, buffer: &'a mut [u8]) -> Result<MultiplexedSocketEvent<'a>, Error>;
|
||||||
|
fn send_to(&mut self, remote: &SocketAddr, data: &[u8]) -> Result<(), Error>;
|
||||||
|
|
||||||
|
fn send_message_to<Tx: Encode>(
|
||||||
|
&mut self,
|
||||||
|
remote: &SocketAddr,
|
||||||
|
buffer: &mut [u8],
|
||||||
|
data: &Tx,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let mut encoder = Encoder::new(buffer);
|
||||||
|
data.encode(&mut encoder)?;
|
||||||
|
self.send_to(remote, encoder.get())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PacketSocket for UdpSocket {
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn send(&mut self, data: &[u8]) -> Result<usize, Self::Error> {
|
||||||
|
UdpSocket::send(self, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send_to(&mut self, data: &[u8], addr: &SocketAddr) -> Result<usize, Self::Error> {
|
||||||
|
UdpSocket::send_to(self, data, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn recv_from(&mut self, data: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
|
||||||
|
UdpSocket::recv_from(self, data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ClientSocket<S: PacketSocket>(SocketWrapper<S, ServerMessageProxy, ClientMessageProxy>);
|
||||||
|
pub struct ServerSocket<S: MultiplexedSocket>(S);
|
||||||
|
|
||||||
|
impl<S: PacketSocket, Rx: MessageProxy, Tx: MessageProxy> MessageSocket<Rx, Tx>
|
||||||
|
for SocketWrapper<S, Rx, Tx>
|
||||||
|
{
|
||||||
|
fn send(&mut self, buffer: &mut [u8], data: &Tx::Type<'_>) -> Result<(), Error> {
|
||||||
|
let mut enc = Encoder::new(buffer);
|
||||||
|
data.encode(&mut enc)?;
|
||||||
|
let message = enc.get();
|
||||||
|
let amount = self.socket.send(message).map_err(S::Error::into)?;
|
||||||
|
if amount == message.len() {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(Error::Truncated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send_to(
|
||||||
|
&mut self,
|
||||||
|
addr: &SocketAddr,
|
||||||
|
buffer: &mut [u8],
|
||||||
|
data: &Tx::Type<'_>,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let mut enc = Encoder::new(buffer);
|
||||||
|
data.encode(&mut enc)?;
|
||||||
|
let message = enc.get();
|
||||||
|
let amount = self.socket.send_to(message, addr).map_err(S::Error::into)?;
|
||||||
|
if amount == message.len() {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(Error::Truncated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn recv_from<'de>(
|
||||||
|
&mut self,
|
||||||
|
buffer: &'de mut [u8],
|
||||||
|
) -> Result<(Rx::Type<'de>, SocketAddr), Error> {
|
||||||
|
let (len, remote) = self.socket.recv_from(buffer).map_err(S::Error::into)?;
|
||||||
|
let mut dec = Decoder::new(&buffer[..len]);
|
||||||
|
let message = Rx::Type::<'de>::decode(&mut dec)?;
|
||||||
|
Ok((message, remote))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PacketSocket, Rx: MessageProxy, Tx: MessageProxy> SocketWrapper<S, Rx, Tx> {
|
||||||
|
pub fn new(socket: S) -> Self {
|
||||||
|
Self {
|
||||||
|
socket,
|
||||||
|
_pd: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_inner(self) -> S {
|
||||||
|
self.socket
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_inner_mut(&mut self) -> &mut S {
|
||||||
|
&mut self.socket
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PacketSocket, Rx: MessageProxy, Tx: MessageProxy> AsRawFd for SocketWrapper<S, Rx, Tx> {
|
||||||
|
fn as_raw_fd(&self) -> RawFd {
|
||||||
|
self.socket.as_raw_fd()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PacketSocket> ClientSocket<S> {
|
||||||
|
pub fn new(socket: S) -> Self {
|
||||||
|
Self(SocketWrapper::new(socket))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PacketSocket> Deref for ClientSocket<S> {
|
||||||
|
type Target = SocketWrapper<S, ServerMessageProxy, ClientMessageProxy>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: PacketSocket> DerefMut for ClientSocket<S> {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
&mut self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: MultiplexedSocket> ServerSocket<S> {
|
||||||
|
pub fn new(socket: S) -> Self {
|
||||||
|
Self(socket)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: MultiplexedSocket> Deref for ServerSocket<S> {
|
||||||
|
type Target = S;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: MultiplexedSocket> DerefMut for ServerSocket<S> {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
&mut self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// impl<S: PacketSocket> Deref for ServerSocket<S> {
|
||||||
|
// type Target = SocketWrapper<S, ClientMessageProxy, ServerMessageProxy>;
|
||||||
|
//
|
||||||
|
// fn deref(&self) -> &Self::Target {
|
||||||
|
// &self.0
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// impl<S: PacketSocket> DerefMut for ServerSocket<S> {
|
||||||
|
// fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
// &mut self.0
|
||||||
|
// }
|
||||||
|
// }
|
Loading…
x
Reference in New Issue
Block a user