#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))] #![feature(let_chains)] use std::{ io::{stderr, stdout, IsTerminal, Read, Stderr, Write}, ops::{Deref, DerefMut}, os::fd::AsRawFd, }; use clap::Parser; use cross::io::Poll; use rsh::{ crypt::{ client::{self, ClientSocket, Message}, config::{ClientConfig, SimpleClientKeyStore}, signature::{SignEd25519, SignatureMethod}, }, proto::{ServerMessage, StreamIndex}, }; use std::{ io::{self, stdin, Stdin, Stdout}, net::{IpAddr, SocketAddr}, path::PathBuf, process::ExitCode, time::Duration, }; use libterm::{RawMode, RawTerminal}; use rsh::proto::{ClientMessage, TerminalInfo}; pub const PING_TIMEOUT: Duration = Duration::from_secs(3); #[derive(Debug, thiserror::Error)] pub enum Error { #[error("I/O error: {0}")] Io(#[from] io::Error), #[error("Terminal error: {0}")] Terminal(#[from] libterm::Error), #[error("Disconnected by the server: {0}")] Disconnected(String), #[error("Timed out")] Timeout, #[error("Socket error: {0}")] Socket(#[from] client::Error), #[error("Aborted by user")] Abort, } #[derive(Debug, Parser)] struct Args { #[clap(short, long)] key: PathBuf, #[clap(short = 'P', long, default_value_t = 77)] port: u16, remote: IpAddr, command: Vec, } struct RawStdin { stdin: Stdin, raw: Option, } pub struct Client { poll: Poll, socket: ClientSocket, stdin: RawStdin, stdout: Stdout, stderr: Stderr, last0: u8, last1: u8, } pub enum Event<'a, 'b> { Data(&'a [u8]), Stdin(&'b [u8]), Disconnected(&'b str), } impl RawStdin { pub fn open() -> Result { let stdin = stdin(); let raw = if stdin.is_terminal() { Some(unsafe { RawMode::enter(&stdin) }?) } else { None }; Ok(Self { stdin, raw }) } } impl Deref for RawStdin { type Target = Stdin; fn deref(&self) -> &Self::Target { &self.stdin } } impl DerefMut for RawStdin { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.stdin } } impl AsRawFd for RawStdin { fn as_raw_fd(&self) -> std::os::fd::RawFd { self.stdin.as_raw_fd() } } impl Drop for RawStdin { fn drop(&mut self) { if let Some(raw) = &mut self.raw { unsafe { raw.leave(&self.stdin) }; } } } impl Client { pub fn connect(remote: SocketAddr, crypto_config: ClientConfig) -> Result { let mut poll = Poll::new()?; let mut socket = ClientSocket::connect(remote, crypto_config)?; let stdin = RawStdin::open()?; let stdout = stdout(); let stderr = stderr(); poll.add(&stdin)?; poll.add(&socket)?; let info = terminal_info(&stdout)?; Self::handshake(&mut socket, info)?; Ok(Self { stdin, stdout, stderr, socket, poll, last0: 0, last1: 0, }) } fn update_last(&mut self, buffer: &[u8]) -> Result<(), Error> { if buffer.len() >= 2 { self.last0 = buffer[buffer.len() - 2]; self.last1 = buffer[buffer.len() - 1]; } else if buffer.len() >= 1 { self.last0 = self.last1; self.last1 = buffer[buffer.len() - 1]; } else { self.last0 = self.last1; self.last1 = 0; } if self.last0 == b'\x1B' && self.last1 == b'~' { Err(Error::Abort) } else { Ok(()) } } pub fn run(mut self) -> Result { let mut recv_buf = [0; 512]; loop { if let Some(message) = self.socket.read(&mut recv_buf)? { match message { ServerMessage::Bye(reason) => return Ok(reason.into()), ServerMessage::Output(StreamIndex::Stdout, data) => { self.stdout.write_all(data).ok(); self.stdout.flush().ok(); continue; } ServerMessage::Output(StreamIndex::Stderr, data) => { self.stderr.write_all(data).ok(); self.stderr.flush().ok(); continue; } _ => continue, } } let fd = self.poll.wait(None)?.unwrap(); if fd == self.socket.as_raw_fd() { if self.socket.poll()? == 0 { return Ok("".into()); } } else if self.stdin.as_raw_fd() == fd { let len = self.stdin.read(&mut recv_buf)?; self.update_last(&recv_buf[..len])?; self.socket .write_all(&ClientMessage::Input(&recv_buf[..len]))?; } else { unreachable!() } } } fn handshake(socket: &mut ClientSocket, info: TerminalInfo) -> Result<(), Error> { let mut recv_buf = [0; 256]; socket.write_all(&ClientMessage::OpenSession(info))?; loop { if let Some(message) = socket.read(&mut recv_buf)? { match message { ServerMessage::SessionOpen => return Ok(()), _ => return Err(Error::Disconnected("Unexpected server message".into())), } } if socket.poll()? == 0 { todo!() } } } } fn terminal_info(stdout: &Stdout) -> Result { let (columns, rows) = stdout.raw_size()?; Ok(TerminalInfo { columns: columns as _, rows: rows as _, }) } fn run_terminal(remote: SocketAddr, config: ClientConfig) -> Result<(), Error> { let reason = Client::connect(remote, config)?.run()?; if !reason.is_empty() { eprintln!("\nDisconnected: {reason}"); } Ok(()) } fn run_command( remote: SocketAddr, config: ClientConfig, command: Vec, ) -> Result { let mut poll = Poll::new()?; let mut buffer = [0; 512]; let mut command_string = String::new(); for (i, word) in command.iter().enumerate() { if i != 0 { command_string.push(' '); } command_string.push_str(word); } let mut stdin = stdin(); let mut stdout = stdout(); let mut stderr = stderr(); let mut socket = ClientSocket::connect(remote, config)?; poll.add(&socket)?; poll.add(&stdin)?; socket.write_all(&ClientMessage::RunCommand(command_string.as_str()))?; loop { let fd = poll.wait(None)?.unwrap(); match fd { _ if fd == socket.as_raw_fd() => { let message = match socket.poll_read(&mut buffer)? { Message::Data(data) => data, Message::Incomplete => continue, Message::Closed => break, }; match message { ServerMessage::Output(StreamIndex::Stdout, output) => { stdout.write_all(output).ok(); stdout.flush().ok(); } ServerMessage::Output(StreamIndex::Stderr, output) => { stderr.write_all(output).ok(); stderr.flush().ok(); } _ => todo!(), } } _ if fd == stdin.as_raw_fd() => { let len = stdin.read(&mut buffer)?; if len == 0 { poll.remove(&stdin)?; socket.write_all(&ClientMessage::CloseStdin)?; } else { socket.write_all(&ClientMessage::Input(&buffer[..len]))?; } } _ => unreachable!(), } } Ok(ExitCode::SUCCESS) } fn run(args: Args) -> Result { let remote = SocketAddr::new(args.remote, args.port); let ed25519 = SignEd25519::load_signing_key(args.key).unwrap(); let key = SignatureMethod::Ed25519(ed25519); let config = ClientConfig::with_default_algorithms(SimpleClientKeyStore::new(key)); if args.command.is_empty() { run_terminal(remote, config).map(|_| ExitCode::SUCCESS) } else { run_command(remote, config, args.command) } } fn main() -> ExitCode { env_logger::init(); let args = Args::parse(); match run(args) { Ok(status) => status, Err(error) => { eprintln!("Error: {error}"); ExitCode::FAILURE } } }