248 lines
7.0 KiB
Rust

#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))]
use std::{
io::{self, stdin, stdout, Read, Stdin, Stdout, Write as IoWrite},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket},
os::fd::AsRawFd,
process::ExitCode,
time::{Duration, Instant},
};
use clap::Parser;
use cross::io::Poll;
use libterm::{RawMode, RawTerminal};
use rsh::{
crypt::ClientEncryptedSocket, proto::{ClientMessage, ServerMessage, TerminalInfo}, socket::{MessageSocket, PacketSocket}, ClientSocket
};
pub const PING_TIMEOUT: Duration = Duration::from_secs(3);
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("Protocol error: {0}")]
Protocol(#[from] rsh::Error),
#[error("Terminal error: {0}")]
Terminal(#[from] libterm::Error),
#[error("Disconnected by the server: {0}")]
Disconnected(String),
#[error("Timed out")]
Timeout,
}
#[derive(Debug, Parser)]
struct Args {
remote: IpAddr,
}
pub struct Client {
poll: Poll,
socket: ClientSocket<ClientEncryptedSocket<UdpSocket>>,
stdin: Stdin,
stdout: Stdout,
need_bye: bool,
_raw: RawMode,
}
pub enum Event<'b> {
Stdin(&'b [u8]),
Data(&'b [u8]),
Disconnected(&'b str),
}
impl Client {
pub fn connect(remote: SocketAddr) -> Result<Self, Error> {
let mut poll = Poll::new()?;
let stdin = stdin();
let stdout = stdout();
let local: SocketAddr = match remote {
SocketAddr::V4(_) => SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(),
SocketAddr::V6(_) => SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0).into(),
};
let info = terminal_info(&stdout)?;
let socket = UdpSocket::bind(local)?;
socket.connect(remote)?;
let mut socket = ClientEncryptedSocket::new(socket);
poll.add(&socket)?;
socket.try_connect_blocking(&mut poll, Duration::from_secs(1))?;
let mut socket = ClientSocket::new(socket);
// Do the handshake thing
Self::handshake(&mut poll, &mut socket, info, 5)?;
// TODO handle bye if something fails here
poll.add(&stdin)?;
let _raw = unsafe { RawMode::enter(&stdin)? };
Ok(Self {
poll,
socket,
stdout,
stdin,
need_bye: true,
_raw,
})
}
pub fn poll<'b>(&mut self, buffer: &'b mut [u8], pty_max: usize) -> Result<Option<Event<'b>>, Error> {
// let mut buf = [0; 16];
let event = loop {
let Some(event) = self.poll.wait(Some(Duration::from_millis(500)))? else {
self.check_timeout()?;
continue;
};
match event {
fd if fd == self.socket.as_raw_fd() => {
let message = match self.socket.recv_from(buffer) {
Ok((message, _)) => message,
Err(rsh::Error::Ping) => return Ok(None),
Err(error) => return Err(error.into())
};
match message {
ServerMessage::Bye(reason) => {
// No need for a bye
self.need_bye = false;
break Ok(Some(Event::Disconnected(reason)));
}
ServerMessage::Output(data) => {
break Ok(Some(Event::Data(data)));
}
// Ignore this one
ServerMessage::Hello => break Ok(None),
}
}
fd if fd == self.stdin.as_raw_fd() => {
let len = self.stdin.read(&mut buffer[..pty_max])?;
break Ok(Some(Event::Stdin(&buffer[..len])));
}
_ => unreachable!()
}
};
self.check_timeout()?;
event
}
pub fn run(mut self) -> Result<String, Error> {
let mut recv_buf = [0; 256];
let mut send_buf = [0; 256];
loop {
let Some(event) = self.poll(&mut recv_buf, 64)? else {
continue;
};
match event {
Event::Data(data) => {
self.stdout.write_all(data)?;
self.stdout.flush()?;
}
Event::Stdin(data) => {
self.socket.send(&mut send_buf, &ClientMessage::Input(data))?;
}
Event::Disconnected(reason) => {
break Ok(reason.into());
}
}
}
}
fn check_timeout(&mut self) -> Result<(), Error> {
let now = Instant::now();
if now - self.socket.as_inner_mut().last_ping() >= PING_TIMEOUT {
Err(Error::Timeout)
} else {
Ok(())
}
}
fn handshake<S: PacketSocket>(
poll: &mut Poll,
socket: &mut ClientSocket<S>,
terminal: TerminalInfo,
attempts: usize,
) -> Result<(), Error> {
assert_ne!(attempts, 0);
let mut timeout = Duration::from_millis(500);
for _ in 0..attempts - 1 {
if Self::try_handshake(poll, socket, terminal, timeout).is_ok() {
return Ok(());
}
timeout *= 2;
}
Self::try_handshake(poll, socket, terminal, timeout)
}
fn try_handshake<S: PacketSocket>(
poll: &mut Poll,
socket: &mut ClientSocket<S>,
terminal: TerminalInfo,
timeout: Duration,
) -> Result<(), Error> {
let mut buffer = [0; 512];
socket.send(&mut buffer, &ClientMessage::Hello(terminal))?;
if poll.wait(Some(timeout))?.is_none() {
return Err(Error::Timeout);
};
let (message, _) = socket.recv_from(&mut buffer)?;
match message {
ServerMessage::Hello => Ok(()),
ServerMessage::Bye(reason) => Err(Error::Disconnected(reason.into())),
_ => Err(Error::Disconnected("Invalid message received".into())),
}
}
}
impl Drop for Client {
fn drop(&mut self) {
if self.need_bye {
let mut buf = [0; 32];
self.socket
.send(&mut buf, &ClientMessage::Bye("".into()))
.ok();
}
}
}
fn terminal_info(stdout: &Stdout) -> Result<TerminalInfo, Error> {
let (columns, rows) = stdout.raw_size()?;
Ok(TerminalInfo {
columns: columns as _,
rows: rows as _,
})
}
fn run(args: Args) -> Result<(), Error> {
let remote = SocketAddr::new(args.remote, 7777);
let reason = Client::connect(remote)?.run()?;
if !reason.is_empty() {
eprintln!("\nDisconnected: {reason}");
}
Ok(())
}
fn main() -> ExitCode {
env_logger::init();
let args = Args::parse();
if let Err(error) = run(args) {
eprintln!("Error: {error}");
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}