From a8a61926279ec88088d56a4d1b62f229f6fc1e4a Mon Sep 17 00:00:00 2001
From: Mark Poliakov <mark@alnyan.me>
Date: Sat, 2 Nov 2024 20:22:53 +0200
Subject: [PATCH] rsh: better server modularity

---
 userspace/rsh/src/crypt/mod.rs       |  28 --
 userspace/rsh/src/crypt/signature.rs |   2 +-
 userspace/rsh/src/lib.rs             |   5 +-
 userspace/rsh/src/main.rs            |   4 +-
 userspace/rsh/src/rshd/main.rs       | 354 +++++++-------------------
 userspace/rsh/src/server.rs          | 367 +++++++++++++++++++++++++++
 6 files changed, 465 insertions(+), 295 deletions(-)
 create mode 100644 userspace/rsh/src/server.rs

diff --git a/userspace/rsh/src/crypt/mod.rs b/userspace/rsh/src/crypt/mod.rs
index 1473dc3f..f37d4c13 100644
--- a/userspace/rsh/src/crypt/mod.rs
+++ b/userspace/rsh/src/crypt/mod.rs
@@ -354,37 +354,9 @@ pub fn ciphersuite_name(cipher: u8) -> Option<&'static str> {
     }
 }
 
-// pub fn hash_algo_name(hash: u8) -> Option<&'static str> {
-//     match hash {
-//         V1_HASH_SHA256 => Some("sha256"),
-//         _ => None,
-//     }
-// }
-
 pub fn sig_algo_name(sig: u8) -> Option<&'static str> {
     match sig {
         V1_SIG_ED25519 => Some("ed25519"),
         _ => None,
     }
 }
-
-// impl fmt::Display for PublicKeyFingerprint<'_> {
-//     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-//         let hash = match hash_algo_name(self.hash) {
-//             Some(name) => name,
-//             None => "unknown",
-//         };
-//         let sig = match sig_algo_name(self.sig) {
-//             Some(name) => name,
-//             None => "unknown",
-//         };
-//         write!(f, "{} {} {} ", sig, self.key_bits, hash)?;
-//         for (i, byte) in self.hash_data.iter().enumerate() {
-//             if i != 0 {
-//                 write!(f, ":")?;
-//             }
-//             write!(f, "{:02x}", *byte)?;
-//         }
-//         Ok(())
-//     }
-// }
diff --git a/userspace/rsh/src/crypt/signature.rs b/userspace/rsh/src/crypt/signature.rs
index fe66e1ff..91b80f61 100644
--- a/userspace/rsh/src/crypt/signature.rs
+++ b/userspace/rsh/src/crypt/signature.rs
@@ -79,7 +79,7 @@ impl SignEd25519 {
     }
 
     pub fn load_signing_key<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
-        let signing_key = ed25519_dalek::SigningKey::read_pkcs8_pem_file(path).unwrap();
+        let signing_key = ed25519_dalek::SigningKey::read_pkcs8_pem_file(path).map_err(|_| Error::InvalidKey)?;
         let verifying_key = signing_key.verifying_key();
         Ok(Self {
             signing_key,
diff --git a/userspace/rsh/src/lib.rs b/userspace/rsh/src/lib.rs
index 2d064730..7e07e065 100644
--- a/userspace/rsh/src/lib.rs
+++ b/userspace/rsh/src/lib.rs
@@ -1,5 +1,5 @@
 #![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))]
-#![feature(generic_const_exprs, portable_simd)]
+#![feature(generic_const_exprs, portable_simd, if_let_guard)]
 #![allow(incomplete_features)]
 
 use std::io;
@@ -9,12 +9,13 @@ use proto::{DecodeError, EncodeError};
 pub mod proto;
 pub mod socket;
 pub mod crypt;
+pub mod server;
 
 pub use socket::{ClientSocket, ServerSocket};
 
 #[derive(Debug, thiserror::Error)]
 pub enum Error {
-    #[error("I/O error")]
+    #[error("I/O error: {0}")]
     Io(#[from] io::Error),
     #[error("Could not send a message fully")]
     Truncated,
diff --git a/userspace/rsh/src/main.rs b/userspace/rsh/src/main.rs
index 05ea91f7..de3a49df 100644
--- a/userspace/rsh/src/main.rs
+++ b/userspace/rsh/src/main.rs
@@ -31,6 +31,8 @@ pub enum Error {
 struct Args {
     #[clap(short, long)]
     key: PathBuf,
+    #[clap(short = 'P', long, default_value_t = 77)]
+    port: u16,
     remote: IpAddr,
 }
 
@@ -223,7 +225,7 @@ fn terminal_info(stdout: &Stdout) -> Result<TerminalInfo, Error> {
 }
 
 fn run(args: Args) -> Result<(), Error> {
-    let remote = SocketAddr::new(args.remote, 77);
+    let remote = SocketAddr::new(args.remote, args.port);
     let ed25519 = SignEd25519::load_signing_key(args.key).unwrap();
     let key = SignatureMethod::Ed25519(ed25519);
     let config = ClientConfig {
diff --git a/userspace/rsh/src/rshd/main.rs b/userspace/rsh/src/rshd/main.rs
index b53ac62d..de599feb 100644
--- a/userspace/rsh/src/rshd/main.rs
+++ b/userspace/rsh/src/rshd/main.rs
@@ -1,15 +1,18 @@
 #![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os, rustc_private))]
 #![feature(if_let_guard)]
 use std::{
-    collections::{HashMap, HashSet}, fs::File, io::{Read, Write}, net::{SocketAddr, UdpSocket}, os::fd::{AsRawFd, FromRawFd, RawFd}, path::PathBuf, process::{Child, Command, ExitCode, Stdio}, str::FromStr, time::Duration
+    collections::HashSet,
+    net::SocketAddr,
+    path::PathBuf,
+    process::ExitCode,
+    str::FromStr,
+    time::Duration,
 };
 
 use clap::Parser;
-use cross::io::{Poll, TimerFd};
 use rsh::{
-    crypt::{server::ServerConfig, ServerEncryptedSocket, SimpleServerKeyStore},
-    proto::{ClientMessage, Decode, Decoder, ServerMessage, TerminalInfo},
-    socket::{MultiplexedSocket, MultiplexedSocketEvent},
+    crypt::{server::ServerConfig, SimpleServerKeyStore},
+    server::Server,
     Error,
 };
 
@@ -19,287 +22,112 @@ pub const PING_INTERVAL: Duration = Duration::from_millis(500);
 struct Args {
     #[clap(short = 'P', long, help = "rsh listen port", default_value_t = 77)]
     port: u16,
-    #[clap(short = 'S', long, help = "where rsh will load private keys from", default_value = "/etc/rsh")]
-    keystore: PathBuf
+    #[clap(
+        short = 'S',
+        long,
+        help = "where rsh will load private keys from",
+        default_value = "/etc/rsh"
+    )]
+    keystore: PathBuf,
 }
 
-pub struct Session {
-    pty_master: File,
+#[cfg(target_os = "yggdrasil")]
+pub struct YggdrasilSession {
+    pty_master: std::fs::File,
+    fds: [std::os::fd::RawFd; 1],
     remote: SocketAddr,
-    shell: Child,
+    shell: std::process::Child,
 }
 
-pub struct Server {
-    poll: Poll,
-    timer: TimerFd,
-    socket: ServerEncryptedSocket<UdpSocket>,
+#[cfg(target_os = "yggdrasil")]
+impl rsh::server::Session for YggdrasilSession {
+    type Error = std::io::Error;
 
-    addr_to_session: HashMap<SocketAddr, RawFd>,
-    pty_to_session: HashMap<RawFd, Session>,
-}
-
-pub enum PtyEvent<'b> {
-    Data(&'b [u8]),
-    Err(Error),
-    Closed,
-}
-
-pub enum Event<'b> {
-    NewClient(SocketAddr, TerminalInfo),
-    SessionInput(RawFd, SocketAddr, &'b [u8]),
-    ClientBye(SocketAddr, &'b str),
-    Pty(RawFd, SocketAddr, PtyEvent<'b>),
-    Tick,
-}
-
-impl Session {
-    pub fn open(info: &TerminalInfo, remote: SocketAddr) -> Result<Self, Error> {
-        #[cfg(target_os = "yggdrasil")]
-        {
-            use std::os::yggdrasil::{
-                self,
-                io::terminal::{create_pty, TerminalSize},
-                process::CommandExt,
-            };
-            // TODO unix version
-            let (pty_master, pty_slave) = create_pty(
-                Default::default(),
-                TerminalSize {
-                    columns: info.columns as _,
-                    rows: info.rows as _,
+    fn open(remote: &SocketAddr, terminal: &rsh::proto::TerminalInfo) -> Result<Self, Self::Error> {
+        use std::{
+            os::{
+                fd::{AsRawFd, FromRawFd},
+                yggdrasil::{
+                    self,
+                    io::terminal::{create_pty, TerminalSize},
+                    process::CommandExt,
                 },
-            )?;
+            },
+            process::{Command, Stdio},
+        };
 
-            let pty_slave_fd = pty_slave.as_raw_fd();
-            let group_id = yggdrasil::process::create_process_group();
-            let shell = unsafe {
-                Command::new("/bin/sh")
-                    .arg("-l")
-                    .stdin(Stdio::from_raw_fd(pty_slave_fd))
-                    .stdout(Stdio::from_raw_fd(pty_slave_fd))
-                    .stderr(Stdio::from_raw_fd(pty_slave_fd))
-                    .process_group(group_id)
-                    .gain_terminal(0)
-                    .spawn()?
-            };
+        let remote = *remote;
+        // TODO unix version
+        let (pty_master, pty_slave) = create_pty(
+            Default::default(),
+            TerminalSize {
+                columns: terminal.columns as _,
+                rows: terminal.rows as _,
+            },
+        )?;
 
-            Ok(Self {
-                pty_master,
-                shell,
-                remote,
-            })
-        }
-        #[cfg(unix)]
-        {
-            todo!()
-        }
-    }
-}
+        let pty_slave_fd = pty_slave.as_raw_fd();
+        let group_id = yggdrasil::process::create_process_group();
+        let shell = unsafe {
+            Command::new("/bin/sh")
+                .arg("-l")
+                .stdin(Stdio::from_raw_fd(pty_slave_fd))
+                .stdout(Stdio::from_raw_fd(pty_slave_fd))
+                .stderr(Stdio::from_raw_fd(pty_slave_fd))
+                .process_group(group_id)
+                .gain_terminal(0)
+                .spawn()?
+        };
+
+        let fds = [pty_master.as_raw_fd()];
 
-impl Server {
-    pub fn new(listen_addr: SocketAddr, crypto_config: ServerConfig) -> Result<Self, Error> {
-        let mut poll = Poll::new()?;
-        let timer = TimerFd::new()?;
-        let socket = UdpSocket::bind(listen_addr)?;
-        let socket = ServerEncryptedSocket::new_with_config(socket, crypto_config);
-        poll.add(&socket)?;
-        poll.add(&timer)?;
         Ok(Self {
-            poll,
-            socket,
-            timer,
-            addr_to_session: HashMap::new(),
-            pty_to_session: HashMap::new(),
+            pty_master,
+            shell,
+            remote,
+            fds,
         })
     }
 
-    pub fn poll<'b>(
-        &mut self,
-        buffer: &'b mut [u8],
-        pty_max: usize,
-    ) -> Result<Option<Event<'b>>, Error> {
-        let fd = self.poll.wait(None)?.unwrap();
-
-        match fd {
-            fd if fd == self.socket.as_raw_fd() => {
-                let event = self.socket.recv_from(buffer)?;
-
-                let (message, remote) = match event {
-                    MultiplexedSocketEvent::ClientDisconnected(remote) => {
-                        self.remove_session_by_remote(remote).ok();
-                        return Ok(None);
-                    }
-                    MultiplexedSocketEvent::None(_) => return Ok(None),
-                    MultiplexedSocketEvent::ClientData(peer, data) => {
-                        let mut decoder = Decoder::new(data);
-                        let message = ClientMessage::decode(&mut decoder);
-                        (message, peer)
-                    }
-                    MultiplexedSocketEvent::Error(_) => return Ok(None),
-                };
-
-                let message = match message {
-                    Ok(message) => message,
-                    Err(error) => {
-                        eprintln!("Decode error: {error}");
-                        return Ok(None);
-                    }
-                };
-
-                let event = match message {
-                    ClientMessage::Hello(terminal)
-                        if self.addr_to_session.get(&remote).is_none() =>
-                    {
-                        Event::NewClient(remote, terminal)
-                    }
-                    ClientMessage::Bye(reason) => Event::ClientBye(remote, reason),
-                    ClientMessage::Input(data)
-                        if let Some(fd) = self.addr_to_session.get(&remote) =>
-                    {
-                        Event::SessionInput(*fd, remote, data)
-                    }
-                    _ => return Ok(None),
-                };
-
-                Ok(Some(event))
-            }
-            fd if fd == self.timer.as_raw_fd() => Ok(Some(Event::Tick)),
-            fd => {
-                let session = self.pty_to_session.get_mut(&fd).unwrap();
-                let event = match session.pty_master.read(&mut buffer[..pty_max]) {
-                    Ok(0) => PtyEvent::Closed,
-                    Ok(len) => PtyEvent::Data(&buffer[..len]),
-                    Err(e) => PtyEvent::Err(e.into()),
-                };
-                Ok(Some(Event::Pty(fd, session.remote, event)))
-            }
-        }
-    }
-
-    pub fn run(mut self) -> Result<(), Error> {
-        self.timer.start(PING_INTERVAL)?;
-        let mut recv_buf = [0; 256];
-        let mut send_buf = [0; 256];
-
-        loop {
-            let Some(event) = self.poll(&mut recv_buf, 128)? else {
-                continue;
-            };
-
-            match event {
-                Event::SessionInput(fd, remote, data) => {
-                    let session = self.pty_to_session.get_mut(&fd).unwrap();
-                    if let Err(error) = session.pty_master.write(&data) {
-                        eprintln!("PTY write error: {error}");
-                        self.socket
-                            .send_message_to(
-                                &remote,
-                                &mut send_buf,
-                                &ServerMessage::Bye("PTY error"),
-                            )
-                            .ok();
-                        self.remove_session_by_fd(fd)?;
-                    }
-                }
-                Event::ClientBye(remote, reason) => {
-                    println!("Client {remote} disconnected: {reason}");
-                    self.remove_session_by_remote(remote)?;
-                }
-                Event::NewClient(remote, terminal) => {
-                    println!("New client: {remote}");
-                    match Session::open(&terminal, remote) {
-                        Ok(session) => {
-                            self.register_session(remote, session)?;
-                            self.socket
-                                .send_message_to(&remote, &mut send_buf, &ServerMessage::Hello)
-                                .ok();
-                        }
-                        Err(err) => {
-                            eprintln!("PTY open error: {err}");
-                            self.socket
-                                .send_message_to(
-                                    &remote,
-                                    &mut send_buf,
-                                    &ServerMessage::Bye("PTY open error"),
-                                )
-                                .ok();
-                            self.socket.remove_client(&remote);
-                        }
-                    }
-                }
-                Event::Pty(fd, remote, event) => match event {
-                    PtyEvent::Data(data) => {
-                        self.socket
-                            .send_message_to(&remote, &mut send_buf, &ServerMessage::Output(data))
-                            .ok();
-                    }
-                    PtyEvent::Err(error) => {
-                        eprintln!("PTY read error: {error}");
-                        self.socket
-                            .send_message_to(
-                                &remote,
-                                &mut send_buf,
-                                &ServerMessage::Bye("PTY error"),
-                            )
-                            .ok();
-                        self.remove_session_by_fd(fd)?;
-                    }
-                    PtyEvent::Closed => {
-                        println!("End of PTY for {remote}");
-                        self.socket
-                            .send_message_to(&remote, &mut send_buf, &ServerMessage::Bye(""))
-                            .ok();
-                        self.remove_session_by_fd(fd)?;
-                    }
-                },
-                Event::Tick => {
-                    // Restart the timer
-                    self.update_client_timeouts()?;
-                    self.timer.start(PING_INTERVAL)?;
-                }
-            }
-        }
-    }
-
-    fn update_client_timeouts(&mut self) -> Result<(), Error> {
-        let removed = self.socket.ping_clients(8);
-        for entry in removed {
-            log::debug!("Client timed out: {entry}");
-            self.remove_session_by_remote(entry).ok();
-        }
+    fn close(mut self) -> Result<(), Self::Error> {
+        self.shell.wait()?;
         Ok(())
     }
 
-    fn register_session(&mut self, remote: SocketAddr, session: Session) -> Result<(), Error> {
-        let fd = session.pty_master.as_raw_fd();
-        self.addr_to_session.insert(remote, fd);
-        self.pty_to_session.insert(fd, session);
-        self.poll.add(&fd).map_err(Error::from)
+    fn peer(&self) -> SocketAddr {
+        self.remote
     }
 
-    fn remove_session_by_fd(&mut self, fd: RawFd) -> Result<Option<Session>, Error> {
-        if let Some(mut session) = self.pty_to_session.remove(&fd) {
-            // TODO: implement kernel support for pidfd or something, to poll the exit status of
-            //       the task instead of doing it here.
-            // NOTE: this will block the whole server while the process finishes.
-            session.shell.wait().ok();
-            self.addr_to_session.remove(&session.remote).unwrap();
-            self.socket.remove_client(&session.remote);
-            self.poll.remove(&fd)?;
-            Ok(Some(session))
-        } else {
-            Ok(None)
-        }
+    fn handle_input<'s, S: rsh::socket::PacketSocket>(
+        &mut self,
+        input: &[u8],
+        _client: rsh::server::SessionClient<'s, S>,
+    ) -> Result<bool, Self::Error> {
+        use std::io::Write;
+        self.pty_master.write_all(input)?;
+        Ok(false)
     }
 
-    fn remove_session_by_remote(&mut self, remote: SocketAddr) -> Result<Option<Session>, Error> {
-        let Some(fd) = self.addr_to_session.get(&remote).copied() else {
-            return Ok(None);
-        };
-        self.remove_session_by_fd(fd)
+    fn read_output(
+        &mut self,
+        fd: std::os::fd::RawFd,
+        buffer: &mut [u8],
+    ) -> Result<usize, Self::Error> {
+        use std::io::Read;
+        assert_eq!(fd, self.fds[0]);
+        self.pty_master.read(buffer)
+    }
+
+    fn event_fds(&self) -> &[std::os::fd::RawFd] {
+        &self.fds
     }
 }
 
+#[cfg(unix)]
+pub type SessionImpl = rsh::server::EchoSession;
+#[cfg(target_os = "yggdrasil")]
+pub type SessionImpl = YggdrasilSession;
+
 fn run(args: Args) -> Result<(), Error> {
     let keystore = Box::new(SimpleServerKeyStore {
         path: args.keystore,
@@ -312,7 +140,7 @@ fn run(args: Args) -> Result<(), Error> {
         ..Default::default()
     };
     let listen_addr = SocketAddr::from_str(&format!("0.0.0.0:{}", args.port)).unwrap();
-    let server = Server::new(listen_addr, server_config)?;
+    let server = Server::<_, SessionImpl>::listen_udp(listen_addr, server_config)?;
     server.run()
 }
 
diff --git a/userspace/rsh/src/server.rs b/userspace/rsh/src/server.rs
new file mode 100644
index 00000000..e059c858
--- /dev/null
+++ b/userspace/rsh/src/server.rs
@@ -0,0 +1,367 @@
+use std::{
+    collections::{hash_map::Entry, HashMap},
+    fmt,
+    net::{SocketAddr, UdpSocket},
+    os::fd::{AsRawFd, RawFd},
+    time::Duration,
+};
+
+use cross::io::{Poll, TimerFd};
+
+use crate::{
+    crypt::{server::ServerConfig, ServerEncryptedSocket},
+    proto::{ClientMessage, Decode, Decoder, ServerMessage, TerminalInfo},
+    socket::{MultiplexedSocket, MultiplexedSocketEvent, PacketSocket},
+    Error,
+};
+
+pub const PING_INTERVAL: Duration = Duration::from_millis(500);
+
+pub trait Session: Sized {
+    type Error: fmt::Display;
+
+    fn open(peer: &SocketAddr, terminal: &TerminalInfo) -> Result<Self, Self::Error>;
+    fn peer(&self) -> SocketAddr;
+    fn handle_input<'s, S: PacketSocket>(
+        &mut self,
+        input: &[u8],
+        client: SessionClient<'s, S>,
+    ) -> Result<bool, Self::Error>;
+    fn read_output(&mut self, fd: RawFd, buffer: &mut [u8]) -> Result<usize, Self::Error>;
+    fn event_fds(&self) -> &[RawFd];
+    fn close(self) -> Result<(), Self::Error>;
+}
+
+pub struct SessionClient<'s, S: PacketSocket> {
+    address: SocketAddr,
+    transport: &'s mut ServerEncryptedSocket<S>,
+    send_buf: &'s mut [u8],
+}
+
+pub struct EchoSession {
+    peer: SocketAddr,
+}
+
+enum SessionEvent<'b, T: Session> {
+    Data(&'b [u8]),
+    Err(T::Error),
+    Closed,
+}
+
+enum Event<'b, T: Session> {
+    NewClient(SocketAddr, TerminalInfo),
+    SessionInput(u64, SocketAddr, &'b [u8]),
+    ClientBye(SocketAddr, &'b str),
+    SessionEvent(RawFd, SocketAddr, SessionEvent<'b, T>),
+    Tick,
+}
+
+pub struct Server<S: PacketSocket, T: Session> {
+    poll: Poll,
+    timer: TimerFd,
+    socket: ServerEncryptedSocket<S>,
+
+    last_session_key: u64,
+    sessions: HashMap<u64, T>,
+    peer_to_session: HashMap<SocketAddr, u64>,
+    session_event_map: HashMap<RawFd, u64>,
+}
+
+impl<T: Session> Server<UdpSocket, T> {
+    pub fn listen_udp(listen_addr: SocketAddr, crypto_config: ServerConfig) -> Result<Self, Error> {
+        let mut poll = Poll::new()?;
+        let timer = TimerFd::new()?;
+        let socket = UdpSocket::bind(listen_addr)?;
+        let socket = ServerEncryptedSocket::new_with_config(socket, crypto_config);
+        poll.add(&socket)?;
+        poll.add(&timer)?;
+        Ok(Self {
+            poll,
+            socket,
+            timer,
+            last_session_key: 1,
+            sessions: HashMap::new(),
+            peer_to_session: HashMap::new(),
+            session_event_map: HashMap::new(),
+        })
+    }
+
+    fn poll<'b>(
+        &mut self,
+        buffer: &'b mut [u8],
+        pty_max: usize,
+    ) -> Result<Option<Event<'b, T>>, Error> {
+        let fd = self.poll.wait(None)?.unwrap();
+
+        match fd {
+            fd if fd == self.socket.as_raw_fd() => {
+                let event = self.socket.recv_from(buffer)?;
+
+                let (message, remote) = match event {
+                    MultiplexedSocketEvent::ClientDisconnected(remote) => {
+                        self.remove_session_by_remote(remote).ok();
+                        return Ok(None);
+                    }
+                    MultiplexedSocketEvent::None(_) => return Ok(None),
+                    MultiplexedSocketEvent::ClientData(peer, data) => {
+                        let mut decoder = Decoder::new(data);
+                        let message = ClientMessage::decode(&mut decoder);
+                        (message, peer)
+                    }
+                    MultiplexedSocketEvent::Error(_) => return Ok(None),
+                };
+
+                let message = match message {
+                    Ok(message) => message,
+                    Err(error) => {
+                        log::warn!("Decode error: {error}");
+                        return Ok(None);
+                    }
+                };
+
+                let event = match message {
+                    ClientMessage::Hello(terminal)
+                        if self.peer_to_session.get(&remote).is_none() =>
+                    {
+                        Event::NewClient(remote, terminal)
+                    }
+                    ClientMessage::Bye(reason) => Event::ClientBye(remote, reason),
+                    ClientMessage::Input(data)
+                        if let Some(fd) = self.peer_to_session.get(&remote) =>
+                    {
+                        Event::SessionInput(*fd, remote, data)
+                    }
+                    _ => return Ok(None),
+                };
+
+                Ok(Some(event))
+            }
+            fd if fd == self.timer.as_raw_fd() => Ok(Some(Event::Tick)),
+            fd => {
+                // Otherwise the event comes from a session
+                let key = *self.session_event_map.get(&fd).unwrap();
+                let session = self.sessions.get_mut(&key).unwrap();
+                let event = match session.read_output(fd, &mut buffer[..pty_max]) {
+                    Ok(0) => SessionEvent::Closed,
+                    Ok(len) => SessionEvent::Data(&buffer[..len]),
+                    Err(e) => SessionEvent::Err(e),
+                };
+                Ok(Some(Event::SessionEvent(fd, session.peer(), event)))
+            }
+        }
+    }
+
+    pub fn run(mut self) -> Result<(), Error> {
+        self.timer.start(PING_INTERVAL)?;
+        let mut recv_buf = [0; 256];
+        let mut send_buf = [0; 256];
+
+        loop {
+            let Some(event) = self.poll(&mut recv_buf, 128)? else {
+                continue;
+            };
+
+            match event {
+                Event::SessionInput(key, remote, data) => {
+                    let session = self.sessions.get_mut(&key).unwrap();
+                    let peer = SessionClient {
+                        address: remote,
+                        send_buf: &mut send_buf,
+                        transport: &mut self.socket,
+                    };
+                    match session.handle_input(data, peer) {
+                        Ok(false) => (),
+                        Ok(true) => {
+                            log::debug!("{remote}: session closed");
+                            self.socket
+                                .send_message_to(
+                                    &remote,
+                                    &mut send_buf,
+                                    &ServerMessage::Bye("Session closed"),
+                                )
+                                .ok();
+                            self.remove_session_by_key(key)?;
+                        }
+                        Err(error) => {
+                            log::error!("{remote}: session input error: {error}");
+                            self.socket
+                                .send_message_to(
+                                    &remote,
+                                    &mut send_buf,
+                                    &ServerMessage::Bye("Session error"),
+                                )
+                                .ok();
+                            self.remove_session_by_key(key)?;
+                        }
+                    }
+                }
+                Event::ClientBye(remote, reason) => {
+                    log::debug!("Client {remote} disconnected: {reason}");
+                    self.remove_session_by_remote(remote)?;
+                }
+                Event::NewClient(remote, terminal) => {
+                    log::debug!("New client: {remote}");
+                    match T::open(&remote, &terminal) {
+                        Ok(session) => {
+                            self.register_session(remote, session)?;
+                            self.socket
+                                .send_message_to(&remote, &mut send_buf, &ServerMessage::Hello)
+                                .ok();
+                        }
+                        Err(err) => {
+                            log::error!("PTY open error: {err}");
+                            self.socket
+                                .send_message_to(
+                                    &remote,
+                                    &mut send_buf,
+                                    &ServerMessage::Bye("PTY open error"),
+                                )
+                                .ok();
+                            self.socket.remove_client(&remote);
+                        }
+                    }
+                }
+                Event::SessionEvent(fd, remote, event) => match event {
+                    SessionEvent::Data(data) => {
+                        self.socket
+                            .send_message_to(&remote, &mut send_buf, &ServerMessage::Output(data))
+                            .ok();
+                    }
+                    SessionEvent::Err(error) => {
+                        log::error!("Session output read error: {error}");
+                        self.socket
+                            .send_message_to(
+                                &remote,
+                                &mut send_buf,
+                                &ServerMessage::Bye("Session error"),
+                            )
+                            .ok();
+                        self.remove_session_by_fd(fd)?;
+                    }
+                    SessionEvent::Closed => {
+                        log::debug!("Session closed for {remote}");
+                        self.socket
+                            .send_message_to(&remote, &mut send_buf, &ServerMessage::Bye(""))
+                            .ok();
+                        self.remove_session_by_fd(fd)?;
+                    }
+                },
+                Event::Tick => {
+                    // Restart the timer
+                    self.update_client_timeouts()?;
+                    self.timer.start(PING_INTERVAL)?;
+                }
+            }
+        }
+    }
+
+    fn update_client_timeouts(&mut self) -> Result<(), Error> {
+        let removed = self.socket.ping_clients(8);
+        for entry in removed {
+            log::debug!("Client timed out: {entry}");
+            self.remove_session_by_remote(entry).ok();
+        }
+        Ok(())
+    }
+
+    fn register_session(&mut self, remote: SocketAddr, session: T) -> Result<(), Error> {
+        let (key, session) = loop {
+            let key = self.last_session_key;
+            self.last_session_key += 1;
+            match self.sessions.entry(key) {
+                Entry::Occupied(_) => continue,
+                Entry::Vacant(entry) => {
+                    let session = entry.insert(session);
+                    break (key, session);
+                }
+            }
+        };
+        for fd in session.event_fds() {
+            self.poll.add(fd)?;
+            self.session_event_map.insert(*fd, key);
+        }
+        self.peer_to_session.insert(remote, key);
+        Ok(())
+    }
+
+    fn remove_session_by_key(&mut self, key: u64) -> Result<(), Error> {
+        let Some(session) = self.sessions.remove(&key) else {
+            return Ok(());
+        };
+
+        for fd in session.event_fds() {
+            self.poll.remove(fd)?;
+            self.session_event_map.remove(fd);
+        }
+        self.peer_to_session.remove(&session.peer()).unwrap();
+        self.socket.remove_client(&session.peer());
+
+        if let Err(error) = session.close() {
+            log::warn!("Session close error: {error}");
+        }
+
+        Ok(())
+    }
+
+    fn remove_session_by_fd(&mut self, fd: RawFd) -> Result<(), Error> {
+        let Some(key) = self.session_event_map.get(&fd).copied() else {
+            return Ok(());
+        };
+        self.remove_session_by_key(key)
+    }
+
+    fn remove_session_by_remote(&mut self, remote: SocketAddr) -> Result<(), Error> {
+        let Some(key) = self.peer_to_session.get(&remote).copied() else {
+            return Ok(());
+        };
+        self.remove_session_by_key(key)
+    }
+}
+
+impl<'s, S: PacketSocket> SessionClient<'s, S> {
+    pub fn send_data(&mut self, data: &[u8]) -> Result<(), Error> {
+        self.send_message(&ServerMessage::Output(data))
+    }
+
+    pub fn send_message(&mut self, message: &ServerMessage) -> Result<(), Error> {
+        self.transport
+            .send_message_to(&self.address, self.send_buf, message)
+    }
+}
+
+impl Session for EchoSession {
+    type Error = Error;
+
+    fn open(peer: &SocketAddr, _terminal: &TerminalInfo) -> Result<Self, Self::Error> {
+        Ok(Self { peer: *peer })
+    }
+
+    fn close(self) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
+    fn event_fds(&self) -> &[RawFd] {
+        &[]
+    }
+
+    fn read_output(&mut self, _fd: RawFd, _buffer: &mut [u8]) -> Result<usize, Self::Error> {
+        Ok(0)
+    }
+
+    fn handle_input<'s, S: PacketSocket>(
+        &mut self,
+        input: &[u8],
+        mut client: SessionClient<'s, S>,
+    ) -> Result<bool, Self::Error> {
+        if input.contains(&b'\x04') {
+            return Ok(true);
+        }
+        log::debug!("{:02x?}", input);
+        client.send_data(input)?;
+        Ok(false)
+    }
+
+    fn peer(&self) -> SocketAddr {
+        self.peer
+    }
+}