vfs/rsh: better pipes, piped command execution in rsh

This commit is contained in:
Mark Poliakov 2024-11-06 19:40:27 +02:00
parent a707a6e5f1
commit 2479702baf
21 changed files with 707 additions and 202 deletions

View File

@ -17,8 +17,8 @@ use libk_util::sync::IrqSafeSpinlock;
use yggdrasil_abi::{ use yggdrasil_abi::{
error::Error, error::Error,
io::{ io::{
DeviceRequest, DirectoryEntry, OpenOptions, RawFd, SeekFrom, TerminalOptions, TerminalSize, DeviceRequest, DirectoryEntry, OpenOptions, PipeOptions, RawFd, SeekFrom, TerminalOptions,
TimerOptions, TerminalSize, TimerOptions,
}, },
}; };
@ -71,7 +71,7 @@ pub enum File {
Block(BlockFile), Block(BlockFile),
Char(CharFile), Char(CharFile),
Socket(SocketWrapper), Socket(SocketWrapper),
AnonymousPipe(PipeEnd), AnonymousPipe(PipeEnd, AtomicBool),
Poll(FdPoll), Poll(FdPoll),
Timer(TimerFile), Timer(TimerFile),
Channel(ChannelDescriptor), Channel(ChannelDescriptor),
@ -103,11 +103,17 @@ pub struct FileSet {
impl File { impl File {
/// Constructs a pipe pair, returning its `(read, write)` ends /// Constructs a pipe pair, returning its `(read, write)` ends
pub fn new_pipe_pair(capacity: usize) -> (Arc<Self>, Arc<Self>) { pub fn new_pipe_pair(capacity: usize, options: PipeOptions) -> (Arc<Self>, Arc<Self>) {
let (read, write) = PipeEnd::new_pair(capacity); let (read, write) = PipeEnd::new_pair(capacity);
( (
Arc::new(Self::AnonymousPipe(read)), Arc::new(Self::AnonymousPipe(
Arc::new(Self::AnonymousPipe(write)), read,
AtomicBool::new(options.contains(PipeOptions::READ_NONBLOCKING)),
)),
Arc::new(Self::AnonymousPipe(
write,
AtomicBool::new(options.contains(PipeOptions::WRITE_NONBLOCKING)),
)),
) )
} }
@ -289,6 +295,7 @@ impl File {
Self::Socket(socket) => socket.poll_read(cx), Self::Socket(socket) => socket.poll_read(cx),
Self::Timer(timer) => timer.poll_read(cx), Self::Timer(timer) => timer.poll_read(cx),
Self::Pid(pid) => pid.poll_read(cx), Self::Pid(pid) => pid.poll_read(cx),
Self::AnonymousPipe(pipe, _) => pipe.poll_read(cx),
// Polling not implemented, return ready immediately (XXX ?) // Polling not implemented, return ready immediately (XXX ?)
_ => Poll::Ready(Err(Error::NotImplemented)), _ => Poll::Ready(Err(Error::NotImplemented)),
} }
@ -365,7 +372,9 @@ impl Read for File {
Self::Regular(file) => file.read(buf), Self::Regular(file) => file.read(buf),
Self::Block(file) => file.read(buf), Self::Block(file) => file.read(buf),
Self::Char(file) => file.read(buf), Self::Char(file) => file.read(buf),
Self::AnonymousPipe(pipe) => pipe.read(buf), Self::AnonymousPipe(pipe, nonblocking) => {
pipe.read(buf, nonblocking.load(Ordering::Acquire))
}
Self::PtySlave(half) => half.read(buf), Self::PtySlave(half) => half.read(buf),
Self::PtyMaster(half) => half.read(buf), Self::PtyMaster(half) => half.read(buf),
Self::Timer(timer) => timer.read(buf), Self::Timer(timer) => timer.read(buf),
@ -388,7 +397,7 @@ impl Write for File {
Self::Regular(file) => file.write(buf), Self::Regular(file) => file.write(buf),
Self::Block(file) => file.write(buf), Self::Block(file) => file.write(buf),
Self::Char(file) => file.write(buf), Self::Char(file) => file.write(buf),
Self::AnonymousPipe(pipe) => pipe.write(buf), Self::AnonymousPipe(pipe, _) => pipe.write(buf),
Self::PtySlave(half) => half.write(buf), Self::PtySlave(half) => half.write(buf),
Self::PtyMaster(half) => half.write(buf), Self::PtyMaster(half) => half.write(buf),
Self::Timer(timer) => timer.write(buf), Self::Timer(timer) => timer.write(buf),
@ -447,7 +456,7 @@ impl fmt::Debug for File {
.field("write", &file.write) .field("write", &file.write)
.finish_non_exhaustive(), .finish_non_exhaustive(),
Self::Directory(_) => f.debug_struct("DirectoryFile").finish_non_exhaustive(), Self::Directory(_) => f.debug_struct("DirectoryFile").finish_non_exhaustive(),
Self::AnonymousPipe(_) => f.debug_struct("AnonymousPipe").finish_non_exhaustive(), Self::AnonymousPipe(_, _) => f.debug_struct("AnonymousPipe").finish_non_exhaustive(),
Self::Poll(_) => f.debug_struct("Poll").finish_non_exhaustive(), Self::Poll(_) => f.debug_struct("Poll").finish_non_exhaustive(),
Self::Channel(_) => f.debug_struct("Channel").finish_non_exhaustive(), Self::Channel(_) => f.debug_struct("Channel").finish_non_exhaustive(),
Self::SharedMemory(_) => f.debug_struct("SharedMemory").finish_non_exhaustive(), Self::SharedMemory(_) => f.debug_struct("SharedMemory").finish_non_exhaustive(),

View File

@ -1,4 +1,5 @@
use core::{ use core::{
future::poll_fn,
pin::Pin, pin::Pin,
sync::atomic::{AtomicBool, Ordering}, sync::atomic::{AtomicBool, Ordering},
task::{Context, Poll}, task::{Context, Poll},
@ -6,9 +7,14 @@ use core::{
use alloc::{sync::Arc, vec, vec::Vec}; use alloc::{sync::Arc, vec, vec::Vec};
use futures_util::{task::AtomicWaker, Future}; use futures_util::{task::AtomicWaker, Future};
use libk_util::sync::IrqSafeSpinlock; use libk_util::{
sync::{IrqSafeSpinlock, IrqSafeSpinlockGuard},
waker::QueueWaker,
};
use yggdrasil_abi::error::Error; use yggdrasil_abi::error::Error;
use crate::vfs::FileReadiness;
struct PipeInner { struct PipeInner {
data: Vec<u8>, data: Vec<u8>,
capacity: usize, capacity: usize,
@ -19,8 +25,8 @@ struct PipeInner {
pub struct Pipe { pub struct Pipe {
inner: IrqSafeSpinlock<PipeInner>, inner: IrqSafeSpinlock<PipeInner>,
shutdown: AtomicBool, shutdown: AtomicBool,
read_notify: AtomicWaker, read_notify: QueueWaker,
write_notify: AtomicWaker, write_notify: QueueWaker,
} }
pub enum PipeEnd { pub enum PipeEnd {
@ -82,15 +88,15 @@ impl Pipe {
Self { Self {
inner: IrqSafeSpinlock::new(PipeInner::new(capacity)), inner: IrqSafeSpinlock::new(PipeInner::new(capacity)),
shutdown: AtomicBool::new(false), shutdown: AtomicBool::new(false),
read_notify: AtomicWaker::new(), read_notify: QueueWaker::new(),
write_notify: AtomicWaker::new(), write_notify: QueueWaker::new(),
} }
} }
pub fn shutdown(&self) { pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Release); self.shutdown.store(true, Ordering::Release);
self.read_notify.wake(); self.read_notify.wake_all();
self.write_notify.wake(); self.write_notify.wake_all();
} }
pub fn blocking_write(&self, val: u8) -> impl Future<Output = Result<(), Error>> + '_ { pub fn blocking_write(&self, val: u8) -> impl Future<Output = Result<(), Error>> + '_ {
@ -105,23 +111,15 @@ impl Pipe {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut lock = self.pipe.inner.lock(); let mut lock = self.pipe.inner.lock();
// Try fast path before acquiring write notify to avoid unnecessary contention
if self.pipe.shutdown.load(Ordering::Acquire) {
// TODO BrokenPipe
return Poll::Ready(Err(Error::ReadOnly));
} else if lock.try_write(self.val) {
self.pipe.read_notify.wake();
return Poll::Ready(Ok(()));
}
self.pipe.write_notify.register(cx.waker());
if self.pipe.shutdown.load(Ordering::Acquire) { if self.pipe.shutdown.load(Ordering::Acquire) {
self.pipe.write_notify.remove(cx.waker());
Poll::Ready(Err(Error::ReadOnly)) Poll::Ready(Err(Error::ReadOnly))
} else if lock.try_write(self.val) { } else if lock.try_write(self.val) {
self.pipe.read_notify.wake(); self.pipe.write_notify.remove(cx.waker());
self.pipe.read_notify.wake_one();
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} else { } else {
self.pipe.write_notify.register(cx.waker());
Poll::Pending Poll::Pending
} }
} }
@ -130,37 +128,54 @@ impl Pipe {
F { pipe: self, val } F { pipe: self, val }
} }
pub fn blocking_read(&self) -> impl Future<Output = Option<u8>> + '_ { fn poll_read_end(
struct F<'a> { &self,
pipe: &'a Pipe, cx: &mut Context<'_>,
) -> Poll<Option<IrqSafeSpinlockGuard<'_, PipeInner>>> {
let lock = self.inner.lock();
if lock.can_read() {
self.read_notify.remove(cx.waker());
Poll::Ready(Some(lock))
} else if self.shutdown.load(Ordering::Acquire) {
self.read_notify.remove(cx.waker());
Poll::Ready(None)
} else {
self.read_notify.register(cx.waker());
Poll::Pending
} }
}
impl<'a> Future for F<'a> { async fn read_async(&self, buf: &mut [u8]) -> Result<usize, Error> {
type Output = Option<u8>; let mut pos = 0;
let lock = poll_fn(|cx| self.poll_read_end(cx)).await;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { match lock {
let mut lock = self.pipe.inner.lock(); Some(mut lock) => {
while pos < buf.len()
if let Some(val) = lock.try_read() { && let Some(byte) = lock.try_read()
self.pipe.write_notify.wake(); {
return Poll::Ready(Some(val)); buf[pos] = byte;
} else if self.pipe.shutdown.load(Ordering::Acquire) { pos += 1;
return Poll::Ready(None);
} }
if pos != 0 {
self.pipe.read_notify.register(cx.waker()); self.write_notify.wake_all();
if let Some(val) = lock.try_read() {
Poll::Ready(Some(val))
} else if self.pipe.shutdown.load(Ordering::Acquire) {
Poll::Ready(None)
} else {
Poll::Pending
} }
Ok(pos)
} }
None => Ok(0),
} }
}
}
F { pipe: self } impl FileReadiness for Pipe {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.shutdown.load(Ordering::Acquire) || self.inner.lock().can_read() {
self.read_notify.remove(cx.waker());
Poll::Ready(Ok(()))
} else {
self.read_notify.register(cx.waker());
Poll::Pending
}
} }
} }
@ -173,27 +188,29 @@ impl PipeEnd {
(read, write) (read, write)
} }
pub fn read(&self, buf: &mut [u8]) -> Result<usize, Error> { pub fn read(&self, buf: &mut [u8], nonblocking: bool) -> Result<usize, Error> {
let PipeEnd::Read(read) = self else { let PipeEnd::Read(read) = self else {
return Err(Error::InvalidOperation); return Err(Error::InvalidOperation);
}; };
block! { if nonblocking {
let mut lock = read.inner.lock();
let mut pos = 0; let mut pos = 0;
let mut rem = buf.len(); while pos < buf.len()
&& let Some(ch) = lock.try_read()
while rem != 0 { {
if let Some(val) = read.blocking_read().await { buf[pos] = ch;
buf[pos] = val; pos += 1;
pos += 1;
rem -= 1;
} else {
break;
}
} }
read.write_notify.wake_all();
Ok(pos) if pos == 0 && !read.shutdown.load(Ordering::Acquire) {
}? Err(Error::WouldBlock)
} else {
Ok(pos)
}
} else {
block!(read.read_async(buf).await)?
}
} }
pub fn write(&self, buf: &[u8]) -> Result<usize, Error> { pub fn write(&self, buf: &[u8]) -> Result<usize, Error> {
@ -216,6 +233,15 @@ impl PipeEnd {
} }
} }
impl FileReadiness for PipeEnd {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self {
Self::Read(pipe) => pipe.poll_read(cx),
Self::Write(_) => Poll::Ready(Err(Error::NotImplemented)),
}
}
}
impl Drop for PipeEnd { impl Drop for PipeEnd {
fn drop(&mut self) { fn drop(&mut self) {
match self { match self {

View File

@ -1,8 +1,8 @@
pub(crate) use abi::{ pub(crate) use abi::{
error::Error, error::Error,
io::{ io::{
DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions, PollControl, RawFd, DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions, PipeOptions, PollControl,
TerminalOptions, TerminalSize, TimerOptions, UnmountOptions, RawFd, TerminalOptions, TerminalSize, TimerOptions, UnmountOptions,
}, },
mem::MappingSource, mem::MappingSource,
net::SocketType, net::SocketType,

View File

@ -4,8 +4,8 @@ use abi::{
error::Error, error::Error,
io::{ io::{
ChannelPublisherId, DeviceRequest, DirectoryEntry, FileAttr, FileMetadataUpdate, FileMode, ChannelPublisherId, DeviceRequest, DirectoryEntry, FileAttr, FileMetadataUpdate, FileMode,
MessageDestination, OpenOptions, PollControl, RawFd, ReceivedMessageMetadata, SeekFrom, MessageDestination, OpenOptions, PipeOptions, PollControl, RawFd, ReceivedMessageMetadata,
SentMessage, TerminalOptions, TerminalSize, TimerOptions, SeekFrom, SentMessage, TerminalOptions, TerminalSize, TimerOptions,
}, },
process::ProcessId, process::ProcessId,
}; };
@ -299,12 +299,15 @@ pub(crate) fn create_poll_channel() -> Result<RawFd, Error> {
}) })
} }
pub(crate) fn create_pipe(ends: &mut [MaybeUninit<RawFd>; 2]) -> Result<(), Error> { pub(crate) fn create_pipe(
ends: &mut [MaybeUninit<RawFd>; 2],
options: PipeOptions,
) -> Result<(), Error> {
let thread = Thread::current(); let thread = Thread::current();
let process = thread.process(); let process = thread.process();
run_with_io(&process, |mut io| { run_with_io(&process, |mut io| {
let (read, write) = File::new_pipe_pair(256); let (read, write) = File::new_pipe_pair(256, options);
let read_fd = io.files.place_file(read, true)?; let read_fd = io.files.place_file(read, true)?;
let write_fd = io.files.place_file(write, true)?; let write_fd = io.files.place_file(write, true)?;

View File

@ -49,6 +49,13 @@ bitfield OpenOptions(u32) {
CREATE_EXCL: 5, CREATE_EXCL: 5,
} }
bitfield PipeOptions(u32) {
/// If set, read end of the pipe will return WouldBlock if there's no data in the pipe
READ_NONBLOCKING: 0,
/// If set, write end of the pipe will return WouldBlock if the pipe is full
WRITE_NONBLOCKING: 1,
}
enum FileType(u32) { enum FileType(u32) {
/// Regular file /// Regular file
File = 1, File = 1,

View File

@ -105,7 +105,7 @@ syscall create_pid(pid: ProcessId) -> Result<RawFd>;
syscall create_pty(opts: &TerminalOptions, size: &TerminalSize, fds: &mut [MaybeUninit<RawFd>; 2]) -> Result<()>; syscall create_pty(opts: &TerminalOptions, size: &TerminalSize, fds: &mut [MaybeUninit<RawFd>; 2]) -> Result<()>;
syscall create_shared_memory(size: usize) -> Result<RawFd>; syscall create_shared_memory(size: usize) -> Result<RawFd>;
syscall create_poll_channel() -> Result<RawFd>; syscall create_poll_channel() -> Result<RawFd>;
syscall create_pipe(fds: &mut [MaybeUninit<RawFd>; 2]) -> Result<()>; syscall create_pipe(fds: &mut [MaybeUninit<RawFd>; 2], options: PipeOptions) -> Result<()>;
syscall poll_channel_wait( syscall poll_channel_wait(
poll_fd: RawFd, poll_fd: RawFd,

View File

@ -5,8 +5,8 @@ mod input;
mod terminal; mod terminal;
pub use crate::generated::{ pub use crate::generated::{
DirectoryEntry, FileAttr, FileMode, FileType, GroupId, MountOptions, OpenOptions, PollControl, DirectoryEntry, FileAttr, FileMode, FileType, GroupId, MountOptions, OpenOptions, PipeOptions,
RawFd, TimerOptions, UnmountOptions, UserId, PollControl, RawFd, TimerOptions, UnmountOptions, UserId,
}; };
pub use channel::{ChannelPublisherId, MessageDestination, ReceivedMessageMetadata, SentMessage}; pub use channel::{ChannelPublisherId, MessageDestination, ReceivedMessageMetadata, SentMessage};
pub use device::DeviceRequest; pub use device::DeviceRequest;

View File

@ -23,5 +23,5 @@ pub mod poll {
pub use abi::io::{ pub use abi::io::{
DirectoryEntry, FileAttr, FileMetadataUpdate, FileMetadataUpdateMode, FileMode, FileType, DirectoryEntry, FileAttr, FileMetadataUpdate, FileMetadataUpdateMode, FileMode, FileType,
OpenOptions, RawFd, SeekFrom, TimerOptions, OpenOptions, PipeOptions, RawFd, SeekFrom, TimerOptions,
}; };

View File

@ -17,7 +17,8 @@ mod generated {
error::Error, error::Error,
io::{ io::{
ChannelPublisherId, DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions, ChannelPublisherId, DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions,
PollControl, RawFd, TerminalOptions, TerminalSize, TimerOptions, UnmountOptions, PipeOptions, PollControl, RawFd, TerminalOptions, TerminalSize, TimerOptions,
UnmountOptions,
}, },
mem::MappingSource, mem::MappingSource,
net::SocketType, net::SocketType,

View File

@ -1,10 +1,15 @@
use std::{ use std::{
io, io::{self, Read, Write},
os::fd::{AsRawFd, RawFd}, os::fd::{AsRawFd, RawFd},
process::Stdio,
time::Duration, time::Duration,
}; };
use crate::sys::{self, Poll as SysPoll, TimerFd as SysTimerFd, PidFd as SysPidFd}; use crate::sys::{
self, PidFd as SysPidFd, Pipe as SysPipe, Poll as SysPoll, TimerFd as SysTimerFd,
};
use self::sys::PipeImpl;
#[repr(transparent)] #[repr(transparent)]
pub struct Poll(sys::PollImpl); pub struct Poll(sys::PollImpl);
@ -15,6 +20,9 @@ pub struct TimerFd(sys::TimerFdImpl);
#[repr(transparent)] #[repr(transparent)]
pub struct PidFd(sys::PidFdImpl); pub struct PidFd(sys::PidFdImpl);
#[repr(transparent)]
pub struct Pipe(sys::PipeImpl);
impl Poll { impl Poll {
pub fn new() -> io::Result<Self> { pub fn new() -> io::Result<Self> {
sys::PollImpl::new().map(Self) sys::PollImpl::new().map(Self)
@ -74,3 +82,36 @@ impl AsRawFd for PidFd {
self.0.as_raw_fd() self.0.as_raw_fd()
} }
} }
impl Pipe {
pub fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)> {
let (read, write) = PipeImpl::new(read_nonblocking, write_nonblocking)?;
Ok((Self(read), Self(write)))
}
pub fn to_child_stdio(&self) -> Stdio {
self.0.to_child_stdio()
}
}
impl Read for Pipe {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl Write for Pipe {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
impl AsRawFd for Pipe {
fn as_raw_fd(&self) -> RawFd {
self.0.as_raw_fd()
}
}

View File

@ -9,8 +9,9 @@ mod unix;
pub(crate) use unix::*; pub(crate) use unix::*;
use std::{ use std::{
io, io::{self, Read, Write},
os::fd::{AsRawFd, RawFd}, os::fd::{AsRawFd, RawFd},
process::Stdio,
time::Duration, time::Duration,
}; };
@ -31,3 +32,8 @@ pub(crate) trait PidFd: Sized + AsRawFd {
fn new(pid: u32) -> io::Result<Self>; fn new(pid: u32) -> io::Result<Self>;
fn exit_status(&self) -> io::Result<i32>; fn exit_status(&self) -> io::Result<i32>;
} }
pub(crate) trait Pipe: Read + Write + AsRawFd + Sized {
fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)>;
fn to_child_stdio(&self) -> Stdio;
}

View File

@ -1,7 +1,9 @@
pub mod poll; pub mod poll;
pub mod timer; pub mod timer;
pub mod pid; pub mod pid;
pub mod pipe;
pub use poll::PollImpl; pub use poll::PollImpl;
pub use timer::TimerFdImpl; pub use timer::TimerFdImpl;
pub use pid::PidFdImpl; pub use pid::PidFdImpl;
pub use pipe::PipeImpl;

View File

@ -10,7 +10,7 @@ pub struct PidFdImpl {
impl PidFd for PidFdImpl { impl PidFd for PidFdImpl {
fn new(pid: u32) -> io::Result<Self> { fn new(pid: u32) -> io::Result<Self> {
let pid = pid as i32; let pid = pid as i32;
let fd = unsafe { libc::pidfd_open(pid) }; let fd = unsafe { libc::syscall(libc::SYS_pidfd_open, pid, 0) } as i32;
if fd < 0 { if fd < 0 {
return Err(io::Error::last_os_error()); return Err(io::Error::last_os_error());
} }
@ -19,7 +19,7 @@ impl PidFd for PidFdImpl {
} }
fn exit_status(&self) -> io::Result<i32> { fn exit_status(&self) -> io::Result<i32> {
let status = 0; let mut status = 0;
let res = unsafe { libc::waitpid(self.pid, &mut status, 0) }; let res = unsafe { libc::waitpid(self.pid, &mut status, 0) };
if res < 0 { if res < 0 {
return Err(io::Error::last_os_error()); return Err(io::Error::last_os_error());

View File

@ -0,0 +1,47 @@
use std::{io::{self, Read, Write}, os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}};
use crate::sys::Pipe;
pub struct PipeImpl {
fd: OwnedFd,
}
impl Pipe for PipeImpl {
fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)> {
let mut fds = [0; 2];
let res = unsafe { libc::pipe(fds.as_mut_ptr()) };
if res < 0 {
return Err(io::Error::last_os_error());
}
let read = unsafe { OwnedFd::from_raw_fd(fds[0]) };
let write = unsafe { OwnedFd::from_raw_fd(fds[1]) };
Ok((Self { fd: read }, Self { fd: write }))
}
fn to_child_stdio(&self) -> std::process::Stdio {
todo!()
}
}
impl Read for PipeImpl {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
todo!()
}
}
impl Write for PipeImpl {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
todo!()
}
fn flush(&mut self) -> io::Result<()> {
todo!()
}
}
impl AsRawFd for PipeImpl {
fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
}

View File

@ -1,7 +1,9 @@
pub mod poll; pub mod poll;
pub mod timer; pub mod timer;
pub mod pid; pub mod pid;
pub mod pipe;
pub use poll::PollImpl; pub use poll::PollImpl;
pub use timer::TimerFdImpl; pub use timer::TimerFdImpl;
pub use pid::PidFdImpl; pub use pid::PidFdImpl;
pub use pipe::PipeImpl;

View File

@ -0,0 +1,50 @@
use std::{
fs::File,
io::{self, Read, Write},
os::{
fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
yggdrasil::io::pipe::create_pipe_pair,
},
process::Stdio,
};
use crate::sys::Pipe;
pub struct PipeImpl {
file: File,
}
impl Pipe for PipeImpl {
fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)> {
let (read, write) = create_pipe_pair(read_nonblocking, write_nonblocking)?;
let read = unsafe { File::from_raw_fd(read.into_raw_fd()) };
let write = unsafe { File::from_raw_fd(write.into_raw_fd()) };
Ok((Self { file: read }, Self { file: write }))
}
fn to_child_stdio(&self) -> Stdio {
unsafe { Stdio::from_raw_fd(self.as_raw_fd()) }
}
}
impl Write for PipeImpl {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.file.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.file.flush()
}
}
impl Read for PipeImpl {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.file.read(buf)
}
}
impl AsRawFd for PipeImpl {
fn as_raw_fd(&self) -> RawFd {
self.file.as_raw_fd()
}
}

View File

@ -1,5 +1,5 @@
#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))] #![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))]
#![feature(generic_const_exprs, portable_simd, if_let_guard)] #![feature(generic_const_exprs, portable_simd, if_let_guard, let_chains)]
#![allow(incomplete_features)] #![allow(incomplete_features)]
pub mod socket; pub mod socket;

View File

@ -2,7 +2,7 @@
#![feature(let_chains)] #![feature(let_chains)]
use std::{ use std::{
io::{stdout, IsTerminal, Read, Write}, io::{stderr, stdout, IsTerminal, Read, Stderr, Write},
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
os::fd::AsRawFd, os::fd::AsRawFd,
}; };
@ -11,11 +11,11 @@ use clap::Parser;
use cross::io::Poll; use cross::io::Poll;
use rsh::{ use rsh::{
crypt::{ crypt::{
client::{self, ClientSocket}, client::{self, ClientSocket, Message},
config::{ClientConfig, SimpleClientKeyStore}, config::{ClientConfig, SimpleClientKeyStore},
signature::{SignEd25519, SignatureMethod}, signature::{SignEd25519, SignatureMethod},
}, },
proto::ServerMessage, proto::{ServerMessage, StreamIndex},
}; };
use std::{ use std::{
@ -70,8 +70,9 @@ enum Input {
pub struct Client { pub struct Client {
poll: Poll, poll: Poll,
socket: ClientSocket, socket: ClientSocket,
input: Input, stdin: RawStdin,
stdout: Stdout, stdout: Stdout,
stderr: Stderr,
need_bye: bool, need_bye: bool,
last0: u8, last0: u8,
last1: u8, last1: u8,
@ -124,30 +125,14 @@ impl Drop for RawStdin {
} }
impl Client { impl Client {
pub fn connect( pub fn connect(remote: SocketAddr, crypto_config: ClientConfig) -> Result<Self, Error> {
remote: SocketAddr,
crypto_config: ClientConfig,
command: Vec<String>,
) -> Result<Self, Error> {
let mut poll = Poll::new()?; let mut poll = Poll::new()?;
let mut socket = ClientSocket::connect(remote, crypto_config)?; let mut socket = ClientSocket::connect(remote, crypto_config)?;
let input = match command.is_empty() { let stdin = RawStdin::open()?;
true => {
let stdin = RawStdin::open()?;
poll.add(&stdin)?;
Input::Stdin(stdin)
},
false => {
let mut bytes = vec![];
for command in command {
bytes.extend_from_slice(command.as_bytes());
bytes.push(b' ');
}
Input::Command(bytes)
},
};
let stdout = stdout(); let stdout = stdout();
let stderr = stderr();
poll.add(&stdin)?;
poll.add(&socket)?; poll.add(&socket)?;
let info = terminal_info(&stdout)?; let info = terminal_info(&stdout)?;
@ -155,8 +140,9 @@ impl Client {
Self::handshake(&mut socket, info)?; Self::handshake(&mut socket, info)?;
Ok(Self { Ok(Self {
input, stdin,
stdout, stdout,
stderr,
socket, socket,
poll, poll,
need_bye: false, need_bye: false,
@ -187,19 +173,20 @@ impl Client {
pub fn run(mut self) -> Result<String, Error> { pub fn run(mut self) -> Result<String, Error> {
let mut recv_buf = [0; 512]; let mut recv_buf = [0; 512];
if let Input::Command(command) = &self.input {
self.socket.write_all(&ClientMessage::Input(&command[..]))?;
}
loop { loop {
if let Some(message) = self.socket.read(&mut recv_buf)? { if let Some(message) = self.socket.read(&mut recv_buf)? {
match message { match message {
ServerMessage::Bye(reason) => return Ok(reason.into()), ServerMessage::Bye(reason) => return Ok(reason.into()),
ServerMessage::Output(data) => { ServerMessage::Output(StreamIndex::Stdout, data) => {
self.stdout.write_all(data).ok(); self.stdout.write_all(data).ok();
self.stdout.flush().ok(); self.stdout.flush().ok();
continue; continue;
} }
ServerMessage::Output(StreamIndex::Stderr, data) => {
self.stderr.write_all(data).ok();
self.stderr.flush().ok();
continue;
}
_ => continue, _ => continue,
} }
} }
@ -210,8 +197,8 @@ impl Client {
if self.socket.poll()? == 0 { if self.socket.poll()? == 0 {
return Ok("".into()); return Ok("".into());
} }
} else if let Input::Stdin(stdin) = &mut self.input && stdin.as_raw_fd() == fd { } else if self.stdin.as_raw_fd() == fd {
let len = stdin.read(&mut recv_buf)?; let len = self.stdin.read(&mut recv_buf)?;
self.update_last(&recv_buf[..len])?; self.update_last(&recv_buf[..len])?;
self.socket self.socket
.write_all(&ClientMessage::Input(&recv_buf[..len]))?; .write_all(&ClientMessage::Input(&recv_buf[..len]))?;
@ -247,27 +234,101 @@ fn terminal_info(stdout: &Stdout) -> Result<TerminalInfo, Error> {
}) })
} }
fn run(args: Args) -> Result<(), Error> { fn run_terminal(remote: SocketAddr, config: ClientConfig) -> Result<(), Error> {
let reason = Client::connect(remote, config)?.run()?;
if !reason.is_empty() {
eprintln!("\nDisconnected: {reason}");
}
Ok(())
}
fn run_command(
remote: SocketAddr,
config: ClientConfig,
command: Vec<String>,
) -> Result<ExitCode, Error> {
let mut poll = Poll::new()?;
let mut buffer = [0; 512];
let mut command_string = String::new();
for (i, word) in command.iter().enumerate() {
if i != 0 {
command_string.push(' ');
}
command_string.push_str(word);
}
let mut stdin = stdin();
let mut stdout = stdout();
let mut stderr = stderr();
let mut socket = ClientSocket::connect(remote, config)?;
poll.add(&socket)?;
poll.add(&stdin)?;
socket.write_all(&ClientMessage::RunCommand(command_string.as_str()))?;
loop {
let fd = poll.wait(None)?.unwrap();
match fd {
_ if fd == socket.as_raw_fd() => {
let message = match socket.poll_read(&mut buffer)? {
Message::Data(data) => data,
Message::Incomplete => continue,
Message::Closed => break
};
match message {
ServerMessage::Output(StreamIndex::Stdout, output) => {
stdout.write_all(output).ok();
stdout.flush().ok();
}
ServerMessage::Output(StreamIndex::Stderr, output) => {
stderr.write_all(output).ok();
stderr.flush().ok();
}
_ => todo!()
}
}
_ if fd == stdin.as_raw_fd() => {
let len = stdin.read(&mut buffer)?;
if len == 0 {
poll.remove(&stdin)?;
socket.write_all(&ClientMessage::CloseStdin)?;
} else {
socket.write_all(&ClientMessage::Input(&buffer[..len]))?;
}
}
_ => unreachable!()
}
}
Ok(ExitCode::SUCCESS)
}
fn run(args: Args) -> Result<ExitCode, Error> {
let remote = SocketAddr::new(args.remote, args.port); let remote = SocketAddr::new(args.remote, args.port);
let ed25519 = SignEd25519::load_signing_key(args.key).unwrap(); let ed25519 = SignEd25519::load_signing_key(args.key).unwrap();
let key = SignatureMethod::Ed25519(ed25519); let key = SignatureMethod::Ed25519(ed25519);
let config = ClientConfig::with_default_algorithms(SimpleClientKeyStore::new(key)); let config = ClientConfig::with_default_algorithms(SimpleClientKeyStore::new(key));
let reason = Client::connect(remote, config, args.command)?.run()?; if args.command.is_empty() {
if !reason.is_empty() { run_terminal(remote, config).map(|_| ExitCode::SUCCESS)
eprintln!("\nDisconnected: {reason}"); } else {
run_command(remote, config, args.command)
} }
Ok(())
} }
fn main() -> ExitCode { fn main() -> ExitCode {
env_logger::init(); env_logger::init();
let args = Args::parse(); let args = Args::parse();
if let Err(error) = run(args) { match run(args) {
eprintln!("Error: {error}"); Ok(status) => status,
ExitCode::FAILURE Err(error) => {
} else { eprintln!("Error: {error}");
ExitCode::SUCCESS ExitCode::FAILURE
}
} }
} }

View File

@ -13,6 +13,13 @@ pub enum ClientMessage<'a> {
RunCommand(&'a str), RunCommand(&'a str),
Bye(&'a str), Bye(&'a str),
Input(&'a [u8]), Input(&'a [u8]),
CloseStdin,
}
#[derive(Debug, Clone, Copy)]
pub enum StreamIndex {
Stdout,
Stderr,
} }
#[derive(Debug)] #[derive(Debug)]
@ -20,7 +27,7 @@ pub enum ServerMessage<'a> {
SessionOpen, SessionOpen,
CommandStatus(i32), CommandStatus(i32),
Bye(&'a str), Bye(&'a str),
Output(&'a [u8]), Output(StreamIndex, &'a [u8]),
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@ -164,6 +171,32 @@ impl ClientMessage<'_> {
const TAG_RUN_COMMAND: u8 = 0x81; const TAG_RUN_COMMAND: u8 = 0x81;
const TAG_BYE: u8 = 0x82; const TAG_BYE: u8 = 0x82;
const TAG_INPUT: u8 = 0x90; const TAG_INPUT: u8 = 0x90;
const TAG_CLOSE_STDIN: u8 = 0x91;
}
impl StreamIndex {
const TAG_STDOUT: u8 = 0x01;
const TAG_STDERR: u8 = 0x02;
}
impl Encode for StreamIndex {
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
match self {
Self::Stdout => buffer.write(&[Self::TAG_STDOUT]),
Self::Stderr => buffer.write(&[Self::TAG_STDERR]),
}
}
}
impl<'de> Decode<'de> for StreamIndex {
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
let tag = buffer.read_u8()?;
match tag {
Self::TAG_STDOUT => Ok(Self::Stdout),
Self::TAG_STDERR => Ok(Self::Stderr),
_ => Err(DecodeError::InvalidMessage)
}
}
} }
impl Encode for TerminalInfo { impl Encode for TerminalInfo {
@ -201,6 +234,7 @@ impl<'a> Encode for ClientMessage<'a> {
buffer.write(&[Self::TAG_INPUT])?; buffer.write(&[Self::TAG_INPUT])?;
buffer.write_variable_bytes(data) buffer.write_variable_bytes(data)
} }
Self::CloseStdin => buffer.write(&[Self::TAG_CLOSE_STDIN])
} }
} }
} }
@ -221,6 +255,9 @@ impl<'de> Decode<'de> for ClientMessage<'de> {
Self::TAG_INPUT => { Self::TAG_INPUT => {
buffer.read_variable_bytes().map(Self::Input) buffer.read_variable_bytes().map(Self::Input)
} }
Self::TAG_CLOSE_STDIN => {
Ok(Self::CloseStdin)
}
_ => Err(DecodeError::InvalidMessage) _ => Err(DecodeError::InvalidMessage)
} }
} }
@ -244,8 +281,9 @@ impl<'a> Encode for ServerMessage<'a> {
buffer.write(&[Self::TAG_BYE])?; buffer.write(&[Self::TAG_BYE])?;
buffer.write_str(reason) buffer.write_str(reason)
} }
Self::Output(data) => { Self::Output(index, data) => {
buffer.write(&[Self::TAG_OUTPUT])?; buffer.write(&[Self::TAG_OUTPUT])?;
index.encode(buffer)?;
buffer.write_variable_bytes(data) buffer.write_variable_bytes(data)
} }
} }
@ -267,7 +305,9 @@ impl<'de> Decode<'de> for ServerMessage<'de> {
buffer.read_str().map(Self::Bye) buffer.read_str().map(Self::Bye)
} }
Self::TAG_OUTPUT => { Self::TAG_OUTPUT => {
buffer.read_variable_bytes().map(Self::Output) let index = StreamIndex::decode(buffer)?;
let data = buffer.read_variable_bytes()?;
Ok(Self::Output(index, data))
}, },
_ => Err(DecodeError::InvalidMessage), _ => Err(DecodeError::InvalidMessage),
} }

View File

@ -1,12 +1,15 @@
use std::{ use std::{
collections::{hash_map::Entry, HashMap}, collections::{hash_map::Entry, HashMap},
fmt, io, fmt,
hint::unreachable_unchecked,
io::{self, Read, Stdout, Write},
net::SocketAddr, net::SocketAddr,
os::fd::{AsRawFd, RawFd}, os::fd::{AsRawFd, RawFd},
process::{Child, Command, Stdio},
time::Duration, time::Duration,
}; };
use cross::io::Poll; use cross::io::{PidFd, Pipe, Poll};
use crate::{ use crate::{
crypt::{ crypt::{
@ -14,7 +17,7 @@ use crate::{
config::ServerConfig, config::ServerConfig,
server::{self, ServerSocket}, server::{self, ServerSocket},
}, },
proto::{ClientMessage, ServerMessage, TerminalInfo}, proto::{ClientMessage, ServerMessage, StreamIndex, TerminalInfo},
}; };
pub const PING_INTERVAL: Duration = Duration::from_millis(500); pub const PING_INTERVAL: Duration = Duration::from_millis(500);
@ -52,22 +55,28 @@ pub struct EchoSession {
peer: SocketAddr, peer: SocketAddr,
} }
struct ClientSession<T: Session> { struct PendingCommand {
stream: ClientSocket, child: Child,
session: Option<T>, child_pid: Option<PidFd>,
stdin: Option<Pipe>,
stdout: Option<Pipe>,
stderr: Option<Pipe>,
} }
// enum Event<'b, T: Session> { enum ClientSession<T: Session> {
// NewSession(SocketAddr, TerminalInfo), None,
// SessionInput(u64, SocketAddr, &'b [u8]), Terminal(T),
// ClientBye(SocketAddr, &'b str), Command(PendingCommand),
// SessionEvent(RawFd, SocketAddr, SessionEvent<'b, T>), }
// Tick,
// }
struct SessionSet<T: Session> { struct Client<T: Session> {
stream: ClientSocket,
session: ClientSession<T>,
}
struct ClientSet<T: Session> {
last_session_key: u64, last_session_key: u64,
sessions: HashMap<u64, ClientSession<T>>, clients: HashMap<u64, Client<T>>,
socket_fd_map: HashMap<RawFd, u64>, socket_fd_map: HashMap<RawFd, u64>,
session_fd_map: HashMap<RawFd, u64>, session_fd_map: HashMap<RawFd, u64>,
} }
@ -76,29 +85,63 @@ pub struct Server<T: Session> {
poll: Poll, poll: Poll,
socket: ServerSocket, socket: ServerSocket,
sessions: SessionSet<T>, clients: ClientSet<T>,
} }
impl<T: Session> SessionSet<T> { enum CommandEvent {
Output(StreamIndex, Result<usize, io::Error>),
Exited(Result<i32, io::Error>)
}
impl<T: Session> Client<T> {
fn make_command(text: &str) -> Result<PendingCommand, Error> {
let mut words = text.split(' ');
let program = words.next().unwrap();
let mut command = Command::new(program);
let (stdin_read, stdin_write) = Pipe::new(false, false)?;
let (stdout_read, stdout_write) = Pipe::new(true, false)?;
let (stderr_read, stderr_write) = Pipe::new(true, false)?;
command
.args(words)
.stdin(stdin_read.to_child_stdio())
.stdout(stdout_write.to_child_stdio())
.stderr(stderr_write.to_child_stdio());
let child = command.spawn()?;
let child_pid = PidFd::new(child.id())?;
Ok(PendingCommand {
child,
child_pid: Some(child_pid),
stdout: Some(stdout_read),
stderr: Some(stderr_read),
stdin: Some(stdin_write),
})
}
}
impl<T: Session> ClientSet<T> {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
last_session_key: 1, last_session_key: 1,
sessions: HashMap::new(), clients: HashMap::new(),
socket_fd_map: HashMap::new(), socket_fd_map: HashMap::new(),
session_fd_map: HashMap::new(), session_fd_map: HashMap::new(),
} }
} }
pub fn add_client(&mut self, stream: ClientSocket, poll: &mut Poll) -> Result<(), Error> { pub fn add_client(&mut self, stream: ClientSocket, poll: &mut Poll) -> Result<(), Error> {
let (key, session) = loop { let (key, client) = loop {
let key = self.last_session_key; let key = self.last_session_key;
self.last_session_key += 1; self.last_session_key += 1;
match self.sessions.entry(key) { match self.clients.entry(key) {
Entry::Vacant(entry) => { Entry::Vacant(entry) => {
let session = entry.insert(ClientSession { let session = entry.insert(Client {
stream, stream,
session: None, session: ClientSession::None,
}); });
break (key, session); break (key, session);
} }
@ -106,22 +149,39 @@ impl<T: Session> SessionSet<T> {
} }
}; };
poll.add(&session.stream)?; poll.add(&client.stream)?;
self.socket_fd_map.insert(session.stream.as_raw_fd(), key); self.socket_fd_map.insert(client.stream.as_raw_fd(), key);
Ok(()) Ok(())
} }
fn remove(&mut self, key: u64, poll: &mut Poll) -> Result<(), Error> { fn remove(&mut self, key: u64, poll: &mut Poll) -> Result<(), Error> {
if let Some(session) = self.sessions.remove(&key) { if let Some(client) = self.clients.remove(&key) {
poll.remove(&session.stream)?; poll.remove(&client.stream)?;
self.socket_fd_map.remove(&session.stream.as_raw_fd()); self.socket_fd_map.remove(&client.stream.as_raw_fd());
if let Some(session) = session.session { match client.session {
for fd in session.event_fds() { ClientSession::Terminal(terminal) => {
poll.remove(fd)?; for fd in terminal.event_fds() {
self.session_fd_map.remove(fd); poll.remove(fd)?;
self.session_fd_map.remove(fd);
}
} }
ClientSession::Command(mut command) => {
if let Some(stdout) = command.stdout.take() {
poll.remove(&stdout)?;
self.session_fd_map.remove(&stdout.as_raw_fd());
}
if let Some(stderr) = command.stderr.take() {
poll.remove(&stderr)?;
self.session_fd_map.remove(&stderr.as_raw_fd());
}
if let Some(child_pid) = command.child_pid.take() {
poll.remove(&child_pid)?;
self.session_fd_map.remove(&child_pid.as_raw_fd());
}
}
ClientSession::None => (),
} }
} }
Ok(()) Ok(())
@ -131,10 +191,10 @@ impl<T: Session> SessionSet<T> {
let mut buffer = [0; 512]; let mut buffer = [0; 512];
if let Some(&key) = self.socket_fd_map.get(&fd) { if let Some(&key) = self.socket_fd_map.get(&fd) {
let session = self.sessions.get_mut(&key).unwrap(); let client = self.clients.get_mut(&key).unwrap();
let peer = session.stream.remote_address(); let peer = client.stream.remote_address();
let mut closed = match session.stream.poll() { let mut closed = match client.stream.poll() {
Ok(0) => true, Ok(0) => true,
Ok(_) => false, Ok(_) => false,
Err(error) => { Err(error) => {
@ -144,7 +204,7 @@ impl<T: Session> SessionSet<T> {
}; };
loop { loop {
let message = match session.stream.read(&mut buffer) { let message = match client.stream.read(&mut buffer) {
Ok(Some(message)) => message, Ok(Some(message)) => message,
Ok(None) => break, Ok(None) => break,
Err(error) => { Err(error) => {
@ -154,8 +214,33 @@ impl<T: Session> SessionSet<T> {
} }
}; };
match (&message, &mut session.session) { match (&message, &mut client.session) {
(ClientMessage::OpenSession(terminal), None) => { (ClientMessage::RunCommand(command), ClientSession::None) => {
log::info!("{peer}: cmd {command:?}");
let command = match Client::<T>::make_command(command) {
Ok(command) => command,
Err(error) => {
log::error!("{peer}: command error: {error}");
return self.remove(key, poll);
}
};
let stdout_fd = command.stdout.as_ref().unwrap().as_raw_fd();
let stderr_fd = command.stderr.as_ref().unwrap().as_raw_fd();
let child_pid_fd = command.child_pid.as_ref().unwrap().as_raw_fd();
poll.add(&stdout_fd)?;
poll.add(&stderr_fd)?;
poll.add(&child_pid_fd)?;
self.session_fd_map.insert(stdout_fd, key);
self.session_fd_map.insert(stderr_fd, key);
self.session_fd_map.insert(child_pid_fd, key);
client.session = ClientSession::Command(command);
}
(ClientMessage::OpenSession(terminal), ClientSession::None) => {
log::info!("{peer}: new session"); log::info!("{peer}: new session");
let terminal = match T::open(&peer, &terminal) { let terminal = match T::open(&peer, &terminal) {
Ok(session) => session, Ok(session) => session,
@ -165,15 +250,33 @@ impl<T: Session> SessionSet<T> {
return Ok(()); return Ok(());
} }
}; };
let terminal = session.session.insert(terminal); let terminal = client.session.set_terminal(terminal);
for fd in terminal.event_fds() { for fd in terminal.event_fds() {
poll.add(fd)?; poll.add(fd)?;
self.session_fd_map.insert(*fd, key); self.session_fd_map.insert(*fd, key);
} }
session.stream.write_all(&ServerMessage::SessionOpen).ok();
client.stream.write_all(&ServerMessage::SessionOpen).ok();
} }
(ClientMessage::Input(data), Some(terminal)) => { (ClientMessage::Input(data), ClientSession::Command(command)) => {
let client = SessionClient { transport: &mut session.stream }; if let Some(stdin) = command.stdin.as_mut() {
match stdin.write_all(data) {
Ok(()) => {
stdin.flush().ok();
},
Err(error) => {
log::error!("{peer}: stdin error: {error}");
return self.remove(key, poll);
}
}
} else {
log::warn!("{peer}: stdin already closed");
}
}
(ClientMessage::Input(data), ClientSession::Terminal(terminal)) => {
let client = SessionClient {
transport: &mut client.stream,
};
match terminal.handle_input(data, client) { match terminal.handle_input(data, client) {
Ok(false) => (), Ok(false) => (),
Ok(true) => { Ok(true) => {
@ -186,6 +289,10 @@ impl<T: Session> SessionSet<T> {
} }
} }
} }
(ClientMessage::CloseStdin, ClientSession::Command(command)) => {
// Drop stdin handle
command.stdin.take();
}
_ => { _ => {
log::warn!("{peer}: unhandled message"); log::warn!("{peer}: unhandled message");
} }
@ -199,45 +306,148 @@ impl<T: Session> SessionSet<T> {
Ok(()) Ok(())
} else if let Some(&key) = self.session_fd_map.get(&fd) { } else if let Some(&key) = self.session_fd_map.get(&fd) {
let session = self.sessions.get_mut(&key).unwrap(); log::debug!("poll fd {:?}", fd);
let terminal = session.session.as_mut().unwrap(); let client = self.clients.get_mut(&key).unwrap();
let peer = session.stream.remote_address(); let peer = client.stream.remote_address();
match terminal.read_output(fd, &mut buffer) { let (mut len, stream_index) = match &mut client.session {
Ok(0) => { ClientSession::Command(command) => match command.read_output(fd, &mut buffer) {
log::info!("{peer}: session closed"); CommandEvent::Output(index, Ok(len)) => {
self.remove(key, poll) if len == 0 {
} poll.remove(&fd)?;
Ok(mut len) => { self.socket_fd_map.remove(&fd);
// Split output into 128-byte chunks
let mut pos = 0;
while len != 0 {
let amount = core::cmp::min(len, 128);
if let Err(error) = session match index {
.stream StreamIndex::Stdout => {
.write_all(&ServerMessage::Output(&buffer[pos..pos + amount])) command.stdout = None;
{ }
log::error!("{peer}: communication error: {error}"); StreamIndex::Stderr => {
command.stderr = None;
}
}
if command.is_dead() {
log::info!("{peer}: command stdout/stderr closed");
return self.remove(key, poll);
}
}
(len, index)
},
CommandEvent::Output(index, Err(error)) => {
if error.kind() == io::ErrorKind::WouldBlock {
return Ok(());
}
log::error!("{peer}: {index:?} error: {error}");
return self.remove(key, poll);
}
CommandEvent::Exited(status) => {
log::info!("{peer}: command exited: {:?}", status);
poll.remove(&fd)?;
self.socket_fd_map.remove(&fd);
command.child_pid = None;
if command.is_dead() {
log::info!("{peer}: command stdout/stderr closed");
return self.remove(key, poll); return self.remove(key, poll);
} }
pos += amount; return Ok(());
len -= amount;
} }
Ok(()) },
} ClientSession::Terminal(terminal) => match terminal.read_output(fd, &mut buffer) {
Err(error) => { Ok(0) => {
log::error!("{peer}: session read error: {error}"); poll.remove(&fd)?;
self.remove(key, poll) self.socket_fd_map.remove(&fd);
}
// TODO check for process as well
log::info!("{peer}: terminal closed");
return self.remove(key, poll);
}
Ok(len) => {
(len, StreamIndex::Stdout)
},
Err(error) => {
log::error!("{peer}: terminal error: {error}");
return self.remove(key, poll);
}
},
ClientSession::None => unreachable!(),
};
if len == 0 {
log::info!("{peer}: {stream_index:?} closed");
return Ok(());
// return self.remove(key, poll);
} }
// Split output into 128-byte chunks
let mut pos = 0;
while len != 0 {
let amount = core::cmp::min(len, 128);
if let Err(error) = client.stream.write_all(&ServerMessage::Output(
stream_index,
&buffer[pos..pos + amount],
)) {
log::error!("{peer}: communication error: {error}");
return self.remove(key, poll);
}
pos += amount;
len -= amount;
}
log::debug!("Done");
Ok(())
} else { } else {
unreachable!() unreachable!()
} }
} }
} }
impl PendingCommand {
pub fn read_output(
&mut self,
fd: RawFd,
buffer: &mut [u8],
) -> CommandEvent {
if let Some(stdout) = self.stdout.as_mut() && fd == stdout.as_raw_fd() {
log::debug!("poll stdout");
let res = stdout.read(buffer);
log::debug!(">> {:?}", res);
return CommandEvent::Output(StreamIndex::Stdout, res);
}
if let Some(stderr) = self.stderr.as_mut() && fd == stderr.as_raw_fd() {
return CommandEvent::Output(StreamIndex::Stderr, stderr.read(buffer));
}
if let Some(child_pid) = self.child_pid.as_mut() && fd == child_pid.as_raw_fd() {
let status = child_pid.exit_status();
return CommandEvent::Exited(status);
}
unreachable!()
}
pub fn is_dead(&self) -> bool {
self.child_pid.is_none() && self.stdout.is_none() && self.stderr.is_none()
}
}
impl Drop for PendingCommand {
fn drop(&mut self) {
self.child.wait().ok();
}
}
impl<T: Session> ClientSession<T> {
pub fn set_terminal(&mut self, terminal: T) -> &mut T {
*self = Self::Terminal(terminal);
match self {
Self::Terminal(terminal) => terminal,
_ => unsafe { unreachable_unchecked() },
}
}
}
impl<T: Session> Server<T> { impl<T: Session> Server<T> {
pub fn listen(listen_addr: SocketAddr, crypto_config: ServerConfig) -> Result<Self, Error> { pub fn listen(listen_addr: SocketAddr, crypto_config: ServerConfig) -> Result<Self, Error> {
let mut poll = Poll::new()?; let mut poll = Poll::new()?;
@ -246,7 +456,7 @@ impl<T: Session> Server<T> {
Ok(Self { Ok(Self {
poll, poll,
socket, socket,
sessions: SessionSet::new(), clients: ClientSet::new(),
}) })
} }
@ -259,14 +469,14 @@ impl<T: Session> Server<T> {
// Poll server socket // Poll server socket
match self.socket.poll()? { match self.socket.poll()? {
Some(client) => { Some(client) => {
self.sessions.add_client(client, &mut self.poll)?; self.clients.add_client(client, &mut self.poll)?;
} }
None => continue, None => continue,
} }
} }
_ => { _ => {
// Client/Session-related activity // Client/Session-related activity
self.sessions.handle_input(fd, &mut self.poll)?; self.clients.handle_input(fd, &mut self.poll)?;
} }
} }
} }
@ -275,7 +485,7 @@ impl<T: Session> Server<T> {
impl<'s> SessionClient<'s> { impl<'s> SessionClient<'s> {
pub fn send_data(&mut self, data: &[u8]) -> Result<(), Error> { pub fn send_data(&mut self, data: &[u8]) -> Result<(), Error> {
self.send_message(&ServerMessage::Output(data)) self.send_message(&ServerMessage::Output(StreamIndex::Stdout, data))
} }
pub fn send_message(&mut self, message: &ServerMessage) -> Result<(), Error> { pub fn send_message(&mut self, message: &ServerMessage) -> Result<(), Error> {

View File

@ -74,7 +74,7 @@ pub fn wait_for_pipeline(
} }
pub fn create_pipe() -> Result<Pipe, io::Error> { pub fn create_pipe() -> Result<Pipe, io::Error> {
let (read, write) = pipe::create_pipe_pair()?; let (read, write) = pipe::create_pipe_pair(false, false)?;
Ok(Pipe { read, write }) Ok(Pipe { read, write })
} }