vfs: add pidfd

This commit is contained in:
Mark Poliakov 2024-11-06 10:45:21 +02:00
parent 67cf3673ca
commit b668add453
19 changed files with 277 additions and 53 deletions

View File

@ -202,7 +202,7 @@ impl TcpConnection {
pub fn read_nonblocking(&mut self, buffer: &mut [u8]) -> Result<usize, Error> {
let amount = self.rx_buffer.read(buffer);
if amount == 0 && self.state != TcpConnectionState::Established {
if amount == 0 && self.state != TcpConnectionState::Established && !self.is_closing() {
// TODO ConnectionAborted?
return Err(Error::ConnectionReset);
}

View File

@ -1,7 +1,7 @@
use core::{
future::poll_fn,
sync::atomic::{AtomicBool, Ordering},
task::Poll,
task::{Context, Poll},
};
use crate::{sync::spin_rwlock::IrqSafeRwLock, waker::QueueWaker};
@ -64,18 +64,19 @@ impl<T> OneTimeEvent<T> {
*self.value.read()
}
pub fn poll(&self, cx: &mut Context<'_>) -> Poll<()> {
self.notify.register(cx.waker());
if self.is_signalled() {
self.notify.remove(cx.waker());
Poll::Ready(())
} else {
Poll::Pending
}
}
/// Waits for event to happen without returning anything
pub async fn wait(&self) {
poll_fn(|cx| {
self.notify.register(cx.waker());
if self.is_signalled() {
self.notify.remove(cx.waker());
Poll::Ready(())
} else {
Poll::Pending
}
})
.await
poll_fn(|cx| self.poll(cx)).await
}
/// Waits for event to happen and copies its data as a return value

View File

@ -1,4 +1,7 @@
use core::sync::atomic::{AtomicU32, Ordering};
use core::{
sync::atomic::{AtomicU32, Ordering},
task::{Context, Poll},
};
use abi_lib::SyscallRegister;
use alloc::{
@ -32,7 +35,7 @@ use crate::{
types::{AllocateProcessId, ProcessTlsInfo},
TaskContextImpl, ThreadId,
},
vfs::{FileSet, IoContext, NodeRef},
vfs::{FileReadiness, FileSet, IoContext, NodeRef},
};
pub trait ForkFrame = kernel_arch::task::ForkFrame<KernelTableManagerImpl, GlobalPhysicalAllocator>;
@ -412,6 +415,12 @@ impl Process {
}
}
impl FileReadiness for Process {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.exit.poll(cx).map(Ok)
}
}
impl ProcessInner {
pub fn new(
id: ProcessId,

View File

@ -22,13 +22,16 @@ use yggdrasil_abi::{
},
};
use crate::vfs::{
channel::ChannelDescriptor,
device::{BlockDeviceWrapper, CharDeviceWrapper},
node::NodeRef,
traits::{Read, Seek, Write},
FdPoll, FileReadiness, Node, PseudoTerminalMaster, PseudoTerminalSlave, SharedMemory,
TimerFile,
use crate::{
task::process::Process,
vfs::{
channel::ChannelDescriptor,
device::{BlockDeviceWrapper, CharDeviceWrapper},
node::NodeRef,
traits::{Read, Seek, Write},
FdPoll, FileReadiness, Node, PseudoTerminalMaster, PseudoTerminalSlave, SharedMemory,
TimerFile,
},
};
use self::{
@ -38,7 +41,7 @@ use self::{
regular::RegularFile,
};
use super::{pty, socket::SocketWrapper};
use super::{pid::PidFile, pty, socket::SocketWrapper};
mod device;
mod directory;
@ -75,6 +78,7 @@ pub enum File {
SharedMemory(Arc<SharedMemory>),
PtySlave(TerminalHalfWrapper<PseudoTerminalSlave>),
PtyMaster(TerminalHalfWrapper<PseudoTerminalMaster>),
Pid(PidFile),
}
#[async_trait]
@ -154,6 +158,11 @@ impl File {
Arc::new(Self::Timer(TimerFile::new(repeat, blocking)))
}
/// Creates a new [PidFile]-backed file
pub fn new_pid(process: &Arc<Process>) -> FileRef {
Arc::new(Self::Pid(PidFile::new(process)))
}
/// Constructs a [File] from a [PacketSocket], [ConnectionSocket] or a [ListenerSocket].
pub fn from_socket<S: Into<SocketWrapper>>(socket: S) -> Arc<Self> {
Arc::new(Self::Socket(socket.into()))
@ -237,6 +246,7 @@ impl File {
Self::Block(_) => todo!(),
Self::Regular(file) => Ok(Arc::new(Self::Regular(file.clone()))),
Self::SharedMemory(shm) => Ok(Arc::new(Self::SharedMemory(shm.clone()))),
Self::Pid(pid) => Ok(Arc::new(Self::Pid(pid.clone()))),
Self::PtySlave(half) => Ok(Arc::new(Self::PtySlave(half.clone()))),
Self::PtyMaster(half) => Ok(Arc::new(Self::PtyMaster(half.clone()))),
@ -278,6 +288,7 @@ impl File {
Self::PtySlave(half) => half.half.poll_read(cx),
Self::Socket(socket) => socket.poll_read(cx),
Self::Timer(timer) => timer.poll_read(cx),
Self::Pid(pid) => pid.poll_read(cx),
// Polling not implemented, return ready immediately (XXX ?)
_ => Poll::Ready(Err(Error::NotImplemented)),
}
@ -358,6 +369,7 @@ impl Read for File {
Self::PtySlave(half) => half.read(buf),
Self::PtyMaster(half) => half.read(buf),
Self::Timer(timer) => timer.read(buf),
Self::Pid(pid) => pid.read(buf),
// TODO maybe allow reading FDs from poll channels as if they were regular streams?
Self::Poll(_) => Err(Error::InvalidOperation),
// TODO maybe allow reading messages from Channels?
@ -380,6 +392,8 @@ impl Write for File {
Self::PtySlave(half) => half.write(buf),
Self::PtyMaster(half) => half.write(buf),
Self::Timer(timer) => timer.write(buf),
// TODO allow sending signals via writes to PID FDs?
Self::Pid(_) => Err(Error::InvalidOperation),
// TODO maybe allow adding FDs to poll channels this way
Self::Poll(_) => Err(Error::InvalidOperation),
// TODO maybe allow writing messages to Channels?
@ -440,6 +454,7 @@ impl fmt::Debug for File {
Self::PtySlave(_) => f.debug_struct("PtySlave").finish_non_exhaustive(),
Self::PtyMaster(_) => f.debug_struct("PtyMaster").finish_non_exhaustive(),
Self::Socket(socket) => fmt::Debug::fmt(socket, f),
Self::Pid(pid) => fmt::Debug::fmt(pid, f),
Self::Timer(_) => f.debug_struct("Timer").finish_non_exhaustive(),
}
}

View File

@ -12,6 +12,7 @@ pub(crate) mod file;
pub(crate) mod ioctx;
pub(crate) mod node;
pub(crate) mod path;
pub(crate) mod pid;
pub(crate) mod poll;
pub(crate) mod shared_memory;
pub(crate) mod socket;

View File

@ -0,0 +1,58 @@
use alloc::sync::{Arc, Weak};
use core::{
fmt,
task::{Context, Poll},
};
use libk_util::io::Read;
use yggdrasil_abi::{error::Error, process::ExitCode};
use crate::task::process::Process;
use super::FileReadiness;
#[derive(Clone)]
pub struct PidFile {
process: Weak<Process>,
}
impl PidFile {
pub fn new(process: &Arc<Process>) -> Self {
Self {
process: Arc::downgrade(process),
}
}
}
impl Read for PidFile {
fn read(&self, buf: &mut [u8]) -> Result<usize, Error> {
if buf.len() < size_of::<i32>() {
return Err(Error::BufferTooSmall);
}
let process = self.process.upgrade().ok_or(Error::DoesNotExist)?;
let exit = block!(process.wait_for_exit().await)?;
match exit {
ExitCode::Exited(code) => buf[..size_of::<i32>()].copy_from_slice(&code.to_le_bytes()),
_ => todo!(),
}
Ok(size_of::<i32>())
}
}
impl FileReadiness for PidFile {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.process.upgrade() {
Some(process) => process.poll_read(cx),
None => Poll::Ready(Err(Error::DoesNotExist)),
}
}
}
impl fmt::Debug for PidFile {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.process.upgrade() {
Some(process) => f.debug_struct("PidFile").field("pid", &process.id).finish(),
None => f.debug_struct("PidFile").field("pid", &"<none>").finish(),
}
}
}

View File

@ -7,11 +7,12 @@ use abi::{
MessageDestination, OpenOptions, PollControl, RawFd, ReceivedMessageMetadata, SeekFrom,
SentMessage, TerminalOptions, TerminalSize, TimerOptions,
},
process::ProcessId,
};
use alloc::boxed::Box;
use libk::{
block,
task::thread::Thread,
task::{process::Process, thread::Thread},
vfs::{self, File, MessagePayload, Read, Seek, Write},
};
@ -242,6 +243,18 @@ pub(crate) fn create_timer(options: TimerOptions) -> Result<RawFd, Error> {
})
}
pub(crate) fn create_pid(pid: ProcessId) -> Result<RawFd, Error> {
let thread = Thread::current();
let process = thread.process();
run_with_io(&process, |mut io| {
let process = Process::get(pid).ok_or(Error::DoesNotExist)?;
let file = File::new_pid(&process);
let fd = io.files.place_file(file, true)?;
Ok(fd)
})
}
pub(crate) fn create_pty(
options: &TerminalOptions,
size: &TerminalSize,

View File

@ -101,6 +101,7 @@ syscall device_request(fd: RawFd, req: &mut DeviceRequest) -> Result<()>;
// Misc I/O
syscall open_channel(name: &str, subscribe: bool) -> Result<RawFd>;
syscall create_timer(options: TimerOptions) -> Result<RawFd>;
syscall create_pid(pid: ProcessId) -> Result<RawFd>;
syscall create_pty(opts: &TerminalOptions, size: &TerminalSize, fds: &mut [MaybeUninit<RawFd>; 2]) -> Result<()>;
syscall create_shared_memory(size: usize) -> Result<RawFd>;
syscall create_poll_channel() -> Result<RawFd>;

View File

@ -4,13 +4,17 @@ use std::{
time::Duration,
};
use crate::sys::{self, Poll as SysPoll, TimerFd as SysTimerFd};
use crate::sys::{self, Poll as SysPoll, TimerFd as SysTimerFd, PidFd as SysPidFd};
#[repr(transparent)]
pub struct Poll(sys::PollImpl);
#[repr(transparent)]
pub struct TimerFd(sys::TimerFdImpl);
#[repr(transparent)]
pub struct PidFd(sys::PidFdImpl);
impl Poll {
pub fn new() -> io::Result<Self> {
sys::PollImpl::new().map(Self)
@ -48,3 +52,19 @@ impl AsRawFd for TimerFd {
self.0.as_raw_fd()
}
}
impl PidFd {
pub fn new(pid: u32) -> io::Result<Self> {
sys::PidFdImpl::new(pid).map(Self)
}
pub fn exit_status(&self) -> io::Result<i32> {
self.0.exit_status()
}
}
impl AsRawFd for PidFd {
fn as_raw_fd(&self) -> RawFd {
self.0.as_raw_fd()
}
}

View File

@ -26,3 +26,8 @@ pub(crate) trait TimerFd: Sized + AsRawFd {
fn start(&mut self, timeout: Duration) -> io::Result<()>;
fn is_expired(&mut self) -> io::Result<bool>;
}
pub(crate) trait PidFd: Sized + AsRawFd {
fn new(pid: u32) -> io::Result<Self>;
fn exit_status(&self) -> io::Result<i32>;
}

View File

@ -1,5 +1,7 @@
pub mod poll;
pub mod timer;
pub mod pid;
pub use poll::PollImpl;
pub use timer::TimerFdImpl;
pub use pid::PidFdImpl;

View File

@ -0,0 +1,35 @@
use std::{io, os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}};
use crate::sys::PidFd;
pub struct PidFdImpl {
fd: OwnedFd,
pid: i32,
}
impl PidFd for PidFdImpl {
fn new(pid: u32) -> io::Result<Self> {
let pid = pid as i32;
let fd = unsafe { libc::pidfd_open(pid) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
let fd = unsafe { OwnedFd::from_raw_fd(fd) };
Ok(Self { fd, pid })
}
fn exit_status(&self) -> io::Result<i32> {
let status = 0;
let res = unsafe { libc::waitpid(self.pid, &mut status, 0) };
if res < 0 {
return Err(io::Error::last_os_error());
}
Ok(status)
}
}
impl AsRawFd for PidFdImpl {
fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
}

View File

@ -1,5 +1,7 @@
pub mod poll;
pub mod timer;
pub mod pid;
pub use poll::PollImpl;
pub use timer::TimerFdImpl;
pub use pid::PidFdImpl;

View File

@ -0,0 +1,23 @@
use std::{io, os::{fd::{AsRawFd, RawFd}, yggdrasil::io::pid::{PidFd as YggPidFd, ProcessId}}};
use crate::sys::PidFd;
pub struct PidFdImpl(YggPidFd);
impl PidFd for PidFdImpl {
fn new(pid: u32) -> io::Result<Self> {
let pid = unsafe { ProcessId::from_raw(pid) };
YggPidFd::new(pid).map(Self)
}
fn exit_status(&self) -> io::Result<i32> {
self.0.status()
}
}
impl AsRawFd for PidFdImpl {
fn as_raw_fd(&self) -> RawFd {
self.0.as_raw_fd()
}
}

0
userspace/rsh/log Normal file
View File

View File

@ -113,11 +113,15 @@ impl Client {
self.need_bye = false;
break Ok(Some(Event::Disconnected(reason)));
}
ServerMessage::CommandStatus(status) => {
log::info!("Command finished: {status}");
break Ok(None);
}
ServerMessage::Output(data) => {
break Ok(Some(Event::Data(data)));
}
// Ignore this one
ServerMessage::Hello => break Ok(None),
ServerMessage::SessionOpen => break Ok(None),
}
}
fd if fd == self.stdin.as_raw_fd() => {
@ -191,14 +195,14 @@ impl Client {
timeout: Duration,
) -> Result<(), Error> {
let mut buffer = [0; 512];
socket.send(&mut buffer, &ClientMessage::Hello(terminal))?;
socket.send(&mut buffer, &ClientMessage::OpenSession(terminal))?;
if poll.wait(Some(timeout))?.is_none() {
return Err(Error::Timeout);
};
let (message, _) = socket.recv_from(&mut buffer)?;
match message {
ServerMessage::Hello => Ok(()),
ServerMessage::SessionOpen => Ok(()),
ServerMessage::Bye(reason) => Err(Error::Disconnected(reason.into())),
_ => Err(Error::Disconnected("Invalid message received".into())),
}

View File

@ -9,14 +9,16 @@ pub struct ServerMessageProxy;
#[derive(Debug)]
pub enum ClientMessage<'a> {
Hello(TerminalInfo),
OpenSession(TerminalInfo),
RunCommand(&'a str),
Bye(&'a str),
Input(&'a [u8]),
}
#[derive(Debug)]
pub enum ServerMessage<'a> {
Hello,
SessionOpen,
CommandStatus(i32),
Bye(&'a str),
Output(&'a [u8]),
}
@ -158,8 +160,9 @@ impl MessageProxy for ServerMessageProxy {
}
impl ClientMessage<'_> {
const TAG_HELLO: u8 = 0x80;
const TAG_BYE: u8 = 0x81;
const TAG_OPEN_SESSION: u8 = 0x80;
const TAG_RUN_COMMAND: u8 = 0x81;
const TAG_BYE: u8 = 0x82;
const TAG_INPUT: u8 = 0x90;
}
@ -182,10 +185,14 @@ impl<'de> Decode<'de> for TerminalInfo {
impl<'a> Encode for ClientMessage<'a> {
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
match self {
Self::Hello(info) => {
buffer.write(&[Self::TAG_HELLO])?;
Self::OpenSession(info) => {
buffer.write(&[Self::TAG_OPEN_SESSION])?;
info.encode(buffer)
}
Self::RunCommand(command) => {
buffer.write(&[Self::TAG_RUN_COMMAND])?;
buffer.write_str(command)
}
Self::Bye(reason) => {
buffer.write(&[Self::TAG_BYE])?;
buffer.write_str(reason)
@ -202,8 +209,11 @@ impl<'de> Decode<'de> for ClientMessage<'de> {
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
let tag = buffer.read_u8()?;
match tag {
Self::TAG_HELLO => {
TerminalInfo::decode(buffer).map(Self::Hello)
Self::TAG_OPEN_SESSION => {
TerminalInfo::decode(buffer).map(Self::OpenSession)
}
Self::TAG_RUN_COMMAND => {
buffer.read_str().map(Self::RunCommand)
}
Self::TAG_BYE => {
buffer.read_str().map(Self::Bye)
@ -217,15 +227,19 @@ impl<'de> Decode<'de> for ClientMessage<'de> {
}
impl ServerMessage<'_> {
const TAG_HELLO: u8 = 0x10;
const TAG_BYE: u8 = 0x11;
const TAG_SESSION_OPEN: u8 = 0x10;
const TAG_COMMAND_STATUS: u8 = 0x11;
const TAG_BYE: u8 = 0x12;
const TAG_OUTPUT: u8 = 0x20;
}
impl<'a> Encode for ServerMessage<'a> {
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
match self {
Self::Hello => buffer.write(&[Self::TAG_HELLO]),
Self::SessionOpen => buffer.write(&[Self::TAG_SESSION_OPEN]),
Self::CommandStatus(status) => {
buffer.write(&status.to_le_bytes())
}
Self::Bye(reason) => {
buffer.write(&[Self::TAG_BYE])?;
buffer.write_str(reason)
@ -242,7 +256,13 @@ impl<'de> Decode<'de> for ServerMessage<'de> {
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
let tag = buffer.read_u8()?;
match tag {
Self::TAG_HELLO => Ok(Self::Hello),
Self::TAG_SESSION_OPEN => Ok(Self::SessionOpen),
Self::TAG_COMMAND_STATUS => {
let mut status = [0; size_of::<u32>()];
let bytes = buffer.read_bytes(size_of::<i32>())?;
status.copy_from_slice(bytes);
Ok(Self::CommandStatus(i32::from_le_bytes(status)))
}
Self::TAG_BYE => {
buffer.read_str().map(Self::Bye)
}

View File

@ -1,15 +1,12 @@
#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os, rustc_private))]
#![feature(if_let_guard)]
use std::{
collections::HashSet,
net::SocketAddr,
path::PathBuf,
process::ExitCode,
str::FromStr,
collections::HashSet, net::SocketAddr, path::PathBuf, process::ExitCode, str::FromStr,
time::Duration,
};
use clap::Parser;
use cross::io::PidFd;
use rsh::{
crypt::{server::ServerConfig, SimpleServerKeyStore},
server::Server,
@ -34,9 +31,10 @@ struct Args {
#[cfg(target_os = "yggdrasil")]
pub struct YggdrasilSession {
pty_master: std::fs::File,
fds: [std::os::fd::RawFd; 1],
fds: [std::os::fd::RawFd; 2],
remote: SocketAddr,
shell: std::process::Child,
pidfd: PidFd,
}
#[cfg(target_os = "yggdrasil")]
@ -78,11 +76,13 @@ impl rsh::server::Session for YggdrasilSession {
.gain_terminal(0)
.spawn()?
};
let pidfd = PidFd::new(shell.id())?;
let fds = [pty_master.as_raw_fd()];
let fds = [pty_master.as_raw_fd(), pidfd.as_raw_fd()];
Ok(Self {
pty_master,
pidfd,
shell,
remote,
fds,
@ -114,8 +114,15 @@ impl rsh::server::Session for YggdrasilSession {
buffer: &mut [u8],
) -> Result<usize, Self::Error> {
use std::io::Read;
assert_eq!(fd, self.fds[0]);
self.pty_master.read(buffer)
if fd == self.fds[0] {
self.pty_master.read(buffer)
} else if fd == self.fds[1] {
let status = self.pidfd.exit_status()?;
log::info!("Shell exited with status: {status}");
Ok(0)
} else {
unreachable!()
}
}
fn event_fds(&self) -> &[std::os::fd::RawFd] {

View File

@ -49,7 +49,7 @@ enum SessionEvent<'b, T: Session> {
}
enum Event<'b, T: Session> {
NewClient(SocketAddr, TerminalInfo),
NewSession(SocketAddr, TerminalInfo),
SessionInput(u64, SocketAddr, &'b [u8]),
ClientBye(SocketAddr, &'b str),
SessionEvent(RawFd, SocketAddr, SessionEvent<'b, T>),
@ -120,10 +120,14 @@ impl<T: Session> Server<UdpSocket, T> {
};
let event = match message {
ClientMessage::Hello(terminal)
ClientMessage::OpenSession(terminal)
if self.peer_to_session.get(&remote).is_none() =>
{
Event::NewClient(remote, terminal)
Event::NewSession(remote, terminal)
}
ClientMessage::RunCommand(command) => {
log::info!("TODO: RunCommand");
return Ok(None);
}
ClientMessage::Bye(reason) => Event::ClientBye(remote, reason),
ClientMessage::Input(data)
@ -199,13 +203,17 @@ impl<T: Session> Server<UdpSocket, T> {
log::debug!("Client {remote} disconnected: {reason}");
self.remove_session_by_remote(remote)?;
}
Event::NewClient(remote, terminal) => {
Event::NewSession(remote, terminal) => {
log::debug!("New client: {remote}");
match T::open(&remote, &terminal) {
Ok(session) => {
self.register_session(remote, session)?;
self.socket
.send_message_to(&remote, &mut send_buf, &ServerMessage::Hello)
.send_message_to(
&remote,
&mut send_buf,
&ServerMessage::SessionOpen,
)
.ok();
}
Err(err) => {