rsh: better protocol handling
This commit is contained in:
parent
2d9cc793e0
commit
3e605b3b11
6
userspace/Cargo.lock
generated
6
userspace/Cargo.lock
generated
@ -1014,10 +1014,7 @@ dependencies = [
|
|||||||
"bytemuck",
|
"bytemuck",
|
||||||
"clap",
|
"clap",
|
||||||
"cross",
|
"cross",
|
||||||
"flexbuffers",
|
|
||||||
"libterm",
|
"libterm",
|
||||||
"serde",
|
|
||||||
"smallvec",
|
|
||||||
"thiserror",
|
"thiserror",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1200,9 +1197,6 @@ name = "smallvec"
|
|||||||
version = "1.13.2"
|
version = "1.13.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
||||||
dependencies = [
|
|
||||||
"serde",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "spin"
|
name = "spin"
|
||||||
|
@ -9,11 +9,7 @@ path = "src/rshd/main.rs"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
clap.workspace = true
|
clap.workspace = true
|
||||||
flexbuffers.workspace = true
|
|
||||||
libterm.workspace = true
|
libterm.workspace = true
|
||||||
serde.workspace = true
|
|
||||||
thiserror.workspace = true
|
thiserror.workspace = true
|
||||||
cross.workspace = true
|
cross.workspace = true
|
||||||
bytemuck.workspace = true
|
bytemuck.workspace = true
|
||||||
|
|
||||||
smallvec = { version = "1.13.2", features = ["serde"] }
|
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))]
|
#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))]
|
||||||
|
#![feature(generic_const_exprs)]
|
||||||
|
#![allow(incomplete_features)]
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
io,
|
io,
|
||||||
@ -8,44 +10,46 @@ use std::{
|
|||||||
os::fd::{AsRawFd, RawFd},
|
os::fd::{AsRawFd, RawFd},
|
||||||
};
|
};
|
||||||
|
|
||||||
use proto::{ClientMessage, ServerMessage};
|
use proto::{
|
||||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
ClientMessageProxy, Decode, DecodeError, Decoder, Encode, EncodeError, Encoder, MessageProxy,
|
||||||
|
ServerMessageProxy,
|
||||||
|
};
|
||||||
|
|
||||||
pub mod proto;
|
pub mod proto;
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
#[error("Deserialization error: {0}")]
|
|
||||||
Deserialization(#[from] flexbuffers::DeserializationError),
|
|
||||||
#[error("Serialization error: {0}")]
|
|
||||||
Serialization(#[from] flexbuffers::SerializationError),
|
|
||||||
#[error("I/O error")]
|
#[error("I/O error")]
|
||||||
Io(#[from] io::Error),
|
Io(#[from] io::Error),
|
||||||
#[error("Could not send a message fully")]
|
#[error("Could not send a message fully")]
|
||||||
Truncated,
|
Truncated,
|
||||||
|
#[error("Decode error: {0}")]
|
||||||
|
Decode(#[from] DecodeError),
|
||||||
|
#[error("Encode error: {0}")]
|
||||||
|
Encode(#[from] EncodeError),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SocketWrapper<Rx, Tx: Serialize> {
|
pub struct SocketWrapper<Rx: MessageProxy, Tx: MessageProxy> {
|
||||||
socket: UdpSocket,
|
socket: UdpSocket,
|
||||||
buffer: [u8; 256],
|
|
||||||
_pd: PhantomData<(Rx, Tx)>,
|
_pd: PhantomData<(Rx, Tx)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ClientSocket(SocketWrapper<ServerMessage, ClientMessage>);
|
pub struct ClientSocket(SocketWrapper<ServerMessageProxy, ClientMessageProxy>);
|
||||||
pub struct ServerSocket(SocketWrapper<ClientMessage, ServerMessage>);
|
pub struct ServerSocket(SocketWrapper<ClientMessageProxy, ServerMessageProxy>);
|
||||||
|
|
||||||
impl<Rx, Tx: Serialize> SocketWrapper<Rx, Tx> {
|
impl<Rx: MessageProxy, Tx: MessageProxy> SocketWrapper<Rx, Tx> {
|
||||||
pub fn new(socket: UdpSocket) -> Self {
|
pub fn new(socket: UdpSocket) -> Self {
|
||||||
Self {
|
Self {
|
||||||
socket,
|
socket,
|
||||||
buffer: [0; 256],
|
|
||||||
_pd: PhantomData,
|
_pd: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send(&mut self, message: &Tx) -> Result<(), Error> {
|
pub fn send(&self, buffer: &mut [u8], message: &Tx::Type<'_>) -> Result<(), Error> {
|
||||||
let message = flexbuffers::to_vec(message)?;
|
let mut enc = Encoder::new(buffer);
|
||||||
let amount = self.socket.send(&message)?;
|
message.encode(&mut enc)?;
|
||||||
|
let message = enc.get();
|
||||||
|
let amount = self.socket.send(message)?;
|
||||||
if amount == message.len() {
|
if amount == message.len() {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
@ -53,9 +57,16 @@ impl<Rx, Tx: Serialize> SocketWrapper<Rx, Tx> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_to(&mut self, remote: &SocketAddr, message: &Tx) -> Result<(), Error> {
|
pub fn send_to(
|
||||||
let message = flexbuffers::to_vec(message)?;
|
&mut self,
|
||||||
let amount = self.socket.send_to(&message, remote)?;
|
remote: &SocketAddr,
|
||||||
|
buffer: &mut [u8],
|
||||||
|
message: &Tx::Type<'_>,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let mut enc = Encoder::new(buffer);
|
||||||
|
message.encode(&mut enc)?;
|
||||||
|
let message = enc.get();
|
||||||
|
let amount = self.socket.send_to(message, remote)?;
|
||||||
if amount == message.len() {
|
if amount == message.len() {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
@ -63,26 +74,18 @@ impl<Rx, Tx: Serialize> SocketWrapper<Rx, Tx> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn recv_from(&mut self) -> Result<(Rx, SocketAddr), Error>
|
pub fn recv_from<'de>(
|
||||||
where
|
&mut self,
|
||||||
Rx: DeserializeOwned,
|
buffer: &'de mut [u8],
|
||||||
{
|
) -> Result<(Rx::Type<'de>, SocketAddr), Error> {
|
||||||
let (len, remote) = self.socket.recv_from(&mut self.buffer)?;
|
|
||||||
let message = flexbuffers::from_slice(&self.buffer[..len])?;
|
|
||||||
Ok((message, remote))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn recv_from_with<'de>(&mut self, buffer: &'de mut [u8]) -> Result<(Rx, SocketAddr), Error>
|
|
||||||
where
|
|
||||||
Rx: Deserialize<'de> + 'de,
|
|
||||||
{
|
|
||||||
let (len, remote) = self.socket.recv_from(buffer)?;
|
let (len, remote) = self.socket.recv_from(buffer)?;
|
||||||
let message = flexbuffers::from_slice(&buffer[..len])?;
|
let mut dec = Decoder::new(&buffer[..len]);
|
||||||
|
let message = Rx::Type::<'de>::decode(&mut dec)?;
|
||||||
Ok((message, remote))
|
Ok((message, remote))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Rx, Tx: Serialize> AsRawFd for SocketWrapper<Rx, Tx> {
|
impl<Rx: MessageProxy, Tx: MessageProxy> AsRawFd for SocketWrapper<Rx, Tx> {
|
||||||
fn as_raw_fd(&self) -> RawFd {
|
fn as_raw_fd(&self) -> RawFd {
|
||||||
self.socket.as_raw_fd()
|
self.socket.as_raw_fd()
|
||||||
}
|
}
|
||||||
@ -95,7 +98,7 @@ impl ClientSocket {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Deref for ClientSocket {
|
impl Deref for ClientSocket {
|
||||||
type Target = SocketWrapper<ServerMessage, ClientMessage>;
|
type Target = SocketWrapper<ServerMessageProxy, ClientMessageProxy>;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
&self.0
|
&self.0
|
||||||
@ -115,7 +118,7 @@ impl ServerSocket {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Deref for ServerSocket {
|
impl Deref for ServerSocket {
|
||||||
type Target = SocketWrapper<ClientMessage, ServerMessage>;
|
type Target = SocketWrapper<ClientMessageProxy, ServerMessageProxy>;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
&self.0
|
&self.0
|
||||||
|
@ -12,7 +12,7 @@ use clap::Parser;
|
|||||||
use cross::io::Poll;
|
use cross::io::Poll;
|
||||||
use libterm::{RawMode, RawTerminal};
|
use libterm::{RawMode, RawTerminal};
|
||||||
use rsh::{
|
use rsh::{
|
||||||
proto::{ClientData, ClientMessage, ServerData, ServerMessage, TerminalInfo},
|
proto::{ClientMessage, ServerMessage, TerminalInfo},
|
||||||
ClientSocket,
|
ClientSocket,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -44,13 +44,14 @@ pub struct Client {
|
|||||||
stdout: Stdout,
|
stdout: Stdout,
|
||||||
need_bye: bool,
|
need_bye: bool,
|
||||||
last_ping: Instant,
|
last_ping: Instant,
|
||||||
_raw: RawMode
|
_raw: RawMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum Event {
|
pub enum Event<'b> {
|
||||||
Stdin(ClientData),
|
Stdin(&'b [u8]),
|
||||||
Data(ServerData),
|
Data(&'b [u8]),
|
||||||
Disconnected(String),
|
Disconnected(&'b str),
|
||||||
|
Ping,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Client {
|
impl Client {
|
||||||
@ -91,8 +92,8 @@ impl Client {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn poll(&mut self) -> Result<Event, Error> {
|
pub fn poll<'b>(&mut self, buffer: &'b mut [u8], pty_max: usize) -> Result<Option<Event<'b>>, Error> {
|
||||||
let mut buf = [0; 16];
|
// let mut buf = [0; 16];
|
||||||
let event = loop {
|
let event = loop {
|
||||||
let Some(event) = self.poll.wait(Some(Duration::from_millis(500)))? else {
|
let Some(event) = self.poll.wait(Some(Duration::from_millis(500)))? else {
|
||||||
self.check_timeout()?;
|
self.check_timeout()?;
|
||||||
@ -101,29 +102,26 @@ impl Client {
|
|||||||
|
|
||||||
match event {
|
match event {
|
||||||
fd if fd == self.socket.as_raw_fd() => {
|
fd if fd == self.socket.as_raw_fd() => {
|
||||||
let (message, _) = self.socket.recv_from()?;
|
let (message, _) = self.socket.recv_from(buffer)?;
|
||||||
match message {
|
match message {
|
||||||
ServerMessage::Bye(reason) => {
|
ServerMessage::Bye(reason) => {
|
||||||
// No need for a bye
|
// No need for a bye
|
||||||
self.need_bye = false;
|
self.need_bye = false;
|
||||||
break Ok(Event::Disconnected(reason));
|
break Ok(Some(Event::Disconnected(reason)));
|
||||||
}
|
}
|
||||||
ServerMessage::Output(data) => {
|
ServerMessage::Output(data) => {
|
||||||
break Ok(Event::Data(data));
|
break Ok(Some(Event::Data(data)));
|
||||||
}
|
}
|
||||||
ServerMessage::Ping => {
|
ServerMessage::Ping => {
|
||||||
self.last_ping = Instant::now();
|
break Ok(Some(Event::Ping));
|
||||||
self.socket.send(&ClientMessage::Pong).ok();
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
// Ignore this one
|
// Ignore this one
|
||||||
ServerMessage::Hello => continue,
|
ServerMessage::Hello => break Ok(None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fd if fd == self.stdin.as_raw_fd() => {
|
fd if fd == self.stdin.as_raw_fd() => {
|
||||||
let len = self.stdin.read(&mut buf)?;
|
let len = self.stdin.read(&mut buffer[..pty_max])?;
|
||||||
let data = ClientData::from_slice(&buf[..len]);
|
break Ok(Some(Event::Stdin(&buffer[..len])));
|
||||||
break Ok(Event::Stdin(data));
|
|
||||||
}
|
}
|
||||||
_ => unreachable!()
|
_ => unreachable!()
|
||||||
}
|
}
|
||||||
@ -135,19 +133,27 @@ impl Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn run(mut self) -> Result<String, Error> {
|
pub fn run(mut self) -> Result<String, Error> {
|
||||||
|
let mut recv_buf = [0; 256];
|
||||||
|
let mut send_buf = [0; 256];
|
||||||
loop {
|
loop {
|
||||||
let event = self.poll()?;
|
let Some(event) = self.poll(&mut recv_buf, 64)? else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
Event::Data(data) => {
|
Event::Data(data) => {
|
||||||
self.stdout.write_all(&data).ok();
|
self.stdout.write_all(data)?;
|
||||||
self.stdout.flush().ok();
|
self.stdout.flush()?;
|
||||||
}
|
}
|
||||||
Event::Stdin(data) => {
|
Event::Stdin(data) => {
|
||||||
self.socket.send(&ClientMessage::Input(data))?;
|
self.socket.send(&mut send_buf, &ClientMessage::Input(data))?;
|
||||||
|
}
|
||||||
|
Event::Ping => {
|
||||||
|
self.last_ping = Instant::now();
|
||||||
|
self.socket.send(&mut send_buf, &ClientMessage::Pong)?;
|
||||||
}
|
}
|
||||||
Event::Disconnected(reason) => {
|
Event::Disconnected(reason) => {
|
||||||
break Ok(reason);
|
break Ok(reason.into());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -187,15 +193,16 @@ impl Client {
|
|||||||
terminal: TerminalInfo,
|
terminal: TerminalInfo,
|
||||||
timeout: Duration,
|
timeout: Duration,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
socket.send(&ClientMessage::Hello(terminal))?;
|
let mut buffer = [0; 512];
|
||||||
|
socket.send(&mut buffer, &ClientMessage::Hello(terminal))?;
|
||||||
|
|
||||||
if poll.wait(Some(timeout))?.is_none() {
|
if poll.wait(Some(timeout))?.is_none() {
|
||||||
return Err(Error::Timeout);
|
return Err(Error::Timeout);
|
||||||
};
|
};
|
||||||
let (message, _) = socket.recv_from()?;
|
let (message, _) = socket.recv_from(&mut buffer)?;
|
||||||
match message {
|
match message {
|
||||||
ServerMessage::Hello => Ok(()),
|
ServerMessage::Hello => Ok(()),
|
||||||
ServerMessage::Bye(reason) => Err(Error::Disconnected(reason)),
|
ServerMessage::Bye(reason) => Err(Error::Disconnected(reason.into())),
|
||||||
_ => Err(Error::Disconnected("Invalid message received".into())),
|
_ => Err(Error::Disconnected("Invalid message received".into())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -204,7 +211,10 @@ impl Client {
|
|||||||
impl Drop for Client {
|
impl Drop for Client {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
if self.need_bye {
|
if self.need_bye {
|
||||||
self.socket.send(&ClientMessage::Bye("".into())).ok();
|
let mut buf = [0; 32];
|
||||||
|
self.socket
|
||||||
|
.send(&mut buf, &ClientMessage::Bye("".into()))
|
||||||
|
.ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,27 +1,251 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
#[derive(Debug, Clone, Copy)]
|
||||||
use smallvec::SmallVec;
|
|
||||||
|
|
||||||
pub type ServerData = SmallVec<[u8; 64]>;
|
|
||||||
pub type ClientData = SmallVec<[u8; 16]>;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
|
||||||
pub struct TerminalInfo {
|
pub struct TerminalInfo {
|
||||||
pub columns: u32,
|
pub columns: u32,
|
||||||
pub rows: u32
|
pub rows: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
pub struct ClientMessageProxy;
|
||||||
pub enum ClientMessage {
|
pub struct ServerMessageProxy;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ClientMessage<'a> {
|
||||||
Hello(TerminalInfo),
|
Hello(TerminalInfo),
|
||||||
Bye(String),
|
Bye(&'a str),
|
||||||
Pong,
|
Pong,
|
||||||
Input(ClientData),
|
Input(&'a [u8]),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug)]
|
||||||
pub enum ServerMessage {
|
pub enum ServerMessage<'a> {
|
||||||
Hello,
|
Hello,
|
||||||
Bye(String),
|
Bye(&'a str),
|
||||||
Ping,
|
Ping,
|
||||||
Output(ServerData),
|
Output(&'a [u8]),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum EncodeError {
|
||||||
|
#[error("Message too large")]
|
||||||
|
MessageTooLarge,
|
||||||
|
#[error("Value too long")]
|
||||||
|
ValueTooLong,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum DecodeError {
|
||||||
|
#[error("Truncated message received")]
|
||||||
|
Truncated,
|
||||||
|
#[error("Malformed message received")]
|
||||||
|
InvalidMessage,
|
||||||
|
#[error("Malformed string in the message")]
|
||||||
|
InvalidString(core::str::Utf8Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Encoder<'a> {
|
||||||
|
buffer: &'a mut [u8],
|
||||||
|
pos: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Decoder<'a> {
|
||||||
|
buffer: &'a [u8],
|
||||||
|
pos: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Encode {
|
||||||
|
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Decode<'de>: Sized + 'de {
|
||||||
|
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait MessageProxy {
|
||||||
|
type Type<'de>: Encode + Decode<'de>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Encoder<'a> {
|
||||||
|
pub const fn new(buffer: &'a mut [u8]) -> Self {
|
||||||
|
Self { buffer, pos: 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write_variable_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
|
||||||
|
let len: u32 = bytes
|
||||||
|
.len()
|
||||||
|
.try_into()
|
||||||
|
.map_err(|_| EncodeError::ValueTooLong)?;
|
||||||
|
self.write(&len.to_le_bytes())?;
|
||||||
|
self.write(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write_str(&mut self, s: &str) -> Result<(), EncodeError> {
|
||||||
|
self.write_variable_bytes(s.as_bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
|
||||||
|
if self.pos + bytes.len() > self.buffer.len() {
|
||||||
|
return Err(EncodeError::ValueTooLong);
|
||||||
|
}
|
||||||
|
self.buffer[self.pos..self.pos + bytes.len()].copy_from_slice(bytes);
|
||||||
|
self.pos += bytes.len();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self) -> &[u8] {
|
||||||
|
&self.buffer[..self.pos]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Decoder<'a> {
|
||||||
|
pub fn new(buffer: &'a [u8]) -> Self {
|
||||||
|
Self { buffer, pos: 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_u8(&mut self) -> Result<u8, DecodeError> {
|
||||||
|
if self.pos + 1 > self.buffer.len() {
|
||||||
|
return Err(DecodeError::Truncated);
|
||||||
|
}
|
||||||
|
let byte = self.buffer[self.pos];
|
||||||
|
self.pos += 1;
|
||||||
|
Ok(byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_le_u32(&mut self) -> Result<u32, DecodeError> {
|
||||||
|
let bytes = self.read_bytes(size_of::<u32>())?;
|
||||||
|
Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_bytes(&mut self, len: usize) -> Result<&'a [u8], DecodeError> {
|
||||||
|
if self.pos + len > self.buffer.len() {
|
||||||
|
return Err(DecodeError::Truncated);
|
||||||
|
}
|
||||||
|
let slice = &self.buffer[self.pos..self.pos + len];
|
||||||
|
self.pos += len;
|
||||||
|
Ok(slice)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_variable_bytes(&mut self) -> Result<&'a [u8], DecodeError> {
|
||||||
|
let len = self.read_le_u32()?;
|
||||||
|
self.read_bytes(len as usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_str(&mut self) -> Result<&'a str, DecodeError> {
|
||||||
|
let slice = self.read_variable_bytes()?;
|
||||||
|
core::str::from_utf8(slice).map_err(DecodeError::InvalidString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MessageProxy for ClientMessageProxy {
|
||||||
|
type Type<'de> = ClientMessage<'de>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MessageProxy for ServerMessageProxy {
|
||||||
|
type Type<'de> = ServerMessage<'de>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientMessage<'_> {
|
||||||
|
const TAG_HELLO: u8 = 0x80;
|
||||||
|
const TAG_BYE: u8 = 0x81;
|
||||||
|
const TAG_PONG: u8 = 0x82;
|
||||||
|
const TAG_INPUT: u8 = 0x90;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encode for TerminalInfo {
|
||||||
|
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
|
||||||
|
buffer.write(&self.columns.to_le_bytes())?;
|
||||||
|
buffer.write(&self.rows.to_le_bytes())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Decode<'de> for TerminalInfo {
|
||||||
|
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
|
||||||
|
let columns = buffer.read_le_u32()?;
|
||||||
|
let rows = buffer.read_le_u32()?;
|
||||||
|
Ok(Self { columns, rows })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Encode for ClientMessage<'a> {
|
||||||
|
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
|
||||||
|
match self {
|
||||||
|
Self::Hello(info) => {
|
||||||
|
buffer.write(&[Self::TAG_HELLO])?;
|
||||||
|
info.encode(buffer)
|
||||||
|
}
|
||||||
|
Self::Bye(reason) => {
|
||||||
|
buffer.write(&[Self::TAG_BYE])?;
|
||||||
|
buffer.write_str(reason)
|
||||||
|
}
|
||||||
|
Self::Pong => {
|
||||||
|
buffer.write(&[Self::TAG_PONG])
|
||||||
|
}
|
||||||
|
Self::Input(data) => {
|
||||||
|
buffer.write(&[Self::TAG_INPUT])?;
|
||||||
|
buffer.write_variable_bytes(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Decode<'de> for ClientMessage<'de> {
|
||||||
|
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
|
||||||
|
let tag = buffer.read_u8()?;
|
||||||
|
match tag {
|
||||||
|
Self::TAG_HELLO => {
|
||||||
|
TerminalInfo::decode(buffer).map(Self::Hello)
|
||||||
|
}
|
||||||
|
Self::TAG_BYE => {
|
||||||
|
buffer.read_str().map(Self::Bye)
|
||||||
|
}
|
||||||
|
Self::TAG_PONG => {
|
||||||
|
Ok(Self::Pong)
|
||||||
|
}
|
||||||
|
Self::TAG_INPUT => {
|
||||||
|
buffer.read_variable_bytes().map(Self::Input)
|
||||||
|
}
|
||||||
|
_ => Err(DecodeError::InvalidMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerMessage<'_> {
|
||||||
|
const TAG_HELLO: u8 = 0x10;
|
||||||
|
const TAG_BYE: u8 = 0x11;
|
||||||
|
const TAG_PING: u8 = 0x12;
|
||||||
|
const TAG_OUTPUT: u8 = 0x20;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Encode for ServerMessage<'a> {
|
||||||
|
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
|
||||||
|
match self {
|
||||||
|
Self::Hello => buffer.write(&[Self::TAG_HELLO]),
|
||||||
|
Self::Bye(reason) => {
|
||||||
|
buffer.write(&[Self::TAG_BYE])?;
|
||||||
|
buffer.write_str(reason)
|
||||||
|
}
|
||||||
|
// TODO sequence number
|
||||||
|
Self::Ping => buffer.write(&[Self::TAG_PING]),
|
||||||
|
Self::Output(data) => {
|
||||||
|
buffer.write(&[Self::TAG_OUTPUT])?;
|
||||||
|
buffer.write_variable_bytes(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Decode<'de> for ServerMessage<'de> {
|
||||||
|
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
|
||||||
|
let tag = buffer.read_u8()?;
|
||||||
|
match tag {
|
||||||
|
Self::TAG_HELLO => Ok(Self::Hello),
|
||||||
|
Self::TAG_BYE => {
|
||||||
|
buffer.read_str().map(Self::Bye)
|
||||||
|
}
|
||||||
|
Self::TAG_PING => Ok(Self::Ping),
|
||||||
|
Self::TAG_OUTPUT => {
|
||||||
|
buffer.read_variable_bytes().map(Self::Output)
|
||||||
|
},
|
||||||
|
_ => Err(DecodeError::InvalidMessage),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@ use std::{
|
|||||||
|
|
||||||
use cross::io::{Poll, TimerFd};
|
use cross::io::{Poll, TimerFd};
|
||||||
use rsh::{
|
use rsh::{
|
||||||
proto::{ClientData, ClientMessage, ServerData, ServerMessage, TerminalInfo},
|
proto::{ClientMessage, ServerMessage, TerminalInfo},
|
||||||
Error, ServerSocket,
|
Error, ServerSocket,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -35,17 +35,18 @@ pub struct Server {
|
|||||||
pty_to_session: HashMap<RawFd, Session>,
|
pty_to_session: HashMap<RawFd, Session>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum PtyEvent {
|
pub enum PtyEvent<'b> {
|
||||||
Data(ServerData),
|
Data(&'b [u8]),
|
||||||
Err(Error),
|
Err(Error),
|
||||||
Closed,
|
Closed,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum Event {
|
pub enum Event<'b> {
|
||||||
NewClient(SocketAddr, TerminalInfo),
|
NewClient(SocketAddr, TerminalInfo),
|
||||||
SessionInput(RawFd, SocketAddr, ClientData),
|
SessionInput(RawFd, SocketAddr, &'b [u8]),
|
||||||
ClientBye(SocketAddr, String),
|
ClientBye(SocketAddr, &'b str),
|
||||||
Pty(RawFd, SocketAddr, PtyEvent),
|
Pty(RawFd, SocketAddr, PtyEvent<'b>),
|
||||||
|
Tick,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
@ -109,14 +110,19 @@ impl Server {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn poll(&mut self) -> Result<Event, Error> {
|
pub fn poll<'b>(&mut self, buffer: &'b mut [u8], pty_max: usize) -> Result<Option<Event<'b>>, Error> {
|
||||||
let mut buf = [0; 64];
|
|
||||||
loop {
|
|
||||||
let fd = self.poll.wait(None)?.unwrap();
|
let fd = self.poll.wait(None)?.unwrap();
|
||||||
|
|
||||||
match fd {
|
match fd {
|
||||||
fd if fd == self.socket.as_raw_fd() => {
|
fd if fd == self.socket.as_raw_fd() => {
|
||||||
let (message, remote) = self.socket.recv_from()?;
|
let (message, remote) = match self.socket.recv_from(buffer) {
|
||||||
|
Ok((message, remote)) => (message, remote),
|
||||||
|
Err(error @ (Error::Decode(_) | Error::Truncated)) => {
|
||||||
|
eprintln!("Receive error: {error}");
|
||||||
|
return Ok(None)
|
||||||
|
},
|
||||||
|
Err(error) => return Err(error),
|
||||||
|
};
|
||||||
|
|
||||||
let event = match message {
|
let event = match message {
|
||||||
ClientMessage::Hello(terminal)
|
ClientMessage::Hello(terminal)
|
||||||
@ -133,35 +139,35 @@ impl Server {
|
|||||||
ClientMessage::Pong if let Some(fd) = self.addr_to_session.get(&remote) => {
|
ClientMessage::Pong if let Some(fd) = self.addr_to_session.get(&remote) => {
|
||||||
let session = self.pty_to_session.get_mut(fd).unwrap();
|
let session = self.pty_to_session.get_mut(fd).unwrap();
|
||||||
session.timeouts = 0;
|
session.timeouts = 0;
|
||||||
continue
|
return Ok(None);
|
||||||
},
|
}
|
||||||
_ => continue,
|
_ => return Ok(None),
|
||||||
};
|
};
|
||||||
|
|
||||||
break Ok(event);
|
Ok(Some(event))
|
||||||
}
|
|
||||||
fd if fd == self.timer.as_raw_fd() => {
|
|
||||||
// Restart the timer
|
|
||||||
self.update_client_timeouts()?;
|
|
||||||
self.timer.start(PING_INTERVAL)?;
|
|
||||||
}
|
}
|
||||||
|
fd if fd == self.timer.as_raw_fd() => Ok(Some(Event::Tick)),
|
||||||
fd => {
|
fd => {
|
||||||
let session = self.pty_to_session.get_mut(&fd).unwrap();
|
let session = self.pty_to_session.get_mut(&fd).unwrap();
|
||||||
let event = match session.pty_master.read(&mut buf) {
|
let event = match session.pty_master.read(&mut buffer[..pty_max]) {
|
||||||
Ok(0) => PtyEvent::Closed,
|
Ok(0) => PtyEvent::Closed,
|
||||||
Ok(len) => PtyEvent::Data(buf[..len].into()),
|
Ok(len) => PtyEvent::Data(&buffer[..len]),
|
||||||
Err(e) => PtyEvent::Err(e.into()),
|
Err(e) => PtyEvent::Err(e.into()),
|
||||||
};
|
};
|
||||||
break Ok(Event::Pty(fd, session.remote, event));
|
Ok(Some(Event::Pty(fd, session.remote, event)))
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run(mut self) -> Result<(), Error> {
|
pub fn run(mut self) -> Result<(), Error> {
|
||||||
self.timer.start(PING_INTERVAL)?;
|
self.timer.start(PING_INTERVAL)?;
|
||||||
|
let mut recv_buf = [0; 256];
|
||||||
|
let mut send_buf = [0; 256];
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let event = self.poll()?;
|
let Some(event) = self.poll(&mut recv_buf, 128)? else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
Event::SessionInput(fd, remote, data) => {
|
Event::SessionInput(fd, remote, data) => {
|
||||||
@ -170,7 +176,7 @@ impl Server {
|
|||||||
eprintln!("PTY write error: {error}");
|
eprintln!("PTY write error: {error}");
|
||||||
self.remove_session_by_fd(fd)?;
|
self.remove_session_by_fd(fd)?;
|
||||||
self.socket
|
self.socket
|
||||||
.send_to(&remote, &ServerMessage::Bye(format!("PTY error: {error}")))
|
.send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -183,46 +189,57 @@ impl Server {
|
|||||||
match Session::open(&terminal, remote) {
|
match Session::open(&terminal, remote) {
|
||||||
Ok(session) => {
|
Ok(session) => {
|
||||||
self.register_session(remote, session)?;
|
self.register_session(remote, session)?;
|
||||||
self.socket.send_to(&remote, &ServerMessage::Hello).ok();
|
self.socket
|
||||||
|
.send_to(&remote, &mut send_buf, &ServerMessage::Hello)
|
||||||
|
.ok();
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
eprintln!("PTY open error: {err}");
|
eprintln!("PTY open error: {err}");
|
||||||
self.socket
|
self.socket
|
||||||
.send_to(&remote, &ServerMessage::Bye("PTY open error".into()))
|
.send_to(
|
||||||
|
&remote,
|
||||||
|
&mut send_buf,
|
||||||
|
&ServerMessage::Bye("PTY open error"),
|
||||||
|
)
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Event::Pty(fd, remote, event) => match event {
|
Event::Pty(fd, remote, event) => match event {
|
||||||
|
PtyEvent::Data(data) => {
|
||||||
|
self.socket
|
||||||
|
.send_to(&remote, &mut send_buf, &ServerMessage::Output(data))
|
||||||
|
.ok();
|
||||||
|
},
|
||||||
PtyEvent::Err(error) => {
|
PtyEvent::Err(error) => {
|
||||||
eprintln!("PTY read error: {error}");
|
eprintln!("PTY read error: {error}");
|
||||||
self.remove_session_by_fd(fd)?;
|
self.remove_session_by_fd(fd)?;
|
||||||
self.socket
|
self.socket
|
||||||
.send_to(&remote, &ServerMessage::Bye(format!("PTY error: {error}")))
|
.send_to(&remote, &mut send_buf, &ServerMessage::Bye("PTY error"))
|
||||||
.ok();
|
.ok();
|
||||||
}
|
},
|
||||||
PtyEvent::Data(data) => {
|
|
||||||
self.socket
|
|
||||||
.send_to(&remote, &ServerMessage::Output(data))
|
|
||||||
.ok();
|
|
||||||
}
|
|
||||||
PtyEvent::Closed => {
|
PtyEvent::Closed => {
|
||||||
println!("End of PTY for {remote}");
|
println!("End of PTY for {remote}");
|
||||||
self.remove_session_by_fd(fd)?;
|
self.remove_session_by_fd(fd)?;
|
||||||
self.socket
|
self.socket
|
||||||
.send_to(&remote, &ServerMessage::Bye("".into()))
|
.send_to(&remote, &mut send_buf, &ServerMessage::Bye(""))
|
||||||
.ok();
|
.ok();
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
Event::Tick => {
|
||||||
|
// Restart the timer
|
||||||
|
self.update_client_timeouts(&mut send_buf)?;
|
||||||
|
self.timer.start(PING_INTERVAL)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_client_timeouts(&mut self) -> Result<(), Error> {
|
fn update_client_timeouts(&mut self, send_buf: &mut [u8]) -> Result<(), Error> {
|
||||||
let mut removed = vec![];
|
let mut removed = vec![];
|
||||||
for (remote, fd) in self.addr_to_session.iter() {
|
for (remote, fd) in self.addr_to_session.iter() {
|
||||||
let session = self.pty_to_session.get_mut(&fd).unwrap();
|
let session = self.pty_to_session.get_mut(&fd).unwrap();
|
||||||
self.socket.send_to(remote, &ServerMessage::Ping).ok();
|
self.socket.send_to(remote, send_buf, &ServerMessage::Ping).ok();
|
||||||
session.timeouts += 1;
|
session.timeouts += 1;
|
||||||
if session.timeouts >= 10 {
|
if session.timeouts >= 10 {
|
||||||
removed.push((*remote, *fd));
|
removed.push((*remote, *fd));
|
||||||
@ -232,7 +249,7 @@ impl Server {
|
|||||||
for (remote, fd) in removed {
|
for (remote, fd) in removed {
|
||||||
eprintln!("Client {remote} timed out");
|
eprintln!("Client {remote} timed out");
|
||||||
self.remove_session_by_fd(fd)?;
|
self.remove_session_by_fd(fd)?;
|
||||||
self.socket.send_to(&remote, &ServerMessage::Bye("Timed out".into())).ok();
|
self.socket.send_to(&remote, send_buf, &ServerMessage::Bye("Timed out".into())).ok();
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user