yggdrasil/kernel/libk/src/vfs/channel.rs

216 lines
5.7 KiB
Rust

use core::{
pin::Pin,
sync::atomic::{AtomicU32, Ordering},
task::{Context, Poll},
};
use alloc::{
boxed::Box,
collections::{BTreeMap, VecDeque},
string::String,
sync::Arc,
};
use futures_util::{task::AtomicWaker, Future};
use libk_util::sync::{IrqSafeSpinlock, LockMethod};
use yggdrasil_abi::{
error::Error,
io::{ChannelPublisherId, MessageDestination},
};
use crate::{
task::sync::Mutex,
vfs::{FileReadiness, FileRef},
};
/// Describes a channel over which messages can be sent to [Subscription]s
pub struct Channel {
last_id: AtomicU32,
subscriptions: Mutex<BTreeMap<u32, Arc<Subscription>>>,
}
/// Describes message payload
pub enum MessagePayload {
/// Payload contains a file
File(FileRef),
/// Payload contains byte data
Data(Box<[u8]>),
}
/// Describes a message sent over a channel
pub struct Message {
/// Channel descriptor ID from which the message came
pub source: ChannelPublisherId,
/// Data of the message
pub payload: MessagePayload,
}
/// Describes a single subscription so some [Channel]
pub struct Subscription {
queue: Mutex<VecDeque<Arc<Message>>>,
notify: AtomicWaker,
}
/// Describes a pair of a [Channel] descriptor plus an optional [Subscription]
pub struct ChannelDescriptor {
id: u32,
tx: Arc<Channel>,
rx: Option<Arc<Subscription>>,
}
impl ChannelDescriptor {
/// Opens a channel descriptor, optionally creating a subscription to it
pub fn open(name: &str, subscribe: bool) -> ChannelDescriptor {
let tx = Channel::get_or_create(name.into());
// NOTE The first one to open the channel is guaranteed to get an ID of 0
let id = tx.last_id.fetch_add(1, Ordering::SeqCst);
let rx = if subscribe {
Some(tx.subscribe(id))
} else {
None
};
Self { tx, rx, id }
}
/// Receives a message from the subscription
pub fn receive_message(&self) -> Result<Arc<Message>, Error> {
let Some(rx) = self.rx.as_ref() else {
return Err(Error::InvalidOperation);
};
rx.receive_message_inner()
}
/// Asynchronously receives a message from the subscription
pub async fn receive_message_async(&self) -> Result<Arc<Message>, Error> {
let rx = self.rx.as_ref().ok_or(Error::InvalidOperation)?;
rx.receive_message_async().await
}
/// Sends a message to the channel
pub fn send_message(
&self,
payload: MessagePayload,
dst: MessageDestination,
) -> Result<(), Error> {
let message = Arc::new(Message {
source: unsafe { ChannelPublisherId::from_raw(self.id) },
payload,
});
let lock = self.tx.subscriptions.lock()?;
match dst {
MessageDestination::Specific(id) => {
if let Some(sub) = lock.get(&id) {
sub.push_message(message)?;
}
}
MessageDestination::AllExceptSelf => {
for (&id, sub) in lock.iter() {
if id == self.id {
continue;
}
sub.push_message(message.clone())?;
}
}
MessageDestination::All => todo!(),
}
Ok(())
}
}
impl Channel {
fn new() -> Arc<Channel> {
Arc::new(Self {
last_id: AtomicU32::new(0),
subscriptions: Mutex::new(BTreeMap::new()),
})
}
fn get_or_create(name: String) -> Arc<Channel> {
let mut channels = CHANNELS.lock();
channels.entry(name).or_insert_with(Self::new).clone()
}
fn subscribe(&self, id: u32) -> Arc<Subscription> {
let mut lock = self.subscriptions.lock().unwrap();
let sub = Arc::new(Subscription {
queue: Mutex::new(VecDeque::new()),
notify: AtomicWaker::new(),
});
lock.insert(id, sub.clone());
sub
}
}
impl Subscription {
fn receive_message_async(&self) -> impl Future<Output = Result<Arc<Message>, Error>> + '_ {
struct F<'f> {
rx: &'f Subscription,
}
impl<'f> Future for F<'f> {
type Output = Result<Arc<Message>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut lock = self.rx.queue.lock()?;
if let Some(msg) = lock.pop_front() {
return Poll::Ready(Ok(msg));
}
drop(lock);
self.rx.notify.register(cx.waker());
let mut lock = self.rx.queue.lock()?;
if let Some(msg) = lock.pop_front() {
Poll::Ready(Ok(msg))
} else {
Poll::Pending
}
}
}
F { rx: self }
}
fn receive_message_inner(&self) -> Result<Arc<Message>, Error> {
block!(self.receive_message_async().await)?
}
fn push_message(&self, msg: Arc<Message>) -> Result<(), Error> {
self.queue.lock()?.push_back(msg);
self.notify.wake();
Ok(())
}
}
impl FileReadiness for ChannelDescriptor {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
let Some(rx) = self.rx.as_ref() else {
return Poll::Ready(Err(Error::InvalidOperation));
};
if !rx.queue.lock()?.is_empty() {
return Poll::Ready(Ok(()));
}
rx.notify.register(cx.waker());
if !rx.queue.lock()?.is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
static CHANNELS: IrqSafeSpinlock<BTreeMap<String, Arc<Channel>>> =
IrqSafeSpinlock::new(BTreeMap::new());