From 3e605b3b11ca415eb30334fa5f9b084405502354 Mon Sep 17 00:00:00 2001 From: Mark Poliakov Date: Fri, 1 Nov 2024 18:44:41 +0200 Subject: [PATCH] rsh: better protocol handling --- userspace/Cargo.lock | 6 - userspace/rsh/Cargo.toml | 4 - userspace/rsh/src/lib.rs | 73 +++++----- userspace/rsh/src/main.rs | 64 +++++---- userspace/rsh/src/proto.rs | 256 ++++++++++++++++++++++++++++++--- userspace/rsh/src/rshd/main.rs | 147 ++++++++++--------- 6 files changed, 397 insertions(+), 153 deletions(-) diff --git a/userspace/Cargo.lock b/userspace/Cargo.lock index bf229970..eb66cc74 100644 --- a/userspace/Cargo.lock +++ b/userspace/Cargo.lock @@ -1014,10 +1014,7 @@ dependencies = [ "bytemuck", "clap", "cross", - "flexbuffers", "libterm", - "serde", - "smallvec", "thiserror", ] @@ -1200,9 +1197,6 @@ name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" -dependencies = [ - "serde", -] [[package]] name = "spin" diff --git a/userspace/rsh/Cargo.toml b/userspace/rsh/Cargo.toml index 07b06a48..9b72a83c 100644 --- a/userspace/rsh/Cargo.toml +++ b/userspace/rsh/Cargo.toml @@ -9,11 +9,7 @@ path = "src/rshd/main.rs" [dependencies] clap.workspace = true -flexbuffers.workspace = true libterm.workspace = true -serde.workspace = true thiserror.workspace = true cross.workspace = true bytemuck.workspace = true - -smallvec = { version = "1.13.2", features = ["serde"] } diff --git a/userspace/rsh/src/lib.rs b/userspace/rsh/src/lib.rs index 0cbca488..b71d8630 100644 --- a/userspace/rsh/src/lib.rs +++ b/userspace/rsh/src/lib.rs @@ -1,4 +1,6 @@ #![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))] +#![feature(generic_const_exprs)] +#![allow(incomplete_features)] use std::{ io, @@ -8,44 +10,46 @@ use std::{ os::fd::{AsRawFd, RawFd}, }; -use proto::{ClientMessage, ServerMessage}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use proto::{ + ClientMessageProxy, Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy, + ServerMessageProxy, +}; pub mod proto; #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("Deserialization error: {0}")] - Deserialization(#[from] flexbuffers::DeserializationError), - #[error("Serialization error: {0}")] - Serialization(#[from] flexbuffers::SerializationError), #[error("I/O error")] Io(#[from] io::Error), #[error("Could not send a message fully")] Truncated, + #[error("Decode error: {0}")] + Decode(#[from] DecodeError), + #[error("Encode error: {0}")] + Encode(#[from] EncodeError), } -pub struct SocketWrapper { +pub struct SocketWrapper { socket: UdpSocket, - buffer: [u8; 256], _pd: PhantomData<(Rx, Tx)>, } -pub struct ClientSocket(SocketWrapper); -pub struct ServerSocket(SocketWrapper); +pub struct ClientSocket(SocketWrapper); +pub struct ServerSocket(SocketWrapper); -impl SocketWrapper { +impl SocketWrapper { pub fn new(socket: UdpSocket) -> Self { Self { socket, - buffer: [0; 256], _pd: PhantomData, } } - pub fn send(&mut self, message: &Tx) -> Result<(), Error> { - let message = flexbuffers::to_vec(message)?; - let amount = self.socket.send(&message)?; + pub fn send(&self, buffer: &mut [u8], message: &Tx::Type<'_>) -> Result<(), Error> { + let mut enc = Encoder::new(buffer); + message.encode(&mut enc)?; + let message = enc.get(); + let amount = self.socket.send(message)?; if amount == message.len() { Ok(()) } else { @@ -53,9 +57,16 @@ impl SocketWrapper { } } - pub fn send_to(&mut self, remote: &SocketAddr, message: &Tx) -> Result<(), Error> { - let message = flexbuffers::to_vec(message)?; - let amount = self.socket.send_to(&message, remote)?; + pub fn send_to( + &mut self, + remote: &SocketAddr, + buffer: &mut [u8], + message: &Tx::Type<'_>, + ) -> Result<(), Error> { + let mut enc = Encoder::new(buffer); + message.encode(&mut enc)?; + let message = enc.get(); + let amount = self.socket.send_to(message, remote)?; if amount == message.len() { Ok(()) } else { @@ -63,26 +74,18 @@ impl SocketWrapper { } } - pub fn recv_from(&mut self) -> Result<(Rx, SocketAddr), Error> - where - Rx: DeserializeOwned, - { - let (len, remote) = self.socket.recv_from(&mut self.buffer)?; - let message = flexbuffers::from_slice(&self.buffer[..len])?; - Ok((message, remote)) - } - - pub fn recv_from_with<'de>(&mut self, buffer: &'de mut [u8]) -> Result<(Rx, SocketAddr), Error> - where - Rx: Deserialize<'de> + 'de, - { + pub fn recv_from<'de>( + &mut self, + buffer: &'de mut [u8], + ) -> Result<(Rx::Type<'de>, SocketAddr), Error> { let (len, remote) = self.socket.recv_from(buffer)?; - let message = flexbuffers::from_slice(&buffer[..len])?; + let mut dec = Decoder::new(&buffer[..len]); + let message = Rx::Type::<'de>::decode(&mut dec)?; Ok((message, remote)) } } -impl AsRawFd for SocketWrapper { +impl AsRawFd for SocketWrapper { fn as_raw_fd(&self) -> RawFd { self.socket.as_raw_fd() } @@ -95,7 +98,7 @@ impl ClientSocket { } impl Deref for ClientSocket { - type Target = SocketWrapper; + type Target = SocketWrapper; fn deref(&self) -> &Self::Target { &self.0 @@ -115,7 +118,7 @@ impl ServerSocket { } impl Deref for ServerSocket { - type Target = SocketWrapper; + type Target = SocketWrapper; fn deref(&self) -> &Self::Target { &self.0 diff --git a/userspace/rsh/src/main.rs b/userspace/rsh/src/main.rs index 56b2a2de..89f87e78 100644 --- a/userspace/rsh/src/main.rs +++ b/userspace/rsh/src/main.rs @@ -12,7 +12,7 @@ use clap::Parser; use cross::io::Poll; use libterm::{RawMode, RawTerminal}; use rsh::{ - proto::{ClientData, ClientMessage, ServerData, ServerMessage, TerminalInfo}, + proto::{ClientMessage, ServerMessage, TerminalInfo}, ClientSocket, }; @@ -44,13 +44,14 @@ pub struct Client { stdout: Stdout, need_bye: bool, last_ping: Instant, - _raw: RawMode + _raw: RawMode, } -pub enum Event { - Stdin(ClientData), - Data(ServerData), - Disconnected(String), +pub enum Event<'b> { + Stdin(&'b [u8]), + Data(&'b [u8]), + Disconnected(&'b str), + Ping, } impl Client { @@ -91,8 +92,8 @@ impl Client { }) } - pub fn poll(&mut self) -> Result { - let mut buf = [0; 16]; + pub fn poll<'b>(&mut self, buffer: &'b mut [u8], pty_max: usize) -> Result>, Error> { + // let mut buf = [0; 16]; let event = loop { let Some(event) = self.poll.wait(Some(Duration::from_millis(500)))? else { self.check_timeout()?; @@ -101,29 +102,26 @@ impl Client { match event { fd if fd == self.socket.as_raw_fd() => { - let (message, _) = self.socket.recv_from()?; + let (message, _) = self.socket.recv_from(buffer)?; match message { ServerMessage::Bye(reason) => { // No need for a bye self.need_bye = false; - break Ok(Event::Disconnected(reason)); + break Ok(Some(Event::Disconnected(reason))); } ServerMessage::Output(data) => { - break Ok(Event::Data(data)); + break Ok(Some(Event::Data(data))); } ServerMessage::Ping => { - self.last_ping = Instant::now(); - self.socket.send(&ClientMessage::Pong).ok(); - continue; + break Ok(Some(Event::Ping)); } // Ignore this one - ServerMessage::Hello => continue, + ServerMessage::Hello => break Ok(None), } } fd if fd == self.stdin.as_raw_fd() => { - let len = self.stdin.read(&mut buf)?; - let data = ClientData::from_slice(&buf[..len]); - break Ok(Event::Stdin(data)); + let len = self.stdin.read(&mut buffer[..pty_max])?; + break Ok(Some(Event::Stdin(&buffer[..len]))); } _ => unreachable!() } @@ -135,19 +133,27 @@ impl Client { } pub fn run(mut self) -> Result { + let mut recv_buf = [0; 256]; + let mut send_buf = [0; 256]; loop { - let event = self.poll()?; + let Some(event) = self.poll(&mut recv_buf, 64)? else { + continue; + }; match event { Event::Data(data) => { - self.stdout.write_all(&data).ok(); - self.stdout.flush().ok(); + self.stdout.write_all(data)?; + self.stdout.flush()?; } Event::Stdin(data) => { - self.socket.send(&ClientMessage::Input(data))?; + self.socket.send(&mut send_buf, &ClientMessage::Input(data))?; + } + Event::Ping => { + self.last_ping = Instant::now(); + self.socket.send(&mut send_buf, &ClientMessage::Pong)?; } Event::Disconnected(reason) => { - break Ok(reason); + break Ok(reason.into()); } } } @@ -187,15 +193,16 @@ impl Client { terminal: TerminalInfo, timeout: Duration, ) -> Result<(), Error> { - socket.send(&ClientMessage::Hello(terminal))?; + let mut buffer = [0; 512]; + socket.send(&mut buffer, &ClientMessage::Hello(terminal))?; if poll.wait(Some(timeout))?.is_none() { return Err(Error::Timeout); }; - let (message, _) = socket.recv_from()?; + let (message, _) = socket.recv_from(&mut buffer)?; match message { ServerMessage::Hello => Ok(()), - ServerMessage::Bye(reason) => Err(Error::Disconnected(reason)), + ServerMessage::Bye(reason) => Err(Error::Disconnected(reason.into())), _ => Err(Error::Disconnected("Invalid message received".into())), } } @@ -204,7 +211,10 @@ impl Client { impl Drop for Client { fn drop(&mut self) { if self.need_bye { - self.socket.send(&ClientMessage::Bye("".into())).ok(); + let mut buf = [0; 32]; + self.socket + .send(&mut buf, &ClientMessage::Bye("".into())) + .ok(); } } } diff --git a/userspace/rsh/src/proto.rs b/userspace/rsh/src/proto.rs index 6487a2a8..b1dcc899 100644 --- a/userspace/rsh/src/proto.rs +++ b/userspace/rsh/src/proto.rs @@ -1,27 +1,251 @@ -use serde::{Deserialize, Serialize}; -use smallvec::SmallVec; - -pub type ServerData = SmallVec<[u8; 64]>; -pub type ClientData = SmallVec<[u8; 16]>; - -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy)] pub struct TerminalInfo { pub columns: u32, - pub rows: u32 + pub rows: u32, } -#[derive(Debug, Serialize, Deserialize)] -pub enum ClientMessage { +pub struct ClientMessageProxy; +pub struct ServerMessageProxy; + +#[derive(Debug)] +pub enum ClientMessage<'a> { Hello(TerminalInfo), - Bye(String), + Bye(&'a str), Pong, - Input(ClientData), + Input(&'a [u8]), } -#[derive(Debug, Serialize, Deserialize)] -pub enum ServerMessage { +#[derive(Debug)] +pub enum ServerMessage<'a> { Hello, - Bye(String), + Bye(&'a str), Ping, - Output(ServerData), + Output(&'a [u8]), +} + +#[derive(Debug, thiserror::Error)] +pub enum EncodeError { + #[error("Message too large")] + MessageTooLarge, + #[error("Value too long")] + ValueTooLong, +} + +#[derive(Debug, thiserror::Error)] +pub enum DecodeError { + #[error("Truncated message received")] + Truncated, + #[error("Malformed message received")] + InvalidMessage, + #[error("Malformed string in the message")] + InvalidString(core::str::Utf8Error), +} + +pub struct Encoder<'a> { + buffer: &'a mut [u8], + pos: usize, +} + +pub struct Decoder<'a> { + buffer: &'a [u8], + pos: usize, +} + +pub trait Encode { + fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError>; +} + +pub trait Decode<'de>: Sized + 'de { + fn decode(buffer: &mut Decoder<'de>) -> Result; +} + +pub trait MessageProxy { + type Type<'de>: Encode + Decode<'de>; +} + +impl<'a> Encoder<'a> { + pub const fn new(buffer: &'a mut [u8]) -> Self { + Self { buffer, pos: 0 } + } + + pub fn write_variable_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> { + let len: u32 = bytes + .len() + .try_into() + .map_err(|_| EncodeError::ValueTooLong)?; + self.write(&len.to_le_bytes())?; + self.write(bytes) + } + + pub fn write_str(&mut self, s: &str) -> Result<(), EncodeError> { + self.write_variable_bytes(s.as_bytes()) + } + + pub fn write(&mut self, bytes: &[u8]) -> Result<(), EncodeError> { + if self.pos + bytes.len() > self.buffer.len() { + return Err(EncodeError::ValueTooLong); + } + self.buffer[self.pos..self.pos + bytes.len()].copy_from_slice(bytes); + self.pos += bytes.len(); + Ok(()) + } + + pub fn get(&self) -> &[u8] { + &self.buffer[..self.pos] + } +} + +impl<'a> Decoder<'a> { + pub fn new(buffer: &'a [u8]) -> Self { + Self { buffer, pos: 0 } + } + + pub fn read_u8(&mut self) -> Result { + if self.pos + 1 > self.buffer.len() { + return Err(DecodeError::Truncated); + } + let byte = self.buffer[self.pos]; + self.pos += 1; + Ok(byte) + } + + pub fn read_le_u32(&mut self) -> Result { + let bytes = self.read_bytes(size_of::())?; + Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])) + } + + pub fn read_bytes(&mut self, len: usize) -> Result<&'a [u8], DecodeError> { + if self.pos + len > self.buffer.len() { + return Err(DecodeError::Truncated); + } + let slice = &self.buffer[self.pos..self.pos + len]; + self.pos += len; + Ok(slice) + } + + pub fn read_variable_bytes(&mut self) -> Result<&'a [u8], DecodeError> { + let len = self.read_le_u32()?; + self.read_bytes(len as usize) + } + + pub fn read_str(&mut self) -> Result<&'a str, DecodeError> { + let slice = self.read_variable_bytes()?; + core::str::from_utf8(slice).map_err(DecodeError::InvalidString) + } +} + +impl MessageProxy for ClientMessageProxy { + type Type<'de> = ClientMessage<'de>; +} + +impl MessageProxy for ServerMessageProxy { + type Type<'de> = ServerMessage<'de>; +} + +impl ClientMessage<'_> { + const TAG_HELLO: u8 = 0x80; + const TAG_BYE: u8 = 0x81; + const TAG_PONG: u8 = 0x82; + const TAG_INPUT: u8 = 0x90; +} + +impl Encode for TerminalInfo { + fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> { + buffer.write(&self.columns.to_le_bytes())?; + buffer.write(&self.rows.to_le_bytes())?; + Ok(()) + } +} + +impl<'de> Decode<'de> for TerminalInfo { + fn decode(buffer: &mut Decoder<'de>) -> Result { + let columns = buffer.read_le_u32()?; + let rows = buffer.read_le_u32()?; + Ok(Self { columns, rows }) + } +} + +impl<'a> Encode for ClientMessage<'a> { + fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> { + match self { + Self::Hello(info) => { + buffer.write(&[Self::TAG_HELLO])?; + info.encode(buffer) + } + Self::Bye(reason) => { + buffer.write(&[Self::TAG_BYE])?; + buffer.write_str(reason) + } + Self::Pong => { + buffer.write(&[Self::TAG_PONG]) + } + Self::Input(data) => { + buffer.write(&[Self::TAG_INPUT])?; + buffer.write_variable_bytes(data) + } + } + } +} + +impl<'de> Decode<'de> for ClientMessage<'de> { + fn decode(buffer: &mut Decoder<'de>) -> Result { + let tag = buffer.read_u8()?; + match tag { + Self::TAG_HELLO => { + TerminalInfo::decode(buffer).map(Self::Hello) + } + Self::TAG_BYE => { + buffer.read_str().map(Self::Bye) + } + Self::TAG_PONG => { + Ok(Self::Pong) + } + Self::TAG_INPUT => { + buffer.read_variable_bytes().map(Self::Input) + } + _ => Err(DecodeError::InvalidMessage) + } + } +} + +impl ServerMessage<'_> { + const TAG_HELLO: u8 = 0x10; + const TAG_BYE: u8 = 0x11; + const TAG_PING: u8 = 0x12; + const TAG_OUTPUT: u8 = 0x20; +} + +impl<'a> Encode for ServerMessage<'a> { + fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> { + match self { + Self::Hello => buffer.write(&[Self::TAG_HELLO]), + Self::Bye(reason) => { + buffer.write(&[Self::TAG_BYE])?; + buffer.write_str(reason) + } + // TODO sequence number + Self::Ping => buffer.write(&[Self::TAG_PING]), + Self::Output(data) => { + buffer.write(&[Self::TAG_OUTPUT])?; + buffer.write_variable_bytes(data) + } + } + } +} + +impl<'de> Decode<'de> for ServerMessage<'de> { + fn decode(buffer: &mut Decoder<'de>) -> Result { + let tag = buffer.read_u8()?; + match tag { + Self::TAG_HELLO => Ok(Self::Hello), + Self::TAG_BYE => { + buffer.read_str().map(Self::Bye) + } + Self::TAG_PING => Ok(Self::Ping), + Self::TAG_OUTPUT => { + buffer.read_variable_bytes().map(Self::Output) + }, + _ => Err(DecodeError::InvalidMessage), + } + } } diff --git a/userspace/rsh/src/rshd/main.rs b/userspace/rsh/src/rshd/main.rs index e97459d6..d77d6648 100644 --- a/userspace/rsh/src/rshd/main.rs +++ b/userspace/rsh/src/rshd/main.rs @@ -13,7 +13,7 @@ use std::{ use cross::io::{Poll, TimerFd}; use rsh::{ - proto::{ClientData, ClientMessage, ServerData, ServerMessage, TerminalInfo}, + proto::{ClientMessage, ServerMessage, TerminalInfo}, Error, ServerSocket, }; @@ -35,17 +35,18 @@ pub struct Server { pty_to_session: HashMap, } -pub enum PtyEvent { - Data(ServerData), +pub enum PtyEvent<'b> { + Data(&'b [u8]), Err(Error), Closed, } -pub enum Event { +pub enum Event<'b> { NewClient(SocketAddr, TerminalInfo), - SessionInput(RawFd, SocketAddr, ClientData), - ClientBye(SocketAddr, String), - Pty(RawFd, SocketAddr, PtyEvent), + SessionInput(RawFd, SocketAddr, &'b [u8]), + ClientBye(SocketAddr, &'b str), + Pty(RawFd, SocketAddr, PtyEvent<'b>), + Tick, } impl Session { @@ -109,59 +110,64 @@ impl Server { }) } - pub fn poll(&mut self) -> Result { - let mut buf = [0; 64]; - loop { - let fd = self.poll.wait(None)?.unwrap(); + pub fn poll<'b>(&mut self, buffer: &'b mut [u8], pty_max: usize) -> Result>, Error> { + let fd = self.poll.wait(None)?.unwrap(); - match fd { - fd if fd == self.socket.as_raw_fd() => { - let (message, remote) = self.socket.recv_from()?; + match fd { + fd if fd == self.socket.as_raw_fd() => { + let (message, remote) = match self.socket.recv_from(buffer) { + Ok((message, remote)) => (message, remote), + Err(error @ (Error::Decode(_) | Error::Truncated)) => { + eprintln!("Receive error: {error}"); + return Ok(None) + }, + Err(error) => return Err(error), + }; - let event = match message { - ClientMessage::Hello(terminal) - if self.addr_to_session.get(&remote).is_none() => - { - Event::NewClient(remote, terminal) - } - ClientMessage::Bye(reason) => Event::ClientBye(remote, reason), - ClientMessage::Input(data) - if let Some(fd) = self.addr_to_session.get(&remote) => - { - Event::SessionInput(*fd, remote, data) - } - ClientMessage::Pong if let Some(fd) = self.addr_to_session.get(&remote) => { - let session = self.pty_to_session.get_mut(fd).unwrap(); - session.timeouts = 0; - continue - }, - _ => continue, - }; + let event = match message { + ClientMessage::Hello(terminal) + if self.addr_to_session.get(&remote).is_none() => + { + Event::NewClient(remote, terminal) + } + ClientMessage::Bye(reason) => Event::ClientBye(remote, reason), + ClientMessage::Input(data) + if let Some(fd) = self.addr_to_session.get(&remote) => + { + Event::SessionInput(*fd, remote, data) + } + ClientMessage::Pong if let Some(fd) = self.addr_to_session.get(&remote) => { + let session = self.pty_to_session.get_mut(fd).unwrap(); + session.timeouts = 0; + return Ok(None); + } + _ => return Ok(None), + }; - break Ok(event); - } - fd if fd == self.timer.as_raw_fd() => { - // Restart the timer - self.update_client_timeouts()?; - self.timer.start(PING_INTERVAL)?; - } - fd => { - let session = self.pty_to_session.get_mut(&fd).unwrap(); - let event = match session.pty_master.read(&mut buf) { - Ok(0) => PtyEvent::Closed, - Ok(len) => PtyEvent::Data(buf[..len].into()), - Err(e) => PtyEvent::Err(e.into()), - }; - break Ok(Event::Pty(fd, session.remote, event)); - } + Ok(Some(event)) + } + fd if fd == self.timer.as_raw_fd() => Ok(Some(Event::Tick)), + fd => { + let session = self.pty_to_session.get_mut(&fd).unwrap(); + let event = match session.pty_master.read(&mut buffer[..pty_max]) { + Ok(0) => PtyEvent::Closed, + Ok(len) => PtyEvent::Data(&buffer[..len]), + Err(e) => PtyEvent::Err(e.into()), + }; + Ok(Some(Event::Pty(fd, session.remote, event))) } } } pub fn run(mut self) -> Result<(), Error> { self.timer.start(PING_INTERVAL)?; + let mut recv_buf = [0; 256]; + let mut send_buf = [0; 256]; + loop { - let event = self.poll()?; + let Some(event) = self.poll(&mut recv_buf, 128)? else { + continue; + }; match event { Event::SessionInput(fd, remote, data) => { @@ -170,7 +176,7 @@ impl Server { eprintln!("PTY write error: {error}"); self.remove_session_by_fd(fd)?; self.socket - .send_to(&remote, &ServerMessage::Bye(format!("PTY error: {error}"))) + .send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error")) .ok(); } } @@ -183,46 +189,57 @@ impl Server { match Session::open(&terminal, remote) { Ok(session) => { self.register_session(remote, session)?; - self.socket.send_to(&remote, &ServerMessage::Hello).ok(); + self.socket + .send_to(&remote, &mut send_buf, &ServerMessage::Hello) + .ok(); } Err(err) => { eprintln!("PTY open error: {err}"); self.socket - .send_to(&remote, &ServerMessage::Bye("PTY open error".into())) + .send_to( + &remote, + &mut send_buf, + &ServerMessage::Bye("PTY open error"), + ) .ok(); } } } Event::Pty(fd, remote, event) => match event { + PtyEvent::Data(data) => { + self.socket + .send_to(&remote, &mut send_buf, &ServerMessage::Output(data)) + .ok(); + }, PtyEvent::Err(error) => { eprintln!("PTY read error: {error}"); self.remove_session_by_fd(fd)?; self.socket - .send_to(&remote, &ServerMessage::Bye(format!("PTY error: {error}"))) + .send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error")) .ok(); - } - PtyEvent::Data(data) => { - self.socket - .send_to(&remote, &ServerMessage::Output(data)) - .ok(); - } + }, PtyEvent::Closed => { println!("End of PTY for {remote}"); self.remove_session_by_fd(fd)?; self.socket - .send_to(&remote, &ServerMessage::Bye("".into())) + .send_to(&remote, &mut send_buf, &ServerMessage::Bye("")) .ok(); - } + }, }, + Event::Tick => { + // Restart the timer + self.update_client_timeouts(&mut send_buf)?; + self.timer.start(PING_INTERVAL)?; + } } } } - fn update_client_timeouts(&mut self) -> Result<(), Error> { + fn update_client_timeouts(&mut self, send_buf: &mut [u8]) -> Result<(), Error> { let mut removed = vec![]; for (remote, fd) in self.addr_to_session.iter() { let session = self.pty_to_session.get_mut(&fd).unwrap(); - self.socket.send_to(remote, &ServerMessage::Ping).ok(); + self.socket.send_to(remote, send_buf, &ServerMessage::Ping).ok(); session.timeouts += 1; if session.timeouts >= 10 { removed.push((*remote, *fd)); @@ -232,7 +249,7 @@ impl Server { for (remote, fd) in removed { eprintln!("Client {remote} timed out"); self.remove_session_by_fd(fd)?; - self.socket.send_to(&remote, &ServerMessage::Bye("Timed out".into())).ok(); + self.socket.send_to(&remote, send_buf, &ServerMessage::Bye("Timed out".into())).ok(); } Ok(()) }