net: move to berkeley-style sockets

This commit is contained in:
Mark Poliakov 2025-01-07 15:17:17 +02:00
parent f1256e262b
commit a4e441d236
46 changed files with 2267 additions and 1317 deletions

View File

@ -80,7 +80,10 @@ impl ArpTable {
pub fn lookup_cache(interface: u32, address: IpAddr) -> Option<MacAddress> {
let (address, _) = match address {
IpAddr::V4(address) => Self::lookup_cache_v4(interface, address),
IpAddr::V6(_) => todo!(),
IpAddr::V6(v6) => {
log::warn!("TODO: ArpTable v6 lookup: {v6}");
return None;
}
}?;
Some(address)
}
@ -92,7 +95,10 @@ impl ArpTable {
pub fn flush_address(interface: u32, address: IpAddr) -> bool {
match address {
IpAddr::V4(address) => Self::flush_address_v4(interface, address),
IpAddr::V6(_) => todo!(),
IpAddr::V6(v6) => {
log::warn!("TODO: ArpTable v6 flush: {v6}");
false
}
}
}
@ -103,7 +109,10 @@ impl ArpTable {
pub fn insert_address(interface: u32, mac: MacAddress, address: IpAddr, owned: bool) {
match address {
IpAddr::V4(address) => Self::insert_address_v4(interface, mac, address, owned),
IpAddr::V6(_) => todo!(),
IpAddr::V6(v6) => {
log::warn!("TODO: ArpTable v6 insert: {v6}");
return;
}
}
ARP_TABLE.notify.wake_all();
}
@ -203,7 +212,10 @@ fn send_request(interface: &NetworkInterface, query_address: IpAddr) -> Result<(
match query_address {
IpAddr::V4(address) => send_request_v4(interface, address),
IpAddr::V6(_) => todo!(),
IpAddr::V6(v6) => {
log::warn!("TODO: ARP IPv6 query: {v6}");
Err(Error::NotImplemented)
}
}
}

View File

@ -24,7 +24,7 @@ async fn send_v4_reply(
};
if icmp_data.len() % 2 != 0 {
todo!();
return Err(Error::InvalidArgument);
}
let l4_bytes = bytemuck::bytes_of(&reply_frame);
@ -77,6 +77,9 @@ async fn handle_v4(source_address: Ipv4Addr, l3_packet: L3Packet) -> Result<(),
pub async fn handle(l3_packet: L3Packet) -> Result<(), Error> {
match l3_packet.source_address {
IpAddr::V4(v4) => handle_v4(v4, l3_packet).await,
IpAddr::V6(_) => todo!(),
IpAddr::V6(v6) => {
log::warn!("TODO: ICMPv6 from {v6}");
Err(Error::NotImplemented)
}
}
}

View File

@ -17,7 +17,7 @@ use yggdrasil_abi::{
use crate::{
l3::{self, L3Packet},
socket::{TcpListener, TcpSocket},
socket::tcp::{TcpListener, TcpStream},
util::Assembler,
};
@ -652,12 +652,15 @@ pub async fn handle(packet: L3Packet) -> Result<(), Error> {
match tcp_frame.flags {
TcpFlags::SYN => {
if let Some(listener) = TcpListener::get(local) {
if let Some(listener) = TcpListener::get(&local)
&& listener.is_listening()
{
log::debug!("tcp: create remote stream {remote} -> {local}");
let window_size = u16::from_network_order(tcp_frame.window_size);
let tx_seq = 12345;
// Create a socket and insert it into the table
TcpSocket::accept_remote(
TcpStream::accept_remote(
listener.clone(),
local,
remote,
@ -703,16 +706,17 @@ pub async fn handle(packet: L3Packet) -> Result<(), Error> {
seq,
};
let socket = TcpSocket::get(local, remote).ok_or(Error::DoesNotExist)?;
let mut connection = socket.connection().write();
let stream = TcpStream::get(local, remote).ok_or(Error::DoesNotExist)?;
let mut connection = stream.connection.write();
match connection.handle_packet(packet, tcp_data).await? {
TcpSocketBehavior::None => (),
TcpSocketBehavior::Accept => {
socket.accept();
stream.accept();
}
TcpSocketBehavior::Remove => {
drop(connection);
socket.remove_socket()?;
stream.remove_stream()?;
}
}
Ok(())

View File

@ -1,4 +1,4 @@
#![feature(map_try_insert)]
#![feature(map_try_insert, let_chains, result_flattening)]
#![allow(clippy::type_complexity, clippy::new_without_default)]
#![no_std]

View File

@ -1,18 +1,20 @@
use alloc::{collections::BTreeMap, sync::Arc};
use libk::vfs::Socket;
use alloc::{
collections::{btree_map::Entry, BTreeMap},
sync::Arc,
};
use yggdrasil_abi::{
error::Error,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
};
pub mod udp;
pub use udp::UdpSocket;
pub mod tcp;
pub use tcp::{TcpListener, TcpSocket};
pub use tcp::TcpSocket;
pub mod raw;
pub use raw::RawSocket;
pub struct SocketTable<T: Socket> {
pub struct SocketTable<T> {
inner: BTreeMap<SocketAddr, Arc<T>>,
}
@ -71,7 +73,7 @@ impl<T> TwoWaySocketTable<T> {
}
}
impl<T: Socket> SocketTable<T> {
impl<T> SocketTable<T> {
pub const fn new() -> Self {
Self {
inner: BTreeMap::new(),
@ -95,6 +97,29 @@ impl<T: Socket> SocketTable<T> {
Err(Error::AddrInUse)
}
pub fn bind_to_ephemeral_port(&mut self, ip: IpAddr, socket: Arc<T>) -> Result<u16, Error> {
for port in 32768..u16::MAX - 1 {
let local = SocketAddr::new(ip, port);
match self.try_insert(local, socket.clone()) {
Ok(()) => return Ok(port),
Err(Error::AddrInUse) => continue,
Err(error) => return Err(error),
}
}
Err(Error::AddrInUse)
}
pub fn try_insert(&mut self, address: SocketAddr, socket: Arc<T>) -> Result<(), Error> {
match self.inner.entry(address) {
Entry::Vacant(entry) => {
entry.insert(socket);
Ok(())
}
Entry::Occupied(_) => Err(Error::AddrInUse),
}
}
pub fn try_insert_with<F: FnOnce() -> Result<Arc<T>, Error>>(
&mut self,
address: SocketAddr,
@ -124,12 +149,11 @@ impl<T: Socket> SocketTable<T> {
return Some(socket.clone());
}
match local {
SocketAddr::V4(_v4) => {
let unspec_v4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, local.port());
self.inner.get(&SocketAddr::V4(unspec_v4)).cloned()
}
SocketAddr::V6(_) => todo!(),
}
let unspec = match local {
SocketAddr::V4(_v4) => SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, local.port()).into(),
SocketAddr::V6(_v6) => SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, local.port()).into(),
};
self.inner.get(&unspec).cloned()
}
}

View File

@ -2,26 +2,25 @@ use core::{
fmt,
sync::atomic::{AtomicU32, Ordering},
task::{Context, Poll},
time::Duration,
};
use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
use async_trait::async_trait;
use libk::{
error::Error,
task::runtime::maybe_timeout,
vfs::{FileReadiness, PacketSocket, Socket},
};
use libk_mm::PageBox;
use libk_util::{
queue::BoundedMpmcQueue,
sync::{spin_rwlock::IrqSafeRwLock, IrqSafeSpinlock},
};
use yggdrasil_abi::net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketInterfaceQuery, SocketOption};
use libk_util::{queue::BoundedMpmcQueue, sync::spin_rwlock::IrqSafeRwLock};
use yggdrasil_abi::net::{SocketAddr, SocketInterfaceQuery, SocketOption};
use crate::{ethernet::L2Packet, interface::NetworkInterface};
pub struct RawSocket {
id: u32,
bound: IrqSafeSpinlock<Option<u32>>,
bound: IrqSafeRwLock<Option<u32>>,
receive_queue: BoundedMpmcQueue<L2Packet>,
}
@ -32,17 +31,15 @@ static BOUND_RAW_SOCKETS: IrqSafeRwLock<BTreeMap<u32, Vec<u32>>> =
IrqSafeRwLock::new(BTreeMap::new());
impl RawSocket {
pub fn bind() -> Result<Arc<Self>, Error> {
pub fn new() -> Arc<Self> {
let id = RAW_SOCKET_ID.fetch_add(1, Ordering::SeqCst);
let socket = Self {
let socket = Arc::new(Self {
id,
bound: IrqSafeSpinlock::new(None),
bound: IrqSafeRwLock::new(None),
receive_queue: BoundedMpmcQueue::new(256),
};
let socket = Arc::new(socket);
});
RAW_SOCKETS.write().insert(id, socket.clone());
Ok(socket)
socket
}
fn bound_packet_received(&self, packet: L2Packet) {
@ -57,6 +54,7 @@ impl RawSocket {
if let Some(ids) = bound_sockets.get(&packet.interface_id) {
for id in ids {
let socket = raw_sockets.get(id).unwrap();
log::info!("Packet -> {id}");
socket.bound_packet_received(packet.clone());
}
}
@ -80,10 +78,14 @@ impl FileReadiness for RawSocket {
}
impl Socket for RawSocket {
fn bind(self: Arc<Self>, _local: SocketAddr) -> Result<(), Error> {
Err(Error::InvalidOperation)
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::BoundHardwareAddress(mac) => {
let bound = self.bound.lock().ok_or(Error::DoesNotExist)?;
let bound = self.bound.read().ok_or(Error::NotConnected)?;
let interface = NetworkInterface::get(bound).unwrap();
*mac = interface.mac;
Ok(())
@ -95,18 +97,16 @@ impl Socket for RawSocket {
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
SocketOption::BindInterface(query) => {
let mut bound = self.bound.lock();
let mut bound = self.bound.write();
if bound.is_some() {
return Err(Error::AlreadyExists);
}
let mut bound_sockets = BOUND_RAW_SOCKETS.write();
let interface = match *query {
SocketInterfaceQuery::ById(id) => NetworkInterface::get(id),
SocketInterfaceQuery::ByName(name) => NetworkInterface::query_by_name(name),
}?;
let list = bound_sockets.entry(interface.id).or_default();
bound.replace(interface.id);
list.push(self.id);
@ -118,8 +118,8 @@ impl Socket for RawSocket {
}
}
fn close(&self) -> Result<(), Error> {
let bound = self.bound.lock().take();
fn close(self: Arc<Self>) -> Result<(), Error> {
let bound = self.bound.write().take();
if let Some(bound) = bound {
let mut bound_sockets = BOUND_RAW_SOCKETS.write();
@ -140,8 +140,8 @@ impl Socket for RawSocket {
Ok(())
}
fn local_address(&self) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
fn local_address(&self) -> Option<SocketAddr> {
None
}
fn remote_address(&self) -> Option<SocketAddr> {
@ -151,18 +151,27 @@ impl Socket for RawSocket {
#[async_trait]
impl PacketSocket for RawSocket {
async fn send_to(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error> {
fn connect(self: Arc<Self>, _remote: SocketAddr) -> Result<(), Error> {
Err(Error::InvalidOperation)
}
async fn send_to(
self: Arc<Self>,
destination: Option<SocketAddr>,
data: &[u8],
_timeout: Option<Duration>,
) -> Result<usize, Error> {
self.send_nonblocking(destination, data)
}
// TODO currently this is still blocking by NIC send code
fn send_nonblocking(
&self,
self: Arc<Self>,
_destination: Option<SocketAddr>,
buffer: &[u8],
) -> Result<usize, Error> {
// TODO cap by MTU?
let bound = self.bound.lock().ok_or(Error::InvalidOperation)?;
let bound = self.bound.read().ok_or(Error::InvalidOperation)?;
let interface = NetworkInterface::get(bound)?;
let l2_offset = interface.device.packet_prefix_size();
if buffer.len() > 4096 - l2_offset {
@ -174,12 +183,20 @@ impl PacketSocket for RawSocket {
Ok(buffer.len())
}
async fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let packet = self.receive_queue.pop_front().await;
async fn receive_from(
self: Arc<Self>,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error> {
let future = self.receive_queue.pop_front();
let packet = maybe_timeout(future, timeout).await?;
Self::packet_to_user(packet, buffer)
}
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
fn receive_nonblocking(
self: Arc<Self>,
buffer: &mut [u8],
) -> Result<(usize, SocketAddr), Error> {
let packet = self
.receive_queue
.try_pop_front()
@ -190,7 +207,7 @@ impl PacketSocket for RawSocket {
impl fmt::Debug for RawSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let bound = *self.bound.lock();
let bound = *self.bound.read();
f.debug_struct("RawSocket")
.field("interface", &bound)
.finish_non_exhaustive()

View File

@ -1,495 +0,0 @@
use core::{
fmt,
future::{poll_fn, Future},
sync::atomic::{AtomicU8, Ordering},
task::{Context, Poll},
time::Duration,
};
use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
use async_trait::async_trait;
use libk::{
block,
error::Error,
task::runtime::with_timeout,
time::monotonic_time,
vfs::{ConnectionSocket, FileReadiness, ListenerSocket, Socket},
};
use libk_util::{
sync::{
spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockWriteGuard},
IrqSafeSpinlock, IrqSafeSpinlockGuard,
},
waker::QueueWaker,
};
use yggdrasil_abi::net::{SocketAddr, SocketOption};
use crate::{
interface::NetworkInterface,
l3::Route,
l4::tcp::{TcpConnection, TcpConnectionState},
};
use super::{SocketTable, TwoWaySocketTable};
pub struct TcpSocket {
pub(crate) local: SocketAddr,
pub(crate) remote: SocketAddr,
ttl: AtomicU8,
// Listener which accepted the socket
listener: Option<Arc<TcpListener>>,
connection: IrqSafeRwLock<TcpConnection>,
}
pub struct TcpListener {
accept: SocketAddr,
// Currently active sockets
sockets: IrqSafeRwLock<BTreeMap<SocketAddr, Arc<TcpSocket>>>,
pending_accept: IrqSafeSpinlock<Vec<Arc<TcpSocket>>>,
accept_notify: QueueWaker,
}
static TCP_SOCKETS: IrqSafeRwLock<TwoWaySocketTable<TcpSocket>> =
IrqSafeRwLock::new(TwoWaySocketTable::new());
static TCP_LISTENERS: IrqSafeRwLock<SocketTable<TcpListener>> =
IrqSafeRwLock::new(SocketTable::new());
impl TcpSocket {
pub async fn connect(
remote: SocketAddr,
timeout: Option<Duration>,
) -> Result<(SocketAddr, Arc<TcpSocket>), Error> {
let future = Self::connect_async(remote);
match timeout {
Some(timeout) => with_timeout(future, timeout).await?,
None => future.await,
}
}
pub fn accept_remote(
listener: Arc<TcpListener>,
local: SocketAddr,
remote: SocketAddr,
remote_window_size: usize,
tx_seq: u32,
rx_seq: u32,
) -> Result<Arc<TcpSocket>, Error> {
let mut sockets = TCP_SOCKETS.write();
sockets.try_insert_with(local, remote, move || {
let connection = TcpConnection::new(
local,
remote,
remote_window_size,
tx_seq,
rx_seq,
TcpConnectionState::SynReceived,
);
log::debug!("Accepted TCP socket {} -> {}", local, remote);
let socket = Self {
local,
remote,
ttl: AtomicU8::new(64),
listener: Some(listener),
connection: IrqSafeRwLock::new(connection),
};
Ok(Arc::new(socket))
})
}
pub fn connection(&self) -> &IrqSafeRwLock<TcpConnection> {
&self.connection
}
pub(crate) fn accept(self: &Arc<Self>) {
if let Some(listener) = self.listener.as_ref() {
listener.accept_socket(self.clone());
}
}
pub fn get(local: SocketAddr, remote: SocketAddr) -> Option<Arc<Self>> {
TCP_SOCKETS.read().get(local, remote)
}
pub fn receive_async<'a>(
&'a self,
buffer: &'a mut [u8],
) -> impl Future<Output = Result<usize, Error>> + 'a {
poll_fn(|cx| match self.poll_receive(cx) {
Poll::Ready(Ok(mut lock)) => Poll::Ready(lock.read_nonblocking(buffer)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
})
}
pub async fn send_async(&self, data: &[u8]) -> Result<usize, Error> {
let mut pos = 0;
let mut rem = data.len();
while rem != 0 {
// TODO check MTU
let amount = rem.min(512);
self.send_segment_async(&data[pos..pos + amount]).await?;
pos += amount;
rem -= amount;
}
Ok(pos)
}
pub async fn close_async(&self, remove_from_listener: bool) -> Result<(), Error> {
// TODO timeout here
// Already closing
if self.connection.read().is_closing() {
return Ok(());
}
// Wait for all sent data to be acknowledged
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.finish().await?;
}
log::debug!(
"TCP socket closed (FinWait2/Closed): {} <-> {}",
self.local,
self.remote
);
// Wait for connection to get closed
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_finish(cx)
})
.await;
if remove_from_listener {
if let Some(listener) = self.listener.as_ref() {
listener.remove_socket(self.remote);
};
}
Ok(())
}
pub(crate) fn remove_socket(&self) -> Result<(), Error> {
log::debug!(
"TCP socket closed and removed: {} <-> {}",
self.local,
self.remote
);
let connection = self.connection.read();
debug_assert!(connection.is_closed());
TCP_SOCKETS.write().remove(self.local, self.remote)?;
connection.notify_all();
Ok(())
}
fn poll_receive(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<IrqSafeRwLockWriteGuard<TcpConnection>, Error>> {
let lock = self.connection.write();
match lock.poll_receive(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(lock)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
}
async fn send_segment_async(&self, data: &[u8]) -> Result<(), Error> {
// TODO timeout here
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.transmit(data).await?;
}
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_acknowledge(cx)
})
.await;
Ok(())
}
async fn connect_async(remote: SocketAddr) -> Result<(SocketAddr, Arc<TcpSocket>), Error> {
// Lookup route to remote
let (interface_id, _, remote_ip) =
Route::lookup(remote.ip()).ok_or(Error::HostUnreachable)?;
let remote = SocketAddr::new(remote_ip, remote.port());
let interface = NetworkInterface::get(interface_id)?;
let local_ip = interface.address.read().ok_or(Error::NetworkUnreachable)?;
let socket = {
let mut sockets = TCP_SOCKETS.write();
sockets.try_insert_with_ephemeral_port(local_ip, remote, |port| {
let t = monotonic_time();
let tx_seq = t.as_millis() as u32;
let local = SocketAddr::new(local_ip, port);
let connection =
TcpConnection::new(local, remote, 16384, tx_seq, 0, TcpConnectionState::Closed);
let socket = Self {
local,
remote,
ttl: AtomicU8::new(64),
listener: None,
connection: IrqSafeRwLock::new(connection),
};
Ok(Arc::new(socket))
})?
};
let mut t = 200;
for _ in 0..5 {
let timeout = Duration::from_millis(t);
log::debug!("Try SYN with timeout={:?}", timeout);
match socket.try_connect(timeout).await {
Ok(()) => return Ok((socket.local, socket)),
Err(Error::TimedOut) => (),
Err(error) => return Err(error),
}
t *= 2;
}
// Couldn't establish
Err(Error::TimedOut)
}
async fn try_connect(&self, timeout: Duration) -> Result<(), Error> {
{
let mut connection = self.connection.write();
connection.send_syn().await?;
}
let fut = poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_established(cx)
});
with_timeout(fut, timeout).await?
}
}
impl Socket for TcpSocket {
fn local_address(&self) -> SocketAddr {
self.local
}
fn remote_address(&self) -> Option<SocketAddr> {
Some(self.remote)
}
fn close(&self) -> Result<(), Error> {
block!(self.close_async(true).await)?
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
&SocketOption::Ttl(ttl) => {
if ttl == 0 || ttl > 255 {
return Err(Error::InvalidArgument);
}
self.ttl.store(ttl as _, Ordering::Release);
Ok(())
}
SocketOption::NoDelay(_) => {
log::warn!("TODO: TCP nodelay");
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::Ttl(ttl) => {
*ttl = self.ttl.load(Ordering::Acquire) as _;
Ok(())
}
SocketOption::NoDelay(nodelay) => {
*nodelay = false;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
}
impl FileReadiness for TcpSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.poll_receive(cx).map_ok(|_| ())
}
}
#[async_trait]
impl ConnectionSocket for TcpSocket {
async fn receive(&self, buffer: &mut [u8]) -> Result<usize, Error> {
self.receive_async(buffer).await
}
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<usize, Error> {
match self.connection.write().read_nonblocking(buffer) {
// TODO check if this really means "no data at the moment"
Ok(0) => Err(Error::WouldBlock),
Ok(len) => Ok(len),
Err(error) => Err(error),
}
}
async fn send(&self, data: &[u8]) -> Result<usize, Error> {
self.send_async(data).await
}
fn send_nonblocking(&self, data: &[u8]) -> Result<usize, Error> {
log::warn!("TODO: TCP::send_nonblocking");
block!(self.send_async(data).await)?
}
}
impl fmt::Debug for TcpSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TcpSocket")
.field("local", &self.local)
.field("remote", &self.remote)
.finish_non_exhaustive()
}
}
impl TcpListener {
pub fn bind(accept: SocketAddr) -> Result<Arc<Self>, Error> {
TCP_LISTENERS.write().try_insert_with(accept, || {
let listener = TcpListener {
accept,
sockets: IrqSafeRwLock::new(BTreeMap::new()),
pending_accept: IrqSafeSpinlock::new(Vec::new()),
accept_notify: QueueWaker::new(),
};
log::debug!("TCP Listener opened: {}", accept);
Ok(Arc::new(listener))
})
}
pub fn get(local: SocketAddr) -> Option<Arc<Self>> {
TCP_LISTENERS.read().get(&local)
}
pub fn accept_async(&self) -> impl Future<Output = Result<Arc<TcpSocket>, Error>> + '_ {
poll_fn(|cx| match self.poll_accept(cx) {
Poll::Ready(mut lock) => Poll::Ready(Ok(lock.pop().unwrap())),
Poll::Pending => Poll::Pending,
})
}
fn accept_socket(&self, socket: Arc<TcpSocket>) {
log::debug!("{}: accept {}", self.accept, socket.remote);
self.sockets.write().insert(socket.remote, socket.clone());
self.pending_accept.lock().push(socket);
self.accept_notify.wake_one();
}
fn remove_socket(&self, remote: SocketAddr) {
log::debug!("Remove client {}", remote);
self.sockets.write().remove(&remote);
}
fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<IrqSafeSpinlockGuard<Vec<Arc<TcpSocket>>>> {
let lock = self.pending_accept.lock();
self.accept_notify.register(cx.waker());
if !lock.is_empty() {
self.accept_notify.remove(cx.waker());
Poll::Ready(lock)
} else {
Poll::Pending
}
}
}
impl Socket for TcpListener {
fn local_address(&self) -> SocketAddr {
self.accept
}
fn remote_address(&self) -> Option<SocketAddr> {
None
}
fn close(&self) -> Result<(), Error> {
// TODO if clients not closed already, send RST?
TCP_LISTENERS.write().remove(self.accept)
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
SocketOption::Ipv6Only(_v6only) => {
log::warn!("TODO: TCP listener IPv6-only");
Err(Error::InvalidOperation)
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::Ipv6Only(v6only) => {
*v6only = false;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
}
impl FileReadiness for TcpListener {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.poll_accept(cx).map(|_| Ok(()))
}
}
#[async_trait]
impl ListenerSocket for TcpListener {
async fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let socket = self.accept_async().await?;
let remote = socket.remote;
Ok((remote, socket))
}
fn accept_nonblocking(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let socket = self.pending_accept.lock().pop().ok_or(Error::WouldBlock)?;
let remote = socket.remote;
Ok((remote, socket))
}
}
impl fmt::Debug for TcpListener {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TcpListener")
.field("local", &self.accept)
.finish_non_exhaustive()
}
}

View File

@ -0,0 +1,112 @@
use core::{
fmt,
future::poll_fn,
sync::atomic::{AtomicBool, Ordering},
task::{Context, Poll},
};
use alloc::{collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
use libk::error::Error;
use libk_util::{
sync::{spin_rwlock::IrqSafeRwLock, IrqSafeSpinlock, IrqSafeSpinlockGuard},
waker::QueueWaker,
};
use yggdrasil_abi::net::SocketAddr;
use crate::socket::SocketTable;
use super::TcpStream;
pub struct TcpListener {
pub(super) local: SocketAddr,
pub(super) listening: AtomicBool,
// Currently active sockets
sockets: IrqSafeRwLock<BTreeMap<SocketAddr, Arc<TcpStream>>>,
pending_accept: IrqSafeSpinlock<Vec<Arc<TcpStream>>>,
accept_notify: QueueWaker,
}
static TCP_LISTENERS: IrqSafeRwLock<SocketTable<TcpListener>> =
IrqSafeRwLock::new(SocketTable::new());
impl TcpListener {
pub fn bind(local: SocketAddr) -> Result<Arc<Self>, Error> {
if local.port() == 0 {
// TODO: ephemeral binding
todo!();
}
let listener = Arc::new(TcpListener {
local,
listening: AtomicBool::new(false),
sockets: IrqSafeRwLock::new(BTreeMap::new()),
pending_accept: IrqSafeSpinlock::new(Vec::new()),
accept_notify: QueueWaker::new(),
});
TCP_LISTENERS.write().try_insert(local, listener.clone())?;
Ok(listener)
}
pub fn close(&self) -> Result<(), Error> {
// TODO if clients not closed already, send RST?
TCP_LISTENERS.write().remove(self.local)
}
pub(super) async fn accept(&self) -> Result<Arc<TcpStream>, Error> {
poll_fn(|cx| match self.poll_accept(cx) {
Poll::Ready(mut lock) => Poll::Ready(Ok(lock.pop().unwrap())),
Poll::Pending => Poll::Pending,
})
.await
}
pub(super) fn accept_nonblocking(&self) -> Result<Arc<TcpStream>, Error> {
self.pending_accept.lock().pop().ok_or(Error::WouldBlock)
}
pub(super) fn poll_accept(
&self,
cx: &mut Context<'_>,
) -> Poll<IrqSafeSpinlockGuard<Vec<Arc<TcpStream>>>> {
let lock = self.pending_accept.lock();
self.accept_notify.register(cx.waker());
if !lock.is_empty() {
self.accept_notify.remove(cx.waker());
Poll::Ready(lock)
} else {
Poll::Pending
}
}
pub(super) fn accept_stream(&self, stream: Arc<TcpStream>) {
log::debug!("{}: accept {}", self.local, stream.remote);
self.sockets.write().insert(stream.remote, stream.clone());
self.pending_accept.lock().push(stream);
self.accept_notify.wake_one();
}
pub(super) fn remove_stream(&self, remote: SocketAddr) {
log::debug!("Remove remote stream {}", remote);
self.sockets.write().remove(&remote);
}
pub fn is_listening(&self) -> bool {
self.listening.load(Ordering::Acquire)
}
pub fn get(local: &SocketAddr) -> Option<Arc<TcpListener>> {
TCP_LISTENERS.read().get(local)
}
}
impl fmt::Debug for TcpListener {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TcpListener")
.field("listening", &self.is_listening())
.field("local", &self.local)
.finish_non_exhaustive()
}
}

View File

@ -0,0 +1,270 @@
use core::{
fmt, mem,
sync::atomic::Ordering,
task::{Context, Poll},
time::Duration,
};
use alloc::{boxed::Box, sync::Arc};
use async_trait::async_trait;
use libk::{
block,
error::Error,
task::runtime::maybe_timeout,
vfs::{ConnectionSocket, FileReadiness, Socket},
};
use libk_util::sync::spin_rwlock::IrqSafeRwLock;
use yggdrasil_abi::net::{SocketAddr, SocketOption};
mod listener;
mod stream;
pub use listener::TcpListener;
pub use stream::TcpStream;
pub struct TcpSocket {
options: IrqSafeRwLock<TcpSocketOptions>,
inner: IrqSafeRwLock<TcpSocketInner>,
}
pub enum TcpSocketInner {
Empty,
Listener(Arc<TcpListener>),
Stream(Arc<TcpStream>),
}
struct TcpSocketOptions {
ttl: u8,
nodelay: bool,
v6_only: bool,
}
impl Default for TcpSocketOptions {
fn default() -> Self {
Self {
ttl: 64,
nodelay: false,
v6_only: false,
}
}
}
impl TcpSocket {
pub fn new() -> Arc<Self> {
Arc::new(Self {
options: IrqSafeRwLock::new(TcpSocketOptions::default()),
inner: IrqSafeRwLock::new(TcpSocketInner::Empty),
})
}
fn as_stream(&self) -> Option<Arc<TcpStream>> {
if let TcpSocketInner::Stream(stream) = &*self.inner.read() {
Some(stream.clone())
} else {
None
}
}
fn as_listener(&self) -> Option<Arc<TcpListener>> {
if let TcpSocketInner::Listener(listener) = &*self.inner.read() {
Some(listener.clone())
} else {
None
}
}
}
impl Socket for TcpSocket {
fn bind(self: Arc<Self>, local: SocketAddr) -> Result<(), Error> {
let mut inner = self.inner.write();
// Already connected or bound
if !matches!(&*inner, TcpSocketInner::Empty) {
return Err(Error::InvalidOperation);
}
// Bind a listener socket
let listener = TcpListener::bind(local)?;
*inner = TcpSocketInner::Listener(listener);
Ok(())
}
fn close(self: Arc<Self>) -> Result<(), Error> {
let inner = mem::replace(&mut *self.inner.write(), TcpSocketInner::Empty);
match inner {
TcpSocketInner::Empty => Ok(()),
TcpSocketInner::Stream(socket) => block!(socket.close(true).await)?,
TcpSocketInner::Listener(socket) => socket.close(),
}
}
fn local_address(&self) -> Option<SocketAddr> {
match &*self.inner.read() {
TcpSocketInner::Empty => None,
TcpSocketInner::Stream(socket) => Some(socket.local),
TcpSocketInner::Listener(socket) => Some(socket.local),
}
}
fn remote_address(&self) -> Option<SocketAddr> {
match &*self.inner.read() {
TcpSocketInner::Empty | TcpSocketInner::Listener(_) => None,
TcpSocketInner::Stream(socket) => Some(socket.remote),
}
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match *option {
SocketOption::Ttl(ttl) => {
if !(1..256).contains(&ttl) {
return Err(Error::InvalidArgument);
}
let ttl = ttl as u8;
self.options.write().ttl = ttl;
Ok(())
}
SocketOption::NoDelay(nodelay) => {
self.options.write().nodelay = nodelay;
Ok(())
}
SocketOption::Ipv6Only(v6_only) => {
self.options.write().v6_only = v6_only;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::Ttl(ttl) => {
*ttl = self.options.read().ttl as u32;
Ok(())
}
SocketOption::NoDelay(nodelay) => {
*nodelay = self.options.read().nodelay;
Ok(())
}
SocketOption::Ipv6Only(v6_only) => {
*v6_only = self.options.read().v6_only;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
}
#[async_trait]
impl ConnectionSocket for TcpSocket {
async fn connect(
self: Arc<Self>,
remote: SocketAddr,
_timeout: Option<Duration>,
) -> Result<(), Error> {
let mut inner = self.inner.write();
// Already connected or bound
if !matches!(&*inner, TcpSocketInner::Empty) {
return Err(Error::InvalidOperation);
}
let stream = TcpStream::connect(remote).await?;
*inner = TcpSocketInner::Stream(stream);
Ok(())
}
async fn receive(
&self,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error> {
let stream = self.as_stream().ok_or(Error::NotConnected)?;
let len = stream.receive(buffer, timeout).await?;
Ok((len, stream.remote))
}
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let stream = self.as_stream().ok_or(Error::NotConnected)?;
let len = stream.receive_nonblocking(buffer)?;
Ok((len, stream.remote))
}
async fn send(&self, data: &[u8], timeout: Option<Duration>) -> Result<usize, Error> {
let stream = self.as_stream().ok_or(Error::NotConnected)?;
let future = stream.send(data);
maybe_timeout(future, timeout).await?
}
fn send_nonblocking(&self, buffer: &[u8]) -> Result<usize, Error> {
let stream = self.as_stream().ok_or(Error::NotConnected)?;
let len = stream.send_nonblocking(buffer)?;
Ok(len)
}
fn listen(self: Arc<Self>) -> Result<(), Error> {
let listener = self.as_listener().ok_or(Error::InvalidOperation)?;
match listener
.listening
.compare_exchange(false, true, Ordering::Release, Ordering::Relaxed)
{
Ok(_) => Ok(()),
Err(_) => Err(Error::InvalidOperation),
}
}
async fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let listener = self.as_listener().ok_or(Error::InvalidOperation)?;
let stream = listener.accept().await?;
let remote = stream.remote;
let socket = Arc::new(TcpSocket {
options: IrqSafeRwLock::new(TcpSocketOptions::default()),
inner: IrqSafeRwLock::new(TcpSocketInner::Stream(stream)),
});
Ok((remote, socket))
}
fn accept_nonblocking(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let listener = self.as_listener().ok_or(Error::InvalidOperation)?;
let stream = listener.accept_nonblocking()?;
let remote = stream.remote;
let socket = Arc::new(TcpSocket {
options: IrqSafeRwLock::new(TcpSocketOptions::default()),
inner: IrqSafeRwLock::new(TcpSocketInner::Stream(stream)),
});
Ok((remote, socket))
}
async fn shutdown(&self, read: bool, write: bool) -> Result<(), Error> {
log::warn!("TODO: shutdown(read = {read}, write = {write})");
Ok(())
}
}
impl FileReadiness for TcpSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match &*self.inner.read() {
TcpSocketInner::Empty => Poll::Ready(Ok(())),
TcpSocketInner::Stream(socket) => socket.poll_receive(cx).map_ok(|_| ()),
TcpSocketInner::Listener(socket) => socket.poll_accept(cx).map(|_| Ok(())),
}
}
}
impl fmt::Debug for TcpSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &*self.inner.read() {
TcpSocketInner::Empty => f.debug_struct("TcpSocket").finish_non_exhaustive(),
TcpSocketInner::Stream(socket) => f
.debug_struct("TcpSocket")
.field("stream", socket)
.finish_non_exhaustive(),
TcpSocketInner::Listener(socket) => f
.debug_struct("TcpSocket")
.field("listener", socket)
.finish_non_exhaustive(),
}
}
}

View File

@ -0,0 +1,309 @@
use core::{
fmt,
future::poll_fn,
task::{Context, Poll},
time::Duration,
};
use alloc::sync::Arc;
use libk::{
error::Error,
task::runtime::{maybe_timeout, with_timeout},
time::monotonic_time,
};
use libk_util::sync::spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockWriteGuard};
use yggdrasil_abi::net::SocketAddr;
use crate::{
interface::NetworkInterface,
l3::Route,
l4::tcp::{TcpConnection, TcpConnectionState},
socket::TwoWaySocketTable,
};
use super::TcpListener;
pub struct TcpStream {
pub(super) local: SocketAddr,
pub(super) remote: SocketAddr,
// Listener which accepted the socket
listener: Option<Arc<TcpListener>>,
pub(crate) connection: IrqSafeRwLock<TcpConnection>,
}
static TCP_STREAMS: IrqSafeRwLock<TwoWaySocketTable<TcpStream>> =
IrqSafeRwLock::new(TwoWaySocketTable::new());
impl TcpStream {
pub async fn connect(remote: SocketAddr) -> Result<Arc<Self>, Error> {
// Lookup route to remote
let (interface_id, _, remote_ip) =
Route::lookup(remote.ip()).ok_or(Error::HostUnreachable)?;
let remote = SocketAddr::new(remote_ip, remote.port());
let interface = NetworkInterface::get(interface_id)?;
let local_ip = interface.address.read().ok_or(Error::NetworkUnreachable)?;
// Create a new TcpStream with an ephemeral port
let stream = {
let mut streams = TCP_STREAMS.write();
streams.try_insert_with_ephemeral_port(local_ip, remote, |port| {
let t = monotonic_time();
let tx_seq = t.as_millis() as u32;
let local = SocketAddr::new(local_ip, port);
let connection =
TcpConnection::new(local, remote, 16384, tx_seq, 0, TcpConnectionState::Closed);
Ok(Arc::new(Self {
local,
remote,
listener: None,
connection: IrqSafeRwLock::new(connection),
}))
})?
};
let mut t = 200;
for _ in 0..5 {
let timeout = Duration::from_millis(t);
log::debug!("Try SYN with timeout={:?}", timeout);
match stream.try_connect(timeout).await {
Ok(()) => return Ok(stream),
Err(Error::TimedOut) => (),
Err(error) => return Err(error),
}
t *= 2;
}
// Couldn't establish
stream.close(false).await.ok();
Err(Error::TimedOut)
}
async fn try_connect(&self, timeout: Duration) -> Result<(), Error> {
{
let mut connection = self.connection.write();
connection.send_syn().await?;
}
let fut = poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_established(cx)
});
with_timeout(fut, timeout).await?
}
pub fn accept(self: &Arc<Self>) {
if let Some(listener) = self.listener.as_ref() {
listener.accept_stream(self.clone());
}
}
pub fn accept_remote(
listener: Arc<TcpListener>,
local: SocketAddr,
remote: SocketAddr,
remote_window_size: usize,
tx_seq: u32,
rx_seq: u32,
) -> Result<Arc<TcpStream>, Error> {
let mut streams = TCP_STREAMS.write();
streams.try_insert_with(local, remote, move || {
let connection = TcpConnection::new(
local,
remote,
remote_window_size,
tx_seq,
rx_seq,
TcpConnectionState::SynReceived,
);
log::debug!("Accepted TCP socket {} -> {}", local, remote);
let socket = Self {
local,
remote,
listener: Some(listener),
connection: IrqSafeRwLock::new(connection),
};
Ok(Arc::new(socket))
})
}
pub async fn close(&self, remove_from_listener: bool) -> Result<(), Error> {
// TODO timeout here
// Already closing
if self.connection.read().is_closing() {
return Ok(());
}
// Wait for all sent data to be acknowledged
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.finish().await?;
}
log::debug!(
"TCP socket closed (FinWait2/Closed): {} <-> {}",
self.local,
self.remote
);
// Wait for connection to get closed
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_finish(cx)
})
.await;
if remove_from_listener {
if let Some(listener) = self.listener.as_ref() {
listener.remove_stream(self.remote);
};
}
Ok(())
}
pub(crate) fn remove_stream(&self) -> Result<(), Error> {
log::debug!(
"TCP socket closed and removed: {} <-> {}",
self.local,
self.remote
);
let connection = self.connection.read();
debug_assert!(connection.is_closed());
TCP_STREAMS.write().remove(self.local, self.remote)?;
connection.notify_all();
Ok(())
}
pub fn get(local: SocketAddr, remote: SocketAddr) -> Option<Arc<TcpStream>> {
TCP_STREAMS.read().get(local, remote)
}
pub(super) fn poll_receive(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<IrqSafeRwLockWriteGuard<TcpConnection>, Error>> {
let lock = self.connection.write();
match lock.poll_receive(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(lock)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
}
pub async fn receive(
&self,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<usize, Error> {
let future = poll_fn(|cx| match self.poll_receive(cx) {
Poll::Ready(Ok(mut lock)) => Poll::Ready(lock.read_nonblocking(buffer)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
});
maybe_timeout(future, timeout).await.flatten()
}
pub async fn send(&self, data: &[u8]) -> Result<usize, Error> {
let mut pos = 0;
let mut rem = data.len();
while rem != 0 {
// TODO check MTU
let amount = rem.min(512);
self.send_segment(&data[pos..pos + amount]).await?;
pos += amount;
rem -= amount;
}
Ok(pos)
}
pub fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<usize, Error> {
if buffer.is_empty() {
return Ok(0);
}
let mut lock = self.connection.write();
match lock.read_nonblocking(buffer) {
Ok(0) => Err(Error::WouldBlock),
res => res,
}
}
pub fn send_nonblocking(&self, data: &[u8]) -> Result<usize, Error> {
if data.is_empty() {
return Ok(0);
}
let mut pos = 0;
let mut rem = data.len();
let mut sent = false;
while rem != 0 {
// TODO check MTU
let amount = rem.min(512);
match self.send_segment_nonblocking(&data[pos..pos + amount]) {
Ok(()) => sent = true,
Err(Error::WouldBlock) => break,
Err(error) => return Err(error),
}
pos += amount;
rem -= amount;
}
// No data sent and WouldBlock returned
if !sent {
return Err(Error::WouldBlock);
}
Ok(pos)
}
async fn send_segment(&self, data: &[u8]) -> Result<(), Error> {
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.transmit(data).await?;
}
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_acknowledge(cx)
})
.await;
Ok(())
}
fn send_segment_nonblocking(&self, _data: &[u8]) -> Result<(), Error> {
todo!()
}
}
impl fmt::Debug for TcpStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TcpStream")
.field("local", &self.local)
.field("remote", &self.remote)
.finish_non_exhaustive()
}
}

View File

@ -2,6 +2,7 @@ use core::{
fmt,
sync::atomic::{AtomicBool, AtomicU8, Ordering},
task::{Context, Poll},
time::Duration,
};
use alloc::{boxed::Box, sync::Arc, vec::Vec};
@ -9,18 +10,22 @@ use async_trait::async_trait;
use libk::{
block,
error::Error,
task::runtime::maybe_timeout,
vfs::{FileReadiness, PacketSocket, Socket},
};
use libk_util::{queue::BoundedMpmcQueue, sync::spin_rwlock::IrqSafeRwLock};
use yggdrasil_abi::net::{SocketAddr, SocketOption};
use libk_util::{
queue::BoundedMpmcQueue,
sync::spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockReadGuard},
};
use yggdrasil_abi::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketOption};
use crate::l4;
use super::SocketTable;
pub struct UdpSocket {
local: SocketAddr,
remote: Option<SocketAddr>,
local: IrqSafeRwLock<Option<SocketAddr>>,
remote: IrqSafeRwLock<Option<SocketAddr>>,
broadcast: AtomicBool,
ttl: AtomicU8,
@ -32,32 +37,18 @@ pub struct UdpSocket {
static UDP_SOCKETS: IrqSafeRwLock<SocketTable<UdpSocket>> = IrqSafeRwLock::new(SocketTable::new());
impl UdpSocket {
fn create_socket(local: SocketAddr) -> Arc<UdpSocket> {
log::debug!("UDP socket opened: {}", local);
Arc::new(UdpSocket {
local,
ttl: AtomicU8::new(64),
remote: None,
pub fn new() -> Arc<Self> {
Arc::new(Self {
local: IrqSafeRwLock::new(None),
remote: IrqSafeRwLock::new(None),
broadcast: AtomicBool::new(false),
ttl: AtomicU8::new(64),
receive_queue: BoundedMpmcQueue::new(128),
})
}
pub fn bind(address: SocketAddr) -> Result<Arc<UdpSocket>, Error> {
let mut sockets = UDP_SOCKETS.write();
if address.port() == 0 {
sockets.try_insert_with_ephemeral_port(address.ip(), |port| {
Ok(Self::create_socket(SocketAddr::new(address.ip(), port)))
})
} else {
sockets.try_insert_with(address, move || Ok(Self::create_socket(address)))
}
}
pub fn connect(&self, _address: SocketAddr) -> Result<(), Error> {
todo!()
}
pub fn get(local: &SocketAddr) -> Option<Arc<UdpSocket>> {
UDP_SOCKETS.read().get(local)
}
@ -67,6 +58,25 @@ impl UdpSocket {
.try_push_back((source, Vec::from(data)))
.map_err(|_| Error::QueueFull)
}
// If address is bound, keep it, if not, bind an ephemeral port
pub fn ensure_address(self: &Arc<Self>, v6: bool) -> Result<u16, Error> {
let local = self.local.read();
if let Some(address) = *local {
Ok(address.port())
} else {
let mut local = IrqSafeRwLockReadGuard::upgrade(local);
let ip = match v6 {
true => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
false => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
};
let port = UDP_SOCKETS
.write()
.bind_to_ephemeral_port(ip, self.clone())?;
*local = Some(SocketAddr::new(ip, port));
Ok(port)
}
}
}
impl FileReadiness for UdpSocket {
@ -77,18 +87,35 @@ impl FileReadiness for UdpSocket {
#[async_trait]
impl PacketSocket for UdpSocket {
async fn send_to(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error> {
let Some(destination) = destination else {
// TODO can still send without setting address if "connected"
return Err(Error::InvalidArgument);
};
fn connect(self: Arc<Self>, remote: SocketAddr) -> Result<(), Error> {
let mut connected = self.remote.write();
if connected.is_some() {
return Err(Error::InvalidOperation);
}
*connected = Some(remote);
Ok(())
}
async fn send_to(
self: Arc<Self>,
destination: Option<SocketAddr>,
data: &[u8],
_timeout: Option<Duration>,
) -> Result<usize, Error> {
let destination = destination
.or_else(|| self.remote_address())
.ok_or(Error::NotConnected)?;
// If socket wasn't bound yet, bind it to an ephemeral port
let port = self.ensure_address(destination.ip().is_ipv6())?;
// TODO check that destnation family matches self family
match (self.broadcast.load(Ordering::Acquire), destination.ip()) {
// SendTo in broadcast?
(true, _) => todo!(),
// TODO broadcast
(true, _) => return Err(Error::NotImplemented),
(false, _) => {
l4::udp::send(
self.local.port(),
port,
destination.ip(),
destination.port(),
self.ttl.load(Ordering::Acquire),
@ -102,16 +129,21 @@ impl PacketSocket for UdpSocket {
}
fn send_nonblocking(
&self,
self: Arc<Self>,
destination: Option<SocketAddr>,
buffer: &[u8],
) -> Result<usize, Error> {
log::warn!("TODO: UDP::send_nonblocking()");
block!(self.send_to(destination, buffer).await)?
block!(self.send_to(destination, buffer, None).await)?
}
async fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let (source, packet) = self.receive_queue.pop_front().await;
async fn receive_from(
self: Arc<Self>,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error> {
let future = self.receive_queue.pop_front();
let (source, packet) = maybe_timeout(future, timeout).await?;
if packet.len() > buffer.len() {
// TODO check how other OSs handle this
@ -121,7 +153,10 @@ impl PacketSocket for UdpSocket {
Ok((packet.len(), source))
}
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
fn receive_nonblocking(
self: Arc<Self>,
buffer: &mut [u8],
) -> Result<(usize, SocketAddr), Error> {
let (source, packet) = self
.receive_queue
.try_pop_front()
@ -137,23 +172,43 @@ impl PacketSocket for UdpSocket {
}
impl Socket for UdpSocket {
fn local_address(&self) -> SocketAddr {
self.local
fn bind(self: Arc<Self>, local: SocketAddr) -> Result<(), Error> {
let mut bound = self.local.write();
if bound.is_some() {
return Err(Error::InvalidOperation);
}
if local.port() == 0 {
let port = UDP_SOCKETS
.write()
.bind_to_ephemeral_port(local.ip(), self.clone())?;
*bound = Some(SocketAddr::new(local.ip(), port));
} else {
UDP_SOCKETS.write().try_insert(local, self.clone())?;
*bound = Some(local);
}
Ok(())
}
fn local_address(&self) -> Option<SocketAddr> {
*self.local.read()
}
fn remote_address(&self) -> Option<SocketAddr> {
self.remote
*self.remote.read()
}
fn close(&self) -> Result<(), Error> {
log::debug!("UDP socket closed: {}", self.local);
UDP_SOCKETS.write().remove(self.local)
fn close(self: Arc<Self>) -> Result<(), Error> {
if let Some(local) = self.local.write().take() {
log::debug!("UDP socket closed: {}", local);
UDP_SOCKETS.write().remove(local)?;
}
Ok(())
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
&SocketOption::Broadcast(broadcast) => {
log::debug!("{} broadcast: {}", self.local, broadcast);
// log::debug!("{} broadcast: {}", self.local, broadcast);
self.broadcast.store(broadcast, Ordering::Release);
Ok(())
}
@ -209,9 +264,12 @@ impl Socket for UdpSocket {
impl fmt::Debug for UdpSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let local = *self.local.read();
let remote = *self.remote.read();
f.debug_struct("UdpSocket")
.field("local", &self.local)
.field("remote", &self.remote)
.field("local", &local)
.field("remote", &remote)
.finish_non_exhaustive()
}
}

View File

@ -44,6 +44,15 @@ impl RwLockInner {
self.value.fetch_nand(Self::LOCKED_WRITE, Ordering::Release);
}
#[inline]
fn upgrade(&self) {
// At least one read lock is held by this task.
// When there's *exactly* one lock (being this task) held, upgrade is possible
while !self.try_upgrade() {
core::hint::spin_loop();
}
}
#[inline]
fn acquire_read_raw(&self) -> usize {
let value = self.value.fetch_add(Self::LOCKED_READ, Ordering::Acquire);
@ -77,6 +86,18 @@ impl RwLockInner {
.is_ok()
}
#[inline]
fn try_upgrade(&self) -> bool {
self.value
.compare_exchange(
Self::LOCKED_READ,
Self::LOCKED_WRITE,
Ordering::Acquire,
Ordering::Relaxed,
)
.is_ok()
}
#[inline]
fn acquire_read(&self) {
while !self.try_acquire_read() {
@ -133,6 +154,11 @@ impl<T> IrqSafeRwLock<T> {
self.inner.downgrade_write();
}
#[inline]
unsafe fn upgrade(&self) {
self.inner.upgrade();
}
unsafe fn release_read(&self) {
self.inner.release_read();
}
@ -159,10 +185,26 @@ impl<T> Drop for IrqSafeRwLockReadGuard<'_, T> {
}
}
impl<T> IrqSafeRwLockReadGuard<'_, T> {
impl<'a, T> IrqSafeRwLockReadGuard<'a, T> {
pub fn get(guard: &Self) -> *const T {
guard.lock.value.get()
}
pub fn upgrade(guard: IrqSafeRwLockReadGuard<'a, T>) -> IrqSafeRwLockWriteGuard<'a, T> {
let lock = guard.lock;
let irq_guard = IrqGuard::acquire();
// Read lock still held
core::mem::forget(guard);
unsafe {
lock.upgrade();
}
IrqSafeRwLockWriteGuard {
lock,
_guard: irq_guard,
}
}
}
impl<'a, T> IrqSafeRwLockWriteGuard<'a, T> {

View File

@ -372,11 +372,9 @@ impl IoContext {
if !flags.contains_any(RemoveFlags::DIRECTORY | RemoveFlags::DIRECTORY_ONLY) {
return Err(Error::IsADirectory);
}
} else {
if flags.contains_any(RemoveFlags::DIRECTORY_ONLY) {
} else if flags.contains_any(RemoveFlags::DIRECTORY_ONLY) {
return Err(Error::NotADirectory);
}
}
parent.remove_file(filename, access)
}

View File

@ -35,7 +35,7 @@ pub use path::{Filename, OwnedFilename};
pub use poll::FdPoll;
pub use pty::{PseudoTerminalMaster, PseudoTerminalSlave};
pub use shared_memory::SharedMemory;
pub use socket::{ConnectionSocket, ListenerSocket, PacketSocket, Socket, SocketWrapper};
pub use socket::{ConnectionSocket, PacketSocket, Socket, SocketWrapper};
pub use terminal::{Terminal, TerminalInput, TerminalOutput};
pub use timer::TimerFile;
pub use traits::{FileReadiness, Read, Seek, Write};

View File

@ -9,18 +9,21 @@ use async_trait::async_trait;
use libk_util::sync::spin_rwlock::IrqSafeRwLock;
use yggdrasil_abi::{
error::Error,
net::{SocketAddr, SocketOption},
net::{SocketAddr, SocketOption, SocketShutdown},
};
use crate::{task::runtime::maybe_timeout, vfs::FileReadiness};
use crate::vfs::FileReadiness;
use super::{File, FileRef};
enum SocketInner {
Connection(Arc<dyn ConnectionSocket + Send + 'static>),
Listener(Arc<dyn ListenerSocket + Send + 'static>),
Packet(Arc<dyn PacketSocket + Send + 'static>),
Connection(Arc<dyn ConnectionSocket + 'static>),
// Listener(Arc<dyn ListenerSocket + Send + 'static>),
Packet(Arc<dyn PacketSocket + 'static>),
}
struct InnerOptions {
connect_timeout: Option<Duration>,
recv_timeout: Option<Duration>,
send_timeout: Option<Duration>,
non_blocking: bool,
@ -34,14 +37,16 @@ pub struct SocketWrapper {
/// Interface for interacting with network sockets
#[allow(unused)]
pub trait Socket: FileReadiness + fmt::Debug + Send {
fn bind(self: Arc<Self>, local: SocketAddr) -> Result<(), Error>;
/// Socket listen/receive address
fn local_address(&self) -> SocketAddr;
fn local_address(&self) -> Option<SocketAddr>;
/// Socket remote address
fn remote_address(&self) -> Option<SocketAddr>;
/// Closes a socket
fn close(&self) -> Result<(), Error>;
fn close(self: Arc<Self>) -> Result<(), Error>;
/// Updates a socket option
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
@ -59,40 +64,62 @@ pub trait Socket: FileReadiness + fmt::Debug + Send {
pub trait PacketSocket: Socket {
/// Receives a packet into provided buffer. Will return an error if packet cannot be placed
/// within the buffer.
async fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error>;
async fn receive_from(
self: Arc<Self>,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error>;
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error>;
fn receive_nonblocking(
self: Arc<Self>,
buffer: &mut [u8],
) -> Result<(usize, SocketAddr), Error>;
/// Sends provided data to the recepient specified by `destination`
async fn send_to(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error>;
async fn send_to(
self: Arc<Self>,
destination: Option<SocketAddr>,
data: &[u8],
timeout: Option<Duration>,
) -> Result<usize, Error>;
fn send_nonblocking(
&self,
self: Arc<Self>,
destination: Option<SocketAddr>,
buffer: &[u8],
) -> Result<usize, Error>;
fn connect(self: Arc<Self>, remote: SocketAddr) -> Result<(), Error>;
}
/// Connection-based client socket interface
#[async_trait]
pub trait ConnectionSocket: Socket {
/// Receives data into provided buffer
async fn receive(&self, buffer: &mut [u8]) -> Result<usize, Error>;
async fn connect(
self: Arc<Self>,
remote: SocketAddr,
timeout: Option<Duration>,
) -> Result<(), Error>;
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<usize, Error>;
async fn receive(
&self,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error>;
/// Transmits data
async fn send(&self, data: &[u8]) -> Result<usize, Error>;
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error>;
async fn send(&self, data: &[u8], timeout: Option<Duration>) -> Result<usize, Error>;
fn send_nonblocking(&self, buffer: &[u8]) -> Result<usize, Error>;
}
/// Connection-based listener socket interface
#[async_trait]
pub trait ListenerSocket: Socket {
/// Blocks the execution until an incoming connection is accepted
fn listen(self: Arc<Self>) -> Result<(), Error>;
async fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error>;
fn accept_nonblocking(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error>;
async fn shutdown(&self, read: bool, write: bool) -> Result<(), Error>;
}
impl SocketWrapper {
@ -110,103 +137,108 @@ impl SocketWrapper {
}
}
pub fn from_listener(socket: Arc<dyn ListenerSocket + 'static>) -> Self {
Self {
inner: SocketInner::Listener(socket),
options: IrqSafeRwLock::new(InnerOptions::default()),
async fn send_inner(&self, data: &[u8], remote: Option<SocketAddr>) -> Result<usize, Error> {
let timeout = self.options.read().send_timeout;
match &self.inner {
SocketInner::Packet(socket) => socket.clone().send_to(remote, data, timeout).await,
SocketInner::Connection(socket) => socket.send(data, timeout).await,
}
}
pub fn accept(&self) -> Result<(SocketWrapper, SocketAddr), Error> {
let SocketInner::Listener(socket) = &self.inner else {
async fn receive_inner(&self, data: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let timeout = self.options.read().recv_timeout;
match &self.inner {
SocketInner::Packet(socket) => socket.clone().receive_from(data, timeout).await,
SocketInner::Connection(socket) => socket.clone().receive(data, timeout).await,
}
}
async fn connect_inner(&self, remote: SocketAddr) -> Result<(), Error> {
let timeout = self.options.read().connect_timeout;
match &self.inner {
SocketInner::Packet(socket) => socket.clone().connect(remote),
SocketInner::Connection(socket) => socket.clone().connect(remote, timeout).await,
}
}
pub fn connect(&self, remote: SocketAddr) -> Result<(), Error> {
block!(self.connect_inner(remote).await)?
}
pub fn send_to(&self, data: &[u8], remote: Option<SocketAddr>) -> Result<usize, Error> {
if self.options.read().non_blocking {
match &self.inner {
SocketInner::Packet(socket) => socket.clone().send_nonblocking(remote, data),
SocketInner::Connection(socket) => socket.clone().send_nonblocking(data),
}
} else {
block!(self.send_inner(data, remote).await)?
}
}
pub fn receive_from(&self, data: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
if self.options.read().non_blocking {
match &self.inner {
SocketInner::Packet(socket) => socket.clone().receive_nonblocking(data),
SocketInner::Connection(socket) => socket.receive_nonblocking(data),
}
} else {
block!(self.receive_inner(data).await)?
}
}
pub fn bind(&self, local: SocketAddr) -> Result<(), Error> {
match &self.inner {
SocketInner::Packet(socket) => socket.clone().bind(local),
SocketInner::Connection(socket) => socket.clone().bind(local),
}
}
pub fn listen(&self) -> Result<(), Error> {
match &self.inner {
SocketInner::Packet(_) => Err(Error::InvalidOperation),
SocketInner::Connection(socket) => socket.clone().listen(),
}
}
pub fn accept(&self) -> Result<(FileRef, SocketAddr), Error> {
let SocketInner::Connection(socket) = &self.inner else {
return Err(Error::InvalidOperation);
};
let options = self.options.read();
let (remote, remote_socket) = match (options.non_blocking, options.recv_timeout) {
(false, timeout) => {
let fut = socket.accept();
block!(maybe_timeout(fut, timeout).await)???
}
(true, _) => socket.accept_nonblocking()?,
let (remote, stream) = if self.options.read().non_blocking {
socket.accept_nonblocking()?
} else {
block!(socket.accept().await)??
};
let remote_socket = Self::from_connection(remote_socket);
Ok((remote_socket, remote))
let file = File::from_socket(SocketWrapper::from_connection(stream));
Ok((file, remote))
}
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
match &self.inner {
SocketInner::Packet(socket) => socket.receive_nonblocking(buffer),
SocketInner::Connection(socket) => {
let remote = socket.remote_address().ok_or(Error::NotConnected)?;
let len = socket.receive_nonblocking(buffer)?;
Ok((len, remote))
}
SocketInner::Listener(_) => Err(Error::InvalidOperation),
}
pub fn shutdown(&self, how: SocketShutdown) -> Result<(), Error> {
let SocketInner::Connection(socket) = &self.inner else {
return Err(Error::InvalidOperation);
};
block!(
socket
.shutdown(
how.contains(SocketShutdown::READ),
how.contains(SocketShutdown::WRITE),
)
.await
)?
}
async fn receive(
&self,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error> {
match &self.inner {
SocketInner::Packet(socket) => {
maybe_timeout(socket.receive_from(buffer), timeout).await?
}
SocketInner::Connection(socket) => {
let remote = socket.remote_address().ok_or(Error::NotConnected)?;
let len = maybe_timeout(socket.receive(buffer), timeout).await??;
Ok((len, remote))
}
SocketInner::Listener(_) => Err(Error::InvalidOperation),
}
}
fn send_nonblocking(&self, buffer: &[u8], remote: Option<SocketAddr>) -> Result<usize, Error> {
match &self.inner {
SocketInner::Packet(socket) => socket.send_nonblocking(remote, buffer),
SocketInner::Connection(socket) => socket.send_nonblocking(buffer),
SocketInner::Listener(_) => Err(Error::InvalidOperation),
}
}
async fn send(
&self,
buffer: &[u8],
remote: Option<SocketAddr>,
timeout: Option<Duration>,
) -> Result<usize, Error> {
match &self.inner {
SocketInner::Packet(socket) => {
maybe_timeout(socket.send_to(remote, buffer), timeout).await?
}
SocketInner::Connection(socket) => maybe_timeout(socket.send(buffer), timeout).await?,
SocketInner::Listener(_) => Err(Error::InvalidOperation),
}
}
pub fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let options = self.options.read();
match (options.non_blocking, options.recv_timeout) {
(false, timeout) => block!(self.receive(buffer, timeout).await)?,
(true, _) => self.receive_nonblocking(buffer),
}
}
pub fn send_to(&self, buffer: &[u8], remote: Option<SocketAddr>) -> Result<usize, Error> {
let options = self.options.read();
match (options.non_blocking, options.recv_timeout) {
(false, timeout) => block!(self.send(buffer, remote, timeout).await)?,
(true, _) => self.send_nonblocking(buffer, remote),
}
}
}
impl Socket for SocketWrapper {
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
pub fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
SocketOption::NonBlocking(nb) => {
self.options.write().non_blocking = *nb;
return Ok(());
}
SocketOption::RecvTimeout(timeout) => {
self.options.write().recv_timeout = *timeout;
return Ok(());
@ -215,22 +247,25 @@ impl Socket for SocketWrapper {
self.options.write().send_timeout = *timeout;
return Ok(());
}
SocketOption::NonBlocking(nb) => {
self.options.write().non_blocking = *nb;
SocketOption::ConnectTimeout(timeout) => {
self.options.write().connect_timeout = *timeout;
return Ok(());
}
_ => (),
}
match &self.inner {
SocketInner::Packet(socket) => socket.set_option(option),
SocketInner::Listener(socket) => socket.set_option(option),
SocketInner::Connection(socket) => socket.set_option(option),
SocketInner::Packet(socket) => socket.set_option(option),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
pub fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::NonBlocking(nb) => {
*nb = self.options.read().non_blocking;
return Ok(());
}
SocketOption::RecvTimeout(timeout) => {
*timeout = self.options.read().recv_timeout;
return Ok(());
@ -239,41 +274,16 @@ impl Socket for SocketWrapper {
*timeout = self.options.read().send_timeout;
return Ok(());
}
SocketOption::NonBlocking(nb) => {
*nb = self.options.read().non_blocking;
SocketOption::ConnectTimeout(timeout) => {
*timeout = self.options.read().connect_timeout;
return Ok(());
}
_ => (),
}
match &self.inner {
SocketInner::Packet(socket) => socket.get_option(option),
SocketInner::Listener(socket) => socket.get_option(option),
SocketInner::Connection(socket) => socket.get_option(option),
}
}
fn local_address(&self) -> SocketAddr {
match &self.inner {
SocketInner::Packet(socket) => socket.local_address(),
SocketInner::Listener(socket) => socket.local_address(),
SocketInner::Connection(socket) => socket.local_address(),
}
}
fn remote_address(&self) -> Option<SocketAddr> {
match &self.inner {
SocketInner::Packet(socket) => socket.remote_address(),
SocketInner::Listener(socket) => socket.remote_address(),
SocketInner::Connection(socket) => socket.remote_address(),
}
}
fn close(&self) -> Result<(), Error> {
match &self.inner {
SocketInner::Packet(socket) => socket.close(),
SocketInner::Listener(socket) => socket.close(),
SocketInner::Connection(socket) => socket.close(),
SocketInner::Packet(socket) => socket.get_option(option),
}
}
}
@ -281,9 +291,8 @@ impl Socket for SocketWrapper {
impl FileReadiness for SocketWrapper {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match &self.inner {
SocketInner::Packet(socket) => socket.poll_read(cx),
SocketInner::Listener(socket) => socket.poll_read(cx),
SocketInner::Connection(socket) => socket.poll_read(cx),
SocketInner::Packet(socket) => socket.poll_read(cx),
}
}
}
@ -293,7 +302,6 @@ impl fmt::Debug for SocketWrapper {
match &self.inner {
SocketInner::Packet(socket) => socket.fmt(f),
SocketInner::Connection(socket) => socket.fmt(f),
SocketInner::Listener(socket) => socket.fmt(f),
}
}
}
@ -301,9 +309,8 @@ impl fmt::Debug for SocketWrapper {
impl Drop for SocketInner {
fn drop(&mut self) {
let res = match self {
Self::Packet(socket) => socket.close(),
Self::Connection(socket) => socket.close(),
Self::Listener(socket) => socket.close(),
Self::Packet(socket) => socket.clone().close(),
Self::Connection(socket) => socket.clone().close(),
};
if let Err(error) = res {
log::warn!("Socket close error: {error:?}");
@ -314,69 +321,10 @@ impl Drop for SocketInner {
impl Default for InnerOptions {
fn default() -> Self {
Self {
connect_timeout: None,
recv_timeout: None,
send_timeout: None,
non_blocking: false,
}
}
}
// impl From<Arc<dyn ConnectionSocket + 'static>> for SocketWrapper {
// fn from(value: Arc<dyn ConnectionSocket + 'static>) -> Self {
// Self::Connection(value)
// }
// }
//
// impl From<Arc<dyn ListenerSocket + 'static>> for SocketWrapper {
// fn from(value: Arc<dyn ListenerSocket + 'static>) -> Self {
// Self::Listener(value)
// }
// }
//
// impl From<Arc<dyn PacketSocket + 'static>> for SocketWrapper {
// fn from(value: Arc<dyn PacketSocket + 'static>) -> Self {
// Self::Packet(value)
// }
// }
// impl Deref for PacketSocketWrapper {
// type Target = dyn PacketSocket;
//
// fn deref(&self) -> &Self::Target {
// self.0.as_ref()
// }
// }
//
// impl Drop for PacketSocketWrapper {
// fn drop(&mut self) {
// self.0.close().ok();
// }
// }
//
// impl Deref for ListenerSocketWrapper {
// type Target = dyn ListenerSocket;
//
// fn deref(&self) -> &Self::Target {
// self.0.as_ref()
// }
// }
//
// impl Drop for ListenerSocketWrapper {
// fn drop(&mut self) {
// self.0.close().ok();
// }
// }
//
// impl Deref for ConnectionSocketWrapper {
// type Target = dyn ConnectionSocket;
//
// fn deref(&self) -> &Self::Target {
// self.0.as_ref()
// }
// }
//
// impl Drop for ConnectionSocketWrapper {
// fn drop(&mut self) {
// self.0.close().ok();
// }
// }

View File

@ -8,7 +8,7 @@ pub(crate) use abi::{
UnmountOptions,
},
mem::{MappingFlags, MappingSource},
net::SocketType,
net::{SocketShutdown, SocketType},
process::{Signal, SignalEntryData, SpawnOptions, WaitFlags},
system::SystemInfo,
};

View File

@ -3,127 +3,218 @@ use core::{mem::MaybeUninit, net::SocketAddr};
use abi::{
error::Error,
io::RawFd,
net::{SocketConnect, SocketOption, SocketType},
net::{SocketOption, SocketShutdown, SocketType},
};
use libk::{
task::thread::Thread,
vfs::{File, Socket, SocketWrapper},
vfs::{File, FileRef, SocketWrapper},
};
use ygg_driver_net_core::socket::{RawSocket, TcpListener, TcpSocket, UdpSocket};
use ygg_driver_net_core::socket::{RawSocket, TcpSocket, UdpSocket};
use crate::syscall::run_with_io;
// Network
pub(crate) fn connect_socket(
connect: &mut SocketConnect,
local_result: &mut MaybeUninit<SocketAddr>,
) -> Result<RawFd, Error> {
fn get_socket(fd: RawFd) -> Result<FileRef, Error> {
let thread = Thread::current();
let process = thread.process();
run_with_io(&process, |io| io.files.file(fd).cloned())
}
pub(crate) fn create_socket(ty: SocketType) -> Result<RawFd, Error> {
let thread = Thread::current();
let process = thread.process();
run_with_io(&process, |mut io| {
let (local, fd) = match connect {
SocketConnect::Tcp(remote, timeout) => {
let remote = (*remote).into();
let (local, socket) = libk::block!(TcpSocket::connect(remote, *timeout).await)??;
let file = File::from_socket(SocketWrapper::from_connection(socket));
let fd = io.files.place_file(file, true)?;
(local.into(), fd)
}
SocketConnect::Udp(_socket, _remote) => {
todo!("UDP socket connect")
}
};
local_result.write(local);
Ok(fd)
})
}
pub(crate) fn bind_socket(listen: &SocketAddr, ty: SocketType) -> Result<RawFd, Error> {
let thread = Thread::current();
let process = thread.process();
run_with_io(&process, |mut io| {
let listen = (*listen).into();
let socket = match ty {
SocketType::RawPacket => SocketWrapper::from_packet(RawSocket::bind()?),
SocketType::UdpPacket => SocketWrapper::from_packet(UdpSocket::bind(listen)?),
SocketType::TcpStream => SocketWrapper::from_listener(TcpListener::bind(listen)?),
SocketType::RawPacket => SocketWrapper::from_packet(RawSocket::new()),
SocketType::UdpPacket => SocketWrapper::from_packet(UdpSocket::new()),
SocketType::TcpStream => SocketWrapper::from_connection(TcpSocket::new()),
};
let file = File::from_socket(socket);
let fd = io.files.place_file(file, true)?;
Ok(fd)
})
}
pub(crate) fn accept(
socket_fd: RawFd,
remote_result: &mut MaybeUninit<SocketAddr>,
) -> Result<RawFd, Error> {
pub(crate) fn bind(sock_fd: RawFd, local: &SocketAddr) -> Result<(), Error> {
let file = get_socket(sock_fd)?;
file.as_socket()?.bind((*local).into())
}
pub(crate) fn listen(sock_fd: RawFd) -> Result<(), Error> {
let file = get_socket(sock_fd)?;
file.as_socket()?.listen()
}
pub(crate) fn connect(sock_fd: RawFd, remote: &SocketAddr) -> Result<(), Error> {
let file = get_socket(sock_fd)?;
file.as_socket()?.connect((*remote).into())
}
pub(crate) fn accept(sock_fd: RawFd, remote: &mut MaybeUninit<SocketAddr>) -> Result<RawFd, Error> {
let thread = Thread::current();
let process = thread.process();
run_with_io(&process, |mut io| {
let file = io.files.file(socket_fd)?;
let listener = file.as_socket()?;
let (socket, remote) = listener.accept()?;
remote_result.write(remote.into());
let fd = io.files.place_file(File::from_socket(socket), true)?;
Ok(fd)
let listener = io.files.file(sock_fd)?;
let listener = listener.as_socket()?;
let (stream_file, stream_remote) = listener.accept()?;
let stream_fd = io.files.place_file(stream_file, true)?;
remote.write(stream_remote.into());
Ok(stream_fd)
})
}
pub(crate) fn shutdown(sock_fd: RawFd, how: SocketShutdown) -> Result<(), Error> {
let file = get_socket(sock_fd)?;
file.as_socket()?.shutdown(how)
}
pub(crate) fn send_to(
socket_fd: RawFd,
buffer: &[u8],
recepient: &Option<SocketAddr>,
sock_fd: RawFd,
data: &[u8],
remote: &Option<SocketAddr>,
) -> Result<usize, Error> {
let thread = Thread::current();
let process = thread.process();
run_with_io(&process, |io| {
let file = io.files.file(socket_fd)?;
let socket = file.as_socket()?;
let remote = recepient.map(Into::into);
socket.send_to(buffer, remote)
})
let file = get_socket(sock_fd)?;
file.as_socket()?.send_to(data, remote.map(Into::into))
}
pub(crate) fn receive_from(
socket_fd: RawFd,
buffer: &mut [u8],
remote_result: &mut MaybeUninit<SocketAddr>,
sock_fd: RawFd,
data: &mut [u8],
remote: &mut MaybeUninit<SocketAddr>,
) -> Result<usize, Error> {
let thread = Thread::current();
let process = thread.process();
run_with_io(&process, |io| {
let file = io.files.file(socket_fd)?;
let socket = file.as_socket()?;
let (len, remote) = socket.receive_from(buffer)?;
remote_result.write(remote.into());
let file = get_socket(sock_fd)?;
let (len, remote_) = file.as_socket()?.receive_from(data)?;
remote.write(remote_.into());
Ok(len)
})
}
pub(crate) fn set_socket_option(socket_fd: RawFd, option: &SocketOption) -> Result<(), Error> {
let thread = Thread::current();
let process = thread.process();
run_with_io(&process, |io| {
let file = io.files.file(socket_fd)?;
let socket = file.as_socket()?;
socket.set_option(option)
})
pub(crate) fn get_socket_option(sock_fd: RawFd, option: &mut SocketOption) -> Result<(), Error> {
let file = get_socket(sock_fd)?;
file.as_socket()?.get_option(option)
}
pub(crate) fn get_socket_option(socket_fd: RawFd, option: &mut SocketOption) -> Result<(), Error> {
let thread = Thread::current();
let process = thread.process();
run_with_io(&process, |io| {
let file = io.files.file(socket_fd)?;
let socket = file.as_socket()?;
socket.get_option(option)
})
pub(crate) fn set_socket_option(sock_fd: RawFd, option: &SocketOption) -> Result<(), Error> {
let file = get_socket(sock_fd)?;
file.as_socket()?.set_option(option)
}
// // Network
// pub(crate) fn connect_socket(
// connect: &mut SocketConnect,
// local_result: &mut MaybeUninit<SocketAddr>,
// ) -> Result<RawFd, Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |mut io| {
// let (local, fd) = match connect {
// SocketConnect::Tcp(remote, timeout) => {
// let remote = (*remote).into();
// let (local, socket) = libk::block!(TcpSocket::connect(remote, *timeout).await)??;
// let file = File::from_socket(SocketWrapper::from_connection(socket));
// let fd = io.files.place_file(file, true)?;
// (local.into(), fd)
// }
// SocketConnect::Udp(_socket, _remote) => {
// todo!("UDP socket connect")
// }
// };
// local_result.write(local);
// Ok(fd)
// })
// }
//
// pub(crate) fn bind_socket(listen: &SocketAddr, ty: SocketType) -> Result<RawFd, Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |mut io| {
// let listen = (*listen).into();
// let socket = match ty {
// SocketType::RawPacket => SocketWrapper::from_packet(RawSocket::bind()?),
// SocketType::UdpPacket => SocketWrapper::from_packet(UdpSocket::bind(listen)?),
// SocketType::TcpStream => SocketWrapper::from_listener(TcpListener::bind(listen)?),
// };
// let file = File::from_socket(socket);
// let fd = io.files.place_file(file, true)?;
// Ok(fd)
// })
// }
//
// pub(crate) fn accept(
// socket_fd: RawFd,
// remote_result: &mut MaybeUninit<SocketAddr>,
// ) -> Result<RawFd, Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |mut io| {
// let file = io.files.file(socket_fd)?;
// let listener = file.as_socket()?;
// let (socket, remote) = listener.accept()?;
// remote_result.write(remote.into());
// let fd = io.files.place_file(File::from_socket(socket), true)?;
// Ok(fd)
// })
// }
//
// pub(crate) fn send_to(
// socket_fd: RawFd,
// buffer: &[u8],
// recepient: &Option<SocketAddr>,
// ) -> Result<usize, Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |io| {
// let file = io.files.file(socket_fd)?;
// let socket = file.as_socket()?;
// let remote = recepient.map(Into::into);
// socket.send_to(buffer, remote)
// })
// }
//
// pub(crate) fn receive_from(
// socket_fd: RawFd,
// buffer: &mut [u8],
// remote_result: &mut MaybeUninit<SocketAddr>,
// ) -> Result<usize, Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |io| {
// let file = io.files.file(socket_fd)?;
// let socket = file.as_socket()?;
// let (len, remote) = socket.receive_from(buffer)?;
// remote_result.write(remote.into());
// Ok(len)
// })
// }
//
// pub(crate) fn set_socket_option(socket_fd: RawFd, option: &SocketOption) -> Result<(), Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |io| {
// let file = io.files.file(socket_fd)?;
// let socket = file.as_socket()?;
// socket.set_option(option)
// })
// }
//
// pub(crate) fn get_socket_option(socket_fd: RawFd, option: &mut SocketOption) -> Result<(), Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |io| {
// let file = io.files.file(socket_fd)?;
// let socket = file.as_socket()?;
// socket.get_option(option)
// })
// }

View File

@ -53,6 +53,13 @@ enum SocketType(u32) {
UdpPacket = 2,
}
bitfield SocketShutdown(u32) {
/// Stop reception on a socket
READ: 0,
/// Stop transmission on a socket
WRITE: 1,
}
// abi::mem
bitfield MappingFlags(u32) {
@ -161,13 +168,15 @@ syscall receive_message(
) -> Result<()>;
// Network
syscall connect_socket(connect: &mut SocketConnect, local: &mut MaybeUninit<SocketAddr>) -> Result<RawFd>;
syscall bind_socket(address: &SocketAddr, ty: SocketType) -> Result<RawFd>;
syscall accept(sock_fd: RawFd, remote: &mut MaybeUninit<SocketAddr>) -> Result<RawFd>;
syscall create_socket(ty: SocketType) -> Result<RawFd>;
syscall bind(sock_fd: RawFd, local: &SocketAddr) -> Result<()>;
syscall listen(sock_fd: RawFd) -> Result<()>;
syscall connect(sock_fd: RawFd, remote: &SocketAddr) -> Result<()>;
syscall accept(sock_fd: RawFd, remote: &mut MaybeUninit<SocketAddr>) -> Result<RawFd>;
syscall shutdown(sock_fd: RawFd, how: SocketShutdown) -> Result<()>;
syscall send_to(sock_fd: RawFd, data: &[u8], remote: &Option<SocketAddr>) -> Result<usize>;
syscall receive_from(sock_fd: RawFd, data: &mut [u8], remote: &mut MaybeUninit<SocketAddr>) -> Result<usize>;
syscall get_socket_option(sock_fd: RawFd, option: &mut SocketOption<'_>) -> Result<()>;
syscall set_socket_option(sock_fd: RawFd, option: &SocketOption<'_>) -> Result<()>;

View File

@ -4,6 +4,7 @@
clippy::new_without_default,
clippy::should_implement_trait,
clippy::module_inception,
clippy::missing_transmute_annotations,
incomplete_features,
stable_features
)]

View File

@ -10,8 +10,7 @@ pub mod types;
use core::time::Duration;
pub use crate::generated::SocketType;
use crate::io::RawFd;
pub use crate::generated::{SocketShutdown, SocketType};
pub use types::{
ip_addr::{IpAddr, Ipv4Addr, Ipv6Addr},
@ -20,16 +19,6 @@ pub use types::{
MacAddress,
};
/// Describes a socket connect operation
#[derive(Clone, Debug)]
pub enum SocketConnect {
/// Connect a TCP socket with optional timeout.
Tcp(core::net::SocketAddr, Option<Duration>),
/// "Connect" an UDP socket, this just sets the sender's address in the socket so
/// the caller can then use send().
Udp(RawFd, core::net::SocketAddr),
}
/// Describes a method to query an interface
#[derive(Clone, Debug)]
pub enum SocketInterfaceQuery<'a> {
@ -52,6 +41,8 @@ pub enum SocketOption<'a> {
UnbindInterface,
/// (Read-only) Hardware address of the bound interface
BoundHardwareAddress(MacAddress),
/// (Read-only) Local socket address
LocalAddress(Option<core::net::SocketAddr>),
/// (Read-only) Remote socket address
PeerAddress(Option<core::net::SocketAddr>),
/// If set, reception will return [crate::error::Error::WouldBlock] if the socket has
@ -63,6 +54,8 @@ pub enum SocketOption<'a> {
RecvTimeout(Option<Duration>),
/// If not [None], send operations will have a time limit set before returning an error.
SendTimeout(Option<Duration>),
/// If not [None], connect() call will timeout after the specified time limit.
ConnectTimeout(Option<Duration>),
/// (UDP) If set, allows multicast packets to be looped back to local host.
MulticastLoopV4(bool),
/// (UDP) If set, allows multicast packets to be looped back to local host.

View File

@ -85,6 +85,30 @@ impl From<Ipv4Addr> for core::net::Ipv4Addr {
// IPv6
impl Ipv6Addr {
/// An IPv6 unspecified address `::`.
pub const UNSPECIFIED: Self = Self::new(0, 0, 0, 0, 0, 0, 0, 0);
/// Constructs a new IPv6 address from its words.
///
/// The result represents the IP address `a:b:c:d:e:f:g:h`.
#[allow(clippy::too_many_arguments)]
pub const fn new(a: u16, b: u16, c: u16, d: u16, e: u16, f: u16, g: u16, h: u16) -> Self {
let addr16 = [
a.to_be(),
b.to_be(),
c.to_be(),
d.to_be(),
e.to_be(),
f.to_be(),
g.to_be(),
h.to_be(),
];
// SAFETY: `[u16; 8]` is safe to transmute to `[u8; 16]`
Self(unsafe { core::mem::transmute::<_, [u8; 16]>(addr16) })
}
}
impl FromStr for Ipv6Addr {
type Err = Error;

View File

@ -175,3 +175,15 @@ impl From<core::net::SocketAddr> for SocketAddr {
}
}
}
impl From<SocketAddrV4> for SocketAddr {
fn from(value: SocketAddrV4) -> Self {
Self::V4(value)
}
}
impl From<SocketAddrV6> for SocketAddr {
fn from(value: SocketAddrV6) -> Self {
Self::V6(value)
}
}

View File

@ -1,12 +1,10 @@
//! Network-related functions and types
use core::{mem::MaybeUninit, net::SocketAddr, time::Duration};
use core::{net::SocketAddr, time::Duration};
pub use abi::net::{MacAddress, SocketConnect, SocketInterfaceQuery, SocketOption, SocketType};
pub use abi::net::{MacAddress, SocketInterfaceQuery, SocketOption, SocketShutdown, SocketType};
use abi::{error::Error, io::RawFd};
use crate::sys;
#[allow(unused_macros)]
macro socket_option_variant {
($opt:ident: bool) => { $crate::net::SocketOption::$opt(false) },
@ -79,23 +77,70 @@ pub mod dns {
}
}
fn connect_inner(connect: &mut SocketConnect) -> Result<(SocketAddr, RawFd), Error> {
let mut local = MaybeUninit::uninit();
let fd = unsafe { sys::connect_socket(connect, &mut local) }?;
let local = unsafe { local.assume_init() };
Ok((local, fd))
fn bind_inner(fd: RawFd, local: &SocketAddr, listen: bool) -> Result<(), Error> {
unsafe { crate::sys::bind(fd, local) }?;
if listen {
unsafe { crate::sys::listen(fd) }?;
}
Ok(())
}
fn connect_inner(fd: RawFd, remote: &SocketAddr, timeout: Option<Duration>) -> Result<(), Error> {
if timeout.is_some() {
unsafe { crate::sys::set_socket_option(fd, &SocketOption::ConnectTimeout(timeout)) }?;
}
unsafe { crate::sys::connect(fd, remote) }?;
Ok(())
}
/// Creates a new socket and binds it to a local address
pub fn create_and_bind(ty: SocketType, local: &SocketAddr, listen: bool) -> Result<RawFd, Error> {
let fd = unsafe { crate::sys::create_socket(ty) }?;
match bind_inner(fd, local, listen) {
Ok(()) => Ok(fd),
Err(error) => {
unsafe { crate::sys::close(fd) }.ok();
Err(error)
}
}
}
/// Binds a TCP listener socket to some local address
pub fn bind_tcp(local: &SocketAddr) -> Result<RawFd, Error> {
create_and_bind(SocketType::TcpStream, local, true)
}
/// Binds a raw socket to some network interface
pub fn bind_raw(iface: SocketInterfaceQuery<'_>) -> Result<RawFd, Error> {
let fd = unsafe { crate::sys::create_socket(SocketType::RawPacket) }?;
let option = SocketOption::BindInterface(iface);
match unsafe { crate::sys::set_socket_option(fd, &option) } {
Ok(()) => Ok(fd),
Err(error) => {
unsafe { crate::sys::close(fd) }.ok();
Err(error)
}
}
}
/// Binds an UDP socket to some local address
pub fn bind_udp(local: &SocketAddr) -> Result<RawFd, Error> {
create_and_bind(SocketType::UdpPacket, local, false)
}
/// Connect to a TCP listener
pub fn connect_tcp(
remote: SocketAddr,
timeout: Option<Duration>,
) -> Result<(SocketAddr, RawFd), Error> {
connect_inner(&mut SocketConnect::Tcp(remote, timeout))
pub fn connect_tcp(remote: &SocketAddr, timeout: Option<Duration>) -> Result<RawFd, Error> {
let fd = unsafe { crate::sys::create_socket(SocketType::TcpStream) }?;
match connect_inner(fd, remote, timeout) {
Ok(()) => Ok(fd),
Err(error) => {
unsafe { crate::sys::close(fd) }.ok();
Err(error)
}
}
}
/// "Connect" an UDP socket
pub fn connect_udp(socket_fd: RawFd, remote: SocketAddr) -> Result<(), Error> {
connect_inner(&mut SocketConnect::Udp(socket_fd, remote))?;
Ok(())
pub fn connect_udp(socket_fd: RawFd, remote: &SocketAddr) -> Result<(), Error> {
connect_inner(socket_fd, remote, None)
}

View File

@ -21,7 +21,7 @@ mod generated {
TerminalOptions, TerminalSize, TimerOptions, UnmountOptions,
},
mem::{MappingFlags, MappingSource},
net::SocketType,
net::{SocketShutdown, SocketType},
process::{
ExecveOptions, ProcessGroupId, ProcessId, Signal, SignalEntryData, SpawnOptions,
ThreadSpawnOptions, WaitFlags,

10
test.c
View File

@ -1,8 +1,14 @@
#include <pthread.h>
#include <unistd.h>
#include <netinet/in.h>
#include <stdlib.h>
#include <assert.h>
#include <stdio.h>
int main(int argc, const char **argv) {
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
perror("socket()");
return EXIT_FAILURE;
}
return 0;
}

View File

@ -24,7 +24,7 @@ fn include_dir(d: &DirEntry) -> bool {
&& d.path()
.iter()
.nth(2)
.map_or(false, |c| c.to_str().map_or(false, |x| !x.starts_with("_")))
.is_some_and(|c| c.to_str().is_some_and(|x| !x.starts_with("_")))
}
fn generate_header(config_path: impl AsRef<Path>, header_output: impl AsRef<Path>) {
@ -67,11 +67,7 @@ fn generate_header(config_path: impl AsRef<Path>, header_output: impl AsRef<Path
fn compile_crt0(arch: &str, output_dir: impl AsRef<Path>) {
let output_dir = output_dir.as_ref();
let mut command = Command::new("clang");
let arch = if arch == "x86" {
"i686"
} else {
arch
};
let arch = if arch == "x86" { "i686" } else { arch };
let input_dir = PathBuf::from("crt").join(arch);
let crt0_c = input_dir.join("crt0.c");
let crt0_s = input_dir.join("crt0.S");

View File

@ -0,0 +1,15 @@
language = "C"
style = "Type"
sys_includes = [
"netinet/in.h",
"inttypes.h"
]
no_includes = true
include_guard = "_ARPA_INET_H"
usize_type = "size_t"
isize_type = "ssize_t"
[export]

View File

@ -0,0 +1,20 @@
#[no_mangle]
unsafe extern "C" fn htonl(v: u32) -> u32 {
v.to_be()
}
#[no_mangle]
unsafe extern "C" fn htons(v: u16) -> u16 {
v.to_be()
}
#[no_mangle]
unsafe extern "C" fn ntohl(v: u32) -> u32 {
u32::from_be(v)
}
#[no_mangle]
unsafe extern "C" fn ntohs(v: u16) -> u16 {
u16::from_be(v)
}

View File

@ -195,6 +195,7 @@ impl From<yggdrasil_rt::Error> for Errno {
Error::DirectoryNotEmpty => Errno::ENOTEMPTY,
Error::NotConnected => Errno::ENOTCONN,
Error::ProcessNotFound => Errno::ESRCH,
Error::CrossDeviceLink => Errno::EXDEV,
}
}
}

View File

@ -1,6 +1,9 @@
use core::ffi::{c_char, c_int, c_short, VaList};
use yggdrasil_rt::{io::{AccessMode, FileMode, OpenOptions}, sys as syscall};
use yggdrasil_rt::{
io::{AccessMode, FileMode, OpenOptions},
sys as syscall,
};
use crate::{
error::{CFdResult, CIntCountResult, CIntZeroResult, EResult, ResultExt, TryFromExt},
@ -125,10 +128,7 @@ fn open_opts(opts: c_int, ap: &mut VaList) -> EResult<OpenMode> {
}
// TODO O_CLOEXEC
if opts
& (O_DSYNC | O_RSYNC | O_SYNC | O_TTY_INIT | O_NONBLOCK | O_NOFOLLOW | O_NOCTTY)
!= 0
{
if opts & (O_DSYNC | O_RSYNC | O_SYNC | O_TTY_INIT | O_NONBLOCK | O_NOFOLLOW | O_NOCTTY) != 0 {
todo!();
}
@ -187,7 +187,7 @@ pub(crate) unsafe extern "C" fn faccessat(
atfd: c_int,
path: *const c_char,
mode: c_int,
flags: c_int,
_flags: c_int,
) -> CIntZeroResult {
let atfd = util::at_fd(atfd)?;
let path = path.ensure_str();

View File

@ -136,6 +136,11 @@ pub mod sys_types;
pub mod sys_utsname;
pub mod sys_wait;
// Network
pub mod arpa_inet;
pub mod netinet_in;
pub mod sys_socket;
// TODO Generate those as part of dyn-loader (and make dyn-loader a shared library)
pub mod link;

View File

@ -0,0 +1,17 @@
language = "C"
style = "Tag"
sys_includes = [
"inttypes.h",
"sys/socket.h",
"arpa/inet.h"
]
no_includes = true
include_guard = "_NETINET_IN_H"
usize_type = "size_t"
isize_type = "ssize_t"
[export]
include = ["sockaddr_in", "sockaddr_in6"]

View File

@ -0,0 +1,60 @@
use core::ffi::c_int;
use super::sys_socket::sa_family_t;
pub type in_port_t = u16;
pub type in_addr_t = u32;
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct in_addr {
pub s_addr: in_addr_t
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct in6_addr {
pub s6_addr: [u8; 16]
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct sockaddr_in {
pub sin_family: sa_family_t,
pub sin_port: in_port_t,
pub sin_addr: in_addr,
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct sockaddr_in6 {
pub sin6_family: sa_family_t,
pub sin6_port: in_port_t,
pub sin6_flowinfo: u32,
pub sin6_addr: in6_addr,
pub sin6_scope_id: u32
}
// TODO more IPv6
pub const IPPROTO_ICMP: c_int = 1;
pub const IPPROTO_TCP: c_int = 6;
pub const IPPROTO_UDP: c_int = 17;
pub const IPPROTO_IP: c_int = 0;
pub const IPPROTO_IPV6: c_int = 41;
pub const IPPROTO_RAW: c_int = 255;
pub const INADDR_ANY: in_addr_t = 0;
pub const INADDR_BROADCAST: in_addr_t = 0xFFFFFFFF;
pub const INET_ADDRSTRLEN: usize = 16;
pub const INET6_ADDRSTRLEN: usize = 46;
pub const IPV6_JOIN_GROUP: c_int = 6001;
pub const IPV6_LEAVE_GROUP: c_int = 6002;
pub const IPV6_MULTICAST_HOPS: c_int = 6003;
pub const IPV6_MULTICAST_IF: c_int = 6004;
pub const IPV6_MULTICAST_LOOP: c_int = 6005;
pub const IPV6_UNICAST_HOPS: c_int = 6006;
pub const IPV6_V6ONLY: c_int = 6007;

View File

@ -1,9 +1,15 @@
use core::ffi::{c_char, c_int};
use yggdrasil_rt::sys as syscall;
use yggdrasil_rt::{
io::{RemoveFlags, Rename},
sys as syscall,
};
use crate::{
error::{self, CIntZeroResult, CPtrResult, ResultExt}, headers::errno::Errno, io::managed::{stderr, FILE}, util::{PointerExt, PointerStrExt}
error::{self, CIntZeroResult, CPtrResult, ResultExt},
headers::{errno::Errno, fcntl::AT_FDCWD},
io::managed::{stderr, FILE},
util::{self, PointerExt, PointerStrExt},
};
#[no_mangle]
@ -32,21 +38,34 @@ unsafe extern "C" fn ctermid(_buf: *mut c_char) -> *mut c_char {
#[no_mangle]
unsafe extern "C" fn remove(path: *const c_char) -> CIntZeroResult {
let path = path.ensure_str();
syscall::remove(None, path).e_map_err(Errno::from)?;
syscall::remove(None, path, RemoveFlags::DIRECTORY).e_map_err(Errno::from)?;
CIntZeroResult::SUCCESS
}
#[no_mangle]
unsafe extern "C" fn rename(src: *const c_char, dst: *const c_char) -> c_int {
let src = src.ensure_str();
let dst = dst.ensure_str();
yggdrasil_rt::debug_trace!("rename {src:?} -> {dst:?}");
todo!()
unsafe extern "C" fn rename(src: *const c_char, dst: *const c_char) -> CIntZeroResult {
renameat(AT_FDCWD, src, dst)
}
#[no_mangle]
unsafe extern "C" fn renameat(_atfd: c_int, _src: *const c_char, _dst: *const c_char) -> c_int {
todo!()
unsafe extern "C" fn renameat(
atfd: c_int,
src: *const c_char,
dst: *const c_char,
) -> CIntZeroResult {
let at = util::at_fd(atfd)?;
let source = src.ensure_str();
let destination = dst.ensure_str();
syscall::rename(&Rename {
source_at: at,
destination_at: at,
source,
destination,
})
.e_map_err(Errno::from)?;
CIntZeroResult::SUCCESS
}
#[no_mangle]

View File

@ -1,6 +1,16 @@
use core::{ffi::{c_char, c_int}, ptr::NonNull, slice};
use core::{
ffi::{c_char, c_int},
ptr::NonNull,
slice,
};
use crate::{allocator::c_alloc, error::{self, CPtrResult, CResult}, headers::errno::Errno, io, util::PointerStrExt};
use crate::{
allocator::c_alloc,
error::{self, CPtrResult, CResult},
headers::errno::Errno,
io,
util::PointerStrExt,
};
#[no_mangle]
unsafe extern "C" fn grantpt(_fd: c_int) -> c_int {
@ -28,7 +38,10 @@ unsafe extern "C" fn ptsname(_fd: c_int) -> *mut c_char {
}
#[no_mangle]
unsafe extern "C" fn realpath(path: *const c_char, mut resolved_ptr: *mut c_char) -> CPtrResult<c_char> {
unsafe extern "C" fn realpath(
path: *const c_char,
resolved_ptr: *mut c_char,
) -> CPtrResult<c_char> {
if path.is_null() {
error::errno = Errno::EINVAL;
return CPtrResult::ERROR;

View File

@ -1,8 +1,18 @@
use core::{ffi::{c_char, c_int, c_void}, num::NonZeroUsize, ptr::{self, NonNull}};
use core::{
ffi::{c_char, c_int, c_void},
ptr::{self, NonNull},
};
use yggdrasil_rt::{io::RawFd, mem::{MappingFlags, MappingSource}, sys as syscall};
use yggdrasil_rt::{
io::RawFd,
mem::{MappingFlags, MappingSource},
sys as syscall,
};
use crate::{error::{self, CPtrResult, EResult, ResultExt, TryFromExt}, headers::errno::Errno};
use crate::{
error::{self, EResult, ResultExt, TryFromExt},
headers::errno::Errno,
};
use super::sys_types::{mode_t, off_t};

View File

@ -0,0 +1,14 @@
language = "C"
style = "Tag"
sys_includes = [
"sys/types.h"
]
no_includes = true
include_guard = "_SYS_SOCKET_H"
usize_type = "size_t"
isize_type = "ssize_t"
[export]

View File

@ -0,0 +1,72 @@
use core::ffi::{c_int, c_void};
use super::{msghdr, sockaddr, socklen_t};
#[no_mangle]
unsafe extern "C" fn recv(fd: c_int, buffer: *mut c_void, len: usize, flags: c_int) -> isize {
let _ = fd;
let _ = buffer;
let _ = len;
let _ = flags;
todo!()
}
#[no_mangle]
unsafe extern "C" fn recvfrom(
fd: c_int,
buffer: *mut c_void,
len: usize,
flags: c_int,
remote: *mut sockaddr,
remote_len: *mut socklen_t,
) -> isize {
let _ = fd;
let _ = buffer;
let _ = len;
let _ = flags;
let _ = remote;
let _ = remote_len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn recvmsg(fd: c_int, message: *mut msghdr, flags: c_int) -> isize {
let _ = fd;
let _ = message;
let _ = flags;
todo!()
}
#[no_mangle]
unsafe extern "C" fn send(fd: c_int, data: *const c_void, len: usize) -> isize {
let _ = fd;
let _ = data;
let _ = len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn sendmsg(fd: c_int, message: *const msghdr, flags: c_int) -> isize {
let _ = fd;
let _ = message;
let _ = flags;
todo!()
}
#[no_mangle]
unsafe extern "C" fn sendto(
fd: c_int,
data: *const c_void,
len: usize,
flags: c_int,
remote: *const sockaddr,
remote_len: socklen_t,
) -> isize {
let _ = fd;
let _ = data;
let _ = len;
let _ = flags;
let _ = remote;
let _ = remote_len;
todo!()
}

View File

@ -0,0 +1,86 @@
use core::ffi::{c_int, c_void};
mod io;
mod option;
mod socket;
pub type socklen_t = usize;
pub type sa_family_t = u16;
pub const __SS_SIZE: usize = 256;
// __SS_SIZE - sizeof(sa_samily_t)
pub const __SOCKADDR_LEN: usize = __SS_SIZE - 2;
// __SS_SIZE - sizeof(sa_family_t) - sizeof(__ss_aligntype)
pub const __SS_PADSIZE: usize = __SS_SIZE - 6;
pub type __ss_aligntype = u32;
#[derive(Clone, Copy, Debug)]
#[repr(C)]
pub struct sockaddr {
pub sa_family: sa_family_t,
pub sa_data: [u8; __SOCKADDR_LEN],
}
// TODO struct sockaddr_storage
// TODO struct iovec from sys/uio.h
#[derive(Clone, Copy, Debug)]
#[repr(C)]
pub struct iovec {
__dummy: u32
}
#[derive(Clone, Copy, Debug)]
#[repr(C)]
pub struct msghdr {
pub msg_name: *mut c_void,
pub msg_namelen: socklen_t,
pub msg_iov: *mut iovec,
pub msg_iovlen: c_int,
pub msg_control: *mut c_void,
pub msg_controllen: socklen_t,
pub msg_flags: c_int
}
// TODO struct cmsghdr
// socket() parameters
pub const SOCK_DGRAM: c_int = 1;
pub const SOCK_RAW: c_int = 2;
pub const SOCK_SEQPACKET: c_int = 3;
pub const SOCK_STREAM: c_int = 4;
// setsockopt() parameters
pub const SOL_SOCKET: c_int = 1;
pub const SO_ACCEPTCONN: c_int = 1;
pub const SO_BROADCAST: c_int = 2;
pub const SO_DEBUG: c_int = 3;
pub const SO_DONTROUTE: c_int = 4;
pub const SO_ERROR: c_int = 5;
pub const SO_KEEPALIVE: c_int = 6;
pub const SO_LINGER: c_int = 7;
pub const SO_OOBINLINE: c_int = 8;
pub const SO_RCVBUF: c_int = 9;
pub const SO_RCVLOWAT: c_int = 10;
pub const SO_RCVTIMEO: c_int = 11;
pub const SO_REUSEADDR: c_int = 12;
pub const SO_SNDBUF: c_int = 13;
pub const SO_SNDLOWAT: c_int = 14;
pub const SO_SNDTIMEO: c_int = 15;
pub const SO_TYPE: c_int = 16;
pub const SOMAXCONN: usize = 64;
// TODO msg_flags values
pub const AF_INET: c_int = 1;
pub const AF_INET6: c_int = 2;
pub const AF_UNIX: c_int = 3;
pub const AF_UNSPEC: c_int = 0;
pub const SHUT_RD: c_int = 1 << 0;
pub const SHUT_WR: c_int = 1 << 1;
pub const SHUT_RDWR: c_int = SHUT_RD | SHUT_WR;

View File

@ -0,0 +1,35 @@
use core::ffi::{c_int, c_void};
use super::socklen_t;
#[no_mangle]
unsafe extern "C" fn getsockopt(
fd: c_int,
level: c_int,
name: c_int,
value: *mut c_void,
size: *mut socklen_t,
) -> c_int {
let _ = fd;
let _ = level;
let _ = name;
let _ = value;
let _ = size;
todo!()
}
#[no_mangle]
unsafe extern "C" fn setsockopt(
fd: c_int,
level: c_int,
name: c_int,
value: *const c_void,
size: socklen_t,
) -> c_int {
let _ = fd;
let _ = level;
let _ = name;
let _ = value;
let _ = size;
todo!()
}

View File

@ -0,0 +1,99 @@
use core::ffi::c_int;
use crate::{
error::{self, CFdResult, CResult},
headers::{
errno::Errno,
sys_socket::{AF_INET, SOCK_DGRAM, SOCK_STREAM},
},
};
use super::{sockaddr, socklen_t};
#[no_mangle]
unsafe extern "C" fn accept(fd: c_int, remote: *mut sockaddr, len: *mut socklen_t) -> c_int {
let _ = fd;
let _ = remote;
let _ = len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn bind(fd: c_int, local: *const sockaddr, len: socklen_t) -> c_int {
let _ = fd;
let _ = local;
let _ = len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn connect(fd: c_int, remote: *const sockaddr, len: socklen_t) -> c_int {
let _ = fd;
let _ = remote;
let _ = len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn getpeername(fd: c_int, remote: *mut sockaddr, len: *mut socklen_t) -> c_int {
let _ = fd;
let _ = remote;
let _ = len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn getsockname(fd: c_int, local: *mut sockaddr, len: *mut socklen_t) -> c_int {
let _ = fd;
let _ = local;
let _ = len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn listen(fd: c_int, backlog: c_int) -> c_int {
let _ = fd;
let _ = backlog;
todo!()
}
#[no_mangle]
unsafe extern "C" fn shutdown(fd: c_int, how: c_int) -> c_int {
let _ = fd;
let _ = how;
todo!()
}
#[no_mangle]
unsafe extern "C" fn sockatmark(fd: c_int) -> c_int {
let _ = fd;
todo!()
}
#[no_mangle]
unsafe extern "C" fn socket(domain: c_int, ty: c_int, proto: c_int) -> CFdResult {
match (domain, ty, proto) {
(AF_INET, SOCK_STREAM, 0) => todo!(),
(AF_INET, SOCK_DGRAM, 0) => todo!(),
(_, _, _) => {
yggdrasil_rt::debug_trace!("Unsupported socket({domain}, {ty}, {proto})");
error::errno = Errno::ENOTSUPP;
CFdResult::ERROR
}
}
}
#[no_mangle]
unsafe extern "C" fn socketpair(
domain: c_int,
ty: c_int,
proto: c_int,
sockets: *mut c_int,
) -> c_int {
let _ = domain;
let _ = ty;
let _ = proto;
let _ = sockets;
todo!()
}

View File

@ -1,8 +1,14 @@
use core::{ffi::c_int, ptr::NonNull};
use crate::{error::{self, CIntZeroResult, CResult, EResult, ResultExt}, headers::{
errno::Errno, sys_time::{__ygg_timespec_t, timespec}, sys_types::{clock_t, clockid_t, pid_t, time_t}, time::{CLOCK_MONOTONIC, CLOCK_REALTIME}
}};
use crate::{
error::{CIntZeroResult, EResult, ResultExt},
headers::{
errno::Errno,
sys_time::{__ygg_timespec_t, timespec},
sys_types::{clock_t, clockid_t, pid_t, time_t},
time::{CLOCK_MONOTONIC, CLOCK_REALTIME},
},
};
use yggdrasil_rt::time::{self as rt, ClockType};
@ -21,7 +27,7 @@ fn clock_type(clock_id: clockid_t) -> EResult<ClockType> {
match clock_id {
CLOCK_REALTIME => EResult::Ok(ClockType::RealTime),
CLOCK_MONOTONIC => EResult::Ok(ClockType::Monotonic),
_ => EResult::Err(Errno::EINVAL)
_ => EResult::Err(Errno::EINVAL),
}
}
@ -42,14 +48,17 @@ unsafe extern "C" fn clock_getres(_clock_id: clockid_t, _ts: *mut __ygg_timespec
}
#[no_mangle]
unsafe extern "C" fn clock_gettime(clock_id: clockid_t, ts: *mut __ygg_timespec_t) -> CIntZeroResult {
unsafe extern "C" fn clock_gettime(
clock_id: clockid_t,
ts: *mut __ygg_timespec_t,
) -> CIntZeroResult {
let clock = clock_type(clock_id)?;
let time = rt::get_clock(clock).e_map_err(Errno::from)?;
if let Some(ts) = NonNull::new(ts) {
ts.write(timespec {
tv_sec: time_t(time.seconds as _),
tv_nsec: time.nanoseconds as _
tv_sec: time_t(time.seconds() as _),
tv_nsec: time.subsec_nanos() as _,
});
}

View File

@ -1,7 +1,4 @@
use core::{
ffi::{c_char, c_int},
ptr::NonNull,
};
use core::{ffi::c_char, ptr::NonNull};
use crate::{
error::{CIntZeroResult, CUsizeResult, OptionExt},
@ -21,7 +18,7 @@ unsafe extern "C" fn mbrlen(_str: *const c_char, _n: usize, _state: *mut mbstate
#[no_mangle]
unsafe extern "C" fn wcrtomb(dst: *mut c_char, wc: wchar_t, state: *mut mbstate_t) -> CUsizeResult {
let state = match state.as_mut() {
let _state = match state.as_mut() {
Some(state) => state,
#[allow(static_mut_refs)]
None => &mut GLOBAL,

View File

@ -1,8 +1,5 @@
use core::cell::RefCell;
use yggdrasil_rt::process::thread_local;
struct RandomState {
xs64: u64
}

View File

@ -2,7 +2,7 @@
use std::os::{
fd::AsRawFd,
yggdrasil::io::{poll::PollChannel, raw_socket::RawSocket, timer::TimerFd},
yggdrasil::io::{poll::PollChannel, net::raw_socket::RawSocket, timer::TimerFd},
};
use std::{io, mem::size_of, process::ExitCode, time::Duration};

View File

@ -1,347 +1,349 @@
#![feature(yggdrasil_os, rustc_private)]
fn main() {}
use std::{
mem::size_of,
os::{
fd::AsRawFd,
yggdrasil::io::{poll::PollChannel, raw_socket::RawSocket, timer::TimerFd},
},
process::ExitCode,
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use bytemuck::Zeroable;
use clap::Parser;
use netutils::{netconfig::NetConfig, Error};
use yggdrasil_abi::net::{
protocols::{EtherType, EthernetFrame, IcmpV4Frame, InetChecksum, IpProtocol, Ipv4Frame},
types::NetValueImpl,
IpAddr, Ipv4Addr, MacAddress,
};
#[derive(Parser)]
struct Args {
#[clap(
help = "Time (ms) between a reply is received and the next request is sent",
short,
long,
default_value_t = 1000,
value_parser = valid_interval
)]
inteval: u32,
#[clap(
help = "Time (ms) after which the request is considered unanswered",
short,
long,
default_value_t = 500,
value_parser = valid_timeout,
)]
timeout: u32,
#[clap(
help = "Number of requests to perform",
short,
long,
default_value_t = 10
)]
count: usize,
#[clap(
help = "Amount of bytes to include as data",
short,
long,
default_value_t = 64,
value_parser = valid_data_size
)]
data_size: usize,
#[clap(help = "Address to ping")]
address: core::net::IpAddr,
}
fn valid_interval(s: &str) -> Result<u32, String> {
clap_num::number_range(s, 100, 10000)
}
fn valid_timeout(s: &str) -> Result<u32, String> {
clap_num::number_range(s, 100, 5000)
}
fn valid_data_size(s: &str) -> Result<usize, String> {
clap_num::number_range(s, 4, 128)
}
struct PingRouting {
interface_id: u32,
source_ip: IpAddr,
destination_ip: IpAddr,
source_mac: MacAddress,
gateway_mac: MacAddress,
}
struct PingStats {
packets_sent: usize,
packets_received: usize,
}
fn resolve_routing(address: IpAddr) -> Result<PingRouting, Error> {
let mut nc = NetConfig::open()?;
let routing = nc.query_route(address)?;
let Some(source) = routing.source else {
todo!();
};
let Some(gateway) = routing.gateway else {
todo!();
};
let gateway_mac = nc.query_arp(routing.interface_id, gateway, true)?;
Ok(PingRouting {
interface_id: routing.interface_id,
source_ip: source,
destination_ip: routing.destination,
source_mac: routing.source_mac,
gateway_mac,
})
}
fn validate_ping_reply(
packet: &[u8],
local: Ipv4Addr,
remote: Ipv4Addr,
expect_l4_data: &[u8],
expect_id: u16,
expect_seq: u16,
) -> bool {
if packet.len() < size_of::<EthernetFrame>() + size_of::<Ipv4Frame>() {
return false;
}
let l3_offset = size_of::<EthernetFrame>();
let l2_frame: &EthernetFrame = bytemuck::from_bytes(&packet[..l3_offset]);
if EtherType::from_network_order(l2_frame.ethertype) != EtherType::IPV4 {
return false;
}
let l3_frame: &Ipv4Frame =
bytemuck::from_bytes(&packet[l3_offset..l3_offset + size_of::<Ipv4Frame>()]);
if l3_frame.protocol != IpProtocol::ICMP
|| u32::from_network_order(l3_frame.source_address) != u32::from(remote)
|| u32::from_network_order(l3_frame.destination_address) != u32::from(local)
{
return false;
}
let mut ip_checksum = InetChecksum::new();
ip_checksum.add_value(l3_frame, true);
let ip_checksum = ip_checksum.finish();
if ip_checksum != 0 {
eprintln!("IP checksum mismatch: {:#06x}", ip_checksum);
return false;
}
let l4_offset = l3_offset + l3_frame.header_length();
let l4_size = l3_frame
.total_length()
.saturating_sub(l3_frame.header_length());
if packet.len() < l4_offset + size_of::<IcmpV4Frame>() + expect_l4_data.len() {
return false;
}
let l4_frame: &IcmpV4Frame =
bytemuck::from_bytes(&packet[l4_offset..l4_offset + size_of::<IcmpV4Frame>()]);
let l4_data = &packet[l4_offset + size_of::<IcmpV4Frame>()..l4_offset + l4_size];
if l4_frame.ty != 0 || l4_frame.code != 0 {
return false;
}
let rest = u32::from_network_order(l4_frame.rest);
let reply_id = (rest >> 16) as u16;
let reply_seq = rest as u16;
if reply_id != expect_id || reply_seq != expect_seq {
eprintln!(
"ICMP seq/id mismatch: sent {}/{}, got {}/{}",
expect_id, expect_seq, reply_id, reply_seq
);
return false;
}
let mut icmp_checksum = InetChecksum::new();
icmp_checksum.add_value(l4_frame, true);
icmp_checksum.add_bytes(l4_data, true);
let icmp_checksum = icmp_checksum.finish();
if icmp_checksum != 0 {
eprintln!("ICMP checksum mismatch: {:#06x}", icmp_checksum);
return false;
}
l4_data == expect_l4_data
}
#[allow(clippy::too_many_arguments)]
fn ping_once(
socket: &mut RawSocket,
poll: &mut PollChannel,
timer: &mut TimerFd,
info: &PingRouting,
timeout: Duration,
data_len: usize,
id: u16,
seq: u16,
) -> Result<bool, Error> {
let mut buffer = [0; 4096];
let source_ip = info.source_ip.into_ipv4().unwrap();
let destination_ip = info.destination_ip.into_ipv4().unwrap();
let mut l4_data = Vec::with_capacity(data_len);
for _ in 0..data_len {
l4_data.push(rand::random());
}
let ip_len = (size_of::<Ipv4Frame>() + size_of::<IcmpV4Frame>() + data_len)
.try_into()
.unwrap();
let l2_frame = EthernetFrame {
source_mac: info.source_mac,
destination_mac: info.gateway_mac,
ethertype: EtherType::IPV4.to_network_order(),
};
let mut l3_frame = Ipv4Frame {
source_address: u32::from(source_ip).to_network_order(),
destination_address: u32::from(destination_ip).to_network_order(),
protocol: IpProtocol::ICMP,
version_length: 0x45,
total_length: u16::to_network_order(ip_len),
flags_frag: u16::to_network_order(0x4000),
id: u16::to_network_order(0),
ttl: 255,
..Ipv4Frame::zeroed()
};
let mut l4_frame = IcmpV4Frame {
ty: 8,
code: 0,
checksum: u16::to_network_order(0),
rest: u32::to_network_order(((id as u32) << 16) | (seq as u32)),
};
let mut ip_checksum = InetChecksum::new();
ip_checksum.add_value(&l3_frame, true);
l3_frame.header_checksum = ip_checksum.finish().to_network_order();
let mut icmp_checksum = InetChecksum::new();
icmp_checksum.add_value(&l4_frame, true);
icmp_checksum.add_bytes(&l4_data, true);
l4_frame.checksum = icmp_checksum.finish().to_network_order();
let mut packet = vec![];
packet.extend_from_slice(bytemuck::bytes_of(&l2_frame));
packet.extend_from_slice(bytemuck::bytes_of(&l3_frame));
packet.extend_from_slice(bytemuck::bytes_of(&l4_frame));
packet.extend_from_slice(&l4_data);
timer.start(timeout)?;
socket.send(&packet)?;
loop {
let (fd, result) = poll.wait(None, true)?.unwrap();
result?;
match fd {
fd if fd == socket.as_raw_fd() => {
// TODO
let len = socket.recv(&mut buffer)?;
if validate_ping_reply(&buffer[..len], source_ip, destination_ip, &l4_data, id, seq)
{
return Ok(true);
}
}
fd if fd == timer.as_raw_fd() => {
return Ok(false);
}
_ => unreachable!(),
}
}
}
fn ping(
address: IpAddr,
times: usize,
data_len: usize,
interval: Duration,
timeout: Duration,
) -> Result<PingStats, Error> {
let routing = resolve_routing(address)?;
let mut stats = PingStats {
packets_sent: 0,
packets_received: 0,
};
let mut poll = PollChannel::new()?;
let mut timer = TimerFd::new(false, false)?;
let mut socket = RawSocket::bind(routing.interface_id)?;
poll.add(timer.as_raw_fd())?;
poll.add(socket.as_raw_fd())?;
let id = rand::random();
for i in 0..times {
if INTERRUPTED.load(Ordering::Acquire) {
break;
}
let result = ping_once(
&mut socket,
&mut poll,
&mut timer,
&routing,
timeout,
data_len,
id,
i as u16,
)?;
stats.packets_sent += 1;
if result {
stats.packets_received += 1;
println!("[{}/{}] {}: PONG", i + 1, times, address);
}
std::thread::sleep(interval);
}
Ok(stats)
}
static INTERRUPTED: AtomicBool = AtomicBool::new(false);
fn main() -> ExitCode {
// set_signal_handler(Signal::Interrupted, SignalHandler::Function(interrupt));
let args = Args::parse();
let stats = match ping(
args.address.into(),
args.count,
args.data_size,
Duration::from_millis(args.inteval.into()),
Duration::from_millis(args.timeout.into()),
) {
Ok(stats) => stats,
Err(error) => {
eprintln!("ping: {}", error);
return ExitCode::FAILURE;
}
};
let loss = (stats.packets_sent - stats.packets_received) * 100 / stats.packets_sent;
println!(
"{} sent, {} received, {}% loss",
stats.packets_sent, stats.packets_received, loss
);
ExitCode::SUCCESS
}
// #![feature(yggdrasil_os, rustc_private)]
//
// use std::{
// mem::size_of,
// os::{
// fd::AsRawFd,
// yggdrasil::io::{poll::PollChannel, raw_socket::RawSocket, timer::TimerFd},
// },
// process::ExitCode,
// sync::atomic::{AtomicBool, Ordering},
// time::Duration,
// };
//
// use bytemuck::Zeroable;
// use clap::Parser;
// use netutils::{netconfig::NetConfig, Error};
// use yggdrasil_abi::net::{
// protocols::{EtherType, EthernetFrame, IcmpV4Frame, InetChecksum, IpProtocol, Ipv4Frame},
// types::NetValueImpl,
// IpAddr, Ipv4Addr, MacAddress,
// };
//
// #[derive(Parser)]
// struct Args {
// #[clap(
// help = "Time (ms) between a reply is received and the next request is sent",
// short,
// long,
// default_value_t = 1000,
// value_parser = valid_interval
// )]
// inteval: u32,
// #[clap(
// help = "Time (ms) after which the request is considered unanswered",
// short,
// long,
// default_value_t = 500,
// value_parser = valid_timeout,
// )]
// timeout: u32,
// #[clap(
// help = "Number of requests to perform",
// short,
// long,
// default_value_t = 10
// )]
// count: usize,
// #[clap(
// help = "Amount of bytes to include as data",
// short,
// long,
// default_value_t = 64,
// value_parser = valid_data_size
// )]
// data_size: usize,
//
// #[clap(help = "Address to ping")]
// address: core::net::IpAddr,
// }
//
// fn valid_interval(s: &str) -> Result<u32, String> {
// clap_num::number_range(s, 100, 10000)
// }
//
// fn valid_timeout(s: &str) -> Result<u32, String> {
// clap_num::number_range(s, 100, 5000)
// }
//
// fn valid_data_size(s: &str) -> Result<usize, String> {
// clap_num::number_range(s, 4, 128)
// }
//
// struct PingRouting {
// interface_id: u32,
// source_ip: IpAddr,
// destination_ip: IpAddr,
// source_mac: MacAddress,
// gateway_mac: MacAddress,
// }
//
// struct PingStats {
// packets_sent: usize,
// packets_received: usize,
// }
//
// fn resolve_routing(address: IpAddr) -> Result<PingRouting, Error> {
// let mut nc = NetConfig::open()?;
// let routing = nc.query_route(address)?;
// let Some(source) = routing.source else {
// todo!();
// };
// let Some(gateway) = routing.gateway else {
// todo!();
// };
//
// let gateway_mac = nc.query_arp(routing.interface_id, gateway, true)?;
//
// Ok(PingRouting {
// interface_id: routing.interface_id,
// source_ip: source,
// destination_ip: routing.destination,
// source_mac: routing.source_mac,
// gateway_mac,
// })
// }
//
// fn validate_ping_reply(
// packet: &[u8],
// local: Ipv4Addr,
// remote: Ipv4Addr,
// expect_l4_data: &[u8],
// expect_id: u16,
// expect_seq: u16,
// ) -> bool {
// if packet.len() < size_of::<EthernetFrame>() + size_of::<Ipv4Frame>() {
// return false;
// }
//
// let l3_offset = size_of::<EthernetFrame>();
//
// let l2_frame: &EthernetFrame = bytemuck::from_bytes(&packet[..l3_offset]);
//
// if EtherType::from_network_order(l2_frame.ethertype) != EtherType::IPV4 {
// return false;
// }
// let l3_frame: &Ipv4Frame =
// bytemuck::from_bytes(&packet[l3_offset..l3_offset + size_of::<Ipv4Frame>()]);
// if l3_frame.protocol != IpProtocol::ICMP
// || u32::from_network_order(l3_frame.source_address) != u32::from(remote)
// || u32::from_network_order(l3_frame.destination_address) != u32::from(local)
// {
// return false;
// }
// let mut ip_checksum = InetChecksum::new();
// ip_checksum.add_value(l3_frame, true);
// let ip_checksum = ip_checksum.finish();
//
// if ip_checksum != 0 {
// eprintln!("IP checksum mismatch: {:#06x}", ip_checksum);
// return false;
// }
//
// let l4_offset = l3_offset + l3_frame.header_length();
// let l4_size = l3_frame
// .total_length()
// .saturating_sub(l3_frame.header_length());
// if packet.len() < l4_offset + size_of::<IcmpV4Frame>() + expect_l4_data.len() {
// return false;
// }
// let l4_frame: &IcmpV4Frame =
// bytemuck::from_bytes(&packet[l4_offset..l4_offset + size_of::<IcmpV4Frame>()]);
// let l4_data = &packet[l4_offset + size_of::<IcmpV4Frame>()..l4_offset + l4_size];
//
// if l4_frame.ty != 0 || l4_frame.code != 0 {
// return false;
// }
//
// let rest = u32::from_network_order(l4_frame.rest);
// let reply_id = (rest >> 16) as u16;
// let reply_seq = rest as u16;
//
// if reply_id != expect_id || reply_seq != expect_seq {
// eprintln!(
// "ICMP seq/id mismatch: sent {}/{}, got {}/{}",
// expect_id, expect_seq, reply_id, reply_seq
// );
// return false;
// }
//
// let mut icmp_checksum = InetChecksum::new();
// icmp_checksum.add_value(l4_frame, true);
// icmp_checksum.add_bytes(l4_data, true);
// let icmp_checksum = icmp_checksum.finish();
//
// if icmp_checksum != 0 {
// eprintln!("ICMP checksum mismatch: {:#06x}", icmp_checksum);
// return false;
// }
//
// l4_data == expect_l4_data
// }
//
// #[allow(clippy::too_many_arguments)]
// fn ping_once(
// socket: &mut RawSocket,
// poll: &mut PollChannel,
// timer: &mut TimerFd,
// info: &PingRouting,
// timeout: Duration,
// data_len: usize,
// id: u16,
// seq: u16,
// ) -> Result<bool, Error> {
// let mut buffer = [0; 4096];
//
// let source_ip = info.source_ip.into_ipv4().unwrap();
// let destination_ip = info.destination_ip.into_ipv4().unwrap();
// let mut l4_data = Vec::with_capacity(data_len);
//
// for _ in 0..data_len {
// l4_data.push(rand::random());
// }
//
// let ip_len = (size_of::<Ipv4Frame>() + size_of::<IcmpV4Frame>() + data_len)
// .try_into()
// .unwrap();
//
// let l2_frame = EthernetFrame {
// source_mac: info.source_mac,
// destination_mac: info.gateway_mac,
// ethertype: EtherType::IPV4.to_network_order(),
// };
// let mut l3_frame = Ipv4Frame {
// source_address: u32::from(source_ip).to_network_order(),
// destination_address: u32::from(destination_ip).to_network_order(),
// protocol: IpProtocol::ICMP,
// version_length: 0x45,
// total_length: u16::to_network_order(ip_len),
// flags_frag: u16::to_network_order(0x4000),
// id: u16::to_network_order(0),
// ttl: 255,
// ..Ipv4Frame::zeroed()
// };
// let mut l4_frame = IcmpV4Frame {
// ty: 8,
// code: 0,
// checksum: u16::to_network_order(0),
// rest: u32::to_network_order(((id as u32) << 16) | (seq as u32)),
// };
//
// let mut ip_checksum = InetChecksum::new();
// ip_checksum.add_value(&l3_frame, true);
// l3_frame.header_checksum = ip_checksum.finish().to_network_order();
//
// let mut icmp_checksum = InetChecksum::new();
// icmp_checksum.add_value(&l4_frame, true);
// icmp_checksum.add_bytes(&l4_data, true);
// l4_frame.checksum = icmp_checksum.finish().to_network_order();
//
// let mut packet = vec![];
// packet.extend_from_slice(bytemuck::bytes_of(&l2_frame));
// packet.extend_from_slice(bytemuck::bytes_of(&l3_frame));
// packet.extend_from_slice(bytemuck::bytes_of(&l4_frame));
// packet.extend_from_slice(&l4_data);
//
// timer.start(timeout)?;
// socket.send(&packet)?;
//
// loop {
// let (fd, result) = poll.wait(None, true)?.unwrap();
// result?;
//
// match fd {
// fd if fd == socket.as_raw_fd() => {
// // TODO
// let len = socket.recv(&mut buffer)?;
// if validate_ping_reply(&buffer[..len], source_ip, destination_ip, &l4_data, id, seq)
// {
// return Ok(true);
// }
// }
// fd if fd == timer.as_raw_fd() => {
// return Ok(false);
// }
// _ => unreachable!(),
// }
// }
// }
//
// fn ping(
// address: IpAddr,
// times: usize,
// data_len: usize,
// interval: Duration,
// timeout: Duration,
// ) -> Result<PingStats, Error> {
// let routing = resolve_routing(address)?;
//
// let mut stats = PingStats {
// packets_sent: 0,
// packets_received: 0,
// };
// let mut poll = PollChannel::new()?;
// let mut timer = TimerFd::new(false, false)?;
// let mut socket = RawSocket::bind(routing.interface_id)?;
//
// poll.add(timer.as_raw_fd())?;
// poll.add(socket.as_raw_fd())?;
//
// let id = rand::random();
// for i in 0..times {
// if INTERRUPTED.load(Ordering::Acquire) {
// break;
// }
//
// let result = ping_once(
// &mut socket,
// &mut poll,
// &mut timer,
// &routing,
// timeout,
// data_len,
// id,
// i as u16,
// )?;
// stats.packets_sent += 1;
//
// if result {
// stats.packets_received += 1;
// println!("[{}/{}] {}: PONG", i + 1, times, address);
// }
//
// std::thread::sleep(interval);
// }
//
// Ok(stats)
// }
//
// static INTERRUPTED: AtomicBool = AtomicBool::new(false);
//
// fn main() -> ExitCode {
// // set_signal_handler(Signal::Interrupted, SignalHandler::Function(interrupt));
//
// let args = Args::parse();
//
// let stats = match ping(
// args.address.into(),
// args.count,
// args.data_size,
// Duration::from_millis(args.inteval.into()),
// Duration::from_millis(args.timeout.into()),
// ) {
// Ok(stats) => stats,
// Err(error) => {
// eprintln!("ping: {}", error);
// return ExitCode::FAILURE;
// }
// };
//
// let loss = (stats.packets_sent - stats.packets_received) * 100 / stats.packets_sent;
// println!(
// "{} sent, {} received, {}% loss",
// stats.packets_sent, stats.packets_received, loss
// );
//
// ExitCode::SUCCESS
// }