net: rework socket subsystem

This commit is contained in:
Mark Poliakov 2024-11-04 10:53:51 +02:00
parent 98816e0ebc
commit d1c1360926
11 changed files with 1429 additions and 1187 deletions

1
Cargo.lock generated
View File

@ -2089,6 +2089,7 @@ dependencies = [
name = "ygg_driver_net_core"
version = "0.1.0"
dependencies = [
"async-trait",
"bytemuck",
"kernel-fs",
"libk",

View File

@ -12,6 +12,7 @@ libk.workspace = true
kernel-fs = { path = "../../fs/kernel-fs" }
async-trait.workspace = true
log.workspace = true
bytemuck.workspace = true
serde_json.workspace = true

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,135 @@
use alloc::{collections::BTreeMap, sync::Arc};
use libk::vfs::Socket;
use yggdrasil_abi::{
error::Error,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
};
pub mod udp;
pub use udp::UdpSocket;
pub mod tcp;
pub use tcp::{TcpListener, TcpSocket};
pub mod raw;
pub use raw::RawSocket;
pub struct SocketTable<T: Socket> {
inner: BTreeMap<SocketAddr, Arc<T>>,
}
pub struct TwoWaySocketTable<T> {
inner: BTreeMap<(SocketAddr, SocketAddr), Arc<T>>,
}
impl<T> TwoWaySocketTable<T> {
pub const fn new() -> Self {
Self {
inner: BTreeMap::new(),
}
}
pub fn try_insert_with<F: FnOnce() -> Result<Arc<T>, Error>>(
&mut self,
local: SocketAddr,
remote: SocketAddr,
with: F,
) -> Result<Arc<T>, Error> {
if self.inner.contains_key(&(local, remote)) {
return Err(Error::AddrInUse);
}
let socket = with()?;
self.inner.insert((local, remote), socket.clone());
Ok(socket)
}
pub fn try_insert_with_ephemeral_port<F: FnMut(u16) -> Result<Arc<T>, Error>>(
&mut self,
local: IpAddr,
remote: SocketAddr,
mut with: F,
) -> Result<Arc<T>, Error> {
for port in 32768..u16::MAX - 1 {
let local = SocketAddr::new(local, port);
match self.try_insert_with(local, remote, || with(port)) {
Ok(socket) => return Ok(socket),
Err(Error::AddrInUse) => continue,
Err(error) => return Err(error),
}
}
Err(Error::AddrInUse)
}
pub fn remove(&mut self, local: SocketAddr, remote: SocketAddr) -> Result<(), Error> {
match self.inner.remove(&(local, remote)) {
Some(_) => Ok(()),
None => Err(Error::DoesNotExist),
}
}
pub fn get(&self, local: SocketAddr, remote: SocketAddr) -> Option<Arc<T>> {
self.inner.get(&(local, remote)).cloned()
}
}
impl<T: Socket> SocketTable<T> {
pub const fn new() -> Self {
Self {
inner: BTreeMap::new(),
}
}
pub fn try_insert_with_ephemeral_port<F: FnMut(u16) -> Result<Arc<T>, Error>>(
&mut self,
local: IpAddr,
mut with: F,
) -> Result<Arc<T>, Error> {
for port in 32768..u16::MAX - 1 {
let local = SocketAddr::new(local, port);
match self.try_insert_with(local, || with(port)) {
Ok(socket) => return Ok(socket),
Err(Error::AddrInUse) => continue,
Err(error) => return Err(error),
}
}
Err(Error::AddrInUse)
}
pub fn try_insert_with<F: FnOnce() -> Result<Arc<T>, Error>>(
&mut self,
address: SocketAddr,
with: F,
) -> Result<Arc<T>, Error> {
if self.inner.contains_key(&address) {
return Err(Error::AddrInUse);
}
let socket = with()?;
self.inner.insert(address, socket.clone());
Ok(socket)
}
pub fn remove(&mut self, local: SocketAddr) -> Result<(), Error> {
match self.inner.remove(&local) {
Some(_) => Ok(()),
None => Err(Error::DoesNotExist),
}
}
pub fn get_exact(&self, local: &SocketAddr) -> Option<Arc<T>> {
self.inner.get(local).cloned()
}
pub fn get(&self, local: &SocketAddr) -> Option<Arc<T>> {
if let Some(socket) = self.inner.get(local) {
return Some(socket.clone());
}
match local {
SocketAddr::V4(_v4) => {
let unspec_v4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, local.port());
self.inner.get(&SocketAddr::V4(unspec_v4)).cloned()
}
SocketAddr::V6(_) => todo!(),
}
}
}

View File

@ -0,0 +1,198 @@
use core::{
fmt,
sync::atomic::{AtomicU32, Ordering},
task::{Context, Poll},
};
use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
use async_trait::async_trait;
use libk::{
error::Error,
vfs::{FileReadiness, PacketSocket, Socket},
};
use libk_mm::PageBox;
use libk_util::{
queue::BoundedMpmcQueue,
sync::{spin_rwlock::IrqSafeRwLock, IrqSafeSpinlock},
};
use yggdrasil_abi::net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketInterfaceQuery, SocketOption};
use crate::{ethernet::L2Packet, interface::NetworkInterface};
pub struct RawSocket {
id: u32,
bound: IrqSafeSpinlock<Option<u32>>,
receive_queue: BoundedMpmcQueue<L2Packet>,
}
static RAW_SOCKET_ID: AtomicU32 = AtomicU32::new(0);
static RAW_SOCKETS: IrqSafeRwLock<BTreeMap<u32, Arc<RawSocket>>> =
IrqSafeRwLock::new(BTreeMap::new());
static BOUND_RAW_SOCKETS: IrqSafeRwLock<BTreeMap<u32, Vec<u32>>> =
IrqSafeRwLock::new(BTreeMap::new());
impl RawSocket {
pub fn bind() -> Result<Arc<Self>, Error> {
let id = RAW_SOCKET_ID.fetch_add(1, Ordering::SeqCst);
let socket = Self {
id,
bound: IrqSafeSpinlock::new(None),
receive_queue: BoundedMpmcQueue::new(256),
};
let socket = Arc::new(socket);
RAW_SOCKETS.write().insert(id, socket.clone());
Ok(socket)
}
fn bound_packet_received(&self, packet: L2Packet) {
// TODO do something with the dropped packet?
self.receive_queue.try_push_back(packet).ok();
}
pub fn packet_received(packet: L2Packet) {
let bound_sockets = BOUND_RAW_SOCKETS.read();
let raw_sockets = RAW_SOCKETS.read();
if let Some(ids) = bound_sockets.get(&packet.interface_id) {
for id in ids {
let socket = raw_sockets.get(id).unwrap();
socket.bound_packet_received(packet.clone());
}
}
}
fn packet_to_user(packet: L2Packet, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let full_len = packet.data.len();
let len = full_len - packet.l2_offset;
if buffer.len() < len {
return Err(Error::BufferTooSmall);
}
buffer[..len].copy_from_slice(&packet.data[packet.l2_offset..full_len]);
Ok((len, SocketAddr::NULL_V4))
}
}
impl FileReadiness for RawSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.receive_queue.poll_not_empty(cx).map(Ok)
}
}
impl Socket for RawSocket {
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::BoundHardwareAddress(mac) => {
let bound = self.bound.lock().ok_or(Error::DoesNotExist)?;
let interface = NetworkInterface::get(bound).unwrap();
*mac = interface.mac;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
SocketOption::BindInterface(query) => {
let mut bound = self.bound.lock();
if bound.is_some() {
return Err(Error::AlreadyExists);
}
let mut bound_sockets = BOUND_RAW_SOCKETS.write();
let interface = match *query {
SocketInterfaceQuery::ById(id) => NetworkInterface::get(id),
SocketInterfaceQuery::ByName(name) => NetworkInterface::query_by_name(name),
}?;
let list = bound_sockets.entry(interface.id).or_default();
bound.replace(interface.id);
list.push(self.id);
Ok(())
}
SocketOption::UnbindInterface => todo!(),
_ => Err(Error::InvalidOperation),
}
}
fn close(&self) -> Result<(), Error> {
let bound = self.bound.lock().take();
if let Some(bound) = bound {
let mut bound_sockets = BOUND_RAW_SOCKETS.write();
let mut clear = false;
if let Some(list) = bound_sockets.get_mut(&bound) {
list.retain(|&item| item != self.id);
clear = list.is_empty();
}
if clear {
bound_sockets.remove(&bound);
}
}
RAW_SOCKETS.write().remove(&self.id).unwrap();
Ok(())
}
fn local_address(&self) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
}
fn remote_address(&self) -> Option<SocketAddr> {
None
}
}
#[async_trait]
impl PacketSocket for RawSocket {
async fn send_to(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error> {
self.send_nonblocking(destination, data)
}
// TODO currently this is still blocking by NIC send code
fn send_nonblocking(
&self,
_destination: Option<SocketAddr>,
buffer: &[u8],
) -> Result<usize, Error> {
// TODO cap by MTU?
let bound = self.bound.lock().ok_or(Error::InvalidOperation)?;
let interface = NetworkInterface::get(bound)?;
let l2_offset = interface.device.packet_prefix_size();
if buffer.len() > 4096 - l2_offset {
return Err(Error::InvalidArgument);
}
let mut packet = PageBox::new_slice(0, l2_offset + buffer.len())?;
packet[l2_offset..l2_offset + buffer.len()].copy_from_slice(buffer);
interface.device.transmit(packet)?;
Ok(buffer.len())
}
async fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let packet = self.receive_queue.pop_front().await;
Self::packet_to_user(packet, buffer)
}
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let packet = self
.receive_queue
.try_pop_front()
.ok_or(Error::WouldBlock)?;
Self::packet_to_user(packet, buffer)
}
}
impl fmt::Debug for RawSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let bound = *self.bound.lock();
f.debug_struct("RawSocket")
.field("interface", &bound)
.finish_non_exhaustive()
}
}

View File

@ -0,0 +1,498 @@
use core::{
fmt,
future::{poll_fn, Future},
sync::atomic::{AtomicU8, Ordering},
task::{Context, Poll},
time::Duration,
};
use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
use async_trait::async_trait;
use libk::{
block,
error::Error,
task::runtime::{run_with_timeout, FutureTimeout},
vfs::{ConnectionSocket, FileReadiness, ListenerSocket, Socket},
};
use libk_device::monotonic_timestamp;
use libk_util::{
sync::{
spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockWriteGuard},
IrqSafeSpinlock, IrqSafeSpinlockGuard,
},
waker::QueueWaker,
};
use yggdrasil_abi::net::{SocketAddr, SocketOption};
use crate::{
interface::NetworkInterface,
l3::Route,
l4::tcp::{TcpConnection, TcpConnectionState},
};
use super::{SocketTable, TwoWaySocketTable};
pub struct TcpSocket {
pub(crate) local: SocketAddr,
pub(crate) remote: SocketAddr,
ttl: AtomicU8,
// Listener which accepted the socket
listener: Option<Arc<TcpListener>>,
connection: IrqSafeRwLock<TcpConnection>,
}
pub struct TcpListener {
accept: SocketAddr,
// Currently active sockets
sockets: IrqSafeRwLock<BTreeMap<SocketAddr, Arc<TcpSocket>>>,
pending_accept: IrqSafeSpinlock<Vec<Arc<TcpSocket>>>,
accept_notify: QueueWaker,
}
static TCP_SOCKETS: IrqSafeRwLock<TwoWaySocketTable<TcpSocket>> =
IrqSafeRwLock::new(TwoWaySocketTable::new());
static TCP_LISTENERS: IrqSafeRwLock<SocketTable<TcpListener>> =
IrqSafeRwLock::new(SocketTable::new());
impl TcpSocket {
pub async fn connect(
remote: SocketAddr,
timeout: Option<Duration>,
) -> Result<(SocketAddr, Arc<TcpSocket>), Error> {
let future = Self::connect_async(remote);
match timeout {
Some(timeout) => run_with_timeout(timeout, future).await.into(),
None => future.await,
}
}
pub fn accept_remote(
listener: Arc<TcpListener>,
local: SocketAddr,
remote: SocketAddr,
remote_window_size: usize,
tx_seq: u32,
rx_seq: u32,
) -> Result<Arc<TcpSocket>, Error> {
let mut sockets = TCP_SOCKETS.write();
sockets.try_insert_with(local, remote, move || {
let connection = TcpConnection::new(
local,
remote,
remote_window_size,
tx_seq,
rx_seq,
TcpConnectionState::SynReceived,
);
log::debug!("Accepted TCP socket {} -> {}", local, remote);
let socket = Self {
local,
remote,
ttl: AtomicU8::new(64),
listener: Some(listener),
connection: IrqSafeRwLock::new(connection),
};
Ok(Arc::new(socket))
})
}
pub fn connection(&self) -> &IrqSafeRwLock<TcpConnection> {
&self.connection
}
pub(crate) fn accept(self: &Arc<Self>) {
if let Some(listener) = self.listener.as_ref() {
listener.accept_socket(self.clone());
}
}
pub fn get(local: SocketAddr, remote: SocketAddr) -> Option<Arc<Self>> {
TCP_SOCKETS.read().get(local, remote)
}
pub fn receive_async<'a>(
&'a self,
buffer: &'a mut [u8],
) -> impl Future<Output = Result<usize, Error>> + 'a {
poll_fn(|cx| match self.poll_receive(cx) {
Poll::Ready(Ok(mut lock)) => Poll::Ready(lock.read_nonblocking(buffer)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
})
}
pub async fn send_async(&self, data: &[u8]) -> Result<usize, Error> {
let mut pos = 0;
let mut rem = data.len();
while rem != 0 {
// TODO check MTU
let amount = rem.min(512);
self.send_segment_async(&data[pos..pos + amount]).await?;
pos += amount;
rem -= amount;
}
Ok(pos)
}
pub async fn close_async(&self, remove_from_listener: bool) -> Result<(), Error> {
// TODO timeout here
// Already closing
if self.connection.read().is_closing() {
return Ok(());
}
// Wait for all sent data to be acknowledged
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.finish().await?;
}
log::debug!(
"TCP socket closed (FinWait2/Closed): {} <-> {}",
self.local,
self.remote
);
// Wait for connection to get closed
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_finish(cx)
})
.await;
if remove_from_listener {
if let Some(listener) = self.listener.as_ref() {
listener.remove_socket(self.remote);
};
}
Ok(())
}
pub(crate) fn remove_socket(&self) -> Result<(), Error> {
log::debug!(
"TCP socket closed and removed: {} <-> {}",
self.local,
self.remote
);
let connection = self.connection.read();
debug_assert!(connection.is_closed());
TCP_SOCKETS.write().remove(self.local, self.remote)?;
connection.notify_all();
Ok(())
}
fn poll_receive(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<IrqSafeRwLockWriteGuard<TcpConnection>, Error>> {
let lock = self.connection.write();
match lock.poll_receive(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(lock)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
}
async fn send_segment_async(&self, data: &[u8]) -> Result<(), Error> {
// TODO timeout here
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.transmit(data).await?;
}
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_acknowledge(cx)
})
.await;
Ok(())
}
async fn connect_async(remote: SocketAddr) -> Result<(SocketAddr, Arc<TcpSocket>), Error> {
// Lookup route to remote
let (interface_id, _, remote_ip) =
Route::lookup(remote.ip()).ok_or(Error::HostUnreachable)?;
let remote = SocketAddr::new(remote_ip, remote.port());
let interface = NetworkInterface::get(interface_id)?;
let local_ip = interface.address.read().ok_or(Error::NetworkUnreachable)?;
let socket = {
let mut sockets = TCP_SOCKETS.write();
sockets.try_insert_with_ephemeral_port(local_ip, remote, |port| {
let t = monotonic_timestamp()?;
let tx_seq = t.as_micros() as u32;
let local = SocketAddr::new(local_ip, port);
let connection =
TcpConnection::new(local, remote, 16384, tx_seq, 0, TcpConnectionState::Closed);
let socket = Self {
local,
remote,
ttl: AtomicU8::new(64),
listener: None,
connection: IrqSafeRwLock::new(connection),
};
Ok(Arc::new(socket))
})?
};
let mut t = 200;
for _ in 0..5 {
let timeout = Duration::from_millis(t);
log::debug!("Try SYN with timeout={:?}", timeout);
match socket.try_connect(timeout).await {
Ok(()) => return Ok((socket.local, socket)),
Err(Error::TimedOut) => (),
Err(error) => return Err(error),
}
t *= 2;
}
// Couldn't establish
Err(Error::TimedOut)
}
async fn try_connect(&self, timeout: Duration) -> Result<(), Error> {
{
let mut connection = self.connection.write();
connection.send_syn().await?;
}
let fut = poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_established(cx)
});
match run_with_timeout(timeout, fut).await {
FutureTimeout::Ok(value) => value,
FutureTimeout::Timeout => Err(Error::TimedOut),
}
}
}
impl Socket for TcpSocket {
fn local_address(&self) -> SocketAddr {
self.local
}
fn remote_address(&self) -> Option<SocketAddr> {
Some(self.remote)
}
fn close(&self) -> Result<(), Error> {
block!(self.close_async(true).await)?
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
&SocketOption::Ttl(ttl) => {
if ttl == 0 || ttl > 255 {
return Err(Error::InvalidArgument);
}
self.ttl.store(ttl as _, Ordering::Release);
Ok(())
}
SocketOption::NoDelay(_) => {
log::warn!("TODO: TCP nodelay");
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::Ttl(ttl) => {
*ttl = self.ttl.load(Ordering::Acquire) as _;
Ok(())
}
SocketOption::NoDelay(nodelay) => {
*nodelay = false;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
}
impl FileReadiness for TcpSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.poll_receive(cx).map_ok(|_| ())
}
}
#[async_trait]
impl ConnectionSocket for TcpSocket {
async fn receive(&self, buffer: &mut [u8]) -> Result<usize, Error> {
self.receive_async(buffer).await
}
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<usize, Error> {
match self.connection.write().read_nonblocking(buffer) {
// TODO check if this really means "no data at the moment"
Ok(0) => Err(Error::WouldBlock),
Ok(len) => Ok(len),
Err(error) => Err(error),
}
}
async fn send(&self, data: &[u8]) -> Result<usize, Error> {
self.send_async(data).await
}
fn send_nonblocking(&self, data: &[u8]) -> Result<usize, Error> {
log::warn!("TODO: TCP::send_nonblocking");
block!(self.send_async(data).await)?
}
}
impl fmt::Debug for TcpSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TcpSocket")
.field("local", &self.local)
.field("remote", &self.remote)
.finish_non_exhaustive()
}
}
impl TcpListener {
pub fn bind(accept: SocketAddr) -> Result<Arc<Self>, Error> {
TCP_LISTENERS.write().try_insert_with(accept, || {
let listener = TcpListener {
accept,
sockets: IrqSafeRwLock::new(BTreeMap::new()),
pending_accept: IrqSafeSpinlock::new(Vec::new()),
accept_notify: QueueWaker::new(),
};
log::debug!("TCP Listener opened: {}", accept);
Ok(Arc::new(listener))
})
}
pub fn get(local: SocketAddr) -> Option<Arc<Self>> {
TCP_LISTENERS.read().get(&local)
}
pub fn accept_async(&self) -> impl Future<Output = Result<Arc<TcpSocket>, Error>> + '_ {
poll_fn(|cx| match self.poll_accept(cx) {
Poll::Ready(mut lock) => Poll::Ready(Ok(lock.pop().unwrap())),
Poll::Pending => Poll::Pending,
})
}
fn accept_socket(&self, socket: Arc<TcpSocket>) {
log::debug!("{}: accept {}", self.accept, socket.remote);
self.sockets.write().insert(socket.remote, socket.clone());
self.pending_accept.lock().push(socket);
self.accept_notify.wake_one();
}
fn remove_socket(&self, remote: SocketAddr) {
log::debug!("Remove client {}", remote);
self.sockets.write().remove(&remote);
}
fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<IrqSafeSpinlockGuard<Vec<Arc<TcpSocket>>>> {
let lock = self.pending_accept.lock();
self.accept_notify.register(cx.waker());
if !lock.is_empty() {
self.accept_notify.remove(cx.waker());
Poll::Ready(lock)
} else {
Poll::Pending
}
}
}
impl Socket for TcpListener {
fn local_address(&self) -> SocketAddr {
self.accept
}
fn remote_address(&self) -> Option<SocketAddr> {
None
}
fn close(&self) -> Result<(), Error> {
// TODO if clients not closed already, send RST?
TCP_LISTENERS.write().remove(self.accept)
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
SocketOption::Ipv6Only(_v6only) => {
log::warn!("TODO: TCP listener IPv6-only");
Err(Error::InvalidOperation)
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::Ipv6Only(v6only) => {
*v6only = false;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
}
impl FileReadiness for TcpListener {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.poll_accept(cx).map(|_| Ok(()))
}
}
#[async_trait]
impl ListenerSocket for TcpListener {
async fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let socket = self.accept_async().await?;
let remote = socket.remote;
Ok((remote, socket))
}
fn accept_nonblocking(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let socket = self.pending_accept.lock().pop().ok_or(Error::WouldBlock)?;
let remote = socket.remote;
Ok((remote, socket))
}
}
impl fmt::Debug for TcpListener {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TcpListener")
.field("local", &self.accept)
.finish_non_exhaustive()
}
}

View File

@ -0,0 +1,217 @@
use core::{
fmt,
sync::atomic::{AtomicBool, AtomicU8, Ordering},
task::{Context, Poll},
};
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use async_trait::async_trait;
use libk::{
block,
error::Error,
vfs::{FileReadiness, PacketSocket, Socket},
};
use libk_util::{queue::BoundedMpmcQueue, sync::spin_rwlock::IrqSafeRwLock};
use yggdrasil_abi::net::{SocketAddr, SocketOption};
use crate::l4;
use super::SocketTable;
pub struct UdpSocket {
local: SocketAddr,
remote: Option<SocketAddr>,
broadcast: AtomicBool,
ttl: AtomicU8,
// TODO just place packets here for one less copy?
receive_queue: BoundedMpmcQueue<(SocketAddr, Vec<u8>)>,
}
static UDP_SOCKETS: IrqSafeRwLock<SocketTable<UdpSocket>> = IrqSafeRwLock::new(SocketTable::new());
impl UdpSocket {
fn create_socket(local: SocketAddr) -> Arc<UdpSocket> {
log::debug!("UDP socket opened: {}", local);
Arc::new(UdpSocket {
local,
ttl: AtomicU8::new(64),
remote: None,
broadcast: AtomicBool::new(false),
receive_queue: BoundedMpmcQueue::new(128),
})
}
pub fn bind(address: SocketAddr) -> Result<Arc<UdpSocket>, Error> {
let mut sockets = UDP_SOCKETS.write();
if address.port() == 0 {
sockets.try_insert_with_ephemeral_port(address.ip(), |port| {
Ok(Self::create_socket(SocketAddr::new(address.ip(), port)))
})
} else {
sockets.try_insert_with(address, move || Ok(Self::create_socket(address)))
}
}
pub fn connect(&self, _address: SocketAddr) -> Result<(), Error> {
todo!()
}
pub fn get(local: &SocketAddr) -> Option<Arc<UdpSocket>> {
UDP_SOCKETS.read().get(local)
}
pub fn packet_received(&self, source: SocketAddr, data: &[u8]) -> Result<(), Error> {
self.receive_queue
.try_push_back((source, Vec::from(data)))
.map_err(|_| Error::QueueFull)
}
}
impl FileReadiness for UdpSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.receive_queue.poll_not_empty(cx).map(Ok)
}
}
#[async_trait]
impl PacketSocket for UdpSocket {
async fn send_to(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error> {
let Some(destination) = destination else {
// TODO can still send without setting address if "connected"
return Err(Error::InvalidArgument);
};
// TODO check that destnation family matches self family
match (self.broadcast.load(Ordering::Acquire), destination.ip()) {
// SendTo in broadcast?
(true, _) => todo!(),
(false, _) => {
l4::udp::send(
self.local.port(),
destination.ip(),
destination.port(),
self.ttl.load(Ordering::Acquire),
data,
)
.await?;
}
}
Ok(data.len())
}
fn send_nonblocking(
&self,
destination: Option<SocketAddr>,
buffer: &[u8],
) -> Result<usize, Error> {
log::warn!("TODO: UDP::send_nonblocking()");
block!(self.send_to(destination, buffer).await)?
}
async fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let (source, packet) = self.receive_queue.pop_front().await;
if packet.len() > buffer.len() {
// TODO check how other OSs handle this
return Err(Error::BufferTooSmall);
}
buffer[..packet.len()].copy_from_slice(&packet);
Ok((packet.len(), source))
}
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let (source, packet) = self
.receive_queue
.try_pop_front()
.ok_or(Error::WouldBlock)?;
if packet.len() > buffer.len() {
// TODO check how other OSs handle this
return Err(Error::BufferTooSmall);
}
buffer[..packet.len()].copy_from_slice(&packet);
Ok((packet.len(), source))
}
}
impl Socket for UdpSocket {
fn local_address(&self) -> SocketAddr {
self.local
}
fn remote_address(&self) -> Option<SocketAddr> {
self.remote
}
fn close(&self) -> Result<(), Error> {
log::debug!("UDP socket closed: {}", self.local);
UDP_SOCKETS.write().remove(self.local)
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
&SocketOption::Broadcast(broadcast) => {
log::debug!("{} broadcast: {}", self.local, broadcast);
self.broadcast.store(broadcast, Ordering::Release);
Ok(())
}
&SocketOption::Ttl(ttl) => {
if ttl == 0 || ttl > 255 {
return Err(Error::InvalidArgument);
}
self.ttl.store(ttl as _, Ordering::Release);
Ok(())
}
SocketOption::MulticastTtlV4(_) => {
log::warn!("TODO: UDP multicast v4 timeout");
Err(Error::InvalidOperation)
}
SocketOption::MulticastLoopV4(_) => {
log::warn!("TODO: UDP multicast loop v4");
Err(Error::InvalidOperation)
}
SocketOption::MulticastLoopV6(_) => {
log::warn!("TODO: UDP multicast loop v6");
Err(Error::InvalidOperation)
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::Broadcast(broadcast) => {
*broadcast = self.broadcast.load(Ordering::Acquire);
Ok(())
}
SocketOption::Ttl(ttl) => {
*ttl = self.ttl.load(Ordering::Acquire) as _;
Ok(())
}
SocketOption::MulticastTtlV4(ttl) => {
*ttl = 64;
Ok(())
}
SocketOption::MulticastLoopV4(loop_v4) => {
*loop_v4 = false;
Ok(())
}
SocketOption::MulticastLoopV6(loop_v6) => {
*loop_v6 = false;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
}
impl fmt::Debug for UdpSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UdpSocket")
.field("local", &self.local)
.field("remote", &self.remote)
.finish_non_exhaustive()
}
}

View File

@ -20,17 +20,15 @@ use yggdrasil_abi::{
DeviceRequest, DirectoryEntry, OpenOptions, RawFd, SeekFrom, TerminalOptions, TerminalSize,
TimerOptions,
},
net::SocketAddr,
};
use crate::vfs::{
channel::ChannelDescriptor,
device::{BlockDeviceWrapper, CharDeviceWrapper},
node::NodeRef,
socket::{ConnectionSocketWrapper, ListenerSocketWrapper, PacketSocketWrapper},
traits::{Read, Seek, Write},
ConnectionSocket, FdPoll, FileReadiness, ListenerSocket, Node, PacketSocket,
PseudoTerminalMaster, PseudoTerminalSlave, SharedMemory, Socket, TimerFile,
FdPoll, FileReadiness, Node, PseudoTerminalMaster, PseudoTerminalSlave, SharedMemory,
TimerFile,
};
use self::{
@ -40,7 +38,7 @@ use self::{
regular::RegularFile,
};
use super::pty;
use super::{pty, socket::SocketWrapper};
mod device;
mod directory;
@ -69,11 +67,7 @@ pub enum File {
Regular(RegularFile),
Block(BlockFile),
Char(CharFile),
PacketSocket(Arc<PacketSocketWrapper>),
ListenerSocket(Arc<ListenerSocketWrapper>),
StreamSocket(Arc<ConnectionSocketWrapper>),
Socket(SocketWrapper),
AnonymousPipe(PipeEnd),
Poll(FdPoll),
Timer(TimerFile),
@ -160,23 +154,9 @@ impl File {
Arc::new(Self::Timer(TimerFile::new(repeat, blocking)))
}
/// Constructs a [File] from a [PacketSocket]
pub fn from_packet_socket(socket: Arc<dyn PacketSocket>) -> Arc<Self> {
Arc::new(Self::PacketSocket(Arc::new(PacketSocketWrapper(socket))))
}
/// Constructs a [File] from a [ListenerSocket]
pub fn from_listener_socket(socket: Arc<dyn ListenerSocket>) -> Arc<Self> {
Arc::new(Self::ListenerSocket(Arc::new(ListenerSocketWrapper(
socket,
))))
}
/// Constructs a [File] from a [ConnectionSocket]
pub fn from_stream_socket(socket: Arc<dyn ConnectionSocket>) -> Arc<Self> {
Arc::new(Self::StreamSocket(Arc::new(ConnectionSocketWrapper(
socket,
))))
/// Constructs a [File] from a [PacketSocket], [ConnectionSocket] or a [ListenerSocket].
pub fn from_socket<S: Into<SocketWrapper>>(socket: S) -> Arc<Self> {
Arc::new(Self::Socket(socket.into()))
}
pub(crate) fn directory(node: NodeRef, position: DirectoryOpenPosition) -> Arc<Self> {
@ -296,9 +276,7 @@ impl File {
Self::Poll(ch) => ch.poll_read(cx),
Self::PtyMaster(half) => half.half.poll_read(cx),
Self::PtySlave(half) => half.half.poll_read(cx),
Self::PacketSocket(sock) => sock.poll_read(cx),
Self::StreamSocket(sock) => sock.poll_read(cx),
Self::ListenerSocket(sock) => sock.poll_read(cx),
Self::Socket(socket) => socket.poll_read(cx),
Self::Timer(timer) => timer.poll_read(cx),
// Polling not implemented, return ready immediately (XXX ?)
_ => Poll::Ready(Err(Error::NotImplemented)),
@ -335,51 +313,9 @@ impl File {
}
/// Interprets the file as a socket
pub fn as_socket(&self) -> Result<&dyn Socket, Error> {
pub fn as_socket(&self) -> Result<&SocketWrapper, Error> {
match self {
Self::PacketSocket(socket) => Ok(socket.0.as_ref()),
_ => Err(Error::InvalidOperation),
}
}
/// Sends data to a socket
pub fn send_to(&self, buffer: &[u8], recepient: Option<SocketAddr>) -> Result<usize, Error> {
match (self, recepient) {
(Self::PacketSocket(socket), recepient) => socket.send(recepient, buffer),
(Self::StreamSocket(socket), None) => socket.send(buffer),
(_, _) => todo!(),
}
}
/// Receives data from a socket
pub fn receive_from(
&self,
buffer: &mut [u8],
remote: &mut MaybeUninit<SocketAddr>,
) -> Result<usize, Error> {
match self {
Self::PacketSocket(socket) => {
let (addr, len) = socket.receive(buffer)?;
remote.write(addr);
Ok(len)
}
Self::StreamSocket(socket) => {
// Always the same
remote.write(socket.remote_address().unwrap());
socket.receive(buffer)
}
_ => Err(Error::InvalidOperation),
}
}
/// Waits for incoming connection to be accepted by the listener
pub fn accept(&self, remote: &mut MaybeUninit<SocketAddr>) -> Result<FileRef, Error> {
match self {
Self::ListenerSocket(socket) => {
let (address, incoming) = socket.accept()?;
remote.write(address);
Ok(File::from_stream_socket(incoming))
}
Self::Socket(socket) => Ok(socket),
_ => Err(Error::InvalidOperation),
}
}
@ -428,9 +364,7 @@ impl Read for File {
Self::Channel(_) => Err(Error::InvalidOperation),
Self::SharedMemory(_) => Err(Error::InvalidOperation),
// TODO maybe allow reading messages from Packet/Stream sockets?
Self::PacketSocket(_) | Self::ListenerSocket(_) | Self::StreamSocket(_) => {
Err(Error::InvalidOperation)
}
Self::Socket(_) => Err(Error::InvalidOperation),
Self::Directory(_) => Err(Error::IsADirectory),
}
}
@ -452,9 +386,7 @@ impl Write for File {
Self::Channel(_) => Err(Error::InvalidOperation),
Self::SharedMemory(_) => Err(Error::InvalidOperation),
// TODO maybe allow writing messages to Packet/Stream sockets?
Self::PacketSocket(_) | Self::ListenerSocket(_) | Self::StreamSocket(_) => {
Err(Error::InvalidOperation)
}
Self::Socket(_) => Err(Error::InvalidOperation),
Self::Directory(_) => Err(Error::IsADirectory),
}
}
@ -507,20 +439,7 @@ impl fmt::Debug for File {
Self::SharedMemory(_) => f.debug_struct("SharedMemory").finish_non_exhaustive(),
Self::PtySlave(_) => f.debug_struct("PtySlave").finish_non_exhaustive(),
Self::PtyMaster(_) => f.debug_struct("PtyMaster").finish_non_exhaustive(),
Self::PacketSocket(sock) => f
.debug_struct("PacketSocket")
.field("local", &sock.local_address())
.field("remote", &sock.remote_address())
.finish_non_exhaustive(),
Self::StreamSocket(sock) => f
.debug_struct("StreamSocket")
.field("local", &sock.local_address())
.field("remote", &sock.remote_address())
.finish_non_exhaustive(),
Self::ListenerSocket(sock) => f
.debug_struct("ListenerSocket")
.field("local", &sock.local_address())
.finish_non_exhaustive(),
Self::Socket(socket) => fmt::Debug::fmt(socket, f),
Self::Timer(_) => f.debug_struct("Timer").finish_non_exhaustive(),
}
}

View File

@ -30,7 +30,7 @@ pub use node::{
pub use poll::FdPoll;
pub use pty::{PseudoTerminalMaster, PseudoTerminalSlave};
pub use shared_memory::SharedMemory;
pub use socket::{ConnectionSocket, ListenerSocket, PacketSocket, Socket};
pub use socket::{ConnectionSocket, ListenerSocket, PacketSocket, Socket, SocketWrapper};
pub use terminal::{Terminal, TerminalInput, TerminalOutput};
pub use timer::TimerFile;
pub use traits::{FileReadiness, Read, Seek, Write};

View File

@ -1,16 +1,43 @@
use core::ops::Deref;
use core::{
fmt,
future::Future,
task::{Context, Poll},
time::Duration,
};
use alloc::sync::Arc;
use alloc::{boxed::Box, sync::Arc};
use async_trait::async_trait;
use libk_util::sync::spin_rwlock::IrqSafeRwLock;
use yggdrasil_abi::{
error::Error,
net::{SocketAddr, SocketOption},
};
use crate::vfs::FileReadiness;
use crate::{
task::runtime::{run_with_timeout, FutureTimeout},
vfs::FileReadiness,
};
enum SocketInner {
Connection(Arc<dyn ConnectionSocket + Send + 'static>),
Listener(Arc<dyn ListenerSocket + Send + 'static>),
Packet(Arc<dyn PacketSocket + Send + 'static>),
}
struct InnerOptions {
recv_timeout: Option<Duration>,
send_timeout: Option<Duration>,
non_blocking: bool,
}
pub struct SocketWrapper {
inner: SocketInner,
options: IrqSafeRwLock<InnerOptions>,
}
/// Interface for interacting with network sockets
#[allow(unused)]
pub trait Socket: FileReadiness + Send {
pub trait Socket: FileReadiness + fmt::Debug + Send {
/// Socket listen/receive address
fn local_address(&self) -> SocketAddr;
@ -32,71 +59,342 @@ pub trait Socket: FileReadiness + Send {
}
/// Stateless/packet-based socket interface
#[async_trait]
pub trait PacketSocket: Socket {
/// Receives a packet into provided buffer. Will return an error if packet cannot be placed
/// within the buffer.
fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error>;
async fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error>;
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error>;
/// Sends provided data to the recepient specified by `destination`
fn send(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error>;
async fn send_to(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error>;
fn send_nonblocking(
&self,
destination: Option<SocketAddr>,
buffer: &[u8],
) -> Result<usize, Error>;
}
/// Connection-based client socket interface
#[async_trait]
pub trait ConnectionSocket: Socket {
/// Receives data into provided buffer
fn receive(&self, buffer: &mut [u8]) -> Result<usize, Error>;
async fn receive(&self, buffer: &mut [u8]) -> Result<usize, Error>;
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<usize, Error>;
/// Transmits data
fn send(&self, data: &[u8]) -> Result<usize, Error>;
async fn send(&self, data: &[u8]) -> Result<usize, Error>;
fn send_nonblocking(&self, buffer: &[u8]) -> Result<usize, Error>;
}
/// Connection-based listener socket interface
#[async_trait]
pub trait ListenerSocket: Socket {
/// Blocks the execution until an incoming connection is accepted
fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error>;
async fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error>;
fn accept_nonblocking(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error>;
}
pub struct PacketSocketWrapper(pub Arc<dyn PacketSocket + 'static>);
pub struct ListenerSocketWrapper(pub Arc<dyn ListenerSocket + 'static>);
pub struct ConnectionSocketWrapper(pub Arc<dyn ConnectionSocket + 'static>);
impl SocketWrapper {
pub fn from_connection(socket: Arc<dyn ConnectionSocket + 'static>) -> Self {
Self {
inner: SocketInner::Connection(socket),
options: IrqSafeRwLock::new(InnerOptions::default()),
}
}
impl Deref for PacketSocketWrapper {
type Target = dyn PacketSocket;
pub fn from_packet(socket: Arc<dyn PacketSocket + 'static>) -> Self {
Self {
inner: SocketInner::Packet(socket),
options: IrqSafeRwLock::new(InnerOptions::default()),
}
}
fn deref(&self) -> &Self::Target {
self.0.as_ref()
pub fn from_listener(socket: Arc<dyn ListenerSocket + 'static>) -> Self {
Self {
inner: SocketInner::Listener(socket),
options: IrqSafeRwLock::new(InnerOptions::default()),
}
}
pub fn accept(&self) -> Result<(SocketWrapper, SocketAddr), Error> {
let SocketInner::Listener(socket) = &self.inner else {
return Err(Error::InvalidOperation);
};
let options = self.options.read();
let (remote, remote_socket) = match (options.non_blocking, options.recv_timeout) {
(false, timeout) => {
let fut = socket.accept();
block!(maybe_timeout(timeout, fut).await)??
}
(true, _) => socket.accept_nonblocking()?,
};
let remote_socket = Self::from_connection(remote_socket);
Ok((remote_socket, remote))
}
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
match &self.inner {
SocketInner::Packet(socket) => socket.receive_nonblocking(buffer),
SocketInner::Connection(socket) => {
let remote = socket.remote_address().ok_or(Error::NotConnected)?;
let len = socket.receive_nonblocking(buffer)?;
Ok((len, remote))
}
SocketInner::Listener(_) => Err(Error::InvalidOperation),
}
}
async fn receive(
&self,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error> {
match &self.inner {
SocketInner::Packet(socket) => {
maybe_timeout(timeout, socket.receive_from(buffer)).await
}
SocketInner::Connection(socket) => {
let remote = socket.remote_address().ok_or(Error::NotConnected)?;
let len = maybe_timeout(timeout, socket.receive(buffer)).await?;
Ok((len, remote))
}
SocketInner::Listener(_) => Err(Error::InvalidOperation),
}
}
fn send_nonblocking(&self, buffer: &[u8], remote: Option<SocketAddr>) -> Result<usize, Error> {
match &self.inner {
SocketInner::Packet(socket) => socket.send_nonblocking(remote, buffer),
SocketInner::Connection(socket) => socket.send_nonblocking(buffer),
SocketInner::Listener(_) => Err(Error::InvalidOperation),
}
}
async fn send(
&self,
buffer: &[u8],
remote: Option<SocketAddr>,
timeout: Option<Duration>,
) -> Result<usize, Error> {
match &self.inner {
SocketInner::Packet(socket) => {
maybe_timeout(timeout, socket.send_to(remote, buffer)).await
}
SocketInner::Connection(socket) => maybe_timeout(timeout, socket.send(buffer)).await,
SocketInner::Listener(_) => Err(Error::InvalidOperation),
}
}
pub fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let options = self.options.read();
match (options.non_blocking, options.recv_timeout) {
(false, timeout) => block!(self.receive(buffer, timeout).await)?,
(true, _) => self.receive_nonblocking(buffer),
}
}
pub fn send_to(&self, buffer: &[u8], remote: Option<SocketAddr>) -> Result<usize, Error> {
let options = self.options.read();
match (options.non_blocking, options.recv_timeout) {
(false, timeout) => block!(self.send(buffer, remote, timeout).await)?,
(true, _) => self.send_nonblocking(buffer, remote),
}
}
}
impl Drop for PacketSocketWrapper {
impl Socket for SocketWrapper {
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
SocketOption::RecvTimeout(timeout) => {
self.options.write().recv_timeout = *timeout;
return Ok(());
}
SocketOption::SendTimeout(timeout) => {
self.options.write().send_timeout = *timeout;
return Ok(());
}
SocketOption::NonBlocking(nb) => {
self.options.write().non_blocking = *nb;
return Ok(());
}
_ => (),
}
match &self.inner {
SocketInner::Packet(socket) => socket.set_option(option),
SocketInner::Listener(socket) => socket.set_option(option),
SocketInner::Connection(socket) => socket.set_option(option),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::RecvTimeout(timeout) => {
*timeout = self.options.read().recv_timeout;
return Ok(());
}
SocketOption::SendTimeout(timeout) => {
*timeout = self.options.read().send_timeout;
return Ok(());
}
SocketOption::NonBlocking(nb) => {
*nb = self.options.read().non_blocking;
return Ok(());
}
_ => (),
}
match &self.inner {
SocketInner::Packet(socket) => socket.get_option(option),
SocketInner::Listener(socket) => socket.get_option(option),
SocketInner::Connection(socket) => socket.get_option(option),
}
}
fn local_address(&self) -> SocketAddr {
match &self.inner {
SocketInner::Packet(socket) => socket.local_address(),
SocketInner::Listener(socket) => socket.local_address(),
SocketInner::Connection(socket) => socket.local_address(),
}
}
fn remote_address(&self) -> Option<SocketAddr> {
match &self.inner {
SocketInner::Packet(socket) => socket.remote_address(),
SocketInner::Listener(socket) => socket.remote_address(),
SocketInner::Connection(socket) => socket.remote_address(),
}
}
fn close(&self) -> Result<(), Error> {
match &self.inner {
SocketInner::Packet(socket) => socket.close(),
SocketInner::Listener(socket) => socket.close(),
SocketInner::Connection(socket) => socket.close(),
}
}
}
impl FileReadiness for SocketWrapper {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match &self.inner {
SocketInner::Packet(socket) => socket.poll_read(cx),
SocketInner::Listener(socket) => socket.poll_read(cx),
SocketInner::Connection(socket) => socket.poll_read(cx),
}
}
}
impl fmt::Debug for SocketWrapper {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.inner {
SocketInner::Packet(socket) => socket.fmt(f),
SocketInner::Connection(socket) => socket.fmt(f),
SocketInner::Listener(socket) => socket.fmt(f),
}
}
}
impl Drop for SocketInner {
fn drop(&mut self) {
self.0.close().ok();
let res = match self {
Self::Packet(socket) => socket.close(),
Self::Connection(socket) => socket.close(),
Self::Listener(socket) => socket.close(),
};
if let Err(error) = res {
log::warn!("Socket close error: {error:?}");
}
}
}
impl Deref for ListenerSocketWrapper {
type Target = dyn ListenerSocket;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
impl Default for InnerOptions {
fn default() -> Self {
Self {
recv_timeout: None,
send_timeout: None,
non_blocking: false,
}
}
}
impl Drop for ListenerSocketWrapper {
fn drop(&mut self) {
self.0.close().ok();
async fn maybe_timeout<R, F: Future<Output = Result<R, Error>> + Send>(
timeout: Option<Duration>,
fut: F,
) -> Result<R, Error> {
if let Some(timeout) = timeout {
match run_with_timeout(timeout, fut).await {
FutureTimeout::Ok(value) => value,
FutureTimeout::Timeout => Err(Error::TimedOut),
}
} else {
fut.await
}
}
impl Deref for ConnectionSocketWrapper {
type Target = dyn ConnectionSocket;
// impl From<Arc<dyn ConnectionSocket + 'static>> for SocketWrapper {
// fn from(value: Arc<dyn ConnectionSocket + 'static>) -> Self {
// Self::Connection(value)
// }
// }
//
// impl From<Arc<dyn ListenerSocket + 'static>> for SocketWrapper {
// fn from(value: Arc<dyn ListenerSocket + 'static>) -> Self {
// Self::Listener(value)
// }
// }
//
// impl From<Arc<dyn PacketSocket + 'static>> for SocketWrapper {
// fn from(value: Arc<dyn PacketSocket + 'static>) -> Self {
// Self::Packet(value)
// }
// }
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl Drop for ConnectionSocketWrapper {
fn drop(&mut self) {
self.0.close().ok();
}
}
// impl Deref for PacketSocketWrapper {
// type Target = dyn PacketSocket;
//
// fn deref(&self) -> &Self::Target {
// self.0.as_ref()
// }
// }
//
// impl Drop for PacketSocketWrapper {
// fn drop(&mut self) {
// self.0.close().ok();
// }
// }
//
// impl Deref for ListenerSocketWrapper {
// type Target = dyn ListenerSocket;
//
// fn deref(&self) -> &Self::Target {
// self.0.as_ref()
// }
// }
//
// impl Drop for ListenerSocketWrapper {
// fn drop(&mut self) {
// self.0.close().ok();
// }
// }
//
// impl Deref for ConnectionSocketWrapper {
// type Target = dyn ConnectionSocket;
//
// fn deref(&self) -> &Self::Target {
// self.0.as_ref()
// }
// }
//
// impl Drop for ConnectionSocketWrapper {
// fn drop(&mut self) {
// self.0.close().ok();
// }
// }

View File

@ -5,7 +5,10 @@ use abi::{
io::RawFd,
net::{SocketConnect, SocketOption, SocketType},
};
use libk::{task::thread::Thread, vfs::File};
use libk::{
task::thread::Thread,
vfs::{File, Socket, SocketWrapper},
};
use ygg_driver_net_core::socket::{RawSocket, TcpListener, TcpSocket, UdpSocket};
use crate::syscall::run_with_io;
@ -20,16 +23,16 @@ pub(crate) fn connect_socket(
run_with_io(&process, |mut io| {
let (local, fd) = match connect {
&mut SocketConnect::Udp(_fd, _remote) => {
todo!("UDP socket connect");
}
&mut SocketConnect::Tcp(remote, timeout) => {
let (local, socket) = TcpSocket::connect(remote.into(), timeout)?;
let fd = io
.files
.place_file(File::from_stream_socket(socket), true)?;
SocketConnect::Tcp(remote, timeout) => {
let remote = (*remote).into();
let (local, socket) = libk::block!(TcpSocket::connect(remote, *timeout).await)??;
let file = File::from_socket(SocketWrapper::from_connection(socket));
let fd = io.files.place_file(file, true)?;
(local.into(), fd)
}
SocketConnect::Udp(_socket, _remote) => {
todo!("UDP socket connect")
}
};
local_result.write(local);
Ok(fd)
@ -41,13 +44,13 @@ pub(crate) fn bind_socket(listen: &SocketAddr, ty: SocketType) -> Result<RawFd,
let process = thread.process();
run_with_io(&process, |mut io| {
let file = match ty {
SocketType::UdpPacket => File::from_packet_socket(UdpSocket::bind((*listen).into())?),
SocketType::RawPacket => File::from_packet_socket(RawSocket::bind()?),
SocketType::TcpStream => {
File::from_listener_socket(TcpListener::bind((*listen).into())?)
}
let listen = (*listen).into();
let socket = match ty {
SocketType::RawPacket => SocketWrapper::from_packet(RawSocket::bind()?),
SocketType::UdpPacket => SocketWrapper::from_packet(UdpSocket::bind(listen)?),
SocketType::TcpStream => SocketWrapper::from_listener(TcpListener::bind(listen)?),
};
let file = File::from_socket(socket);
let fd = io.files.place_file(file, true)?;
Ok(fd)
})
@ -62,13 +65,11 @@ pub(crate) fn accept(
run_with_io(&process, |mut io| {
let file = io.files.file(socket_fd)?;
let mut remote = MaybeUninit::uninit();
let accepted_file = file.accept(&mut remote)?;
let accepted_fd = io.files.place_file(accepted_file, true)?;
unsafe {
remote_result.write(remote.assume_init().into());
}
Ok(accepted_fd)
let listener = file.as_socket()?;
let (socket, remote) = listener.accept()?;
remote_result.write(remote.into());
let fd = io.files.place_file(File::from_socket(socket), true)?;
Ok(fd)
})
}
@ -82,8 +83,9 @@ pub(crate) fn send_to(
run_with_io(&process, |io| {
let file = io.files.file(socket_fd)?;
file.send_to(buffer, recepient.map(Into::into))
let socket = file.as_socket()?;
let remote = recepient.map(Into::into);
socket.send_to(buffer, remote)
})
}
@ -97,9 +99,9 @@ pub(crate) fn receive_from(
run_with_io(&process, |io| {
let file = io.files.file(socket_fd)?;
let mut remote = MaybeUninit::uninit();
let len = file.receive_from(buffer, &mut remote)?;
remote_result.write(unsafe { remote.assume_init() }.into());
let socket = file.as_socket()?;
let (len, remote) = socket.receive_from(buffer)?;
remote_result.write(remote.into());
Ok(len)
})
}