248 lines
7.0 KiB
Rust
248 lines
7.0 KiB
Rust
#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))]
|
|
|
|
use std::{
|
|
io::{self, stdin, stdout, Read, Stdin, Stdout, Write as IoWrite},
|
|
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket},
|
|
os::fd::AsRawFd,
|
|
process::ExitCode,
|
|
time::{Duration, Instant},
|
|
};
|
|
|
|
use clap::Parser;
|
|
use cross::io::Poll;
|
|
use libterm::{RawMode, RawTerminal};
|
|
use rsh::{
|
|
crypt::ClientEncryptedSocket, proto::{ClientMessage, ServerMessage, TerminalInfo}, socket::{MessageSocket, PacketSocket}, ClientSocket
|
|
};
|
|
|
|
pub const PING_TIMEOUT: Duration = Duration::from_secs(3);
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum Error {
|
|
#[error("I/O error: {0}")]
|
|
Io(#[from] io::Error),
|
|
#[error("Protocol error: {0}")]
|
|
Protocol(#[from] rsh::Error),
|
|
#[error("Terminal error: {0}")]
|
|
Terminal(#[from] libterm::Error),
|
|
#[error("Disconnected by the server: {0}")]
|
|
Disconnected(String),
|
|
#[error("Timed out")]
|
|
Timeout,
|
|
}
|
|
|
|
#[derive(Debug, Parser)]
|
|
struct Args {
|
|
remote: IpAddr,
|
|
}
|
|
|
|
pub struct Client {
|
|
poll: Poll,
|
|
socket: ClientSocket<ClientEncryptedSocket<UdpSocket>>,
|
|
stdin: Stdin,
|
|
stdout: Stdout,
|
|
need_bye: bool,
|
|
_raw: RawMode,
|
|
}
|
|
|
|
pub enum Event<'b> {
|
|
Stdin(&'b [u8]),
|
|
Data(&'b [u8]),
|
|
Disconnected(&'b str),
|
|
}
|
|
|
|
impl Client {
|
|
pub fn connect(remote: SocketAddr) -> Result<Self, Error> {
|
|
let mut poll = Poll::new()?;
|
|
let stdin = stdin();
|
|
let stdout = stdout();
|
|
let local: SocketAddr = match remote {
|
|
SocketAddr::V4(_) => SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(),
|
|
SocketAddr::V6(_) => SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0).into(),
|
|
};
|
|
|
|
let info = terminal_info(&stdout)?;
|
|
|
|
let socket = UdpSocket::bind(local)?;
|
|
socket.connect(remote)?;
|
|
|
|
let mut socket = ClientEncryptedSocket::new(socket);
|
|
|
|
poll.add(&socket)?;
|
|
|
|
socket.try_connect_blocking(&mut poll, Duration::from_secs(1))?;
|
|
|
|
let mut socket = ClientSocket::new(socket);
|
|
|
|
// Do the handshake thing
|
|
Self::handshake(&mut poll, &mut socket, info, 5)?;
|
|
|
|
// TODO handle bye if something fails here
|
|
poll.add(&stdin)?;
|
|
|
|
let _raw = unsafe { RawMode::enter(&stdin)? };
|
|
|
|
Ok(Self {
|
|
poll,
|
|
socket,
|
|
stdout,
|
|
stdin,
|
|
need_bye: true,
|
|
_raw,
|
|
})
|
|
}
|
|
|
|
pub fn poll<'b>(&mut self, buffer: &'b mut [u8], pty_max: usize) -> Result<Option<Event<'b>>, Error> {
|
|
// let mut buf = [0; 16];
|
|
let event = loop {
|
|
let Some(event) = self.poll.wait(Some(Duration::from_millis(500)))? else {
|
|
self.check_timeout()?;
|
|
continue;
|
|
};
|
|
|
|
match event {
|
|
fd if fd == self.socket.as_raw_fd() => {
|
|
let message = match self.socket.recv_from(buffer) {
|
|
Ok((message, _)) => message,
|
|
Err(rsh::Error::Ping) => return Ok(None),
|
|
Err(error) => return Err(error.into())
|
|
};
|
|
match message {
|
|
ServerMessage::Bye(reason) => {
|
|
// No need for a bye
|
|
self.need_bye = false;
|
|
break Ok(Some(Event::Disconnected(reason)));
|
|
}
|
|
ServerMessage::Output(data) => {
|
|
break Ok(Some(Event::Data(data)));
|
|
}
|
|
// Ignore this one
|
|
ServerMessage::Hello => break Ok(None),
|
|
}
|
|
}
|
|
fd if fd == self.stdin.as_raw_fd() => {
|
|
let len = self.stdin.read(&mut buffer[..pty_max])?;
|
|
break Ok(Some(Event::Stdin(&buffer[..len])));
|
|
}
|
|
_ => unreachable!()
|
|
}
|
|
};
|
|
|
|
self.check_timeout()?;
|
|
|
|
event
|
|
}
|
|
|
|
pub fn run(mut self) -> Result<String, Error> {
|
|
let mut recv_buf = [0; 256];
|
|
let mut send_buf = [0; 256];
|
|
loop {
|
|
let Some(event) = self.poll(&mut recv_buf, 64)? else {
|
|
continue;
|
|
};
|
|
|
|
match event {
|
|
Event::Data(data) => {
|
|
self.stdout.write_all(data)?;
|
|
self.stdout.flush()?;
|
|
}
|
|
Event::Stdin(data) => {
|
|
self.socket.send(&mut send_buf, &ClientMessage::Input(data))?;
|
|
}
|
|
Event::Disconnected(reason) => {
|
|
break Ok(reason.into());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn check_timeout(&mut self) -> Result<(), Error> {
|
|
let now = Instant::now();
|
|
if now - self.socket.as_inner_mut().last_ping() >= PING_TIMEOUT {
|
|
Err(Error::Timeout)
|
|
} else {
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
fn handshake<S: PacketSocket>(
|
|
poll: &mut Poll,
|
|
socket: &mut ClientSocket<S>,
|
|
terminal: TerminalInfo,
|
|
attempts: usize,
|
|
) -> Result<(), Error> {
|
|
assert_ne!(attempts, 0);
|
|
let mut timeout = Duration::from_millis(500);
|
|
|
|
for _ in 0..attempts - 1 {
|
|
if Self::try_handshake(poll, socket, terminal, timeout).is_ok() {
|
|
return Ok(());
|
|
}
|
|
timeout *= 2;
|
|
}
|
|
|
|
Self::try_handshake(poll, socket, terminal, timeout)
|
|
}
|
|
|
|
fn try_handshake<S: PacketSocket>(
|
|
poll: &mut Poll,
|
|
socket: &mut ClientSocket<S>,
|
|
terminal: TerminalInfo,
|
|
timeout: Duration,
|
|
) -> Result<(), Error> {
|
|
let mut buffer = [0; 512];
|
|
socket.send(&mut buffer, &ClientMessage::Hello(terminal))?;
|
|
|
|
if poll.wait(Some(timeout))?.is_none() {
|
|
return Err(Error::Timeout);
|
|
};
|
|
let (message, _) = socket.recv_from(&mut buffer)?;
|
|
match message {
|
|
ServerMessage::Hello => Ok(()),
|
|
ServerMessage::Bye(reason) => Err(Error::Disconnected(reason.into())),
|
|
_ => Err(Error::Disconnected("Invalid message received".into())),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for Client {
|
|
fn drop(&mut self) {
|
|
if self.need_bye {
|
|
let mut buf = [0; 32];
|
|
self.socket
|
|
.send(&mut buf, &ClientMessage::Bye("".into()))
|
|
.ok();
|
|
}
|
|
}
|
|
}
|
|
|
|
fn terminal_info(stdout: &Stdout) -> Result<TerminalInfo, Error> {
|
|
let (columns, rows) = stdout.raw_size()?;
|
|
Ok(TerminalInfo {
|
|
columns: columns as _,
|
|
rows: rows as _,
|
|
})
|
|
}
|
|
|
|
fn run(args: Args) -> Result<(), Error> {
|
|
let remote = SocketAddr::new(args.remote, 7777);
|
|
let reason = Client::connect(remote)?.run()?;
|
|
if !reason.is_empty() {
|
|
eprintln!("\nDisconnected: {reason}");
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn main() -> ExitCode {
|
|
env_logger::init();
|
|
let args = Args::parse();
|
|
|
|
if let Err(error) = run(args) {
|
|
eprintln!("Error: {error}");
|
|
ExitCode::FAILURE
|
|
} else {
|
|
ExitCode::SUCCESS
|
|
}
|
|
}
|