From a8a61926279ec88088d56a4d1b62f229f6fc1e4a Mon Sep 17 00:00:00 2001 From: Mark Poliakov <mark@alnyan.me> Date: Sat, 2 Nov 2024 20:22:53 +0200 Subject: [PATCH] rsh: better server modularity --- userspace/rsh/src/crypt/mod.rs | 28 -- userspace/rsh/src/crypt/signature.rs | 2 +- userspace/rsh/src/lib.rs | 5 +- userspace/rsh/src/main.rs | 4 +- userspace/rsh/src/rshd/main.rs | 354 +++++++------------------- userspace/rsh/src/server.rs | 367 +++++++++++++++++++++++++++ 6 files changed, 465 insertions(+), 295 deletions(-) create mode 100644 userspace/rsh/src/server.rs diff --git a/userspace/rsh/src/crypt/mod.rs b/userspace/rsh/src/crypt/mod.rs index 1473dc3f..f37d4c13 100644 --- a/userspace/rsh/src/crypt/mod.rs +++ b/userspace/rsh/src/crypt/mod.rs @@ -354,37 +354,9 @@ pub fn ciphersuite_name(cipher: u8) -> Option<&'static str> { } } -// pub fn hash_algo_name(hash: u8) -> Option<&'static str> { -// match hash { -// V1_HASH_SHA256 => Some("sha256"), -// _ => None, -// } -// } - pub fn sig_algo_name(sig: u8) -> Option<&'static str> { match sig { V1_SIG_ED25519 => Some("ed25519"), _ => None, } } - -// impl fmt::Display for PublicKeyFingerprint<'_> { -// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { -// let hash = match hash_algo_name(self.hash) { -// Some(name) => name, -// None => "unknown", -// }; -// let sig = match sig_algo_name(self.sig) { -// Some(name) => name, -// None => "unknown", -// }; -// write!(f, "{} {} {} ", sig, self.key_bits, hash)?; -// for (i, byte) in self.hash_data.iter().enumerate() { -// if i != 0 { -// write!(f, ":")?; -// } -// write!(f, "{:02x}", *byte)?; -// } -// Ok(()) -// } -// } diff --git a/userspace/rsh/src/crypt/signature.rs b/userspace/rsh/src/crypt/signature.rs index fe66e1ff..91b80f61 100644 --- a/userspace/rsh/src/crypt/signature.rs +++ b/userspace/rsh/src/crypt/signature.rs @@ -79,7 +79,7 @@ impl SignEd25519 { } pub fn load_signing_key<P: AsRef<Path>>(path: P) -> Result<Self, Error> { - let signing_key = ed25519_dalek::SigningKey::read_pkcs8_pem_file(path).unwrap(); + let signing_key = ed25519_dalek::SigningKey::read_pkcs8_pem_file(path).map_err(|_| Error::InvalidKey)?; let verifying_key = signing_key.verifying_key(); Ok(Self { signing_key, diff --git a/userspace/rsh/src/lib.rs b/userspace/rsh/src/lib.rs index 2d064730..7e07e065 100644 --- a/userspace/rsh/src/lib.rs +++ b/userspace/rsh/src/lib.rs @@ -1,5 +1,5 @@ #![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))] -#![feature(generic_const_exprs, portable_simd)] +#![feature(generic_const_exprs, portable_simd, if_let_guard)] #![allow(incomplete_features)] use std::io; @@ -9,12 +9,13 @@ use proto::{DecodeError, EncodeError}; pub mod proto; pub mod socket; pub mod crypt; +pub mod server; pub use socket::{ClientSocket, ServerSocket}; #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("I/O error")] + #[error("I/O error: {0}")] Io(#[from] io::Error), #[error("Could not send a message fully")] Truncated, diff --git a/userspace/rsh/src/main.rs b/userspace/rsh/src/main.rs index 05ea91f7..de3a49df 100644 --- a/userspace/rsh/src/main.rs +++ b/userspace/rsh/src/main.rs @@ -31,6 +31,8 @@ pub enum Error { struct Args { #[clap(short, long)] key: PathBuf, + #[clap(short = 'P', long, default_value_t = 77)] + port: u16, remote: IpAddr, } @@ -223,7 +225,7 @@ fn terminal_info(stdout: &Stdout) -> Result<TerminalInfo, Error> { } fn run(args: Args) -> Result<(), Error> { - let remote = SocketAddr::new(args.remote, 77); + let remote = SocketAddr::new(args.remote, args.port); let ed25519 = SignEd25519::load_signing_key(args.key).unwrap(); let key = SignatureMethod::Ed25519(ed25519); let config = ClientConfig { diff --git a/userspace/rsh/src/rshd/main.rs b/userspace/rsh/src/rshd/main.rs index b53ac62d..de599feb 100644 --- a/userspace/rsh/src/rshd/main.rs +++ b/userspace/rsh/src/rshd/main.rs @@ -1,15 +1,18 @@ #![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os, rustc_private))] #![feature(if_let_guard)] use std::{ - collections::{HashMap, HashSet}, fs::File, io::{Read, Write}, net::{SocketAddr, UdpSocket}, os::fd::{AsRawFd, FromRawFd, RawFd}, path::PathBuf, process::{Child, Command, ExitCode, Stdio}, str::FromStr, time::Duration + collections::HashSet, + net::SocketAddr, + path::PathBuf, + process::ExitCode, + str::FromStr, + time::Duration, }; use clap::Parser; -use cross::io::{Poll, TimerFd}; use rsh::{ - crypt::{server::ServerConfig, ServerEncryptedSocket, SimpleServerKeyStore}, - proto::{ClientMessage, Decode, Decoder, ServerMessage, TerminalInfo}, - socket::{MultiplexedSocket, MultiplexedSocketEvent}, + crypt::{server::ServerConfig, SimpleServerKeyStore}, + server::Server, Error, }; @@ -19,287 +22,112 @@ pub const PING_INTERVAL: Duration = Duration::from_millis(500); struct Args { #[clap(short = 'P', long, help = "rsh listen port", default_value_t = 77)] port: u16, - #[clap(short = 'S', long, help = "where rsh will load private keys from", default_value = "/etc/rsh")] - keystore: PathBuf + #[clap( + short = 'S', + long, + help = "where rsh will load private keys from", + default_value = "/etc/rsh" + )] + keystore: PathBuf, } -pub struct Session { - pty_master: File, +#[cfg(target_os = "yggdrasil")] +pub struct YggdrasilSession { + pty_master: std::fs::File, + fds: [std::os::fd::RawFd; 1], remote: SocketAddr, - shell: Child, + shell: std::process::Child, } -pub struct Server { - poll: Poll, - timer: TimerFd, - socket: ServerEncryptedSocket<UdpSocket>, +#[cfg(target_os = "yggdrasil")] +impl rsh::server::Session for YggdrasilSession { + type Error = std::io::Error; - addr_to_session: HashMap<SocketAddr, RawFd>, - pty_to_session: HashMap<RawFd, Session>, -} - -pub enum PtyEvent<'b> { - Data(&'b [u8]), - Err(Error), - Closed, -} - -pub enum Event<'b> { - NewClient(SocketAddr, TerminalInfo), - SessionInput(RawFd, SocketAddr, &'b [u8]), - ClientBye(SocketAddr, &'b str), - Pty(RawFd, SocketAddr, PtyEvent<'b>), - Tick, -} - -impl Session { - pub fn open(info: &TerminalInfo, remote: SocketAddr) -> Result<Self, Error> { - #[cfg(target_os = "yggdrasil")] - { - use std::os::yggdrasil::{ - self, - io::terminal::{create_pty, TerminalSize}, - process::CommandExt, - }; - // TODO unix version - let (pty_master, pty_slave) = create_pty( - Default::default(), - TerminalSize { - columns: info.columns as _, - rows: info.rows as _, + fn open(remote: &SocketAddr, terminal: &rsh::proto::TerminalInfo) -> Result<Self, Self::Error> { + use std::{ + os::{ + fd::{AsRawFd, FromRawFd}, + yggdrasil::{ + self, + io::terminal::{create_pty, TerminalSize}, + process::CommandExt, }, - )?; + }, + process::{Command, Stdio}, + }; - let pty_slave_fd = pty_slave.as_raw_fd(); - let group_id = yggdrasil::process::create_process_group(); - let shell = unsafe { - Command::new("/bin/sh") - .arg("-l") - .stdin(Stdio::from_raw_fd(pty_slave_fd)) - .stdout(Stdio::from_raw_fd(pty_slave_fd)) - .stderr(Stdio::from_raw_fd(pty_slave_fd)) - .process_group(group_id) - .gain_terminal(0) - .spawn()? - }; + let remote = *remote; + // TODO unix version + let (pty_master, pty_slave) = create_pty( + Default::default(), + TerminalSize { + columns: terminal.columns as _, + rows: terminal.rows as _, + }, + )?; - Ok(Self { - pty_master, - shell, - remote, - }) - } - #[cfg(unix)] - { - todo!() - } - } -} + let pty_slave_fd = pty_slave.as_raw_fd(); + let group_id = yggdrasil::process::create_process_group(); + let shell = unsafe { + Command::new("/bin/sh") + .arg("-l") + .stdin(Stdio::from_raw_fd(pty_slave_fd)) + .stdout(Stdio::from_raw_fd(pty_slave_fd)) + .stderr(Stdio::from_raw_fd(pty_slave_fd)) + .process_group(group_id) + .gain_terminal(0) + .spawn()? + }; + + let fds = [pty_master.as_raw_fd()]; -impl Server { - pub fn new(listen_addr: SocketAddr, crypto_config: ServerConfig) -> Result<Self, Error> { - let mut poll = Poll::new()?; - let timer = TimerFd::new()?; - let socket = UdpSocket::bind(listen_addr)?; - let socket = ServerEncryptedSocket::new_with_config(socket, crypto_config); - poll.add(&socket)?; - poll.add(&timer)?; Ok(Self { - poll, - socket, - timer, - addr_to_session: HashMap::new(), - pty_to_session: HashMap::new(), + pty_master, + shell, + remote, + fds, }) } - pub fn poll<'b>( - &mut self, - buffer: &'b mut [u8], - pty_max: usize, - ) -> Result<Option<Event<'b>>, Error> { - let fd = self.poll.wait(None)?.unwrap(); - - match fd { - fd if fd == self.socket.as_raw_fd() => { - let event = self.socket.recv_from(buffer)?; - - let (message, remote) = match event { - MultiplexedSocketEvent::ClientDisconnected(remote) => { - self.remove_session_by_remote(remote).ok(); - return Ok(None); - } - MultiplexedSocketEvent::None(_) => return Ok(None), - MultiplexedSocketEvent::ClientData(peer, data) => { - let mut decoder = Decoder::new(data); - let message = ClientMessage::decode(&mut decoder); - (message, peer) - } - MultiplexedSocketEvent::Error(_) => return Ok(None), - }; - - let message = match message { - Ok(message) => message, - Err(error) => { - eprintln!("Decode error: {error}"); - return Ok(None); - } - }; - - let event = match message { - ClientMessage::Hello(terminal) - if self.addr_to_session.get(&remote).is_none() => - { - Event::NewClient(remote, terminal) - } - ClientMessage::Bye(reason) => Event::ClientBye(remote, reason), - ClientMessage::Input(data) - if let Some(fd) = self.addr_to_session.get(&remote) => - { - Event::SessionInput(*fd, remote, data) - } - _ => return Ok(None), - }; - - Ok(Some(event)) - } - fd if fd == self.timer.as_raw_fd() => Ok(Some(Event::Tick)), - fd => { - let session = self.pty_to_session.get_mut(&fd).unwrap(); - let event = match session.pty_master.read(&mut buffer[..pty_max]) { - Ok(0) => PtyEvent::Closed, - Ok(len) => PtyEvent::Data(&buffer[..len]), - Err(e) => PtyEvent::Err(e.into()), - }; - Ok(Some(Event::Pty(fd, session.remote, event))) - } - } - } - - pub fn run(mut self) -> Result<(), Error> { - self.timer.start(PING_INTERVAL)?; - let mut recv_buf = [0; 256]; - let mut send_buf = [0; 256]; - - loop { - let Some(event) = self.poll(&mut recv_buf, 128)? else { - continue; - }; - - match event { - Event::SessionInput(fd, remote, data) => { - let session = self.pty_to_session.get_mut(&fd).unwrap(); - if let Err(error) = session.pty_master.write(&data) { - eprintln!("PTY write error: {error}"); - self.socket - .send_message_to( - &remote, - &mut send_buf, - &ServerMessage::Bye("PTY error"), - ) - .ok(); - self.remove_session_by_fd(fd)?; - } - } - Event::ClientBye(remote, reason) => { - println!("Client {remote} disconnected: {reason}"); - self.remove_session_by_remote(remote)?; - } - Event::NewClient(remote, terminal) => { - println!("New client: {remote}"); - match Session::open(&terminal, remote) { - Ok(session) => { - self.register_session(remote, session)?; - self.socket - .send_message_to(&remote, &mut send_buf, &ServerMessage::Hello) - .ok(); - } - Err(err) => { - eprintln!("PTY open error: {err}"); - self.socket - .send_message_to( - &remote, - &mut send_buf, - &ServerMessage::Bye("PTY open error"), - ) - .ok(); - self.socket.remove_client(&remote); - } - } - } - Event::Pty(fd, remote, event) => match event { - PtyEvent::Data(data) => { - self.socket - .send_message_to(&remote, &mut send_buf, &ServerMessage::Output(data)) - .ok(); - } - PtyEvent::Err(error) => { - eprintln!("PTY read error: {error}"); - self.socket - .send_message_to( - &remote, - &mut send_buf, - &ServerMessage::Bye("PTY error"), - ) - .ok(); - self.remove_session_by_fd(fd)?; - } - PtyEvent::Closed => { - println!("End of PTY for {remote}"); - self.socket - .send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("")) - .ok(); - self.remove_session_by_fd(fd)?; - } - }, - Event::Tick => { - // Restart the timer - self.update_client_timeouts()?; - self.timer.start(PING_INTERVAL)?; - } - } - } - } - - fn update_client_timeouts(&mut self) -> Result<(), Error> { - let removed = self.socket.ping_clients(8); - for entry in removed { - log::debug!("Client timed out: {entry}"); - self.remove_session_by_remote(entry).ok(); - } + fn close(mut self) -> Result<(), Self::Error> { + self.shell.wait()?; Ok(()) } - fn register_session(&mut self, remote: SocketAddr, session: Session) -> Result<(), Error> { - let fd = session.pty_master.as_raw_fd(); - self.addr_to_session.insert(remote, fd); - self.pty_to_session.insert(fd, session); - self.poll.add(&fd).map_err(Error::from) + fn peer(&self) -> SocketAddr { + self.remote } - fn remove_session_by_fd(&mut self, fd: RawFd) -> Result<Option<Session>, Error> { - if let Some(mut session) = self.pty_to_session.remove(&fd) { - // TODO: implement kernel support for pidfd or something, to poll the exit status of - // the task instead of doing it here. - // NOTE: this will block the whole server while the process finishes. - session.shell.wait().ok(); - self.addr_to_session.remove(&session.remote).unwrap(); - self.socket.remove_client(&session.remote); - self.poll.remove(&fd)?; - Ok(Some(session)) - } else { - Ok(None) - } + fn handle_input<'s, S: rsh::socket::PacketSocket>( + &mut self, + input: &[u8], + _client: rsh::server::SessionClient<'s, S>, + ) -> Result<bool, Self::Error> { + use std::io::Write; + self.pty_master.write_all(input)?; + Ok(false) } - fn remove_session_by_remote(&mut self, remote: SocketAddr) -> Result<Option<Session>, Error> { - let Some(fd) = self.addr_to_session.get(&remote).copied() else { - return Ok(None); - }; - self.remove_session_by_fd(fd) + fn read_output( + &mut self, + fd: std::os::fd::RawFd, + buffer: &mut [u8], + ) -> Result<usize, Self::Error> { + use std::io::Read; + assert_eq!(fd, self.fds[0]); + self.pty_master.read(buffer) + } + + fn event_fds(&self) -> &[std::os::fd::RawFd] { + &self.fds } } +#[cfg(unix)] +pub type SessionImpl = rsh::server::EchoSession; +#[cfg(target_os = "yggdrasil")] +pub type SessionImpl = YggdrasilSession; + fn run(args: Args) -> Result<(), Error> { let keystore = Box::new(SimpleServerKeyStore { path: args.keystore, @@ -312,7 +140,7 @@ fn run(args: Args) -> Result<(), Error> { ..Default::default() }; let listen_addr = SocketAddr::from_str(&format!("0.0.0.0:{}", args.port)).unwrap(); - let server = Server::new(listen_addr, server_config)?; + let server = Server::<_, SessionImpl>::listen_udp(listen_addr, server_config)?; server.run() } diff --git a/userspace/rsh/src/server.rs b/userspace/rsh/src/server.rs new file mode 100644 index 00000000..e059c858 --- /dev/null +++ b/userspace/rsh/src/server.rs @@ -0,0 +1,367 @@ +use std::{ + collections::{hash_map::Entry, HashMap}, + fmt, + net::{SocketAddr, UdpSocket}, + os::fd::{AsRawFd, RawFd}, + time::Duration, +}; + +use cross::io::{Poll, TimerFd}; + +use crate::{ + crypt::{server::ServerConfig, ServerEncryptedSocket}, + proto::{ClientMessage, Decode, Decoder, ServerMessage, TerminalInfo}, + socket::{MultiplexedSocket, MultiplexedSocketEvent, PacketSocket}, + Error, +}; + +pub const PING_INTERVAL: Duration = Duration::from_millis(500); + +pub trait Session: Sized { + type Error: fmt::Display; + + fn open(peer: &SocketAddr, terminal: &TerminalInfo) -> Result<Self, Self::Error>; + fn peer(&self) -> SocketAddr; + fn handle_input<'s, S: PacketSocket>( + &mut self, + input: &[u8], + client: SessionClient<'s, S>, + ) -> Result<bool, Self::Error>; + fn read_output(&mut self, fd: RawFd, buffer: &mut [u8]) -> Result<usize, Self::Error>; + fn event_fds(&self) -> &[RawFd]; + fn close(self) -> Result<(), Self::Error>; +} + +pub struct SessionClient<'s, S: PacketSocket> { + address: SocketAddr, + transport: &'s mut ServerEncryptedSocket<S>, + send_buf: &'s mut [u8], +} + +pub struct EchoSession { + peer: SocketAddr, +} + +enum SessionEvent<'b, T: Session> { + Data(&'b [u8]), + Err(T::Error), + Closed, +} + +enum Event<'b, T: Session> { + NewClient(SocketAddr, TerminalInfo), + SessionInput(u64, SocketAddr, &'b [u8]), + ClientBye(SocketAddr, &'b str), + SessionEvent(RawFd, SocketAddr, SessionEvent<'b, T>), + Tick, +} + +pub struct Server<S: PacketSocket, T: Session> { + poll: Poll, + timer: TimerFd, + socket: ServerEncryptedSocket<S>, + + last_session_key: u64, + sessions: HashMap<u64, T>, + peer_to_session: HashMap<SocketAddr, u64>, + session_event_map: HashMap<RawFd, u64>, +} + +impl<T: Session> Server<UdpSocket, T> { + pub fn listen_udp(listen_addr: SocketAddr, crypto_config: ServerConfig) -> Result<Self, Error> { + let mut poll = Poll::new()?; + let timer = TimerFd::new()?; + let socket = UdpSocket::bind(listen_addr)?; + let socket = ServerEncryptedSocket::new_with_config(socket, crypto_config); + poll.add(&socket)?; + poll.add(&timer)?; + Ok(Self { + poll, + socket, + timer, + last_session_key: 1, + sessions: HashMap::new(), + peer_to_session: HashMap::new(), + session_event_map: HashMap::new(), + }) + } + + fn poll<'b>( + &mut self, + buffer: &'b mut [u8], + pty_max: usize, + ) -> Result<Option<Event<'b, T>>, Error> { + let fd = self.poll.wait(None)?.unwrap(); + + match fd { + fd if fd == self.socket.as_raw_fd() => { + let event = self.socket.recv_from(buffer)?; + + let (message, remote) = match event { + MultiplexedSocketEvent::ClientDisconnected(remote) => { + self.remove_session_by_remote(remote).ok(); + return Ok(None); + } + MultiplexedSocketEvent::None(_) => return Ok(None), + MultiplexedSocketEvent::ClientData(peer, data) => { + let mut decoder = Decoder::new(data); + let message = ClientMessage::decode(&mut decoder); + (message, peer) + } + MultiplexedSocketEvent::Error(_) => return Ok(None), + }; + + let message = match message { + Ok(message) => message, + Err(error) => { + log::warn!("Decode error: {error}"); + return Ok(None); + } + }; + + let event = match message { + ClientMessage::Hello(terminal) + if self.peer_to_session.get(&remote).is_none() => + { + Event::NewClient(remote, terminal) + } + ClientMessage::Bye(reason) => Event::ClientBye(remote, reason), + ClientMessage::Input(data) + if let Some(fd) = self.peer_to_session.get(&remote) => + { + Event::SessionInput(*fd, remote, data) + } + _ => return Ok(None), + }; + + Ok(Some(event)) + } + fd if fd == self.timer.as_raw_fd() => Ok(Some(Event::Tick)), + fd => { + // Otherwise the event comes from a session + let key = *self.session_event_map.get(&fd).unwrap(); + let session = self.sessions.get_mut(&key).unwrap(); + let event = match session.read_output(fd, &mut buffer[..pty_max]) { + Ok(0) => SessionEvent::Closed, + Ok(len) => SessionEvent::Data(&buffer[..len]), + Err(e) => SessionEvent::Err(e), + }; + Ok(Some(Event::SessionEvent(fd, session.peer(), event))) + } + } + } + + pub fn run(mut self) -> Result<(), Error> { + self.timer.start(PING_INTERVAL)?; + let mut recv_buf = [0; 256]; + let mut send_buf = [0; 256]; + + loop { + let Some(event) = self.poll(&mut recv_buf, 128)? else { + continue; + }; + + match event { + Event::SessionInput(key, remote, data) => { + let session = self.sessions.get_mut(&key).unwrap(); + let peer = SessionClient { + address: remote, + send_buf: &mut send_buf, + transport: &mut self.socket, + }; + match session.handle_input(data, peer) { + Ok(false) => (), + Ok(true) => { + log::debug!("{remote}: session closed"); + self.socket + .send_message_to( + &remote, + &mut send_buf, + &ServerMessage::Bye("Session closed"), + ) + .ok(); + self.remove_session_by_key(key)?; + } + Err(error) => { + log::error!("{remote}: session input error: {error}"); + self.socket + .send_message_to( + &remote, + &mut send_buf, + &ServerMessage::Bye("Session error"), + ) + .ok(); + self.remove_session_by_key(key)?; + } + } + } + Event::ClientBye(remote, reason) => { + log::debug!("Client {remote} disconnected: {reason}"); + self.remove_session_by_remote(remote)?; + } + Event::NewClient(remote, terminal) => { + log::debug!("New client: {remote}"); + match T::open(&remote, &terminal) { + Ok(session) => { + self.register_session(remote, session)?; + self.socket + .send_message_to(&remote, &mut send_buf, &ServerMessage::Hello) + .ok(); + } + Err(err) => { + log::error!("PTY open error: {err}"); + self.socket + .send_message_to( + &remote, + &mut send_buf, + &ServerMessage::Bye("PTY open error"), + ) + .ok(); + self.socket.remove_client(&remote); + } + } + } + Event::SessionEvent(fd, remote, event) => match event { + SessionEvent::Data(data) => { + self.socket + .send_message_to(&remote, &mut send_buf, &ServerMessage::Output(data)) + .ok(); + } + SessionEvent::Err(error) => { + log::error!("Session output read error: {error}"); + self.socket + .send_message_to( + &remote, + &mut send_buf, + &ServerMessage::Bye("Session error"), + ) + .ok(); + self.remove_session_by_fd(fd)?; + } + SessionEvent::Closed => { + log::debug!("Session closed for {remote}"); + self.socket + .send_message_to(&remote, &mut send_buf, &ServerMessage::Bye("")) + .ok(); + self.remove_session_by_fd(fd)?; + } + }, + Event::Tick => { + // Restart the timer + self.update_client_timeouts()?; + self.timer.start(PING_INTERVAL)?; + } + } + } + } + + fn update_client_timeouts(&mut self) -> Result<(), Error> { + let removed = self.socket.ping_clients(8); + for entry in removed { + log::debug!("Client timed out: {entry}"); + self.remove_session_by_remote(entry).ok(); + } + Ok(()) + } + + fn register_session(&mut self, remote: SocketAddr, session: T) -> Result<(), Error> { + let (key, session) = loop { + let key = self.last_session_key; + self.last_session_key += 1; + match self.sessions.entry(key) { + Entry::Occupied(_) => continue, + Entry::Vacant(entry) => { + let session = entry.insert(session); + break (key, session); + } + } + }; + for fd in session.event_fds() { + self.poll.add(fd)?; + self.session_event_map.insert(*fd, key); + } + self.peer_to_session.insert(remote, key); + Ok(()) + } + + fn remove_session_by_key(&mut self, key: u64) -> Result<(), Error> { + let Some(session) = self.sessions.remove(&key) else { + return Ok(()); + }; + + for fd in session.event_fds() { + self.poll.remove(fd)?; + self.session_event_map.remove(fd); + } + self.peer_to_session.remove(&session.peer()).unwrap(); + self.socket.remove_client(&session.peer()); + + if let Err(error) = session.close() { + log::warn!("Session close error: {error}"); + } + + Ok(()) + } + + fn remove_session_by_fd(&mut self, fd: RawFd) -> Result<(), Error> { + let Some(key) = self.session_event_map.get(&fd).copied() else { + return Ok(()); + }; + self.remove_session_by_key(key) + } + + fn remove_session_by_remote(&mut self, remote: SocketAddr) -> Result<(), Error> { + let Some(key) = self.peer_to_session.get(&remote).copied() else { + return Ok(()); + }; + self.remove_session_by_key(key) + } +} + +impl<'s, S: PacketSocket> SessionClient<'s, S> { + pub fn send_data(&mut self, data: &[u8]) -> Result<(), Error> { + self.send_message(&ServerMessage::Output(data)) + } + + pub fn send_message(&mut self, message: &ServerMessage) -> Result<(), Error> { + self.transport + .send_message_to(&self.address, self.send_buf, message) + } +} + +impl Session for EchoSession { + type Error = Error; + + fn open(peer: &SocketAddr, _terminal: &TerminalInfo) -> Result<Self, Self::Error> { + Ok(Self { peer: *peer }) + } + + fn close(self) -> Result<(), Self::Error> { + Ok(()) + } + + fn event_fds(&self) -> &[RawFd] { + &[] + } + + fn read_output(&mut self, _fd: RawFd, _buffer: &mut [u8]) -> Result<usize, Self::Error> { + Ok(0) + } + + fn handle_input<'s, S: PacketSocket>( + &mut self, + input: &[u8], + mut client: SessionClient<'s, S>, + ) -> Result<bool, Self::Error> { + if input.contains(&b'\x04') { + return Ok(true); + } + log::debug!("{:02x?}", input); + client.send_data(input)?; + Ok(false) + } + + fn peer(&self) -> SocketAddr { + self.peer + } +}