diff --git a/kernel/libk/src/vfs/file/mod.rs b/kernel/libk/src/vfs/file/mod.rs index 5274d4e0..8323c323 100644 --- a/kernel/libk/src/vfs/file/mod.rs +++ b/kernel/libk/src/vfs/file/mod.rs @@ -17,8 +17,8 @@ use libk_util::sync::IrqSafeSpinlock; use yggdrasil_abi::{ error::Error, io::{ - DeviceRequest, DirectoryEntry, OpenOptions, RawFd, SeekFrom, TerminalOptions, TerminalSize, - TimerOptions, + DeviceRequest, DirectoryEntry, OpenOptions, PipeOptions, RawFd, SeekFrom, TerminalOptions, + TerminalSize, TimerOptions, }, }; @@ -71,7 +71,7 @@ pub enum File { Block(BlockFile), Char(CharFile), Socket(SocketWrapper), - AnonymousPipe(PipeEnd), + AnonymousPipe(PipeEnd, AtomicBool), Poll(FdPoll), Timer(TimerFile), Channel(ChannelDescriptor), @@ -103,11 +103,17 @@ pub struct FileSet { impl File { /// Constructs a pipe pair, returning its `(read, write)` ends - pub fn new_pipe_pair(capacity: usize) -> (Arc, Arc) { + pub fn new_pipe_pair(capacity: usize, options: PipeOptions) -> (Arc, Arc) { let (read, write) = PipeEnd::new_pair(capacity); ( - Arc::new(Self::AnonymousPipe(read)), - Arc::new(Self::AnonymousPipe(write)), + Arc::new(Self::AnonymousPipe( + read, + AtomicBool::new(options.contains(PipeOptions::READ_NONBLOCKING)), + )), + Arc::new(Self::AnonymousPipe( + write, + AtomicBool::new(options.contains(PipeOptions::WRITE_NONBLOCKING)), + )), ) } @@ -289,6 +295,7 @@ impl File { Self::Socket(socket) => socket.poll_read(cx), Self::Timer(timer) => timer.poll_read(cx), Self::Pid(pid) => pid.poll_read(cx), + Self::AnonymousPipe(pipe, _) => pipe.poll_read(cx), // Polling not implemented, return ready immediately (XXX ?) _ => Poll::Ready(Err(Error::NotImplemented)), } @@ -365,7 +372,9 @@ impl Read for File { Self::Regular(file) => file.read(buf), Self::Block(file) => file.read(buf), Self::Char(file) => file.read(buf), - Self::AnonymousPipe(pipe) => pipe.read(buf), + Self::AnonymousPipe(pipe, nonblocking) => { + pipe.read(buf, nonblocking.load(Ordering::Acquire)) + } Self::PtySlave(half) => half.read(buf), Self::PtyMaster(half) => half.read(buf), Self::Timer(timer) => timer.read(buf), @@ -388,7 +397,7 @@ impl Write for File { Self::Regular(file) => file.write(buf), Self::Block(file) => file.write(buf), Self::Char(file) => file.write(buf), - Self::AnonymousPipe(pipe) => pipe.write(buf), + Self::AnonymousPipe(pipe, _) => pipe.write(buf), Self::PtySlave(half) => half.write(buf), Self::PtyMaster(half) => half.write(buf), Self::Timer(timer) => timer.write(buf), @@ -447,7 +456,7 @@ impl fmt::Debug for File { .field("write", &file.write) .finish_non_exhaustive(), Self::Directory(_) => f.debug_struct("DirectoryFile").finish_non_exhaustive(), - Self::AnonymousPipe(_) => f.debug_struct("AnonymousPipe").finish_non_exhaustive(), + Self::AnonymousPipe(_, _) => f.debug_struct("AnonymousPipe").finish_non_exhaustive(), Self::Poll(_) => f.debug_struct("Poll").finish_non_exhaustive(), Self::Channel(_) => f.debug_struct("Channel").finish_non_exhaustive(), Self::SharedMemory(_) => f.debug_struct("SharedMemory").finish_non_exhaustive(), diff --git a/kernel/libk/src/vfs/file/pipe.rs b/kernel/libk/src/vfs/file/pipe.rs index 9743f638..bcff53bd 100644 --- a/kernel/libk/src/vfs/file/pipe.rs +++ b/kernel/libk/src/vfs/file/pipe.rs @@ -1,4 +1,5 @@ use core::{ + future::poll_fn, pin::Pin, sync::atomic::{AtomicBool, Ordering}, task::{Context, Poll}, @@ -6,9 +7,14 @@ use core::{ use alloc::{sync::Arc, vec, vec::Vec}; use futures_util::{task::AtomicWaker, Future}; -use libk_util::sync::IrqSafeSpinlock; +use libk_util::{ + sync::{IrqSafeSpinlock, IrqSafeSpinlockGuard}, + waker::QueueWaker, +}; use yggdrasil_abi::error::Error; +use crate::vfs::FileReadiness; + struct PipeInner { data: Vec, capacity: usize, @@ -19,8 +25,8 @@ struct PipeInner { pub struct Pipe { inner: IrqSafeSpinlock, shutdown: AtomicBool, - read_notify: AtomicWaker, - write_notify: AtomicWaker, + read_notify: QueueWaker, + write_notify: QueueWaker, } pub enum PipeEnd { @@ -82,15 +88,15 @@ impl Pipe { Self { inner: IrqSafeSpinlock::new(PipeInner::new(capacity)), shutdown: AtomicBool::new(false), - read_notify: AtomicWaker::new(), - write_notify: AtomicWaker::new(), + read_notify: QueueWaker::new(), + write_notify: QueueWaker::new(), } } pub fn shutdown(&self) { self.shutdown.store(true, Ordering::Release); - self.read_notify.wake(); - self.write_notify.wake(); + self.read_notify.wake_all(); + self.write_notify.wake_all(); } pub fn blocking_write(&self, val: u8) -> impl Future> + '_ { @@ -105,23 +111,15 @@ impl Pipe { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut lock = self.pipe.inner.lock(); - // Try fast path before acquiring write notify to avoid unnecessary contention - if self.pipe.shutdown.load(Ordering::Acquire) { - // TODO BrokenPipe - return Poll::Ready(Err(Error::ReadOnly)); - } else if lock.try_write(self.val) { - self.pipe.read_notify.wake(); - return Poll::Ready(Ok(())); - } - - self.pipe.write_notify.register(cx.waker()); - if self.pipe.shutdown.load(Ordering::Acquire) { + self.pipe.write_notify.remove(cx.waker()); Poll::Ready(Err(Error::ReadOnly)) } else if lock.try_write(self.val) { - self.pipe.read_notify.wake(); + self.pipe.write_notify.remove(cx.waker()); + self.pipe.read_notify.wake_one(); Poll::Ready(Ok(())) } else { + self.pipe.write_notify.register(cx.waker()); Poll::Pending } } @@ -130,37 +128,54 @@ impl Pipe { F { pipe: self, val } } - pub fn blocking_read(&self) -> impl Future> + '_ { - struct F<'a> { - pipe: &'a Pipe, + fn poll_read_end( + &self, + cx: &mut Context<'_>, + ) -> Poll>> { + let lock = self.inner.lock(); + + if lock.can_read() { + self.read_notify.remove(cx.waker()); + Poll::Ready(Some(lock)) + } else if self.shutdown.load(Ordering::Acquire) { + self.read_notify.remove(cx.waker()); + Poll::Ready(None) + } else { + self.read_notify.register(cx.waker()); + Poll::Pending } + } - impl<'a> Future for F<'a> { - type Output = Option; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut lock = self.pipe.inner.lock(); - - if let Some(val) = lock.try_read() { - self.pipe.write_notify.wake(); - return Poll::Ready(Some(val)); - } else if self.pipe.shutdown.load(Ordering::Acquire) { - return Poll::Ready(None); + async fn read_async(&self, buf: &mut [u8]) -> Result { + let mut pos = 0; + let lock = poll_fn(|cx| self.poll_read_end(cx)).await; + match lock { + Some(mut lock) => { + while pos < buf.len() + && let Some(byte) = lock.try_read() + { + buf[pos] = byte; + pos += 1; } - - self.pipe.read_notify.register(cx.waker()); - - if let Some(val) = lock.try_read() { - Poll::Ready(Some(val)) - } else if self.pipe.shutdown.load(Ordering::Acquire) { - Poll::Ready(None) - } else { - Poll::Pending + if pos != 0 { + self.write_notify.wake_all(); } + Ok(pos) } + None => Ok(0), } + } +} - F { pipe: self } +impl FileReadiness for Pipe { + fn poll_read(&self, cx: &mut Context<'_>) -> Poll> { + if self.shutdown.load(Ordering::Acquire) || self.inner.lock().can_read() { + self.read_notify.remove(cx.waker()); + Poll::Ready(Ok(())) + } else { + self.read_notify.register(cx.waker()); + Poll::Pending + } } } @@ -173,27 +188,29 @@ impl PipeEnd { (read, write) } - pub fn read(&self, buf: &mut [u8]) -> Result { + pub fn read(&self, buf: &mut [u8], nonblocking: bool) -> Result { let PipeEnd::Read(read) = self else { return Err(Error::InvalidOperation); }; - block! { + if nonblocking { + let mut lock = read.inner.lock(); let mut pos = 0; - let mut rem = buf.len(); - - while rem != 0 { - if let Some(val) = read.blocking_read().await { - buf[pos] = val; - pos += 1; - rem -= 1; - } else { - break; - } + while pos < buf.len() + && let Some(ch) = lock.try_read() + { + buf[pos] = ch; + pos += 1; } - - Ok(pos) - }? + read.write_notify.wake_all(); + if pos == 0 && !read.shutdown.load(Ordering::Acquire) { + Err(Error::WouldBlock) + } else { + Ok(pos) + } + } else { + block!(read.read_async(buf).await)? + } } pub fn write(&self, buf: &[u8]) -> Result { @@ -216,6 +233,15 @@ impl PipeEnd { } } +impl FileReadiness for PipeEnd { + fn poll_read(&self, cx: &mut Context<'_>) -> Poll> { + match self { + Self::Read(pipe) => pipe.poll_read(cx), + Self::Write(_) => Poll::Ready(Err(Error::NotImplemented)), + } + } +} + impl Drop for PipeEnd { fn drop(&mut self) { match self { diff --git a/kernel/src/syscall/imp/mod.rs b/kernel/src/syscall/imp/mod.rs index 22bfbd27..459a98ac 100644 --- a/kernel/src/syscall/imp/mod.rs +++ b/kernel/src/syscall/imp/mod.rs @@ -1,8 +1,8 @@ pub(crate) use abi::{ error::Error, io::{ - DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions, PollControl, RawFd, - TerminalOptions, TerminalSize, TimerOptions, UnmountOptions, + DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions, PipeOptions, PollControl, + RawFd, TerminalOptions, TerminalSize, TimerOptions, UnmountOptions, }, mem::MappingSource, net::SocketType, diff --git a/kernel/src/syscall/imp/sys_io.rs b/kernel/src/syscall/imp/sys_io.rs index 5648ad91..cc513578 100644 --- a/kernel/src/syscall/imp/sys_io.rs +++ b/kernel/src/syscall/imp/sys_io.rs @@ -4,8 +4,8 @@ use abi::{ error::Error, io::{ ChannelPublisherId, DeviceRequest, DirectoryEntry, FileAttr, FileMetadataUpdate, FileMode, - MessageDestination, OpenOptions, PollControl, RawFd, ReceivedMessageMetadata, SeekFrom, - SentMessage, TerminalOptions, TerminalSize, TimerOptions, + MessageDestination, OpenOptions, PipeOptions, PollControl, RawFd, ReceivedMessageMetadata, + SeekFrom, SentMessage, TerminalOptions, TerminalSize, TimerOptions, }, process::ProcessId, }; @@ -299,12 +299,15 @@ pub(crate) fn create_poll_channel() -> Result { }) } -pub(crate) fn create_pipe(ends: &mut [MaybeUninit; 2]) -> Result<(), Error> { +pub(crate) fn create_pipe( + ends: &mut [MaybeUninit; 2], + options: PipeOptions, +) -> Result<(), Error> { let thread = Thread::current(); let process = thread.process(); run_with_io(&process, |mut io| { - let (read, write) = File::new_pipe_pair(256); + let (read, write) = File::new_pipe_pair(256, options); let read_fd = io.files.place_file(read, true)?; let write_fd = io.files.place_file(write, true)?; diff --git a/lib/abi/def/io.abi b/lib/abi/def/io.abi index ac963b80..1ca533e7 100644 --- a/lib/abi/def/io.abi +++ b/lib/abi/def/io.abi @@ -49,6 +49,13 @@ bitfield OpenOptions(u32) { CREATE_EXCL: 5, } +bitfield PipeOptions(u32) { + /// If set, read end of the pipe will return WouldBlock if there's no data in the pipe + READ_NONBLOCKING: 0, + /// If set, write end of the pipe will return WouldBlock if the pipe is full + WRITE_NONBLOCKING: 1, +} + enum FileType(u32) { /// Regular file File = 1, diff --git a/lib/abi/def/yggdrasil.abi b/lib/abi/def/yggdrasil.abi index 13ec757a..c4d97ebb 100644 --- a/lib/abi/def/yggdrasil.abi +++ b/lib/abi/def/yggdrasil.abi @@ -105,7 +105,7 @@ syscall create_pid(pid: ProcessId) -> Result; syscall create_pty(opts: &TerminalOptions, size: &TerminalSize, fds: &mut [MaybeUninit; 2]) -> Result<()>; syscall create_shared_memory(size: usize) -> Result; syscall create_poll_channel() -> Result; -syscall create_pipe(fds: &mut [MaybeUninit; 2]) -> Result<()>; +syscall create_pipe(fds: &mut [MaybeUninit; 2], options: PipeOptions) -> Result<()>; syscall poll_channel_wait( poll_fd: RawFd, diff --git a/lib/abi/src/io/mod.rs b/lib/abi/src/io/mod.rs index 65f605b7..bfe31cd9 100644 --- a/lib/abi/src/io/mod.rs +++ b/lib/abi/src/io/mod.rs @@ -5,8 +5,8 @@ mod input; mod terminal; pub use crate::generated::{ - DirectoryEntry, FileAttr, FileMode, FileType, GroupId, MountOptions, OpenOptions, PollControl, - RawFd, TimerOptions, UnmountOptions, UserId, + DirectoryEntry, FileAttr, FileMode, FileType, GroupId, MountOptions, OpenOptions, PipeOptions, + PollControl, RawFd, TimerOptions, UnmountOptions, UserId, }; pub use channel::{ChannelPublisherId, MessageDestination, ReceivedMessageMetadata, SentMessage}; pub use device::DeviceRequest; diff --git a/lib/runtime/src/io.rs b/lib/runtime/src/io.rs index 8b7d63f6..50eea59c 100644 --- a/lib/runtime/src/io.rs +++ b/lib/runtime/src/io.rs @@ -23,5 +23,5 @@ pub mod poll { pub use abi::io::{ DirectoryEntry, FileAttr, FileMetadataUpdate, FileMetadataUpdateMode, FileMode, FileType, - OpenOptions, RawFd, SeekFrom, TimerOptions, + OpenOptions, PipeOptions, RawFd, SeekFrom, TimerOptions, }; diff --git a/lib/runtime/src/sys/mod.rs b/lib/runtime/src/sys/mod.rs index 551591cc..23b0d9ba 100644 --- a/lib/runtime/src/sys/mod.rs +++ b/lib/runtime/src/sys/mod.rs @@ -17,7 +17,8 @@ mod generated { error::Error, io::{ ChannelPublisherId, DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions, - PollControl, RawFd, TerminalOptions, TerminalSize, TimerOptions, UnmountOptions, + PipeOptions, PollControl, RawFd, TerminalOptions, TerminalSize, TimerOptions, + UnmountOptions, }, mem::MappingSource, net::SocketType, diff --git a/userspace/lib/cross/src/io.rs b/userspace/lib/cross/src/io.rs index 755fba60..ab2336b7 100644 --- a/userspace/lib/cross/src/io.rs +++ b/userspace/lib/cross/src/io.rs @@ -1,10 +1,15 @@ use std::{ - io, + io::{self, Read, Write}, os::fd::{AsRawFd, RawFd}, + process::Stdio, time::Duration, }; -use crate::sys::{self, Poll as SysPoll, TimerFd as SysTimerFd, PidFd as SysPidFd}; +use crate::sys::{ + self, PidFd as SysPidFd, Pipe as SysPipe, Poll as SysPoll, TimerFd as SysTimerFd, +}; + +use self::sys::PipeImpl; #[repr(transparent)] pub struct Poll(sys::PollImpl); @@ -15,6 +20,9 @@ pub struct TimerFd(sys::TimerFdImpl); #[repr(transparent)] pub struct PidFd(sys::PidFdImpl); +#[repr(transparent)] +pub struct Pipe(sys::PipeImpl); + impl Poll { pub fn new() -> io::Result { sys::PollImpl::new().map(Self) @@ -74,3 +82,36 @@ impl AsRawFd for PidFd { self.0.as_raw_fd() } } + +impl Pipe { + pub fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)> { + let (read, write) = PipeImpl::new(read_nonblocking, write_nonblocking)?; + Ok((Self(read), Self(write))) + } + + pub fn to_child_stdio(&self) -> Stdio { + self.0.to_child_stdio() + } +} + +impl Read for Pipe { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } +} + +impl Write for Pipe { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } +} + +impl AsRawFd for Pipe { + fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() + } +} diff --git a/userspace/lib/cross/src/sys/mod.rs b/userspace/lib/cross/src/sys/mod.rs index e64a31e4..7f1ad175 100644 --- a/userspace/lib/cross/src/sys/mod.rs +++ b/userspace/lib/cross/src/sys/mod.rs @@ -9,8 +9,9 @@ mod unix; pub(crate) use unix::*; use std::{ - io, + io::{self, Read, Write}, os::fd::{AsRawFd, RawFd}, + process::Stdio, time::Duration, }; @@ -31,3 +32,8 @@ pub(crate) trait PidFd: Sized + AsRawFd { fn new(pid: u32) -> io::Result; fn exit_status(&self) -> io::Result; } + +pub(crate) trait Pipe: Read + Write + AsRawFd + Sized { + fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)>; + fn to_child_stdio(&self) -> Stdio; +} diff --git a/userspace/lib/cross/src/sys/unix/mod.rs b/userspace/lib/cross/src/sys/unix/mod.rs index a1837498..5770805a 100644 --- a/userspace/lib/cross/src/sys/unix/mod.rs +++ b/userspace/lib/cross/src/sys/unix/mod.rs @@ -1,7 +1,9 @@ pub mod poll; pub mod timer; pub mod pid; +pub mod pipe; pub use poll::PollImpl; pub use timer::TimerFdImpl; pub use pid::PidFdImpl; +pub use pipe::PipeImpl; diff --git a/userspace/lib/cross/src/sys/unix/pid.rs b/userspace/lib/cross/src/sys/unix/pid.rs index f45c4621..ebed1cf2 100644 --- a/userspace/lib/cross/src/sys/unix/pid.rs +++ b/userspace/lib/cross/src/sys/unix/pid.rs @@ -10,7 +10,7 @@ pub struct PidFdImpl { impl PidFd for PidFdImpl { fn new(pid: u32) -> io::Result { let pid = pid as i32; - let fd = unsafe { libc::pidfd_open(pid) }; + let fd = unsafe { libc::syscall(libc::SYS_pidfd_open, pid, 0) } as i32; if fd < 0 { return Err(io::Error::last_os_error()); } @@ -19,7 +19,7 @@ impl PidFd for PidFdImpl { } fn exit_status(&self) -> io::Result { - let status = 0; + let mut status = 0; let res = unsafe { libc::waitpid(self.pid, &mut status, 0) }; if res < 0 { return Err(io::Error::last_os_error()); diff --git a/userspace/lib/cross/src/sys/unix/pipe.rs b/userspace/lib/cross/src/sys/unix/pipe.rs new file mode 100644 index 00000000..1f25fc8a --- /dev/null +++ b/userspace/lib/cross/src/sys/unix/pipe.rs @@ -0,0 +1,47 @@ +use std::{io::{self, Read, Write}, os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}}; + +use crate::sys::Pipe; + +pub struct PipeImpl { + fd: OwnedFd, +} + +impl Pipe for PipeImpl { + fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)> { + let mut fds = [0; 2]; + let res = unsafe { libc::pipe(fds.as_mut_ptr()) }; + if res < 0 { + return Err(io::Error::last_os_error()); + } + let read = unsafe { OwnedFd::from_raw_fd(fds[0]) }; + let write = unsafe { OwnedFd::from_raw_fd(fds[1]) }; + + Ok((Self { fd: read }, Self { fd: write })) + } + + fn to_child_stdio(&self) -> std::process::Stdio { + todo!() + } +} + +impl Read for PipeImpl { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + todo!() + } +} + +impl Write for PipeImpl { + fn write(&mut self, buf: &[u8]) -> io::Result { + todo!() + } + + fn flush(&mut self) -> io::Result<()> { + todo!() + } +} + +impl AsRawFd for PipeImpl { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } +} diff --git a/userspace/lib/cross/src/sys/yggdrasil/mod.rs b/userspace/lib/cross/src/sys/yggdrasil/mod.rs index a1837498..5770805a 100644 --- a/userspace/lib/cross/src/sys/yggdrasil/mod.rs +++ b/userspace/lib/cross/src/sys/yggdrasil/mod.rs @@ -1,7 +1,9 @@ pub mod poll; pub mod timer; pub mod pid; +pub mod pipe; pub use poll::PollImpl; pub use timer::TimerFdImpl; pub use pid::PidFdImpl; +pub use pipe::PipeImpl; diff --git a/userspace/lib/cross/src/sys/yggdrasil/pipe.rs b/userspace/lib/cross/src/sys/yggdrasil/pipe.rs new file mode 100644 index 00000000..e12b0b80 --- /dev/null +++ b/userspace/lib/cross/src/sys/yggdrasil/pipe.rs @@ -0,0 +1,50 @@ +use std::{ + fs::File, + io::{self, Read, Write}, + os::{ + fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, + yggdrasil::io::pipe::create_pipe_pair, + }, + process::Stdio, +}; + +use crate::sys::Pipe; + +pub struct PipeImpl { + file: File, +} + +impl Pipe for PipeImpl { + fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)> { + let (read, write) = create_pipe_pair(read_nonblocking, write_nonblocking)?; + let read = unsafe { File::from_raw_fd(read.into_raw_fd()) }; + let write = unsafe { File::from_raw_fd(write.into_raw_fd()) }; + Ok((Self { file: read }, Self { file: write })) + } + + fn to_child_stdio(&self) -> Stdio { + unsafe { Stdio::from_raw_fd(self.as_raw_fd()) } + } +} + +impl Write for PipeImpl { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.file.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.file.flush() + } +} + +impl Read for PipeImpl { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.file.read(buf) + } +} + +impl AsRawFd for PipeImpl { + fn as_raw_fd(&self) -> RawFd { + self.file.as_raw_fd() + } +} diff --git a/userspace/rsh/src/lib.rs b/userspace/rsh/src/lib.rs index df9fbe9a..0045088b 100644 --- a/userspace/rsh/src/lib.rs +++ b/userspace/rsh/src/lib.rs @@ -1,5 +1,5 @@ #![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))] -#![feature(generic_const_exprs, portable_simd, if_let_guard)] +#![feature(generic_const_exprs, portable_simd, if_let_guard, let_chains)] #![allow(incomplete_features)] pub mod socket; diff --git a/userspace/rsh/src/main.rs b/userspace/rsh/src/main.rs index 0be0a09d..41c73d01 100644 --- a/userspace/rsh/src/main.rs +++ b/userspace/rsh/src/main.rs @@ -2,7 +2,7 @@ #![feature(let_chains)] use std::{ - io::{stdout, IsTerminal, Read, Write}, + io::{stderr, stdout, IsTerminal, Read, Stderr, Write}, ops::{Deref, DerefMut}, os::fd::AsRawFd, }; @@ -11,11 +11,11 @@ use clap::Parser; use cross::io::Poll; use rsh::{ crypt::{ - client::{self, ClientSocket}, + client::{self, ClientSocket, Message}, config::{ClientConfig, SimpleClientKeyStore}, signature::{SignEd25519, SignatureMethod}, }, - proto::ServerMessage, + proto::{ServerMessage, StreamIndex}, }; use std::{ @@ -70,8 +70,9 @@ enum Input { pub struct Client { poll: Poll, socket: ClientSocket, - input: Input, + stdin: RawStdin, stdout: Stdout, + stderr: Stderr, need_bye: bool, last0: u8, last1: u8, @@ -124,30 +125,14 @@ impl Drop for RawStdin { } impl Client { - pub fn connect( - remote: SocketAddr, - crypto_config: ClientConfig, - command: Vec, - ) -> Result { + pub fn connect(remote: SocketAddr, crypto_config: ClientConfig) -> Result { let mut poll = Poll::new()?; let mut socket = ClientSocket::connect(remote, crypto_config)?; - let input = match command.is_empty() { - true => { - let stdin = RawStdin::open()?; - poll.add(&stdin)?; - Input::Stdin(stdin) - }, - false => { - let mut bytes = vec![]; - for command in command { - bytes.extend_from_slice(command.as_bytes()); - bytes.push(b' '); - } - Input::Command(bytes) - }, - }; + let stdin = RawStdin::open()?; let stdout = stdout(); + let stderr = stderr(); + poll.add(&stdin)?; poll.add(&socket)?; let info = terminal_info(&stdout)?; @@ -155,8 +140,9 @@ impl Client { Self::handshake(&mut socket, info)?; Ok(Self { - input, + stdin, stdout, + stderr, socket, poll, need_bye: false, @@ -187,19 +173,20 @@ impl Client { pub fn run(mut self) -> Result { let mut recv_buf = [0; 512]; - if let Input::Command(command) = &self.input { - self.socket.write_all(&ClientMessage::Input(&command[..]))?; - } - loop { if let Some(message) = self.socket.read(&mut recv_buf)? { match message { ServerMessage::Bye(reason) => return Ok(reason.into()), - ServerMessage::Output(data) => { + ServerMessage::Output(StreamIndex::Stdout, data) => { self.stdout.write_all(data).ok(); self.stdout.flush().ok(); continue; } + ServerMessage::Output(StreamIndex::Stderr, data) => { + self.stderr.write_all(data).ok(); + self.stderr.flush().ok(); + continue; + } _ => continue, } } @@ -210,8 +197,8 @@ impl Client { if self.socket.poll()? == 0 { return Ok("".into()); } - } else if let Input::Stdin(stdin) = &mut self.input && stdin.as_raw_fd() == fd { - let len = stdin.read(&mut recv_buf)?; + } else if self.stdin.as_raw_fd() == fd { + let len = self.stdin.read(&mut recv_buf)?; self.update_last(&recv_buf[..len])?; self.socket .write_all(&ClientMessage::Input(&recv_buf[..len]))?; @@ -247,27 +234,101 @@ fn terminal_info(stdout: &Stdout) -> Result { }) } -fn run(args: Args) -> Result<(), Error> { +fn run_terminal(remote: SocketAddr, config: ClientConfig) -> Result<(), Error> { + let reason = Client::connect(remote, config)?.run()?; + if !reason.is_empty() { + eprintln!("\nDisconnected: {reason}"); + } + Ok(()) +} + +fn run_command( + remote: SocketAddr, + config: ClientConfig, + command: Vec, +) -> Result { + let mut poll = Poll::new()?; + let mut buffer = [0; 512]; + let mut command_string = String::new(); + for (i, word) in command.iter().enumerate() { + if i != 0 { + command_string.push(' '); + } + command_string.push_str(word); + } + + let mut stdin = stdin(); + let mut stdout = stdout(); + let mut stderr = stderr(); + + let mut socket = ClientSocket::connect(remote, config)?; + + poll.add(&socket)?; + poll.add(&stdin)?; + + socket.write_all(&ClientMessage::RunCommand(command_string.as_str()))?; + + loop { + let fd = poll.wait(None)?.unwrap(); + + match fd { + _ if fd == socket.as_raw_fd() => { + let message = match socket.poll_read(&mut buffer)? { + Message::Data(data) => data, + Message::Incomplete => continue, + Message::Closed => break + }; + + match message { + ServerMessage::Output(StreamIndex::Stdout, output) => { + stdout.write_all(output).ok(); + stdout.flush().ok(); + } + ServerMessage::Output(StreamIndex::Stderr, output) => { + stderr.write_all(output).ok(); + stderr.flush().ok(); + } + _ => todo!() + } + } + _ if fd == stdin.as_raw_fd() => { + let len = stdin.read(&mut buffer)?; + if len == 0 { + poll.remove(&stdin)?; + socket.write_all(&ClientMessage::CloseStdin)?; + } else { + socket.write_all(&ClientMessage::Input(&buffer[..len]))?; + } + } + _ => unreachable!() + } + + } + + Ok(ExitCode::SUCCESS) +} + +fn run(args: Args) -> Result { let remote = SocketAddr::new(args.remote, args.port); let ed25519 = SignEd25519::load_signing_key(args.key).unwrap(); let key = SignatureMethod::Ed25519(ed25519); let config = ClientConfig::with_default_algorithms(SimpleClientKeyStore::new(key)); - let reason = Client::connect(remote, config, args.command)?.run()?; - if !reason.is_empty() { - eprintln!("\nDisconnected: {reason}"); + if args.command.is_empty() { + run_terminal(remote, config).map(|_| ExitCode::SUCCESS) + } else { + run_command(remote, config, args.command) } - - Ok(()) } fn main() -> ExitCode { env_logger::init(); let args = Args::parse(); - if let Err(error) = run(args) { - eprintln!("Error: {error}"); - ExitCode::FAILURE - } else { - ExitCode::SUCCESS + match run(args) { + Ok(status) => status, + Err(error) => { + eprintln!("Error: {error}"); + ExitCode::FAILURE + } } } diff --git a/userspace/rsh/src/proto.rs b/userspace/rsh/src/proto.rs index eb1e9aa6..eaf017d9 100644 --- a/userspace/rsh/src/proto.rs +++ b/userspace/rsh/src/proto.rs @@ -13,6 +13,13 @@ pub enum ClientMessage<'a> { RunCommand(&'a str), Bye(&'a str), Input(&'a [u8]), + CloseStdin, +} + +#[derive(Debug, Clone, Copy)] +pub enum StreamIndex { + Stdout, + Stderr, } #[derive(Debug)] @@ -20,7 +27,7 @@ pub enum ServerMessage<'a> { SessionOpen, CommandStatus(i32), Bye(&'a str), - Output(&'a [u8]), + Output(StreamIndex, &'a [u8]), } #[derive(Debug, thiserror::Error)] @@ -164,6 +171,32 @@ impl ClientMessage<'_> { const TAG_RUN_COMMAND: u8 = 0x81; const TAG_BYE: u8 = 0x82; const TAG_INPUT: u8 = 0x90; + const TAG_CLOSE_STDIN: u8 = 0x91; +} + +impl StreamIndex { + const TAG_STDOUT: u8 = 0x01; + const TAG_STDERR: u8 = 0x02; +} + +impl Encode for StreamIndex { + fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> { + match self { + Self::Stdout => buffer.write(&[Self::TAG_STDOUT]), + Self::Stderr => buffer.write(&[Self::TAG_STDERR]), + } + } +} + +impl<'de> Decode<'de> for StreamIndex { + fn decode(buffer: &mut Decoder<'de>) -> Result { + let tag = buffer.read_u8()?; + match tag { + Self::TAG_STDOUT => Ok(Self::Stdout), + Self::TAG_STDERR => Ok(Self::Stderr), + _ => Err(DecodeError::InvalidMessage) + } + } } impl Encode for TerminalInfo { @@ -201,6 +234,7 @@ impl<'a> Encode for ClientMessage<'a> { buffer.write(&[Self::TAG_INPUT])?; buffer.write_variable_bytes(data) } + Self::CloseStdin => buffer.write(&[Self::TAG_CLOSE_STDIN]) } } } @@ -221,6 +255,9 @@ impl<'de> Decode<'de> for ClientMessage<'de> { Self::TAG_INPUT => { buffer.read_variable_bytes().map(Self::Input) } + Self::TAG_CLOSE_STDIN => { + Ok(Self::CloseStdin) + } _ => Err(DecodeError::InvalidMessage) } } @@ -244,8 +281,9 @@ impl<'a> Encode for ServerMessage<'a> { buffer.write(&[Self::TAG_BYE])?; buffer.write_str(reason) } - Self::Output(data) => { + Self::Output(index, data) => { buffer.write(&[Self::TAG_OUTPUT])?; + index.encode(buffer)?; buffer.write_variable_bytes(data) } } @@ -267,7 +305,9 @@ impl<'de> Decode<'de> for ServerMessage<'de> { buffer.read_str().map(Self::Bye) } Self::TAG_OUTPUT => { - buffer.read_variable_bytes().map(Self::Output) + let index = StreamIndex::decode(buffer)?; + let data = buffer.read_variable_bytes()?; + Ok(Self::Output(index, data)) }, _ => Err(DecodeError::InvalidMessage), } diff --git a/userspace/rsh/src/server.rs b/userspace/rsh/src/server.rs index 36a6a680..d21440ce 100644 --- a/userspace/rsh/src/server.rs +++ b/userspace/rsh/src/server.rs @@ -1,12 +1,15 @@ use std::{ collections::{hash_map::Entry, HashMap}, - fmt, io, + fmt, + hint::unreachable_unchecked, + io::{self, Read, Stdout, Write}, net::SocketAddr, os::fd::{AsRawFd, RawFd}, + process::{Child, Command, Stdio}, time::Duration, }; -use cross::io::Poll; +use cross::io::{PidFd, Pipe, Poll}; use crate::{ crypt::{ @@ -14,7 +17,7 @@ use crate::{ config::ServerConfig, server::{self, ServerSocket}, }, - proto::{ClientMessage, ServerMessage, TerminalInfo}, + proto::{ClientMessage, ServerMessage, StreamIndex, TerminalInfo}, }; pub const PING_INTERVAL: Duration = Duration::from_millis(500); @@ -52,22 +55,28 @@ pub struct EchoSession { peer: SocketAddr, } -struct ClientSession { - stream: ClientSocket, - session: Option, +struct PendingCommand { + child: Child, + child_pid: Option, + stdin: Option, + stdout: Option, + stderr: Option, } -// enum Event<'b, T: Session> { -// NewSession(SocketAddr, TerminalInfo), -// SessionInput(u64, SocketAddr, &'b [u8]), -// ClientBye(SocketAddr, &'b str), -// SessionEvent(RawFd, SocketAddr, SessionEvent<'b, T>), -// Tick, -// } +enum ClientSession { + None, + Terminal(T), + Command(PendingCommand), +} -struct SessionSet { +struct Client { + stream: ClientSocket, + session: ClientSession, +} + +struct ClientSet { last_session_key: u64, - sessions: HashMap>, + clients: HashMap>, socket_fd_map: HashMap, session_fd_map: HashMap, } @@ -76,29 +85,63 @@ pub struct Server { poll: Poll, socket: ServerSocket, - sessions: SessionSet, + clients: ClientSet, } -impl SessionSet { +enum CommandEvent { + Output(StreamIndex, Result), + Exited(Result) +} + +impl Client { + fn make_command(text: &str) -> Result { + let mut words = text.split(' '); + let program = words.next().unwrap(); + let mut command = Command::new(program); + + let (stdin_read, stdin_write) = Pipe::new(false, false)?; + let (stdout_read, stdout_write) = Pipe::new(true, false)?; + let (stderr_read, stderr_write) = Pipe::new(true, false)?; + + command + .args(words) + .stdin(stdin_read.to_child_stdio()) + .stdout(stdout_write.to_child_stdio()) + .stderr(stderr_write.to_child_stdio()); + + let child = command.spawn()?; + let child_pid = PidFd::new(child.id())?; + + Ok(PendingCommand { + child, + child_pid: Some(child_pid), + stdout: Some(stdout_read), + stderr: Some(stderr_read), + stdin: Some(stdin_write), + }) + } +} + +impl ClientSet { pub fn new() -> Self { Self { last_session_key: 1, - sessions: HashMap::new(), + clients: HashMap::new(), socket_fd_map: HashMap::new(), session_fd_map: HashMap::new(), } } pub fn add_client(&mut self, stream: ClientSocket, poll: &mut Poll) -> Result<(), Error> { - let (key, session) = loop { + let (key, client) = loop { let key = self.last_session_key; self.last_session_key += 1; - match self.sessions.entry(key) { + match self.clients.entry(key) { Entry::Vacant(entry) => { - let session = entry.insert(ClientSession { + let session = entry.insert(Client { stream, - session: None, + session: ClientSession::None, }); break (key, session); } @@ -106,22 +149,39 @@ impl SessionSet { } }; - poll.add(&session.stream)?; - self.socket_fd_map.insert(session.stream.as_raw_fd(), key); + poll.add(&client.stream)?; + self.socket_fd_map.insert(client.stream.as_raw_fd(), key); Ok(()) } fn remove(&mut self, key: u64, poll: &mut Poll) -> Result<(), Error> { - if let Some(session) = self.sessions.remove(&key) { - poll.remove(&session.stream)?; - self.socket_fd_map.remove(&session.stream.as_raw_fd()); + if let Some(client) = self.clients.remove(&key) { + poll.remove(&client.stream)?; + self.socket_fd_map.remove(&client.stream.as_raw_fd()); - if let Some(session) = session.session { - for fd in session.event_fds() { - poll.remove(fd)?; - self.session_fd_map.remove(fd); + match client.session { + ClientSession::Terminal(terminal) => { + for fd in terminal.event_fds() { + poll.remove(fd)?; + self.session_fd_map.remove(fd); + } } + ClientSession::Command(mut command) => { + if let Some(stdout) = command.stdout.take() { + poll.remove(&stdout)?; + self.session_fd_map.remove(&stdout.as_raw_fd()); + } + if let Some(stderr) = command.stderr.take() { + poll.remove(&stderr)?; + self.session_fd_map.remove(&stderr.as_raw_fd()); + } + if let Some(child_pid) = command.child_pid.take() { + poll.remove(&child_pid)?; + self.session_fd_map.remove(&child_pid.as_raw_fd()); + } + } + ClientSession::None => (), } } Ok(()) @@ -131,10 +191,10 @@ impl SessionSet { let mut buffer = [0; 512]; if let Some(&key) = self.socket_fd_map.get(&fd) { - let session = self.sessions.get_mut(&key).unwrap(); - let peer = session.stream.remote_address(); + let client = self.clients.get_mut(&key).unwrap(); + let peer = client.stream.remote_address(); - let mut closed = match session.stream.poll() { + let mut closed = match client.stream.poll() { Ok(0) => true, Ok(_) => false, Err(error) => { @@ -144,7 +204,7 @@ impl SessionSet { }; loop { - let message = match session.stream.read(&mut buffer) { + let message = match client.stream.read(&mut buffer) { Ok(Some(message)) => message, Ok(None) => break, Err(error) => { @@ -154,8 +214,33 @@ impl SessionSet { } }; - match (&message, &mut session.session) { - (ClientMessage::OpenSession(terminal), None) => { + match (&message, &mut client.session) { + (ClientMessage::RunCommand(command), ClientSession::None) => { + log::info!("{peer}: cmd {command:?}"); + + let command = match Client::::make_command(command) { + Ok(command) => command, + Err(error) => { + log::error!("{peer}: command error: {error}"); + return self.remove(key, poll); + } + }; + + let stdout_fd = command.stdout.as_ref().unwrap().as_raw_fd(); + let stderr_fd = command.stderr.as_ref().unwrap().as_raw_fd(); + let child_pid_fd = command.child_pid.as_ref().unwrap().as_raw_fd(); + + poll.add(&stdout_fd)?; + poll.add(&stderr_fd)?; + poll.add(&child_pid_fd)?; + + self.session_fd_map.insert(stdout_fd, key); + self.session_fd_map.insert(stderr_fd, key); + self.session_fd_map.insert(child_pid_fd, key); + + client.session = ClientSession::Command(command); + } + (ClientMessage::OpenSession(terminal), ClientSession::None) => { log::info!("{peer}: new session"); let terminal = match T::open(&peer, &terminal) { Ok(session) => session, @@ -165,15 +250,33 @@ impl SessionSet { return Ok(()); } }; - let terminal = session.session.insert(terminal); + let terminal = client.session.set_terminal(terminal); for fd in terminal.event_fds() { poll.add(fd)?; self.session_fd_map.insert(*fd, key); } - session.stream.write_all(&ServerMessage::SessionOpen).ok(); + + client.stream.write_all(&ServerMessage::SessionOpen).ok(); } - (ClientMessage::Input(data), Some(terminal)) => { - let client = SessionClient { transport: &mut session.stream }; + (ClientMessage::Input(data), ClientSession::Command(command)) => { + if let Some(stdin) = command.stdin.as_mut() { + match stdin.write_all(data) { + Ok(()) => { + stdin.flush().ok(); + }, + Err(error) => { + log::error!("{peer}: stdin error: {error}"); + return self.remove(key, poll); + } + } + } else { + log::warn!("{peer}: stdin already closed"); + } + } + (ClientMessage::Input(data), ClientSession::Terminal(terminal)) => { + let client = SessionClient { + transport: &mut client.stream, + }; match terminal.handle_input(data, client) { Ok(false) => (), Ok(true) => { @@ -186,6 +289,10 @@ impl SessionSet { } } } + (ClientMessage::CloseStdin, ClientSession::Command(command)) => { + // Drop stdin handle + command.stdin.take(); + } _ => { log::warn!("{peer}: unhandled message"); } @@ -199,45 +306,148 @@ impl SessionSet { Ok(()) } else if let Some(&key) = self.session_fd_map.get(&fd) { - let session = self.sessions.get_mut(&key).unwrap(); - let terminal = session.session.as_mut().unwrap(); - let peer = session.stream.remote_address(); + log::debug!("poll fd {:?}", fd); + let client = self.clients.get_mut(&key).unwrap(); + let peer = client.stream.remote_address(); - match terminal.read_output(fd, &mut buffer) { - Ok(0) => { - log::info!("{peer}: session closed"); - self.remove(key, poll) - } - Ok(mut len) => { - // Split output into 128-byte chunks - let mut pos = 0; - while len != 0 { - let amount = core::cmp::min(len, 128); + let (mut len, stream_index) = match &mut client.session { + ClientSession::Command(command) => match command.read_output(fd, &mut buffer) { + CommandEvent::Output(index, Ok(len)) => { + if len == 0 { + poll.remove(&fd)?; + self.socket_fd_map.remove(&fd); - if let Err(error) = session - .stream - .write_all(&ServerMessage::Output(&buffer[pos..pos + amount])) - { - log::error!("{peer}: communication error: {error}"); + match index { + StreamIndex::Stdout => { + command.stdout = None; + } + StreamIndex::Stderr => { + command.stderr = None; + } + } + + if command.is_dead() { + log::info!("{peer}: command stdout/stderr closed"); + return self.remove(key, poll); + } + } + (len, index) + }, + CommandEvent::Output(index, Err(error)) => { + if error.kind() == io::ErrorKind::WouldBlock { + return Ok(()); + } + log::error!("{peer}: {index:?} error: {error}"); + return self.remove(key, poll); + } + CommandEvent::Exited(status) => { + log::info!("{peer}: command exited: {:?}", status); + poll.remove(&fd)?; + self.socket_fd_map.remove(&fd); + command.child_pid = None; + + if command.is_dead() { + log::info!("{peer}: command stdout/stderr closed"); return self.remove(key, poll); } - pos += amount; - len -= amount; + return Ok(()); } - Ok(()) - } - Err(error) => { - log::error!("{peer}: session read error: {error}"); - self.remove(key, poll) - } + }, + ClientSession::Terminal(terminal) => match terminal.read_output(fd, &mut buffer) { + Ok(0) => { + poll.remove(&fd)?; + self.socket_fd_map.remove(&fd); + + // TODO check for process as well + log::info!("{peer}: terminal closed"); + return self.remove(key, poll); + } + Ok(len) => { + (len, StreamIndex::Stdout) + }, + Err(error) => { + log::error!("{peer}: terminal error: {error}"); + return self.remove(key, poll); + } + }, + ClientSession::None => unreachable!(), + }; + + if len == 0 { + log::info!("{peer}: {stream_index:?} closed"); + return Ok(()); + // return self.remove(key, poll); } + + // Split output into 128-byte chunks + let mut pos = 0; + while len != 0 { + let amount = core::cmp::min(len, 128); + + if let Err(error) = client.stream.write_all(&ServerMessage::Output( + stream_index, + &buffer[pos..pos + amount], + )) { + log::error!("{peer}: communication error: {error}"); + return self.remove(key, poll); + } + + pos += amount; + len -= amount; + } + + log::debug!("Done"); + Ok(()) } else { unreachable!() } } } +impl PendingCommand { + pub fn read_output( + &mut self, + fd: RawFd, + buffer: &mut [u8], + ) -> CommandEvent { + if let Some(stdout) = self.stdout.as_mut() && fd == stdout.as_raw_fd() { + log::debug!("poll stdout"); + let res = stdout.read(buffer); + log::debug!(">> {:?}", res); + return CommandEvent::Output(StreamIndex::Stdout, res); + } + if let Some(stderr) = self.stderr.as_mut() && fd == stderr.as_raw_fd() { + return CommandEvent::Output(StreamIndex::Stderr, stderr.read(buffer)); + } + if let Some(child_pid) = self.child_pid.as_mut() && fd == child_pid.as_raw_fd() { + let status = child_pid.exit_status(); + return CommandEvent::Exited(status); + } + unreachable!() + } + + pub fn is_dead(&self) -> bool { + self.child_pid.is_none() && self.stdout.is_none() && self.stderr.is_none() + } +} + +impl Drop for PendingCommand { + fn drop(&mut self) { + self.child.wait().ok(); + } +} + +impl ClientSession { + pub fn set_terminal(&mut self, terminal: T) -> &mut T { + *self = Self::Terminal(terminal); + match self { + Self::Terminal(terminal) => terminal, + _ => unsafe { unreachable_unchecked() }, + } + } +} + impl Server { pub fn listen(listen_addr: SocketAddr, crypto_config: ServerConfig) -> Result { let mut poll = Poll::new()?; @@ -246,7 +456,7 @@ impl Server { Ok(Self { poll, socket, - sessions: SessionSet::new(), + clients: ClientSet::new(), }) } @@ -259,14 +469,14 @@ impl Server { // Poll server socket match self.socket.poll()? { Some(client) => { - self.sessions.add_client(client, &mut self.poll)?; + self.clients.add_client(client, &mut self.poll)?; } None => continue, } } _ => { // Client/Session-related activity - self.sessions.handle_input(fd, &mut self.poll)?; + self.clients.handle_input(fd, &mut self.poll)?; } } } @@ -275,7 +485,7 @@ impl Server { impl<'s> SessionClient<'s> { pub fn send_data(&mut self, data: &[u8]) -> Result<(), Error> { - self.send_message(&ServerMessage::Output(data)) + self.send_message(&ServerMessage::Output(StreamIndex::Stdout, data)) } pub fn send_message(&mut self, message: &ServerMessage) -> Result<(), Error> { diff --git a/userspace/shell/src/sys/yggdrasil.rs b/userspace/shell/src/sys/yggdrasil.rs index fc339a80..f7508b7b 100644 --- a/userspace/shell/src/sys/yggdrasil.rs +++ b/userspace/shell/src/sys/yggdrasil.rs @@ -74,7 +74,7 @@ pub fn wait_for_pipeline( } pub fn create_pipe() -> Result { - let (read, write) = pipe::create_pipe_pair()?; + let (read, write) = pipe::create_pipe_pair(false, false)?; Ok(Pipe { read, write }) }