vfs/rsh: better pipes, piped command execution in rsh

This commit is contained in:
Mark Poliakov 2024-11-06 19:40:27 +02:00
parent a707a6e5f1
commit 2479702baf
21 changed files with 707 additions and 202 deletions

View File

@ -17,8 +17,8 @@ use libk_util::sync::IrqSafeSpinlock;
use yggdrasil_abi::{
error::Error,
io::{
DeviceRequest, DirectoryEntry, OpenOptions, RawFd, SeekFrom, TerminalOptions, TerminalSize,
TimerOptions,
DeviceRequest, DirectoryEntry, OpenOptions, PipeOptions, RawFd, SeekFrom, TerminalOptions,
TerminalSize, TimerOptions,
},
};
@ -71,7 +71,7 @@ pub enum File {
Block(BlockFile),
Char(CharFile),
Socket(SocketWrapper),
AnonymousPipe(PipeEnd),
AnonymousPipe(PipeEnd, AtomicBool),
Poll(FdPoll),
Timer(TimerFile),
Channel(ChannelDescriptor),
@ -103,11 +103,17 @@ pub struct FileSet {
impl File {
/// Constructs a pipe pair, returning its `(read, write)` ends
pub fn new_pipe_pair(capacity: usize) -> (Arc<Self>, Arc<Self>) {
pub fn new_pipe_pair(capacity: usize, options: PipeOptions) -> (Arc<Self>, Arc<Self>) {
let (read, write) = PipeEnd::new_pair(capacity);
(
Arc::new(Self::AnonymousPipe(read)),
Arc::new(Self::AnonymousPipe(write)),
Arc::new(Self::AnonymousPipe(
read,
AtomicBool::new(options.contains(PipeOptions::READ_NONBLOCKING)),
)),
Arc::new(Self::AnonymousPipe(
write,
AtomicBool::new(options.contains(PipeOptions::WRITE_NONBLOCKING)),
)),
)
}
@ -289,6 +295,7 @@ impl File {
Self::Socket(socket) => socket.poll_read(cx),
Self::Timer(timer) => timer.poll_read(cx),
Self::Pid(pid) => pid.poll_read(cx),
Self::AnonymousPipe(pipe, _) => pipe.poll_read(cx),
// Polling not implemented, return ready immediately (XXX ?)
_ => Poll::Ready(Err(Error::NotImplemented)),
}
@ -365,7 +372,9 @@ impl Read for File {
Self::Regular(file) => file.read(buf),
Self::Block(file) => file.read(buf),
Self::Char(file) => file.read(buf),
Self::AnonymousPipe(pipe) => pipe.read(buf),
Self::AnonymousPipe(pipe, nonblocking) => {
pipe.read(buf, nonblocking.load(Ordering::Acquire))
}
Self::PtySlave(half) => half.read(buf),
Self::PtyMaster(half) => half.read(buf),
Self::Timer(timer) => timer.read(buf),
@ -388,7 +397,7 @@ impl Write for File {
Self::Regular(file) => file.write(buf),
Self::Block(file) => file.write(buf),
Self::Char(file) => file.write(buf),
Self::AnonymousPipe(pipe) => pipe.write(buf),
Self::AnonymousPipe(pipe, _) => pipe.write(buf),
Self::PtySlave(half) => half.write(buf),
Self::PtyMaster(half) => half.write(buf),
Self::Timer(timer) => timer.write(buf),
@ -447,7 +456,7 @@ impl fmt::Debug for File {
.field("write", &file.write)
.finish_non_exhaustive(),
Self::Directory(_) => f.debug_struct("DirectoryFile").finish_non_exhaustive(),
Self::AnonymousPipe(_) => f.debug_struct("AnonymousPipe").finish_non_exhaustive(),
Self::AnonymousPipe(_, _) => f.debug_struct("AnonymousPipe").finish_non_exhaustive(),
Self::Poll(_) => f.debug_struct("Poll").finish_non_exhaustive(),
Self::Channel(_) => f.debug_struct("Channel").finish_non_exhaustive(),
Self::SharedMemory(_) => f.debug_struct("SharedMemory").finish_non_exhaustive(),

View File

@ -1,4 +1,5 @@
use core::{
future::poll_fn,
pin::Pin,
sync::atomic::{AtomicBool, Ordering},
task::{Context, Poll},
@ -6,9 +7,14 @@ use core::{
use alloc::{sync::Arc, vec, vec::Vec};
use futures_util::{task::AtomicWaker, Future};
use libk_util::sync::IrqSafeSpinlock;
use libk_util::{
sync::{IrqSafeSpinlock, IrqSafeSpinlockGuard},
waker::QueueWaker,
};
use yggdrasil_abi::error::Error;
use crate::vfs::FileReadiness;
struct PipeInner {
data: Vec<u8>,
capacity: usize,
@ -19,8 +25,8 @@ struct PipeInner {
pub struct Pipe {
inner: IrqSafeSpinlock<PipeInner>,
shutdown: AtomicBool,
read_notify: AtomicWaker,
write_notify: AtomicWaker,
read_notify: QueueWaker,
write_notify: QueueWaker,
}
pub enum PipeEnd {
@ -82,15 +88,15 @@ impl Pipe {
Self {
inner: IrqSafeSpinlock::new(PipeInner::new(capacity)),
shutdown: AtomicBool::new(false),
read_notify: AtomicWaker::new(),
write_notify: AtomicWaker::new(),
read_notify: QueueWaker::new(),
write_notify: QueueWaker::new(),
}
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Release);
self.read_notify.wake();
self.write_notify.wake();
self.read_notify.wake_all();
self.write_notify.wake_all();
}
pub fn blocking_write(&self, val: u8) -> impl Future<Output = Result<(), Error>> + '_ {
@ -105,23 +111,15 @@ impl Pipe {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut lock = self.pipe.inner.lock();
// Try fast path before acquiring write notify to avoid unnecessary contention
if self.pipe.shutdown.load(Ordering::Acquire) {
// TODO BrokenPipe
return Poll::Ready(Err(Error::ReadOnly));
} else if lock.try_write(self.val) {
self.pipe.read_notify.wake();
return Poll::Ready(Ok(()));
}
self.pipe.write_notify.register(cx.waker());
if self.pipe.shutdown.load(Ordering::Acquire) {
self.pipe.write_notify.remove(cx.waker());
Poll::Ready(Err(Error::ReadOnly))
} else if lock.try_write(self.val) {
self.pipe.read_notify.wake();
self.pipe.write_notify.remove(cx.waker());
self.pipe.read_notify.wake_one();
Poll::Ready(Ok(()))
} else {
self.pipe.write_notify.register(cx.waker());
Poll::Pending
}
}
@ -130,37 +128,54 @@ impl Pipe {
F { pipe: self, val }
}
pub fn blocking_read(&self) -> impl Future<Output = Option<u8>> + '_ {
struct F<'a> {
pipe: &'a Pipe,
fn poll_read_end(
&self,
cx: &mut Context<'_>,
) -> Poll<Option<IrqSafeSpinlockGuard<'_, PipeInner>>> {
let lock = self.inner.lock();
if lock.can_read() {
self.read_notify.remove(cx.waker());
Poll::Ready(Some(lock))
} else if self.shutdown.load(Ordering::Acquire) {
self.read_notify.remove(cx.waker());
Poll::Ready(None)
} else {
self.read_notify.register(cx.waker());
Poll::Pending
}
}
impl<'a> Future for F<'a> {
type Output = Option<u8>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut lock = self.pipe.inner.lock();
if let Some(val) = lock.try_read() {
self.pipe.write_notify.wake();
return Poll::Ready(Some(val));
} else if self.pipe.shutdown.load(Ordering::Acquire) {
return Poll::Ready(None);
async fn read_async(&self, buf: &mut [u8]) -> Result<usize, Error> {
let mut pos = 0;
let lock = poll_fn(|cx| self.poll_read_end(cx)).await;
match lock {
Some(mut lock) => {
while pos < buf.len()
&& let Some(byte) = lock.try_read()
{
buf[pos] = byte;
pos += 1;
}
self.pipe.read_notify.register(cx.waker());
if let Some(val) = lock.try_read() {
Poll::Ready(Some(val))
} else if self.pipe.shutdown.load(Ordering::Acquire) {
Poll::Ready(None)
} else {
Poll::Pending
if pos != 0 {
self.write_notify.wake_all();
}
Ok(pos)
}
None => Ok(0),
}
}
}
F { pipe: self }
impl FileReadiness for Pipe {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.shutdown.load(Ordering::Acquire) || self.inner.lock().can_read() {
self.read_notify.remove(cx.waker());
Poll::Ready(Ok(()))
} else {
self.read_notify.register(cx.waker());
Poll::Pending
}
}
}
@ -173,27 +188,29 @@ impl PipeEnd {
(read, write)
}
pub fn read(&self, buf: &mut [u8]) -> Result<usize, Error> {
pub fn read(&self, buf: &mut [u8], nonblocking: bool) -> Result<usize, Error> {
let PipeEnd::Read(read) = self else {
return Err(Error::InvalidOperation);
};
block! {
if nonblocking {
let mut lock = read.inner.lock();
let mut pos = 0;
let mut rem = buf.len();
while rem != 0 {
if let Some(val) = read.blocking_read().await {
buf[pos] = val;
pos += 1;
rem -= 1;
} else {
break;
}
while pos < buf.len()
&& let Some(ch) = lock.try_read()
{
buf[pos] = ch;
pos += 1;
}
Ok(pos)
}?
read.write_notify.wake_all();
if pos == 0 && !read.shutdown.load(Ordering::Acquire) {
Err(Error::WouldBlock)
} else {
Ok(pos)
}
} else {
block!(read.read_async(buf).await)?
}
}
pub fn write(&self, buf: &[u8]) -> Result<usize, Error> {
@ -216,6 +233,15 @@ impl PipeEnd {
}
}
impl FileReadiness for PipeEnd {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self {
Self::Read(pipe) => pipe.poll_read(cx),
Self::Write(_) => Poll::Ready(Err(Error::NotImplemented)),
}
}
}
impl Drop for PipeEnd {
fn drop(&mut self) {
match self {

View File

@ -1,8 +1,8 @@
pub(crate) use abi::{
error::Error,
io::{
DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions, PollControl, RawFd,
TerminalOptions, TerminalSize, TimerOptions, UnmountOptions,
DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions, PipeOptions, PollControl,
RawFd, TerminalOptions, TerminalSize, TimerOptions, UnmountOptions,
},
mem::MappingSource,
net::SocketType,

View File

@ -4,8 +4,8 @@ use abi::{
error::Error,
io::{
ChannelPublisherId, DeviceRequest, DirectoryEntry, FileAttr, FileMetadataUpdate, FileMode,
MessageDestination, OpenOptions, PollControl, RawFd, ReceivedMessageMetadata, SeekFrom,
SentMessage, TerminalOptions, TerminalSize, TimerOptions,
MessageDestination, OpenOptions, PipeOptions, PollControl, RawFd, ReceivedMessageMetadata,
SeekFrom, SentMessage, TerminalOptions, TerminalSize, TimerOptions,
},
process::ProcessId,
};
@ -299,12 +299,15 @@ pub(crate) fn create_poll_channel() -> Result<RawFd, Error> {
})
}
pub(crate) fn create_pipe(ends: &mut [MaybeUninit<RawFd>; 2]) -> Result<(), Error> {
pub(crate) fn create_pipe(
ends: &mut [MaybeUninit<RawFd>; 2],
options: PipeOptions,
) -> Result<(), Error> {
let thread = Thread::current();
let process = thread.process();
run_with_io(&process, |mut io| {
let (read, write) = File::new_pipe_pair(256);
let (read, write) = File::new_pipe_pair(256, options);
let read_fd = io.files.place_file(read, true)?;
let write_fd = io.files.place_file(write, true)?;

View File

@ -49,6 +49,13 @@ bitfield OpenOptions(u32) {
CREATE_EXCL: 5,
}
bitfield PipeOptions(u32) {
/// If set, read end of the pipe will return WouldBlock if there's no data in the pipe
READ_NONBLOCKING: 0,
/// If set, write end of the pipe will return WouldBlock if the pipe is full
WRITE_NONBLOCKING: 1,
}
enum FileType(u32) {
/// Regular file
File = 1,

View File

@ -105,7 +105,7 @@ syscall create_pid(pid: ProcessId) -> Result<RawFd>;
syscall create_pty(opts: &TerminalOptions, size: &TerminalSize, fds: &mut [MaybeUninit<RawFd>; 2]) -> Result<()>;
syscall create_shared_memory(size: usize) -> Result<RawFd>;
syscall create_poll_channel() -> Result<RawFd>;
syscall create_pipe(fds: &mut [MaybeUninit<RawFd>; 2]) -> Result<()>;
syscall create_pipe(fds: &mut [MaybeUninit<RawFd>; 2], options: PipeOptions) -> Result<()>;
syscall poll_channel_wait(
poll_fd: RawFd,

View File

@ -5,8 +5,8 @@ mod input;
mod terminal;
pub use crate::generated::{
DirectoryEntry, FileAttr, FileMode, FileType, GroupId, MountOptions, OpenOptions, PollControl,
RawFd, TimerOptions, UnmountOptions, UserId,
DirectoryEntry, FileAttr, FileMode, FileType, GroupId, MountOptions, OpenOptions, PipeOptions,
PollControl, RawFd, TimerOptions, UnmountOptions, UserId,
};
pub use channel::{ChannelPublisherId, MessageDestination, ReceivedMessageMetadata, SentMessage};
pub use device::DeviceRequest;

View File

@ -23,5 +23,5 @@ pub mod poll {
pub use abi::io::{
DirectoryEntry, FileAttr, FileMetadataUpdate, FileMetadataUpdateMode, FileMode, FileType,
OpenOptions, RawFd, SeekFrom, TimerOptions,
OpenOptions, PipeOptions, RawFd, SeekFrom, TimerOptions,
};

View File

@ -17,7 +17,8 @@ mod generated {
error::Error,
io::{
ChannelPublisherId, DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions,
PollControl, RawFd, TerminalOptions, TerminalSize, TimerOptions, UnmountOptions,
PipeOptions, PollControl, RawFd, TerminalOptions, TerminalSize, TimerOptions,
UnmountOptions,
},
mem::MappingSource,
net::SocketType,

View File

@ -1,10 +1,15 @@
use std::{
io,
io::{self, Read, Write},
os::fd::{AsRawFd, RawFd},
process::Stdio,
time::Duration,
};
use crate::sys::{self, Poll as SysPoll, TimerFd as SysTimerFd, PidFd as SysPidFd};
use crate::sys::{
self, PidFd as SysPidFd, Pipe as SysPipe, Poll as SysPoll, TimerFd as SysTimerFd,
};
use self::sys::PipeImpl;
#[repr(transparent)]
pub struct Poll(sys::PollImpl);
@ -15,6 +20,9 @@ pub struct TimerFd(sys::TimerFdImpl);
#[repr(transparent)]
pub struct PidFd(sys::PidFdImpl);
#[repr(transparent)]
pub struct Pipe(sys::PipeImpl);
impl Poll {
pub fn new() -> io::Result<Self> {
sys::PollImpl::new().map(Self)
@ -74,3 +82,36 @@ impl AsRawFd for PidFd {
self.0.as_raw_fd()
}
}
impl Pipe {
pub fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)> {
let (read, write) = PipeImpl::new(read_nonblocking, write_nonblocking)?;
Ok((Self(read), Self(write)))
}
pub fn to_child_stdio(&self) -> Stdio {
self.0.to_child_stdio()
}
}
impl Read for Pipe {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl Write for Pipe {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
impl AsRawFd for Pipe {
fn as_raw_fd(&self) -> RawFd {
self.0.as_raw_fd()
}
}

View File

@ -9,8 +9,9 @@ mod unix;
pub(crate) use unix::*;
use std::{
io,
io::{self, Read, Write},
os::fd::{AsRawFd, RawFd},
process::Stdio,
time::Duration,
};
@ -31,3 +32,8 @@ pub(crate) trait PidFd: Sized + AsRawFd {
fn new(pid: u32) -> io::Result<Self>;
fn exit_status(&self) -> io::Result<i32>;
}
pub(crate) trait Pipe: Read + Write + AsRawFd + Sized {
fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)>;
fn to_child_stdio(&self) -> Stdio;
}

View File

@ -1,7 +1,9 @@
pub mod poll;
pub mod timer;
pub mod pid;
pub mod pipe;
pub use poll::PollImpl;
pub use timer::TimerFdImpl;
pub use pid::PidFdImpl;
pub use pipe::PipeImpl;

View File

@ -10,7 +10,7 @@ pub struct PidFdImpl {
impl PidFd for PidFdImpl {
fn new(pid: u32) -> io::Result<Self> {
let pid = pid as i32;
let fd = unsafe { libc::pidfd_open(pid) };
let fd = unsafe { libc::syscall(libc::SYS_pidfd_open, pid, 0) } as i32;
if fd < 0 {
return Err(io::Error::last_os_error());
}
@ -19,7 +19,7 @@ impl PidFd for PidFdImpl {
}
fn exit_status(&self) -> io::Result<i32> {
let status = 0;
let mut status = 0;
let res = unsafe { libc::waitpid(self.pid, &mut status, 0) };
if res < 0 {
return Err(io::Error::last_os_error());

View File

@ -0,0 +1,47 @@
use std::{io::{self, Read, Write}, os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}};
use crate::sys::Pipe;
pub struct PipeImpl {
fd: OwnedFd,
}
impl Pipe for PipeImpl {
fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)> {
let mut fds = [0; 2];
let res = unsafe { libc::pipe(fds.as_mut_ptr()) };
if res < 0 {
return Err(io::Error::last_os_error());
}
let read = unsafe { OwnedFd::from_raw_fd(fds[0]) };
let write = unsafe { OwnedFd::from_raw_fd(fds[1]) };
Ok((Self { fd: read }, Self { fd: write }))
}
fn to_child_stdio(&self) -> std::process::Stdio {
todo!()
}
}
impl Read for PipeImpl {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
todo!()
}
}
impl Write for PipeImpl {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
todo!()
}
fn flush(&mut self) -> io::Result<()> {
todo!()
}
}
impl AsRawFd for PipeImpl {
fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
}

View File

@ -1,7 +1,9 @@
pub mod poll;
pub mod timer;
pub mod pid;
pub mod pipe;
pub use poll::PollImpl;
pub use timer::TimerFdImpl;
pub use pid::PidFdImpl;
pub use pipe::PipeImpl;

View File

@ -0,0 +1,50 @@
use std::{
fs::File,
io::{self, Read, Write},
os::{
fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
yggdrasil::io::pipe::create_pipe_pair,
},
process::Stdio,
};
use crate::sys::Pipe;
pub struct PipeImpl {
file: File,
}
impl Pipe for PipeImpl {
fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)> {
let (read, write) = create_pipe_pair(read_nonblocking, write_nonblocking)?;
let read = unsafe { File::from_raw_fd(read.into_raw_fd()) };
let write = unsafe { File::from_raw_fd(write.into_raw_fd()) };
Ok((Self { file: read }, Self { file: write }))
}
fn to_child_stdio(&self) -> Stdio {
unsafe { Stdio::from_raw_fd(self.as_raw_fd()) }
}
}
impl Write for PipeImpl {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.file.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.file.flush()
}
}
impl Read for PipeImpl {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.file.read(buf)
}
}
impl AsRawFd for PipeImpl {
fn as_raw_fd(&self) -> RawFd {
self.file.as_raw_fd()
}
}

View File

@ -1,5 +1,5 @@
#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))]
#![feature(generic_const_exprs, portable_simd, if_let_guard)]
#![feature(generic_const_exprs, portable_simd, if_let_guard, let_chains)]
#![allow(incomplete_features)]
pub mod socket;

View File

@ -2,7 +2,7 @@
#![feature(let_chains)]
use std::{
io::{stdout, IsTerminal, Read, Write},
io::{stderr, stdout, IsTerminal, Read, Stderr, Write},
ops::{Deref, DerefMut},
os::fd::AsRawFd,
};
@ -11,11 +11,11 @@ use clap::Parser;
use cross::io::Poll;
use rsh::{
crypt::{
client::{self, ClientSocket},
client::{self, ClientSocket, Message},
config::{ClientConfig, SimpleClientKeyStore},
signature::{SignEd25519, SignatureMethod},
},
proto::ServerMessage,
proto::{ServerMessage, StreamIndex},
};
use std::{
@ -70,8 +70,9 @@ enum Input {
pub struct Client {
poll: Poll,
socket: ClientSocket,
input: Input,
stdin: RawStdin,
stdout: Stdout,
stderr: Stderr,
need_bye: bool,
last0: u8,
last1: u8,
@ -124,30 +125,14 @@ impl Drop for RawStdin {
}
impl Client {
pub fn connect(
remote: SocketAddr,
crypto_config: ClientConfig,
command: Vec<String>,
) -> Result<Self, Error> {
pub fn connect(remote: SocketAddr, crypto_config: ClientConfig) -> Result<Self, Error> {
let mut poll = Poll::new()?;
let mut socket = ClientSocket::connect(remote, crypto_config)?;
let input = match command.is_empty() {
true => {
let stdin = RawStdin::open()?;
poll.add(&stdin)?;
Input::Stdin(stdin)
},
false => {
let mut bytes = vec![];
for command in command {
bytes.extend_from_slice(command.as_bytes());
bytes.push(b' ');
}
Input::Command(bytes)
},
};
let stdin = RawStdin::open()?;
let stdout = stdout();
let stderr = stderr();
poll.add(&stdin)?;
poll.add(&socket)?;
let info = terminal_info(&stdout)?;
@ -155,8 +140,9 @@ impl Client {
Self::handshake(&mut socket, info)?;
Ok(Self {
input,
stdin,
stdout,
stderr,
socket,
poll,
need_bye: false,
@ -187,19 +173,20 @@ impl Client {
pub fn run(mut self) -> Result<String, Error> {
let mut recv_buf = [0; 512];
if let Input::Command(command) = &self.input {
self.socket.write_all(&ClientMessage::Input(&command[..]))?;
}
loop {
if let Some(message) = self.socket.read(&mut recv_buf)? {
match message {
ServerMessage::Bye(reason) => return Ok(reason.into()),
ServerMessage::Output(data) => {
ServerMessage::Output(StreamIndex::Stdout, data) => {
self.stdout.write_all(data).ok();
self.stdout.flush().ok();
continue;
}
ServerMessage::Output(StreamIndex::Stderr, data) => {
self.stderr.write_all(data).ok();
self.stderr.flush().ok();
continue;
}
_ => continue,
}
}
@ -210,8 +197,8 @@ impl Client {
if self.socket.poll()? == 0 {
return Ok("".into());
}
} else if let Input::Stdin(stdin) = &mut self.input && stdin.as_raw_fd() == fd {
let len = stdin.read(&mut recv_buf)?;
} else if self.stdin.as_raw_fd() == fd {
let len = self.stdin.read(&mut recv_buf)?;
self.update_last(&recv_buf[..len])?;
self.socket
.write_all(&ClientMessage::Input(&recv_buf[..len]))?;
@ -247,27 +234,101 @@ fn terminal_info(stdout: &Stdout) -> Result<TerminalInfo, Error> {
})
}
fn run(args: Args) -> Result<(), Error> {
fn run_terminal(remote: SocketAddr, config: ClientConfig) -> Result<(), Error> {
let reason = Client::connect(remote, config)?.run()?;
if !reason.is_empty() {
eprintln!("\nDisconnected: {reason}");
}
Ok(())
}
fn run_command(
remote: SocketAddr,
config: ClientConfig,
command: Vec<String>,
) -> Result<ExitCode, Error> {
let mut poll = Poll::new()?;
let mut buffer = [0; 512];
let mut command_string = String::new();
for (i, word) in command.iter().enumerate() {
if i != 0 {
command_string.push(' ');
}
command_string.push_str(word);
}
let mut stdin = stdin();
let mut stdout = stdout();
let mut stderr = stderr();
let mut socket = ClientSocket::connect(remote, config)?;
poll.add(&socket)?;
poll.add(&stdin)?;
socket.write_all(&ClientMessage::RunCommand(command_string.as_str()))?;
loop {
let fd = poll.wait(None)?.unwrap();
match fd {
_ if fd == socket.as_raw_fd() => {
let message = match socket.poll_read(&mut buffer)? {
Message::Data(data) => data,
Message::Incomplete => continue,
Message::Closed => break
};
match message {
ServerMessage::Output(StreamIndex::Stdout, output) => {
stdout.write_all(output).ok();
stdout.flush().ok();
}
ServerMessage::Output(StreamIndex::Stderr, output) => {
stderr.write_all(output).ok();
stderr.flush().ok();
}
_ => todo!()
}
}
_ if fd == stdin.as_raw_fd() => {
let len = stdin.read(&mut buffer)?;
if len == 0 {
poll.remove(&stdin)?;
socket.write_all(&ClientMessage::CloseStdin)?;
} else {
socket.write_all(&ClientMessage::Input(&buffer[..len]))?;
}
}
_ => unreachable!()
}
}
Ok(ExitCode::SUCCESS)
}
fn run(args: Args) -> Result<ExitCode, Error> {
let remote = SocketAddr::new(args.remote, args.port);
let ed25519 = SignEd25519::load_signing_key(args.key).unwrap();
let key = SignatureMethod::Ed25519(ed25519);
let config = ClientConfig::with_default_algorithms(SimpleClientKeyStore::new(key));
let reason = Client::connect(remote, config, args.command)?.run()?;
if !reason.is_empty() {
eprintln!("\nDisconnected: {reason}");
if args.command.is_empty() {
run_terminal(remote, config).map(|_| ExitCode::SUCCESS)
} else {
run_command(remote, config, args.command)
}
Ok(())
}
fn main() -> ExitCode {
env_logger::init();
let args = Args::parse();
if let Err(error) = run(args) {
eprintln!("Error: {error}");
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
match run(args) {
Ok(status) => status,
Err(error) => {
eprintln!("Error: {error}");
ExitCode::FAILURE
}
}
}

View File

@ -13,6 +13,13 @@ pub enum ClientMessage<'a> {
RunCommand(&'a str),
Bye(&'a str),
Input(&'a [u8]),
CloseStdin,
}
#[derive(Debug, Clone, Copy)]
pub enum StreamIndex {
Stdout,
Stderr,
}
#[derive(Debug)]
@ -20,7 +27,7 @@ pub enum ServerMessage<'a> {
SessionOpen,
CommandStatus(i32),
Bye(&'a str),
Output(&'a [u8]),
Output(StreamIndex, &'a [u8]),
}
#[derive(Debug, thiserror::Error)]
@ -164,6 +171,32 @@ impl ClientMessage<'_> {
const TAG_RUN_COMMAND: u8 = 0x81;
const TAG_BYE: u8 = 0x82;
const TAG_INPUT: u8 = 0x90;
const TAG_CLOSE_STDIN: u8 = 0x91;
}
impl StreamIndex {
const TAG_STDOUT: u8 = 0x01;
const TAG_STDERR: u8 = 0x02;
}
impl Encode for StreamIndex {
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
match self {
Self::Stdout => buffer.write(&[Self::TAG_STDOUT]),
Self::Stderr => buffer.write(&[Self::TAG_STDERR]),
}
}
}
impl<'de> Decode<'de> for StreamIndex {
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
let tag = buffer.read_u8()?;
match tag {
Self::TAG_STDOUT => Ok(Self::Stdout),
Self::TAG_STDERR => Ok(Self::Stderr),
_ => Err(DecodeError::InvalidMessage)
}
}
}
impl Encode for TerminalInfo {
@ -201,6 +234,7 @@ impl<'a> Encode for ClientMessage<'a> {
buffer.write(&[Self::TAG_INPUT])?;
buffer.write_variable_bytes(data)
}
Self::CloseStdin => buffer.write(&[Self::TAG_CLOSE_STDIN])
}
}
}
@ -221,6 +255,9 @@ impl<'de> Decode<'de> for ClientMessage<'de> {
Self::TAG_INPUT => {
buffer.read_variable_bytes().map(Self::Input)
}
Self::TAG_CLOSE_STDIN => {
Ok(Self::CloseStdin)
}
_ => Err(DecodeError::InvalidMessage)
}
}
@ -244,8 +281,9 @@ impl<'a> Encode for ServerMessage<'a> {
buffer.write(&[Self::TAG_BYE])?;
buffer.write_str(reason)
}
Self::Output(data) => {
Self::Output(index, data) => {
buffer.write(&[Self::TAG_OUTPUT])?;
index.encode(buffer)?;
buffer.write_variable_bytes(data)
}
}
@ -267,7 +305,9 @@ impl<'de> Decode<'de> for ServerMessage<'de> {
buffer.read_str().map(Self::Bye)
}
Self::TAG_OUTPUT => {
buffer.read_variable_bytes().map(Self::Output)
let index = StreamIndex::decode(buffer)?;
let data = buffer.read_variable_bytes()?;
Ok(Self::Output(index, data))
},
_ => Err(DecodeError::InvalidMessage),
}

View File

@ -1,12 +1,15 @@
use std::{
collections::{hash_map::Entry, HashMap},
fmt, io,
fmt,
hint::unreachable_unchecked,
io::{self, Read, Stdout, Write},
net::SocketAddr,
os::fd::{AsRawFd, RawFd},
process::{Child, Command, Stdio},
time::Duration,
};
use cross::io::Poll;
use cross::io::{PidFd, Pipe, Poll};
use crate::{
crypt::{
@ -14,7 +17,7 @@ use crate::{
config::ServerConfig,
server::{self, ServerSocket},
},
proto::{ClientMessage, ServerMessage, TerminalInfo},
proto::{ClientMessage, ServerMessage, StreamIndex, TerminalInfo},
};
pub const PING_INTERVAL: Duration = Duration::from_millis(500);
@ -52,22 +55,28 @@ pub struct EchoSession {
peer: SocketAddr,
}
struct ClientSession<T: Session> {
stream: ClientSocket,
session: Option<T>,
struct PendingCommand {
child: Child,
child_pid: Option<PidFd>,
stdin: Option<Pipe>,
stdout: Option<Pipe>,
stderr: Option<Pipe>,
}
// enum Event<'b, T: Session> {
// NewSession(SocketAddr, TerminalInfo),
// SessionInput(u64, SocketAddr, &'b [u8]),
// ClientBye(SocketAddr, &'b str),
// SessionEvent(RawFd, SocketAddr, SessionEvent<'b, T>),
// Tick,
// }
enum ClientSession<T: Session> {
None,
Terminal(T),
Command(PendingCommand),
}
struct SessionSet<T: Session> {
struct Client<T: Session> {
stream: ClientSocket,
session: ClientSession<T>,
}
struct ClientSet<T: Session> {
last_session_key: u64,
sessions: HashMap<u64, ClientSession<T>>,
clients: HashMap<u64, Client<T>>,
socket_fd_map: HashMap<RawFd, u64>,
session_fd_map: HashMap<RawFd, u64>,
}
@ -76,29 +85,63 @@ pub struct Server<T: Session> {
poll: Poll,
socket: ServerSocket,
sessions: SessionSet<T>,
clients: ClientSet<T>,
}
impl<T: Session> SessionSet<T> {
enum CommandEvent {
Output(StreamIndex, Result<usize, io::Error>),
Exited(Result<i32, io::Error>)
}
impl<T: Session> Client<T> {
fn make_command(text: &str) -> Result<PendingCommand, Error> {
let mut words = text.split(' ');
let program = words.next().unwrap();
let mut command = Command::new(program);
let (stdin_read, stdin_write) = Pipe::new(false, false)?;
let (stdout_read, stdout_write) = Pipe::new(true, false)?;
let (stderr_read, stderr_write) = Pipe::new(true, false)?;
command
.args(words)
.stdin(stdin_read.to_child_stdio())
.stdout(stdout_write.to_child_stdio())
.stderr(stderr_write.to_child_stdio());
let child = command.spawn()?;
let child_pid = PidFd::new(child.id())?;
Ok(PendingCommand {
child,
child_pid: Some(child_pid),
stdout: Some(stdout_read),
stderr: Some(stderr_read),
stdin: Some(stdin_write),
})
}
}
impl<T: Session> ClientSet<T> {
pub fn new() -> Self {
Self {
last_session_key: 1,
sessions: HashMap::new(),
clients: HashMap::new(),
socket_fd_map: HashMap::new(),
session_fd_map: HashMap::new(),
}
}
pub fn add_client(&mut self, stream: ClientSocket, poll: &mut Poll) -> Result<(), Error> {
let (key, session) = loop {
let (key, client) = loop {
let key = self.last_session_key;
self.last_session_key += 1;
match self.sessions.entry(key) {
match self.clients.entry(key) {
Entry::Vacant(entry) => {
let session = entry.insert(ClientSession {
let session = entry.insert(Client {
stream,
session: None,
session: ClientSession::None,
});
break (key, session);
}
@ -106,22 +149,39 @@ impl<T: Session> SessionSet<T> {
}
};
poll.add(&session.stream)?;
self.socket_fd_map.insert(session.stream.as_raw_fd(), key);
poll.add(&client.stream)?;
self.socket_fd_map.insert(client.stream.as_raw_fd(), key);
Ok(())
}
fn remove(&mut self, key: u64, poll: &mut Poll) -> Result<(), Error> {
if let Some(session) = self.sessions.remove(&key) {
poll.remove(&session.stream)?;
self.socket_fd_map.remove(&session.stream.as_raw_fd());
if let Some(client) = self.clients.remove(&key) {
poll.remove(&client.stream)?;
self.socket_fd_map.remove(&client.stream.as_raw_fd());
if let Some(session) = session.session {
for fd in session.event_fds() {
poll.remove(fd)?;
self.session_fd_map.remove(fd);
match client.session {
ClientSession::Terminal(terminal) => {
for fd in terminal.event_fds() {
poll.remove(fd)?;
self.session_fd_map.remove(fd);
}
}
ClientSession::Command(mut command) => {
if let Some(stdout) = command.stdout.take() {
poll.remove(&stdout)?;
self.session_fd_map.remove(&stdout.as_raw_fd());
}
if let Some(stderr) = command.stderr.take() {
poll.remove(&stderr)?;
self.session_fd_map.remove(&stderr.as_raw_fd());
}
if let Some(child_pid) = command.child_pid.take() {
poll.remove(&child_pid)?;
self.session_fd_map.remove(&child_pid.as_raw_fd());
}
}
ClientSession::None => (),
}
}
Ok(())
@ -131,10 +191,10 @@ impl<T: Session> SessionSet<T> {
let mut buffer = [0; 512];
if let Some(&key) = self.socket_fd_map.get(&fd) {
let session = self.sessions.get_mut(&key).unwrap();
let peer = session.stream.remote_address();
let client = self.clients.get_mut(&key).unwrap();
let peer = client.stream.remote_address();
let mut closed = match session.stream.poll() {
let mut closed = match client.stream.poll() {
Ok(0) => true,
Ok(_) => false,
Err(error) => {
@ -144,7 +204,7 @@ impl<T: Session> SessionSet<T> {
};
loop {
let message = match session.stream.read(&mut buffer) {
let message = match client.stream.read(&mut buffer) {
Ok(Some(message)) => message,
Ok(None) => break,
Err(error) => {
@ -154,8 +214,33 @@ impl<T: Session> SessionSet<T> {
}
};
match (&message, &mut session.session) {
(ClientMessage::OpenSession(terminal), None) => {
match (&message, &mut client.session) {
(ClientMessage::RunCommand(command), ClientSession::None) => {
log::info!("{peer}: cmd {command:?}");
let command = match Client::<T>::make_command(command) {
Ok(command) => command,
Err(error) => {
log::error!("{peer}: command error: {error}");
return self.remove(key, poll);
}
};
let stdout_fd = command.stdout.as_ref().unwrap().as_raw_fd();
let stderr_fd = command.stderr.as_ref().unwrap().as_raw_fd();
let child_pid_fd = command.child_pid.as_ref().unwrap().as_raw_fd();
poll.add(&stdout_fd)?;
poll.add(&stderr_fd)?;
poll.add(&child_pid_fd)?;
self.session_fd_map.insert(stdout_fd, key);
self.session_fd_map.insert(stderr_fd, key);
self.session_fd_map.insert(child_pid_fd, key);
client.session = ClientSession::Command(command);
}
(ClientMessage::OpenSession(terminal), ClientSession::None) => {
log::info!("{peer}: new session");
let terminal = match T::open(&peer, &terminal) {
Ok(session) => session,
@ -165,15 +250,33 @@ impl<T: Session> SessionSet<T> {
return Ok(());
}
};
let terminal = session.session.insert(terminal);
let terminal = client.session.set_terminal(terminal);
for fd in terminal.event_fds() {
poll.add(fd)?;
self.session_fd_map.insert(*fd, key);
}
session.stream.write_all(&ServerMessage::SessionOpen).ok();
client.stream.write_all(&ServerMessage::SessionOpen).ok();
}
(ClientMessage::Input(data), Some(terminal)) => {
let client = SessionClient { transport: &mut session.stream };
(ClientMessage::Input(data), ClientSession::Command(command)) => {
if let Some(stdin) = command.stdin.as_mut() {
match stdin.write_all(data) {
Ok(()) => {
stdin.flush().ok();
},
Err(error) => {
log::error!("{peer}: stdin error: {error}");
return self.remove(key, poll);
}
}
} else {
log::warn!("{peer}: stdin already closed");
}
}
(ClientMessage::Input(data), ClientSession::Terminal(terminal)) => {
let client = SessionClient {
transport: &mut client.stream,
};
match terminal.handle_input(data, client) {
Ok(false) => (),
Ok(true) => {
@ -186,6 +289,10 @@ impl<T: Session> SessionSet<T> {
}
}
}
(ClientMessage::CloseStdin, ClientSession::Command(command)) => {
// Drop stdin handle
command.stdin.take();
}
_ => {
log::warn!("{peer}: unhandled message");
}
@ -199,45 +306,148 @@ impl<T: Session> SessionSet<T> {
Ok(())
} else if let Some(&key) = self.session_fd_map.get(&fd) {
let session = self.sessions.get_mut(&key).unwrap();
let terminal = session.session.as_mut().unwrap();
let peer = session.stream.remote_address();
log::debug!("poll fd {:?}", fd);
let client = self.clients.get_mut(&key).unwrap();
let peer = client.stream.remote_address();
match terminal.read_output(fd, &mut buffer) {
Ok(0) => {
log::info!("{peer}: session closed");
self.remove(key, poll)
}
Ok(mut len) => {
// Split output into 128-byte chunks
let mut pos = 0;
while len != 0 {
let amount = core::cmp::min(len, 128);
let (mut len, stream_index) = match &mut client.session {
ClientSession::Command(command) => match command.read_output(fd, &mut buffer) {
CommandEvent::Output(index, Ok(len)) => {
if len == 0 {
poll.remove(&fd)?;
self.socket_fd_map.remove(&fd);
if let Err(error) = session
.stream
.write_all(&ServerMessage::Output(&buffer[pos..pos + amount]))
{
log::error!("{peer}: communication error: {error}");
match index {
StreamIndex::Stdout => {
command.stdout = None;
}
StreamIndex::Stderr => {
command.stderr = None;
}
}
if command.is_dead() {
log::info!("{peer}: command stdout/stderr closed");
return self.remove(key, poll);
}
}
(len, index)
},
CommandEvent::Output(index, Err(error)) => {
if error.kind() == io::ErrorKind::WouldBlock {
return Ok(());
}
log::error!("{peer}: {index:?} error: {error}");
return self.remove(key, poll);
}
CommandEvent::Exited(status) => {
log::info!("{peer}: command exited: {:?}", status);
poll.remove(&fd)?;
self.socket_fd_map.remove(&fd);
command.child_pid = None;
if command.is_dead() {
log::info!("{peer}: command stdout/stderr closed");
return self.remove(key, poll);
}
pos += amount;
len -= amount;
return Ok(());
}
Ok(())
}
Err(error) => {
log::error!("{peer}: session read error: {error}");
self.remove(key, poll)
}
},
ClientSession::Terminal(terminal) => match terminal.read_output(fd, &mut buffer) {
Ok(0) => {
poll.remove(&fd)?;
self.socket_fd_map.remove(&fd);
// TODO check for process as well
log::info!("{peer}: terminal closed");
return self.remove(key, poll);
}
Ok(len) => {
(len, StreamIndex::Stdout)
},
Err(error) => {
log::error!("{peer}: terminal error: {error}");
return self.remove(key, poll);
}
},
ClientSession::None => unreachable!(),
};
if len == 0 {
log::info!("{peer}: {stream_index:?} closed");
return Ok(());
// return self.remove(key, poll);
}
// Split output into 128-byte chunks
let mut pos = 0;
while len != 0 {
let amount = core::cmp::min(len, 128);
if let Err(error) = client.stream.write_all(&ServerMessage::Output(
stream_index,
&buffer[pos..pos + amount],
)) {
log::error!("{peer}: communication error: {error}");
return self.remove(key, poll);
}
pos += amount;
len -= amount;
}
log::debug!("Done");
Ok(())
} else {
unreachable!()
}
}
}
impl PendingCommand {
pub fn read_output(
&mut self,
fd: RawFd,
buffer: &mut [u8],
) -> CommandEvent {
if let Some(stdout) = self.stdout.as_mut() && fd == stdout.as_raw_fd() {
log::debug!("poll stdout");
let res = stdout.read(buffer);
log::debug!(">> {:?}", res);
return CommandEvent::Output(StreamIndex::Stdout, res);
}
if let Some(stderr) = self.stderr.as_mut() && fd == stderr.as_raw_fd() {
return CommandEvent::Output(StreamIndex::Stderr, stderr.read(buffer));
}
if let Some(child_pid) = self.child_pid.as_mut() && fd == child_pid.as_raw_fd() {
let status = child_pid.exit_status();
return CommandEvent::Exited(status);
}
unreachable!()
}
pub fn is_dead(&self) -> bool {
self.child_pid.is_none() && self.stdout.is_none() && self.stderr.is_none()
}
}
impl Drop for PendingCommand {
fn drop(&mut self) {
self.child.wait().ok();
}
}
impl<T: Session> ClientSession<T> {
pub fn set_terminal(&mut self, terminal: T) -> &mut T {
*self = Self::Terminal(terminal);
match self {
Self::Terminal(terminal) => terminal,
_ => unsafe { unreachable_unchecked() },
}
}
}
impl<T: Session> Server<T> {
pub fn listen(listen_addr: SocketAddr, crypto_config: ServerConfig) -> Result<Self, Error> {
let mut poll = Poll::new()?;
@ -246,7 +456,7 @@ impl<T: Session> Server<T> {
Ok(Self {
poll,
socket,
sessions: SessionSet::new(),
clients: ClientSet::new(),
})
}
@ -259,14 +469,14 @@ impl<T: Session> Server<T> {
// Poll server socket
match self.socket.poll()? {
Some(client) => {
self.sessions.add_client(client, &mut self.poll)?;
self.clients.add_client(client, &mut self.poll)?;
}
None => continue,
}
}
_ => {
// Client/Session-related activity
self.sessions.handle_input(fd, &mut self.poll)?;
self.clients.handle_input(fd, &mut self.poll)?;
}
}
}
@ -275,7 +485,7 @@ impl<T: Session> Server<T> {
impl<'s> SessionClient<'s> {
pub fn send_data(&mut self, data: &[u8]) -> Result<(), Error> {
self.send_message(&ServerMessage::Output(data))
self.send_message(&ServerMessage::Output(StreamIndex::Stdout, data))
}
pub fn send_message(&mut self, message: &ServerMessage) -> Result<(), Error> {

View File

@ -74,7 +74,7 @@ pub fn wait_for_pipeline(
}
pub fn create_pipe() -> Result<Pipe, io::Error> {
let (read, write) = pipe::create_pipe_pair()?;
let (read, write) = pipe::create_pipe_pair(false, false)?;
Ok(Pipe { read, write })
}