327 lines
8.6 KiB
Rust

#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))]
#![feature(let_chains)]
use std::{
io::{stderr, stdout, IsTerminal, Read, Stderr, Write},
ops::{Deref, DerefMut},
os::fd::AsRawFd,
};
use clap::Parser;
use cross::io::Poll;
use rsh::{
crypt::{
client::{self, ClientSocket, Message},
config::{ClientConfig, SimpleClientKeyStore},
signature::{SignEd25519, SignatureMethod},
},
proto::{ServerMessage, StreamIndex},
};
use std::{
io::{self, stdin, Stdin, Stdout},
net::{IpAddr, SocketAddr},
path::PathBuf,
process::ExitCode,
time::Duration,
};
use libterm::{RawMode, RawTerminal};
use rsh::proto::{ClientMessage, TerminalInfo};
pub const PING_TIMEOUT: Duration = Duration::from_secs(3);
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("Terminal error: {0}")]
Terminal(#[from] libterm::Error),
#[error("Disconnected by the server: {0}")]
Disconnected(String),
#[error("Timed out")]
Timeout,
#[error("Socket error: {0}")]
Socket(#[from] client::Error),
#[error("Aborted by user")]
Abort,
}
#[derive(Debug, Parser)]
struct Args {
#[clap(short, long)]
key: PathBuf,
#[clap(short = 'P', long, default_value_t = 77)]
port: u16,
remote: IpAddr,
command: Vec<String>,
}
struct RawStdin {
stdin: Stdin,
raw: Option<RawMode>,
}
pub struct Client {
poll: Poll,
socket: ClientSocket,
stdin: RawStdin,
stdout: Stdout,
stderr: Stderr,
last0: u8,
last1: u8,
}
pub enum Event<'a, 'b> {
Data(&'a [u8]),
Stdin(&'b [u8]),
Disconnected(&'b str),
}
impl RawStdin {
pub fn open() -> Result<Self, Error> {
let stdin = stdin();
let raw = if stdin.is_terminal() {
Some(unsafe { RawMode::enter(&stdin) }?)
} else {
None
};
Ok(Self { stdin, raw })
}
}
impl Deref for RawStdin {
type Target = Stdin;
fn deref(&self) -> &Self::Target {
&self.stdin
}
}
impl DerefMut for RawStdin {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.stdin
}
}
impl AsRawFd for RawStdin {
fn as_raw_fd(&self) -> std::os::fd::RawFd {
self.stdin.as_raw_fd()
}
}
impl Drop for RawStdin {
fn drop(&mut self) {
if let Some(raw) = &mut self.raw {
unsafe { raw.leave(&self.stdin) };
}
}
}
impl Client {
pub fn connect(remote: SocketAddr, crypto_config: ClientConfig) -> Result<Self, Error> {
let mut poll = Poll::new()?;
let mut socket = ClientSocket::connect(remote, crypto_config)?;
let stdin = RawStdin::open()?;
let stdout = stdout();
let stderr = stderr();
poll.add(&stdin)?;
poll.add(&socket)?;
let info = terminal_info(&stdout)?;
Self::handshake(&mut socket, info)?;
Ok(Self {
stdin,
stdout,
stderr,
socket,
poll,
last0: 0,
last1: 0,
})
}
fn update_last(&mut self, buffer: &[u8]) -> Result<(), Error> {
if buffer.len() >= 2 {
self.last0 = buffer[buffer.len() - 2];
self.last1 = buffer[buffer.len() - 1];
} else if buffer.len() >= 1 {
self.last0 = self.last1;
self.last1 = buffer[buffer.len() - 1];
} else {
self.last0 = self.last1;
self.last1 = 0;
}
if self.last0 == b'\x1B' && self.last1 == b'~' {
Err(Error::Abort)
} else {
Ok(())
}
}
pub fn run(mut self) -> Result<String, Error> {
let mut recv_buf = [0; 512];
loop {
if let Some(message) = self.socket.read(&mut recv_buf)? {
match message {
ServerMessage::Bye(reason) => return Ok(reason.into()),
ServerMessage::Output(StreamIndex::Stdout, data) => {
self.stdout.write_all(data).ok();
self.stdout.flush().ok();
continue;
}
ServerMessage::Output(StreamIndex::Stderr, data) => {
self.stderr.write_all(data).ok();
self.stderr.flush().ok();
continue;
}
_ => continue,
}
}
let fd = self.poll.wait(None)?.unwrap();
if fd == self.socket.as_raw_fd() {
if self.socket.poll()? == 0 {
return Ok("".into());
}
} else if self.stdin.as_raw_fd() == fd {
let len = self.stdin.read(&mut recv_buf)?;
self.update_last(&recv_buf[..len])?;
self.socket
.write_all(&ClientMessage::Input(&recv_buf[..len]))?;
} else {
unreachable!()
}
}
}
fn handshake(socket: &mut ClientSocket, info: TerminalInfo) -> Result<(), Error> {
let mut recv_buf = [0; 256];
socket.write_all(&ClientMessage::OpenSession(info))?;
loop {
if let Some(message) = socket.read(&mut recv_buf)? {
match message {
ServerMessage::SessionOpen => return Ok(()),
_ => return Err(Error::Disconnected("Unexpected server message".into())),
}
}
if socket.poll()? == 0 {
todo!()
}
}
}
}
fn terminal_info(stdout: &Stdout) -> Result<TerminalInfo, Error> {
let (columns, rows) = stdout.raw_size()?;
Ok(TerminalInfo {
columns: columns as _,
rows: rows as _,
})
}
fn run_terminal(remote: SocketAddr, config: ClientConfig) -> Result<(), Error> {
let reason = Client::connect(remote, config)?.run()?;
if !reason.is_empty() {
eprintln!("\nDisconnected: {reason}");
}
Ok(())
}
fn run_command(
remote: SocketAddr,
config: ClientConfig,
command: Vec<String>,
) -> Result<ExitCode, Error> {
let mut poll = Poll::new()?;
let mut buffer = [0; 512];
let mut command_string = String::new();
for (i, word) in command.iter().enumerate() {
if i != 0 {
command_string.push(' ');
}
command_string.push_str(word);
}
let mut stdin = stdin();
let mut stdout = stdout();
let mut stderr = stderr();
let mut socket = ClientSocket::connect(remote, config)?;
poll.add(&socket)?;
poll.add(&stdin)?;
socket.write_all(&ClientMessage::RunCommand(command_string.as_str()))?;
loop {
let fd = poll.wait(None)?.unwrap();
match fd {
_ if fd == socket.as_raw_fd() => {
let message = match socket.poll_read(&mut buffer)? {
Message::Data(data) => data,
Message::Incomplete => continue,
Message::Closed => break,
};
match message {
ServerMessage::Output(StreamIndex::Stdout, output) => {
stdout.write_all(output).ok();
stdout.flush().ok();
}
ServerMessage::Output(StreamIndex::Stderr, output) => {
stderr.write_all(output).ok();
stderr.flush().ok();
}
_ => todo!(),
}
}
_ if fd == stdin.as_raw_fd() => {
let len = stdin.read(&mut buffer)?;
if len == 0 {
poll.remove(&stdin)?;
socket.write_all(&ClientMessage::CloseStdin)?;
} else {
socket.write_all(&ClientMessage::Input(&buffer[..len]))?;
}
}
_ => unreachable!(),
}
}
Ok(ExitCode::SUCCESS)
}
fn run(args: Args) -> Result<ExitCode, Error> {
let remote = SocketAddr::new(args.remote, args.port);
let ed25519 = SignEd25519::load_signing_key(args.key).unwrap();
let key = SignatureMethod::Ed25519(ed25519);
let config = ClientConfig::with_default_algorithms(SimpleClientKeyStore::new(key));
if args.command.is_empty() {
run_terminal(remote, config).map(|_| ExitCode::SUCCESS)
} else {
run_command(remote, config, args.command)
}
}
fn main() -> ExitCode {
env_logger::init();
let args = Args::parse();
match run(args) {
Ok(status) => status,
Err(error) => {
eprintln!("Error: {error}");
ExitCode::FAILURE
}
}
}