vfs/rsh: better pipes, piped command execution in rsh
This commit is contained in:
parent
a707a6e5f1
commit
2479702baf
@ -17,8 +17,8 @@ use libk_util::sync::IrqSafeSpinlock;
|
||||
use yggdrasil_abi::{
|
||||
error::Error,
|
||||
io::{
|
||||
DeviceRequest, DirectoryEntry, OpenOptions, RawFd, SeekFrom, TerminalOptions, TerminalSize,
|
||||
TimerOptions,
|
||||
DeviceRequest, DirectoryEntry, OpenOptions, PipeOptions, RawFd, SeekFrom, TerminalOptions,
|
||||
TerminalSize, TimerOptions,
|
||||
},
|
||||
};
|
||||
|
||||
@ -71,7 +71,7 @@ pub enum File {
|
||||
Block(BlockFile),
|
||||
Char(CharFile),
|
||||
Socket(SocketWrapper),
|
||||
AnonymousPipe(PipeEnd),
|
||||
AnonymousPipe(PipeEnd, AtomicBool),
|
||||
Poll(FdPoll),
|
||||
Timer(TimerFile),
|
||||
Channel(ChannelDescriptor),
|
||||
@ -103,11 +103,17 @@ pub struct FileSet {
|
||||
|
||||
impl File {
|
||||
/// Constructs a pipe pair, returning its `(read, write)` ends
|
||||
pub fn new_pipe_pair(capacity: usize) -> (Arc<Self>, Arc<Self>) {
|
||||
pub fn new_pipe_pair(capacity: usize, options: PipeOptions) -> (Arc<Self>, Arc<Self>) {
|
||||
let (read, write) = PipeEnd::new_pair(capacity);
|
||||
(
|
||||
Arc::new(Self::AnonymousPipe(read)),
|
||||
Arc::new(Self::AnonymousPipe(write)),
|
||||
Arc::new(Self::AnonymousPipe(
|
||||
read,
|
||||
AtomicBool::new(options.contains(PipeOptions::READ_NONBLOCKING)),
|
||||
)),
|
||||
Arc::new(Self::AnonymousPipe(
|
||||
write,
|
||||
AtomicBool::new(options.contains(PipeOptions::WRITE_NONBLOCKING)),
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
@ -289,6 +295,7 @@ impl File {
|
||||
Self::Socket(socket) => socket.poll_read(cx),
|
||||
Self::Timer(timer) => timer.poll_read(cx),
|
||||
Self::Pid(pid) => pid.poll_read(cx),
|
||||
Self::AnonymousPipe(pipe, _) => pipe.poll_read(cx),
|
||||
// Polling not implemented, return ready immediately (XXX ?)
|
||||
_ => Poll::Ready(Err(Error::NotImplemented)),
|
||||
}
|
||||
@ -365,7 +372,9 @@ impl Read for File {
|
||||
Self::Regular(file) => file.read(buf),
|
||||
Self::Block(file) => file.read(buf),
|
||||
Self::Char(file) => file.read(buf),
|
||||
Self::AnonymousPipe(pipe) => pipe.read(buf),
|
||||
Self::AnonymousPipe(pipe, nonblocking) => {
|
||||
pipe.read(buf, nonblocking.load(Ordering::Acquire))
|
||||
}
|
||||
Self::PtySlave(half) => half.read(buf),
|
||||
Self::PtyMaster(half) => half.read(buf),
|
||||
Self::Timer(timer) => timer.read(buf),
|
||||
@ -388,7 +397,7 @@ impl Write for File {
|
||||
Self::Regular(file) => file.write(buf),
|
||||
Self::Block(file) => file.write(buf),
|
||||
Self::Char(file) => file.write(buf),
|
||||
Self::AnonymousPipe(pipe) => pipe.write(buf),
|
||||
Self::AnonymousPipe(pipe, _) => pipe.write(buf),
|
||||
Self::PtySlave(half) => half.write(buf),
|
||||
Self::PtyMaster(half) => half.write(buf),
|
||||
Self::Timer(timer) => timer.write(buf),
|
||||
@ -447,7 +456,7 @@ impl fmt::Debug for File {
|
||||
.field("write", &file.write)
|
||||
.finish_non_exhaustive(),
|
||||
Self::Directory(_) => f.debug_struct("DirectoryFile").finish_non_exhaustive(),
|
||||
Self::AnonymousPipe(_) => f.debug_struct("AnonymousPipe").finish_non_exhaustive(),
|
||||
Self::AnonymousPipe(_, _) => f.debug_struct("AnonymousPipe").finish_non_exhaustive(),
|
||||
Self::Poll(_) => f.debug_struct("Poll").finish_non_exhaustive(),
|
||||
Self::Channel(_) => f.debug_struct("Channel").finish_non_exhaustive(),
|
||||
Self::SharedMemory(_) => f.debug_struct("SharedMemory").finish_non_exhaustive(),
|
||||
|
@ -1,4 +1,5 @@
|
||||
use core::{
|
||||
future::poll_fn,
|
||||
pin::Pin,
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
task::{Context, Poll},
|
||||
@ -6,9 +7,14 @@ use core::{
|
||||
|
||||
use alloc::{sync::Arc, vec, vec::Vec};
|
||||
use futures_util::{task::AtomicWaker, Future};
|
||||
use libk_util::sync::IrqSafeSpinlock;
|
||||
use libk_util::{
|
||||
sync::{IrqSafeSpinlock, IrqSafeSpinlockGuard},
|
||||
waker::QueueWaker,
|
||||
};
|
||||
use yggdrasil_abi::error::Error;
|
||||
|
||||
use crate::vfs::FileReadiness;
|
||||
|
||||
struct PipeInner {
|
||||
data: Vec<u8>,
|
||||
capacity: usize,
|
||||
@ -19,8 +25,8 @@ struct PipeInner {
|
||||
pub struct Pipe {
|
||||
inner: IrqSafeSpinlock<PipeInner>,
|
||||
shutdown: AtomicBool,
|
||||
read_notify: AtomicWaker,
|
||||
write_notify: AtomicWaker,
|
||||
read_notify: QueueWaker,
|
||||
write_notify: QueueWaker,
|
||||
}
|
||||
|
||||
pub enum PipeEnd {
|
||||
@ -82,15 +88,15 @@ impl Pipe {
|
||||
Self {
|
||||
inner: IrqSafeSpinlock::new(PipeInner::new(capacity)),
|
||||
shutdown: AtomicBool::new(false),
|
||||
read_notify: AtomicWaker::new(),
|
||||
write_notify: AtomicWaker::new(),
|
||||
read_notify: QueueWaker::new(),
|
||||
write_notify: QueueWaker::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shutdown(&self) {
|
||||
self.shutdown.store(true, Ordering::Release);
|
||||
self.read_notify.wake();
|
||||
self.write_notify.wake();
|
||||
self.read_notify.wake_all();
|
||||
self.write_notify.wake_all();
|
||||
}
|
||||
|
||||
pub fn blocking_write(&self, val: u8) -> impl Future<Output = Result<(), Error>> + '_ {
|
||||
@ -105,23 +111,15 @@ impl Pipe {
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut lock = self.pipe.inner.lock();
|
||||
|
||||
// Try fast path before acquiring write notify to avoid unnecessary contention
|
||||
if self.pipe.shutdown.load(Ordering::Acquire) {
|
||||
// TODO BrokenPipe
|
||||
return Poll::Ready(Err(Error::ReadOnly));
|
||||
} else if lock.try_write(self.val) {
|
||||
self.pipe.read_notify.wake();
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
self.pipe.write_notify.register(cx.waker());
|
||||
|
||||
if self.pipe.shutdown.load(Ordering::Acquire) {
|
||||
self.pipe.write_notify.remove(cx.waker());
|
||||
Poll::Ready(Err(Error::ReadOnly))
|
||||
} else if lock.try_write(self.val) {
|
||||
self.pipe.read_notify.wake();
|
||||
self.pipe.write_notify.remove(cx.waker());
|
||||
self.pipe.read_notify.wake_one();
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
self.pipe.write_notify.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
@ -130,37 +128,54 @@ impl Pipe {
|
||||
F { pipe: self, val }
|
||||
}
|
||||
|
||||
pub fn blocking_read(&self) -> impl Future<Output = Option<u8>> + '_ {
|
||||
struct F<'a> {
|
||||
pipe: &'a Pipe,
|
||||
}
|
||||
fn poll_read_end(
|
||||
&self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<IrqSafeSpinlockGuard<'_, PipeInner>>> {
|
||||
let lock = self.inner.lock();
|
||||
|
||||
impl<'a> Future for F<'a> {
|
||||
type Output = Option<u8>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut lock = self.pipe.inner.lock();
|
||||
|
||||
if let Some(val) = lock.try_read() {
|
||||
self.pipe.write_notify.wake();
|
||||
return Poll::Ready(Some(val));
|
||||
} else if self.pipe.shutdown.load(Ordering::Acquire) {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
self.pipe.read_notify.register(cx.waker());
|
||||
|
||||
if let Some(val) = lock.try_read() {
|
||||
Poll::Ready(Some(val))
|
||||
} else if self.pipe.shutdown.load(Ordering::Acquire) {
|
||||
if lock.can_read() {
|
||||
self.read_notify.remove(cx.waker());
|
||||
Poll::Ready(Some(lock))
|
||||
} else if self.shutdown.load(Ordering::Acquire) {
|
||||
self.read_notify.remove(cx.waker());
|
||||
Poll::Ready(None)
|
||||
} else {
|
||||
self.read_notify.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_async(&self, buf: &mut [u8]) -> Result<usize, Error> {
|
||||
let mut pos = 0;
|
||||
let lock = poll_fn(|cx| self.poll_read_end(cx)).await;
|
||||
match lock {
|
||||
Some(mut lock) => {
|
||||
while pos < buf.len()
|
||||
&& let Some(byte) = lock.try_read()
|
||||
{
|
||||
buf[pos] = byte;
|
||||
pos += 1;
|
||||
}
|
||||
if pos != 0 {
|
||||
self.write_notify.wake_all();
|
||||
}
|
||||
Ok(pos)
|
||||
}
|
||||
None => Ok(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
F { pipe: self }
|
||||
impl FileReadiness for Pipe {
|
||||
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if self.shutdown.load(Ordering::Acquire) || self.inner.lock().can_read() {
|
||||
self.read_notify.remove(cx.waker());
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
self.read_notify.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -173,27 +188,29 @@ impl PipeEnd {
|
||||
(read, write)
|
||||
}
|
||||
|
||||
pub fn read(&self, buf: &mut [u8]) -> Result<usize, Error> {
|
||||
pub fn read(&self, buf: &mut [u8], nonblocking: bool) -> Result<usize, Error> {
|
||||
let PipeEnd::Read(read) = self else {
|
||||
return Err(Error::InvalidOperation);
|
||||
};
|
||||
|
||||
block! {
|
||||
if nonblocking {
|
||||
let mut lock = read.inner.lock();
|
||||
let mut pos = 0;
|
||||
let mut rem = buf.len();
|
||||
|
||||
while rem != 0 {
|
||||
if let Some(val) = read.blocking_read().await {
|
||||
buf[pos] = val;
|
||||
while pos < buf.len()
|
||||
&& let Some(ch) = lock.try_read()
|
||||
{
|
||||
buf[pos] = ch;
|
||||
pos += 1;
|
||||
rem -= 1;
|
||||
}
|
||||
read.write_notify.wake_all();
|
||||
if pos == 0 && !read.shutdown.load(Ordering::Acquire) {
|
||||
Err(Error::WouldBlock)
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(pos)
|
||||
}?
|
||||
}
|
||||
} else {
|
||||
block!(read.read_async(buf).await)?
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(&self, buf: &[u8]) -> Result<usize, Error> {
|
||||
@ -216,6 +233,15 @@ impl PipeEnd {
|
||||
}
|
||||
}
|
||||
|
||||
impl FileReadiness for PipeEnd {
|
||||
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
match self {
|
||||
Self::Read(pipe) => pipe.poll_read(cx),
|
||||
Self::Write(_) => Poll::Ready(Err(Error::NotImplemented)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PipeEnd {
|
||||
fn drop(&mut self) {
|
||||
match self {
|
||||
|
@ -1,8 +1,8 @@
|
||||
pub(crate) use abi::{
|
||||
error::Error,
|
||||
io::{
|
||||
DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions, PollControl, RawFd,
|
||||
TerminalOptions, TerminalSize, TimerOptions, UnmountOptions,
|
||||
DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions, PipeOptions, PollControl,
|
||||
RawFd, TerminalOptions, TerminalSize, TimerOptions, UnmountOptions,
|
||||
},
|
||||
mem::MappingSource,
|
||||
net::SocketType,
|
||||
|
@ -4,8 +4,8 @@ use abi::{
|
||||
error::Error,
|
||||
io::{
|
||||
ChannelPublisherId, DeviceRequest, DirectoryEntry, FileAttr, FileMetadataUpdate, FileMode,
|
||||
MessageDestination, OpenOptions, PollControl, RawFd, ReceivedMessageMetadata, SeekFrom,
|
||||
SentMessage, TerminalOptions, TerminalSize, TimerOptions,
|
||||
MessageDestination, OpenOptions, PipeOptions, PollControl, RawFd, ReceivedMessageMetadata,
|
||||
SeekFrom, SentMessage, TerminalOptions, TerminalSize, TimerOptions,
|
||||
},
|
||||
process::ProcessId,
|
||||
};
|
||||
@ -299,12 +299,15 @@ pub(crate) fn create_poll_channel() -> Result<RawFd, Error> {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn create_pipe(ends: &mut [MaybeUninit<RawFd>; 2]) -> Result<(), Error> {
|
||||
pub(crate) fn create_pipe(
|
||||
ends: &mut [MaybeUninit<RawFd>; 2],
|
||||
options: PipeOptions,
|
||||
) -> Result<(), Error> {
|
||||
let thread = Thread::current();
|
||||
let process = thread.process();
|
||||
|
||||
run_with_io(&process, |mut io| {
|
||||
let (read, write) = File::new_pipe_pair(256);
|
||||
let (read, write) = File::new_pipe_pair(256, options);
|
||||
|
||||
let read_fd = io.files.place_file(read, true)?;
|
||||
let write_fd = io.files.place_file(write, true)?;
|
||||
|
@ -49,6 +49,13 @@ bitfield OpenOptions(u32) {
|
||||
CREATE_EXCL: 5,
|
||||
}
|
||||
|
||||
bitfield PipeOptions(u32) {
|
||||
/// If set, read end of the pipe will return WouldBlock if there's no data in the pipe
|
||||
READ_NONBLOCKING: 0,
|
||||
/// If set, write end of the pipe will return WouldBlock if the pipe is full
|
||||
WRITE_NONBLOCKING: 1,
|
||||
}
|
||||
|
||||
enum FileType(u32) {
|
||||
/// Regular file
|
||||
File = 1,
|
||||
|
@ -105,7 +105,7 @@ syscall create_pid(pid: ProcessId) -> Result<RawFd>;
|
||||
syscall create_pty(opts: &TerminalOptions, size: &TerminalSize, fds: &mut [MaybeUninit<RawFd>; 2]) -> Result<()>;
|
||||
syscall create_shared_memory(size: usize) -> Result<RawFd>;
|
||||
syscall create_poll_channel() -> Result<RawFd>;
|
||||
syscall create_pipe(fds: &mut [MaybeUninit<RawFd>; 2]) -> Result<()>;
|
||||
syscall create_pipe(fds: &mut [MaybeUninit<RawFd>; 2], options: PipeOptions) -> Result<()>;
|
||||
|
||||
syscall poll_channel_wait(
|
||||
poll_fd: RawFd,
|
||||
|
@ -5,8 +5,8 @@ mod input;
|
||||
mod terminal;
|
||||
|
||||
pub use crate::generated::{
|
||||
DirectoryEntry, FileAttr, FileMode, FileType, GroupId, MountOptions, OpenOptions, PollControl,
|
||||
RawFd, TimerOptions, UnmountOptions, UserId,
|
||||
DirectoryEntry, FileAttr, FileMode, FileType, GroupId, MountOptions, OpenOptions, PipeOptions,
|
||||
PollControl, RawFd, TimerOptions, UnmountOptions, UserId,
|
||||
};
|
||||
pub use channel::{ChannelPublisherId, MessageDestination, ReceivedMessageMetadata, SentMessage};
|
||||
pub use device::DeviceRequest;
|
||||
|
@ -23,5 +23,5 @@ pub mod poll {
|
||||
|
||||
pub use abi::io::{
|
||||
DirectoryEntry, FileAttr, FileMetadataUpdate, FileMetadataUpdateMode, FileMode, FileType,
|
||||
OpenOptions, RawFd, SeekFrom, TimerOptions,
|
||||
OpenOptions, PipeOptions, RawFd, SeekFrom, TimerOptions,
|
||||
};
|
||||
|
@ -17,7 +17,8 @@ mod generated {
|
||||
error::Error,
|
||||
io::{
|
||||
ChannelPublisherId, DirectoryEntry, FileAttr, FileMode, MountOptions, OpenOptions,
|
||||
PollControl, RawFd, TerminalOptions, TerminalSize, TimerOptions, UnmountOptions,
|
||||
PipeOptions, PollControl, RawFd, TerminalOptions, TerminalSize, TimerOptions,
|
||||
UnmountOptions,
|
||||
},
|
||||
mem::MappingSource,
|
||||
net::SocketType,
|
||||
|
@ -1,10 +1,15 @@
|
||||
use std::{
|
||||
io,
|
||||
io::{self, Read, Write},
|
||||
os::fd::{AsRawFd, RawFd},
|
||||
process::Stdio,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use crate::sys::{self, Poll as SysPoll, TimerFd as SysTimerFd, PidFd as SysPidFd};
|
||||
use crate::sys::{
|
||||
self, PidFd as SysPidFd, Pipe as SysPipe, Poll as SysPoll, TimerFd as SysTimerFd,
|
||||
};
|
||||
|
||||
use self::sys::PipeImpl;
|
||||
|
||||
#[repr(transparent)]
|
||||
pub struct Poll(sys::PollImpl);
|
||||
@ -15,6 +20,9 @@ pub struct TimerFd(sys::TimerFdImpl);
|
||||
#[repr(transparent)]
|
||||
pub struct PidFd(sys::PidFdImpl);
|
||||
|
||||
#[repr(transparent)]
|
||||
pub struct Pipe(sys::PipeImpl);
|
||||
|
||||
impl Poll {
|
||||
pub fn new() -> io::Result<Self> {
|
||||
sys::PollImpl::new().map(Self)
|
||||
@ -74,3 +82,36 @@ impl AsRawFd for PidFd {
|
||||
self.0.as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
impl Pipe {
|
||||
pub fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)> {
|
||||
let (read, write) = PipeImpl::new(read_nonblocking, write_nonblocking)?;
|
||||
Ok((Self(read), Self(write)))
|
||||
}
|
||||
|
||||
pub fn to_child_stdio(&self) -> Stdio {
|
||||
self.0.to_child_stdio()
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for Pipe {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.0.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for Pipe {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.0.write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.0.flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRawFd for Pipe {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.0.as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
@ -9,8 +9,9 @@ mod unix;
|
||||
pub(crate) use unix::*;
|
||||
|
||||
use std::{
|
||||
io,
|
||||
io::{self, Read, Write},
|
||||
os::fd::{AsRawFd, RawFd},
|
||||
process::Stdio,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
@ -31,3 +32,8 @@ pub(crate) trait PidFd: Sized + AsRawFd {
|
||||
fn new(pid: u32) -> io::Result<Self>;
|
||||
fn exit_status(&self) -> io::Result<i32>;
|
||||
}
|
||||
|
||||
pub(crate) trait Pipe: Read + Write + AsRawFd + Sized {
|
||||
fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)>;
|
||||
fn to_child_stdio(&self) -> Stdio;
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
pub mod poll;
|
||||
pub mod timer;
|
||||
pub mod pid;
|
||||
pub mod pipe;
|
||||
|
||||
pub use poll::PollImpl;
|
||||
pub use timer::TimerFdImpl;
|
||||
pub use pid::PidFdImpl;
|
||||
pub use pipe::PipeImpl;
|
||||
|
@ -10,7 +10,7 @@ pub struct PidFdImpl {
|
||||
impl PidFd for PidFdImpl {
|
||||
fn new(pid: u32) -> io::Result<Self> {
|
||||
let pid = pid as i32;
|
||||
let fd = unsafe { libc::pidfd_open(pid) };
|
||||
let fd = unsafe { libc::syscall(libc::SYS_pidfd_open, pid, 0) } as i32;
|
||||
if fd < 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
@ -19,7 +19,7 @@ impl PidFd for PidFdImpl {
|
||||
}
|
||||
|
||||
fn exit_status(&self) -> io::Result<i32> {
|
||||
let status = 0;
|
||||
let mut status = 0;
|
||||
let res = unsafe { libc::waitpid(self.pid, &mut status, 0) };
|
||||
if res < 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
|
47
userspace/lib/cross/src/sys/unix/pipe.rs
Normal file
47
userspace/lib/cross/src/sys/unix/pipe.rs
Normal file
@ -0,0 +1,47 @@
|
||||
use std::{io::{self, Read, Write}, os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}};
|
||||
|
||||
use crate::sys::Pipe;
|
||||
|
||||
pub struct PipeImpl {
|
||||
fd: OwnedFd,
|
||||
}
|
||||
|
||||
impl Pipe for PipeImpl {
|
||||
fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)> {
|
||||
let mut fds = [0; 2];
|
||||
let res = unsafe { libc::pipe(fds.as_mut_ptr()) };
|
||||
if res < 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
let read = unsafe { OwnedFd::from_raw_fd(fds[0]) };
|
||||
let write = unsafe { OwnedFd::from_raw_fd(fds[1]) };
|
||||
|
||||
Ok((Self { fd: read }, Self { fd: write }))
|
||||
}
|
||||
|
||||
fn to_child_stdio(&self) -> std::process::Stdio {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for PipeImpl {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for PipeImpl {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRawFd for PipeImpl {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.fd.as_raw_fd()
|
||||
}
|
||||
}
|
@ -1,7 +1,9 @@
|
||||
pub mod poll;
|
||||
pub mod timer;
|
||||
pub mod pid;
|
||||
pub mod pipe;
|
||||
|
||||
pub use poll::PollImpl;
|
||||
pub use timer::TimerFdImpl;
|
||||
pub use pid::PidFdImpl;
|
||||
pub use pipe::PipeImpl;
|
||||
|
50
userspace/lib/cross/src/sys/yggdrasil/pipe.rs
Normal file
50
userspace/lib/cross/src/sys/yggdrasil/pipe.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use std::{
|
||||
fs::File,
|
||||
io::{self, Read, Write},
|
||||
os::{
|
||||
fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
|
||||
yggdrasil::io::pipe::create_pipe_pair,
|
||||
},
|
||||
process::Stdio,
|
||||
};
|
||||
|
||||
use crate::sys::Pipe;
|
||||
|
||||
pub struct PipeImpl {
|
||||
file: File,
|
||||
}
|
||||
|
||||
impl Pipe for PipeImpl {
|
||||
fn new(read_nonblocking: bool, write_nonblocking: bool) -> io::Result<(Self, Self)> {
|
||||
let (read, write) = create_pipe_pair(read_nonblocking, write_nonblocking)?;
|
||||
let read = unsafe { File::from_raw_fd(read.into_raw_fd()) };
|
||||
let write = unsafe { File::from_raw_fd(write.into_raw_fd()) };
|
||||
Ok((Self { file: read }, Self { file: write }))
|
||||
}
|
||||
|
||||
fn to_child_stdio(&self) -> Stdio {
|
||||
unsafe { Stdio::from_raw_fd(self.as_raw_fd()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for PipeImpl {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.file.write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.file.flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for PipeImpl {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.file.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRawFd for PipeImpl {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.file.as_raw_fd()
|
||||
}
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os))]
|
||||
#![feature(generic_const_exprs, portable_simd, if_let_guard)]
|
||||
#![feature(generic_const_exprs, portable_simd, if_let_guard, let_chains)]
|
||||
#![allow(incomplete_features)]
|
||||
|
||||
pub mod socket;
|
||||
|
@ -2,7 +2,7 @@
|
||||
#![feature(let_chains)]
|
||||
|
||||
use std::{
|
||||
io::{stdout, IsTerminal, Read, Write},
|
||||
io::{stderr, stdout, IsTerminal, Read, Stderr, Write},
|
||||
ops::{Deref, DerefMut},
|
||||
os::fd::AsRawFd,
|
||||
};
|
||||
@ -11,11 +11,11 @@ use clap::Parser;
|
||||
use cross::io::Poll;
|
||||
use rsh::{
|
||||
crypt::{
|
||||
client::{self, ClientSocket},
|
||||
client::{self, ClientSocket, Message},
|
||||
config::{ClientConfig, SimpleClientKeyStore},
|
||||
signature::{SignEd25519, SignatureMethod},
|
||||
},
|
||||
proto::ServerMessage,
|
||||
proto::{ServerMessage, StreamIndex},
|
||||
};
|
||||
|
||||
use std::{
|
||||
@ -70,8 +70,9 @@ enum Input {
|
||||
pub struct Client {
|
||||
poll: Poll,
|
||||
socket: ClientSocket,
|
||||
input: Input,
|
||||
stdin: RawStdin,
|
||||
stdout: Stdout,
|
||||
stderr: Stderr,
|
||||
need_bye: bool,
|
||||
last0: u8,
|
||||
last1: u8,
|
||||
@ -124,30 +125,14 @@ impl Drop for RawStdin {
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn connect(
|
||||
remote: SocketAddr,
|
||||
crypto_config: ClientConfig,
|
||||
command: Vec<String>,
|
||||
) -> Result<Self, Error> {
|
||||
pub fn connect(remote: SocketAddr, crypto_config: ClientConfig) -> Result<Self, Error> {
|
||||
let mut poll = Poll::new()?;
|
||||
let mut socket = ClientSocket::connect(remote, crypto_config)?;
|
||||
let input = match command.is_empty() {
|
||||
true => {
|
||||
let stdin = RawStdin::open()?;
|
||||
poll.add(&stdin)?;
|
||||
Input::Stdin(stdin)
|
||||
},
|
||||
false => {
|
||||
let mut bytes = vec![];
|
||||
for command in command {
|
||||
bytes.extend_from_slice(command.as_bytes());
|
||||
bytes.push(b' ');
|
||||
}
|
||||
Input::Command(bytes)
|
||||
},
|
||||
};
|
||||
let stdout = stdout();
|
||||
let stderr = stderr();
|
||||
|
||||
poll.add(&stdin)?;
|
||||
poll.add(&socket)?;
|
||||
|
||||
let info = terminal_info(&stdout)?;
|
||||
@ -155,8 +140,9 @@ impl Client {
|
||||
Self::handshake(&mut socket, info)?;
|
||||
|
||||
Ok(Self {
|
||||
input,
|
||||
stdin,
|
||||
stdout,
|
||||
stderr,
|
||||
socket,
|
||||
poll,
|
||||
need_bye: false,
|
||||
@ -187,19 +173,20 @@ impl Client {
|
||||
pub fn run(mut self) -> Result<String, Error> {
|
||||
let mut recv_buf = [0; 512];
|
||||
|
||||
if let Input::Command(command) = &self.input {
|
||||
self.socket.write_all(&ClientMessage::Input(&command[..]))?;
|
||||
}
|
||||
|
||||
loop {
|
||||
if let Some(message) = self.socket.read(&mut recv_buf)? {
|
||||
match message {
|
||||
ServerMessage::Bye(reason) => return Ok(reason.into()),
|
||||
ServerMessage::Output(data) => {
|
||||
ServerMessage::Output(StreamIndex::Stdout, data) => {
|
||||
self.stdout.write_all(data).ok();
|
||||
self.stdout.flush().ok();
|
||||
continue;
|
||||
}
|
||||
ServerMessage::Output(StreamIndex::Stderr, data) => {
|
||||
self.stderr.write_all(data).ok();
|
||||
self.stderr.flush().ok();
|
||||
continue;
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
@ -210,8 +197,8 @@ impl Client {
|
||||
if self.socket.poll()? == 0 {
|
||||
return Ok("".into());
|
||||
}
|
||||
} else if let Input::Stdin(stdin) = &mut self.input && stdin.as_raw_fd() == fd {
|
||||
let len = stdin.read(&mut recv_buf)?;
|
||||
} else if self.stdin.as_raw_fd() == fd {
|
||||
let len = self.stdin.read(&mut recv_buf)?;
|
||||
self.update_last(&recv_buf[..len])?;
|
||||
self.socket
|
||||
.write_all(&ClientMessage::Input(&recv_buf[..len]))?;
|
||||
@ -247,27 +234,101 @@ fn terminal_info(stdout: &Stdout) -> Result<TerminalInfo, Error> {
|
||||
})
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<(), Error> {
|
||||
fn run_terminal(remote: SocketAddr, config: ClientConfig) -> Result<(), Error> {
|
||||
let reason = Client::connect(remote, config)?.run()?;
|
||||
if !reason.is_empty() {
|
||||
eprintln!("\nDisconnected: {reason}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_command(
|
||||
remote: SocketAddr,
|
||||
config: ClientConfig,
|
||||
command: Vec<String>,
|
||||
) -> Result<ExitCode, Error> {
|
||||
let mut poll = Poll::new()?;
|
||||
let mut buffer = [0; 512];
|
||||
let mut command_string = String::new();
|
||||
for (i, word) in command.iter().enumerate() {
|
||||
if i != 0 {
|
||||
command_string.push(' ');
|
||||
}
|
||||
command_string.push_str(word);
|
||||
}
|
||||
|
||||
let mut stdin = stdin();
|
||||
let mut stdout = stdout();
|
||||
let mut stderr = stderr();
|
||||
|
||||
let mut socket = ClientSocket::connect(remote, config)?;
|
||||
|
||||
poll.add(&socket)?;
|
||||
poll.add(&stdin)?;
|
||||
|
||||
socket.write_all(&ClientMessage::RunCommand(command_string.as_str()))?;
|
||||
|
||||
loop {
|
||||
let fd = poll.wait(None)?.unwrap();
|
||||
|
||||
match fd {
|
||||
_ if fd == socket.as_raw_fd() => {
|
||||
let message = match socket.poll_read(&mut buffer)? {
|
||||
Message::Data(data) => data,
|
||||
Message::Incomplete => continue,
|
||||
Message::Closed => break
|
||||
};
|
||||
|
||||
match message {
|
||||
ServerMessage::Output(StreamIndex::Stdout, output) => {
|
||||
stdout.write_all(output).ok();
|
||||
stdout.flush().ok();
|
||||
}
|
||||
ServerMessage::Output(StreamIndex::Stderr, output) => {
|
||||
stderr.write_all(output).ok();
|
||||
stderr.flush().ok();
|
||||
}
|
||||
_ => todo!()
|
||||
}
|
||||
}
|
||||
_ if fd == stdin.as_raw_fd() => {
|
||||
let len = stdin.read(&mut buffer)?;
|
||||
if len == 0 {
|
||||
poll.remove(&stdin)?;
|
||||
socket.write_all(&ClientMessage::CloseStdin)?;
|
||||
} else {
|
||||
socket.write_all(&ClientMessage::Input(&buffer[..len]))?;
|
||||
}
|
||||
}
|
||||
_ => unreachable!()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Ok(ExitCode::SUCCESS)
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<ExitCode, Error> {
|
||||
let remote = SocketAddr::new(args.remote, args.port);
|
||||
let ed25519 = SignEd25519::load_signing_key(args.key).unwrap();
|
||||
let key = SignatureMethod::Ed25519(ed25519);
|
||||
let config = ClientConfig::with_default_algorithms(SimpleClientKeyStore::new(key));
|
||||
let reason = Client::connect(remote, config, args.command)?.run()?;
|
||||
if !reason.is_empty() {
|
||||
eprintln!("\nDisconnected: {reason}");
|
||||
if args.command.is_empty() {
|
||||
run_terminal(remote, config).map(|_| ExitCode::SUCCESS)
|
||||
} else {
|
||||
run_command(remote, config, args.command)
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> ExitCode {
|
||||
env_logger::init();
|
||||
let args = Args::parse();
|
||||
|
||||
if let Err(error) = run(args) {
|
||||
match run(args) {
|
||||
Ok(status) => status,
|
||||
Err(error) => {
|
||||
eprintln!("Error: {error}");
|
||||
ExitCode::FAILURE
|
||||
} else {
|
||||
ExitCode::SUCCESS
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -13,6 +13,13 @@ pub enum ClientMessage<'a> {
|
||||
RunCommand(&'a str),
|
||||
Bye(&'a str),
|
||||
Input(&'a [u8]),
|
||||
CloseStdin,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum StreamIndex {
|
||||
Stdout,
|
||||
Stderr,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -20,7 +27,7 @@ pub enum ServerMessage<'a> {
|
||||
SessionOpen,
|
||||
CommandStatus(i32),
|
||||
Bye(&'a str),
|
||||
Output(&'a [u8]),
|
||||
Output(StreamIndex, &'a [u8]),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@ -164,6 +171,32 @@ impl ClientMessage<'_> {
|
||||
const TAG_RUN_COMMAND: u8 = 0x81;
|
||||
const TAG_BYE: u8 = 0x82;
|
||||
const TAG_INPUT: u8 = 0x90;
|
||||
const TAG_CLOSE_STDIN: u8 = 0x91;
|
||||
}
|
||||
|
||||
impl StreamIndex {
|
||||
const TAG_STDOUT: u8 = 0x01;
|
||||
const TAG_STDERR: u8 = 0x02;
|
||||
}
|
||||
|
||||
impl Encode for StreamIndex {
|
||||
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
|
||||
match self {
|
||||
Self::Stdout => buffer.write(&[Self::TAG_STDOUT]),
|
||||
Self::Stderr => buffer.write(&[Self::TAG_STDERR]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de> for StreamIndex {
|
||||
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
|
||||
let tag = buffer.read_u8()?;
|
||||
match tag {
|
||||
Self::TAG_STDOUT => Ok(Self::Stdout),
|
||||
Self::TAG_STDERR => Ok(Self::Stderr),
|
||||
_ => Err(DecodeError::InvalidMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode for TerminalInfo {
|
||||
@ -201,6 +234,7 @@ impl<'a> Encode for ClientMessage<'a> {
|
||||
buffer.write(&[Self::TAG_INPUT])?;
|
||||
buffer.write_variable_bytes(data)
|
||||
}
|
||||
Self::CloseStdin => buffer.write(&[Self::TAG_CLOSE_STDIN])
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -221,6 +255,9 @@ impl<'de> Decode<'de> for ClientMessage<'de> {
|
||||
Self::TAG_INPUT => {
|
||||
buffer.read_variable_bytes().map(Self::Input)
|
||||
}
|
||||
Self::TAG_CLOSE_STDIN => {
|
||||
Ok(Self::CloseStdin)
|
||||
}
|
||||
_ => Err(DecodeError::InvalidMessage)
|
||||
}
|
||||
}
|
||||
@ -244,8 +281,9 @@ impl<'a> Encode for ServerMessage<'a> {
|
||||
buffer.write(&[Self::TAG_BYE])?;
|
||||
buffer.write_str(reason)
|
||||
}
|
||||
Self::Output(data) => {
|
||||
Self::Output(index, data) => {
|
||||
buffer.write(&[Self::TAG_OUTPUT])?;
|
||||
index.encode(buffer)?;
|
||||
buffer.write_variable_bytes(data)
|
||||
}
|
||||
}
|
||||
@ -267,7 +305,9 @@ impl<'de> Decode<'de> for ServerMessage<'de> {
|
||||
buffer.read_str().map(Self::Bye)
|
||||
}
|
||||
Self::TAG_OUTPUT => {
|
||||
buffer.read_variable_bytes().map(Self::Output)
|
||||
let index = StreamIndex::decode(buffer)?;
|
||||
let data = buffer.read_variable_bytes()?;
|
||||
Ok(Self::Output(index, data))
|
||||
},
|
||||
_ => Err(DecodeError::InvalidMessage),
|
||||
}
|
||||
|
@ -1,12 +1,15 @@
|
||||
use std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
fmt, io,
|
||||
fmt,
|
||||
hint::unreachable_unchecked,
|
||||
io::{self, Read, Stdout, Write},
|
||||
net::SocketAddr,
|
||||
os::fd::{AsRawFd, RawFd},
|
||||
process::{Child, Command, Stdio},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use cross::io::Poll;
|
||||
use cross::io::{PidFd, Pipe, Poll};
|
||||
|
||||
use crate::{
|
||||
crypt::{
|
||||
@ -14,7 +17,7 @@ use crate::{
|
||||
config::ServerConfig,
|
||||
server::{self, ServerSocket},
|
||||
},
|
||||
proto::{ClientMessage, ServerMessage, TerminalInfo},
|
||||
proto::{ClientMessage, ServerMessage, StreamIndex, TerminalInfo},
|
||||
};
|
||||
|
||||
pub const PING_INTERVAL: Duration = Duration::from_millis(500);
|
||||
@ -52,22 +55,28 @@ pub struct EchoSession {
|
||||
peer: SocketAddr,
|
||||
}
|
||||
|
||||
struct ClientSession<T: Session> {
|
||||
stream: ClientSocket,
|
||||
session: Option<T>,
|
||||
struct PendingCommand {
|
||||
child: Child,
|
||||
child_pid: Option<PidFd>,
|
||||
stdin: Option<Pipe>,
|
||||
stdout: Option<Pipe>,
|
||||
stderr: Option<Pipe>,
|
||||
}
|
||||
|
||||
// enum Event<'b, T: Session> {
|
||||
// NewSession(SocketAddr, TerminalInfo),
|
||||
// SessionInput(u64, SocketAddr, &'b [u8]),
|
||||
// ClientBye(SocketAddr, &'b str),
|
||||
// SessionEvent(RawFd, SocketAddr, SessionEvent<'b, T>),
|
||||
// Tick,
|
||||
// }
|
||||
enum ClientSession<T: Session> {
|
||||
None,
|
||||
Terminal(T),
|
||||
Command(PendingCommand),
|
||||
}
|
||||
|
||||
struct SessionSet<T: Session> {
|
||||
struct Client<T: Session> {
|
||||
stream: ClientSocket,
|
||||
session: ClientSession<T>,
|
||||
}
|
||||
|
||||
struct ClientSet<T: Session> {
|
||||
last_session_key: u64,
|
||||
sessions: HashMap<u64, ClientSession<T>>,
|
||||
clients: HashMap<u64, Client<T>>,
|
||||
socket_fd_map: HashMap<RawFd, u64>,
|
||||
session_fd_map: HashMap<RawFd, u64>,
|
||||
}
|
||||
@ -76,29 +85,63 @@ pub struct Server<T: Session> {
|
||||
poll: Poll,
|
||||
socket: ServerSocket,
|
||||
|
||||
sessions: SessionSet<T>,
|
||||
clients: ClientSet<T>,
|
||||
}
|
||||
|
||||
impl<T: Session> SessionSet<T> {
|
||||
enum CommandEvent {
|
||||
Output(StreamIndex, Result<usize, io::Error>),
|
||||
Exited(Result<i32, io::Error>)
|
||||
}
|
||||
|
||||
impl<T: Session> Client<T> {
|
||||
fn make_command(text: &str) -> Result<PendingCommand, Error> {
|
||||
let mut words = text.split(' ');
|
||||
let program = words.next().unwrap();
|
||||
let mut command = Command::new(program);
|
||||
|
||||
let (stdin_read, stdin_write) = Pipe::new(false, false)?;
|
||||
let (stdout_read, stdout_write) = Pipe::new(true, false)?;
|
||||
let (stderr_read, stderr_write) = Pipe::new(true, false)?;
|
||||
|
||||
command
|
||||
.args(words)
|
||||
.stdin(stdin_read.to_child_stdio())
|
||||
.stdout(stdout_write.to_child_stdio())
|
||||
.stderr(stderr_write.to_child_stdio());
|
||||
|
||||
let child = command.spawn()?;
|
||||
let child_pid = PidFd::new(child.id())?;
|
||||
|
||||
Ok(PendingCommand {
|
||||
child,
|
||||
child_pid: Some(child_pid),
|
||||
stdout: Some(stdout_read),
|
||||
stderr: Some(stderr_read),
|
||||
stdin: Some(stdin_write),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Session> ClientSet<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
last_session_key: 1,
|
||||
sessions: HashMap::new(),
|
||||
clients: HashMap::new(),
|
||||
socket_fd_map: HashMap::new(),
|
||||
session_fd_map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_client(&mut self, stream: ClientSocket, poll: &mut Poll) -> Result<(), Error> {
|
||||
let (key, session) = loop {
|
||||
let (key, client) = loop {
|
||||
let key = self.last_session_key;
|
||||
self.last_session_key += 1;
|
||||
|
||||
match self.sessions.entry(key) {
|
||||
match self.clients.entry(key) {
|
||||
Entry::Vacant(entry) => {
|
||||
let session = entry.insert(ClientSession {
|
||||
let session = entry.insert(Client {
|
||||
stream,
|
||||
session: None,
|
||||
session: ClientSession::None,
|
||||
});
|
||||
break (key, session);
|
||||
}
|
||||
@ -106,23 +149,40 @@ impl<T: Session> SessionSet<T> {
|
||||
}
|
||||
};
|
||||
|
||||
poll.add(&session.stream)?;
|
||||
self.socket_fd_map.insert(session.stream.as_raw_fd(), key);
|
||||
poll.add(&client.stream)?;
|
||||
self.socket_fd_map.insert(client.stream.as_raw_fd(), key);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove(&mut self, key: u64, poll: &mut Poll) -> Result<(), Error> {
|
||||
if let Some(session) = self.sessions.remove(&key) {
|
||||
poll.remove(&session.stream)?;
|
||||
self.socket_fd_map.remove(&session.stream.as_raw_fd());
|
||||
if let Some(client) = self.clients.remove(&key) {
|
||||
poll.remove(&client.stream)?;
|
||||
self.socket_fd_map.remove(&client.stream.as_raw_fd());
|
||||
|
||||
if let Some(session) = session.session {
|
||||
for fd in session.event_fds() {
|
||||
match client.session {
|
||||
ClientSession::Terminal(terminal) => {
|
||||
for fd in terminal.event_fds() {
|
||||
poll.remove(fd)?;
|
||||
self.session_fd_map.remove(fd);
|
||||
}
|
||||
}
|
||||
ClientSession::Command(mut command) => {
|
||||
if let Some(stdout) = command.stdout.take() {
|
||||
poll.remove(&stdout)?;
|
||||
self.session_fd_map.remove(&stdout.as_raw_fd());
|
||||
}
|
||||
if let Some(stderr) = command.stderr.take() {
|
||||
poll.remove(&stderr)?;
|
||||
self.session_fd_map.remove(&stderr.as_raw_fd());
|
||||
}
|
||||
if let Some(child_pid) = command.child_pid.take() {
|
||||
poll.remove(&child_pid)?;
|
||||
self.session_fd_map.remove(&child_pid.as_raw_fd());
|
||||
}
|
||||
}
|
||||
ClientSession::None => (),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -131,10 +191,10 @@ impl<T: Session> SessionSet<T> {
|
||||
let mut buffer = [0; 512];
|
||||
|
||||
if let Some(&key) = self.socket_fd_map.get(&fd) {
|
||||
let session = self.sessions.get_mut(&key).unwrap();
|
||||
let peer = session.stream.remote_address();
|
||||
let client = self.clients.get_mut(&key).unwrap();
|
||||
let peer = client.stream.remote_address();
|
||||
|
||||
let mut closed = match session.stream.poll() {
|
||||
let mut closed = match client.stream.poll() {
|
||||
Ok(0) => true,
|
||||
Ok(_) => false,
|
||||
Err(error) => {
|
||||
@ -144,7 +204,7 @@ impl<T: Session> SessionSet<T> {
|
||||
};
|
||||
|
||||
loop {
|
||||
let message = match session.stream.read(&mut buffer) {
|
||||
let message = match client.stream.read(&mut buffer) {
|
||||
Ok(Some(message)) => message,
|
||||
Ok(None) => break,
|
||||
Err(error) => {
|
||||
@ -154,8 +214,33 @@ impl<T: Session> SessionSet<T> {
|
||||
}
|
||||
};
|
||||
|
||||
match (&message, &mut session.session) {
|
||||
(ClientMessage::OpenSession(terminal), None) => {
|
||||
match (&message, &mut client.session) {
|
||||
(ClientMessage::RunCommand(command), ClientSession::None) => {
|
||||
log::info!("{peer}: cmd {command:?}");
|
||||
|
||||
let command = match Client::<T>::make_command(command) {
|
||||
Ok(command) => command,
|
||||
Err(error) => {
|
||||
log::error!("{peer}: command error: {error}");
|
||||
return self.remove(key, poll);
|
||||
}
|
||||
};
|
||||
|
||||
let stdout_fd = command.stdout.as_ref().unwrap().as_raw_fd();
|
||||
let stderr_fd = command.stderr.as_ref().unwrap().as_raw_fd();
|
||||
let child_pid_fd = command.child_pid.as_ref().unwrap().as_raw_fd();
|
||||
|
||||
poll.add(&stdout_fd)?;
|
||||
poll.add(&stderr_fd)?;
|
||||
poll.add(&child_pid_fd)?;
|
||||
|
||||
self.session_fd_map.insert(stdout_fd, key);
|
||||
self.session_fd_map.insert(stderr_fd, key);
|
||||
self.session_fd_map.insert(child_pid_fd, key);
|
||||
|
||||
client.session = ClientSession::Command(command);
|
||||
}
|
||||
(ClientMessage::OpenSession(terminal), ClientSession::None) => {
|
||||
log::info!("{peer}: new session");
|
||||
let terminal = match T::open(&peer, &terminal) {
|
||||
Ok(session) => session,
|
||||
@ -165,15 +250,33 @@ impl<T: Session> SessionSet<T> {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let terminal = session.session.insert(terminal);
|
||||
let terminal = client.session.set_terminal(terminal);
|
||||
for fd in terminal.event_fds() {
|
||||
poll.add(fd)?;
|
||||
self.session_fd_map.insert(*fd, key);
|
||||
}
|
||||
session.stream.write_all(&ServerMessage::SessionOpen).ok();
|
||||
|
||||
client.stream.write_all(&ServerMessage::SessionOpen).ok();
|
||||
}
|
||||
(ClientMessage::Input(data), Some(terminal)) => {
|
||||
let client = SessionClient { transport: &mut session.stream };
|
||||
(ClientMessage::Input(data), ClientSession::Command(command)) => {
|
||||
if let Some(stdin) = command.stdin.as_mut() {
|
||||
match stdin.write_all(data) {
|
||||
Ok(()) => {
|
||||
stdin.flush().ok();
|
||||
},
|
||||
Err(error) => {
|
||||
log::error!("{peer}: stdin error: {error}");
|
||||
return self.remove(key, poll);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::warn!("{peer}: stdin already closed");
|
||||
}
|
||||
}
|
||||
(ClientMessage::Input(data), ClientSession::Terminal(terminal)) => {
|
||||
let client = SessionClient {
|
||||
transport: &mut client.stream,
|
||||
};
|
||||
match terminal.handle_input(data, client) {
|
||||
Ok(false) => (),
|
||||
Ok(true) => {
|
||||
@ -186,6 +289,10 @@ impl<T: Session> SessionSet<T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
(ClientMessage::CloseStdin, ClientSession::Command(command)) => {
|
||||
// Drop stdin handle
|
||||
command.stdin.take();
|
||||
}
|
||||
_ => {
|
||||
log::warn!("{peer}: unhandled message");
|
||||
}
|
||||
@ -199,25 +306,89 @@ impl<T: Session> SessionSet<T> {
|
||||
|
||||
Ok(())
|
||||
} else if let Some(&key) = self.session_fd_map.get(&fd) {
|
||||
let session = self.sessions.get_mut(&key).unwrap();
|
||||
let terminal = session.session.as_mut().unwrap();
|
||||
let peer = session.stream.remote_address();
|
||||
log::debug!("poll fd {:?}", fd);
|
||||
let client = self.clients.get_mut(&key).unwrap();
|
||||
let peer = client.stream.remote_address();
|
||||
|
||||
match terminal.read_output(fd, &mut buffer) {
|
||||
Ok(0) => {
|
||||
log::info!("{peer}: session closed");
|
||||
self.remove(key, poll)
|
||||
let (mut len, stream_index) = match &mut client.session {
|
||||
ClientSession::Command(command) => match command.read_output(fd, &mut buffer) {
|
||||
CommandEvent::Output(index, Ok(len)) => {
|
||||
if len == 0 {
|
||||
poll.remove(&fd)?;
|
||||
self.socket_fd_map.remove(&fd);
|
||||
|
||||
match index {
|
||||
StreamIndex::Stdout => {
|
||||
command.stdout = None;
|
||||
}
|
||||
Ok(mut len) => {
|
||||
StreamIndex::Stderr => {
|
||||
command.stderr = None;
|
||||
}
|
||||
}
|
||||
|
||||
if command.is_dead() {
|
||||
log::info!("{peer}: command stdout/stderr closed");
|
||||
return self.remove(key, poll);
|
||||
}
|
||||
}
|
||||
(len, index)
|
||||
},
|
||||
CommandEvent::Output(index, Err(error)) => {
|
||||
if error.kind() == io::ErrorKind::WouldBlock {
|
||||
return Ok(());
|
||||
}
|
||||
log::error!("{peer}: {index:?} error: {error}");
|
||||
return self.remove(key, poll);
|
||||
}
|
||||
CommandEvent::Exited(status) => {
|
||||
log::info!("{peer}: command exited: {:?}", status);
|
||||
poll.remove(&fd)?;
|
||||
self.socket_fd_map.remove(&fd);
|
||||
command.child_pid = None;
|
||||
|
||||
if command.is_dead() {
|
||||
log::info!("{peer}: command stdout/stderr closed");
|
||||
return self.remove(key, poll);
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
},
|
||||
ClientSession::Terminal(terminal) => match terminal.read_output(fd, &mut buffer) {
|
||||
Ok(0) => {
|
||||
poll.remove(&fd)?;
|
||||
self.socket_fd_map.remove(&fd);
|
||||
|
||||
// TODO check for process as well
|
||||
log::info!("{peer}: terminal closed");
|
||||
return self.remove(key, poll);
|
||||
}
|
||||
Ok(len) => {
|
||||
(len, StreamIndex::Stdout)
|
||||
},
|
||||
Err(error) => {
|
||||
log::error!("{peer}: terminal error: {error}");
|
||||
return self.remove(key, poll);
|
||||
}
|
||||
},
|
||||
ClientSession::None => unreachable!(),
|
||||
};
|
||||
|
||||
if len == 0 {
|
||||
log::info!("{peer}: {stream_index:?} closed");
|
||||
return Ok(());
|
||||
// return self.remove(key, poll);
|
||||
}
|
||||
|
||||
// Split output into 128-byte chunks
|
||||
let mut pos = 0;
|
||||
while len != 0 {
|
||||
let amount = core::cmp::min(len, 128);
|
||||
|
||||
if let Err(error) = session
|
||||
.stream
|
||||
.write_all(&ServerMessage::Output(&buffer[pos..pos + amount]))
|
||||
{
|
||||
if let Err(error) = client.stream.write_all(&ServerMessage::Output(
|
||||
stream_index,
|
||||
&buffer[pos..pos + amount],
|
||||
)) {
|
||||
log::error!("{peer}: communication error: {error}");
|
||||
return self.remove(key, poll);
|
||||
}
|
||||
@ -225,19 +396,58 @@ impl<T: Session> SessionSet<T> {
|
||||
pos += amount;
|
||||
len -= amount;
|
||||
}
|
||||
|
||||
log::debug!("Done");
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("{peer}: session read error: {error}");
|
||||
self.remove(key, poll)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PendingCommand {
|
||||
pub fn read_output(
|
||||
&mut self,
|
||||
fd: RawFd,
|
||||
buffer: &mut [u8],
|
||||
) -> CommandEvent {
|
||||
if let Some(stdout) = self.stdout.as_mut() && fd == stdout.as_raw_fd() {
|
||||
log::debug!("poll stdout");
|
||||
let res = stdout.read(buffer);
|
||||
log::debug!(">> {:?}", res);
|
||||
return CommandEvent::Output(StreamIndex::Stdout, res);
|
||||
}
|
||||
if let Some(stderr) = self.stderr.as_mut() && fd == stderr.as_raw_fd() {
|
||||
return CommandEvent::Output(StreamIndex::Stderr, stderr.read(buffer));
|
||||
}
|
||||
if let Some(child_pid) = self.child_pid.as_mut() && fd == child_pid.as_raw_fd() {
|
||||
let status = child_pid.exit_status();
|
||||
return CommandEvent::Exited(status);
|
||||
}
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
pub fn is_dead(&self) -> bool {
|
||||
self.child_pid.is_none() && self.stdout.is_none() && self.stderr.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PendingCommand {
|
||||
fn drop(&mut self) {
|
||||
self.child.wait().ok();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Session> ClientSession<T> {
|
||||
pub fn set_terminal(&mut self, terminal: T) -> &mut T {
|
||||
*self = Self::Terminal(terminal);
|
||||
match self {
|
||||
Self::Terminal(terminal) => terminal,
|
||||
_ => unsafe { unreachable_unchecked() },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Session> Server<T> {
|
||||
pub fn listen(listen_addr: SocketAddr, crypto_config: ServerConfig) -> Result<Self, Error> {
|
||||
let mut poll = Poll::new()?;
|
||||
@ -246,7 +456,7 @@ impl<T: Session> Server<T> {
|
||||
Ok(Self {
|
||||
poll,
|
||||
socket,
|
||||
sessions: SessionSet::new(),
|
||||
clients: ClientSet::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -259,14 +469,14 @@ impl<T: Session> Server<T> {
|
||||
// Poll server socket
|
||||
match self.socket.poll()? {
|
||||
Some(client) => {
|
||||
self.sessions.add_client(client, &mut self.poll)?;
|
||||
self.clients.add_client(client, &mut self.poll)?;
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Client/Session-related activity
|
||||
self.sessions.handle_input(fd, &mut self.poll)?;
|
||||
self.clients.handle_input(fd, &mut self.poll)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -275,7 +485,7 @@ impl<T: Session> Server<T> {
|
||||
|
||||
impl<'s> SessionClient<'s> {
|
||||
pub fn send_data(&mut self, data: &[u8]) -> Result<(), Error> {
|
||||
self.send_message(&ServerMessage::Output(data))
|
||||
self.send_message(&ServerMessage::Output(StreamIndex::Stdout, data))
|
||||
}
|
||||
|
||||
pub fn send_message(&mut self, message: &ServerMessage) -> Result<(), Error> {
|
||||
|
@ -74,7 +74,7 @@ pub fn wait_for_pipeline(
|
||||
}
|
||||
|
||||
pub fn create_pipe() -> Result<Pipe, io::Error> {
|
||||
let (read, write) = pipe::create_pipe_pair()?;
|
||||
let (read, write) = pipe::create_pipe_pair(false, false)?;
|
||||
Ok(Pipe { read, write })
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user