rsh: better server modularity
This commit is contained in:
parent
f0fdeb1004
commit
a8a6192627
@ -354,37 +354,9 @@ pub fn ciphersuite_name(cipher: u8) -> Option<&'static str> {
|
||||
}
|
||||
}
|
||||
|
||||
// pub fn hash_algo_name(hash: u8) -> Option<&'static str> {
|
||||
// match hash {
|
||||
// V1_HASH_SHA256 => Some("sha256"),
|
||||
// _ => None,
|
||||
// }
|
||||
// }
|
||||
|
||||
pub fn sig_algo_name(sig: u8) -> Option<&'static str> {
|
||||
match sig {
|
||||
V1_SIG_ED25519 => Some("ed25519"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
// impl fmt::Display for PublicKeyFingerprint<'_> {
|
||||
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
// let hash = match hash_algo_name(self.hash) {
|
||||
// Some(name) => name,
|
||||
// None => "unknown",
|
||||
// };
|
||||
// let sig = match sig_algo_name(self.sig) {
|
||||
// Some(name) => name,
|
||||
// None => "unknown",
|
||||
// };
|
||||
// write!(f, "{} {} {} ", sig, self.key_bits, hash)?;
|
||||
// for (i, byte) in self.hash_data.iter().enumerate() {
|
||||
// if i != 0 {
|
||||
// write!(f, ":")?;
|
||||
// }
|
||||
// write!(f, "{:02x}", *byte)?;
|
||||
// }
|
||||
// Ok(())
|
||||
// }
|
||||
// }
|
||||
|
@ -79,7 +79,7 @@ impl SignEd25519 {
|
||||
}
|
||||
|
||||
pub fn load_signing_key<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
|
||||
let signing_key = ed25519_dalek::SigningKey::read_pkcs8_pem_file(path).unwrap();
|
||||
let signing_key = ed25519_dalek::SigningKey::read_pkcs8_pem_file(path).map_err(|_| Error::InvalidKey)?;
|
||||
let verifying_key = signing_key.verifying_key();
|
||||
Ok(Self {
|
||||
signing_key,
|
||||
|
@ -1,5 +1,5 @@
|
||||
#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))]
|
||||
#![feature(generic_const_exprs, portable_simd)]
|
||||
#![feature(generic_const_exprs, portable_simd, if_let_guard)]
|
||||
#![allow(incomplete_features)]
|
||||
|
||||
use std::io;
|
||||
@ -9,12 +9,13 @@ use proto::{DecodeError, EncodeError};
|
||||
pub mod proto;
|
||||
pub mod socket;
|
||||
pub mod crypt;
|
||||
pub mod server;
|
||||
|
||||
pub use socket::{ClientSocket, ServerSocket};
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("I/O error")]
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
#[error("Could not send a message fully")]
|
||||
Truncated,
|
||||
|
@ -31,6 +31,8 @@ pub enum Error {
|
||||
struct Args {
|
||||
#[clap(short, long)]
|
||||
key: PathBuf,
|
||||
#[clap(short = 'P', long, default_value_t = 77)]
|
||||
port: u16,
|
||||
remote: IpAddr,
|
||||
}
|
||||
|
||||
@ -223,7 +225,7 @@ fn terminal_info(stdout: &Stdout) -> Result<TerminalInfo, Error> {
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<(), Error> {
|
||||
let remote = SocketAddr::new(args.remote, 77);
|
||||
let remote = SocketAddr::new(args.remote, args.port);
|
||||
let ed25519 = SignEd25519::load_signing_key(args.key).unwrap();
|
||||
let key = SignatureMethod::Ed25519(ed25519);
|
||||
let config = ClientConfig {
|
||||
|
@ -1,15 +1,18 @@
|
||||
#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os, rustc_private))]
|
||||
#![feature(if_let_guard)]
|
||||
use std::{
|
||||
collections::{HashMap, HashSet}, fs::File, io::{Read, Write}, net::{SocketAddr, UdpSocket}, os::fd::{AsRawFd, FromRawFd, RawFd}, path::PathBuf, process::{Child, Command, ExitCode, Stdio}, str::FromStr, time::Duration
|
||||
collections::HashSet,
|
||||
net::SocketAddr,
|
||||
path::PathBuf,
|
||||
process::ExitCode,
|
||||
str::FromStr,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use clap::Parser;
|
||||
use cross::io::{Poll, TimerFd};
|
||||
use rsh::{
|
||||
crypt::{server::ServerConfig, ServerEncryptedSocket, SimpleServerKeyStore},
|
||||
proto::{ClientMessage, Decode, Decoder, ServerMessage, TerminalInfo},
|
||||
socket::{MultiplexedSocket, MultiplexedSocketEvent},
|
||||
crypt::{server::ServerConfig, SimpleServerKeyStore},
|
||||
server::Server,
|
||||
Error,
|
||||
};
|
||||
|
||||
@ -19,287 +22,112 @@ pub const PING_INTERVAL: Duration = Duration::from_millis(500);
|
||||
struct Args {
|
||||
#[clap(short = 'P', long, help = "rsh listen port", default_value_t = 77)]
|
||||
port: u16,
|
||||
#[clap(short = 'S', long, help = "where rsh will load private keys from", default_value = "/etc/rsh")]
|
||||
keystore: PathBuf
|
||||
#[clap(
|
||||
short = 'S',
|
||||
long,
|
||||
help = "where rsh will load private keys from",
|
||||
default_value = "/etc/rsh"
|
||||
)]
|
||||
keystore: PathBuf,
|
||||
}
|
||||
|
||||
pub struct Session {
|
||||
pty_master: File,
|
||||
#[cfg(target_os = "yggdrasil")]
|
||||
pub struct YggdrasilSession {
|
||||
pty_master: std::fs::File,
|
||||
fds: [std::os::fd::RawFd; 1],
|
||||
remote: SocketAddr,
|
||||
shell: Child,
|
||||
shell: std::process::Child,
|
||||
}
|
||||
|
||||
pub struct Server {
|
||||
poll: Poll,
|
||||
timer: TimerFd,
|
||||
socket: ServerEncryptedSocket<UdpSocket>,
|
||||
#[cfg(target_os = "yggdrasil")]
|
||||
impl rsh::server::Session for YggdrasilSession {
|
||||
type Error = std::io::Error;
|
||||
|
||||
addr_to_session: HashMap<SocketAddr, RawFd>,
|
||||
pty_to_session: HashMap<RawFd, Session>,
|
||||
}
|
||||
|
||||
pub enum PtyEvent<'b> {
|
||||
Data(&'b [u8]),
|
||||
Err(Error),
|
||||
Closed,
|
||||
}
|
||||
|
||||
pub enum Event<'b> {
|
||||
NewClient(SocketAddr, TerminalInfo),
|
||||
SessionInput(RawFd, SocketAddr, &'b [u8]),
|
||||
ClientBye(SocketAddr, &'b str),
|
||||
Pty(RawFd, SocketAddr, PtyEvent<'b>),
|
||||
Tick,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn open(info: &TerminalInfo, remote: SocketAddr) -> Result<Self, Error> {
|
||||
#[cfg(target_os = "yggdrasil")]
|
||||
{
|
||||
use std::os::yggdrasil::{
|
||||
self,
|
||||
io::terminal::{create_pty, TerminalSize},
|
||||
process::CommandExt,
|
||||
};
|
||||
// TODO unix version
|
||||
let (pty_master, pty_slave) = create_pty(
|
||||
Default::default(),
|
||||
TerminalSize {
|
||||
columns: info.columns as _,
|
||||
rows: info.rows as _,
|
||||
fn open(remote: &SocketAddr, terminal: &rsh::proto::TerminalInfo) -> Result<Self, Self::Error> {
|
||||
use std::{
|
||||
os::{
|
||||
fd::{AsRawFd, FromRawFd},
|
||||
yggdrasil::{
|
||||
self,
|
||||
io::terminal::{create_pty, TerminalSize},
|
||||
process::CommandExt,
|
||||
},
|
||||
)?;
|
||||
},
|
||||
process::{Command, Stdio},
|
||||
};
|
||||
|
||||
let pty_slave_fd = pty_slave.as_raw_fd();
|
||||
let group_id = yggdrasil::process::create_process_group();
|
||||
let shell = unsafe {
|
||||
Command::new("/bin/sh")
|
||||
.arg("-l")
|
||||
.stdin(Stdio::from_raw_fd(pty_slave_fd))
|
||||
.stdout(Stdio::from_raw_fd(pty_slave_fd))
|
||||
.stderr(Stdio::from_raw_fd(pty_slave_fd))
|
||||
.process_group(group_id)
|
||||
.gain_terminal(0)
|
||||
.spawn()?
|
||||
};
|
||||
let remote = *remote;
|
||||
// TODO unix version
|
||||
let (pty_master, pty_slave) = create_pty(
|
||||
Default::default(),
|
||||
TerminalSize {
|
||||
columns: terminal.columns as _,
|
||||
rows: terminal.rows as _,
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
pty_master,
|
||||
shell,
|
||||
remote,
|
||||
})
|
||||
}
|
||||
#[cfg(unix)]
|
||||
{
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
let pty_slave_fd = pty_slave.as_raw_fd();
|
||||
let group_id = yggdrasil::process::create_process_group();
|
||||
let shell = unsafe {
|
||||
Command::new("/bin/sh")
|
||||
.arg("-l")
|
||||
.stdin(Stdio::from_raw_fd(pty_slave_fd))
|
||||
.stdout(Stdio::from_raw_fd(pty_slave_fd))
|
||||
.stderr(Stdio::from_raw_fd(pty_slave_fd))
|
||||
.process_group(group_id)
|
||||
.gain_terminal(0)
|
||||
.spawn()?
|
||||
};
|
||||
|
||||
let fds = [pty_master.as_raw_fd()];
|
||||
|
||||
impl Server {
|
||||
pub fn new(listen_addr: SocketAddr, crypto_config: ServerConfig) -> Result<Self, Error> {
|
||||
let mut poll = Poll::new()?;
|
||||
let timer = TimerFd::new()?;
|
||||
let socket = UdpSocket::bind(listen_addr)?;
|
||||
let socket = ServerEncryptedSocket::new_with_config(socket, crypto_config);
|
||||
poll.add(&socket)?;
|
||||
poll.add(&timer)?;
|
||||
Ok(Self {
|
||||
poll,
|
||||
socket,
|
||||
timer,
|
||||
addr_to_session: HashMap::new(),
|
||||
pty_to_session: HashMap::new(),
|
||||
pty_master,
|
||||
shell,
|
||||
remote,
|
||||
fds,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn poll<'b>(
|
||||
&mut self,
|
||||
buffer: &'b mut [u8],
|
||||
pty_max: usize,
|
||||
) -> Result<Option<Event<'b>>, Error> {
|
||||
let fd = self.poll.wait(None)?.unwrap();
|
||||
|
||||
match fd {
|
||||
fd if fd == self.socket.as_raw_fd() => {
|
||||
let event = self.socket.recv_from(buffer)?;
|
||||
|
||||
let (message, remote) = match event {
|
||||
MultiplexedSocketEvent::ClientDisconnected(remote) => {
|
||||
self.remove_session_by_remote(remote).ok();
|
||||
return Ok(None);
|
||||
}
|
||||
MultiplexedSocketEvent::None(_) => return Ok(None),
|
||||
MultiplexedSocketEvent::ClientData(peer, data) => {
|
||||
let mut decoder = Decoder::new(data);
|
||||
let message = ClientMessage::decode(&mut decoder);
|
||||
(message, peer)
|
||||
}
|
||||
MultiplexedSocketEvent::Error(_) => return Ok(None),
|
||||
};
|
||||
|
||||
let message = match message {
|
||||
Ok(message) => message,
|
||||
Err(error) => {
|
||||
eprintln!("Decode error: {error}");
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let event = match message {
|
||||
ClientMessage::Hello(terminal)
|
||||
if self.addr_to_session.get(&remote).is_none() =>
|
||||
{
|
||||
Event::NewClient(remote, terminal)
|
||||
}
|
||||
ClientMessage::Bye(reason) => Event::ClientBye(remote, reason),
|
||||
ClientMessage::Input(data)
|
||||
if let Some(fd) = self.addr_to_session.get(&remote) =>
|
||||
{
|
||||
Event::SessionInput(*fd, remote, data)
|
||||
}
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
Ok(Some(event))
|
||||
}
|
||||
fd if fd == self.timer.as_raw_fd() => Ok(Some(Event::Tick)),
|
||||
fd => {
|
||||
let session = self.pty_to_session.get_mut(&fd).unwrap();
|
||||
let event = match session.pty_master.read(&mut buffer[..pty_max]) {
|
||||
Ok(0) => PtyEvent::Closed,
|
||||
Ok(len) => PtyEvent::Data(&buffer[..len]),
|
||||
Err(e) => PtyEvent::Err(e.into()),
|
||||
};
|
||||
Ok(Some(Event::Pty(fd, session.remote, event)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(mut self) -> Result<(), Error> {
|
||||
self.timer.start(PING_INTERVAL)?;
|
||||
let mut recv_buf = [0; 256];
|
||||
let mut send_buf = [0; 256];
|
||||
|
||||
loop {
|
||||
let Some(event) = self.poll(&mut recv_buf, 128)? else {
|
||||
continue;
|
||||
};
|
||||
|
||||
match event {
|
||||
Event::SessionInput(fd, remote, data) => {
|
||||
let session = self.pty_to_session.get_mut(&fd).unwrap();
|
||||
if let Err(error) = session.pty_master.write(&data) {
|
||||
eprintln!("PTY write error: {error}");
|
||||
self.socket
|
||||
.send_message_to(
|
||||
&remote,
|
||||
&mut send_buf,
|
||||
&ServerMessage::Bye("PTY error"),
|
||||
)
|
||||
.ok();
|
||||
self.remove_session_by_fd(fd)?;
|
||||
}
|
||||
}
|
||||
Event::ClientBye(remote, reason) => {
|
||||
println!("Client {remote} disconnected: {reason}");
|
||||
self.remove_session_by_remote(remote)?;
|
||||
}
|
||||
Event::NewClient(remote, terminal) => {
|
||||
println!("New client: {remote}");
|
||||
match Session::open(&terminal, remote) {
|
||||
Ok(session) => {
|
||||
self.register_session(remote, session)?;
|
||||
self.socket
|
||||
.send_message_to(&remote, &mut send_buf, &ServerMessage::Hello)
|
||||
.ok();
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("PTY open error: {err}");
|
||||
self.socket
|
||||
.send_message_to(
|
||||
&remote,
|
||||
&mut send_buf,
|
||||
&ServerMessage::Bye("PTY open error"),
|
||||
)
|
||||
.ok();
|
||||
self.socket.remove_client(&remote);
|
||||
}
|
||||
}
|
||||
}
|
||||
Event::Pty(fd, remote, event) => match event {
|
||||
PtyEvent::Data(data) => {
|
||||
self.socket
|
||||
.send_message_to(&remote, &mut send_buf, &ServerMessage::Output(data))
|
||||
.ok();
|
||||
}
|
||||
PtyEvent::Err(error) => {
|
||||
eprintln!("PTY read error: {error}");
|
||||
self.socket
|
||||
.send_message_to(
|
||||
&remote,
|
||||
&mut send_buf,
|
||||
&ServerMessage::Bye("PTY error"),
|
||||
)
|
||||
.ok();
|
||||
self.remove_session_by_fd(fd)?;
|
||||
}
|
||||
PtyEvent::Closed => {
|
||||
println!("End of PTY for {remote}");
|
||||
self.socket
|
||||
.send_message_to(&remote, &mut send_buf, &ServerMessage::Bye(""))
|
||||
.ok();
|
||||
self.remove_session_by_fd(fd)?;
|
||||
}
|
||||
},
|
||||
Event::Tick => {
|
||||
// Restart the timer
|
||||
self.update_client_timeouts()?;
|
||||
self.timer.start(PING_INTERVAL)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_client_timeouts(&mut self) -> Result<(), Error> {
|
||||
let removed = self.socket.ping_clients(8);
|
||||
for entry in removed {
|
||||
log::debug!("Client timed out: {entry}");
|
||||
self.remove_session_by_remote(entry).ok();
|
||||
}
|
||||
fn close(mut self) -> Result<(), Self::Error> {
|
||||
self.shell.wait()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn register_session(&mut self, remote: SocketAddr, session: Session) -> Result<(), Error> {
|
||||
let fd = session.pty_master.as_raw_fd();
|
||||
self.addr_to_session.insert(remote, fd);
|
||||
self.pty_to_session.insert(fd, session);
|
||||
self.poll.add(&fd).map_err(Error::from)
|
||||
fn peer(&self) -> SocketAddr {
|
||||
self.remote
|
||||
}
|
||||
|
||||
fn remove_session_by_fd(&mut self, fd: RawFd) -> Result<Option<Session>, Error> {
|
||||
if let Some(mut session) = self.pty_to_session.remove(&fd) {
|
||||
// TODO: implement kernel support for pidfd or something, to poll the exit status of
|
||||
// the task instead of doing it here.
|
||||
// NOTE: this will block the whole server while the process finishes.
|
||||
session.shell.wait().ok();
|
||||
self.addr_to_session.remove(&session.remote).unwrap();
|
||||
self.socket.remove_client(&session.remote);
|
||||
self.poll.remove(&fd)?;
|
||||
Ok(Some(session))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
fn handle_input<'s, S: rsh::socket::PacketSocket>(
|
||||
&mut self,
|
||||
input: &[u8],
|
||||
_client: rsh::server::SessionClient<'s, S>,
|
||||
) -> Result<bool, Self::Error> {
|
||||
use std::io::Write;
|
||||
self.pty_master.write_all(input)?;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
fn remove_session_by_remote(&mut self, remote: SocketAddr) -> Result<Option<Session>, Error> {
|
||||
let Some(fd) = self.addr_to_session.get(&remote).copied() else {
|
||||
return Ok(None);
|
||||
};
|
||||
self.remove_session_by_fd(fd)
|
||||
fn read_output(
|
||||
&mut self,
|
||||
fd: std::os::fd::RawFd,
|
||||
buffer: &mut [u8],
|
||||
) -> Result<usize, Self::Error> {
|
||||
use std::io::Read;
|
||||
assert_eq!(fd, self.fds[0]);
|
||||
self.pty_master.read(buffer)
|
||||
}
|
||||
|
||||
fn event_fds(&self) -> &[std::os::fd::RawFd] {
|
||||
&self.fds
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub type SessionImpl = rsh::server::EchoSession;
|
||||
#[cfg(target_os = "yggdrasil")]
|
||||
pub type SessionImpl = YggdrasilSession;
|
||||
|
||||
fn run(args: Args) -> Result<(), Error> {
|
||||
let keystore = Box::new(SimpleServerKeyStore {
|
||||
path: args.keystore,
|
||||
@ -312,7 +140,7 @@ fn run(args: Args) -> Result<(), Error> {
|
||||
..Default::default()
|
||||
};
|
||||
let listen_addr = SocketAddr::from_str(&format!("0.0.0.0:{}", args.port)).unwrap();
|
||||
let server = Server::new(listen_addr, server_config)?;
|
||||
let server = Server::<_, SessionImpl>::listen_udp(listen_addr, server_config)?;
|
||||
server.run()
|
||||
}
|
||||
|
||||
|
367
userspace/rsh/src/server.rs
Normal file
367
userspace/rsh/src/server.rs
Normal file
@ -0,0 +1,367 @@
|
||||
use std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
fmt,
|
||||
net::{SocketAddr, UdpSocket},
|
||||
os::fd::{AsRawFd, RawFd},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use cross::io::{Poll, TimerFd};
|
||||
|
||||
use crate::{
|
||||
crypt::{server::ServerConfig, ServerEncryptedSocket},
|
||||
proto::{ClientMessage, Decode, Decoder, ServerMessage, TerminalInfo},
|
||||
socket::{MultiplexedSocket, MultiplexedSocketEvent, PacketSocket},
|
||||
Error,
|
||||
};
|
||||
|
||||
pub const PING_INTERVAL: Duration = Duration::from_millis(500);
|
||||
|
||||
pub trait Session: Sized {
|
||||
type Error: fmt::Display;
|
||||
|
||||
fn open(peer: &SocketAddr, terminal: &TerminalInfo) -> Result<Self, Self::Error>;
|
||||
fn peer(&self) -> SocketAddr;
|
||||
fn handle_input<'s, S: PacketSocket>(
|
||||
&mut self,
|
||||
input: &[u8],
|
||||
client: SessionClient<'s, S>,
|
||||
) -> Result<bool, Self::Error>;
|
||||
fn read_output(&mut self, fd: RawFd, buffer: &mut [u8]) -> Result<usize, Self::Error>;
|
||||
fn event_fds(&self) -> &[RawFd];
|
||||
fn close(self) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
pub struct SessionClient<'s, S: PacketSocket> {
|
||||
address: SocketAddr,
|
||||
transport: &'s mut ServerEncryptedSocket<S>,
|
||||
send_buf: &'s mut [u8],
|
||||
}
|
||||
|
||||
pub struct EchoSession {
|
||||
peer: SocketAddr,
|
||||
}
|
||||
|
||||
enum SessionEvent<'b, T: Session> {
|
||||
Data(&'b [u8]),
|
||||
Err(T::Error),
|
||||
Closed,
|
||||
}
|
||||
|
||||
enum Event<'b, T: Session> {
|
||||
NewClient(SocketAddr, TerminalInfo),
|
||||
SessionInput(u64, SocketAddr, &'b [u8]),
|
||||
ClientBye(SocketAddr, &'b str),
|
||||
SessionEvent(RawFd, SocketAddr, SessionEvent<'b, T>),
|
||||
Tick,
|
||||
}
|
||||
|
||||
pub struct Server<S: PacketSocket, T: Session> {
|
||||
poll: Poll,
|
||||
timer: TimerFd,
|
||||
socket: ServerEncryptedSocket<S>,
|
||||
|
||||
last_session_key: u64,
|
||||
sessions: HashMap<u64, T>,
|
||||
peer_to_session: HashMap<SocketAddr, u64>,
|
||||
session_event_map: HashMap<RawFd, u64>,
|
||||
}
|
||||
|
||||
impl<T: Session> Server<UdpSocket, T> {
|
||||
pub fn listen_udp(listen_addr: SocketAddr, crypto_config: ServerConfig) -> Result<Self, Error> {
|
||||
let mut poll = Poll::new()?;
|
||||
let timer = TimerFd::new()?;
|
||||
let socket = UdpSocket::bind(listen_addr)?;
|
||||
let socket = ServerEncryptedSocket::new_with_config(socket, crypto_config);
|
||||
poll.add(&socket)?;
|
||||
poll.add(&timer)?;
|
||||
Ok(Self {
|
||||
poll,
|
||||
socket,
|
||||
timer,
|
||||
last_session_key: 1,
|
||||
sessions: HashMap::new(),
|
||||
peer_to_session: HashMap::new(),
|
||||
session_event_map: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn poll<'b>(
|
||||
&mut self,
|
||||
buffer: &'b mut [u8],
|
||||
pty_max: usize,
|
||||
) -> Result<Option<Event<'b, T>>, Error> {
|
||||
let fd = self.poll.wait(None)?.unwrap();
|
||||
|
||||
match fd {
|
||||
fd if fd == self.socket.as_raw_fd() => {
|
||||
let event = self.socket.recv_from(buffer)?;
|
||||
|
||||
let (message, remote) = match event {
|
||||
MultiplexedSocketEvent::ClientDisconnected(remote) => {
|
||||
self.remove_session_by_remote(remote).ok();
|
||||
return Ok(None);
|
||||
}
|
||||
MultiplexedSocketEvent::None(_) => return Ok(None),
|
||||
MultiplexedSocketEvent::ClientData(peer, data) => {
|
||||
let mut decoder = Decoder::new(data);
|
||||
let message = ClientMessage::decode(&mut decoder);
|
||||
(message, peer)
|
||||
}
|
||||
MultiplexedSocketEvent::Error(_) => return Ok(None),
|
||||
};
|
||||
|
||||
let message = match message {
|
||||
Ok(message) => message,
|
||||
Err(error) => {
|
||||
log::warn!("Decode error: {error}");
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let event = match message {
|
||||
ClientMessage::Hello(terminal)
|
||||
if self.peer_to_session.get(&remote).is_none() =>
|
||||
{
|
||||
Event::NewClient(remote, terminal)
|
||||
}
|
||||
ClientMessage::Bye(reason) => Event::ClientBye(remote, reason),
|
||||
ClientMessage::Input(data)
|
||||
if let Some(fd) = self.peer_to_session.get(&remote) =>
|
||||
{
|
||||
Event::SessionInput(*fd, remote, data)
|
||||
}
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
Ok(Some(event))
|
||||
}
|
||||
fd if fd == self.timer.as_raw_fd() => Ok(Some(Event::Tick)),
|
||||
fd => {
|
||||
// Otherwise the event comes from a session
|
||||
let key = *self.session_event_map.get(&fd).unwrap();
|
||||
let session = self.sessions.get_mut(&key).unwrap();
|
||||
let event = match session.read_output(fd, &mut buffer[..pty_max]) {
|
||||
Ok(0) => SessionEvent::Closed,
|
||||
Ok(len) => SessionEvent::Data(&buffer[..len]),
|
||||
Err(e) => SessionEvent::Err(e),
|
||||
};
|
||||
Ok(Some(Event::SessionEvent(fd, session.peer(), event)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(mut self) -> Result<(), Error> {
|
||||
self.timer.start(PING_INTERVAL)?;
|
||||
let mut recv_buf = [0; 256];
|
||||
let mut send_buf = [0; 256];
|
||||
|
||||
loop {
|
||||
let Some(event) = self.poll(&mut recv_buf, 128)? else {
|
||||
continue;
|
||||
};
|
||||
|
||||
match event {
|
||||
Event::SessionInput(key, remote, data) => {
|
||||
let session = self.sessions.get_mut(&key).unwrap();
|
||||
let peer = SessionClient {
|
||||
address: remote,
|
||||
send_buf: &mut send_buf,
|
||||
transport: &mut self.socket,
|
||||
};
|
||||
match session.handle_input(data, peer) {
|
||||
Ok(false) => (),
|
||||
Ok(true) => {
|
||||
log::debug!("{remote}: session closed");
|
||||
self.socket
|
||||
.send_message_to(
|
||||
&remote,
|
||||
&mut send_buf,
|
||||
&ServerMessage::Bye("Session closed"),
|
||||
)
|
||||
.ok();
|
||||
self.remove_session_by_key(key)?;
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("{remote}: session input error: {error}");
|
||||
self.socket
|
||||
.send_message_to(
|
||||
&remote,
|
||||
&mut send_buf,
|
||||
&ServerMessage::Bye("Session error"),
|
||||
)
|
||||
.ok();
|
||||
self.remove_session_by_key(key)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Event::ClientBye(remote, reason) => {
|
||||
log::debug!("Client {remote} disconnected: {reason}");
|
||||
self.remove_session_by_remote(remote)?;
|
||||
}
|
||||
Event::NewClient(remote, terminal) => {
|
||||
log::debug!("New client: {remote}");
|
||||
match T::open(&remote, &terminal) {
|
||||
Ok(session) => {
|
||||
self.register_session(remote, session)?;
|
||||
self.socket
|
||||
.send_message_to(&remote, &mut send_buf, &ServerMessage::Hello)
|
||||
.ok();
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("PTY open error: {err}");
|
||||
self.socket
|
||||
.send_message_to(
|
||||
&remote,
|
||||
&mut send_buf,
|
||||
&ServerMessage::Bye("PTY open error"),
|
||||
)
|
||||
.ok();
|
||||
self.socket.remove_client(&remote);
|
||||
}
|
||||
}
|
||||
}
|
||||
Event::SessionEvent(fd, remote, event) => match event {
|
||||
SessionEvent::Data(data) => {
|
||||
self.socket
|
||||
.send_message_to(&remote, &mut send_buf, &ServerMessage::Output(data))
|
||||
.ok();
|
||||
}
|
||||
SessionEvent::Err(error) => {
|
||||
log::error!("Session output read error: {error}");
|
||||
self.socket
|
||||
.send_message_to(
|
||||
&remote,
|
||||
&mut send_buf,
|
||||
&ServerMessage::Bye("Session error"),
|
||||
)
|
||||
.ok();
|
||||
self.remove_session_by_fd(fd)?;
|
||||
}
|
||||
SessionEvent::Closed => {
|
||||
log::debug!("Session closed for {remote}");
|
||||
self.socket
|
||||
.send_message_to(&remote, &mut send_buf, &ServerMessage::Bye(""))
|
||||
.ok();
|
||||
self.remove_session_by_fd(fd)?;
|
||||
}
|
||||
},
|
||||
Event::Tick => {
|
||||
// Restart the timer
|
||||
self.update_client_timeouts()?;
|
||||
self.timer.start(PING_INTERVAL)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_client_timeouts(&mut self) -> Result<(), Error> {
|
||||
let removed = self.socket.ping_clients(8);
|
||||
for entry in removed {
|
||||
log::debug!("Client timed out: {entry}");
|
||||
self.remove_session_by_remote(entry).ok();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn register_session(&mut self, remote: SocketAddr, session: T) -> Result<(), Error> {
|
||||
let (key, session) = loop {
|
||||
let key = self.last_session_key;
|
||||
self.last_session_key += 1;
|
||||
match self.sessions.entry(key) {
|
||||
Entry::Occupied(_) => continue,
|
||||
Entry::Vacant(entry) => {
|
||||
let session = entry.insert(session);
|
||||
break (key, session);
|
||||
}
|
||||
}
|
||||
};
|
||||
for fd in session.event_fds() {
|
||||
self.poll.add(fd)?;
|
||||
self.session_event_map.insert(*fd, key);
|
||||
}
|
||||
self.peer_to_session.insert(remote, key);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_session_by_key(&mut self, key: u64) -> Result<(), Error> {
|
||||
let Some(session) = self.sessions.remove(&key) else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for fd in session.event_fds() {
|
||||
self.poll.remove(fd)?;
|
||||
self.session_event_map.remove(fd);
|
||||
}
|
||||
self.peer_to_session.remove(&session.peer()).unwrap();
|
||||
self.socket.remove_client(&session.peer());
|
||||
|
||||
if let Err(error) = session.close() {
|
||||
log::warn!("Session close error: {error}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_session_by_fd(&mut self, fd: RawFd) -> Result<(), Error> {
|
||||
let Some(key) = self.session_event_map.get(&fd).copied() else {
|
||||
return Ok(());
|
||||
};
|
||||
self.remove_session_by_key(key)
|
||||
}
|
||||
|
||||
fn remove_session_by_remote(&mut self, remote: SocketAddr) -> Result<(), Error> {
|
||||
let Some(key) = self.peer_to_session.get(&remote).copied() else {
|
||||
return Ok(());
|
||||
};
|
||||
self.remove_session_by_key(key)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'s, S: PacketSocket> SessionClient<'s, S> {
|
||||
pub fn send_data(&mut self, data: &[u8]) -> Result<(), Error> {
|
||||
self.send_message(&ServerMessage::Output(data))
|
||||
}
|
||||
|
||||
pub fn send_message(&mut self, message: &ServerMessage) -> Result<(), Error> {
|
||||
self.transport
|
||||
.send_message_to(&self.address, self.send_buf, message)
|
||||
}
|
||||
}
|
||||
|
||||
impl Session for EchoSession {
|
||||
type Error = Error;
|
||||
|
||||
fn open(peer: &SocketAddr, _terminal: &TerminalInfo) -> Result<Self, Self::Error> {
|
||||
Ok(Self { peer: *peer })
|
||||
}
|
||||
|
||||
fn close(self) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn event_fds(&self) -> &[RawFd] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn read_output(&mut self, _fd: RawFd, _buffer: &mut [u8]) -> Result<usize, Self::Error> {
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
fn handle_input<'s, S: PacketSocket>(
|
||||
&mut self,
|
||||
input: &[u8],
|
||||
mut client: SessionClient<'s, S>,
|
||||
) -> Result<bool, Self::Error> {
|
||||
if input.contains(&b'\x04') {
|
||||
return Ok(true);
|
||||
}
|
||||
log::debug!("{:02x?}", input);
|
||||
client.send_data(input)?;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
fn peer(&self) -> SocketAddr {
|
||||
self.peer
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user