net: move to berkeley-style sockets

This commit is contained in:
Mark Poliakov 2025-01-07 15:17:17 +02:00
parent f1256e262b
commit a4e441d236
46 changed files with 2267 additions and 1317 deletions

View File

@ -80,7 +80,10 @@ impl ArpTable {
pub fn lookup_cache(interface: u32, address: IpAddr) -> Option<MacAddress> { pub fn lookup_cache(interface: u32, address: IpAddr) -> Option<MacAddress> {
let (address, _) = match address { let (address, _) = match address {
IpAddr::V4(address) => Self::lookup_cache_v4(interface, address), IpAddr::V4(address) => Self::lookup_cache_v4(interface, address),
IpAddr::V6(_) => todo!(), IpAddr::V6(v6) => {
log::warn!("TODO: ArpTable v6 lookup: {v6}");
return None;
}
}?; }?;
Some(address) Some(address)
} }
@ -92,7 +95,10 @@ impl ArpTable {
pub fn flush_address(interface: u32, address: IpAddr) -> bool { pub fn flush_address(interface: u32, address: IpAddr) -> bool {
match address { match address {
IpAddr::V4(address) => Self::flush_address_v4(interface, address), IpAddr::V4(address) => Self::flush_address_v4(interface, address),
IpAddr::V6(_) => todo!(), IpAddr::V6(v6) => {
log::warn!("TODO: ArpTable v6 flush: {v6}");
false
}
} }
} }
@ -103,7 +109,10 @@ impl ArpTable {
pub fn insert_address(interface: u32, mac: MacAddress, address: IpAddr, owned: bool) { pub fn insert_address(interface: u32, mac: MacAddress, address: IpAddr, owned: bool) {
match address { match address {
IpAddr::V4(address) => Self::insert_address_v4(interface, mac, address, owned), IpAddr::V4(address) => Self::insert_address_v4(interface, mac, address, owned),
IpAddr::V6(_) => todo!(), IpAddr::V6(v6) => {
log::warn!("TODO: ArpTable v6 insert: {v6}");
return;
}
} }
ARP_TABLE.notify.wake_all(); ARP_TABLE.notify.wake_all();
} }
@ -203,7 +212,10 @@ fn send_request(interface: &NetworkInterface, query_address: IpAddr) -> Result<(
match query_address { match query_address {
IpAddr::V4(address) => send_request_v4(interface, address), IpAddr::V4(address) => send_request_v4(interface, address),
IpAddr::V6(_) => todo!(), IpAddr::V6(v6) => {
log::warn!("TODO: ARP IPv6 query: {v6}");
Err(Error::NotImplemented)
}
} }
} }

View File

@ -24,7 +24,7 @@ async fn send_v4_reply(
}; };
if icmp_data.len() % 2 != 0 { if icmp_data.len() % 2 != 0 {
todo!(); return Err(Error::InvalidArgument);
} }
let l4_bytes = bytemuck::bytes_of(&reply_frame); let l4_bytes = bytemuck::bytes_of(&reply_frame);
@ -77,6 +77,9 @@ async fn handle_v4(source_address: Ipv4Addr, l3_packet: L3Packet) -> Result<(),
pub async fn handle(l3_packet: L3Packet) -> Result<(), Error> { pub async fn handle(l3_packet: L3Packet) -> Result<(), Error> {
match l3_packet.source_address { match l3_packet.source_address {
IpAddr::V4(v4) => handle_v4(v4, l3_packet).await, IpAddr::V4(v4) => handle_v4(v4, l3_packet).await,
IpAddr::V6(_) => todo!(), IpAddr::V6(v6) => {
log::warn!("TODO: ICMPv6 from {v6}");
Err(Error::NotImplemented)
}
} }
} }

View File

@ -17,7 +17,7 @@ use yggdrasil_abi::{
use crate::{ use crate::{
l3::{self, L3Packet}, l3::{self, L3Packet},
socket::{TcpListener, TcpSocket}, socket::tcp::{TcpListener, TcpStream},
util::Assembler, util::Assembler,
}; };
@ -652,12 +652,15 @@ pub async fn handle(packet: L3Packet) -> Result<(), Error> {
match tcp_frame.flags { match tcp_frame.flags {
TcpFlags::SYN => { TcpFlags::SYN => {
if let Some(listener) = TcpListener::get(local) { if let Some(listener) = TcpListener::get(&local)
&& listener.is_listening()
{
log::debug!("tcp: create remote stream {remote} -> {local}");
let window_size = u16::from_network_order(tcp_frame.window_size); let window_size = u16::from_network_order(tcp_frame.window_size);
let tx_seq = 12345; let tx_seq = 12345;
// Create a socket and insert it into the table // Create a socket and insert it into the table
TcpSocket::accept_remote( TcpStream::accept_remote(
listener.clone(), listener.clone(),
local, local,
remote, remote,
@ -703,16 +706,17 @@ pub async fn handle(packet: L3Packet) -> Result<(), Error> {
seq, seq,
}; };
let socket = TcpSocket::get(local, remote).ok_or(Error::DoesNotExist)?; let stream = TcpStream::get(local, remote).ok_or(Error::DoesNotExist)?;
let mut connection = socket.connection().write(); let mut connection = stream.connection.write();
match connection.handle_packet(packet, tcp_data).await? { match connection.handle_packet(packet, tcp_data).await? {
TcpSocketBehavior::None => (), TcpSocketBehavior::None => (),
TcpSocketBehavior::Accept => { TcpSocketBehavior::Accept => {
socket.accept(); stream.accept();
} }
TcpSocketBehavior::Remove => { TcpSocketBehavior::Remove => {
drop(connection); drop(connection);
socket.remove_socket()?; stream.remove_stream()?;
} }
} }
Ok(()) Ok(())

View File

@ -1,4 +1,4 @@
#![feature(map_try_insert)] #![feature(map_try_insert, let_chains, result_flattening)]
#![allow(clippy::type_complexity, clippy::new_without_default)] #![allow(clippy::type_complexity, clippy::new_without_default)]
#![no_std] #![no_std]

View File

@ -1,18 +1,20 @@
use alloc::{collections::BTreeMap, sync::Arc}; use alloc::{
use libk::vfs::Socket; collections::{btree_map::Entry, BTreeMap},
sync::Arc,
};
use yggdrasil_abi::{ use yggdrasil_abi::{
error::Error, error::Error,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
}; };
pub mod udp; pub mod udp;
pub use udp::UdpSocket; pub use udp::UdpSocket;
pub mod tcp; pub mod tcp;
pub use tcp::{TcpListener, TcpSocket}; pub use tcp::TcpSocket;
pub mod raw; pub mod raw;
pub use raw::RawSocket; pub use raw::RawSocket;
pub struct SocketTable<T: Socket> { pub struct SocketTable<T> {
inner: BTreeMap<SocketAddr, Arc<T>>, inner: BTreeMap<SocketAddr, Arc<T>>,
} }
@ -71,7 +73,7 @@ impl<T> TwoWaySocketTable<T> {
} }
} }
impl<T: Socket> SocketTable<T> { impl<T> SocketTable<T> {
pub const fn new() -> Self { pub const fn new() -> Self {
Self { Self {
inner: BTreeMap::new(), inner: BTreeMap::new(),
@ -95,6 +97,29 @@ impl<T: Socket> SocketTable<T> {
Err(Error::AddrInUse) Err(Error::AddrInUse)
} }
pub fn bind_to_ephemeral_port(&mut self, ip: IpAddr, socket: Arc<T>) -> Result<u16, Error> {
for port in 32768..u16::MAX - 1 {
let local = SocketAddr::new(ip, port);
match self.try_insert(local, socket.clone()) {
Ok(()) => return Ok(port),
Err(Error::AddrInUse) => continue,
Err(error) => return Err(error),
}
}
Err(Error::AddrInUse)
}
pub fn try_insert(&mut self, address: SocketAddr, socket: Arc<T>) -> Result<(), Error> {
match self.inner.entry(address) {
Entry::Vacant(entry) => {
entry.insert(socket);
Ok(())
}
Entry::Occupied(_) => Err(Error::AddrInUse),
}
}
pub fn try_insert_with<F: FnOnce() -> Result<Arc<T>, Error>>( pub fn try_insert_with<F: FnOnce() -> Result<Arc<T>, Error>>(
&mut self, &mut self,
address: SocketAddr, address: SocketAddr,
@ -124,12 +149,11 @@ impl<T: Socket> SocketTable<T> {
return Some(socket.clone()); return Some(socket.clone());
} }
match local { let unspec = match local {
SocketAddr::V4(_v4) => { SocketAddr::V4(_v4) => SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, local.port()).into(),
let unspec_v4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, local.port()); SocketAddr::V6(_v6) => SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, local.port()).into(),
self.inner.get(&SocketAddr::V4(unspec_v4)).cloned() };
}
SocketAddr::V6(_) => todo!(), self.inner.get(&unspec).cloned()
}
} }
} }

View File

@ -2,26 +2,25 @@ use core::{
fmt, fmt,
sync::atomic::{AtomicU32, Ordering}, sync::atomic::{AtomicU32, Ordering},
task::{Context, Poll}, task::{Context, Poll},
time::Duration,
}; };
use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec}; use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
use async_trait::async_trait; use async_trait::async_trait;
use libk::{ use libk::{
error::Error, error::Error,
task::runtime::maybe_timeout,
vfs::{FileReadiness, PacketSocket, Socket}, vfs::{FileReadiness, PacketSocket, Socket},
}; };
use libk_mm::PageBox; use libk_mm::PageBox;
use libk_util::{ use libk_util::{queue::BoundedMpmcQueue, sync::spin_rwlock::IrqSafeRwLock};
queue::BoundedMpmcQueue, use yggdrasil_abi::net::{SocketAddr, SocketInterfaceQuery, SocketOption};
sync::{spin_rwlock::IrqSafeRwLock, IrqSafeSpinlock},
};
use yggdrasil_abi::net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketInterfaceQuery, SocketOption};
use crate::{ethernet::L2Packet, interface::NetworkInterface}; use crate::{ethernet::L2Packet, interface::NetworkInterface};
pub struct RawSocket { pub struct RawSocket {
id: u32, id: u32,
bound: IrqSafeSpinlock<Option<u32>>, bound: IrqSafeRwLock<Option<u32>>,
receive_queue: BoundedMpmcQueue<L2Packet>, receive_queue: BoundedMpmcQueue<L2Packet>,
} }
@ -32,17 +31,15 @@ static BOUND_RAW_SOCKETS: IrqSafeRwLock<BTreeMap<u32, Vec<u32>>> =
IrqSafeRwLock::new(BTreeMap::new()); IrqSafeRwLock::new(BTreeMap::new());
impl RawSocket { impl RawSocket {
pub fn bind() -> Result<Arc<Self>, Error> { pub fn new() -> Arc<Self> {
let id = RAW_SOCKET_ID.fetch_add(1, Ordering::SeqCst); let id = RAW_SOCKET_ID.fetch_add(1, Ordering::SeqCst);
let socket = Self { let socket = Arc::new(Self {
id, id,
bound: IrqSafeSpinlock::new(None), bound: IrqSafeRwLock::new(None),
receive_queue: BoundedMpmcQueue::new(256), receive_queue: BoundedMpmcQueue::new(256),
}; });
let socket = Arc::new(socket);
RAW_SOCKETS.write().insert(id, socket.clone()); RAW_SOCKETS.write().insert(id, socket.clone());
socket
Ok(socket)
} }
fn bound_packet_received(&self, packet: L2Packet) { fn bound_packet_received(&self, packet: L2Packet) {
@ -57,6 +54,7 @@ impl RawSocket {
if let Some(ids) = bound_sockets.get(&packet.interface_id) { if let Some(ids) = bound_sockets.get(&packet.interface_id) {
for id in ids { for id in ids {
let socket = raw_sockets.get(id).unwrap(); let socket = raw_sockets.get(id).unwrap();
log::info!("Packet -> {id}");
socket.bound_packet_received(packet.clone()); socket.bound_packet_received(packet.clone());
} }
} }
@ -80,10 +78,14 @@ impl FileReadiness for RawSocket {
} }
impl Socket for RawSocket { impl Socket for RawSocket {
fn bind(self: Arc<Self>, _local: SocketAddr) -> Result<(), Error> {
Err(Error::InvalidOperation)
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> { fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option { match option {
SocketOption::BoundHardwareAddress(mac) => { SocketOption::BoundHardwareAddress(mac) => {
let bound = self.bound.lock().ok_or(Error::DoesNotExist)?; let bound = self.bound.read().ok_or(Error::NotConnected)?;
let interface = NetworkInterface::get(bound).unwrap(); let interface = NetworkInterface::get(bound).unwrap();
*mac = interface.mac; *mac = interface.mac;
Ok(()) Ok(())
@ -95,18 +97,16 @@ impl Socket for RawSocket {
fn set_option(&self, option: &SocketOption) -> Result<(), Error> { fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option { match option {
SocketOption::BindInterface(query) => { SocketOption::BindInterface(query) => {
let mut bound = self.bound.lock(); let mut bound = self.bound.write();
if bound.is_some() { if bound.is_some() {
return Err(Error::AlreadyExists); return Err(Error::AlreadyExists);
} }
let mut bound_sockets = BOUND_RAW_SOCKETS.write(); let mut bound_sockets = BOUND_RAW_SOCKETS.write();
let interface = match *query { let interface = match *query {
SocketInterfaceQuery::ById(id) => NetworkInterface::get(id), SocketInterfaceQuery::ById(id) => NetworkInterface::get(id),
SocketInterfaceQuery::ByName(name) => NetworkInterface::query_by_name(name), SocketInterfaceQuery::ByName(name) => NetworkInterface::query_by_name(name),
}?; }?;
let list = bound_sockets.entry(interface.id).or_default(); let list = bound_sockets.entry(interface.id).or_default();
bound.replace(interface.id); bound.replace(interface.id);
list.push(self.id); list.push(self.id);
@ -118,8 +118,8 @@ impl Socket for RawSocket {
} }
} }
fn close(&self) -> Result<(), Error> { fn close(self: Arc<Self>) -> Result<(), Error> {
let bound = self.bound.lock().take(); let bound = self.bound.write().take();
if let Some(bound) = bound { if let Some(bound) = bound {
let mut bound_sockets = BOUND_RAW_SOCKETS.write(); let mut bound_sockets = BOUND_RAW_SOCKETS.write();
@ -140,8 +140,8 @@ impl Socket for RawSocket {
Ok(()) Ok(())
} }
fn local_address(&self) -> SocketAddr { fn local_address(&self) -> Option<SocketAddr> {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)) None
} }
fn remote_address(&self) -> Option<SocketAddr> { fn remote_address(&self) -> Option<SocketAddr> {
@ -151,18 +151,27 @@ impl Socket for RawSocket {
#[async_trait] #[async_trait]
impl PacketSocket for RawSocket { impl PacketSocket for RawSocket {
async fn send_to(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error> { fn connect(self: Arc<Self>, _remote: SocketAddr) -> Result<(), Error> {
Err(Error::InvalidOperation)
}
async fn send_to(
self: Arc<Self>,
destination: Option<SocketAddr>,
data: &[u8],
_timeout: Option<Duration>,
) -> Result<usize, Error> {
self.send_nonblocking(destination, data) self.send_nonblocking(destination, data)
} }
// TODO currently this is still blocking by NIC send code // TODO currently this is still blocking by NIC send code
fn send_nonblocking( fn send_nonblocking(
&self, self: Arc<Self>,
_destination: Option<SocketAddr>, _destination: Option<SocketAddr>,
buffer: &[u8], buffer: &[u8],
) -> Result<usize, Error> { ) -> Result<usize, Error> {
// TODO cap by MTU? // TODO cap by MTU?
let bound = self.bound.lock().ok_or(Error::InvalidOperation)?; let bound = self.bound.read().ok_or(Error::InvalidOperation)?;
let interface = NetworkInterface::get(bound)?; let interface = NetworkInterface::get(bound)?;
let l2_offset = interface.device.packet_prefix_size(); let l2_offset = interface.device.packet_prefix_size();
if buffer.len() > 4096 - l2_offset { if buffer.len() > 4096 - l2_offset {
@ -174,12 +183,20 @@ impl PacketSocket for RawSocket {
Ok(buffer.len()) Ok(buffer.len())
} }
async fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> { async fn receive_from(
let packet = self.receive_queue.pop_front().await; self: Arc<Self>,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error> {
let future = self.receive_queue.pop_front();
let packet = maybe_timeout(future, timeout).await?;
Self::packet_to_user(packet, buffer) Self::packet_to_user(packet, buffer)
} }
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> { fn receive_nonblocking(
self: Arc<Self>,
buffer: &mut [u8],
) -> Result<(usize, SocketAddr), Error> {
let packet = self let packet = self
.receive_queue .receive_queue
.try_pop_front() .try_pop_front()
@ -190,7 +207,7 @@ impl PacketSocket for RawSocket {
impl fmt::Debug for RawSocket { impl fmt::Debug for RawSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let bound = *self.bound.lock(); let bound = *self.bound.read();
f.debug_struct("RawSocket") f.debug_struct("RawSocket")
.field("interface", &bound) .field("interface", &bound)
.finish_non_exhaustive() .finish_non_exhaustive()

View File

@ -1,495 +0,0 @@
use core::{
fmt,
future::{poll_fn, Future},
sync::atomic::{AtomicU8, Ordering},
task::{Context, Poll},
time::Duration,
};
use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
use async_trait::async_trait;
use libk::{
block,
error::Error,
task::runtime::with_timeout,
time::monotonic_time,
vfs::{ConnectionSocket, FileReadiness, ListenerSocket, Socket},
};
use libk_util::{
sync::{
spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockWriteGuard},
IrqSafeSpinlock, IrqSafeSpinlockGuard,
},
waker::QueueWaker,
};
use yggdrasil_abi::net::{SocketAddr, SocketOption};
use crate::{
interface::NetworkInterface,
l3::Route,
l4::tcp::{TcpConnection, TcpConnectionState},
};
use super::{SocketTable, TwoWaySocketTable};
pub struct TcpSocket {
pub(crate) local: SocketAddr,
pub(crate) remote: SocketAddr,
ttl: AtomicU8,
// Listener which accepted the socket
listener: Option<Arc<TcpListener>>,
connection: IrqSafeRwLock<TcpConnection>,
}
pub struct TcpListener {
accept: SocketAddr,
// Currently active sockets
sockets: IrqSafeRwLock<BTreeMap<SocketAddr, Arc<TcpSocket>>>,
pending_accept: IrqSafeSpinlock<Vec<Arc<TcpSocket>>>,
accept_notify: QueueWaker,
}
static TCP_SOCKETS: IrqSafeRwLock<TwoWaySocketTable<TcpSocket>> =
IrqSafeRwLock::new(TwoWaySocketTable::new());
static TCP_LISTENERS: IrqSafeRwLock<SocketTable<TcpListener>> =
IrqSafeRwLock::new(SocketTable::new());
impl TcpSocket {
pub async fn connect(
remote: SocketAddr,
timeout: Option<Duration>,
) -> Result<(SocketAddr, Arc<TcpSocket>), Error> {
let future = Self::connect_async(remote);
match timeout {
Some(timeout) => with_timeout(future, timeout).await?,
None => future.await,
}
}
pub fn accept_remote(
listener: Arc<TcpListener>,
local: SocketAddr,
remote: SocketAddr,
remote_window_size: usize,
tx_seq: u32,
rx_seq: u32,
) -> Result<Arc<TcpSocket>, Error> {
let mut sockets = TCP_SOCKETS.write();
sockets.try_insert_with(local, remote, move || {
let connection = TcpConnection::new(
local,
remote,
remote_window_size,
tx_seq,
rx_seq,
TcpConnectionState::SynReceived,
);
log::debug!("Accepted TCP socket {} -> {}", local, remote);
let socket = Self {
local,
remote,
ttl: AtomicU8::new(64),
listener: Some(listener),
connection: IrqSafeRwLock::new(connection),
};
Ok(Arc::new(socket))
})
}
pub fn connection(&self) -> &IrqSafeRwLock<TcpConnection> {
&self.connection
}
pub(crate) fn accept(self: &Arc<Self>) {
if let Some(listener) = self.listener.as_ref() {
listener.accept_socket(self.clone());
}
}
pub fn get(local: SocketAddr, remote: SocketAddr) -> Option<Arc<Self>> {
TCP_SOCKETS.read().get(local, remote)
}
pub fn receive_async<'a>(
&'a self,
buffer: &'a mut [u8],
) -> impl Future<Output = Result<usize, Error>> + 'a {
poll_fn(|cx| match self.poll_receive(cx) {
Poll::Ready(Ok(mut lock)) => Poll::Ready(lock.read_nonblocking(buffer)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
})
}
pub async fn send_async(&self, data: &[u8]) -> Result<usize, Error> {
let mut pos = 0;
let mut rem = data.len();
while rem != 0 {
// TODO check MTU
let amount = rem.min(512);
self.send_segment_async(&data[pos..pos + amount]).await?;
pos += amount;
rem -= amount;
}
Ok(pos)
}
pub async fn close_async(&self, remove_from_listener: bool) -> Result<(), Error> {
// TODO timeout here
// Already closing
if self.connection.read().is_closing() {
return Ok(());
}
// Wait for all sent data to be acknowledged
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.finish().await?;
}
log::debug!(
"TCP socket closed (FinWait2/Closed): {} <-> {}",
self.local,
self.remote
);
// Wait for connection to get closed
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_finish(cx)
})
.await;
if remove_from_listener {
if let Some(listener) = self.listener.as_ref() {
listener.remove_socket(self.remote);
};
}
Ok(())
}
pub(crate) fn remove_socket(&self) -> Result<(), Error> {
log::debug!(
"TCP socket closed and removed: {} <-> {}",
self.local,
self.remote
);
let connection = self.connection.read();
debug_assert!(connection.is_closed());
TCP_SOCKETS.write().remove(self.local, self.remote)?;
connection.notify_all();
Ok(())
}
fn poll_receive(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<IrqSafeRwLockWriteGuard<TcpConnection>, Error>> {
let lock = self.connection.write();
match lock.poll_receive(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(lock)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
}
async fn send_segment_async(&self, data: &[u8]) -> Result<(), Error> {
// TODO timeout here
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.transmit(data).await?;
}
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_acknowledge(cx)
})
.await;
Ok(())
}
async fn connect_async(remote: SocketAddr) -> Result<(SocketAddr, Arc<TcpSocket>), Error> {
// Lookup route to remote
let (interface_id, _, remote_ip) =
Route::lookup(remote.ip()).ok_or(Error::HostUnreachable)?;
let remote = SocketAddr::new(remote_ip, remote.port());
let interface = NetworkInterface::get(interface_id)?;
let local_ip = interface.address.read().ok_or(Error::NetworkUnreachable)?;
let socket = {
let mut sockets = TCP_SOCKETS.write();
sockets.try_insert_with_ephemeral_port(local_ip, remote, |port| {
let t = monotonic_time();
let tx_seq = t.as_millis() as u32;
let local = SocketAddr::new(local_ip, port);
let connection =
TcpConnection::new(local, remote, 16384, tx_seq, 0, TcpConnectionState::Closed);
let socket = Self {
local,
remote,
ttl: AtomicU8::new(64),
listener: None,
connection: IrqSafeRwLock::new(connection),
};
Ok(Arc::new(socket))
})?
};
let mut t = 200;
for _ in 0..5 {
let timeout = Duration::from_millis(t);
log::debug!("Try SYN with timeout={:?}", timeout);
match socket.try_connect(timeout).await {
Ok(()) => return Ok((socket.local, socket)),
Err(Error::TimedOut) => (),
Err(error) => return Err(error),
}
t *= 2;
}
// Couldn't establish
Err(Error::TimedOut)
}
async fn try_connect(&self, timeout: Duration) -> Result<(), Error> {
{
let mut connection = self.connection.write();
connection.send_syn().await?;
}
let fut = poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_established(cx)
});
with_timeout(fut, timeout).await?
}
}
impl Socket for TcpSocket {
fn local_address(&self) -> SocketAddr {
self.local
}
fn remote_address(&self) -> Option<SocketAddr> {
Some(self.remote)
}
fn close(&self) -> Result<(), Error> {
block!(self.close_async(true).await)?
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
&SocketOption::Ttl(ttl) => {
if ttl == 0 || ttl > 255 {
return Err(Error::InvalidArgument);
}
self.ttl.store(ttl as _, Ordering::Release);
Ok(())
}
SocketOption::NoDelay(_) => {
log::warn!("TODO: TCP nodelay");
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::Ttl(ttl) => {
*ttl = self.ttl.load(Ordering::Acquire) as _;
Ok(())
}
SocketOption::NoDelay(nodelay) => {
*nodelay = false;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
}
impl FileReadiness for TcpSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.poll_receive(cx).map_ok(|_| ())
}
}
#[async_trait]
impl ConnectionSocket for TcpSocket {
async fn receive(&self, buffer: &mut [u8]) -> Result<usize, Error> {
self.receive_async(buffer).await
}
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<usize, Error> {
match self.connection.write().read_nonblocking(buffer) {
// TODO check if this really means "no data at the moment"
Ok(0) => Err(Error::WouldBlock),
Ok(len) => Ok(len),
Err(error) => Err(error),
}
}
async fn send(&self, data: &[u8]) -> Result<usize, Error> {
self.send_async(data).await
}
fn send_nonblocking(&self, data: &[u8]) -> Result<usize, Error> {
log::warn!("TODO: TCP::send_nonblocking");
block!(self.send_async(data).await)?
}
}
impl fmt::Debug for TcpSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TcpSocket")
.field("local", &self.local)
.field("remote", &self.remote)
.finish_non_exhaustive()
}
}
impl TcpListener {
pub fn bind(accept: SocketAddr) -> Result<Arc<Self>, Error> {
TCP_LISTENERS.write().try_insert_with(accept, || {
let listener = TcpListener {
accept,
sockets: IrqSafeRwLock::new(BTreeMap::new()),
pending_accept: IrqSafeSpinlock::new(Vec::new()),
accept_notify: QueueWaker::new(),
};
log::debug!("TCP Listener opened: {}", accept);
Ok(Arc::new(listener))
})
}
pub fn get(local: SocketAddr) -> Option<Arc<Self>> {
TCP_LISTENERS.read().get(&local)
}
pub fn accept_async(&self) -> impl Future<Output = Result<Arc<TcpSocket>, Error>> + '_ {
poll_fn(|cx| match self.poll_accept(cx) {
Poll::Ready(mut lock) => Poll::Ready(Ok(lock.pop().unwrap())),
Poll::Pending => Poll::Pending,
})
}
fn accept_socket(&self, socket: Arc<TcpSocket>) {
log::debug!("{}: accept {}", self.accept, socket.remote);
self.sockets.write().insert(socket.remote, socket.clone());
self.pending_accept.lock().push(socket);
self.accept_notify.wake_one();
}
fn remove_socket(&self, remote: SocketAddr) {
log::debug!("Remove client {}", remote);
self.sockets.write().remove(&remote);
}
fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<IrqSafeSpinlockGuard<Vec<Arc<TcpSocket>>>> {
let lock = self.pending_accept.lock();
self.accept_notify.register(cx.waker());
if !lock.is_empty() {
self.accept_notify.remove(cx.waker());
Poll::Ready(lock)
} else {
Poll::Pending
}
}
}
impl Socket for TcpListener {
fn local_address(&self) -> SocketAddr {
self.accept
}
fn remote_address(&self) -> Option<SocketAddr> {
None
}
fn close(&self) -> Result<(), Error> {
// TODO if clients not closed already, send RST?
TCP_LISTENERS.write().remove(self.accept)
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
SocketOption::Ipv6Only(_v6only) => {
log::warn!("TODO: TCP listener IPv6-only");
Err(Error::InvalidOperation)
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::Ipv6Only(v6only) => {
*v6only = false;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
}
impl FileReadiness for TcpListener {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.poll_accept(cx).map(|_| Ok(()))
}
}
#[async_trait]
impl ListenerSocket for TcpListener {
async fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let socket = self.accept_async().await?;
let remote = socket.remote;
Ok((remote, socket))
}
fn accept_nonblocking(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let socket = self.pending_accept.lock().pop().ok_or(Error::WouldBlock)?;
let remote = socket.remote;
Ok((remote, socket))
}
}
impl fmt::Debug for TcpListener {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TcpListener")
.field("local", &self.accept)
.finish_non_exhaustive()
}
}

View File

@ -0,0 +1,112 @@
use core::{
fmt,
future::poll_fn,
sync::atomic::{AtomicBool, Ordering},
task::{Context, Poll},
};
use alloc::{collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
use libk::error::Error;
use libk_util::{
sync::{spin_rwlock::IrqSafeRwLock, IrqSafeSpinlock, IrqSafeSpinlockGuard},
waker::QueueWaker,
};
use yggdrasil_abi::net::SocketAddr;
use crate::socket::SocketTable;
use super::TcpStream;
pub struct TcpListener {
pub(super) local: SocketAddr,
pub(super) listening: AtomicBool,
// Currently active sockets
sockets: IrqSafeRwLock<BTreeMap<SocketAddr, Arc<TcpStream>>>,
pending_accept: IrqSafeSpinlock<Vec<Arc<TcpStream>>>,
accept_notify: QueueWaker,
}
static TCP_LISTENERS: IrqSafeRwLock<SocketTable<TcpListener>> =
IrqSafeRwLock::new(SocketTable::new());
impl TcpListener {
pub fn bind(local: SocketAddr) -> Result<Arc<Self>, Error> {
if local.port() == 0 {
// TODO: ephemeral binding
todo!();
}
let listener = Arc::new(TcpListener {
local,
listening: AtomicBool::new(false),
sockets: IrqSafeRwLock::new(BTreeMap::new()),
pending_accept: IrqSafeSpinlock::new(Vec::new()),
accept_notify: QueueWaker::new(),
});
TCP_LISTENERS.write().try_insert(local, listener.clone())?;
Ok(listener)
}
pub fn close(&self) -> Result<(), Error> {
// TODO if clients not closed already, send RST?
TCP_LISTENERS.write().remove(self.local)
}
pub(super) async fn accept(&self) -> Result<Arc<TcpStream>, Error> {
poll_fn(|cx| match self.poll_accept(cx) {
Poll::Ready(mut lock) => Poll::Ready(Ok(lock.pop().unwrap())),
Poll::Pending => Poll::Pending,
})
.await
}
pub(super) fn accept_nonblocking(&self) -> Result<Arc<TcpStream>, Error> {
self.pending_accept.lock().pop().ok_or(Error::WouldBlock)
}
pub(super) fn poll_accept(
&self,
cx: &mut Context<'_>,
) -> Poll<IrqSafeSpinlockGuard<Vec<Arc<TcpStream>>>> {
let lock = self.pending_accept.lock();
self.accept_notify.register(cx.waker());
if !lock.is_empty() {
self.accept_notify.remove(cx.waker());
Poll::Ready(lock)
} else {
Poll::Pending
}
}
pub(super) fn accept_stream(&self, stream: Arc<TcpStream>) {
log::debug!("{}: accept {}", self.local, stream.remote);
self.sockets.write().insert(stream.remote, stream.clone());
self.pending_accept.lock().push(stream);
self.accept_notify.wake_one();
}
pub(super) fn remove_stream(&self, remote: SocketAddr) {
log::debug!("Remove remote stream {}", remote);
self.sockets.write().remove(&remote);
}
pub fn is_listening(&self) -> bool {
self.listening.load(Ordering::Acquire)
}
pub fn get(local: &SocketAddr) -> Option<Arc<TcpListener>> {
TCP_LISTENERS.read().get(local)
}
}
impl fmt::Debug for TcpListener {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TcpListener")
.field("listening", &self.is_listening())
.field("local", &self.local)
.finish_non_exhaustive()
}
}

View File

@ -0,0 +1,270 @@
use core::{
fmt, mem,
sync::atomic::Ordering,
task::{Context, Poll},
time::Duration,
};
use alloc::{boxed::Box, sync::Arc};
use async_trait::async_trait;
use libk::{
block,
error::Error,
task::runtime::maybe_timeout,
vfs::{ConnectionSocket, FileReadiness, Socket},
};
use libk_util::sync::spin_rwlock::IrqSafeRwLock;
use yggdrasil_abi::net::{SocketAddr, SocketOption};
mod listener;
mod stream;
pub use listener::TcpListener;
pub use stream::TcpStream;
pub struct TcpSocket {
options: IrqSafeRwLock<TcpSocketOptions>,
inner: IrqSafeRwLock<TcpSocketInner>,
}
pub enum TcpSocketInner {
Empty,
Listener(Arc<TcpListener>),
Stream(Arc<TcpStream>),
}
struct TcpSocketOptions {
ttl: u8,
nodelay: bool,
v6_only: bool,
}
impl Default for TcpSocketOptions {
fn default() -> Self {
Self {
ttl: 64,
nodelay: false,
v6_only: false,
}
}
}
impl TcpSocket {
pub fn new() -> Arc<Self> {
Arc::new(Self {
options: IrqSafeRwLock::new(TcpSocketOptions::default()),
inner: IrqSafeRwLock::new(TcpSocketInner::Empty),
})
}
fn as_stream(&self) -> Option<Arc<TcpStream>> {
if let TcpSocketInner::Stream(stream) = &*self.inner.read() {
Some(stream.clone())
} else {
None
}
}
fn as_listener(&self) -> Option<Arc<TcpListener>> {
if let TcpSocketInner::Listener(listener) = &*self.inner.read() {
Some(listener.clone())
} else {
None
}
}
}
impl Socket for TcpSocket {
fn bind(self: Arc<Self>, local: SocketAddr) -> Result<(), Error> {
let mut inner = self.inner.write();
// Already connected or bound
if !matches!(&*inner, TcpSocketInner::Empty) {
return Err(Error::InvalidOperation);
}
// Bind a listener socket
let listener = TcpListener::bind(local)?;
*inner = TcpSocketInner::Listener(listener);
Ok(())
}
fn close(self: Arc<Self>) -> Result<(), Error> {
let inner = mem::replace(&mut *self.inner.write(), TcpSocketInner::Empty);
match inner {
TcpSocketInner::Empty => Ok(()),
TcpSocketInner::Stream(socket) => block!(socket.close(true).await)?,
TcpSocketInner::Listener(socket) => socket.close(),
}
}
fn local_address(&self) -> Option<SocketAddr> {
match &*self.inner.read() {
TcpSocketInner::Empty => None,
TcpSocketInner::Stream(socket) => Some(socket.local),
TcpSocketInner::Listener(socket) => Some(socket.local),
}
}
fn remote_address(&self) -> Option<SocketAddr> {
match &*self.inner.read() {
TcpSocketInner::Empty | TcpSocketInner::Listener(_) => None,
TcpSocketInner::Stream(socket) => Some(socket.remote),
}
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match *option {
SocketOption::Ttl(ttl) => {
if !(1..256).contains(&ttl) {
return Err(Error::InvalidArgument);
}
let ttl = ttl as u8;
self.options.write().ttl = ttl;
Ok(())
}
SocketOption::NoDelay(nodelay) => {
self.options.write().nodelay = nodelay;
Ok(())
}
SocketOption::Ipv6Only(v6_only) => {
self.options.write().v6_only = v6_only;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::Ttl(ttl) => {
*ttl = self.options.read().ttl as u32;
Ok(())
}
SocketOption::NoDelay(nodelay) => {
*nodelay = self.options.read().nodelay;
Ok(())
}
SocketOption::Ipv6Only(v6_only) => {
*v6_only = self.options.read().v6_only;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
}
#[async_trait]
impl ConnectionSocket for TcpSocket {
async fn connect(
self: Arc<Self>,
remote: SocketAddr,
_timeout: Option<Duration>,
) -> Result<(), Error> {
let mut inner = self.inner.write();
// Already connected or bound
if !matches!(&*inner, TcpSocketInner::Empty) {
return Err(Error::InvalidOperation);
}
let stream = TcpStream::connect(remote).await?;
*inner = TcpSocketInner::Stream(stream);
Ok(())
}
async fn receive(
&self,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error> {
let stream = self.as_stream().ok_or(Error::NotConnected)?;
let len = stream.receive(buffer, timeout).await?;
Ok((len, stream.remote))
}
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let stream = self.as_stream().ok_or(Error::NotConnected)?;
let len = stream.receive_nonblocking(buffer)?;
Ok((len, stream.remote))
}
async fn send(&self, data: &[u8], timeout: Option<Duration>) -> Result<usize, Error> {
let stream = self.as_stream().ok_or(Error::NotConnected)?;
let future = stream.send(data);
maybe_timeout(future, timeout).await?
}
fn send_nonblocking(&self, buffer: &[u8]) -> Result<usize, Error> {
let stream = self.as_stream().ok_or(Error::NotConnected)?;
let len = stream.send_nonblocking(buffer)?;
Ok(len)
}
fn listen(self: Arc<Self>) -> Result<(), Error> {
let listener = self.as_listener().ok_or(Error::InvalidOperation)?;
match listener
.listening
.compare_exchange(false, true, Ordering::Release, Ordering::Relaxed)
{
Ok(_) => Ok(()),
Err(_) => Err(Error::InvalidOperation),
}
}
async fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let listener = self.as_listener().ok_or(Error::InvalidOperation)?;
let stream = listener.accept().await?;
let remote = stream.remote;
let socket = Arc::new(TcpSocket {
options: IrqSafeRwLock::new(TcpSocketOptions::default()),
inner: IrqSafeRwLock::new(TcpSocketInner::Stream(stream)),
});
Ok((remote, socket))
}
fn accept_nonblocking(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let listener = self.as_listener().ok_or(Error::InvalidOperation)?;
let stream = listener.accept_nonblocking()?;
let remote = stream.remote;
let socket = Arc::new(TcpSocket {
options: IrqSafeRwLock::new(TcpSocketOptions::default()),
inner: IrqSafeRwLock::new(TcpSocketInner::Stream(stream)),
});
Ok((remote, socket))
}
async fn shutdown(&self, read: bool, write: bool) -> Result<(), Error> {
log::warn!("TODO: shutdown(read = {read}, write = {write})");
Ok(())
}
}
impl FileReadiness for TcpSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match &*self.inner.read() {
TcpSocketInner::Empty => Poll::Ready(Ok(())),
TcpSocketInner::Stream(socket) => socket.poll_receive(cx).map_ok(|_| ()),
TcpSocketInner::Listener(socket) => socket.poll_accept(cx).map(|_| Ok(())),
}
}
}
impl fmt::Debug for TcpSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &*self.inner.read() {
TcpSocketInner::Empty => f.debug_struct("TcpSocket").finish_non_exhaustive(),
TcpSocketInner::Stream(socket) => f
.debug_struct("TcpSocket")
.field("stream", socket)
.finish_non_exhaustive(),
TcpSocketInner::Listener(socket) => f
.debug_struct("TcpSocket")
.field("listener", socket)
.finish_non_exhaustive(),
}
}
}

View File

@ -0,0 +1,309 @@
use core::{
fmt,
future::poll_fn,
task::{Context, Poll},
time::Duration,
};
use alloc::sync::Arc;
use libk::{
error::Error,
task::runtime::{maybe_timeout, with_timeout},
time::monotonic_time,
};
use libk_util::sync::spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockWriteGuard};
use yggdrasil_abi::net::SocketAddr;
use crate::{
interface::NetworkInterface,
l3::Route,
l4::tcp::{TcpConnection, TcpConnectionState},
socket::TwoWaySocketTable,
};
use super::TcpListener;
pub struct TcpStream {
pub(super) local: SocketAddr,
pub(super) remote: SocketAddr,
// Listener which accepted the socket
listener: Option<Arc<TcpListener>>,
pub(crate) connection: IrqSafeRwLock<TcpConnection>,
}
static TCP_STREAMS: IrqSafeRwLock<TwoWaySocketTable<TcpStream>> =
IrqSafeRwLock::new(TwoWaySocketTable::new());
impl TcpStream {
pub async fn connect(remote: SocketAddr) -> Result<Arc<Self>, Error> {
// Lookup route to remote
let (interface_id, _, remote_ip) =
Route::lookup(remote.ip()).ok_or(Error::HostUnreachable)?;
let remote = SocketAddr::new(remote_ip, remote.port());
let interface = NetworkInterface::get(interface_id)?;
let local_ip = interface.address.read().ok_or(Error::NetworkUnreachable)?;
// Create a new TcpStream with an ephemeral port
let stream = {
let mut streams = TCP_STREAMS.write();
streams.try_insert_with_ephemeral_port(local_ip, remote, |port| {
let t = monotonic_time();
let tx_seq = t.as_millis() as u32;
let local = SocketAddr::new(local_ip, port);
let connection =
TcpConnection::new(local, remote, 16384, tx_seq, 0, TcpConnectionState::Closed);
Ok(Arc::new(Self {
local,
remote,
listener: None,
connection: IrqSafeRwLock::new(connection),
}))
})?
};
let mut t = 200;
for _ in 0..5 {
let timeout = Duration::from_millis(t);
log::debug!("Try SYN with timeout={:?}", timeout);
match stream.try_connect(timeout).await {
Ok(()) => return Ok(stream),
Err(Error::TimedOut) => (),
Err(error) => return Err(error),
}
t *= 2;
}
// Couldn't establish
stream.close(false).await.ok();
Err(Error::TimedOut)
}
async fn try_connect(&self, timeout: Duration) -> Result<(), Error> {
{
let mut connection = self.connection.write();
connection.send_syn().await?;
}
let fut = poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_established(cx)
});
with_timeout(fut, timeout).await?
}
pub fn accept(self: &Arc<Self>) {
if let Some(listener) = self.listener.as_ref() {
listener.accept_stream(self.clone());
}
}
pub fn accept_remote(
listener: Arc<TcpListener>,
local: SocketAddr,
remote: SocketAddr,
remote_window_size: usize,
tx_seq: u32,
rx_seq: u32,
) -> Result<Arc<TcpStream>, Error> {
let mut streams = TCP_STREAMS.write();
streams.try_insert_with(local, remote, move || {
let connection = TcpConnection::new(
local,
remote,
remote_window_size,
tx_seq,
rx_seq,
TcpConnectionState::SynReceived,
);
log::debug!("Accepted TCP socket {} -> {}", local, remote);
let socket = Self {
local,
remote,
listener: Some(listener),
connection: IrqSafeRwLock::new(connection),
};
Ok(Arc::new(socket))
})
}
pub async fn close(&self, remove_from_listener: bool) -> Result<(), Error> {
// TODO timeout here
// Already closing
if self.connection.read().is_closing() {
return Ok(());
}
// Wait for all sent data to be acknowledged
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.finish().await?;
}
log::debug!(
"TCP socket closed (FinWait2/Closed): {} <-> {}",
self.local,
self.remote
);
// Wait for connection to get closed
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_finish(cx)
})
.await;
if remove_from_listener {
if let Some(listener) = self.listener.as_ref() {
listener.remove_stream(self.remote);
};
}
Ok(())
}
pub(crate) fn remove_stream(&self) -> Result<(), Error> {
log::debug!(
"TCP socket closed and removed: {} <-> {}",
self.local,
self.remote
);
let connection = self.connection.read();
debug_assert!(connection.is_closed());
TCP_STREAMS.write().remove(self.local, self.remote)?;
connection.notify_all();
Ok(())
}
pub fn get(local: SocketAddr, remote: SocketAddr) -> Option<Arc<TcpStream>> {
TCP_STREAMS.read().get(local, remote)
}
pub(super) fn poll_receive(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<IrqSafeRwLockWriteGuard<TcpConnection>, Error>> {
let lock = self.connection.write();
match lock.poll_receive(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(lock)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
}
pub async fn receive(
&self,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<usize, Error> {
let future = poll_fn(|cx| match self.poll_receive(cx) {
Poll::Ready(Ok(mut lock)) => Poll::Ready(lock.read_nonblocking(buffer)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
});
maybe_timeout(future, timeout).await.flatten()
}
pub async fn send(&self, data: &[u8]) -> Result<usize, Error> {
let mut pos = 0;
let mut rem = data.len();
while rem != 0 {
// TODO check MTU
let amount = rem.min(512);
self.send_segment(&data[pos..pos + amount]).await?;
pos += amount;
rem -= amount;
}
Ok(pos)
}
pub fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<usize, Error> {
if buffer.is_empty() {
return Ok(0);
}
let mut lock = self.connection.write();
match lock.read_nonblocking(buffer) {
Ok(0) => Err(Error::WouldBlock),
res => res,
}
}
pub fn send_nonblocking(&self, data: &[u8]) -> Result<usize, Error> {
if data.is_empty() {
return Ok(0);
}
let mut pos = 0;
let mut rem = data.len();
let mut sent = false;
while rem != 0 {
// TODO check MTU
let amount = rem.min(512);
match self.send_segment_nonblocking(&data[pos..pos + amount]) {
Ok(()) => sent = true,
Err(Error::WouldBlock) => break,
Err(error) => return Err(error),
}
pos += amount;
rem -= amount;
}
// No data sent and WouldBlock returned
if !sent {
return Err(Error::WouldBlock);
}
Ok(pos)
}
async fn send_segment(&self, data: &[u8]) -> Result<(), Error> {
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.transmit(data).await?;
}
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_acknowledge(cx)
})
.await;
Ok(())
}
fn send_segment_nonblocking(&self, _data: &[u8]) -> Result<(), Error> {
todo!()
}
}
impl fmt::Debug for TcpStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TcpStream")
.field("local", &self.local)
.field("remote", &self.remote)
.finish_non_exhaustive()
}
}

View File

@ -2,6 +2,7 @@ use core::{
fmt, fmt,
sync::atomic::{AtomicBool, AtomicU8, Ordering}, sync::atomic::{AtomicBool, AtomicU8, Ordering},
task::{Context, Poll}, task::{Context, Poll},
time::Duration,
}; };
use alloc::{boxed::Box, sync::Arc, vec::Vec}; use alloc::{boxed::Box, sync::Arc, vec::Vec};
@ -9,18 +10,22 @@ use async_trait::async_trait;
use libk::{ use libk::{
block, block,
error::Error, error::Error,
task::runtime::maybe_timeout,
vfs::{FileReadiness, PacketSocket, Socket}, vfs::{FileReadiness, PacketSocket, Socket},
}; };
use libk_util::{queue::BoundedMpmcQueue, sync::spin_rwlock::IrqSafeRwLock}; use libk_util::{
use yggdrasil_abi::net::{SocketAddr, SocketOption}; queue::BoundedMpmcQueue,
sync::spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockReadGuard},
};
use yggdrasil_abi::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketOption};
use crate::l4; use crate::l4;
use super::SocketTable; use super::SocketTable;
pub struct UdpSocket { pub struct UdpSocket {
local: SocketAddr, local: IrqSafeRwLock<Option<SocketAddr>>,
remote: Option<SocketAddr>, remote: IrqSafeRwLock<Option<SocketAddr>>,
broadcast: AtomicBool, broadcast: AtomicBool,
ttl: AtomicU8, ttl: AtomicU8,
@ -32,32 +37,18 @@ pub struct UdpSocket {
static UDP_SOCKETS: IrqSafeRwLock<SocketTable<UdpSocket>> = IrqSafeRwLock::new(SocketTable::new()); static UDP_SOCKETS: IrqSafeRwLock<SocketTable<UdpSocket>> = IrqSafeRwLock::new(SocketTable::new());
impl UdpSocket { impl UdpSocket {
fn create_socket(local: SocketAddr) -> Arc<UdpSocket> { pub fn new() -> Arc<Self> {
log::debug!("UDP socket opened: {}", local); Arc::new(Self {
Arc::new(UdpSocket { local: IrqSafeRwLock::new(None),
local, remote: IrqSafeRwLock::new(None),
ttl: AtomicU8::new(64),
remote: None,
broadcast: AtomicBool::new(false), broadcast: AtomicBool::new(false),
ttl: AtomicU8::new(64),
receive_queue: BoundedMpmcQueue::new(128), receive_queue: BoundedMpmcQueue::new(128),
}) })
} }
pub fn bind(address: SocketAddr) -> Result<Arc<UdpSocket>, Error> {
let mut sockets = UDP_SOCKETS.write();
if address.port() == 0 {
sockets.try_insert_with_ephemeral_port(address.ip(), |port| {
Ok(Self::create_socket(SocketAddr::new(address.ip(), port)))
})
} else {
sockets.try_insert_with(address, move || Ok(Self::create_socket(address)))
}
}
pub fn connect(&self, _address: SocketAddr) -> Result<(), Error> {
todo!()
}
pub fn get(local: &SocketAddr) -> Option<Arc<UdpSocket>> { pub fn get(local: &SocketAddr) -> Option<Arc<UdpSocket>> {
UDP_SOCKETS.read().get(local) UDP_SOCKETS.read().get(local)
} }
@ -67,6 +58,25 @@ impl UdpSocket {
.try_push_back((source, Vec::from(data))) .try_push_back((source, Vec::from(data)))
.map_err(|_| Error::QueueFull) .map_err(|_| Error::QueueFull)
} }
// If address is bound, keep it, if not, bind an ephemeral port
pub fn ensure_address(self: &Arc<Self>, v6: bool) -> Result<u16, Error> {
let local = self.local.read();
if let Some(address) = *local {
Ok(address.port())
} else {
let mut local = IrqSafeRwLockReadGuard::upgrade(local);
let ip = match v6 {
true => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
false => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
};
let port = UDP_SOCKETS
.write()
.bind_to_ephemeral_port(ip, self.clone())?;
*local = Some(SocketAddr::new(ip, port));
Ok(port)
}
}
} }
impl FileReadiness for UdpSocket { impl FileReadiness for UdpSocket {
@ -77,18 +87,35 @@ impl FileReadiness for UdpSocket {
#[async_trait] #[async_trait]
impl PacketSocket for UdpSocket { impl PacketSocket for UdpSocket {
async fn send_to(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error> { fn connect(self: Arc<Self>, remote: SocketAddr) -> Result<(), Error> {
let Some(destination) = destination else { let mut connected = self.remote.write();
// TODO can still send without setting address if "connected" if connected.is_some() {
return Err(Error::InvalidArgument); return Err(Error::InvalidOperation);
}; }
*connected = Some(remote);
Ok(())
}
async fn send_to(
self: Arc<Self>,
destination: Option<SocketAddr>,
data: &[u8],
_timeout: Option<Duration>,
) -> Result<usize, Error> {
let destination = destination
.or_else(|| self.remote_address())
.ok_or(Error::NotConnected)?;
// If socket wasn't bound yet, bind it to an ephemeral port
let port = self.ensure_address(destination.ip().is_ipv6())?;
// TODO check that destnation family matches self family // TODO check that destnation family matches self family
match (self.broadcast.load(Ordering::Acquire), destination.ip()) { match (self.broadcast.load(Ordering::Acquire), destination.ip()) {
// SendTo in broadcast? // TODO broadcast
(true, _) => todo!(), (true, _) => return Err(Error::NotImplemented),
(false, _) => { (false, _) => {
l4::udp::send( l4::udp::send(
self.local.port(), port,
destination.ip(), destination.ip(),
destination.port(), destination.port(),
self.ttl.load(Ordering::Acquire), self.ttl.load(Ordering::Acquire),
@ -102,16 +129,21 @@ impl PacketSocket for UdpSocket {
} }
fn send_nonblocking( fn send_nonblocking(
&self, self: Arc<Self>,
destination: Option<SocketAddr>, destination: Option<SocketAddr>,
buffer: &[u8], buffer: &[u8],
) -> Result<usize, Error> { ) -> Result<usize, Error> {
log::warn!("TODO: UDP::send_nonblocking()"); log::warn!("TODO: UDP::send_nonblocking()");
block!(self.send_to(destination, buffer).await)? block!(self.send_to(destination, buffer, None).await)?
} }
async fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> { async fn receive_from(
let (source, packet) = self.receive_queue.pop_front().await; self: Arc<Self>,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error> {
let future = self.receive_queue.pop_front();
let (source, packet) = maybe_timeout(future, timeout).await?;
if packet.len() > buffer.len() { if packet.len() > buffer.len() {
// TODO check how other OSs handle this // TODO check how other OSs handle this
@ -121,7 +153,10 @@ impl PacketSocket for UdpSocket {
Ok((packet.len(), source)) Ok((packet.len(), source))
} }
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> { fn receive_nonblocking(
self: Arc<Self>,
buffer: &mut [u8],
) -> Result<(usize, SocketAddr), Error> {
let (source, packet) = self let (source, packet) = self
.receive_queue .receive_queue
.try_pop_front() .try_pop_front()
@ -137,23 +172,43 @@ impl PacketSocket for UdpSocket {
} }
impl Socket for UdpSocket { impl Socket for UdpSocket {
fn local_address(&self) -> SocketAddr { fn bind(self: Arc<Self>, local: SocketAddr) -> Result<(), Error> {
self.local let mut bound = self.local.write();
if bound.is_some() {
return Err(Error::InvalidOperation);
}
if local.port() == 0 {
let port = UDP_SOCKETS
.write()
.bind_to_ephemeral_port(local.ip(), self.clone())?;
*bound = Some(SocketAddr::new(local.ip(), port));
} else {
UDP_SOCKETS.write().try_insert(local, self.clone())?;
*bound = Some(local);
}
Ok(())
}
fn local_address(&self) -> Option<SocketAddr> {
*self.local.read()
} }
fn remote_address(&self) -> Option<SocketAddr> { fn remote_address(&self) -> Option<SocketAddr> {
self.remote *self.remote.read()
} }
fn close(&self) -> Result<(), Error> { fn close(self: Arc<Self>) -> Result<(), Error> {
log::debug!("UDP socket closed: {}", self.local); if let Some(local) = self.local.write().take() {
UDP_SOCKETS.write().remove(self.local) log::debug!("UDP socket closed: {}", local);
UDP_SOCKETS.write().remove(local)?;
}
Ok(())
} }
fn set_option(&self, option: &SocketOption) -> Result<(), Error> { fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option { match option {
&SocketOption::Broadcast(broadcast) => { &SocketOption::Broadcast(broadcast) => {
log::debug!("{} broadcast: {}", self.local, broadcast); // log::debug!("{} broadcast: {}", self.local, broadcast);
self.broadcast.store(broadcast, Ordering::Release); self.broadcast.store(broadcast, Ordering::Release);
Ok(()) Ok(())
} }
@ -209,9 +264,12 @@ impl Socket for UdpSocket {
impl fmt::Debug for UdpSocket { impl fmt::Debug for UdpSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let local = *self.local.read();
let remote = *self.remote.read();
f.debug_struct("UdpSocket") f.debug_struct("UdpSocket")
.field("local", &self.local) .field("local", &local)
.field("remote", &self.remote) .field("remote", &remote)
.finish_non_exhaustive() .finish_non_exhaustive()
} }
} }

View File

@ -44,6 +44,15 @@ impl RwLockInner {
self.value.fetch_nand(Self::LOCKED_WRITE, Ordering::Release); self.value.fetch_nand(Self::LOCKED_WRITE, Ordering::Release);
} }
#[inline]
fn upgrade(&self) {
// At least one read lock is held by this task.
// When there's *exactly* one lock (being this task) held, upgrade is possible
while !self.try_upgrade() {
core::hint::spin_loop();
}
}
#[inline] #[inline]
fn acquire_read_raw(&self) -> usize { fn acquire_read_raw(&self) -> usize {
let value = self.value.fetch_add(Self::LOCKED_READ, Ordering::Acquire); let value = self.value.fetch_add(Self::LOCKED_READ, Ordering::Acquire);
@ -77,6 +86,18 @@ impl RwLockInner {
.is_ok() .is_ok()
} }
#[inline]
fn try_upgrade(&self) -> bool {
self.value
.compare_exchange(
Self::LOCKED_READ,
Self::LOCKED_WRITE,
Ordering::Acquire,
Ordering::Relaxed,
)
.is_ok()
}
#[inline] #[inline]
fn acquire_read(&self) { fn acquire_read(&self) {
while !self.try_acquire_read() { while !self.try_acquire_read() {
@ -133,6 +154,11 @@ impl<T> IrqSafeRwLock<T> {
self.inner.downgrade_write(); self.inner.downgrade_write();
} }
#[inline]
unsafe fn upgrade(&self) {
self.inner.upgrade();
}
unsafe fn release_read(&self) { unsafe fn release_read(&self) {
self.inner.release_read(); self.inner.release_read();
} }
@ -159,10 +185,26 @@ impl<T> Drop for IrqSafeRwLockReadGuard<'_, T> {
} }
} }
impl<T> IrqSafeRwLockReadGuard<'_, T> { impl<'a, T> IrqSafeRwLockReadGuard<'a, T> {
pub fn get(guard: &Self) -> *const T { pub fn get(guard: &Self) -> *const T {
guard.lock.value.get() guard.lock.value.get()
} }
pub fn upgrade(guard: IrqSafeRwLockReadGuard<'a, T>) -> IrqSafeRwLockWriteGuard<'a, T> {
let lock = guard.lock;
let irq_guard = IrqGuard::acquire();
// Read lock still held
core::mem::forget(guard);
unsafe {
lock.upgrade();
}
IrqSafeRwLockWriteGuard {
lock,
_guard: irq_guard,
}
}
} }
impl<'a, T> IrqSafeRwLockWriteGuard<'a, T> { impl<'a, T> IrqSafeRwLockWriteGuard<'a, T> {

View File

@ -372,10 +372,8 @@ impl IoContext {
if !flags.contains_any(RemoveFlags::DIRECTORY | RemoveFlags::DIRECTORY_ONLY) { if !flags.contains_any(RemoveFlags::DIRECTORY | RemoveFlags::DIRECTORY_ONLY) {
return Err(Error::IsADirectory); return Err(Error::IsADirectory);
} }
} else { } else if flags.contains_any(RemoveFlags::DIRECTORY_ONLY) {
if flags.contains_any(RemoveFlags::DIRECTORY_ONLY) { return Err(Error::NotADirectory);
return Err(Error::NotADirectory);
}
} }
parent.remove_file(filename, access) parent.remove_file(filename, access)

View File

@ -35,7 +35,7 @@ pub use path::{Filename, OwnedFilename};
pub use poll::FdPoll; pub use poll::FdPoll;
pub use pty::{PseudoTerminalMaster, PseudoTerminalSlave}; pub use pty::{PseudoTerminalMaster, PseudoTerminalSlave};
pub use shared_memory::SharedMemory; pub use shared_memory::SharedMemory;
pub use socket::{ConnectionSocket, ListenerSocket, PacketSocket, Socket, SocketWrapper}; pub use socket::{ConnectionSocket, PacketSocket, Socket, SocketWrapper};
pub use terminal::{Terminal, TerminalInput, TerminalOutput}; pub use terminal::{Terminal, TerminalInput, TerminalOutput};
pub use timer::TimerFile; pub use timer::TimerFile;
pub use traits::{FileReadiness, Read, Seek, Write}; pub use traits::{FileReadiness, Read, Seek, Write};

View File

@ -9,18 +9,21 @@ use async_trait::async_trait;
use libk_util::sync::spin_rwlock::IrqSafeRwLock; use libk_util::sync::spin_rwlock::IrqSafeRwLock;
use yggdrasil_abi::{ use yggdrasil_abi::{
error::Error, error::Error,
net::{SocketAddr, SocketOption}, net::{SocketAddr, SocketOption, SocketShutdown},
}; };
use crate::{task::runtime::maybe_timeout, vfs::FileReadiness}; use crate::vfs::FileReadiness;
use super::{File, FileRef};
enum SocketInner { enum SocketInner {
Connection(Arc<dyn ConnectionSocket + Send + 'static>), Connection(Arc<dyn ConnectionSocket + 'static>),
Listener(Arc<dyn ListenerSocket + Send + 'static>), // Listener(Arc<dyn ListenerSocket + Send + 'static>),
Packet(Arc<dyn PacketSocket + Send + 'static>), Packet(Arc<dyn PacketSocket + 'static>),
} }
struct InnerOptions { struct InnerOptions {
connect_timeout: Option<Duration>,
recv_timeout: Option<Duration>, recv_timeout: Option<Duration>,
send_timeout: Option<Duration>, send_timeout: Option<Duration>,
non_blocking: bool, non_blocking: bool,
@ -34,14 +37,16 @@ pub struct SocketWrapper {
/// Interface for interacting with network sockets /// Interface for interacting with network sockets
#[allow(unused)] #[allow(unused)]
pub trait Socket: FileReadiness + fmt::Debug + Send { pub trait Socket: FileReadiness + fmt::Debug + Send {
fn bind(self: Arc<Self>, local: SocketAddr) -> Result<(), Error>;
/// Socket listen/receive address /// Socket listen/receive address
fn local_address(&self) -> SocketAddr; fn local_address(&self) -> Option<SocketAddr>;
/// Socket remote address /// Socket remote address
fn remote_address(&self) -> Option<SocketAddr>; fn remote_address(&self) -> Option<SocketAddr>;
/// Closes a socket /// Closes a socket
fn close(&self) -> Result<(), Error>; fn close(self: Arc<Self>) -> Result<(), Error>;
/// Updates a socket option /// Updates a socket option
fn set_option(&self, option: &SocketOption) -> Result<(), Error> { fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
@ -59,40 +64,62 @@ pub trait Socket: FileReadiness + fmt::Debug + Send {
pub trait PacketSocket: Socket { pub trait PacketSocket: Socket {
/// Receives a packet into provided buffer. Will return an error if packet cannot be placed /// Receives a packet into provided buffer. Will return an error if packet cannot be placed
/// within the buffer. /// within the buffer.
async fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error>; async fn receive_from(
self: Arc<Self>,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error>;
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error>; fn receive_nonblocking(
self: Arc<Self>,
buffer: &mut [u8],
) -> Result<(usize, SocketAddr), Error>;
/// Sends provided data to the recepient specified by `destination` /// Sends provided data to the recepient specified by `destination`
async fn send_to(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error>; async fn send_to(
self: Arc<Self>,
destination: Option<SocketAddr>,
data: &[u8],
timeout: Option<Duration>,
) -> Result<usize, Error>;
fn send_nonblocking( fn send_nonblocking(
&self, self: Arc<Self>,
destination: Option<SocketAddr>, destination: Option<SocketAddr>,
buffer: &[u8], buffer: &[u8],
) -> Result<usize, Error>; ) -> Result<usize, Error>;
fn connect(self: Arc<Self>, remote: SocketAddr) -> Result<(), Error>;
} }
/// Connection-based client socket interface /// Connection-based client socket interface
#[async_trait] #[async_trait]
pub trait ConnectionSocket: Socket { pub trait ConnectionSocket: Socket {
/// Receives data into provided buffer async fn connect(
async fn receive(&self, buffer: &mut [u8]) -> Result<usize, Error>; self: Arc<Self>,
remote: SocketAddr,
timeout: Option<Duration>,
) -> Result<(), Error>;
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<usize, Error>; async fn receive(
&self,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error>;
/// Transmits data fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error>;
async fn send(&self, data: &[u8]) -> Result<usize, Error>;
async fn send(&self, data: &[u8], timeout: Option<Duration>) -> Result<usize, Error>;
fn send_nonblocking(&self, buffer: &[u8]) -> Result<usize, Error>; fn send_nonblocking(&self, buffer: &[u8]) -> Result<usize, Error>;
}
/// Connection-based listener socket interface fn listen(self: Arc<Self>) -> Result<(), Error>;
#[async_trait]
pub trait ListenerSocket: Socket {
/// Blocks the execution until an incoming connection is accepted
async fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error>; async fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error>;
fn accept_nonblocking(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error>; fn accept_nonblocking(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error>;
async fn shutdown(&self, read: bool, write: bool) -> Result<(), Error>;
} }
impl SocketWrapper { impl SocketWrapper {
@ -110,103 +137,108 @@ impl SocketWrapper {
} }
} }
pub fn from_listener(socket: Arc<dyn ListenerSocket + 'static>) -> Self { async fn send_inner(&self, data: &[u8], remote: Option<SocketAddr>) -> Result<usize, Error> {
Self { let timeout = self.options.read().send_timeout;
inner: SocketInner::Listener(socket),
options: IrqSafeRwLock::new(InnerOptions::default()), match &self.inner {
SocketInner::Packet(socket) => socket.clone().send_to(remote, data, timeout).await,
SocketInner::Connection(socket) => socket.send(data, timeout).await,
} }
} }
pub fn accept(&self) -> Result<(SocketWrapper, SocketAddr), Error> { async fn receive_inner(&self, data: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let SocketInner::Listener(socket) = &self.inner else { let timeout = self.options.read().recv_timeout;
match &self.inner {
SocketInner::Packet(socket) => socket.clone().receive_from(data, timeout).await,
SocketInner::Connection(socket) => socket.clone().receive(data, timeout).await,
}
}
async fn connect_inner(&self, remote: SocketAddr) -> Result<(), Error> {
let timeout = self.options.read().connect_timeout;
match &self.inner {
SocketInner::Packet(socket) => socket.clone().connect(remote),
SocketInner::Connection(socket) => socket.clone().connect(remote, timeout).await,
}
}
pub fn connect(&self, remote: SocketAddr) -> Result<(), Error> {
block!(self.connect_inner(remote).await)?
}
pub fn send_to(&self, data: &[u8], remote: Option<SocketAddr>) -> Result<usize, Error> {
if self.options.read().non_blocking {
match &self.inner {
SocketInner::Packet(socket) => socket.clone().send_nonblocking(remote, data),
SocketInner::Connection(socket) => socket.clone().send_nonblocking(data),
}
} else {
block!(self.send_inner(data, remote).await)?
}
}
pub fn receive_from(&self, data: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
if self.options.read().non_blocking {
match &self.inner {
SocketInner::Packet(socket) => socket.clone().receive_nonblocking(data),
SocketInner::Connection(socket) => socket.receive_nonblocking(data),
}
} else {
block!(self.receive_inner(data).await)?
}
}
pub fn bind(&self, local: SocketAddr) -> Result<(), Error> {
match &self.inner {
SocketInner::Packet(socket) => socket.clone().bind(local),
SocketInner::Connection(socket) => socket.clone().bind(local),
}
}
pub fn listen(&self) -> Result<(), Error> {
match &self.inner {
SocketInner::Packet(_) => Err(Error::InvalidOperation),
SocketInner::Connection(socket) => socket.clone().listen(),
}
}
pub fn accept(&self) -> Result<(FileRef, SocketAddr), Error> {
let SocketInner::Connection(socket) = &self.inner else {
return Err(Error::InvalidOperation); return Err(Error::InvalidOperation);
}; };
let options = self.options.read(); let (remote, stream) = if self.options.read().non_blocking {
let (remote, remote_socket) = match (options.non_blocking, options.recv_timeout) { socket.accept_nonblocking()?
(false, timeout) => { } else {
let fut = socket.accept(); block!(socket.accept().await)??
block!(maybe_timeout(fut, timeout).await)???
}
(true, _) => socket.accept_nonblocking()?,
}; };
let remote_socket = Self::from_connection(remote_socket); let file = File::from_socket(SocketWrapper::from_connection(stream));
Ok((remote_socket, remote)) Ok((file, remote))
} }
fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> { pub fn shutdown(&self, how: SocketShutdown) -> Result<(), Error> {
match &self.inner { let SocketInner::Connection(socket) = &self.inner else {
SocketInner::Packet(socket) => socket.receive_nonblocking(buffer), return Err(Error::InvalidOperation);
SocketInner::Connection(socket) => { };
let remote = socket.remote_address().ok_or(Error::NotConnected)?;
let len = socket.receive_nonblocking(buffer)?; block!(
Ok((len, remote)) socket
} .shutdown(
SocketInner::Listener(_) => Err(Error::InvalidOperation), how.contains(SocketShutdown::READ),
} how.contains(SocketShutdown::WRITE),
)
.await
)?
} }
async fn receive( pub fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
&self,
buffer: &mut [u8],
timeout: Option<Duration>,
) -> Result<(usize, SocketAddr), Error> {
match &self.inner {
SocketInner::Packet(socket) => {
maybe_timeout(socket.receive_from(buffer), timeout).await?
}
SocketInner::Connection(socket) => {
let remote = socket.remote_address().ok_or(Error::NotConnected)?;
let len = maybe_timeout(socket.receive(buffer), timeout).await??;
Ok((len, remote))
}
SocketInner::Listener(_) => Err(Error::InvalidOperation),
}
}
fn send_nonblocking(&self, buffer: &[u8], remote: Option<SocketAddr>) -> Result<usize, Error> {
match &self.inner {
SocketInner::Packet(socket) => socket.send_nonblocking(remote, buffer),
SocketInner::Connection(socket) => socket.send_nonblocking(buffer),
SocketInner::Listener(_) => Err(Error::InvalidOperation),
}
}
async fn send(
&self,
buffer: &[u8],
remote: Option<SocketAddr>,
timeout: Option<Duration>,
) -> Result<usize, Error> {
match &self.inner {
SocketInner::Packet(socket) => {
maybe_timeout(socket.send_to(remote, buffer), timeout).await?
}
SocketInner::Connection(socket) => maybe_timeout(socket.send(buffer), timeout).await?,
SocketInner::Listener(_) => Err(Error::InvalidOperation),
}
}
pub fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
let options = self.options.read();
match (options.non_blocking, options.recv_timeout) {
(false, timeout) => block!(self.receive(buffer, timeout).await)?,
(true, _) => self.receive_nonblocking(buffer),
}
}
pub fn send_to(&self, buffer: &[u8], remote: Option<SocketAddr>) -> Result<usize, Error> {
let options = self.options.read();
match (options.non_blocking, options.recv_timeout) {
(false, timeout) => block!(self.send(buffer, remote, timeout).await)?,
(true, _) => self.send_nonblocking(buffer, remote),
}
}
}
impl Socket for SocketWrapper {
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option { match option {
SocketOption::NonBlocking(nb) => {
self.options.write().non_blocking = *nb;
return Ok(());
}
SocketOption::RecvTimeout(timeout) => { SocketOption::RecvTimeout(timeout) => {
self.options.write().recv_timeout = *timeout; self.options.write().recv_timeout = *timeout;
return Ok(()); return Ok(());
@ -215,22 +247,25 @@ impl Socket for SocketWrapper {
self.options.write().send_timeout = *timeout; self.options.write().send_timeout = *timeout;
return Ok(()); return Ok(());
} }
SocketOption::NonBlocking(nb) => { SocketOption::ConnectTimeout(timeout) => {
self.options.write().non_blocking = *nb; self.options.write().connect_timeout = *timeout;
return Ok(()); return Ok(());
} }
_ => (), _ => (),
} }
match &self.inner { match &self.inner {
SocketInner::Packet(socket) => socket.set_option(option),
SocketInner::Listener(socket) => socket.set_option(option),
SocketInner::Connection(socket) => socket.set_option(option), SocketInner::Connection(socket) => socket.set_option(option),
SocketInner::Packet(socket) => socket.set_option(option),
} }
} }
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> { pub fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option { match option {
SocketOption::NonBlocking(nb) => {
*nb = self.options.read().non_blocking;
return Ok(());
}
SocketOption::RecvTimeout(timeout) => { SocketOption::RecvTimeout(timeout) => {
*timeout = self.options.read().recv_timeout; *timeout = self.options.read().recv_timeout;
return Ok(()); return Ok(());
@ -239,41 +274,16 @@ impl Socket for SocketWrapper {
*timeout = self.options.read().send_timeout; *timeout = self.options.read().send_timeout;
return Ok(()); return Ok(());
} }
SocketOption::NonBlocking(nb) => { SocketOption::ConnectTimeout(timeout) => {
*nb = self.options.read().non_blocking; *timeout = self.options.read().connect_timeout;
return Ok(()); return Ok(());
} }
_ => (), _ => (),
} }
match &self.inner { match &self.inner {
SocketInner::Packet(socket) => socket.get_option(option),
SocketInner::Listener(socket) => socket.get_option(option),
SocketInner::Connection(socket) => socket.get_option(option), SocketInner::Connection(socket) => socket.get_option(option),
} SocketInner::Packet(socket) => socket.get_option(option),
}
fn local_address(&self) -> SocketAddr {
match &self.inner {
SocketInner::Packet(socket) => socket.local_address(),
SocketInner::Listener(socket) => socket.local_address(),
SocketInner::Connection(socket) => socket.local_address(),
}
}
fn remote_address(&self) -> Option<SocketAddr> {
match &self.inner {
SocketInner::Packet(socket) => socket.remote_address(),
SocketInner::Listener(socket) => socket.remote_address(),
SocketInner::Connection(socket) => socket.remote_address(),
}
}
fn close(&self) -> Result<(), Error> {
match &self.inner {
SocketInner::Packet(socket) => socket.close(),
SocketInner::Listener(socket) => socket.close(),
SocketInner::Connection(socket) => socket.close(),
} }
} }
} }
@ -281,9 +291,8 @@ impl Socket for SocketWrapper {
impl FileReadiness for SocketWrapper { impl FileReadiness for SocketWrapper {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match &self.inner { match &self.inner {
SocketInner::Packet(socket) => socket.poll_read(cx),
SocketInner::Listener(socket) => socket.poll_read(cx),
SocketInner::Connection(socket) => socket.poll_read(cx), SocketInner::Connection(socket) => socket.poll_read(cx),
SocketInner::Packet(socket) => socket.poll_read(cx),
} }
} }
} }
@ -293,7 +302,6 @@ impl fmt::Debug for SocketWrapper {
match &self.inner { match &self.inner {
SocketInner::Packet(socket) => socket.fmt(f), SocketInner::Packet(socket) => socket.fmt(f),
SocketInner::Connection(socket) => socket.fmt(f), SocketInner::Connection(socket) => socket.fmt(f),
SocketInner::Listener(socket) => socket.fmt(f),
} }
} }
} }
@ -301,9 +309,8 @@ impl fmt::Debug for SocketWrapper {
impl Drop for SocketInner { impl Drop for SocketInner {
fn drop(&mut self) { fn drop(&mut self) {
let res = match self { let res = match self {
Self::Packet(socket) => socket.close(), Self::Packet(socket) => socket.clone().close(),
Self::Connection(socket) => socket.close(), Self::Connection(socket) => socket.clone().close(),
Self::Listener(socket) => socket.close(),
}; };
if let Err(error) = res { if let Err(error) = res {
log::warn!("Socket close error: {error:?}"); log::warn!("Socket close error: {error:?}");
@ -314,69 +321,10 @@ impl Drop for SocketInner {
impl Default for InnerOptions { impl Default for InnerOptions {
fn default() -> Self { fn default() -> Self {
Self { Self {
connect_timeout: None,
recv_timeout: None, recv_timeout: None,
send_timeout: None, send_timeout: None,
non_blocking: false, non_blocking: false,
} }
} }
} }
// impl From<Arc<dyn ConnectionSocket + 'static>> for SocketWrapper {
// fn from(value: Arc<dyn ConnectionSocket + 'static>) -> Self {
// Self::Connection(value)
// }
// }
//
// impl From<Arc<dyn ListenerSocket + 'static>> for SocketWrapper {
// fn from(value: Arc<dyn ListenerSocket + 'static>) -> Self {
// Self::Listener(value)
// }
// }
//
// impl From<Arc<dyn PacketSocket + 'static>> for SocketWrapper {
// fn from(value: Arc<dyn PacketSocket + 'static>) -> Self {
// Self::Packet(value)
// }
// }
// impl Deref for PacketSocketWrapper {
// type Target = dyn PacketSocket;
//
// fn deref(&self) -> &Self::Target {
// self.0.as_ref()
// }
// }
//
// impl Drop for PacketSocketWrapper {
// fn drop(&mut self) {
// self.0.close().ok();
// }
// }
//
// impl Deref for ListenerSocketWrapper {
// type Target = dyn ListenerSocket;
//
// fn deref(&self) -> &Self::Target {
// self.0.as_ref()
// }
// }
//
// impl Drop for ListenerSocketWrapper {
// fn drop(&mut self) {
// self.0.close().ok();
// }
// }
//
// impl Deref for ConnectionSocketWrapper {
// type Target = dyn ConnectionSocket;
//
// fn deref(&self) -> &Self::Target {
// self.0.as_ref()
// }
// }
//
// impl Drop for ConnectionSocketWrapper {
// fn drop(&mut self) {
// self.0.close().ok();
// }
// }

View File

@ -8,7 +8,7 @@ pub(crate) use abi::{
UnmountOptions, UnmountOptions,
}, },
mem::{MappingFlags, MappingSource}, mem::{MappingFlags, MappingSource},
net::SocketType, net::{SocketShutdown, SocketType},
process::{Signal, SignalEntryData, SpawnOptions, WaitFlags}, process::{Signal, SignalEntryData, SpawnOptions, WaitFlags},
system::SystemInfo, system::SystemInfo,
}; };

View File

@ -3,127 +3,218 @@ use core::{mem::MaybeUninit, net::SocketAddr};
use abi::{ use abi::{
error::Error, error::Error,
io::RawFd, io::RawFd,
net::{SocketConnect, SocketOption, SocketType}, net::{SocketOption, SocketShutdown, SocketType},
}; };
use libk::{ use libk::{
task::thread::Thread, task::thread::Thread,
vfs::{File, Socket, SocketWrapper}, vfs::{File, FileRef, SocketWrapper},
}; };
use ygg_driver_net_core::socket::{RawSocket, TcpListener, TcpSocket, UdpSocket}; use ygg_driver_net_core::socket::{RawSocket, TcpSocket, UdpSocket};
use crate::syscall::run_with_io; use crate::syscall::run_with_io;
// Network fn get_socket(fd: RawFd) -> Result<FileRef, Error> {
pub(crate) fn connect_socket(
connect: &mut SocketConnect,
local_result: &mut MaybeUninit<SocketAddr>,
) -> Result<RawFd, Error> {
let thread = Thread::current(); let thread = Thread::current();
let process = thread.process(); let process = thread.process();
run_with_io(&process, |mut io| { run_with_io(&process, |io| io.files.file(fd).cloned())
let (local, fd) = match connect {
SocketConnect::Tcp(remote, timeout) => {
let remote = (*remote).into();
let (local, socket) = libk::block!(TcpSocket::connect(remote, *timeout).await)??;
let file = File::from_socket(SocketWrapper::from_connection(socket));
let fd = io.files.place_file(file, true)?;
(local.into(), fd)
}
SocketConnect::Udp(_socket, _remote) => {
todo!("UDP socket connect")
}
};
local_result.write(local);
Ok(fd)
})
} }
pub(crate) fn bind_socket(listen: &SocketAddr, ty: SocketType) -> Result<RawFd, Error> { pub(crate) fn create_socket(ty: SocketType) -> Result<RawFd, Error> {
let thread = Thread::current(); let thread = Thread::current();
let process = thread.process(); let process = thread.process();
run_with_io(&process, |mut io| { run_with_io(&process, |mut io| {
let listen = (*listen).into();
let socket = match ty { let socket = match ty {
SocketType::RawPacket => SocketWrapper::from_packet(RawSocket::bind()?), SocketType::RawPacket => SocketWrapper::from_packet(RawSocket::new()),
SocketType::UdpPacket => SocketWrapper::from_packet(UdpSocket::bind(listen)?), SocketType::UdpPacket => SocketWrapper::from_packet(UdpSocket::new()),
SocketType::TcpStream => SocketWrapper::from_listener(TcpListener::bind(listen)?), SocketType::TcpStream => SocketWrapper::from_connection(TcpSocket::new()),
}; };
let file = File::from_socket(socket); let file = File::from_socket(socket);
let fd = io.files.place_file(file, true)?; let fd = io.files.place_file(file, true)?;
Ok(fd) Ok(fd)
}) })
} }
pub(crate) fn accept( pub(crate) fn bind(sock_fd: RawFd, local: &SocketAddr) -> Result<(), Error> {
socket_fd: RawFd, let file = get_socket(sock_fd)?;
remote_result: &mut MaybeUninit<SocketAddr>, file.as_socket()?.bind((*local).into())
) -> Result<RawFd, Error> { }
pub(crate) fn listen(sock_fd: RawFd) -> Result<(), Error> {
let file = get_socket(sock_fd)?;
file.as_socket()?.listen()
}
pub(crate) fn connect(sock_fd: RawFd, remote: &SocketAddr) -> Result<(), Error> {
let file = get_socket(sock_fd)?;
file.as_socket()?.connect((*remote).into())
}
pub(crate) fn accept(sock_fd: RawFd, remote: &mut MaybeUninit<SocketAddr>) -> Result<RawFd, Error> {
let thread = Thread::current(); let thread = Thread::current();
let process = thread.process(); let process = thread.process();
run_with_io(&process, |mut io| { run_with_io(&process, |mut io| {
let file = io.files.file(socket_fd)?; let listener = io.files.file(sock_fd)?;
let listener = file.as_socket()?; let listener = listener.as_socket()?;
let (socket, remote) = listener.accept()?;
remote_result.write(remote.into()); let (stream_file, stream_remote) = listener.accept()?;
let fd = io.files.place_file(File::from_socket(socket), true)?; let stream_fd = io.files.place_file(stream_file, true)?;
Ok(fd)
remote.write(stream_remote.into());
Ok(stream_fd)
}) })
} }
pub(crate) fn shutdown(sock_fd: RawFd, how: SocketShutdown) -> Result<(), Error> {
let file = get_socket(sock_fd)?;
file.as_socket()?.shutdown(how)
}
pub(crate) fn send_to( pub(crate) fn send_to(
socket_fd: RawFd, sock_fd: RawFd,
buffer: &[u8], data: &[u8],
recepient: &Option<SocketAddr>, remote: &Option<SocketAddr>,
) -> Result<usize, Error> { ) -> Result<usize, Error> {
let thread = Thread::current(); let file = get_socket(sock_fd)?;
let process = thread.process(); file.as_socket()?.send_to(data, remote.map(Into::into))
run_with_io(&process, |io| {
let file = io.files.file(socket_fd)?;
let socket = file.as_socket()?;
let remote = recepient.map(Into::into);
socket.send_to(buffer, remote)
})
} }
pub(crate) fn receive_from( pub(crate) fn receive_from(
socket_fd: RawFd, sock_fd: RawFd,
buffer: &mut [u8], data: &mut [u8],
remote_result: &mut MaybeUninit<SocketAddr>, remote: &mut MaybeUninit<SocketAddr>,
) -> Result<usize, Error> { ) -> Result<usize, Error> {
let thread = Thread::current(); let file = get_socket(sock_fd)?;
let process = thread.process(); let (len, remote_) = file.as_socket()?.receive_from(data)?;
remote.write(remote_.into());
run_with_io(&process, |io| { Ok(len)
let file = io.files.file(socket_fd)?;
let socket = file.as_socket()?;
let (len, remote) = socket.receive_from(buffer)?;
remote_result.write(remote.into());
Ok(len)
})
} }
pub(crate) fn set_socket_option(socket_fd: RawFd, option: &SocketOption) -> Result<(), Error> { pub(crate) fn get_socket_option(sock_fd: RawFd, option: &mut SocketOption) -> Result<(), Error> {
let thread = Thread::current(); let file = get_socket(sock_fd)?;
let process = thread.process(); file.as_socket()?.get_option(option)
run_with_io(&process, |io| {
let file = io.files.file(socket_fd)?;
let socket = file.as_socket()?;
socket.set_option(option)
})
} }
pub(crate) fn get_socket_option(socket_fd: RawFd, option: &mut SocketOption) -> Result<(), Error> { pub(crate) fn set_socket_option(sock_fd: RawFd, option: &SocketOption) -> Result<(), Error> {
let thread = Thread::current(); let file = get_socket(sock_fd)?;
let process = thread.process(); file.as_socket()?.set_option(option)
run_with_io(&process, |io| {
let file = io.files.file(socket_fd)?;
let socket = file.as_socket()?;
socket.get_option(option)
})
} }
// // Network
// pub(crate) fn connect_socket(
// connect: &mut SocketConnect,
// local_result: &mut MaybeUninit<SocketAddr>,
// ) -> Result<RawFd, Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |mut io| {
// let (local, fd) = match connect {
// SocketConnect::Tcp(remote, timeout) => {
// let remote = (*remote).into();
// let (local, socket) = libk::block!(TcpSocket::connect(remote, *timeout).await)??;
// let file = File::from_socket(SocketWrapper::from_connection(socket));
// let fd = io.files.place_file(file, true)?;
// (local.into(), fd)
// }
// SocketConnect::Udp(_socket, _remote) => {
// todo!("UDP socket connect")
// }
// };
// local_result.write(local);
// Ok(fd)
// })
// }
//
// pub(crate) fn bind_socket(listen: &SocketAddr, ty: SocketType) -> Result<RawFd, Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |mut io| {
// let listen = (*listen).into();
// let socket = match ty {
// SocketType::RawPacket => SocketWrapper::from_packet(RawSocket::bind()?),
// SocketType::UdpPacket => SocketWrapper::from_packet(UdpSocket::bind(listen)?),
// SocketType::TcpStream => SocketWrapper::from_listener(TcpListener::bind(listen)?),
// };
// let file = File::from_socket(socket);
// let fd = io.files.place_file(file, true)?;
// Ok(fd)
// })
// }
//
// pub(crate) fn accept(
// socket_fd: RawFd,
// remote_result: &mut MaybeUninit<SocketAddr>,
// ) -> Result<RawFd, Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |mut io| {
// let file = io.files.file(socket_fd)?;
// let listener = file.as_socket()?;
// let (socket, remote) = listener.accept()?;
// remote_result.write(remote.into());
// let fd = io.files.place_file(File::from_socket(socket), true)?;
// Ok(fd)
// })
// }
//
// pub(crate) fn send_to(
// socket_fd: RawFd,
// buffer: &[u8],
// recepient: &Option<SocketAddr>,
// ) -> Result<usize, Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |io| {
// let file = io.files.file(socket_fd)?;
// let socket = file.as_socket()?;
// let remote = recepient.map(Into::into);
// socket.send_to(buffer, remote)
// })
// }
//
// pub(crate) fn receive_from(
// socket_fd: RawFd,
// buffer: &mut [u8],
// remote_result: &mut MaybeUninit<SocketAddr>,
// ) -> Result<usize, Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |io| {
// let file = io.files.file(socket_fd)?;
// let socket = file.as_socket()?;
// let (len, remote) = socket.receive_from(buffer)?;
// remote_result.write(remote.into());
// Ok(len)
// })
// }
//
// pub(crate) fn set_socket_option(socket_fd: RawFd, option: &SocketOption) -> Result<(), Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |io| {
// let file = io.files.file(socket_fd)?;
// let socket = file.as_socket()?;
// socket.set_option(option)
// })
// }
//
// pub(crate) fn get_socket_option(socket_fd: RawFd, option: &mut SocketOption) -> Result<(), Error> {
// let thread = Thread::current();
// let process = thread.process();
//
// run_with_io(&process, |io| {
// let file = io.files.file(socket_fd)?;
// let socket = file.as_socket()?;
// socket.get_option(option)
// })
// }

View File

@ -53,6 +53,13 @@ enum SocketType(u32) {
UdpPacket = 2, UdpPacket = 2,
} }
bitfield SocketShutdown(u32) {
/// Stop reception on a socket
READ: 0,
/// Stop transmission on a socket
WRITE: 1,
}
// abi::mem // abi::mem
bitfield MappingFlags(u32) { bitfield MappingFlags(u32) {
@ -161,13 +168,15 @@ syscall receive_message(
) -> Result<()>; ) -> Result<()>;
// Network // Network
syscall connect_socket(connect: &mut SocketConnect, local: &mut MaybeUninit<SocketAddr>) -> Result<RawFd>;
syscall bind_socket(address: &SocketAddr, ty: SocketType) -> Result<RawFd>;
syscall accept(sock_fd: RawFd, remote: &mut MaybeUninit<SocketAddr>) -> Result<RawFd>;
syscall create_socket(ty: SocketType) -> Result<RawFd>;
syscall bind(sock_fd: RawFd, local: &SocketAddr) -> Result<()>;
syscall listen(sock_fd: RawFd) -> Result<()>;
syscall connect(sock_fd: RawFd, remote: &SocketAddr) -> Result<()>;
syscall accept(sock_fd: RawFd, remote: &mut MaybeUninit<SocketAddr>) -> Result<RawFd>;
syscall shutdown(sock_fd: RawFd, how: SocketShutdown) -> Result<()>;
syscall send_to(sock_fd: RawFd, data: &[u8], remote: &Option<SocketAddr>) -> Result<usize>; syscall send_to(sock_fd: RawFd, data: &[u8], remote: &Option<SocketAddr>) -> Result<usize>;
syscall receive_from(sock_fd: RawFd, data: &mut [u8], remote: &mut MaybeUninit<SocketAddr>) -> Result<usize>; syscall receive_from(sock_fd: RawFd, data: &mut [u8], remote: &mut MaybeUninit<SocketAddr>) -> Result<usize>;
syscall get_socket_option(sock_fd: RawFd, option: &mut SocketOption<'_>) -> Result<()>; syscall get_socket_option(sock_fd: RawFd, option: &mut SocketOption<'_>) -> Result<()>;
syscall set_socket_option(sock_fd: RawFd, option: &SocketOption<'_>) -> Result<()>; syscall set_socket_option(sock_fd: RawFd, option: &SocketOption<'_>) -> Result<()>;

View File

@ -4,6 +4,7 @@
clippy::new_without_default, clippy::new_without_default,
clippy::should_implement_trait, clippy::should_implement_trait,
clippy::module_inception, clippy::module_inception,
clippy::missing_transmute_annotations,
incomplete_features, incomplete_features,
stable_features stable_features
)] )]

View File

@ -10,8 +10,7 @@ pub mod types;
use core::time::Duration; use core::time::Duration;
pub use crate::generated::SocketType; pub use crate::generated::{SocketShutdown, SocketType};
use crate::io::RawFd;
pub use types::{ pub use types::{
ip_addr::{IpAddr, Ipv4Addr, Ipv6Addr}, ip_addr::{IpAddr, Ipv4Addr, Ipv6Addr},
@ -20,16 +19,6 @@ pub use types::{
MacAddress, MacAddress,
}; };
/// Describes a socket connect operation
#[derive(Clone, Debug)]
pub enum SocketConnect {
/// Connect a TCP socket with optional timeout.
Tcp(core::net::SocketAddr, Option<Duration>),
/// "Connect" an UDP socket, this just sets the sender's address in the socket so
/// the caller can then use send().
Udp(RawFd, core::net::SocketAddr),
}
/// Describes a method to query an interface /// Describes a method to query an interface
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum SocketInterfaceQuery<'a> { pub enum SocketInterfaceQuery<'a> {
@ -52,6 +41,8 @@ pub enum SocketOption<'a> {
UnbindInterface, UnbindInterface,
/// (Read-only) Hardware address of the bound interface /// (Read-only) Hardware address of the bound interface
BoundHardwareAddress(MacAddress), BoundHardwareAddress(MacAddress),
/// (Read-only) Local socket address
LocalAddress(Option<core::net::SocketAddr>),
/// (Read-only) Remote socket address /// (Read-only) Remote socket address
PeerAddress(Option<core::net::SocketAddr>), PeerAddress(Option<core::net::SocketAddr>),
/// If set, reception will return [crate::error::Error::WouldBlock] if the socket has /// If set, reception will return [crate::error::Error::WouldBlock] if the socket has
@ -63,6 +54,8 @@ pub enum SocketOption<'a> {
RecvTimeout(Option<Duration>), RecvTimeout(Option<Duration>),
/// If not [None], send operations will have a time limit set before returning an error. /// If not [None], send operations will have a time limit set before returning an error.
SendTimeout(Option<Duration>), SendTimeout(Option<Duration>),
/// If not [None], connect() call will timeout after the specified time limit.
ConnectTimeout(Option<Duration>),
/// (UDP) If set, allows multicast packets to be looped back to local host. /// (UDP) If set, allows multicast packets to be looped back to local host.
MulticastLoopV4(bool), MulticastLoopV4(bool),
/// (UDP) If set, allows multicast packets to be looped back to local host. /// (UDP) If set, allows multicast packets to be looped back to local host.

View File

@ -85,6 +85,30 @@ impl From<Ipv4Addr> for core::net::Ipv4Addr {
// IPv6 // IPv6
impl Ipv6Addr {
/// An IPv6 unspecified address `::`.
pub const UNSPECIFIED: Self = Self::new(0, 0, 0, 0, 0, 0, 0, 0);
/// Constructs a new IPv6 address from its words.
///
/// The result represents the IP address `a:b:c:d:e:f:g:h`.
#[allow(clippy::too_many_arguments)]
pub const fn new(a: u16, b: u16, c: u16, d: u16, e: u16, f: u16, g: u16, h: u16) -> Self {
let addr16 = [
a.to_be(),
b.to_be(),
c.to_be(),
d.to_be(),
e.to_be(),
f.to_be(),
g.to_be(),
h.to_be(),
];
// SAFETY: `[u16; 8]` is safe to transmute to `[u8; 16]`
Self(unsafe { core::mem::transmute::<_, [u8; 16]>(addr16) })
}
}
impl FromStr for Ipv6Addr { impl FromStr for Ipv6Addr {
type Err = Error; type Err = Error;

View File

@ -175,3 +175,15 @@ impl From<core::net::SocketAddr> for SocketAddr {
} }
} }
} }
impl From<SocketAddrV4> for SocketAddr {
fn from(value: SocketAddrV4) -> Self {
Self::V4(value)
}
}
impl From<SocketAddrV6> for SocketAddr {
fn from(value: SocketAddrV6) -> Self {
Self::V6(value)
}
}

View File

@ -1,12 +1,10 @@
//! Network-related functions and types //! Network-related functions and types
use core::{mem::MaybeUninit, net::SocketAddr, time::Duration}; use core::{net::SocketAddr, time::Duration};
pub use abi::net::{MacAddress, SocketConnect, SocketInterfaceQuery, SocketOption, SocketType}; pub use abi::net::{MacAddress, SocketInterfaceQuery, SocketOption, SocketShutdown, SocketType};
use abi::{error::Error, io::RawFd}; use abi::{error::Error, io::RawFd};
use crate::sys;
#[allow(unused_macros)] #[allow(unused_macros)]
macro socket_option_variant { macro socket_option_variant {
($opt:ident: bool) => { $crate::net::SocketOption::$opt(false) }, ($opt:ident: bool) => { $crate::net::SocketOption::$opt(false) },
@ -79,23 +77,70 @@ pub mod dns {
} }
} }
fn connect_inner(connect: &mut SocketConnect) -> Result<(SocketAddr, RawFd), Error> { fn bind_inner(fd: RawFd, local: &SocketAddr, listen: bool) -> Result<(), Error> {
let mut local = MaybeUninit::uninit(); unsafe { crate::sys::bind(fd, local) }?;
let fd = unsafe { sys::connect_socket(connect, &mut local) }?; if listen {
let local = unsafe { local.assume_init() }; unsafe { crate::sys::listen(fd) }?;
Ok((local, fd)) }
Ok(())
}
fn connect_inner(fd: RawFd, remote: &SocketAddr, timeout: Option<Duration>) -> Result<(), Error> {
if timeout.is_some() {
unsafe { crate::sys::set_socket_option(fd, &SocketOption::ConnectTimeout(timeout)) }?;
}
unsafe { crate::sys::connect(fd, remote) }?;
Ok(())
}
/// Creates a new socket and binds it to a local address
pub fn create_and_bind(ty: SocketType, local: &SocketAddr, listen: bool) -> Result<RawFd, Error> {
let fd = unsafe { crate::sys::create_socket(ty) }?;
match bind_inner(fd, local, listen) {
Ok(()) => Ok(fd),
Err(error) => {
unsafe { crate::sys::close(fd) }.ok();
Err(error)
}
}
}
/// Binds a TCP listener socket to some local address
pub fn bind_tcp(local: &SocketAddr) -> Result<RawFd, Error> {
create_and_bind(SocketType::TcpStream, local, true)
}
/// Binds a raw socket to some network interface
pub fn bind_raw(iface: SocketInterfaceQuery<'_>) -> Result<RawFd, Error> {
let fd = unsafe { crate::sys::create_socket(SocketType::RawPacket) }?;
let option = SocketOption::BindInterface(iface);
match unsafe { crate::sys::set_socket_option(fd, &option) } {
Ok(()) => Ok(fd),
Err(error) => {
unsafe { crate::sys::close(fd) }.ok();
Err(error)
}
}
}
/// Binds an UDP socket to some local address
pub fn bind_udp(local: &SocketAddr) -> Result<RawFd, Error> {
create_and_bind(SocketType::UdpPacket, local, false)
} }
/// Connect to a TCP listener /// Connect to a TCP listener
pub fn connect_tcp( pub fn connect_tcp(remote: &SocketAddr, timeout: Option<Duration>) -> Result<RawFd, Error> {
remote: SocketAddr, let fd = unsafe { crate::sys::create_socket(SocketType::TcpStream) }?;
timeout: Option<Duration>, match connect_inner(fd, remote, timeout) {
) -> Result<(SocketAddr, RawFd), Error> { Ok(()) => Ok(fd),
connect_inner(&mut SocketConnect::Tcp(remote, timeout)) Err(error) => {
unsafe { crate::sys::close(fd) }.ok();
Err(error)
}
}
} }
/// "Connect" an UDP socket /// "Connect" an UDP socket
pub fn connect_udp(socket_fd: RawFd, remote: SocketAddr) -> Result<(), Error> { pub fn connect_udp(socket_fd: RawFd, remote: &SocketAddr) -> Result<(), Error> {
connect_inner(&mut SocketConnect::Udp(socket_fd, remote))?; connect_inner(socket_fd, remote, None)
Ok(())
} }

View File

@ -21,7 +21,7 @@ mod generated {
TerminalOptions, TerminalSize, TimerOptions, UnmountOptions, TerminalOptions, TerminalSize, TimerOptions, UnmountOptions,
}, },
mem::{MappingFlags, MappingSource}, mem::{MappingFlags, MappingSource},
net::SocketType, net::{SocketShutdown, SocketType},
process::{ process::{
ExecveOptions, ProcessGroupId, ProcessId, Signal, SignalEntryData, SpawnOptions, ExecveOptions, ProcessGroupId, ProcessId, Signal, SignalEntryData, SpawnOptions,
ThreadSpawnOptions, WaitFlags, ThreadSpawnOptions, WaitFlags,

10
test.c
View File

@ -1,8 +1,14 @@
#include <pthread.h> #include <netinet/in.h>
#include <unistd.h> #include <stdlib.h>
#include <assert.h> #include <assert.h>
#include <stdio.h> #include <stdio.h>
int main(int argc, const char **argv) { int main(int argc, const char **argv) {
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
perror("socket()");
return EXIT_FAILURE;
}
return 0; return 0;
} }

View File

@ -24,7 +24,7 @@ fn include_dir(d: &DirEntry) -> bool {
&& d.path() && d.path()
.iter() .iter()
.nth(2) .nth(2)
.map_or(false, |c| c.to_str().map_or(false, |x| !x.starts_with("_"))) .is_some_and(|c| c.to_str().is_some_and(|x| !x.starts_with("_")))
} }
fn generate_header(config_path: impl AsRef<Path>, header_output: impl AsRef<Path>) { fn generate_header(config_path: impl AsRef<Path>, header_output: impl AsRef<Path>) {
@ -67,11 +67,7 @@ fn generate_header(config_path: impl AsRef<Path>, header_output: impl AsRef<Path
fn compile_crt0(arch: &str, output_dir: impl AsRef<Path>) { fn compile_crt0(arch: &str, output_dir: impl AsRef<Path>) {
let output_dir = output_dir.as_ref(); let output_dir = output_dir.as_ref();
let mut command = Command::new("clang"); let mut command = Command::new("clang");
let arch = if arch == "x86" { let arch = if arch == "x86" { "i686" } else { arch };
"i686"
} else {
arch
};
let input_dir = PathBuf::from("crt").join(arch); let input_dir = PathBuf::from("crt").join(arch);
let crt0_c = input_dir.join("crt0.c"); let crt0_c = input_dir.join("crt0.c");
let crt0_s = input_dir.join("crt0.S"); let crt0_s = input_dir.join("crt0.S");

View File

@ -0,0 +1,15 @@
language = "C"
style = "Type"
sys_includes = [
"netinet/in.h",
"inttypes.h"
]
no_includes = true
include_guard = "_ARPA_INET_H"
usize_type = "size_t"
isize_type = "ssize_t"
[export]

View File

@ -0,0 +1,20 @@
#[no_mangle]
unsafe extern "C" fn htonl(v: u32) -> u32 {
v.to_be()
}
#[no_mangle]
unsafe extern "C" fn htons(v: u16) -> u16 {
v.to_be()
}
#[no_mangle]
unsafe extern "C" fn ntohl(v: u32) -> u32 {
u32::from_be(v)
}
#[no_mangle]
unsafe extern "C" fn ntohs(v: u16) -> u16 {
u16::from_be(v)
}

View File

@ -195,6 +195,7 @@ impl From<yggdrasil_rt::Error> for Errno {
Error::DirectoryNotEmpty => Errno::ENOTEMPTY, Error::DirectoryNotEmpty => Errno::ENOTEMPTY,
Error::NotConnected => Errno::ENOTCONN, Error::NotConnected => Errno::ENOTCONN,
Error::ProcessNotFound => Errno::ESRCH, Error::ProcessNotFound => Errno::ESRCH,
Error::CrossDeviceLink => Errno::EXDEV,
} }
} }
} }

View File

@ -1,6 +1,9 @@
use core::ffi::{c_char, c_int, c_short, VaList}; use core::ffi::{c_char, c_int, c_short, VaList};
use yggdrasil_rt::{io::{AccessMode, FileMode, OpenOptions}, sys as syscall}; use yggdrasil_rt::{
io::{AccessMode, FileMode, OpenOptions},
sys as syscall,
};
use crate::{ use crate::{
error::{CFdResult, CIntCountResult, CIntZeroResult, EResult, ResultExt, TryFromExt}, error::{CFdResult, CIntCountResult, CIntZeroResult, EResult, ResultExt, TryFromExt},
@ -125,10 +128,7 @@ fn open_opts(opts: c_int, ap: &mut VaList) -> EResult<OpenMode> {
} }
// TODO O_CLOEXEC // TODO O_CLOEXEC
if opts if opts & (O_DSYNC | O_RSYNC | O_SYNC | O_TTY_INIT | O_NONBLOCK | O_NOFOLLOW | O_NOCTTY) != 0 {
& (O_DSYNC | O_RSYNC | O_SYNC | O_TTY_INIT | O_NONBLOCK | O_NOFOLLOW | O_NOCTTY)
!= 0
{
todo!(); todo!();
} }
@ -187,7 +187,7 @@ pub(crate) unsafe extern "C" fn faccessat(
atfd: c_int, atfd: c_int,
path: *const c_char, path: *const c_char,
mode: c_int, mode: c_int,
flags: c_int, _flags: c_int,
) -> CIntZeroResult { ) -> CIntZeroResult {
let atfd = util::at_fd(atfd)?; let atfd = util::at_fd(atfd)?;
let path = path.ensure_str(); let path = path.ensure_str();

View File

@ -136,6 +136,11 @@ pub mod sys_types;
pub mod sys_utsname; pub mod sys_utsname;
pub mod sys_wait; pub mod sys_wait;
// Network
pub mod arpa_inet;
pub mod netinet_in;
pub mod sys_socket;
// TODO Generate those as part of dyn-loader (and make dyn-loader a shared library) // TODO Generate those as part of dyn-loader (and make dyn-loader a shared library)
pub mod link; pub mod link;

View File

@ -0,0 +1,17 @@
language = "C"
style = "Tag"
sys_includes = [
"inttypes.h",
"sys/socket.h",
"arpa/inet.h"
]
no_includes = true
include_guard = "_NETINET_IN_H"
usize_type = "size_t"
isize_type = "ssize_t"
[export]
include = ["sockaddr_in", "sockaddr_in6"]

View File

@ -0,0 +1,60 @@
use core::ffi::c_int;
use super::sys_socket::sa_family_t;
pub type in_port_t = u16;
pub type in_addr_t = u32;
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct in_addr {
pub s_addr: in_addr_t
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct in6_addr {
pub s6_addr: [u8; 16]
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct sockaddr_in {
pub sin_family: sa_family_t,
pub sin_port: in_port_t,
pub sin_addr: in_addr,
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct sockaddr_in6 {
pub sin6_family: sa_family_t,
pub sin6_port: in_port_t,
pub sin6_flowinfo: u32,
pub sin6_addr: in6_addr,
pub sin6_scope_id: u32
}
// TODO more IPv6
pub const IPPROTO_ICMP: c_int = 1;
pub const IPPROTO_TCP: c_int = 6;
pub const IPPROTO_UDP: c_int = 17;
pub const IPPROTO_IP: c_int = 0;
pub const IPPROTO_IPV6: c_int = 41;
pub const IPPROTO_RAW: c_int = 255;
pub const INADDR_ANY: in_addr_t = 0;
pub const INADDR_BROADCAST: in_addr_t = 0xFFFFFFFF;
pub const INET_ADDRSTRLEN: usize = 16;
pub const INET6_ADDRSTRLEN: usize = 46;
pub const IPV6_JOIN_GROUP: c_int = 6001;
pub const IPV6_LEAVE_GROUP: c_int = 6002;
pub const IPV6_MULTICAST_HOPS: c_int = 6003;
pub const IPV6_MULTICAST_IF: c_int = 6004;
pub const IPV6_MULTICAST_LOOP: c_int = 6005;
pub const IPV6_UNICAST_HOPS: c_int = 6006;
pub const IPV6_V6ONLY: c_int = 6007;

View File

@ -1,9 +1,15 @@
use core::ffi::{c_char, c_int}; use core::ffi::{c_char, c_int};
use yggdrasil_rt::sys as syscall; use yggdrasil_rt::{
io::{RemoveFlags, Rename},
sys as syscall,
};
use crate::{ use crate::{
error::{self, CIntZeroResult, CPtrResult, ResultExt}, headers::errno::Errno, io::managed::{stderr, FILE}, util::{PointerExt, PointerStrExt} error::{self, CIntZeroResult, CPtrResult, ResultExt},
headers::{errno::Errno, fcntl::AT_FDCWD},
io::managed::{stderr, FILE},
util::{self, PointerExt, PointerStrExt},
}; };
#[no_mangle] #[no_mangle]
@ -32,21 +38,34 @@ unsafe extern "C" fn ctermid(_buf: *mut c_char) -> *mut c_char {
#[no_mangle] #[no_mangle]
unsafe extern "C" fn remove(path: *const c_char) -> CIntZeroResult { unsafe extern "C" fn remove(path: *const c_char) -> CIntZeroResult {
let path = path.ensure_str(); let path = path.ensure_str();
syscall::remove(None, path).e_map_err(Errno::from)?; syscall::remove(None, path, RemoveFlags::DIRECTORY).e_map_err(Errno::from)?;
CIntZeroResult::SUCCESS CIntZeroResult::SUCCESS
} }
#[no_mangle] #[no_mangle]
unsafe extern "C" fn rename(src: *const c_char, dst: *const c_char) -> c_int { unsafe extern "C" fn rename(src: *const c_char, dst: *const c_char) -> CIntZeroResult {
let src = src.ensure_str(); renameat(AT_FDCWD, src, dst)
let dst = dst.ensure_str();
yggdrasil_rt::debug_trace!("rename {src:?} -> {dst:?}");
todo!()
} }
#[no_mangle] #[no_mangle]
unsafe extern "C" fn renameat(_atfd: c_int, _src: *const c_char, _dst: *const c_char) -> c_int { unsafe extern "C" fn renameat(
todo!() atfd: c_int,
src: *const c_char,
dst: *const c_char,
) -> CIntZeroResult {
let at = util::at_fd(atfd)?;
let source = src.ensure_str();
let destination = dst.ensure_str();
syscall::rename(&Rename {
source_at: at,
destination_at: at,
source,
destination,
})
.e_map_err(Errno::from)?;
CIntZeroResult::SUCCESS
} }
#[no_mangle] #[no_mangle]

View File

@ -1,6 +1,16 @@
use core::{ffi::{c_char, c_int}, ptr::NonNull, slice}; use core::{
ffi::{c_char, c_int},
ptr::NonNull,
slice,
};
use crate::{allocator::c_alloc, error::{self, CPtrResult, CResult}, headers::errno::Errno, io, util::PointerStrExt}; use crate::{
allocator::c_alloc,
error::{self, CPtrResult, CResult},
headers::errno::Errno,
io,
util::PointerStrExt,
};
#[no_mangle] #[no_mangle]
unsafe extern "C" fn grantpt(_fd: c_int) -> c_int { unsafe extern "C" fn grantpt(_fd: c_int) -> c_int {
@ -28,7 +38,10 @@ unsafe extern "C" fn ptsname(_fd: c_int) -> *mut c_char {
} }
#[no_mangle] #[no_mangle]
unsafe extern "C" fn realpath(path: *const c_char, mut resolved_ptr: *mut c_char) -> CPtrResult<c_char> { unsafe extern "C" fn realpath(
path: *const c_char,
resolved_ptr: *mut c_char,
) -> CPtrResult<c_char> {
if path.is_null() { if path.is_null() {
error::errno = Errno::EINVAL; error::errno = Errno::EINVAL;
return CPtrResult::ERROR; return CPtrResult::ERROR;

View File

@ -1,8 +1,18 @@
use core::{ffi::{c_char, c_int, c_void}, num::NonZeroUsize, ptr::{self, NonNull}}; use core::{
ffi::{c_char, c_int, c_void},
ptr::{self, NonNull},
};
use yggdrasil_rt::{io::RawFd, mem::{MappingFlags, MappingSource}, sys as syscall}; use yggdrasil_rt::{
io::RawFd,
mem::{MappingFlags, MappingSource},
sys as syscall,
};
use crate::{error::{self, CPtrResult, EResult, ResultExt, TryFromExt}, headers::errno::Errno}; use crate::{
error::{self, EResult, ResultExt, TryFromExt},
headers::errno::Errno,
};
use super::sys_types::{mode_t, off_t}; use super::sys_types::{mode_t, off_t};

View File

@ -0,0 +1,14 @@
language = "C"
style = "Tag"
sys_includes = [
"sys/types.h"
]
no_includes = true
include_guard = "_SYS_SOCKET_H"
usize_type = "size_t"
isize_type = "ssize_t"
[export]

View File

@ -0,0 +1,72 @@
use core::ffi::{c_int, c_void};
use super::{msghdr, sockaddr, socklen_t};
#[no_mangle]
unsafe extern "C" fn recv(fd: c_int, buffer: *mut c_void, len: usize, flags: c_int) -> isize {
let _ = fd;
let _ = buffer;
let _ = len;
let _ = flags;
todo!()
}
#[no_mangle]
unsafe extern "C" fn recvfrom(
fd: c_int,
buffer: *mut c_void,
len: usize,
flags: c_int,
remote: *mut sockaddr,
remote_len: *mut socklen_t,
) -> isize {
let _ = fd;
let _ = buffer;
let _ = len;
let _ = flags;
let _ = remote;
let _ = remote_len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn recvmsg(fd: c_int, message: *mut msghdr, flags: c_int) -> isize {
let _ = fd;
let _ = message;
let _ = flags;
todo!()
}
#[no_mangle]
unsafe extern "C" fn send(fd: c_int, data: *const c_void, len: usize) -> isize {
let _ = fd;
let _ = data;
let _ = len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn sendmsg(fd: c_int, message: *const msghdr, flags: c_int) -> isize {
let _ = fd;
let _ = message;
let _ = flags;
todo!()
}
#[no_mangle]
unsafe extern "C" fn sendto(
fd: c_int,
data: *const c_void,
len: usize,
flags: c_int,
remote: *const sockaddr,
remote_len: socklen_t,
) -> isize {
let _ = fd;
let _ = data;
let _ = len;
let _ = flags;
let _ = remote;
let _ = remote_len;
todo!()
}

View File

@ -0,0 +1,86 @@
use core::ffi::{c_int, c_void};
mod io;
mod option;
mod socket;
pub type socklen_t = usize;
pub type sa_family_t = u16;
pub const __SS_SIZE: usize = 256;
// __SS_SIZE - sizeof(sa_samily_t)
pub const __SOCKADDR_LEN: usize = __SS_SIZE - 2;
// __SS_SIZE - sizeof(sa_family_t) - sizeof(__ss_aligntype)
pub const __SS_PADSIZE: usize = __SS_SIZE - 6;
pub type __ss_aligntype = u32;
#[derive(Clone, Copy, Debug)]
#[repr(C)]
pub struct sockaddr {
pub sa_family: sa_family_t,
pub sa_data: [u8; __SOCKADDR_LEN],
}
// TODO struct sockaddr_storage
// TODO struct iovec from sys/uio.h
#[derive(Clone, Copy, Debug)]
#[repr(C)]
pub struct iovec {
__dummy: u32
}
#[derive(Clone, Copy, Debug)]
#[repr(C)]
pub struct msghdr {
pub msg_name: *mut c_void,
pub msg_namelen: socklen_t,
pub msg_iov: *mut iovec,
pub msg_iovlen: c_int,
pub msg_control: *mut c_void,
pub msg_controllen: socklen_t,
pub msg_flags: c_int
}
// TODO struct cmsghdr
// socket() parameters
pub const SOCK_DGRAM: c_int = 1;
pub const SOCK_RAW: c_int = 2;
pub const SOCK_SEQPACKET: c_int = 3;
pub const SOCK_STREAM: c_int = 4;
// setsockopt() parameters
pub const SOL_SOCKET: c_int = 1;
pub const SO_ACCEPTCONN: c_int = 1;
pub const SO_BROADCAST: c_int = 2;
pub const SO_DEBUG: c_int = 3;
pub const SO_DONTROUTE: c_int = 4;
pub const SO_ERROR: c_int = 5;
pub const SO_KEEPALIVE: c_int = 6;
pub const SO_LINGER: c_int = 7;
pub const SO_OOBINLINE: c_int = 8;
pub const SO_RCVBUF: c_int = 9;
pub const SO_RCVLOWAT: c_int = 10;
pub const SO_RCVTIMEO: c_int = 11;
pub const SO_REUSEADDR: c_int = 12;
pub const SO_SNDBUF: c_int = 13;
pub const SO_SNDLOWAT: c_int = 14;
pub const SO_SNDTIMEO: c_int = 15;
pub const SO_TYPE: c_int = 16;
pub const SOMAXCONN: usize = 64;
// TODO msg_flags values
pub const AF_INET: c_int = 1;
pub const AF_INET6: c_int = 2;
pub const AF_UNIX: c_int = 3;
pub const AF_UNSPEC: c_int = 0;
pub const SHUT_RD: c_int = 1 << 0;
pub const SHUT_WR: c_int = 1 << 1;
pub const SHUT_RDWR: c_int = SHUT_RD | SHUT_WR;

View File

@ -0,0 +1,35 @@
use core::ffi::{c_int, c_void};
use super::socklen_t;
#[no_mangle]
unsafe extern "C" fn getsockopt(
fd: c_int,
level: c_int,
name: c_int,
value: *mut c_void,
size: *mut socklen_t,
) -> c_int {
let _ = fd;
let _ = level;
let _ = name;
let _ = value;
let _ = size;
todo!()
}
#[no_mangle]
unsafe extern "C" fn setsockopt(
fd: c_int,
level: c_int,
name: c_int,
value: *const c_void,
size: socklen_t,
) -> c_int {
let _ = fd;
let _ = level;
let _ = name;
let _ = value;
let _ = size;
todo!()
}

View File

@ -0,0 +1,99 @@
use core::ffi::c_int;
use crate::{
error::{self, CFdResult, CResult},
headers::{
errno::Errno,
sys_socket::{AF_INET, SOCK_DGRAM, SOCK_STREAM},
},
};
use super::{sockaddr, socklen_t};
#[no_mangle]
unsafe extern "C" fn accept(fd: c_int, remote: *mut sockaddr, len: *mut socklen_t) -> c_int {
let _ = fd;
let _ = remote;
let _ = len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn bind(fd: c_int, local: *const sockaddr, len: socklen_t) -> c_int {
let _ = fd;
let _ = local;
let _ = len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn connect(fd: c_int, remote: *const sockaddr, len: socklen_t) -> c_int {
let _ = fd;
let _ = remote;
let _ = len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn getpeername(fd: c_int, remote: *mut sockaddr, len: *mut socklen_t) -> c_int {
let _ = fd;
let _ = remote;
let _ = len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn getsockname(fd: c_int, local: *mut sockaddr, len: *mut socklen_t) -> c_int {
let _ = fd;
let _ = local;
let _ = len;
todo!()
}
#[no_mangle]
unsafe extern "C" fn listen(fd: c_int, backlog: c_int) -> c_int {
let _ = fd;
let _ = backlog;
todo!()
}
#[no_mangle]
unsafe extern "C" fn shutdown(fd: c_int, how: c_int) -> c_int {
let _ = fd;
let _ = how;
todo!()
}
#[no_mangle]
unsafe extern "C" fn sockatmark(fd: c_int) -> c_int {
let _ = fd;
todo!()
}
#[no_mangle]
unsafe extern "C" fn socket(domain: c_int, ty: c_int, proto: c_int) -> CFdResult {
match (domain, ty, proto) {
(AF_INET, SOCK_STREAM, 0) => todo!(),
(AF_INET, SOCK_DGRAM, 0) => todo!(),
(_, _, _) => {
yggdrasil_rt::debug_trace!("Unsupported socket({domain}, {ty}, {proto})");
error::errno = Errno::ENOTSUPP;
CFdResult::ERROR
}
}
}
#[no_mangle]
unsafe extern "C" fn socketpair(
domain: c_int,
ty: c_int,
proto: c_int,
sockets: *mut c_int,
) -> c_int {
let _ = domain;
let _ = ty;
let _ = proto;
let _ = sockets;
todo!()
}

View File

@ -1,8 +1,14 @@
use core::{ffi::c_int, ptr::NonNull}; use core::{ffi::c_int, ptr::NonNull};
use crate::{error::{self, CIntZeroResult, CResult, EResult, ResultExt}, headers::{ use crate::{
errno::Errno, sys_time::{__ygg_timespec_t, timespec}, sys_types::{clock_t, clockid_t, pid_t, time_t}, time::{CLOCK_MONOTONIC, CLOCK_REALTIME} error::{CIntZeroResult, EResult, ResultExt},
}}; headers::{
errno::Errno,
sys_time::{__ygg_timespec_t, timespec},
sys_types::{clock_t, clockid_t, pid_t, time_t},
time::{CLOCK_MONOTONIC, CLOCK_REALTIME},
},
};
use yggdrasil_rt::time::{self as rt, ClockType}; use yggdrasil_rt::time::{self as rt, ClockType};
@ -21,7 +27,7 @@ fn clock_type(clock_id: clockid_t) -> EResult<ClockType> {
match clock_id { match clock_id {
CLOCK_REALTIME => EResult::Ok(ClockType::RealTime), CLOCK_REALTIME => EResult::Ok(ClockType::RealTime),
CLOCK_MONOTONIC => EResult::Ok(ClockType::Monotonic), CLOCK_MONOTONIC => EResult::Ok(ClockType::Monotonic),
_ => EResult::Err(Errno::EINVAL) _ => EResult::Err(Errno::EINVAL),
} }
} }
@ -42,14 +48,17 @@ unsafe extern "C" fn clock_getres(_clock_id: clockid_t, _ts: *mut __ygg_timespec
} }
#[no_mangle] #[no_mangle]
unsafe extern "C" fn clock_gettime(clock_id: clockid_t, ts: *mut __ygg_timespec_t) -> CIntZeroResult { unsafe extern "C" fn clock_gettime(
clock_id: clockid_t,
ts: *mut __ygg_timespec_t,
) -> CIntZeroResult {
let clock = clock_type(clock_id)?; let clock = clock_type(clock_id)?;
let time = rt::get_clock(clock).e_map_err(Errno::from)?; let time = rt::get_clock(clock).e_map_err(Errno::from)?;
if let Some(ts) = NonNull::new(ts) { if let Some(ts) = NonNull::new(ts) {
ts.write(timespec { ts.write(timespec {
tv_sec: time_t(time.seconds as _), tv_sec: time_t(time.seconds() as _),
tv_nsec: time.nanoseconds as _ tv_nsec: time.subsec_nanos() as _,
}); });
} }

View File

@ -1,7 +1,4 @@
use core::{ use core::{ffi::c_char, ptr::NonNull};
ffi::{c_char, c_int},
ptr::NonNull,
};
use crate::{ use crate::{
error::{CIntZeroResult, CUsizeResult, OptionExt}, error::{CIntZeroResult, CUsizeResult, OptionExt},
@ -21,7 +18,7 @@ unsafe extern "C" fn mbrlen(_str: *const c_char, _n: usize, _state: *mut mbstate
#[no_mangle] #[no_mangle]
unsafe extern "C" fn wcrtomb(dst: *mut c_char, wc: wchar_t, state: *mut mbstate_t) -> CUsizeResult { unsafe extern "C" fn wcrtomb(dst: *mut c_char, wc: wchar_t, state: *mut mbstate_t) -> CUsizeResult {
let state = match state.as_mut() { let _state = match state.as_mut() {
Some(state) => state, Some(state) => state,
#[allow(static_mut_refs)] #[allow(static_mut_refs)]
None => &mut GLOBAL, None => &mut GLOBAL,

View File

@ -1,8 +1,5 @@
use core::cell::RefCell; use core::cell::RefCell;
use yggdrasil_rt::process::thread_local;
struct RandomState { struct RandomState {
xs64: u64 xs64: u64
} }

View File

@ -2,7 +2,7 @@
use std::os::{ use std::os::{
fd::AsRawFd, fd::AsRawFd,
yggdrasil::io::{poll::PollChannel, raw_socket::RawSocket, timer::TimerFd}, yggdrasil::io::{poll::PollChannel, net::raw_socket::RawSocket, timer::TimerFd},
}; };
use std::{io, mem::size_of, process::ExitCode, time::Duration}; use std::{io, mem::size_of, process::ExitCode, time::Duration};

View File

@ -1,347 +1,349 @@
#![feature(yggdrasil_os, rustc_private)] fn main() {}
use std::{ // #![feature(yggdrasil_os, rustc_private)]
mem::size_of, //
os::{ // use std::{
fd::AsRawFd, // mem::size_of,
yggdrasil::io::{poll::PollChannel, raw_socket::RawSocket, timer::TimerFd}, // os::{
}, // fd::AsRawFd,
process::ExitCode, // yggdrasil::io::{poll::PollChannel, raw_socket::RawSocket, timer::TimerFd},
sync::atomic::{AtomicBool, Ordering}, // },
time::Duration, // process::ExitCode,
}; // sync::atomic::{AtomicBool, Ordering},
// time::Duration,
use bytemuck::Zeroable; // };
use clap::Parser; //
use netutils::{netconfig::NetConfig, Error}; // use bytemuck::Zeroable;
use yggdrasil_abi::net::{ // use clap::Parser;
protocols::{EtherType, EthernetFrame, IcmpV4Frame, InetChecksum, IpProtocol, Ipv4Frame}, // use netutils::{netconfig::NetConfig, Error};
types::NetValueImpl, // use yggdrasil_abi::net::{
IpAddr, Ipv4Addr, MacAddress, // protocols::{EtherType, EthernetFrame, IcmpV4Frame, InetChecksum, IpProtocol, Ipv4Frame},
}; // types::NetValueImpl,
// IpAddr, Ipv4Addr, MacAddress,
#[derive(Parser)] // };
struct Args { //
#[clap( // #[derive(Parser)]
help = "Time (ms) between a reply is received and the next request is sent", // struct Args {
short, // #[clap(
long, // help = "Time (ms) between a reply is received and the next request is sent",
default_value_t = 1000, // short,
value_parser = valid_interval // long,
)] // default_value_t = 1000,
inteval: u32, // value_parser = valid_interval
#[clap( // )]
help = "Time (ms) after which the request is considered unanswered", // inteval: u32,
short, // #[clap(
long, // help = "Time (ms) after which the request is considered unanswered",
default_value_t = 500, // short,
value_parser = valid_timeout, // long,
)] // default_value_t = 500,
timeout: u32, // value_parser = valid_timeout,
#[clap( // )]
help = "Number of requests to perform", // timeout: u32,
short, // #[clap(
long, // help = "Number of requests to perform",
default_value_t = 10 // short,
)] // long,
count: usize, // default_value_t = 10
#[clap( // )]
help = "Amount of bytes to include as data", // count: usize,
short, // #[clap(
long, // help = "Amount of bytes to include as data",
default_value_t = 64, // short,
value_parser = valid_data_size // long,
)] // default_value_t = 64,
data_size: usize, // value_parser = valid_data_size
// )]
#[clap(help = "Address to ping")] // data_size: usize,
address: core::net::IpAddr, //
} // #[clap(help = "Address to ping")]
// address: core::net::IpAddr,
fn valid_interval(s: &str) -> Result<u32, String> { // }
clap_num::number_range(s, 100, 10000) //
} // fn valid_interval(s: &str) -> Result<u32, String> {
// clap_num::number_range(s, 100, 10000)
fn valid_timeout(s: &str) -> Result<u32, String> { // }
clap_num::number_range(s, 100, 5000) //
} // fn valid_timeout(s: &str) -> Result<u32, String> {
// clap_num::number_range(s, 100, 5000)
fn valid_data_size(s: &str) -> Result<usize, String> { // }
clap_num::number_range(s, 4, 128) //
} // fn valid_data_size(s: &str) -> Result<usize, String> {
// clap_num::number_range(s, 4, 128)
struct PingRouting { // }
interface_id: u32, //
source_ip: IpAddr, // struct PingRouting {
destination_ip: IpAddr, // interface_id: u32,
source_mac: MacAddress, // source_ip: IpAddr,
gateway_mac: MacAddress, // destination_ip: IpAddr,
} // source_mac: MacAddress,
// gateway_mac: MacAddress,
struct PingStats { // }
packets_sent: usize, //
packets_received: usize, // struct PingStats {
} // packets_sent: usize,
// packets_received: usize,
fn resolve_routing(address: IpAddr) -> Result<PingRouting, Error> { // }
let mut nc = NetConfig::open()?; //
let routing = nc.query_route(address)?; // fn resolve_routing(address: IpAddr) -> Result<PingRouting, Error> {
let Some(source) = routing.source else { // let mut nc = NetConfig::open()?;
todo!(); // let routing = nc.query_route(address)?;
}; // let Some(source) = routing.source else {
let Some(gateway) = routing.gateway else { // todo!();
todo!(); // };
}; // let Some(gateway) = routing.gateway else {
// todo!();
let gateway_mac = nc.query_arp(routing.interface_id, gateway, true)?; // };
//
Ok(PingRouting { // let gateway_mac = nc.query_arp(routing.interface_id, gateway, true)?;
interface_id: routing.interface_id, //
source_ip: source, // Ok(PingRouting {
destination_ip: routing.destination, // interface_id: routing.interface_id,
source_mac: routing.source_mac, // source_ip: source,
gateway_mac, // destination_ip: routing.destination,
}) // source_mac: routing.source_mac,
} // gateway_mac,
// })
fn validate_ping_reply( // }
packet: &[u8], //
local: Ipv4Addr, // fn validate_ping_reply(
remote: Ipv4Addr, // packet: &[u8],
expect_l4_data: &[u8], // local: Ipv4Addr,
expect_id: u16, // remote: Ipv4Addr,
expect_seq: u16, // expect_l4_data: &[u8],
) -> bool { // expect_id: u16,
if packet.len() < size_of::<EthernetFrame>() + size_of::<Ipv4Frame>() { // expect_seq: u16,
return false; // ) -> bool {
} // if packet.len() < size_of::<EthernetFrame>() + size_of::<Ipv4Frame>() {
// return false;
let l3_offset = size_of::<EthernetFrame>(); // }
//
let l2_frame: &EthernetFrame = bytemuck::from_bytes(&packet[..l3_offset]); // let l3_offset = size_of::<EthernetFrame>();
//
if EtherType::from_network_order(l2_frame.ethertype) != EtherType::IPV4 { // let l2_frame: &EthernetFrame = bytemuck::from_bytes(&packet[..l3_offset]);
return false; //
} // if EtherType::from_network_order(l2_frame.ethertype) != EtherType::IPV4 {
let l3_frame: &Ipv4Frame = // return false;
bytemuck::from_bytes(&packet[l3_offset..l3_offset + size_of::<Ipv4Frame>()]); // }
if l3_frame.protocol != IpProtocol::ICMP // let l3_frame: &Ipv4Frame =
|| u32::from_network_order(l3_frame.source_address) != u32::from(remote) // bytemuck::from_bytes(&packet[l3_offset..l3_offset + size_of::<Ipv4Frame>()]);
|| u32::from_network_order(l3_frame.destination_address) != u32::from(local) // if l3_frame.protocol != IpProtocol::ICMP
{ // || u32::from_network_order(l3_frame.source_address) != u32::from(remote)
return false; // || u32::from_network_order(l3_frame.destination_address) != u32::from(local)
} // {
let mut ip_checksum = InetChecksum::new(); // return false;
ip_checksum.add_value(l3_frame, true); // }
let ip_checksum = ip_checksum.finish(); // let mut ip_checksum = InetChecksum::new();
// ip_checksum.add_value(l3_frame, true);
if ip_checksum != 0 { // let ip_checksum = ip_checksum.finish();
eprintln!("IP checksum mismatch: {:#06x}", ip_checksum); //
return false; // if ip_checksum != 0 {
} // eprintln!("IP checksum mismatch: {:#06x}", ip_checksum);
// return false;
let l4_offset = l3_offset + l3_frame.header_length(); // }
let l4_size = l3_frame //
.total_length() // let l4_offset = l3_offset + l3_frame.header_length();
.saturating_sub(l3_frame.header_length()); // let l4_size = l3_frame
if packet.len() < l4_offset + size_of::<IcmpV4Frame>() + expect_l4_data.len() { // .total_length()
return false; // .saturating_sub(l3_frame.header_length());
} // if packet.len() < l4_offset + size_of::<IcmpV4Frame>() + expect_l4_data.len() {
let l4_frame: &IcmpV4Frame = // return false;
bytemuck::from_bytes(&packet[l4_offset..l4_offset + size_of::<IcmpV4Frame>()]); // }
let l4_data = &packet[l4_offset + size_of::<IcmpV4Frame>()..l4_offset + l4_size]; // let l4_frame: &IcmpV4Frame =
// bytemuck::from_bytes(&packet[l4_offset..l4_offset + size_of::<IcmpV4Frame>()]);
if l4_frame.ty != 0 || l4_frame.code != 0 { // let l4_data = &packet[l4_offset + size_of::<IcmpV4Frame>()..l4_offset + l4_size];
return false; //
} // if l4_frame.ty != 0 || l4_frame.code != 0 {
// return false;
let rest = u32::from_network_order(l4_frame.rest); // }
let reply_id = (rest >> 16) as u16; //
let reply_seq = rest as u16; // let rest = u32::from_network_order(l4_frame.rest);
// let reply_id = (rest >> 16) as u16;
if reply_id != expect_id || reply_seq != expect_seq { // let reply_seq = rest as u16;
eprintln!( //
"ICMP seq/id mismatch: sent {}/{}, got {}/{}", // if reply_id != expect_id || reply_seq != expect_seq {
expect_id, expect_seq, reply_id, reply_seq // eprintln!(
); // "ICMP seq/id mismatch: sent {}/{}, got {}/{}",
return false; // expect_id, expect_seq, reply_id, reply_seq
} // );
// return false;
let mut icmp_checksum = InetChecksum::new(); // }
icmp_checksum.add_value(l4_frame, true); //
icmp_checksum.add_bytes(l4_data, true); // let mut icmp_checksum = InetChecksum::new();
let icmp_checksum = icmp_checksum.finish(); // icmp_checksum.add_value(l4_frame, true);
// icmp_checksum.add_bytes(l4_data, true);
if icmp_checksum != 0 { // let icmp_checksum = icmp_checksum.finish();
eprintln!("ICMP checksum mismatch: {:#06x}", icmp_checksum); //
return false; // if icmp_checksum != 0 {
} // eprintln!("ICMP checksum mismatch: {:#06x}", icmp_checksum);
// return false;
l4_data == expect_l4_data // }
} //
// l4_data == expect_l4_data
#[allow(clippy::too_many_arguments)] // }
fn ping_once( //
socket: &mut RawSocket, // #[allow(clippy::too_many_arguments)]
poll: &mut PollChannel, // fn ping_once(
timer: &mut TimerFd, // socket: &mut RawSocket,
info: &PingRouting, // poll: &mut PollChannel,
timeout: Duration, // timer: &mut TimerFd,
data_len: usize, // info: &PingRouting,
id: u16, // timeout: Duration,
seq: u16, // data_len: usize,
) -> Result<bool, Error> { // id: u16,
let mut buffer = [0; 4096]; // seq: u16,
// ) -> Result<bool, Error> {
let source_ip = info.source_ip.into_ipv4().unwrap(); // let mut buffer = [0; 4096];
let destination_ip = info.destination_ip.into_ipv4().unwrap(); //
let mut l4_data = Vec::with_capacity(data_len); // let source_ip = info.source_ip.into_ipv4().unwrap();
// let destination_ip = info.destination_ip.into_ipv4().unwrap();
for _ in 0..data_len { // let mut l4_data = Vec::with_capacity(data_len);
l4_data.push(rand::random()); //
} // for _ in 0..data_len {
// l4_data.push(rand::random());
let ip_len = (size_of::<Ipv4Frame>() + size_of::<IcmpV4Frame>() + data_len) // }
.try_into() //
.unwrap(); // let ip_len = (size_of::<Ipv4Frame>() + size_of::<IcmpV4Frame>() + data_len)
// .try_into()
let l2_frame = EthernetFrame { // .unwrap();
source_mac: info.source_mac, //
destination_mac: info.gateway_mac, // let l2_frame = EthernetFrame {
ethertype: EtherType::IPV4.to_network_order(), // source_mac: info.source_mac,
}; // destination_mac: info.gateway_mac,
let mut l3_frame = Ipv4Frame { // ethertype: EtherType::IPV4.to_network_order(),
source_address: u32::from(source_ip).to_network_order(), // };
destination_address: u32::from(destination_ip).to_network_order(), // let mut l3_frame = Ipv4Frame {
protocol: IpProtocol::ICMP, // source_address: u32::from(source_ip).to_network_order(),
version_length: 0x45, // destination_address: u32::from(destination_ip).to_network_order(),
total_length: u16::to_network_order(ip_len), // protocol: IpProtocol::ICMP,
flags_frag: u16::to_network_order(0x4000), // version_length: 0x45,
id: u16::to_network_order(0), // total_length: u16::to_network_order(ip_len),
ttl: 255, // flags_frag: u16::to_network_order(0x4000),
..Ipv4Frame::zeroed() // id: u16::to_network_order(0),
}; // ttl: 255,
let mut l4_frame = IcmpV4Frame { // ..Ipv4Frame::zeroed()
ty: 8, // };
code: 0, // let mut l4_frame = IcmpV4Frame {
checksum: u16::to_network_order(0), // ty: 8,
rest: u32::to_network_order(((id as u32) << 16) | (seq as u32)), // code: 0,
}; // checksum: u16::to_network_order(0),
// rest: u32::to_network_order(((id as u32) << 16) | (seq as u32)),
let mut ip_checksum = InetChecksum::new(); // };
ip_checksum.add_value(&l3_frame, true); //
l3_frame.header_checksum = ip_checksum.finish().to_network_order(); // let mut ip_checksum = InetChecksum::new();
// ip_checksum.add_value(&l3_frame, true);
let mut icmp_checksum = InetChecksum::new(); // l3_frame.header_checksum = ip_checksum.finish().to_network_order();
icmp_checksum.add_value(&l4_frame, true); //
icmp_checksum.add_bytes(&l4_data, true); // let mut icmp_checksum = InetChecksum::new();
l4_frame.checksum = icmp_checksum.finish().to_network_order(); // icmp_checksum.add_value(&l4_frame, true);
// icmp_checksum.add_bytes(&l4_data, true);
let mut packet = vec![]; // l4_frame.checksum = icmp_checksum.finish().to_network_order();
packet.extend_from_slice(bytemuck::bytes_of(&l2_frame)); //
packet.extend_from_slice(bytemuck::bytes_of(&l3_frame)); // let mut packet = vec![];
packet.extend_from_slice(bytemuck::bytes_of(&l4_frame)); // packet.extend_from_slice(bytemuck::bytes_of(&l2_frame));
packet.extend_from_slice(&l4_data); // packet.extend_from_slice(bytemuck::bytes_of(&l3_frame));
// packet.extend_from_slice(bytemuck::bytes_of(&l4_frame));
timer.start(timeout)?; // packet.extend_from_slice(&l4_data);
socket.send(&packet)?; //
// timer.start(timeout)?;
loop { // socket.send(&packet)?;
let (fd, result) = poll.wait(None, true)?.unwrap(); //
result?; // loop {
// let (fd, result) = poll.wait(None, true)?.unwrap();
match fd { // result?;
fd if fd == socket.as_raw_fd() => { //
// TODO // match fd {
let len = socket.recv(&mut buffer)?; // fd if fd == socket.as_raw_fd() => {
if validate_ping_reply(&buffer[..len], source_ip, destination_ip, &l4_data, id, seq) // // TODO
{ // let len = socket.recv(&mut buffer)?;
return Ok(true); // if validate_ping_reply(&buffer[..len], source_ip, destination_ip, &l4_data, id, seq)
} // {
} // return Ok(true);
fd if fd == timer.as_raw_fd() => { // }
return Ok(false); // }
} // fd if fd == timer.as_raw_fd() => {
_ => unreachable!(), // return Ok(false);
} // }
} // _ => unreachable!(),
} // }
// }
fn ping( // }
address: IpAddr, //
times: usize, // fn ping(
data_len: usize, // address: IpAddr,
interval: Duration, // times: usize,
timeout: Duration, // data_len: usize,
) -> Result<PingStats, Error> { // interval: Duration,
let routing = resolve_routing(address)?; // timeout: Duration,
// ) -> Result<PingStats, Error> {
let mut stats = PingStats { // let routing = resolve_routing(address)?;
packets_sent: 0, //
packets_received: 0, // let mut stats = PingStats {
}; // packets_sent: 0,
let mut poll = PollChannel::new()?; // packets_received: 0,
let mut timer = TimerFd::new(false, false)?; // };
let mut socket = RawSocket::bind(routing.interface_id)?; // let mut poll = PollChannel::new()?;
// let mut timer = TimerFd::new(false, false)?;
poll.add(timer.as_raw_fd())?; // let mut socket = RawSocket::bind(routing.interface_id)?;
poll.add(socket.as_raw_fd())?; //
// poll.add(timer.as_raw_fd())?;
let id = rand::random(); // poll.add(socket.as_raw_fd())?;
for i in 0..times { //
if INTERRUPTED.load(Ordering::Acquire) { // let id = rand::random();
break; // for i in 0..times {
} // if INTERRUPTED.load(Ordering::Acquire) {
// break;
let result = ping_once( // }
&mut socket, //
&mut poll, // let result = ping_once(
&mut timer, // &mut socket,
&routing, // &mut poll,
timeout, // &mut timer,
data_len, // &routing,
id, // timeout,
i as u16, // data_len,
)?; // id,
stats.packets_sent += 1; // i as u16,
// )?;
if result { // stats.packets_sent += 1;
stats.packets_received += 1; //
println!("[{}/{}] {}: PONG", i + 1, times, address); // if result {
} // stats.packets_received += 1;
// println!("[{}/{}] {}: PONG", i + 1, times, address);
std::thread::sleep(interval); // }
} //
// std::thread::sleep(interval);
Ok(stats) // }
} //
// Ok(stats)
static INTERRUPTED: AtomicBool = AtomicBool::new(false); // }
//
fn main() -> ExitCode { // static INTERRUPTED: AtomicBool = AtomicBool::new(false);
// set_signal_handler(Signal::Interrupted, SignalHandler::Function(interrupt)); //
// fn main() -> ExitCode {
let args = Args::parse(); // // set_signal_handler(Signal::Interrupted, SignalHandler::Function(interrupt));
//
let stats = match ping( // let args = Args::parse();
args.address.into(), //
args.count, // let stats = match ping(
args.data_size, // args.address.into(),
Duration::from_millis(args.inteval.into()), // args.count,
Duration::from_millis(args.timeout.into()), // args.data_size,
) { // Duration::from_millis(args.inteval.into()),
Ok(stats) => stats, // Duration::from_millis(args.timeout.into()),
Err(error) => { // ) {
eprintln!("ping: {}", error); // Ok(stats) => stats,
return ExitCode::FAILURE; // Err(error) => {
} // eprintln!("ping: {}", error);
}; // return ExitCode::FAILURE;
// }
let loss = (stats.packets_sent - stats.packets_received) * 100 / stats.packets_sent; // };
println!( //
"{} sent, {} received, {}% loss", // let loss = (stats.packets_sent - stats.packets_received) * 100 / stats.packets_sent;
stats.packets_sent, stats.packets_received, loss // println!(
); // "{} sent, {} received, {}% loss",
// stats.packets_sent, stats.packets_received, loss
ExitCode::SUCCESS // );
} //
// ExitCode::SUCCESS
// }