rsh: better protocol handling

This commit is contained in:
Mark Poliakov 2024-11-01 18:44:41 +02:00
parent 2d9cc793e0
commit 3e605b3b11
6 changed files with 397 additions and 153 deletions

6
userspace/Cargo.lock generated
View File

@ -1014,10 +1014,7 @@ dependencies = [
"bytemuck", "bytemuck",
"clap", "clap",
"cross", "cross",
"flexbuffers",
"libterm", "libterm",
"serde",
"smallvec",
"thiserror", "thiserror",
] ]
@ -1200,9 +1197,6 @@ name = "smallvec"
version = "1.13.2" version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "spin" name = "spin"

View File

@ -9,11 +9,7 @@ path = "src/rshd/main.rs"
[dependencies] [dependencies]
clap.workspace = true clap.workspace = true
flexbuffers.workspace = true
libterm.workspace = true libterm.workspace = true
serde.workspace = true
thiserror.workspace = true thiserror.workspace = true
cross.workspace = true cross.workspace = true
bytemuck.workspace = true bytemuck.workspace = true
smallvec = { version = "1.13.2", features = ["serde"] }

View File

@ -1,4 +1,6 @@
#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))] #![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))]
#![feature(generic_const_exprs)]
#![allow(incomplete_features)]
use std::{ use std::{
io, io,
@ -8,44 +10,46 @@ use std::{
os::fd::{AsRawFd, RawFd}, os::fd::{AsRawFd, RawFd},
}; };
use proto::{ClientMessage, ServerMessage}; use proto::{
use serde::{de::DeserializeOwned, Deserialize, Serialize}; ClientMessageProxy, Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy,
ServerMessageProxy,
};
pub mod proto; pub mod proto;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum Error { pub enum Error {
#[error("Deserialization error: {0}")]
Deserialization(#[from] flexbuffers::DeserializationError),
#[error("Serialization error: {0}")]
Serialization(#[from] flexbuffers::SerializationError),
#[error("I/O error")] #[error("I/O error")]
Io(#[from] io::Error), Io(#[from] io::Error),
#[error("Could not send a message fully")] #[error("Could not send a message fully")]
Truncated, Truncated,
#[error("Decode error: {0}")]
Decode(#[from] DecodeError),
#[error("Encode error: {0}")]
Encode(#[from] EncodeError),
} }
pub struct SocketWrapper<Rx, Tx: Serialize> { pub struct SocketWrapper<Rx: MessageProxy, Tx: MessageProxy> {
socket: UdpSocket, socket: UdpSocket,
buffer: [u8; 256],
_pd: PhantomData<(Rx, Tx)>, _pd: PhantomData<(Rx, Tx)>,
} }
pub struct ClientSocket(SocketWrapper<ServerMessage, ClientMessage>); pub struct ClientSocket(SocketWrapper<ServerMessageProxy, ClientMessageProxy>);
pub struct ServerSocket(SocketWrapper<ClientMessage, ServerMessage>); pub struct ServerSocket(SocketWrapper<ClientMessageProxy, ServerMessageProxy>);
impl<Rx, Tx: Serialize> SocketWrapper<Rx, Tx> { impl<Rx: MessageProxy, Tx: MessageProxy> SocketWrapper<Rx, Tx> {
pub fn new(socket: UdpSocket) -> Self { pub fn new(socket: UdpSocket) -> Self {
Self { Self {
socket, socket,
buffer: [0; 256],
_pd: PhantomData, _pd: PhantomData,
} }
} }
pub fn send(&mut self, message: &Tx) -> Result<(), Error> { pub fn send(&self, buffer: &mut [u8], message: &Tx::Type<'_>) -> Result<(), Error> {
let message = flexbuffers::to_vec(message)?; let mut enc = Encoder::new(buffer);
let amount = self.socket.send(&message)?; message.encode(&mut enc)?;
let message = enc.get();
let amount = self.socket.send(message)?;
if amount == message.len() { if amount == message.len() {
Ok(()) Ok(())
} else { } else {
@ -53,9 +57,16 @@ impl<Rx, Tx: Serialize> SocketWrapper<Rx, Tx> {
} }
} }
pub fn send_to(&mut self, remote: &SocketAddr, message: &Tx) -> Result<(), Error> { pub fn send_to(
let message = flexbuffers::to_vec(message)?; &mut self,
let amount = self.socket.send_to(&message, remote)?; remote: &SocketAddr,
buffer: &mut [u8],
message: &Tx::Type<'_>,
) -> Result<(), Error> {
let mut enc = Encoder::new(buffer);
message.encode(&mut enc)?;
let message = enc.get();
let amount = self.socket.send_to(message, remote)?;
if amount == message.len() { if amount == message.len() {
Ok(()) Ok(())
} else { } else {
@ -63,26 +74,18 @@ impl<Rx, Tx: Serialize> SocketWrapper<Rx, Tx> {
} }
} }
pub fn recv_from(&mut self) -> Result<(Rx, SocketAddr), Error> pub fn recv_from<'de>(
where &mut self,
Rx: DeserializeOwned, buffer: &'de mut [u8],
{ ) -> Result<(Rx::Type<'de>, SocketAddr), Error> {
let (len, remote) = self.socket.recv_from(&mut self.buffer)?;
let message = flexbuffers::from_slice(&self.buffer[..len])?;
Ok((message, remote))
}
pub fn recv_from_with<'de>(&mut self, buffer: &'de mut [u8]) -> Result<(Rx, SocketAddr), Error>
where
Rx: Deserialize<'de> + 'de,
{
let (len, remote) = self.socket.recv_from(buffer)?; let (len, remote) = self.socket.recv_from(buffer)?;
let message = flexbuffers::from_slice(&buffer[..len])?; let mut dec = Decoder::new(&buffer[..len]);
let message = Rx::Type::<'de>::decode(&mut dec)?;
Ok((message, remote)) Ok((message, remote))
} }
} }
impl<Rx, Tx: Serialize> AsRawFd for SocketWrapper<Rx, Tx> { impl<Rx: MessageProxy, Tx: MessageProxy> AsRawFd for SocketWrapper<Rx, Tx> {
fn as_raw_fd(&self) -> RawFd { fn as_raw_fd(&self) -> RawFd {
self.socket.as_raw_fd() self.socket.as_raw_fd()
} }
@ -95,7 +98,7 @@ impl ClientSocket {
} }
impl Deref for ClientSocket { impl Deref for ClientSocket {
type Target = SocketWrapper<ServerMessage, ClientMessage>; type Target = SocketWrapper<ServerMessageProxy, ClientMessageProxy>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.0 &self.0
@ -115,7 +118,7 @@ impl ServerSocket {
} }
impl Deref for ServerSocket { impl Deref for ServerSocket {
type Target = SocketWrapper<ClientMessage, ServerMessage>; type Target = SocketWrapper<ClientMessageProxy, ServerMessageProxy>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.0 &self.0

View File

@ -12,7 +12,7 @@ use clap::Parser;
use cross::io::Poll; use cross::io::Poll;
use libterm::{RawMode, RawTerminal}; use libterm::{RawMode, RawTerminal};
use rsh::{ use rsh::{
proto::{ClientData, ClientMessage, ServerData, ServerMessage, TerminalInfo}, proto::{ClientMessage, ServerMessage, TerminalInfo},
ClientSocket, ClientSocket,
}; };
@ -44,13 +44,14 @@ pub struct Client {
stdout: Stdout, stdout: Stdout,
need_bye: bool, need_bye: bool,
last_ping: Instant, last_ping: Instant,
_raw: RawMode _raw: RawMode,
} }
pub enum Event { pub enum Event<'b> {
Stdin(ClientData), Stdin(&'b [u8]),
Data(ServerData), Data(&'b [u8]),
Disconnected(String), Disconnected(&'b str),
Ping,
} }
impl Client { impl Client {
@ -91,8 +92,8 @@ impl Client {
}) })
} }
pub fn poll(&mut self) -> Result<Event, Error> { pub fn poll<'b>(&mut self, buffer: &'b mut [u8], pty_max: usize) -> Result<Option<Event<'b>>, Error> {
let mut buf = [0; 16]; // let mut buf = [0; 16];
let event = loop { let event = loop {
let Some(event) = self.poll.wait(Some(Duration::from_millis(500)))? else { let Some(event) = self.poll.wait(Some(Duration::from_millis(500)))? else {
self.check_timeout()?; self.check_timeout()?;
@ -101,29 +102,26 @@ impl Client {
match event { match event {
fd if fd == self.socket.as_raw_fd() => { fd if fd == self.socket.as_raw_fd() => {
let (message, _) = self.socket.recv_from()?; let (message, _) = self.socket.recv_from(buffer)?;
match message { match message {
ServerMessage::Bye(reason) => { ServerMessage::Bye(reason) => {
// No need for a bye // No need for a bye
self.need_bye = false; self.need_bye = false;
break Ok(Event::Disconnected(reason)); break Ok(Some(Event::Disconnected(reason)));
} }
ServerMessage::Output(data) => { ServerMessage::Output(data) => {
break Ok(Event::Data(data)); break Ok(Some(Event::Data(data)));
} }
ServerMessage::Ping => { ServerMessage::Ping => {
self.last_ping = Instant::now(); break Ok(Some(Event::Ping));
self.socket.send(&ClientMessage::Pong).ok();
continue;
} }
// Ignore this one // Ignore this one
ServerMessage::Hello => continue, ServerMessage::Hello => break Ok(None),
} }
} }
fd if fd == self.stdin.as_raw_fd() => { fd if fd == self.stdin.as_raw_fd() => {
let len = self.stdin.read(&mut buf)?; let len = self.stdin.read(&mut buffer[..pty_max])?;
let data = ClientData::from_slice(&buf[..len]); break Ok(Some(Event::Stdin(&buffer[..len])));
break Ok(Event::Stdin(data));
} }
_ => unreachable!() _ => unreachable!()
} }
@ -135,19 +133,27 @@ impl Client {
} }
pub fn run(mut self) -> Result<String, Error> { pub fn run(mut self) -> Result<String, Error> {
let mut recv_buf = [0; 256];
let mut send_buf = [0; 256];
loop { loop {
let event = self.poll()?; let Some(event) = self.poll(&mut recv_buf, 64)? else {
continue;
};
match event { match event {
Event::Data(data) => { Event::Data(data) => {
self.stdout.write_all(&data).ok(); self.stdout.write_all(data)?;
self.stdout.flush().ok(); self.stdout.flush()?;
} }
Event::Stdin(data) => { Event::Stdin(data) => {
self.socket.send(&ClientMessage::Input(data))?; self.socket.send(&mut send_buf, &ClientMessage::Input(data))?;
}
Event::Ping => {
self.last_ping = Instant::now();
self.socket.send(&mut send_buf, &ClientMessage::Pong)?;
} }
Event::Disconnected(reason) => { Event::Disconnected(reason) => {
break Ok(reason); break Ok(reason.into());
} }
} }
} }
@ -187,15 +193,16 @@ impl Client {
terminal: TerminalInfo, terminal: TerminalInfo,
timeout: Duration, timeout: Duration,
) -> Result<(), Error> { ) -> Result<(), Error> {
socket.send(&ClientMessage::Hello(terminal))?; let mut buffer = [0; 512];
socket.send(&mut buffer, &ClientMessage::Hello(terminal))?;
if poll.wait(Some(timeout))?.is_none() { if poll.wait(Some(timeout))?.is_none() {
return Err(Error::Timeout); return Err(Error::Timeout);
}; };
let (message, _) = socket.recv_from()?; let (message, _) = socket.recv_from(&mut buffer)?;
match message { match message {
ServerMessage::Hello => Ok(()), ServerMessage::Hello => Ok(()),
ServerMessage::Bye(reason) => Err(Error::Disconnected(reason)), ServerMessage::Bye(reason) => Err(Error::Disconnected(reason.into())),
_ => Err(Error::Disconnected("Invalid message received".into())), _ => Err(Error::Disconnected("Invalid message received".into())),
} }
} }
@ -204,7 +211,10 @@ impl Client {
impl Drop for Client { impl Drop for Client {
fn drop(&mut self) { fn drop(&mut self) {
if self.need_bye { if self.need_bye {
self.socket.send(&ClientMessage::Bye("".into())).ok(); let mut buf = [0; 32];
self.socket
.send(&mut buf, &ClientMessage::Bye("".into()))
.ok();
} }
} }
} }

View File

@ -1,27 +1,251 @@
use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Copy)]
use smallvec::SmallVec;
pub type ServerData = SmallVec<[u8; 64]>;
pub type ClientData = SmallVec<[u8; 16]>;
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct TerminalInfo { pub struct TerminalInfo {
pub columns: u32, pub columns: u32,
pub rows: u32 pub rows: u32,
} }
#[derive(Debug, Serialize, Deserialize)] pub struct ClientMessageProxy;
pub enum ClientMessage { pub struct ServerMessageProxy;
#[derive(Debug)]
pub enum ClientMessage<'a> {
Hello(TerminalInfo), Hello(TerminalInfo),
Bye(String), Bye(&'a str),
Pong, Pong,
Input(ClientData), Input(&'a [u8]),
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug)]
pub enum ServerMessage { pub enum ServerMessage<'a> {
Hello, Hello,
Bye(String), Bye(&'a str),
Ping, Ping,
Output(ServerData), Output(&'a [u8]),
}
#[derive(Debug, thiserror::Error)]
pub enum EncodeError {
#[error("Message too large")]
MessageTooLarge,
#[error("Value too long")]
ValueTooLong,
}
#[derive(Debug, thiserror::Error)]
pub enum DecodeError {
#[error("Truncated message received")]
Truncated,
#[error("Malformed message received")]
InvalidMessage,
#[error("Malformed string in the message")]
InvalidString(core::str::Utf8Error),
}
pub struct Encoder<'a> {
buffer: &'a mut [u8],
pos: usize,
}
pub struct Decoder<'a> {
buffer: &'a [u8],
pos: usize,
}
pub trait Encode {
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError>;
}
pub trait Decode<'de>: Sized + 'de {
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError>;
}
pub trait MessageProxy {
type Type<'de>: Encode + Decode<'de>;
}
impl<'a> Encoder<'a> {
pub const fn new(buffer: &'a mut [u8]) -> Self {
Self { buffer, pos: 0 }
}
pub fn write_variable_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
let len: u32 = bytes
.len()
.try_into()
.map_err(|_| EncodeError::ValueTooLong)?;
self.write(&len.to_le_bytes())?;
self.write(bytes)
}
pub fn write_str(&mut self, s: &str) -> Result<(), EncodeError> {
self.write_variable_bytes(s.as_bytes())
}
pub fn write(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
if self.pos + bytes.len() > self.buffer.len() {
return Err(EncodeError::ValueTooLong);
}
self.buffer[self.pos..self.pos + bytes.len()].copy_from_slice(bytes);
self.pos += bytes.len();
Ok(())
}
pub fn get(&self) -> &[u8] {
&self.buffer[..self.pos]
}
}
impl<'a> Decoder<'a> {
pub fn new(buffer: &'a [u8]) -> Self {
Self { buffer, pos: 0 }
}
pub fn read_u8(&mut self) -> Result<u8, DecodeError> {
if self.pos + 1 > self.buffer.len() {
return Err(DecodeError::Truncated);
}
let byte = self.buffer[self.pos];
self.pos += 1;
Ok(byte)
}
pub fn read_le_u32(&mut self) -> Result<u32, DecodeError> {
let bytes = self.read_bytes(size_of::<u32>())?;
Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
}
pub fn read_bytes(&mut self, len: usize) -> Result<&'a [u8], DecodeError> {
if self.pos + len > self.buffer.len() {
return Err(DecodeError::Truncated);
}
let slice = &self.buffer[self.pos..self.pos + len];
self.pos += len;
Ok(slice)
}
pub fn read_variable_bytes(&mut self) -> Result<&'a [u8], DecodeError> {
let len = self.read_le_u32()?;
self.read_bytes(len as usize)
}
pub fn read_str(&mut self) -> Result<&'a str, DecodeError> {
let slice = self.read_variable_bytes()?;
core::str::from_utf8(slice).map_err(DecodeError::InvalidString)
}
}
impl MessageProxy for ClientMessageProxy {
type Type<'de> = ClientMessage<'de>;
}
impl MessageProxy for ServerMessageProxy {
type Type<'de> = ServerMessage<'de>;
}
impl ClientMessage<'_> {
const TAG_HELLO: u8 = 0x80;
const TAG_BYE: u8 = 0x81;
const TAG_PONG: u8 = 0x82;
const TAG_INPUT: u8 = 0x90;
}
impl Encode for TerminalInfo {
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
buffer.write(&self.columns.to_le_bytes())?;
buffer.write(&self.rows.to_le_bytes())?;
Ok(())
}
}
impl<'de> Decode<'de> for TerminalInfo {
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
let columns = buffer.read_le_u32()?;
let rows = buffer.read_le_u32()?;
Ok(Self { columns, rows })
}
}
impl<'a> Encode for ClientMessage<'a> {
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
match self {
Self::Hello(info) => {
buffer.write(&[Self::TAG_HELLO])?;
info.encode(buffer)
}
Self::Bye(reason) => {
buffer.write(&[Self::TAG_BYE])?;
buffer.write_str(reason)
}
Self::Pong => {
buffer.write(&[Self::TAG_PONG])
}
Self::Input(data) => {
buffer.write(&[Self::TAG_INPUT])?;
buffer.write_variable_bytes(data)
}
}
}
}
impl<'de> Decode<'de> for ClientMessage<'de> {
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
let tag = buffer.read_u8()?;
match tag {
Self::TAG_HELLO => {
TerminalInfo::decode(buffer).map(Self::Hello)
}
Self::TAG_BYE => {
buffer.read_str().map(Self::Bye)
}
Self::TAG_PONG => {
Ok(Self::Pong)
}
Self::TAG_INPUT => {
buffer.read_variable_bytes().map(Self::Input)
}
_ => Err(DecodeError::InvalidMessage)
}
}
}
impl ServerMessage<'_> {
const TAG_HELLO: u8 = 0x10;
const TAG_BYE: u8 = 0x11;
const TAG_PING: u8 = 0x12;
const TAG_OUTPUT: u8 = 0x20;
}
impl<'a> Encode for ServerMessage<'a> {
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
match self {
Self::Hello => buffer.write(&[Self::TAG_HELLO]),
Self::Bye(reason) => {
buffer.write(&[Self::TAG_BYE])?;
buffer.write_str(reason)
}
// TODO sequence number
Self::Ping => buffer.write(&[Self::TAG_PING]),
Self::Output(data) => {
buffer.write(&[Self::TAG_OUTPUT])?;
buffer.write_variable_bytes(data)
}
}
}
}
impl<'de> Decode<'de> for ServerMessage<'de> {
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
let tag = buffer.read_u8()?;
match tag {
Self::TAG_HELLO => Ok(Self::Hello),
Self::TAG_BYE => {
buffer.read_str().map(Self::Bye)
}
Self::TAG_PING => Ok(Self::Ping),
Self::TAG_OUTPUT => {
buffer.read_variable_bytes().map(Self::Output)
},
_ => Err(DecodeError::InvalidMessage),
}
}
} }

View File

@ -13,7 +13,7 @@ use std::{
use cross::io::{Poll, TimerFd}; use cross::io::{Poll, TimerFd};
use rsh::{ use rsh::{
proto::{ClientData, ClientMessage, ServerData, ServerMessage, TerminalInfo}, proto::{ClientMessage, ServerMessage, TerminalInfo},
Error, ServerSocket, Error, ServerSocket,
}; };
@ -35,17 +35,18 @@ pub struct Server {
pty_to_session: HashMap<RawFd, Session>, pty_to_session: HashMap<RawFd, Session>,
} }
pub enum PtyEvent { pub enum PtyEvent<'b> {
Data(ServerData), Data(&'b [u8]),
Err(Error), Err(Error),
Closed, Closed,
} }
pub enum Event { pub enum Event<'b> {
NewClient(SocketAddr, TerminalInfo), NewClient(SocketAddr, TerminalInfo),
SessionInput(RawFd, SocketAddr, ClientData), SessionInput(RawFd, SocketAddr, &'b [u8]),
ClientBye(SocketAddr, String), ClientBye(SocketAddr, &'b str),
Pty(RawFd, SocketAddr, PtyEvent), Pty(RawFd, SocketAddr, PtyEvent<'b>),
Tick,
} }
impl Session { impl Session {
@ -109,14 +110,19 @@ impl Server {
}) })
} }
pub fn poll(&mut self) -> Result<Event, Error> { pub fn poll<'b>(&mut self, buffer: &'b mut [u8], pty_max: usize) -> Result<Option<Event<'b>>, Error> {
let mut buf = [0; 64];
loop {
let fd = self.poll.wait(None)?.unwrap(); let fd = self.poll.wait(None)?.unwrap();
match fd { match fd {
fd if fd == self.socket.as_raw_fd() => { fd if fd == self.socket.as_raw_fd() => {
let (message, remote) = self.socket.recv_from()?; let (message, remote) = match self.socket.recv_from(buffer) {
Ok((message, remote)) => (message, remote),
Err(error @ (Error::Decode(_) | Error::Truncated)) => {
eprintln!("Receive error: {error}");
return Ok(None)
},
Err(error) => return Err(error),
};
let event = match message { let event = match message {
ClientMessage::Hello(terminal) ClientMessage::Hello(terminal)
@ -133,35 +139,35 @@ impl Server {
ClientMessage::Pong if let Some(fd) = self.addr_to_session.get(&remote) => { ClientMessage::Pong if let Some(fd) = self.addr_to_session.get(&remote) => {
let session = self.pty_to_session.get_mut(fd).unwrap(); let session = self.pty_to_session.get_mut(fd).unwrap();
session.timeouts = 0; session.timeouts = 0;
continue return Ok(None);
}, }
_ => continue, _ => return Ok(None),
}; };
break Ok(event); Ok(Some(event))
}
fd if fd == self.timer.as_raw_fd() => {
// Restart the timer
self.update_client_timeouts()?;
self.timer.start(PING_INTERVAL)?;
} }
fd if fd == self.timer.as_raw_fd() => Ok(Some(Event::Tick)),
fd => { fd => {
let session = self.pty_to_session.get_mut(&fd).unwrap(); let session = self.pty_to_session.get_mut(&fd).unwrap();
let event = match session.pty_master.read(&mut buf) { let event = match session.pty_master.read(&mut buffer[..pty_max]) {
Ok(0) => PtyEvent::Closed, Ok(0) => PtyEvent::Closed,
Ok(len) => PtyEvent::Data(buf[..len].into()), Ok(len) => PtyEvent::Data(&buffer[..len]),
Err(e) => PtyEvent::Err(e.into()), Err(e) => PtyEvent::Err(e.into()),
}; };
break Ok(Event::Pty(fd, session.remote, event)); Ok(Some(Event::Pty(fd, session.remote, event)))
}
} }
} }
} }
pub fn run(mut self) -> Result<(), Error> { pub fn run(mut self) -> Result<(), Error> {
self.timer.start(PING_INTERVAL)?; self.timer.start(PING_INTERVAL)?;
let mut recv_buf = [0; 256];
let mut send_buf = [0; 256];
loop { loop {
let event = self.poll()?; let Some(event) = self.poll(&mut recv_buf, 128)? else {
continue;
};
match event { match event {
Event::SessionInput(fd, remote, data) => { Event::SessionInput(fd, remote, data) => {
@ -170,7 +176,7 @@ impl Server {
eprintln!("PTY write error: {error}"); eprintln!("PTY write error: {error}");
self.remove_session_by_fd(fd)?; self.remove_session_by_fd(fd)?;
self.socket self.socket
.send_to(&remote, &ServerMessage::Bye(format!("PTY error: {error}"))) .send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
.ok(); .ok();
} }
} }
@ -183,46 +189,57 @@ impl Server {
match Session::open(&terminal, remote) { match Session::open(&terminal, remote) {
Ok(session) => { Ok(session) => {
self.register_session(remote, session)?; self.register_session(remote, session)?;
self.socket.send_to(&remote, &ServerMessage::Hello).ok(); self.socket
.send_to(&remote, &mut send_buf, &ServerMessage::Hello)
.ok();
} }
Err(err) => { Err(err) => {
eprintln!("PTY open error: {err}"); eprintln!("PTY open error: {err}");
self.socket self.socket
.send_to(&remote, &ServerMessage::Bye("PTY open error".into())) .send_to(
&remote,
&mut send_buf,
&ServerMessage::Bye("PTY open error"),
)
.ok(); .ok();
} }
} }
} }
Event::Pty(fd, remote, event) => match event { Event::Pty(fd, remote, event) => match event {
PtyEvent::Data(data) => {
self.socket
.send_to(&remote, &mut send_buf, &ServerMessage::Output(data))
.ok();
},
PtyEvent::Err(error) => { PtyEvent::Err(error) => {
eprintln!("PTY read error: {error}"); eprintln!("PTY read error: {error}");
self.remove_session_by_fd(fd)?; self.remove_session_by_fd(fd)?;
self.socket self.socket
.send_to(&remote, &ServerMessage::Bye(format!("PTY error: {error}"))) .send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
.ok(); .ok();
} },
PtyEvent::Data(data) => {
self.socket
.send_to(&remote, &ServerMessage::Output(data))
.ok();
}
PtyEvent::Closed => { PtyEvent::Closed => {
println!("End of PTY for {remote}"); println!("End of PTY for {remote}");
self.remove_session_by_fd(fd)?; self.remove_session_by_fd(fd)?;
self.socket self.socket
.send_to(&remote, &ServerMessage::Bye("".into())) .send_to(&remote, &mut send_buf, &ServerMessage::Bye(""))
.ok(); .ok();
}
}, },
},
Event::Tick => {
// Restart the timer
self.update_client_timeouts(&mut send_buf)?;
self.timer.start(PING_INTERVAL)?;
}
} }
} }
} }
fn update_client_timeouts(&mut self) -> Result<(), Error> { fn update_client_timeouts(&mut self, send_buf: &mut [u8]) -> Result<(), Error> {
let mut removed = vec![]; let mut removed = vec![];
for (remote, fd) in self.addr_to_session.iter() { for (remote, fd) in self.addr_to_session.iter() {
let session = self.pty_to_session.get_mut(&fd).unwrap(); let session = self.pty_to_session.get_mut(&fd).unwrap();
self.socket.send_to(remote, &ServerMessage::Ping).ok(); self.socket.send_to(remote, send_buf, &ServerMessage::Ping).ok();
session.timeouts += 1; session.timeouts += 1;
if session.timeouts >= 10 { if session.timeouts >= 10 {
removed.push((*remote, *fd)); removed.push((*remote, *fd));
@ -232,7 +249,7 @@ impl Server {
for (remote, fd) in removed { for (remote, fd) in removed {
eprintln!("Client {remote} timed out"); eprintln!("Client {remote} timed out");
self.remove_session_by_fd(fd)?; self.remove_session_by_fd(fd)?;
self.socket.send_to(&remote, &ServerMessage::Bye("Timed out".into())).ok(); self.socket.send_to(&remote, send_buf, &ServerMessage::Bye("Timed out".into())).ok();
} }
Ok(()) Ok(())
} }