From a4e441d236b0b7595ae16591c41fd35075e1a7f6 Mon Sep 17 00:00:00 2001 From: Mark Poliakov Date: Tue, 7 Jan 2025 15:17:17 +0200 Subject: [PATCH] net: move to berkeley-style sockets --- kernel/driver/net/core/src/l3/arp.rs | 20 +- kernel/driver/net/core/src/l4/icmp.rs | 7 +- kernel/driver/net/core/src/l4/tcp.rs | 18 +- kernel/driver/net/core/src/lib.rs | 2 +- kernel/driver/net/core/src/socket/mod.rs | 50 +- kernel/driver/net/core/src/socket/raw.rs | 73 +- kernel/driver/net/core/src/socket/tcp.rs | 495 ------------- .../net/core/src/socket/tcp/listener.rs | 112 +++ kernel/driver/net/core/src/socket/tcp/mod.rs | 270 +++++++ .../driver/net/core/src/socket/tcp/stream.rs | 309 ++++++++ kernel/driver/net/core/src/socket/udp.rs | 152 ++-- kernel/libk/libk-util/src/sync/spin_rwlock.rs | 44 +- kernel/libk/src/vfs/ioctx.rs | 6 +- kernel/libk/src/vfs/mod.rs | 2 +- kernel/libk/src/vfs/socket.rs | 358 ++++----- kernel/src/syscall/imp/mod.rs | 2 +- kernel/src/syscall/imp/sys_net.rs | 255 ++++--- lib/abi/def/yggdrasil.abi | 17 +- lib/abi/src/lib.rs | 1 + lib/abi/src/net/mod.rs | 17 +- lib/abi/src/net/types/ip_addr.rs | 24 + lib/abi/src/net/types/socket_addr.rs | 12 + lib/runtime/src/net.rs | 79 +- lib/runtime/src/sys/mod.rs | 2 +- test.c | 10 +- userspace/lib/ygglibc/build.rs | 8 +- .../src/headers/arpa_inet/cbindgen.toml | 15 + .../lib/ygglibc/src/headers/arpa_inet/mod.rs | 20 + .../lib/ygglibc/src/headers/errno/mod.rs | 1 + .../lib/ygglibc/src/headers/fcntl/mod.rs | 12 +- userspace/lib/ygglibc/src/headers/mod.rs | 5 + .../src/headers/netinet_in/cbindgen.toml | 17 + .../lib/ygglibc/src/headers/netinet_in/mod.rs | 60 ++ .../lib/ygglibc/src/headers/stdio/util.rs | 39 +- .../lib/ygglibc/src/headers/stdlib/io.rs | 19 +- .../lib/ygglibc/src/headers/sys_mman/mod.rs | 16 +- .../src/headers/sys_socket/cbindgen.toml | 14 + .../lib/ygglibc/src/headers/sys_socket/io.rs | 72 ++ .../lib/ygglibc/src/headers/sys_socket/mod.rs | 86 +++ .../ygglibc/src/headers/sys_socket/option.rs | 35 + .../ygglibc/src/headers/sys_socket/socket.rs | 99 +++ .../lib/ygglibc/src/headers/time/timer.rs | 23 +- .../ygglibc/src/headers/wchar/multibyte.rs | 7 +- userspace/lib/ygglibc/src/random.rs | 3 - userspace/netutils/src/dhcp_client.rs | 2 +- userspace/netutils/src/ping.rs | 694 +++++++++--------- 46 files changed, 2267 insertions(+), 1317 deletions(-) delete mode 100644 kernel/driver/net/core/src/socket/tcp.rs create mode 100644 kernel/driver/net/core/src/socket/tcp/listener.rs create mode 100644 kernel/driver/net/core/src/socket/tcp/mod.rs create mode 100644 kernel/driver/net/core/src/socket/tcp/stream.rs create mode 100644 userspace/lib/ygglibc/src/headers/arpa_inet/cbindgen.toml create mode 100644 userspace/lib/ygglibc/src/headers/arpa_inet/mod.rs create mode 100644 userspace/lib/ygglibc/src/headers/netinet_in/cbindgen.toml create mode 100644 userspace/lib/ygglibc/src/headers/netinet_in/mod.rs create mode 100644 userspace/lib/ygglibc/src/headers/sys_socket/cbindgen.toml create mode 100644 userspace/lib/ygglibc/src/headers/sys_socket/io.rs create mode 100644 userspace/lib/ygglibc/src/headers/sys_socket/mod.rs create mode 100644 userspace/lib/ygglibc/src/headers/sys_socket/option.rs create mode 100644 userspace/lib/ygglibc/src/headers/sys_socket/socket.rs diff --git a/kernel/driver/net/core/src/l3/arp.rs b/kernel/driver/net/core/src/l3/arp.rs index 03f04da6..d37ca575 100644 --- a/kernel/driver/net/core/src/l3/arp.rs +++ b/kernel/driver/net/core/src/l3/arp.rs @@ -80,7 +80,10 @@ impl ArpTable { pub fn lookup_cache(interface: u32, address: IpAddr) -> Option { let (address, _) = match address { IpAddr::V4(address) => Self::lookup_cache_v4(interface, address), - IpAddr::V6(_) => todo!(), + IpAddr::V6(v6) => { + log::warn!("TODO: ArpTable v6 lookup: {v6}"); + return None; + } }?; Some(address) } @@ -92,7 +95,10 @@ impl ArpTable { pub fn flush_address(interface: u32, address: IpAddr) -> bool { match address { IpAddr::V4(address) => Self::flush_address_v4(interface, address), - IpAddr::V6(_) => todo!(), + IpAddr::V6(v6) => { + log::warn!("TODO: ArpTable v6 flush: {v6}"); + false + } } } @@ -103,7 +109,10 @@ impl ArpTable { pub fn insert_address(interface: u32, mac: MacAddress, address: IpAddr, owned: bool) { match address { IpAddr::V4(address) => Self::insert_address_v4(interface, mac, address, owned), - IpAddr::V6(_) => todo!(), + IpAddr::V6(v6) => { + log::warn!("TODO: ArpTable v6 insert: {v6}"); + return; + } } ARP_TABLE.notify.wake_all(); } @@ -203,7 +212,10 @@ fn send_request(interface: &NetworkInterface, query_address: IpAddr) -> Result<( match query_address { IpAddr::V4(address) => send_request_v4(interface, address), - IpAddr::V6(_) => todo!(), + IpAddr::V6(v6) => { + log::warn!("TODO: ARP IPv6 query: {v6}"); + Err(Error::NotImplemented) + } } } diff --git a/kernel/driver/net/core/src/l4/icmp.rs b/kernel/driver/net/core/src/l4/icmp.rs index f62ed7df..641e2f20 100644 --- a/kernel/driver/net/core/src/l4/icmp.rs +++ b/kernel/driver/net/core/src/l4/icmp.rs @@ -24,7 +24,7 @@ async fn send_v4_reply( }; if icmp_data.len() % 2 != 0 { - todo!(); + return Err(Error::InvalidArgument); } let l4_bytes = bytemuck::bytes_of(&reply_frame); @@ -77,6 +77,9 @@ async fn handle_v4(source_address: Ipv4Addr, l3_packet: L3Packet) -> Result<(), pub async fn handle(l3_packet: L3Packet) -> Result<(), Error> { match l3_packet.source_address { IpAddr::V4(v4) => handle_v4(v4, l3_packet).await, - IpAddr::V6(_) => todo!(), + IpAddr::V6(v6) => { + log::warn!("TODO: ICMPv6 from {v6}"); + Err(Error::NotImplemented) + } } } diff --git a/kernel/driver/net/core/src/l4/tcp.rs b/kernel/driver/net/core/src/l4/tcp.rs index 51fffd2e..9d70b63f 100644 --- a/kernel/driver/net/core/src/l4/tcp.rs +++ b/kernel/driver/net/core/src/l4/tcp.rs @@ -17,7 +17,7 @@ use yggdrasil_abi::{ use crate::{ l3::{self, L3Packet}, - socket::{TcpListener, TcpSocket}, + socket::tcp::{TcpListener, TcpStream}, util::Assembler, }; @@ -652,12 +652,15 @@ pub async fn handle(packet: L3Packet) -> Result<(), Error> { match tcp_frame.flags { TcpFlags::SYN => { - if let Some(listener) = TcpListener::get(local) { + if let Some(listener) = TcpListener::get(&local) + && listener.is_listening() + { + log::debug!("tcp: create remote stream {remote} -> {local}"); let window_size = u16::from_network_order(tcp_frame.window_size); let tx_seq = 12345; // Create a socket and insert it into the table - TcpSocket::accept_remote( + TcpStream::accept_remote( listener.clone(), local, remote, @@ -703,16 +706,17 @@ pub async fn handle(packet: L3Packet) -> Result<(), Error> { seq, }; - let socket = TcpSocket::get(local, remote).ok_or(Error::DoesNotExist)?; - let mut connection = socket.connection().write(); + let stream = TcpStream::get(local, remote).ok_or(Error::DoesNotExist)?; + let mut connection = stream.connection.write(); + match connection.handle_packet(packet, tcp_data).await? { TcpSocketBehavior::None => (), TcpSocketBehavior::Accept => { - socket.accept(); + stream.accept(); } TcpSocketBehavior::Remove => { drop(connection); - socket.remove_socket()?; + stream.remove_stream()?; } } Ok(()) diff --git a/kernel/driver/net/core/src/lib.rs b/kernel/driver/net/core/src/lib.rs index d86ab9d9..7cd5413d 100644 --- a/kernel/driver/net/core/src/lib.rs +++ b/kernel/driver/net/core/src/lib.rs @@ -1,4 +1,4 @@ -#![feature(map_try_insert)] +#![feature(map_try_insert, let_chains, result_flattening)] #![allow(clippy::type_complexity, clippy::new_without_default)] #![no_std] diff --git a/kernel/driver/net/core/src/socket/mod.rs b/kernel/driver/net/core/src/socket/mod.rs index 932ce250..6d4f08c8 100644 --- a/kernel/driver/net/core/src/socket/mod.rs +++ b/kernel/driver/net/core/src/socket/mod.rs @@ -1,18 +1,20 @@ -use alloc::{collections::BTreeMap, sync::Arc}; -use libk::vfs::Socket; +use alloc::{ + collections::{btree_map::Entry, BTreeMap}, + sync::Arc, +}; use yggdrasil_abi::{ error::Error, - net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, }; pub mod udp; pub use udp::UdpSocket; pub mod tcp; -pub use tcp::{TcpListener, TcpSocket}; +pub use tcp::TcpSocket; pub mod raw; pub use raw::RawSocket; -pub struct SocketTable { +pub struct SocketTable { inner: BTreeMap>, } @@ -71,7 +73,7 @@ impl TwoWaySocketTable { } } -impl SocketTable { +impl SocketTable { pub const fn new() -> Self { Self { inner: BTreeMap::new(), @@ -95,6 +97,29 @@ impl SocketTable { Err(Error::AddrInUse) } + pub fn bind_to_ephemeral_port(&mut self, ip: IpAddr, socket: Arc) -> Result { + for port in 32768..u16::MAX - 1 { + let local = SocketAddr::new(ip, port); + + match self.try_insert(local, socket.clone()) { + Ok(()) => return Ok(port), + Err(Error::AddrInUse) => continue, + Err(error) => return Err(error), + } + } + Err(Error::AddrInUse) + } + + pub fn try_insert(&mut self, address: SocketAddr, socket: Arc) -> Result<(), Error> { + match self.inner.entry(address) { + Entry::Vacant(entry) => { + entry.insert(socket); + Ok(()) + } + Entry::Occupied(_) => Err(Error::AddrInUse), + } + } + pub fn try_insert_with Result, Error>>( &mut self, address: SocketAddr, @@ -124,12 +149,11 @@ impl SocketTable { return Some(socket.clone()); } - match local { - SocketAddr::V4(_v4) => { - let unspec_v4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, local.port()); - self.inner.get(&SocketAddr::V4(unspec_v4)).cloned() - } - SocketAddr::V6(_) => todo!(), - } + let unspec = match local { + SocketAddr::V4(_v4) => SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, local.port()).into(), + SocketAddr::V6(_v6) => SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, local.port()).into(), + }; + + self.inner.get(&unspec).cloned() } } diff --git a/kernel/driver/net/core/src/socket/raw.rs b/kernel/driver/net/core/src/socket/raw.rs index 440041c1..424bd343 100644 --- a/kernel/driver/net/core/src/socket/raw.rs +++ b/kernel/driver/net/core/src/socket/raw.rs @@ -2,26 +2,25 @@ use core::{ fmt, sync::atomic::{AtomicU32, Ordering}, task::{Context, Poll}, + time::Duration, }; use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec}; use async_trait::async_trait; use libk::{ error::Error, + task::runtime::maybe_timeout, vfs::{FileReadiness, PacketSocket, Socket}, }; use libk_mm::PageBox; -use libk_util::{ - queue::BoundedMpmcQueue, - sync::{spin_rwlock::IrqSafeRwLock, IrqSafeSpinlock}, -}; -use yggdrasil_abi::net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketInterfaceQuery, SocketOption}; +use libk_util::{queue::BoundedMpmcQueue, sync::spin_rwlock::IrqSafeRwLock}; +use yggdrasil_abi::net::{SocketAddr, SocketInterfaceQuery, SocketOption}; use crate::{ethernet::L2Packet, interface::NetworkInterface}; pub struct RawSocket { id: u32, - bound: IrqSafeSpinlock>, + bound: IrqSafeRwLock>, receive_queue: BoundedMpmcQueue, } @@ -32,17 +31,15 @@ static BOUND_RAW_SOCKETS: IrqSafeRwLock>> = IrqSafeRwLock::new(BTreeMap::new()); impl RawSocket { - pub fn bind() -> Result, Error> { + pub fn new() -> Arc { let id = RAW_SOCKET_ID.fetch_add(1, Ordering::SeqCst); - let socket = Self { + let socket = Arc::new(Self { id, - bound: IrqSafeSpinlock::new(None), + bound: IrqSafeRwLock::new(None), receive_queue: BoundedMpmcQueue::new(256), - }; - let socket = Arc::new(socket); + }); RAW_SOCKETS.write().insert(id, socket.clone()); - - Ok(socket) + socket } fn bound_packet_received(&self, packet: L2Packet) { @@ -57,6 +54,7 @@ impl RawSocket { if let Some(ids) = bound_sockets.get(&packet.interface_id) { for id in ids { let socket = raw_sockets.get(id).unwrap(); + log::info!("Packet -> {id}"); socket.bound_packet_received(packet.clone()); } } @@ -80,10 +78,14 @@ impl FileReadiness for RawSocket { } impl Socket for RawSocket { + fn bind(self: Arc, _local: SocketAddr) -> Result<(), Error> { + Err(Error::InvalidOperation) + } + fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> { match option { SocketOption::BoundHardwareAddress(mac) => { - let bound = self.bound.lock().ok_or(Error::DoesNotExist)?; + let bound = self.bound.read().ok_or(Error::NotConnected)?; let interface = NetworkInterface::get(bound).unwrap(); *mac = interface.mac; Ok(()) @@ -95,18 +97,16 @@ impl Socket for RawSocket { fn set_option(&self, option: &SocketOption) -> Result<(), Error> { match option { SocketOption::BindInterface(query) => { - let mut bound = self.bound.lock(); + let mut bound = self.bound.write(); if bound.is_some() { return Err(Error::AlreadyExists); } let mut bound_sockets = BOUND_RAW_SOCKETS.write(); - let interface = match *query { SocketInterfaceQuery::ById(id) => NetworkInterface::get(id), SocketInterfaceQuery::ByName(name) => NetworkInterface::query_by_name(name), }?; - let list = bound_sockets.entry(interface.id).or_default(); bound.replace(interface.id); list.push(self.id); @@ -118,8 +118,8 @@ impl Socket for RawSocket { } } - fn close(&self) -> Result<(), Error> { - let bound = self.bound.lock().take(); + fn close(self: Arc) -> Result<(), Error> { + let bound = self.bound.write().take(); if let Some(bound) = bound { let mut bound_sockets = BOUND_RAW_SOCKETS.write(); @@ -140,8 +140,8 @@ impl Socket for RawSocket { Ok(()) } - fn local_address(&self) -> SocketAddr { - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)) + fn local_address(&self) -> Option { + None } fn remote_address(&self) -> Option { @@ -151,18 +151,27 @@ impl Socket for RawSocket { #[async_trait] impl PacketSocket for RawSocket { - async fn send_to(&self, destination: Option, data: &[u8]) -> Result { + fn connect(self: Arc, _remote: SocketAddr) -> Result<(), Error> { + Err(Error::InvalidOperation) + } + + async fn send_to( + self: Arc, + destination: Option, + data: &[u8], + _timeout: Option, + ) -> Result { self.send_nonblocking(destination, data) } // TODO currently this is still blocking by NIC send code fn send_nonblocking( - &self, + self: Arc, _destination: Option, buffer: &[u8], ) -> Result { // TODO cap by MTU? - let bound = self.bound.lock().ok_or(Error::InvalidOperation)?; + let bound = self.bound.read().ok_or(Error::InvalidOperation)?; let interface = NetworkInterface::get(bound)?; let l2_offset = interface.device.packet_prefix_size(); if buffer.len() > 4096 - l2_offset { @@ -174,12 +183,20 @@ impl PacketSocket for RawSocket { Ok(buffer.len()) } - async fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> { - let packet = self.receive_queue.pop_front().await; + async fn receive_from( + self: Arc, + buffer: &mut [u8], + timeout: Option, + ) -> Result<(usize, SocketAddr), Error> { + let future = self.receive_queue.pop_front(); + let packet = maybe_timeout(future, timeout).await?; Self::packet_to_user(packet, buffer) } - fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> { + fn receive_nonblocking( + self: Arc, + buffer: &mut [u8], + ) -> Result<(usize, SocketAddr), Error> { let packet = self .receive_queue .try_pop_front() @@ -190,7 +207,7 @@ impl PacketSocket for RawSocket { impl fmt::Debug for RawSocket { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let bound = *self.bound.lock(); + let bound = *self.bound.read(); f.debug_struct("RawSocket") .field("interface", &bound) .finish_non_exhaustive() diff --git a/kernel/driver/net/core/src/socket/tcp.rs b/kernel/driver/net/core/src/socket/tcp.rs deleted file mode 100644 index 2f789120..00000000 --- a/kernel/driver/net/core/src/socket/tcp.rs +++ /dev/null @@ -1,495 +0,0 @@ -use core::{ - fmt, - future::{poll_fn, Future}, - sync::atomic::{AtomicU8, Ordering}, - task::{Context, Poll}, - time::Duration, -}; - -use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec}; -use async_trait::async_trait; -use libk::{ - block, - error::Error, - task::runtime::with_timeout, - time::monotonic_time, - vfs::{ConnectionSocket, FileReadiness, ListenerSocket, Socket}, -}; -use libk_util::{ - sync::{ - spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockWriteGuard}, - IrqSafeSpinlock, IrqSafeSpinlockGuard, - }, - waker::QueueWaker, -}; -use yggdrasil_abi::net::{SocketAddr, SocketOption}; - -use crate::{ - interface::NetworkInterface, - l3::Route, - l4::tcp::{TcpConnection, TcpConnectionState}, -}; - -use super::{SocketTable, TwoWaySocketTable}; - -pub struct TcpSocket { - pub(crate) local: SocketAddr, - pub(crate) remote: SocketAddr, - - ttl: AtomicU8, - - // Listener which accepted the socket - listener: Option>, - connection: IrqSafeRwLock, -} - -pub struct TcpListener { - accept: SocketAddr, - - // Currently active sockets - sockets: IrqSafeRwLock>>, - pending_accept: IrqSafeSpinlock>>, - accept_notify: QueueWaker, -} - -static TCP_SOCKETS: IrqSafeRwLock> = - IrqSafeRwLock::new(TwoWaySocketTable::new()); -static TCP_LISTENERS: IrqSafeRwLock> = - IrqSafeRwLock::new(SocketTable::new()); - -impl TcpSocket { - pub async fn connect( - remote: SocketAddr, - timeout: Option, - ) -> Result<(SocketAddr, Arc), Error> { - let future = Self::connect_async(remote); - match timeout { - Some(timeout) => with_timeout(future, timeout).await?, - None => future.await, - } - } - - pub fn accept_remote( - listener: Arc, - local: SocketAddr, - remote: SocketAddr, - remote_window_size: usize, - tx_seq: u32, - rx_seq: u32, - ) -> Result, Error> { - let mut sockets = TCP_SOCKETS.write(); - sockets.try_insert_with(local, remote, move || { - let connection = TcpConnection::new( - local, - remote, - remote_window_size, - tx_seq, - rx_seq, - TcpConnectionState::SynReceived, - ); - - log::debug!("Accepted TCP socket {} -> {}", local, remote); - - let socket = Self { - local, - remote, - ttl: AtomicU8::new(64), - listener: Some(listener), - connection: IrqSafeRwLock::new(connection), - }; - - Ok(Arc::new(socket)) - }) - } - - pub fn connection(&self) -> &IrqSafeRwLock { - &self.connection - } - - pub(crate) fn accept(self: &Arc) { - if let Some(listener) = self.listener.as_ref() { - listener.accept_socket(self.clone()); - } - } - - pub fn get(local: SocketAddr, remote: SocketAddr) -> Option> { - TCP_SOCKETS.read().get(local, remote) - } - - pub fn receive_async<'a>( - &'a self, - buffer: &'a mut [u8], - ) -> impl Future> + 'a { - poll_fn(|cx| match self.poll_receive(cx) { - Poll::Ready(Ok(mut lock)) => Poll::Ready(lock.read_nonblocking(buffer)), - Poll::Ready(Err(error)) => Poll::Ready(Err(error)), - Poll::Pending => Poll::Pending, - }) - } - - pub async fn send_async(&self, data: &[u8]) -> Result { - let mut pos = 0; - let mut rem = data.len(); - while rem != 0 { - // TODO check MTU - let amount = rem.min(512); - self.send_segment_async(&data[pos..pos + amount]).await?; - pos += amount; - rem -= amount; - } - Ok(pos) - } - - pub async fn close_async(&self, remove_from_listener: bool) -> Result<(), Error> { - // TODO timeout here - // Already closing - if self.connection.read().is_closing() { - return Ok(()); - } - - // Wait for all sent data to be acknowledged - { - let mut connection = poll_fn(|cx| { - let connection = self.connection.write(); - match connection.poll_send(cx) { - Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)), - Poll::Ready(Err(error)) => Poll::Ready(Err(error)), - Poll::Pending => Poll::Pending, - } - }) - .await?; - - connection.finish().await?; - } - - log::debug!( - "TCP socket closed (FinWait2/Closed): {} <-> {}", - self.local, - self.remote - ); - - // Wait for connection to get closed - poll_fn(|cx| { - let connection = self.connection.read(); - connection.poll_finish(cx) - }) - .await; - - if remove_from_listener { - if let Some(listener) = self.listener.as_ref() { - listener.remove_socket(self.remote); - }; - } - - Ok(()) - } - - pub(crate) fn remove_socket(&self) -> Result<(), Error> { - log::debug!( - "TCP socket closed and removed: {} <-> {}", - self.local, - self.remote - ); - let connection = self.connection.read(); - debug_assert!(connection.is_closed()); - TCP_SOCKETS.write().remove(self.local, self.remote)?; - connection.notify_all(); - Ok(()) - } - - fn poll_receive( - &self, - cx: &mut Context<'_>, - ) -> Poll, Error>> { - let lock = self.connection.write(); - match lock.poll_receive(cx) { - Poll::Ready(Ok(())) => Poll::Ready(Ok(lock)), - Poll::Ready(Err(error)) => Poll::Ready(Err(error)), - Poll::Pending => Poll::Pending, - } - } - - async fn send_segment_async(&self, data: &[u8]) -> Result<(), Error> { - // TODO timeout here - { - let mut connection = poll_fn(|cx| { - let connection = self.connection.write(); - match connection.poll_send(cx) { - Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)), - Poll::Ready(Err(error)) => Poll::Ready(Err(error)), - Poll::Pending => Poll::Pending, - } - }) - .await?; - - connection.transmit(data).await?; - } - - poll_fn(|cx| { - let connection = self.connection.read(); - connection.poll_acknowledge(cx) - }) - .await; - - Ok(()) - } - - async fn connect_async(remote: SocketAddr) -> Result<(SocketAddr, Arc), Error> { - // Lookup route to remote - let (interface_id, _, remote_ip) = - Route::lookup(remote.ip()).ok_or(Error::HostUnreachable)?; - let remote = SocketAddr::new(remote_ip, remote.port()); - let interface = NetworkInterface::get(interface_id)?; - let local_ip = interface.address.read().ok_or(Error::NetworkUnreachable)?; - - let socket = { - let mut sockets = TCP_SOCKETS.write(); - sockets.try_insert_with_ephemeral_port(local_ip, remote, |port| { - let t = monotonic_time(); - let tx_seq = t.as_millis() as u32; - let local = SocketAddr::new(local_ip, port); - let connection = - TcpConnection::new(local, remote, 16384, tx_seq, 0, TcpConnectionState::Closed); - - let socket = Self { - local, - remote, - ttl: AtomicU8::new(64), - listener: None, - connection: IrqSafeRwLock::new(connection), - }; - - Ok(Arc::new(socket)) - })? - }; - - let mut t = 200; - for _ in 0..5 { - let timeout = Duration::from_millis(t); - log::debug!("Try SYN with timeout={:?}", timeout); - match socket.try_connect(timeout).await { - Ok(()) => return Ok((socket.local, socket)), - Err(Error::TimedOut) => (), - Err(error) => return Err(error), - } - t *= 2; - } - - // Couldn't establish - Err(Error::TimedOut) - } - - async fn try_connect(&self, timeout: Duration) -> Result<(), Error> { - { - let mut connection = self.connection.write(); - connection.send_syn().await?; - } - - let fut = poll_fn(|cx| { - let connection = self.connection.read(); - connection.poll_established(cx) - }); - - with_timeout(fut, timeout).await? - } -} - -impl Socket for TcpSocket { - fn local_address(&self) -> SocketAddr { - self.local - } - - fn remote_address(&self) -> Option { - Some(self.remote) - } - - fn close(&self) -> Result<(), Error> { - block!(self.close_async(true).await)? - } - - fn set_option(&self, option: &SocketOption) -> Result<(), Error> { - match option { - &SocketOption::Ttl(ttl) => { - if ttl == 0 || ttl > 255 { - return Err(Error::InvalidArgument); - } - self.ttl.store(ttl as _, Ordering::Release); - Ok(()) - } - SocketOption::NoDelay(_) => { - log::warn!("TODO: TCP nodelay"); - Ok(()) - } - _ => Err(Error::InvalidOperation), - } - } - - fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> { - match option { - SocketOption::Ttl(ttl) => { - *ttl = self.ttl.load(Ordering::Acquire) as _; - Ok(()) - } - SocketOption::NoDelay(nodelay) => { - *nodelay = false; - Ok(()) - } - _ => Err(Error::InvalidOperation), - } - } -} - -impl FileReadiness for TcpSocket { - fn poll_read(&self, cx: &mut Context<'_>) -> Poll> { - self.poll_receive(cx).map_ok(|_| ()) - } -} - -#[async_trait] -impl ConnectionSocket for TcpSocket { - async fn receive(&self, buffer: &mut [u8]) -> Result { - self.receive_async(buffer).await - } - - fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result { - match self.connection.write().read_nonblocking(buffer) { - // TODO check if this really means "no data at the moment" - Ok(0) => Err(Error::WouldBlock), - Ok(len) => Ok(len), - Err(error) => Err(error), - } - } - - async fn send(&self, data: &[u8]) -> Result { - self.send_async(data).await - } - - fn send_nonblocking(&self, data: &[u8]) -> Result { - log::warn!("TODO: TCP::send_nonblocking"); - block!(self.send_async(data).await)? - } -} - -impl fmt::Debug for TcpSocket { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TcpSocket") - .field("local", &self.local) - .field("remote", &self.remote) - .finish_non_exhaustive() - } -} - -impl TcpListener { - pub fn bind(accept: SocketAddr) -> Result, Error> { - TCP_LISTENERS.write().try_insert_with(accept, || { - let listener = TcpListener { - accept, - sockets: IrqSafeRwLock::new(BTreeMap::new()), - pending_accept: IrqSafeSpinlock::new(Vec::new()), - accept_notify: QueueWaker::new(), - }; - - log::debug!("TCP Listener opened: {}", accept); - - Ok(Arc::new(listener)) - }) - } - - pub fn get(local: SocketAddr) -> Option> { - TCP_LISTENERS.read().get(&local) - } - - pub fn accept_async(&self) -> impl Future, Error>> + '_ { - poll_fn(|cx| match self.poll_accept(cx) { - Poll::Ready(mut lock) => Poll::Ready(Ok(lock.pop().unwrap())), - Poll::Pending => Poll::Pending, - }) - } - - fn accept_socket(&self, socket: Arc) { - log::debug!("{}: accept {}", self.accept, socket.remote); - self.sockets.write().insert(socket.remote, socket.clone()); - self.pending_accept.lock().push(socket); - self.accept_notify.wake_one(); - } - - fn remove_socket(&self, remote: SocketAddr) { - log::debug!("Remove client {}", remote); - self.sockets.write().remove(&remote); - } - - fn poll_accept(&self, cx: &mut Context<'_>) -> Poll>>> { - let lock = self.pending_accept.lock(); - self.accept_notify.register(cx.waker()); - if !lock.is_empty() { - self.accept_notify.remove(cx.waker()); - Poll::Ready(lock) - } else { - Poll::Pending - } - } -} - -impl Socket for TcpListener { - fn local_address(&self) -> SocketAddr { - self.accept - } - - fn remote_address(&self) -> Option { - None - } - - fn close(&self) -> Result<(), Error> { - // TODO if clients not closed already, send RST? - TCP_LISTENERS.write().remove(self.accept) - } - - fn set_option(&self, option: &SocketOption) -> Result<(), Error> { - match option { - SocketOption::Ipv6Only(_v6only) => { - log::warn!("TODO: TCP listener IPv6-only"); - Err(Error::InvalidOperation) - } - _ => Err(Error::InvalidOperation), - } - } - - fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> { - match option { - SocketOption::Ipv6Only(v6only) => { - *v6only = false; - Ok(()) - } - _ => Err(Error::InvalidOperation), - } - } -} - -impl FileReadiness for TcpListener { - fn poll_read(&self, cx: &mut Context<'_>) -> Poll> { - self.poll_accept(cx).map(|_| Ok(())) - } -} - -#[async_trait] -impl ListenerSocket for TcpListener { - async fn accept(&self) -> Result<(SocketAddr, Arc), Error> { - let socket = self.accept_async().await?; - let remote = socket.remote; - Ok((remote, socket)) - } - - fn accept_nonblocking(&self) -> Result<(SocketAddr, Arc), Error> { - let socket = self.pending_accept.lock().pop().ok_or(Error::WouldBlock)?; - let remote = socket.remote; - Ok((remote, socket)) - } -} - -impl fmt::Debug for TcpListener { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TcpListener") - .field("local", &self.accept) - .finish_non_exhaustive() - } -} diff --git a/kernel/driver/net/core/src/socket/tcp/listener.rs b/kernel/driver/net/core/src/socket/tcp/listener.rs new file mode 100644 index 00000000..2840101e --- /dev/null +++ b/kernel/driver/net/core/src/socket/tcp/listener.rs @@ -0,0 +1,112 @@ +use core::{ + fmt, + future::poll_fn, + sync::atomic::{AtomicBool, Ordering}, + task::{Context, Poll}, +}; + +use alloc::{collections::btree_map::BTreeMap, sync::Arc, vec::Vec}; +use libk::error::Error; +use libk_util::{ + sync::{spin_rwlock::IrqSafeRwLock, IrqSafeSpinlock, IrqSafeSpinlockGuard}, + waker::QueueWaker, +}; +use yggdrasil_abi::net::SocketAddr; + +use crate::socket::SocketTable; + +use super::TcpStream; + +pub struct TcpListener { + pub(super) local: SocketAddr, + pub(super) listening: AtomicBool, + + // Currently active sockets + sockets: IrqSafeRwLock>>, + pending_accept: IrqSafeSpinlock>>, + accept_notify: QueueWaker, +} + +static TCP_LISTENERS: IrqSafeRwLock> = + IrqSafeRwLock::new(SocketTable::new()); + +impl TcpListener { + pub fn bind(local: SocketAddr) -> Result, Error> { + if local.port() == 0 { + // TODO: ephemeral binding + todo!(); + } + + let listener = Arc::new(TcpListener { + local, + listening: AtomicBool::new(false), + + sockets: IrqSafeRwLock::new(BTreeMap::new()), + pending_accept: IrqSafeSpinlock::new(Vec::new()), + accept_notify: QueueWaker::new(), + }); + TCP_LISTENERS.write().try_insert(local, listener.clone())?; + + Ok(listener) + } + + pub fn close(&self) -> Result<(), Error> { + // TODO if clients not closed already, send RST? + TCP_LISTENERS.write().remove(self.local) + } + + pub(super) async fn accept(&self) -> Result, Error> { + poll_fn(|cx| match self.poll_accept(cx) { + Poll::Ready(mut lock) => Poll::Ready(Ok(lock.pop().unwrap())), + Poll::Pending => Poll::Pending, + }) + .await + } + + pub(super) fn accept_nonblocking(&self) -> Result, Error> { + self.pending_accept.lock().pop().ok_or(Error::WouldBlock) + } + + pub(super) fn poll_accept( + &self, + cx: &mut Context<'_>, + ) -> Poll>>> { + let lock = self.pending_accept.lock(); + self.accept_notify.register(cx.waker()); + if !lock.is_empty() { + self.accept_notify.remove(cx.waker()); + Poll::Ready(lock) + } else { + Poll::Pending + } + } + + pub(super) fn accept_stream(&self, stream: Arc) { + log::debug!("{}: accept {}", self.local, stream.remote); + self.sockets.write().insert(stream.remote, stream.clone()); + self.pending_accept.lock().push(stream); + self.accept_notify.wake_one(); + } + + pub(super) fn remove_stream(&self, remote: SocketAddr) { + log::debug!("Remove remote stream {}", remote); + self.sockets.write().remove(&remote); + } + + pub fn is_listening(&self) -> bool { + self.listening.load(Ordering::Acquire) + } + + pub fn get(local: &SocketAddr) -> Option> { + TCP_LISTENERS.read().get(local) + } +} + +impl fmt::Debug for TcpListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TcpListener") + .field("listening", &self.is_listening()) + .field("local", &self.local) + .finish_non_exhaustive() + } +} diff --git a/kernel/driver/net/core/src/socket/tcp/mod.rs b/kernel/driver/net/core/src/socket/tcp/mod.rs new file mode 100644 index 00000000..a4a31f80 --- /dev/null +++ b/kernel/driver/net/core/src/socket/tcp/mod.rs @@ -0,0 +1,270 @@ +use core::{ + fmt, mem, + sync::atomic::Ordering, + task::{Context, Poll}, + time::Duration, +}; + +use alloc::{boxed::Box, sync::Arc}; +use async_trait::async_trait; +use libk::{ + block, + error::Error, + task::runtime::maybe_timeout, + vfs::{ConnectionSocket, FileReadiness, Socket}, +}; +use libk_util::sync::spin_rwlock::IrqSafeRwLock; +use yggdrasil_abi::net::{SocketAddr, SocketOption}; + +mod listener; +mod stream; + +pub use listener::TcpListener; +pub use stream::TcpStream; + +pub struct TcpSocket { + options: IrqSafeRwLock, + inner: IrqSafeRwLock, +} + +pub enum TcpSocketInner { + Empty, + Listener(Arc), + Stream(Arc), +} + +struct TcpSocketOptions { + ttl: u8, + nodelay: bool, + v6_only: bool, +} + +impl Default for TcpSocketOptions { + fn default() -> Self { + Self { + ttl: 64, + nodelay: false, + v6_only: false, + } + } +} + +impl TcpSocket { + pub fn new() -> Arc { + Arc::new(Self { + options: IrqSafeRwLock::new(TcpSocketOptions::default()), + inner: IrqSafeRwLock::new(TcpSocketInner::Empty), + }) + } + + fn as_stream(&self) -> Option> { + if let TcpSocketInner::Stream(stream) = &*self.inner.read() { + Some(stream.clone()) + } else { + None + } + } + + fn as_listener(&self) -> Option> { + if let TcpSocketInner::Listener(listener) = &*self.inner.read() { + Some(listener.clone()) + } else { + None + } + } +} + +impl Socket for TcpSocket { + fn bind(self: Arc, local: SocketAddr) -> Result<(), Error> { + let mut inner = self.inner.write(); + + // Already connected or bound + if !matches!(&*inner, TcpSocketInner::Empty) { + return Err(Error::InvalidOperation); + } + + // Bind a listener socket + let listener = TcpListener::bind(local)?; + *inner = TcpSocketInner::Listener(listener); + + Ok(()) + } + + fn close(self: Arc) -> Result<(), Error> { + let inner = mem::replace(&mut *self.inner.write(), TcpSocketInner::Empty); + + match inner { + TcpSocketInner::Empty => Ok(()), + TcpSocketInner::Stream(socket) => block!(socket.close(true).await)?, + TcpSocketInner::Listener(socket) => socket.close(), + } + } + + fn local_address(&self) -> Option { + match &*self.inner.read() { + TcpSocketInner::Empty => None, + TcpSocketInner::Stream(socket) => Some(socket.local), + TcpSocketInner::Listener(socket) => Some(socket.local), + } + } + + fn remote_address(&self) -> Option { + match &*self.inner.read() { + TcpSocketInner::Empty | TcpSocketInner::Listener(_) => None, + TcpSocketInner::Stream(socket) => Some(socket.remote), + } + } + + fn set_option(&self, option: &SocketOption) -> Result<(), Error> { + match *option { + SocketOption::Ttl(ttl) => { + if !(1..256).contains(&ttl) { + return Err(Error::InvalidArgument); + } + let ttl = ttl as u8; + self.options.write().ttl = ttl; + Ok(()) + } + SocketOption::NoDelay(nodelay) => { + self.options.write().nodelay = nodelay; + Ok(()) + } + SocketOption::Ipv6Only(v6_only) => { + self.options.write().v6_only = v6_only; + Ok(()) + } + _ => Err(Error::InvalidOperation), + } + } + + fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> { + match option { + SocketOption::Ttl(ttl) => { + *ttl = self.options.read().ttl as u32; + Ok(()) + } + SocketOption::NoDelay(nodelay) => { + *nodelay = self.options.read().nodelay; + Ok(()) + } + SocketOption::Ipv6Only(v6_only) => { + *v6_only = self.options.read().v6_only; + Ok(()) + } + _ => Err(Error::InvalidOperation), + } + } +} + +#[async_trait] +impl ConnectionSocket for TcpSocket { + async fn connect( + self: Arc, + remote: SocketAddr, + _timeout: Option, + ) -> Result<(), Error> { + let mut inner = self.inner.write(); + + // Already connected or bound + if !matches!(&*inner, TcpSocketInner::Empty) { + return Err(Error::InvalidOperation); + } + + let stream = TcpStream::connect(remote).await?; + *inner = TcpSocketInner::Stream(stream); + + Ok(()) + } + + async fn receive( + &self, + buffer: &mut [u8], + timeout: Option, + ) -> Result<(usize, SocketAddr), Error> { + let stream = self.as_stream().ok_or(Error::NotConnected)?; + let len = stream.receive(buffer, timeout).await?; + Ok((len, stream.remote)) + } + + fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> { + let stream = self.as_stream().ok_or(Error::NotConnected)?; + let len = stream.receive_nonblocking(buffer)?; + Ok((len, stream.remote)) + } + + async fn send(&self, data: &[u8], timeout: Option) -> Result { + let stream = self.as_stream().ok_or(Error::NotConnected)?; + let future = stream.send(data); + maybe_timeout(future, timeout).await? + } + + fn send_nonblocking(&self, buffer: &[u8]) -> Result { + let stream = self.as_stream().ok_or(Error::NotConnected)?; + let len = stream.send_nonblocking(buffer)?; + Ok(len) + } + + fn listen(self: Arc) -> Result<(), Error> { + let listener = self.as_listener().ok_or(Error::InvalidOperation)?; + match listener + .listening + .compare_exchange(false, true, Ordering::Release, Ordering::Relaxed) + { + Ok(_) => Ok(()), + Err(_) => Err(Error::InvalidOperation), + } + } + + async fn accept(&self) -> Result<(SocketAddr, Arc), Error> { + let listener = self.as_listener().ok_or(Error::InvalidOperation)?; + let stream = listener.accept().await?; + let remote = stream.remote; + let socket = Arc::new(TcpSocket { + options: IrqSafeRwLock::new(TcpSocketOptions::default()), + inner: IrqSafeRwLock::new(TcpSocketInner::Stream(stream)), + }); + Ok((remote, socket)) + } + + fn accept_nonblocking(&self) -> Result<(SocketAddr, Arc), Error> { + let listener = self.as_listener().ok_or(Error::InvalidOperation)?; + let stream = listener.accept_nonblocking()?; + let remote = stream.remote; + let socket = Arc::new(TcpSocket { + options: IrqSafeRwLock::new(TcpSocketOptions::default()), + inner: IrqSafeRwLock::new(TcpSocketInner::Stream(stream)), + }); + Ok((remote, socket)) + } + + async fn shutdown(&self, read: bool, write: bool) -> Result<(), Error> { + log::warn!("TODO: shutdown(read = {read}, write = {write})"); + Ok(()) + } +} + +impl FileReadiness for TcpSocket { + fn poll_read(&self, cx: &mut Context<'_>) -> Poll> { + match &*self.inner.read() { + TcpSocketInner::Empty => Poll::Ready(Ok(())), + TcpSocketInner::Stream(socket) => socket.poll_receive(cx).map_ok(|_| ()), + TcpSocketInner::Listener(socket) => socket.poll_accept(cx).map(|_| Ok(())), + } + } +} + +impl fmt::Debug for TcpSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &*self.inner.read() { + TcpSocketInner::Empty => f.debug_struct("TcpSocket").finish_non_exhaustive(), + TcpSocketInner::Stream(socket) => f + .debug_struct("TcpSocket") + .field("stream", socket) + .finish_non_exhaustive(), + TcpSocketInner::Listener(socket) => f + .debug_struct("TcpSocket") + .field("listener", socket) + .finish_non_exhaustive(), + } + } +} diff --git a/kernel/driver/net/core/src/socket/tcp/stream.rs b/kernel/driver/net/core/src/socket/tcp/stream.rs new file mode 100644 index 00000000..c06e633d --- /dev/null +++ b/kernel/driver/net/core/src/socket/tcp/stream.rs @@ -0,0 +1,309 @@ +use core::{ + fmt, + future::poll_fn, + task::{Context, Poll}, + time::Duration, +}; + +use alloc::sync::Arc; +use libk::{ + error::Error, + task::runtime::{maybe_timeout, with_timeout}, + time::monotonic_time, +}; +use libk_util::sync::spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockWriteGuard}; +use yggdrasil_abi::net::SocketAddr; + +use crate::{ + interface::NetworkInterface, + l3::Route, + l4::tcp::{TcpConnection, TcpConnectionState}, + socket::TwoWaySocketTable, +}; + +use super::TcpListener; + +pub struct TcpStream { + pub(super) local: SocketAddr, + pub(super) remote: SocketAddr, + + // Listener which accepted the socket + listener: Option>, + pub(crate) connection: IrqSafeRwLock, +} + +static TCP_STREAMS: IrqSafeRwLock> = + IrqSafeRwLock::new(TwoWaySocketTable::new()); + +impl TcpStream { + pub async fn connect(remote: SocketAddr) -> Result, Error> { + // Lookup route to remote + let (interface_id, _, remote_ip) = + Route::lookup(remote.ip()).ok_or(Error::HostUnreachable)?; + let remote = SocketAddr::new(remote_ip, remote.port()); + let interface = NetworkInterface::get(interface_id)?; + let local_ip = interface.address.read().ok_or(Error::NetworkUnreachable)?; + + // Create a new TcpStream with an ephemeral port + let stream = { + let mut streams = TCP_STREAMS.write(); + streams.try_insert_with_ephemeral_port(local_ip, remote, |port| { + let t = monotonic_time(); + let tx_seq = t.as_millis() as u32; + let local = SocketAddr::new(local_ip, port); + let connection = + TcpConnection::new(local, remote, 16384, tx_seq, 0, TcpConnectionState::Closed); + + Ok(Arc::new(Self { + local, + remote, + listener: None, + connection: IrqSafeRwLock::new(connection), + })) + })? + }; + + let mut t = 200; + for _ in 0..5 { + let timeout = Duration::from_millis(t); + log::debug!("Try SYN with timeout={:?}", timeout); + match stream.try_connect(timeout).await { + Ok(()) => return Ok(stream), + Err(Error::TimedOut) => (), + Err(error) => return Err(error), + } + t *= 2; + } + + // Couldn't establish + stream.close(false).await.ok(); + + Err(Error::TimedOut) + } + + async fn try_connect(&self, timeout: Duration) -> Result<(), Error> { + { + let mut connection = self.connection.write(); + connection.send_syn().await?; + } + + let fut = poll_fn(|cx| { + let connection = self.connection.read(); + connection.poll_established(cx) + }); + + with_timeout(fut, timeout).await? + } + + pub fn accept(self: &Arc) { + if let Some(listener) = self.listener.as_ref() { + listener.accept_stream(self.clone()); + } + } + + pub fn accept_remote( + listener: Arc, + local: SocketAddr, + remote: SocketAddr, + remote_window_size: usize, + tx_seq: u32, + rx_seq: u32, + ) -> Result, Error> { + let mut streams = TCP_STREAMS.write(); + + streams.try_insert_with(local, remote, move || { + let connection = TcpConnection::new( + local, + remote, + remote_window_size, + tx_seq, + rx_seq, + TcpConnectionState::SynReceived, + ); + + log::debug!("Accepted TCP socket {} -> {}", local, remote); + + let socket = Self { + local, + remote, + listener: Some(listener), + connection: IrqSafeRwLock::new(connection), + }; + + Ok(Arc::new(socket)) + }) + } + + pub async fn close(&self, remove_from_listener: bool) -> Result<(), Error> { + // TODO timeout here + // Already closing + if self.connection.read().is_closing() { + return Ok(()); + } + + // Wait for all sent data to be acknowledged + { + let mut connection = poll_fn(|cx| { + let connection = self.connection.write(); + match connection.poll_send(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)), + Poll::Ready(Err(error)) => Poll::Ready(Err(error)), + Poll::Pending => Poll::Pending, + } + }) + .await?; + + connection.finish().await?; + } + + log::debug!( + "TCP socket closed (FinWait2/Closed): {} <-> {}", + self.local, + self.remote + ); + + // Wait for connection to get closed + poll_fn(|cx| { + let connection = self.connection.read(); + connection.poll_finish(cx) + }) + .await; + + if remove_from_listener { + if let Some(listener) = self.listener.as_ref() { + listener.remove_stream(self.remote); + }; + } + + Ok(()) + } + + pub(crate) fn remove_stream(&self) -> Result<(), Error> { + log::debug!( + "TCP socket closed and removed: {} <-> {}", + self.local, + self.remote + ); + let connection = self.connection.read(); + debug_assert!(connection.is_closed()); + TCP_STREAMS.write().remove(self.local, self.remote)?; + connection.notify_all(); + Ok(()) + } + + pub fn get(local: SocketAddr, remote: SocketAddr) -> Option> { + TCP_STREAMS.read().get(local, remote) + } + + pub(super) fn poll_receive( + &self, + cx: &mut Context<'_>, + ) -> Poll, Error>> { + let lock = self.connection.write(); + match lock.poll_receive(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(lock)), + Poll::Ready(Err(error)) => Poll::Ready(Err(error)), + Poll::Pending => Poll::Pending, + } + } + + pub async fn receive( + &self, + buffer: &mut [u8], + timeout: Option, + ) -> Result { + let future = poll_fn(|cx| match self.poll_receive(cx) { + Poll::Ready(Ok(mut lock)) => Poll::Ready(lock.read_nonblocking(buffer)), + Poll::Ready(Err(error)) => Poll::Ready(Err(error)), + Poll::Pending => Poll::Pending, + }); + + maybe_timeout(future, timeout).await.flatten() + } + + pub async fn send(&self, data: &[u8]) -> Result { + let mut pos = 0; + let mut rem = data.len(); + while rem != 0 { + // TODO check MTU + let amount = rem.min(512); + self.send_segment(&data[pos..pos + amount]).await?; + pos += amount; + rem -= amount; + } + Ok(pos) + } + + pub fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result { + if buffer.is_empty() { + return Ok(0); + } + let mut lock = self.connection.write(); + match lock.read_nonblocking(buffer) { + Ok(0) => Err(Error::WouldBlock), + res => res, + } + } + + pub fn send_nonblocking(&self, data: &[u8]) -> Result { + if data.is_empty() { + return Ok(0); + } + let mut pos = 0; + let mut rem = data.len(); + let mut sent = false; + while rem != 0 { + // TODO check MTU + let amount = rem.min(512); + match self.send_segment_nonblocking(&data[pos..pos + amount]) { + Ok(()) => sent = true, + Err(Error::WouldBlock) => break, + Err(error) => return Err(error), + } + pos += amount; + rem -= amount; + } + // No data sent and WouldBlock returned + if !sent { + return Err(Error::WouldBlock); + } + Ok(pos) + } + + async fn send_segment(&self, data: &[u8]) -> Result<(), Error> { + { + let mut connection = poll_fn(|cx| { + let connection = self.connection.write(); + match connection.poll_send(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)), + Poll::Ready(Err(error)) => Poll::Ready(Err(error)), + Poll::Pending => Poll::Pending, + } + }) + .await?; + + connection.transmit(data).await?; + } + + poll_fn(|cx| { + let connection = self.connection.read(); + connection.poll_acknowledge(cx) + }) + .await; + + Ok(()) + } + + fn send_segment_nonblocking(&self, _data: &[u8]) -> Result<(), Error> { + todo!() + } +} + +impl fmt::Debug for TcpStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TcpStream") + .field("local", &self.local) + .field("remote", &self.remote) + .finish_non_exhaustive() + } +} diff --git a/kernel/driver/net/core/src/socket/udp.rs b/kernel/driver/net/core/src/socket/udp.rs index 43427956..7cba6c07 100644 --- a/kernel/driver/net/core/src/socket/udp.rs +++ b/kernel/driver/net/core/src/socket/udp.rs @@ -2,6 +2,7 @@ use core::{ fmt, sync::atomic::{AtomicBool, AtomicU8, Ordering}, task::{Context, Poll}, + time::Duration, }; use alloc::{boxed::Box, sync::Arc, vec::Vec}; @@ -9,18 +10,22 @@ use async_trait::async_trait; use libk::{ block, error::Error, + task::runtime::maybe_timeout, vfs::{FileReadiness, PacketSocket, Socket}, }; -use libk_util::{queue::BoundedMpmcQueue, sync::spin_rwlock::IrqSafeRwLock}; -use yggdrasil_abi::net::{SocketAddr, SocketOption}; +use libk_util::{ + queue::BoundedMpmcQueue, + sync::spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockReadGuard}, +}; +use yggdrasil_abi::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketOption}; use crate::l4; use super::SocketTable; pub struct UdpSocket { - local: SocketAddr, - remote: Option, + local: IrqSafeRwLock>, + remote: IrqSafeRwLock>, broadcast: AtomicBool, ttl: AtomicU8, @@ -32,32 +37,18 @@ pub struct UdpSocket { static UDP_SOCKETS: IrqSafeRwLock> = IrqSafeRwLock::new(SocketTable::new()); impl UdpSocket { - fn create_socket(local: SocketAddr) -> Arc { - log::debug!("UDP socket opened: {}", local); - Arc::new(UdpSocket { - local, - ttl: AtomicU8::new(64), - remote: None, + pub fn new() -> Arc { + Arc::new(Self { + local: IrqSafeRwLock::new(None), + remote: IrqSafeRwLock::new(None), + broadcast: AtomicBool::new(false), + ttl: AtomicU8::new(64), + receive_queue: BoundedMpmcQueue::new(128), }) } - pub fn bind(address: SocketAddr) -> Result, Error> { - let mut sockets = UDP_SOCKETS.write(); - if address.port() == 0 { - sockets.try_insert_with_ephemeral_port(address.ip(), |port| { - Ok(Self::create_socket(SocketAddr::new(address.ip(), port))) - }) - } else { - sockets.try_insert_with(address, move || Ok(Self::create_socket(address))) - } - } - - pub fn connect(&self, _address: SocketAddr) -> Result<(), Error> { - todo!() - } - pub fn get(local: &SocketAddr) -> Option> { UDP_SOCKETS.read().get(local) } @@ -67,6 +58,25 @@ impl UdpSocket { .try_push_back((source, Vec::from(data))) .map_err(|_| Error::QueueFull) } + + // If address is bound, keep it, if not, bind an ephemeral port + pub fn ensure_address(self: &Arc, v6: bool) -> Result { + let local = self.local.read(); + if let Some(address) = *local { + Ok(address.port()) + } else { + let mut local = IrqSafeRwLockReadGuard::upgrade(local); + let ip = match v6 { + true => IpAddr::V4(Ipv4Addr::UNSPECIFIED), + false => IpAddr::V6(Ipv6Addr::UNSPECIFIED), + }; + let port = UDP_SOCKETS + .write() + .bind_to_ephemeral_port(ip, self.clone())?; + *local = Some(SocketAddr::new(ip, port)); + Ok(port) + } + } } impl FileReadiness for UdpSocket { @@ -77,18 +87,35 @@ impl FileReadiness for UdpSocket { #[async_trait] impl PacketSocket for UdpSocket { - async fn send_to(&self, destination: Option, data: &[u8]) -> Result { - let Some(destination) = destination else { - // TODO can still send without setting address if "connected" - return Err(Error::InvalidArgument); - }; + fn connect(self: Arc, remote: SocketAddr) -> Result<(), Error> { + let mut connected = self.remote.write(); + if connected.is_some() { + return Err(Error::InvalidOperation); + } + *connected = Some(remote); + Ok(()) + } + + async fn send_to( + self: Arc, + destination: Option, + data: &[u8], + _timeout: Option, + ) -> Result { + let destination = destination + .or_else(|| self.remote_address()) + .ok_or(Error::NotConnected)?; + + // If socket wasn't bound yet, bind it to an ephemeral port + let port = self.ensure_address(destination.ip().is_ipv6())?; + // TODO check that destnation family matches self family match (self.broadcast.load(Ordering::Acquire), destination.ip()) { - // SendTo in broadcast? - (true, _) => todo!(), + // TODO broadcast + (true, _) => return Err(Error::NotImplemented), (false, _) => { l4::udp::send( - self.local.port(), + port, destination.ip(), destination.port(), self.ttl.load(Ordering::Acquire), @@ -102,16 +129,21 @@ impl PacketSocket for UdpSocket { } fn send_nonblocking( - &self, + self: Arc, destination: Option, buffer: &[u8], ) -> Result { log::warn!("TODO: UDP::send_nonblocking()"); - block!(self.send_to(destination, buffer).await)? + block!(self.send_to(destination, buffer, None).await)? } - async fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> { - let (source, packet) = self.receive_queue.pop_front().await; + async fn receive_from( + self: Arc, + buffer: &mut [u8], + timeout: Option, + ) -> Result<(usize, SocketAddr), Error> { + let future = self.receive_queue.pop_front(); + let (source, packet) = maybe_timeout(future, timeout).await?; if packet.len() > buffer.len() { // TODO check how other OSs handle this @@ -121,7 +153,10 @@ impl PacketSocket for UdpSocket { Ok((packet.len(), source)) } - fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> { + fn receive_nonblocking( + self: Arc, + buffer: &mut [u8], + ) -> Result<(usize, SocketAddr), Error> { let (source, packet) = self .receive_queue .try_pop_front() @@ -137,23 +172,43 @@ impl PacketSocket for UdpSocket { } impl Socket for UdpSocket { - fn local_address(&self) -> SocketAddr { - self.local + fn bind(self: Arc, local: SocketAddr) -> Result<(), Error> { + let mut bound = self.local.write(); + if bound.is_some() { + return Err(Error::InvalidOperation); + } + if local.port() == 0 { + let port = UDP_SOCKETS + .write() + .bind_to_ephemeral_port(local.ip(), self.clone())?; + *bound = Some(SocketAddr::new(local.ip(), port)); + } else { + UDP_SOCKETS.write().try_insert(local, self.clone())?; + *bound = Some(local); + } + Ok(()) + } + + fn local_address(&self) -> Option { + *self.local.read() } fn remote_address(&self) -> Option { - self.remote + *self.remote.read() } - fn close(&self) -> Result<(), Error> { - log::debug!("UDP socket closed: {}", self.local); - UDP_SOCKETS.write().remove(self.local) + fn close(self: Arc) -> Result<(), Error> { + if let Some(local) = self.local.write().take() { + log::debug!("UDP socket closed: {}", local); + UDP_SOCKETS.write().remove(local)?; + } + Ok(()) } fn set_option(&self, option: &SocketOption) -> Result<(), Error> { match option { &SocketOption::Broadcast(broadcast) => { - log::debug!("{} broadcast: {}", self.local, broadcast); + // log::debug!("{} broadcast: {}", self.local, broadcast); self.broadcast.store(broadcast, Ordering::Release); Ok(()) } @@ -209,9 +264,12 @@ impl Socket for UdpSocket { impl fmt::Debug for UdpSocket { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let local = *self.local.read(); + let remote = *self.remote.read(); + f.debug_struct("UdpSocket") - .field("local", &self.local) - .field("remote", &self.remote) + .field("local", &local) + .field("remote", &remote) .finish_non_exhaustive() } } diff --git a/kernel/libk/libk-util/src/sync/spin_rwlock.rs b/kernel/libk/libk-util/src/sync/spin_rwlock.rs index 23a84c68..e66cabf1 100644 --- a/kernel/libk/libk-util/src/sync/spin_rwlock.rs +++ b/kernel/libk/libk-util/src/sync/spin_rwlock.rs @@ -44,6 +44,15 @@ impl RwLockInner { self.value.fetch_nand(Self::LOCKED_WRITE, Ordering::Release); } + #[inline] + fn upgrade(&self) { + // At least one read lock is held by this task. + // When there's *exactly* one lock (being this task) held, upgrade is possible + while !self.try_upgrade() { + core::hint::spin_loop(); + } + } + #[inline] fn acquire_read_raw(&self) -> usize { let value = self.value.fetch_add(Self::LOCKED_READ, Ordering::Acquire); @@ -77,6 +86,18 @@ impl RwLockInner { .is_ok() } + #[inline] + fn try_upgrade(&self) -> bool { + self.value + .compare_exchange( + Self::LOCKED_READ, + Self::LOCKED_WRITE, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + } + #[inline] fn acquire_read(&self) { while !self.try_acquire_read() { @@ -133,6 +154,11 @@ impl IrqSafeRwLock { self.inner.downgrade_write(); } + #[inline] + unsafe fn upgrade(&self) { + self.inner.upgrade(); + } + unsafe fn release_read(&self) { self.inner.release_read(); } @@ -159,10 +185,26 @@ impl Drop for IrqSafeRwLockReadGuard<'_, T> { } } -impl IrqSafeRwLockReadGuard<'_, T> { +impl<'a, T> IrqSafeRwLockReadGuard<'a, T> { pub fn get(guard: &Self) -> *const T { guard.lock.value.get() } + + pub fn upgrade(guard: IrqSafeRwLockReadGuard<'a, T>) -> IrqSafeRwLockWriteGuard<'a, T> { + let lock = guard.lock; + let irq_guard = IrqGuard::acquire(); + // Read lock still held + core::mem::forget(guard); + + unsafe { + lock.upgrade(); + } + + IrqSafeRwLockWriteGuard { + lock, + _guard: irq_guard, + } + } } impl<'a, T> IrqSafeRwLockWriteGuard<'a, T> { diff --git a/kernel/libk/src/vfs/ioctx.rs b/kernel/libk/src/vfs/ioctx.rs index 77282364..9dc5c7b2 100644 --- a/kernel/libk/src/vfs/ioctx.rs +++ b/kernel/libk/src/vfs/ioctx.rs @@ -372,10 +372,8 @@ impl IoContext { if !flags.contains_any(RemoveFlags::DIRECTORY | RemoveFlags::DIRECTORY_ONLY) { return Err(Error::IsADirectory); } - } else { - if flags.contains_any(RemoveFlags::DIRECTORY_ONLY) { - return Err(Error::NotADirectory); - } + } else if flags.contains_any(RemoveFlags::DIRECTORY_ONLY) { + return Err(Error::NotADirectory); } parent.remove_file(filename, access) diff --git a/kernel/libk/src/vfs/mod.rs b/kernel/libk/src/vfs/mod.rs index 3099128b..732d7f99 100644 --- a/kernel/libk/src/vfs/mod.rs +++ b/kernel/libk/src/vfs/mod.rs @@ -35,7 +35,7 @@ pub use path::{Filename, OwnedFilename}; pub use poll::FdPoll; pub use pty::{PseudoTerminalMaster, PseudoTerminalSlave}; pub use shared_memory::SharedMemory; -pub use socket::{ConnectionSocket, ListenerSocket, PacketSocket, Socket, SocketWrapper}; +pub use socket::{ConnectionSocket, PacketSocket, Socket, SocketWrapper}; pub use terminal::{Terminal, TerminalInput, TerminalOutput}; pub use timer::TimerFile; pub use traits::{FileReadiness, Read, Seek, Write}; diff --git a/kernel/libk/src/vfs/socket.rs b/kernel/libk/src/vfs/socket.rs index afe8d6aa..dcb42164 100644 --- a/kernel/libk/src/vfs/socket.rs +++ b/kernel/libk/src/vfs/socket.rs @@ -9,18 +9,21 @@ use async_trait::async_trait; use libk_util::sync::spin_rwlock::IrqSafeRwLock; use yggdrasil_abi::{ error::Error, - net::{SocketAddr, SocketOption}, + net::{SocketAddr, SocketOption, SocketShutdown}, }; -use crate::{task::runtime::maybe_timeout, vfs::FileReadiness}; +use crate::vfs::FileReadiness; + +use super::{File, FileRef}; enum SocketInner { - Connection(Arc), - Listener(Arc), - Packet(Arc), + Connection(Arc), + // Listener(Arc), + Packet(Arc), } struct InnerOptions { + connect_timeout: Option, recv_timeout: Option, send_timeout: Option, non_blocking: bool, @@ -34,14 +37,16 @@ pub struct SocketWrapper { /// Interface for interacting with network sockets #[allow(unused)] pub trait Socket: FileReadiness + fmt::Debug + Send { + fn bind(self: Arc, local: SocketAddr) -> Result<(), Error>; + /// Socket listen/receive address - fn local_address(&self) -> SocketAddr; + fn local_address(&self) -> Option; /// Socket remote address fn remote_address(&self) -> Option; /// Closes a socket - fn close(&self) -> Result<(), Error>; + fn close(self: Arc) -> Result<(), Error>; /// Updates a socket option fn set_option(&self, option: &SocketOption) -> Result<(), Error> { @@ -59,40 +64,62 @@ pub trait Socket: FileReadiness + fmt::Debug + Send { pub trait PacketSocket: Socket { /// Receives a packet into provided buffer. Will return an error if packet cannot be placed /// within the buffer. - async fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error>; + async fn receive_from( + self: Arc, + buffer: &mut [u8], + timeout: Option, + ) -> Result<(usize, SocketAddr), Error>; - fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error>; + fn receive_nonblocking( + self: Arc, + buffer: &mut [u8], + ) -> Result<(usize, SocketAddr), Error>; /// Sends provided data to the recepient specified by `destination` - async fn send_to(&self, destination: Option, data: &[u8]) -> Result; + async fn send_to( + self: Arc, + destination: Option, + data: &[u8], + timeout: Option, + ) -> Result; fn send_nonblocking( - &self, + self: Arc, destination: Option, buffer: &[u8], ) -> Result; + + fn connect(self: Arc, remote: SocketAddr) -> Result<(), Error>; } /// Connection-based client socket interface #[async_trait] pub trait ConnectionSocket: Socket { - /// Receives data into provided buffer - async fn receive(&self, buffer: &mut [u8]) -> Result; + async fn connect( + self: Arc, + remote: SocketAddr, + timeout: Option, + ) -> Result<(), Error>; - fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result; + async fn receive( + &self, + buffer: &mut [u8], + timeout: Option, + ) -> Result<(usize, SocketAddr), Error>; - /// Transmits data - async fn send(&self, data: &[u8]) -> Result; + fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error>; + + async fn send(&self, data: &[u8], timeout: Option) -> Result; fn send_nonblocking(&self, buffer: &[u8]) -> Result; -} -/// Connection-based listener socket interface -#[async_trait] -pub trait ListenerSocket: Socket { - /// Blocks the execution until an incoming connection is accepted + + fn listen(self: Arc) -> Result<(), Error>; + async fn accept(&self) -> Result<(SocketAddr, Arc), Error>; fn accept_nonblocking(&self) -> Result<(SocketAddr, Arc), Error>; + + async fn shutdown(&self, read: bool, write: bool) -> Result<(), Error>; } impl SocketWrapper { @@ -110,103 +137,108 @@ impl SocketWrapper { } } - pub fn from_listener(socket: Arc) -> Self { - Self { - inner: SocketInner::Listener(socket), - options: IrqSafeRwLock::new(InnerOptions::default()), + async fn send_inner(&self, data: &[u8], remote: Option) -> Result { + let timeout = self.options.read().send_timeout; + + match &self.inner { + SocketInner::Packet(socket) => socket.clone().send_to(remote, data, timeout).await, + SocketInner::Connection(socket) => socket.send(data, timeout).await, } } - pub fn accept(&self) -> Result<(SocketWrapper, SocketAddr), Error> { - let SocketInner::Listener(socket) = &self.inner else { + async fn receive_inner(&self, data: &mut [u8]) -> Result<(usize, SocketAddr), Error> { + let timeout = self.options.read().recv_timeout; + + match &self.inner { + SocketInner::Packet(socket) => socket.clone().receive_from(data, timeout).await, + SocketInner::Connection(socket) => socket.clone().receive(data, timeout).await, + } + } + + async fn connect_inner(&self, remote: SocketAddr) -> Result<(), Error> { + let timeout = self.options.read().connect_timeout; + + match &self.inner { + SocketInner::Packet(socket) => socket.clone().connect(remote), + SocketInner::Connection(socket) => socket.clone().connect(remote, timeout).await, + } + } + + pub fn connect(&self, remote: SocketAddr) -> Result<(), Error> { + block!(self.connect_inner(remote).await)? + } + + pub fn send_to(&self, data: &[u8], remote: Option) -> Result { + if self.options.read().non_blocking { + match &self.inner { + SocketInner::Packet(socket) => socket.clone().send_nonblocking(remote, data), + SocketInner::Connection(socket) => socket.clone().send_nonblocking(data), + } + } else { + block!(self.send_inner(data, remote).await)? + } + } + + pub fn receive_from(&self, data: &mut [u8]) -> Result<(usize, SocketAddr), Error> { + if self.options.read().non_blocking { + match &self.inner { + SocketInner::Packet(socket) => socket.clone().receive_nonblocking(data), + SocketInner::Connection(socket) => socket.receive_nonblocking(data), + } + } else { + block!(self.receive_inner(data).await)? + } + } + + pub fn bind(&self, local: SocketAddr) -> Result<(), Error> { + match &self.inner { + SocketInner::Packet(socket) => socket.clone().bind(local), + SocketInner::Connection(socket) => socket.clone().bind(local), + } + } + + pub fn listen(&self) -> Result<(), Error> { + match &self.inner { + SocketInner::Packet(_) => Err(Error::InvalidOperation), + SocketInner::Connection(socket) => socket.clone().listen(), + } + } + + pub fn accept(&self) -> Result<(FileRef, SocketAddr), Error> { + let SocketInner::Connection(socket) = &self.inner else { return Err(Error::InvalidOperation); }; - let options = self.options.read(); - let (remote, remote_socket) = match (options.non_blocking, options.recv_timeout) { - (false, timeout) => { - let fut = socket.accept(); - block!(maybe_timeout(fut, timeout).await)??? - } - (true, _) => socket.accept_nonblocking()?, + let (remote, stream) = if self.options.read().non_blocking { + socket.accept_nonblocking()? + } else { + block!(socket.accept().await)?? }; - let remote_socket = Self::from_connection(remote_socket); - Ok((remote_socket, remote)) + let file = File::from_socket(SocketWrapper::from_connection(stream)); + Ok((file, remote)) } - fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> { - match &self.inner { - SocketInner::Packet(socket) => socket.receive_nonblocking(buffer), - SocketInner::Connection(socket) => { - let remote = socket.remote_address().ok_or(Error::NotConnected)?; - let len = socket.receive_nonblocking(buffer)?; - Ok((len, remote)) - } - SocketInner::Listener(_) => Err(Error::InvalidOperation), - } + pub fn shutdown(&self, how: SocketShutdown) -> Result<(), Error> { + let SocketInner::Connection(socket) = &self.inner else { + return Err(Error::InvalidOperation); + }; + + block!( + socket + .shutdown( + how.contains(SocketShutdown::READ), + how.contains(SocketShutdown::WRITE), + ) + .await + )? } - async fn receive( - &self, - buffer: &mut [u8], - timeout: Option, - ) -> Result<(usize, SocketAddr), Error> { - match &self.inner { - SocketInner::Packet(socket) => { - maybe_timeout(socket.receive_from(buffer), timeout).await? - } - SocketInner::Connection(socket) => { - let remote = socket.remote_address().ok_or(Error::NotConnected)?; - let len = maybe_timeout(socket.receive(buffer), timeout).await??; - Ok((len, remote)) - } - SocketInner::Listener(_) => Err(Error::InvalidOperation), - } - } - - fn send_nonblocking(&self, buffer: &[u8], remote: Option) -> Result { - match &self.inner { - SocketInner::Packet(socket) => socket.send_nonblocking(remote, buffer), - SocketInner::Connection(socket) => socket.send_nonblocking(buffer), - SocketInner::Listener(_) => Err(Error::InvalidOperation), - } - } - - async fn send( - &self, - buffer: &[u8], - remote: Option, - timeout: Option, - ) -> Result { - match &self.inner { - SocketInner::Packet(socket) => { - maybe_timeout(socket.send_to(remote, buffer), timeout).await? - } - SocketInner::Connection(socket) => maybe_timeout(socket.send(buffer), timeout).await?, - SocketInner::Listener(_) => Err(Error::InvalidOperation), - } - } - - pub fn receive_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Error> { - let options = self.options.read(); - match (options.non_blocking, options.recv_timeout) { - (false, timeout) => block!(self.receive(buffer, timeout).await)?, - (true, _) => self.receive_nonblocking(buffer), - } - } - - pub fn send_to(&self, buffer: &[u8], remote: Option) -> Result { - let options = self.options.read(); - match (options.non_blocking, options.recv_timeout) { - (false, timeout) => block!(self.send(buffer, remote, timeout).await)?, - (true, _) => self.send_nonblocking(buffer, remote), - } - } -} - -impl Socket for SocketWrapper { - fn set_option(&self, option: &SocketOption) -> Result<(), Error> { + pub fn set_option(&self, option: &SocketOption) -> Result<(), Error> { match option { + SocketOption::NonBlocking(nb) => { + self.options.write().non_blocking = *nb; + return Ok(()); + } SocketOption::RecvTimeout(timeout) => { self.options.write().recv_timeout = *timeout; return Ok(()); @@ -215,22 +247,25 @@ impl Socket for SocketWrapper { self.options.write().send_timeout = *timeout; return Ok(()); } - SocketOption::NonBlocking(nb) => { - self.options.write().non_blocking = *nb; + SocketOption::ConnectTimeout(timeout) => { + self.options.write().connect_timeout = *timeout; return Ok(()); } _ => (), } match &self.inner { - SocketInner::Packet(socket) => socket.set_option(option), - SocketInner::Listener(socket) => socket.set_option(option), SocketInner::Connection(socket) => socket.set_option(option), + SocketInner::Packet(socket) => socket.set_option(option), } } - fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> { + pub fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> { match option { + SocketOption::NonBlocking(nb) => { + *nb = self.options.read().non_blocking; + return Ok(()); + } SocketOption::RecvTimeout(timeout) => { *timeout = self.options.read().recv_timeout; return Ok(()); @@ -239,41 +274,16 @@ impl Socket for SocketWrapper { *timeout = self.options.read().send_timeout; return Ok(()); } - SocketOption::NonBlocking(nb) => { - *nb = self.options.read().non_blocking; + SocketOption::ConnectTimeout(timeout) => { + *timeout = self.options.read().connect_timeout; return Ok(()); } _ => (), } match &self.inner { - SocketInner::Packet(socket) => socket.get_option(option), - SocketInner::Listener(socket) => socket.get_option(option), SocketInner::Connection(socket) => socket.get_option(option), - } - } - - fn local_address(&self) -> SocketAddr { - match &self.inner { - SocketInner::Packet(socket) => socket.local_address(), - SocketInner::Listener(socket) => socket.local_address(), - SocketInner::Connection(socket) => socket.local_address(), - } - } - - fn remote_address(&self) -> Option { - match &self.inner { - SocketInner::Packet(socket) => socket.remote_address(), - SocketInner::Listener(socket) => socket.remote_address(), - SocketInner::Connection(socket) => socket.remote_address(), - } - } - - fn close(&self) -> Result<(), Error> { - match &self.inner { - SocketInner::Packet(socket) => socket.close(), - SocketInner::Listener(socket) => socket.close(), - SocketInner::Connection(socket) => socket.close(), + SocketInner::Packet(socket) => socket.get_option(option), } } } @@ -281,9 +291,8 @@ impl Socket for SocketWrapper { impl FileReadiness for SocketWrapper { fn poll_read(&self, cx: &mut Context<'_>) -> Poll> { match &self.inner { - SocketInner::Packet(socket) => socket.poll_read(cx), - SocketInner::Listener(socket) => socket.poll_read(cx), SocketInner::Connection(socket) => socket.poll_read(cx), + SocketInner::Packet(socket) => socket.poll_read(cx), } } } @@ -293,7 +302,6 @@ impl fmt::Debug for SocketWrapper { match &self.inner { SocketInner::Packet(socket) => socket.fmt(f), SocketInner::Connection(socket) => socket.fmt(f), - SocketInner::Listener(socket) => socket.fmt(f), } } } @@ -301,9 +309,8 @@ impl fmt::Debug for SocketWrapper { impl Drop for SocketInner { fn drop(&mut self) { let res = match self { - Self::Packet(socket) => socket.close(), - Self::Connection(socket) => socket.close(), - Self::Listener(socket) => socket.close(), + Self::Packet(socket) => socket.clone().close(), + Self::Connection(socket) => socket.clone().close(), }; if let Err(error) = res { log::warn!("Socket close error: {error:?}"); @@ -314,69 +321,10 @@ impl Drop for SocketInner { impl Default for InnerOptions { fn default() -> Self { Self { + connect_timeout: None, recv_timeout: None, send_timeout: None, non_blocking: false, } } } - -// impl From> for SocketWrapper { -// fn from(value: Arc) -> Self { -// Self::Connection(value) -// } -// } -// -// impl From> for SocketWrapper { -// fn from(value: Arc) -> Self { -// Self::Listener(value) -// } -// } -// -// impl From> for SocketWrapper { -// fn from(value: Arc) -> Self { -// Self::Packet(value) -// } -// } - -// impl Deref for PacketSocketWrapper { -// type Target = dyn PacketSocket; -// -// fn deref(&self) -> &Self::Target { -// self.0.as_ref() -// } -// } -// -// impl Drop for PacketSocketWrapper { -// fn drop(&mut self) { -// self.0.close().ok(); -// } -// } -// -// impl Deref for ListenerSocketWrapper { -// type Target = dyn ListenerSocket; -// -// fn deref(&self) -> &Self::Target { -// self.0.as_ref() -// } -// } -// -// impl Drop for ListenerSocketWrapper { -// fn drop(&mut self) { -// self.0.close().ok(); -// } -// } -// -// impl Deref for ConnectionSocketWrapper { -// type Target = dyn ConnectionSocket; -// -// fn deref(&self) -> &Self::Target { -// self.0.as_ref() -// } -// } -// -// impl Drop for ConnectionSocketWrapper { -// fn drop(&mut self) { -// self.0.close().ok(); -// } -// } diff --git a/kernel/src/syscall/imp/mod.rs b/kernel/src/syscall/imp/mod.rs index c2a3b4ca..a1ec0483 100644 --- a/kernel/src/syscall/imp/mod.rs +++ b/kernel/src/syscall/imp/mod.rs @@ -8,7 +8,7 @@ pub(crate) use abi::{ UnmountOptions, }, mem::{MappingFlags, MappingSource}, - net::SocketType, + net::{SocketShutdown, SocketType}, process::{Signal, SignalEntryData, SpawnOptions, WaitFlags}, system::SystemInfo, }; diff --git a/kernel/src/syscall/imp/sys_net.rs b/kernel/src/syscall/imp/sys_net.rs index 5e03dc17..e6a54b0a 100644 --- a/kernel/src/syscall/imp/sys_net.rs +++ b/kernel/src/syscall/imp/sys_net.rs @@ -3,127 +3,218 @@ use core::{mem::MaybeUninit, net::SocketAddr}; use abi::{ error::Error, io::RawFd, - net::{SocketConnect, SocketOption, SocketType}, + net::{SocketOption, SocketShutdown, SocketType}, }; use libk::{ task::thread::Thread, - vfs::{File, Socket, SocketWrapper}, + vfs::{File, FileRef, SocketWrapper}, }; -use ygg_driver_net_core::socket::{RawSocket, TcpListener, TcpSocket, UdpSocket}; +use ygg_driver_net_core::socket::{RawSocket, TcpSocket, UdpSocket}; use crate::syscall::run_with_io; -// Network -pub(crate) fn connect_socket( - connect: &mut SocketConnect, - local_result: &mut MaybeUninit, -) -> Result { +fn get_socket(fd: RawFd) -> Result { let thread = Thread::current(); let process = thread.process(); - run_with_io(&process, |mut io| { - let (local, fd) = match connect { - SocketConnect::Tcp(remote, timeout) => { - let remote = (*remote).into(); - let (local, socket) = libk::block!(TcpSocket::connect(remote, *timeout).await)??; - let file = File::from_socket(SocketWrapper::from_connection(socket)); - let fd = io.files.place_file(file, true)?; - (local.into(), fd) - } - SocketConnect::Udp(_socket, _remote) => { - todo!("UDP socket connect") - } - }; - local_result.write(local); - Ok(fd) - }) + run_with_io(&process, |io| io.files.file(fd).cloned()) } -pub(crate) fn bind_socket(listen: &SocketAddr, ty: SocketType) -> Result { +pub(crate) fn create_socket(ty: SocketType) -> Result { let thread = Thread::current(); let process = thread.process(); run_with_io(&process, |mut io| { - let listen = (*listen).into(); let socket = match ty { - SocketType::RawPacket => SocketWrapper::from_packet(RawSocket::bind()?), - SocketType::UdpPacket => SocketWrapper::from_packet(UdpSocket::bind(listen)?), - SocketType::TcpStream => SocketWrapper::from_listener(TcpListener::bind(listen)?), + SocketType::RawPacket => SocketWrapper::from_packet(RawSocket::new()), + SocketType::UdpPacket => SocketWrapper::from_packet(UdpSocket::new()), + SocketType::TcpStream => SocketWrapper::from_connection(TcpSocket::new()), }; let file = File::from_socket(socket); let fd = io.files.place_file(file, true)?; + Ok(fd) }) } -pub(crate) fn accept( - socket_fd: RawFd, - remote_result: &mut MaybeUninit, -) -> Result { +pub(crate) fn bind(sock_fd: RawFd, local: &SocketAddr) -> Result<(), Error> { + let file = get_socket(sock_fd)?; + file.as_socket()?.bind((*local).into()) +} + +pub(crate) fn listen(sock_fd: RawFd) -> Result<(), Error> { + let file = get_socket(sock_fd)?; + file.as_socket()?.listen() +} + +pub(crate) fn connect(sock_fd: RawFd, remote: &SocketAddr) -> Result<(), Error> { + let file = get_socket(sock_fd)?; + file.as_socket()?.connect((*remote).into()) +} + +pub(crate) fn accept(sock_fd: RawFd, remote: &mut MaybeUninit) -> Result { let thread = Thread::current(); let process = thread.process(); run_with_io(&process, |mut io| { - let file = io.files.file(socket_fd)?; - let listener = file.as_socket()?; - let (socket, remote) = listener.accept()?; - remote_result.write(remote.into()); - let fd = io.files.place_file(File::from_socket(socket), true)?; - Ok(fd) + let listener = io.files.file(sock_fd)?; + let listener = listener.as_socket()?; + + let (stream_file, stream_remote) = listener.accept()?; + let stream_fd = io.files.place_file(stream_file, true)?; + + remote.write(stream_remote.into()); + + Ok(stream_fd) }) } +pub(crate) fn shutdown(sock_fd: RawFd, how: SocketShutdown) -> Result<(), Error> { + let file = get_socket(sock_fd)?; + file.as_socket()?.shutdown(how) +} + pub(crate) fn send_to( - socket_fd: RawFd, - buffer: &[u8], - recepient: &Option, + sock_fd: RawFd, + data: &[u8], + remote: &Option, ) -> Result { - let thread = Thread::current(); - let process = thread.process(); - - run_with_io(&process, |io| { - let file = io.files.file(socket_fd)?; - let socket = file.as_socket()?; - let remote = recepient.map(Into::into); - socket.send_to(buffer, remote) - }) + let file = get_socket(sock_fd)?; + file.as_socket()?.send_to(data, remote.map(Into::into)) } pub(crate) fn receive_from( - socket_fd: RawFd, - buffer: &mut [u8], - remote_result: &mut MaybeUninit, + sock_fd: RawFd, + data: &mut [u8], + remote: &mut MaybeUninit, ) -> Result { - let thread = Thread::current(); - let process = thread.process(); - - run_with_io(&process, |io| { - let file = io.files.file(socket_fd)?; - let socket = file.as_socket()?; - let (len, remote) = socket.receive_from(buffer)?; - remote_result.write(remote.into()); - Ok(len) - }) + let file = get_socket(sock_fd)?; + let (len, remote_) = file.as_socket()?.receive_from(data)?; + remote.write(remote_.into()); + Ok(len) } -pub(crate) fn set_socket_option(socket_fd: RawFd, option: &SocketOption) -> Result<(), Error> { - let thread = Thread::current(); - let process = thread.process(); - - run_with_io(&process, |io| { - let file = io.files.file(socket_fd)?; - let socket = file.as_socket()?; - socket.set_option(option) - }) +pub(crate) fn get_socket_option(sock_fd: RawFd, option: &mut SocketOption) -> Result<(), Error> { + let file = get_socket(sock_fd)?; + file.as_socket()?.get_option(option) } -pub(crate) fn get_socket_option(socket_fd: RawFd, option: &mut SocketOption) -> Result<(), Error> { - let thread = Thread::current(); - let process = thread.process(); - - run_with_io(&process, |io| { - let file = io.files.file(socket_fd)?; - let socket = file.as_socket()?; - socket.get_option(option) - }) +pub(crate) fn set_socket_option(sock_fd: RawFd, option: &SocketOption) -> Result<(), Error> { + let file = get_socket(sock_fd)?; + file.as_socket()?.set_option(option) } + +// // Network +// pub(crate) fn connect_socket( +// connect: &mut SocketConnect, +// local_result: &mut MaybeUninit, +// ) -> Result { +// let thread = Thread::current(); +// let process = thread.process(); +// +// run_with_io(&process, |mut io| { +// let (local, fd) = match connect { +// SocketConnect::Tcp(remote, timeout) => { +// let remote = (*remote).into(); +// let (local, socket) = libk::block!(TcpSocket::connect(remote, *timeout).await)??; +// let file = File::from_socket(SocketWrapper::from_connection(socket)); +// let fd = io.files.place_file(file, true)?; +// (local.into(), fd) +// } +// SocketConnect::Udp(_socket, _remote) => { +// todo!("UDP socket connect") +// } +// }; +// local_result.write(local); +// Ok(fd) +// }) +// } +// +// pub(crate) fn bind_socket(listen: &SocketAddr, ty: SocketType) -> Result { +// let thread = Thread::current(); +// let process = thread.process(); +// +// run_with_io(&process, |mut io| { +// let listen = (*listen).into(); +// let socket = match ty { +// SocketType::RawPacket => SocketWrapper::from_packet(RawSocket::bind()?), +// SocketType::UdpPacket => SocketWrapper::from_packet(UdpSocket::bind(listen)?), +// SocketType::TcpStream => SocketWrapper::from_listener(TcpListener::bind(listen)?), +// }; +// let file = File::from_socket(socket); +// let fd = io.files.place_file(file, true)?; +// Ok(fd) +// }) +// } +// +// pub(crate) fn accept( +// socket_fd: RawFd, +// remote_result: &mut MaybeUninit, +// ) -> Result { +// let thread = Thread::current(); +// let process = thread.process(); +// +// run_with_io(&process, |mut io| { +// let file = io.files.file(socket_fd)?; +// let listener = file.as_socket()?; +// let (socket, remote) = listener.accept()?; +// remote_result.write(remote.into()); +// let fd = io.files.place_file(File::from_socket(socket), true)?; +// Ok(fd) +// }) +// } +// +// pub(crate) fn send_to( +// socket_fd: RawFd, +// buffer: &[u8], +// recepient: &Option, +// ) -> Result { +// let thread = Thread::current(); +// let process = thread.process(); +// +// run_with_io(&process, |io| { +// let file = io.files.file(socket_fd)?; +// let socket = file.as_socket()?; +// let remote = recepient.map(Into::into); +// socket.send_to(buffer, remote) +// }) +// } +// +// pub(crate) fn receive_from( +// socket_fd: RawFd, +// buffer: &mut [u8], +// remote_result: &mut MaybeUninit, +// ) -> Result { +// let thread = Thread::current(); +// let process = thread.process(); +// +// run_with_io(&process, |io| { +// let file = io.files.file(socket_fd)?; +// let socket = file.as_socket()?; +// let (len, remote) = socket.receive_from(buffer)?; +// remote_result.write(remote.into()); +// Ok(len) +// }) +// } +// +// pub(crate) fn set_socket_option(socket_fd: RawFd, option: &SocketOption) -> Result<(), Error> { +// let thread = Thread::current(); +// let process = thread.process(); +// +// run_with_io(&process, |io| { +// let file = io.files.file(socket_fd)?; +// let socket = file.as_socket()?; +// socket.set_option(option) +// }) +// } +// +// pub(crate) fn get_socket_option(socket_fd: RawFd, option: &mut SocketOption) -> Result<(), Error> { +// let thread = Thread::current(); +// let process = thread.process(); +// +// run_with_io(&process, |io| { +// let file = io.files.file(socket_fd)?; +// let socket = file.as_socket()?; +// socket.get_option(option) +// }) +// } diff --git a/lib/abi/def/yggdrasil.abi b/lib/abi/def/yggdrasil.abi index 5e4ddda2..3c7253f0 100644 --- a/lib/abi/def/yggdrasil.abi +++ b/lib/abi/def/yggdrasil.abi @@ -53,6 +53,13 @@ enum SocketType(u32) { UdpPacket = 2, } +bitfield SocketShutdown(u32) { + /// Stop reception on a socket + READ: 0, + /// Stop transmission on a socket + WRITE: 1, +} + // abi::mem bitfield MappingFlags(u32) { @@ -161,13 +168,15 @@ syscall receive_message( ) -> Result<()>; // Network -syscall connect_socket(connect: &mut SocketConnect, local: &mut MaybeUninit) -> Result; -syscall bind_socket(address: &SocketAddr, ty: SocketType) -> Result; -syscall accept(sock_fd: RawFd, remote: &mut MaybeUninit) -> Result; +syscall create_socket(ty: SocketType) -> Result; +syscall bind(sock_fd: RawFd, local: &SocketAddr) -> Result<()>; +syscall listen(sock_fd: RawFd) -> Result<()>; +syscall connect(sock_fd: RawFd, remote: &SocketAddr) -> Result<()>; +syscall accept(sock_fd: RawFd, remote: &mut MaybeUninit) -> Result; +syscall shutdown(sock_fd: RawFd, how: SocketShutdown) -> Result<()>; syscall send_to(sock_fd: RawFd, data: &[u8], remote: &Option) -> Result; syscall receive_from(sock_fd: RawFd, data: &mut [u8], remote: &mut MaybeUninit) -> Result; - syscall get_socket_option(sock_fd: RawFd, option: &mut SocketOption<'_>) -> Result<()>; syscall set_socket_option(sock_fd: RawFd, option: &SocketOption<'_>) -> Result<()>; diff --git a/lib/abi/src/lib.rs b/lib/abi/src/lib.rs index ce880a96..99bc24a5 100644 --- a/lib/abi/src/lib.rs +++ b/lib/abi/src/lib.rs @@ -4,6 +4,7 @@ clippy::new_without_default, clippy::should_implement_trait, clippy::module_inception, + clippy::missing_transmute_annotations, incomplete_features, stable_features )] diff --git a/lib/abi/src/net/mod.rs b/lib/abi/src/net/mod.rs index 2e72d77b..f9e1a8cb 100644 --- a/lib/abi/src/net/mod.rs +++ b/lib/abi/src/net/mod.rs @@ -10,8 +10,7 @@ pub mod types; use core::time::Duration; -pub use crate::generated::SocketType; -use crate::io::RawFd; +pub use crate::generated::{SocketShutdown, SocketType}; pub use types::{ ip_addr::{IpAddr, Ipv4Addr, Ipv6Addr}, @@ -20,16 +19,6 @@ pub use types::{ MacAddress, }; -/// Describes a socket connect operation -#[derive(Clone, Debug)] -pub enum SocketConnect { - /// Connect a TCP socket with optional timeout. - Tcp(core::net::SocketAddr, Option), - /// "Connect" an UDP socket, this just sets the sender's address in the socket so - /// the caller can then use send(). - Udp(RawFd, core::net::SocketAddr), -} - /// Describes a method to query an interface #[derive(Clone, Debug)] pub enum SocketInterfaceQuery<'a> { @@ -52,6 +41,8 @@ pub enum SocketOption<'a> { UnbindInterface, /// (Read-only) Hardware address of the bound interface BoundHardwareAddress(MacAddress), + /// (Read-only) Local socket address + LocalAddress(Option), /// (Read-only) Remote socket address PeerAddress(Option), /// If set, reception will return [crate::error::Error::WouldBlock] if the socket has @@ -63,6 +54,8 @@ pub enum SocketOption<'a> { RecvTimeout(Option), /// If not [None], send operations will have a time limit set before returning an error. SendTimeout(Option), + /// If not [None], connect() call will timeout after the specified time limit. + ConnectTimeout(Option), /// (UDP) If set, allows multicast packets to be looped back to local host. MulticastLoopV4(bool), /// (UDP) If set, allows multicast packets to be looped back to local host. diff --git a/lib/abi/src/net/types/ip_addr.rs b/lib/abi/src/net/types/ip_addr.rs index 286740bf..af42ab59 100644 --- a/lib/abi/src/net/types/ip_addr.rs +++ b/lib/abi/src/net/types/ip_addr.rs @@ -85,6 +85,30 @@ impl From for core::net::Ipv4Addr { // IPv6 +impl Ipv6Addr { + /// An IPv6 unspecified address `::`. + pub const UNSPECIFIED: Self = Self::new(0, 0, 0, 0, 0, 0, 0, 0); + + /// Constructs a new IPv6 address from its words. + /// + /// The result represents the IP address `a:b:c:d:e:f:g:h`. + #[allow(clippy::too_many_arguments)] + pub const fn new(a: u16, b: u16, c: u16, d: u16, e: u16, f: u16, g: u16, h: u16) -> Self { + let addr16 = [ + a.to_be(), + b.to_be(), + c.to_be(), + d.to_be(), + e.to_be(), + f.to_be(), + g.to_be(), + h.to_be(), + ]; + // SAFETY: `[u16; 8]` is safe to transmute to `[u8; 16]` + Self(unsafe { core::mem::transmute::<_, [u8; 16]>(addr16) }) + } +} + impl FromStr for Ipv6Addr { type Err = Error; diff --git a/lib/abi/src/net/types/socket_addr.rs b/lib/abi/src/net/types/socket_addr.rs index abee3d44..d2e85630 100644 --- a/lib/abi/src/net/types/socket_addr.rs +++ b/lib/abi/src/net/types/socket_addr.rs @@ -175,3 +175,15 @@ impl From for SocketAddr { } } } + +impl From for SocketAddr { + fn from(value: SocketAddrV4) -> Self { + Self::V4(value) + } +} + +impl From for SocketAddr { + fn from(value: SocketAddrV6) -> Self { + Self::V6(value) + } +} diff --git a/lib/runtime/src/net.rs b/lib/runtime/src/net.rs index f56490c1..0bd4609d 100644 --- a/lib/runtime/src/net.rs +++ b/lib/runtime/src/net.rs @@ -1,12 +1,10 @@ //! Network-related functions and types -use core::{mem::MaybeUninit, net::SocketAddr, time::Duration}; +use core::{net::SocketAddr, time::Duration}; -pub use abi::net::{MacAddress, SocketConnect, SocketInterfaceQuery, SocketOption, SocketType}; +pub use abi::net::{MacAddress, SocketInterfaceQuery, SocketOption, SocketShutdown, SocketType}; use abi::{error::Error, io::RawFd}; -use crate::sys; - #[allow(unused_macros)] macro socket_option_variant { ($opt:ident: bool) => { $crate::net::SocketOption::$opt(false) }, @@ -79,23 +77,70 @@ pub mod dns { } } -fn connect_inner(connect: &mut SocketConnect) -> Result<(SocketAddr, RawFd), Error> { - let mut local = MaybeUninit::uninit(); - let fd = unsafe { sys::connect_socket(connect, &mut local) }?; - let local = unsafe { local.assume_init() }; - Ok((local, fd)) +fn bind_inner(fd: RawFd, local: &SocketAddr, listen: bool) -> Result<(), Error> { + unsafe { crate::sys::bind(fd, local) }?; + if listen { + unsafe { crate::sys::listen(fd) }?; + } + Ok(()) +} + +fn connect_inner(fd: RawFd, remote: &SocketAddr, timeout: Option) -> Result<(), Error> { + if timeout.is_some() { + unsafe { crate::sys::set_socket_option(fd, &SocketOption::ConnectTimeout(timeout)) }?; + } + unsafe { crate::sys::connect(fd, remote) }?; + Ok(()) +} + +/// Creates a new socket and binds it to a local address +pub fn create_and_bind(ty: SocketType, local: &SocketAddr, listen: bool) -> Result { + let fd = unsafe { crate::sys::create_socket(ty) }?; + match bind_inner(fd, local, listen) { + Ok(()) => Ok(fd), + Err(error) => { + unsafe { crate::sys::close(fd) }.ok(); + Err(error) + } + } +} + +/// Binds a TCP listener socket to some local address +pub fn bind_tcp(local: &SocketAddr) -> Result { + create_and_bind(SocketType::TcpStream, local, true) +} + +/// Binds a raw socket to some network interface +pub fn bind_raw(iface: SocketInterfaceQuery<'_>) -> Result { + let fd = unsafe { crate::sys::create_socket(SocketType::RawPacket) }?; + let option = SocketOption::BindInterface(iface); + match unsafe { crate::sys::set_socket_option(fd, &option) } { + Ok(()) => Ok(fd), + Err(error) => { + unsafe { crate::sys::close(fd) }.ok(); + Err(error) + } + } +} + +/// Binds an UDP socket to some local address +pub fn bind_udp(local: &SocketAddr) -> Result { + create_and_bind(SocketType::UdpPacket, local, false) } /// Connect to a TCP listener -pub fn connect_tcp( - remote: SocketAddr, - timeout: Option, -) -> Result<(SocketAddr, RawFd), Error> { - connect_inner(&mut SocketConnect::Tcp(remote, timeout)) +pub fn connect_tcp(remote: &SocketAddr, timeout: Option) -> Result { + let fd = unsafe { crate::sys::create_socket(SocketType::TcpStream) }?; + match connect_inner(fd, remote, timeout) { + Ok(()) => Ok(fd), + Err(error) => { + unsafe { crate::sys::close(fd) }.ok(); + Err(error) + } + } } /// "Connect" an UDP socket -pub fn connect_udp(socket_fd: RawFd, remote: SocketAddr) -> Result<(), Error> { - connect_inner(&mut SocketConnect::Udp(socket_fd, remote))?; - Ok(()) +pub fn connect_udp(socket_fd: RawFd, remote: &SocketAddr) -> Result<(), Error> { + connect_inner(socket_fd, remote, None) } diff --git a/lib/runtime/src/sys/mod.rs b/lib/runtime/src/sys/mod.rs index 8f7aa820..58a2128e 100644 --- a/lib/runtime/src/sys/mod.rs +++ b/lib/runtime/src/sys/mod.rs @@ -21,7 +21,7 @@ mod generated { TerminalOptions, TerminalSize, TimerOptions, UnmountOptions, }, mem::{MappingFlags, MappingSource}, - net::SocketType, + net::{SocketShutdown, SocketType}, process::{ ExecveOptions, ProcessGroupId, ProcessId, Signal, SignalEntryData, SpawnOptions, ThreadSpawnOptions, WaitFlags, diff --git a/test.c b/test.c index 7abeba11..7f8bb2a3 100644 --- a/test.c +++ b/test.c @@ -1,8 +1,14 @@ -#include -#include +#include +#include #include #include int main(int argc, const char **argv) { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + perror("socket()"); + return EXIT_FAILURE; + } + return 0; } diff --git a/userspace/lib/ygglibc/build.rs b/userspace/lib/ygglibc/build.rs index 99e1ed34..016d47d0 100644 --- a/userspace/lib/ygglibc/build.rs +++ b/userspace/lib/ygglibc/build.rs @@ -24,7 +24,7 @@ fn include_dir(d: &DirEntry) -> bool { && d.path() .iter() .nth(2) - .map_or(false, |c| c.to_str().map_or(false, |x| !x.starts_with("_"))) + .is_some_and(|c| c.to_str().is_some_and(|x| !x.starts_with("_"))) } fn generate_header(config_path: impl AsRef, header_output: impl AsRef) { @@ -67,11 +67,7 @@ fn generate_header(config_path: impl AsRef, header_output: impl AsRef) { let output_dir = output_dir.as_ref(); let mut command = Command::new("clang"); - let arch = if arch == "x86" { - "i686" - } else { - arch - }; + let arch = if arch == "x86" { "i686" } else { arch }; let input_dir = PathBuf::from("crt").join(arch); let crt0_c = input_dir.join("crt0.c"); let crt0_s = input_dir.join("crt0.S"); diff --git a/userspace/lib/ygglibc/src/headers/arpa_inet/cbindgen.toml b/userspace/lib/ygglibc/src/headers/arpa_inet/cbindgen.toml new file mode 100644 index 00000000..0f14959f --- /dev/null +++ b/userspace/lib/ygglibc/src/headers/arpa_inet/cbindgen.toml @@ -0,0 +1,15 @@ +language = "C" +style = "Type" + +sys_includes = [ + "netinet/in.h", + "inttypes.h" +] +no_includes = true + +include_guard = "_ARPA_INET_H" + +usize_type = "size_t" +isize_type = "ssize_t" + +[export] diff --git a/userspace/lib/ygglibc/src/headers/arpa_inet/mod.rs b/userspace/lib/ygglibc/src/headers/arpa_inet/mod.rs new file mode 100644 index 00000000..464646e9 --- /dev/null +++ b/userspace/lib/ygglibc/src/headers/arpa_inet/mod.rs @@ -0,0 +1,20 @@ + +#[no_mangle] +unsafe extern "C" fn htonl(v: u32) -> u32 { + v.to_be() +} + +#[no_mangle] +unsafe extern "C" fn htons(v: u16) -> u16 { + v.to_be() +} + +#[no_mangle] +unsafe extern "C" fn ntohl(v: u32) -> u32 { + u32::from_be(v) +} + +#[no_mangle] +unsafe extern "C" fn ntohs(v: u16) -> u16 { + u16::from_be(v) +} diff --git a/userspace/lib/ygglibc/src/headers/errno/mod.rs b/userspace/lib/ygglibc/src/headers/errno/mod.rs index f28f7c09..f021c725 100644 --- a/userspace/lib/ygglibc/src/headers/errno/mod.rs +++ b/userspace/lib/ygglibc/src/headers/errno/mod.rs @@ -195,6 +195,7 @@ impl From for Errno { Error::DirectoryNotEmpty => Errno::ENOTEMPTY, Error::NotConnected => Errno::ENOTCONN, Error::ProcessNotFound => Errno::ESRCH, + Error::CrossDeviceLink => Errno::EXDEV, } } } diff --git a/userspace/lib/ygglibc/src/headers/fcntl/mod.rs b/userspace/lib/ygglibc/src/headers/fcntl/mod.rs index 70c682d4..afee8bfc 100644 --- a/userspace/lib/ygglibc/src/headers/fcntl/mod.rs +++ b/userspace/lib/ygglibc/src/headers/fcntl/mod.rs @@ -1,6 +1,9 @@ use core::ffi::{c_char, c_int, c_short, VaList}; -use yggdrasil_rt::{io::{AccessMode, FileMode, OpenOptions}, sys as syscall}; +use yggdrasil_rt::{ + io::{AccessMode, FileMode, OpenOptions}, + sys as syscall, +}; use crate::{ error::{CFdResult, CIntCountResult, CIntZeroResult, EResult, ResultExt, TryFromExt}, @@ -125,10 +128,7 @@ fn open_opts(opts: c_int, ap: &mut VaList) -> EResult { } // TODO O_CLOEXEC - if opts - & (O_DSYNC | O_RSYNC | O_SYNC | O_TTY_INIT | O_NONBLOCK | O_NOFOLLOW | O_NOCTTY) - != 0 - { + if opts & (O_DSYNC | O_RSYNC | O_SYNC | O_TTY_INIT | O_NONBLOCK | O_NOFOLLOW | O_NOCTTY) != 0 { todo!(); } @@ -187,7 +187,7 @@ pub(crate) unsafe extern "C" fn faccessat( atfd: c_int, path: *const c_char, mode: c_int, - flags: c_int, + _flags: c_int, ) -> CIntZeroResult { let atfd = util::at_fd(atfd)?; let path = path.ensure_str(); diff --git a/userspace/lib/ygglibc/src/headers/mod.rs b/userspace/lib/ygglibc/src/headers/mod.rs index c2857d03..5ecb3fc8 100644 --- a/userspace/lib/ygglibc/src/headers/mod.rs +++ b/userspace/lib/ygglibc/src/headers/mod.rs @@ -136,6 +136,11 @@ pub mod sys_types; pub mod sys_utsname; pub mod sys_wait; +// Network +pub mod arpa_inet; +pub mod netinet_in; +pub mod sys_socket; + // TODO Generate those as part of dyn-loader (and make dyn-loader a shared library) pub mod link; diff --git a/userspace/lib/ygglibc/src/headers/netinet_in/cbindgen.toml b/userspace/lib/ygglibc/src/headers/netinet_in/cbindgen.toml new file mode 100644 index 00000000..adee9ece --- /dev/null +++ b/userspace/lib/ygglibc/src/headers/netinet_in/cbindgen.toml @@ -0,0 +1,17 @@ +language = "C" +style = "Tag" + +sys_includes = [ + "inttypes.h", + "sys/socket.h", + "arpa/inet.h" +] +no_includes = true + +include_guard = "_NETINET_IN_H" + +usize_type = "size_t" +isize_type = "ssize_t" + +[export] +include = ["sockaddr_in", "sockaddr_in6"] diff --git a/userspace/lib/ygglibc/src/headers/netinet_in/mod.rs b/userspace/lib/ygglibc/src/headers/netinet_in/mod.rs new file mode 100644 index 00000000..6ce23d5b --- /dev/null +++ b/userspace/lib/ygglibc/src/headers/netinet_in/mod.rs @@ -0,0 +1,60 @@ +use core::ffi::c_int; + +use super::sys_socket::sa_family_t; + +pub type in_port_t = u16; +pub type in_addr_t = u32; + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct in_addr { + pub s_addr: in_addr_t +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct in6_addr { + pub s6_addr: [u8; 16] +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct sockaddr_in { + pub sin_family: sa_family_t, + pub sin_port: in_port_t, + pub sin_addr: in_addr, +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct sockaddr_in6 { + pub sin6_family: sa_family_t, + pub sin6_port: in_port_t, + pub sin6_flowinfo: u32, + pub sin6_addr: in6_addr, + pub sin6_scope_id: u32 +} + +// TODO more IPv6 + +pub const IPPROTO_ICMP: c_int = 1; +pub const IPPROTO_TCP: c_int = 6; +pub const IPPROTO_UDP: c_int = 17; + +pub const IPPROTO_IP: c_int = 0; +pub const IPPROTO_IPV6: c_int = 41; +pub const IPPROTO_RAW: c_int = 255; + +pub const INADDR_ANY: in_addr_t = 0; +pub const INADDR_BROADCAST: in_addr_t = 0xFFFFFFFF; + +pub const INET_ADDRSTRLEN: usize = 16; +pub const INET6_ADDRSTRLEN: usize = 46; + +pub const IPV6_JOIN_GROUP: c_int = 6001; +pub const IPV6_LEAVE_GROUP: c_int = 6002; +pub const IPV6_MULTICAST_HOPS: c_int = 6003; +pub const IPV6_MULTICAST_IF: c_int = 6004; +pub const IPV6_MULTICAST_LOOP: c_int = 6005; +pub const IPV6_UNICAST_HOPS: c_int = 6006; +pub const IPV6_V6ONLY: c_int = 6007; diff --git a/userspace/lib/ygglibc/src/headers/stdio/util.rs b/userspace/lib/ygglibc/src/headers/stdio/util.rs index da1d2f4a..16983c68 100644 --- a/userspace/lib/ygglibc/src/headers/stdio/util.rs +++ b/userspace/lib/ygglibc/src/headers/stdio/util.rs @@ -1,9 +1,15 @@ use core::ffi::{c_char, c_int}; -use yggdrasil_rt::sys as syscall; +use yggdrasil_rt::{ + io::{RemoveFlags, Rename}, + sys as syscall, +}; use crate::{ - error::{self, CIntZeroResult, CPtrResult, ResultExt}, headers::errno::Errno, io::managed::{stderr, FILE}, util::{PointerExt, PointerStrExt} + error::{self, CIntZeroResult, CPtrResult, ResultExt}, + headers::{errno::Errno, fcntl::AT_FDCWD}, + io::managed::{stderr, FILE}, + util::{self, PointerExt, PointerStrExt}, }; #[no_mangle] @@ -32,21 +38,34 @@ unsafe extern "C" fn ctermid(_buf: *mut c_char) -> *mut c_char { #[no_mangle] unsafe extern "C" fn remove(path: *const c_char) -> CIntZeroResult { let path = path.ensure_str(); - syscall::remove(None, path).e_map_err(Errno::from)?; + syscall::remove(None, path, RemoveFlags::DIRECTORY).e_map_err(Errno::from)?; CIntZeroResult::SUCCESS } #[no_mangle] -unsafe extern "C" fn rename(src: *const c_char, dst: *const c_char) -> c_int { - let src = src.ensure_str(); - let dst = dst.ensure_str(); - yggdrasil_rt::debug_trace!("rename {src:?} -> {dst:?}"); - todo!() +unsafe extern "C" fn rename(src: *const c_char, dst: *const c_char) -> CIntZeroResult { + renameat(AT_FDCWD, src, dst) } #[no_mangle] -unsafe extern "C" fn renameat(_atfd: c_int, _src: *const c_char, _dst: *const c_char) -> c_int { - todo!() +unsafe extern "C" fn renameat( + atfd: c_int, + src: *const c_char, + dst: *const c_char, +) -> CIntZeroResult { + let at = util::at_fd(atfd)?; + let source = src.ensure_str(); + let destination = dst.ensure_str(); + + syscall::rename(&Rename { + source_at: at, + destination_at: at, + source, + destination, + }) + .e_map_err(Errno::from)?; + + CIntZeroResult::SUCCESS } #[no_mangle] diff --git a/userspace/lib/ygglibc/src/headers/stdlib/io.rs b/userspace/lib/ygglibc/src/headers/stdlib/io.rs index 73a91d75..ee1eda5e 100644 --- a/userspace/lib/ygglibc/src/headers/stdlib/io.rs +++ b/userspace/lib/ygglibc/src/headers/stdlib/io.rs @@ -1,6 +1,16 @@ -use core::{ffi::{c_char, c_int}, ptr::NonNull, slice}; +use core::{ + ffi::{c_char, c_int}, + ptr::NonNull, + slice, +}; -use crate::{allocator::c_alloc, error::{self, CPtrResult, CResult}, headers::errno::Errno, io, util::PointerStrExt}; +use crate::{ + allocator::c_alloc, + error::{self, CPtrResult, CResult}, + headers::errno::Errno, + io, + util::PointerStrExt, +}; #[no_mangle] unsafe extern "C" fn grantpt(_fd: c_int) -> c_int { @@ -28,7 +38,10 @@ unsafe extern "C" fn ptsname(_fd: c_int) -> *mut c_char { } #[no_mangle] -unsafe extern "C" fn realpath(path: *const c_char, mut resolved_ptr: *mut c_char) -> CPtrResult { +unsafe extern "C" fn realpath( + path: *const c_char, + resolved_ptr: *mut c_char, +) -> CPtrResult { if path.is_null() { error::errno = Errno::EINVAL; return CPtrResult::ERROR; diff --git a/userspace/lib/ygglibc/src/headers/sys_mman/mod.rs b/userspace/lib/ygglibc/src/headers/sys_mman/mod.rs index dfae8426..3803feee 100644 --- a/userspace/lib/ygglibc/src/headers/sys_mman/mod.rs +++ b/userspace/lib/ygglibc/src/headers/sys_mman/mod.rs @@ -1,8 +1,18 @@ -use core::{ffi::{c_char, c_int, c_void}, num::NonZeroUsize, ptr::{self, NonNull}}; +use core::{ + ffi::{c_char, c_int, c_void}, + ptr::{self, NonNull}, +}; -use yggdrasil_rt::{io::RawFd, mem::{MappingFlags, MappingSource}, sys as syscall}; +use yggdrasil_rt::{ + io::RawFd, + mem::{MappingFlags, MappingSource}, + sys as syscall, +}; -use crate::{error::{self, CPtrResult, EResult, ResultExt, TryFromExt}, headers::errno::Errno}; +use crate::{ + error::{self, EResult, ResultExt, TryFromExt}, + headers::errno::Errno, +}; use super::sys_types::{mode_t, off_t}; diff --git a/userspace/lib/ygglibc/src/headers/sys_socket/cbindgen.toml b/userspace/lib/ygglibc/src/headers/sys_socket/cbindgen.toml new file mode 100644 index 00000000..72a85dd8 --- /dev/null +++ b/userspace/lib/ygglibc/src/headers/sys_socket/cbindgen.toml @@ -0,0 +1,14 @@ +language = "C" +style = "Tag" + +sys_includes = [ + "sys/types.h" +] +no_includes = true + +include_guard = "_SYS_SOCKET_H" + +usize_type = "size_t" +isize_type = "ssize_t" + +[export] diff --git a/userspace/lib/ygglibc/src/headers/sys_socket/io.rs b/userspace/lib/ygglibc/src/headers/sys_socket/io.rs new file mode 100644 index 00000000..c8ff0d86 --- /dev/null +++ b/userspace/lib/ygglibc/src/headers/sys_socket/io.rs @@ -0,0 +1,72 @@ +use core::ffi::{c_int, c_void}; + +use super::{msghdr, sockaddr, socklen_t}; + +#[no_mangle] +unsafe extern "C" fn recv(fd: c_int, buffer: *mut c_void, len: usize, flags: c_int) -> isize { + let _ = fd; + let _ = buffer; + let _ = len; + let _ = flags; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn recvfrom( + fd: c_int, + buffer: *mut c_void, + len: usize, + flags: c_int, + remote: *mut sockaddr, + remote_len: *mut socklen_t, +) -> isize { + let _ = fd; + let _ = buffer; + let _ = len; + let _ = flags; + let _ = remote; + let _ = remote_len; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn recvmsg(fd: c_int, message: *mut msghdr, flags: c_int) -> isize { + let _ = fd; + let _ = message; + let _ = flags; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn send(fd: c_int, data: *const c_void, len: usize) -> isize { + let _ = fd; + let _ = data; + let _ = len; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn sendmsg(fd: c_int, message: *const msghdr, flags: c_int) -> isize { + let _ = fd; + let _ = message; + let _ = flags; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn sendto( + fd: c_int, + data: *const c_void, + len: usize, + flags: c_int, + remote: *const sockaddr, + remote_len: socklen_t, +) -> isize { + let _ = fd; + let _ = data; + let _ = len; + let _ = flags; + let _ = remote; + let _ = remote_len; + todo!() +} diff --git a/userspace/lib/ygglibc/src/headers/sys_socket/mod.rs b/userspace/lib/ygglibc/src/headers/sys_socket/mod.rs new file mode 100644 index 00000000..a7ada45f --- /dev/null +++ b/userspace/lib/ygglibc/src/headers/sys_socket/mod.rs @@ -0,0 +1,86 @@ +use core::ffi::{c_int, c_void}; + +mod io; +mod option; +mod socket; + +pub type socklen_t = usize; +pub type sa_family_t = u16; + +pub const __SS_SIZE: usize = 256; +// __SS_SIZE - sizeof(sa_samily_t) +pub const __SOCKADDR_LEN: usize = __SS_SIZE - 2; +// __SS_SIZE - sizeof(sa_family_t) - sizeof(__ss_aligntype) +pub const __SS_PADSIZE: usize = __SS_SIZE - 6; + +pub type __ss_aligntype = u32; + +#[derive(Clone, Copy, Debug)] +#[repr(C)] +pub struct sockaddr { + pub sa_family: sa_family_t, + pub sa_data: [u8; __SOCKADDR_LEN], +} + +// TODO struct sockaddr_storage + +// TODO struct iovec from sys/uio.h +#[derive(Clone, Copy, Debug)] +#[repr(C)] +pub struct iovec { + __dummy: u32 +} + +#[derive(Clone, Copy, Debug)] +#[repr(C)] +pub struct msghdr { + pub msg_name: *mut c_void, + pub msg_namelen: socklen_t, + pub msg_iov: *mut iovec, + pub msg_iovlen: c_int, + pub msg_control: *mut c_void, + pub msg_controllen: socklen_t, + pub msg_flags: c_int +} + +// TODO struct cmsghdr + +// socket() parameters +pub const SOCK_DGRAM: c_int = 1; +pub const SOCK_RAW: c_int = 2; +pub const SOCK_SEQPACKET: c_int = 3; +pub const SOCK_STREAM: c_int = 4; + +// setsockopt() parameters +pub const SOL_SOCKET: c_int = 1; + +pub const SO_ACCEPTCONN: c_int = 1; +pub const SO_BROADCAST: c_int = 2; +pub const SO_DEBUG: c_int = 3; +pub const SO_DONTROUTE: c_int = 4; +pub const SO_ERROR: c_int = 5; +pub const SO_KEEPALIVE: c_int = 6; +pub const SO_LINGER: c_int = 7; +pub const SO_OOBINLINE: c_int = 8; +pub const SO_RCVBUF: c_int = 9; +pub const SO_RCVLOWAT: c_int = 10; +pub const SO_RCVTIMEO: c_int = 11; +pub const SO_REUSEADDR: c_int = 12; +pub const SO_SNDBUF: c_int = 13; +pub const SO_SNDLOWAT: c_int = 14; +pub const SO_SNDTIMEO: c_int = 15; +pub const SO_TYPE: c_int = 16; + +pub const SOMAXCONN: usize = 64; + +// TODO msg_flags values + +pub const AF_INET: c_int = 1; +pub const AF_INET6: c_int = 2; +pub const AF_UNIX: c_int = 3; +pub const AF_UNSPEC: c_int = 0; + +pub const SHUT_RD: c_int = 1 << 0; +pub const SHUT_WR: c_int = 1 << 1; +pub const SHUT_RDWR: c_int = SHUT_RD | SHUT_WR; + diff --git a/userspace/lib/ygglibc/src/headers/sys_socket/option.rs b/userspace/lib/ygglibc/src/headers/sys_socket/option.rs new file mode 100644 index 00000000..571fec6b --- /dev/null +++ b/userspace/lib/ygglibc/src/headers/sys_socket/option.rs @@ -0,0 +1,35 @@ +use core::ffi::{c_int, c_void}; + +use super::socklen_t; + +#[no_mangle] +unsafe extern "C" fn getsockopt( + fd: c_int, + level: c_int, + name: c_int, + value: *mut c_void, + size: *mut socklen_t, +) -> c_int { + let _ = fd; + let _ = level; + let _ = name; + let _ = value; + let _ = size; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn setsockopt( + fd: c_int, + level: c_int, + name: c_int, + value: *const c_void, + size: socklen_t, +) -> c_int { + let _ = fd; + let _ = level; + let _ = name; + let _ = value; + let _ = size; + todo!() +} diff --git a/userspace/lib/ygglibc/src/headers/sys_socket/socket.rs b/userspace/lib/ygglibc/src/headers/sys_socket/socket.rs new file mode 100644 index 00000000..0e5c6ce8 --- /dev/null +++ b/userspace/lib/ygglibc/src/headers/sys_socket/socket.rs @@ -0,0 +1,99 @@ +use core::ffi::c_int; + +use crate::{ + error::{self, CFdResult, CResult}, + headers::{ + errno::Errno, + sys_socket::{AF_INET, SOCK_DGRAM, SOCK_STREAM}, + }, +}; + +use super::{sockaddr, socklen_t}; + +#[no_mangle] +unsafe extern "C" fn accept(fd: c_int, remote: *mut sockaddr, len: *mut socklen_t) -> c_int { + let _ = fd; + let _ = remote; + let _ = len; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn bind(fd: c_int, local: *const sockaddr, len: socklen_t) -> c_int { + let _ = fd; + let _ = local; + let _ = len; + + todo!() +} + +#[no_mangle] +unsafe extern "C" fn connect(fd: c_int, remote: *const sockaddr, len: socklen_t) -> c_int { + let _ = fd; + let _ = remote; + let _ = len; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn getpeername(fd: c_int, remote: *mut sockaddr, len: *mut socklen_t) -> c_int { + let _ = fd; + let _ = remote; + let _ = len; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn getsockname(fd: c_int, local: *mut sockaddr, len: *mut socklen_t) -> c_int { + let _ = fd; + let _ = local; + let _ = len; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn listen(fd: c_int, backlog: c_int) -> c_int { + let _ = fd; + let _ = backlog; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn shutdown(fd: c_int, how: c_int) -> c_int { + let _ = fd; + let _ = how; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn sockatmark(fd: c_int) -> c_int { + let _ = fd; + todo!() +} + +#[no_mangle] +unsafe extern "C" fn socket(domain: c_int, ty: c_int, proto: c_int) -> CFdResult { + match (domain, ty, proto) { + (AF_INET, SOCK_STREAM, 0) => todo!(), + (AF_INET, SOCK_DGRAM, 0) => todo!(), + (_, _, _) => { + yggdrasil_rt::debug_trace!("Unsupported socket({domain}, {ty}, {proto})"); + error::errno = Errno::ENOTSUPP; + CFdResult::ERROR + } + } +} + +#[no_mangle] +unsafe extern "C" fn socketpair( + domain: c_int, + ty: c_int, + proto: c_int, + sockets: *mut c_int, +) -> c_int { + let _ = domain; + let _ = ty; + let _ = proto; + let _ = sockets; + todo!() +} diff --git a/userspace/lib/ygglibc/src/headers/time/timer.rs b/userspace/lib/ygglibc/src/headers/time/timer.rs index 9fb1220e..efcdf4b5 100644 --- a/userspace/lib/ygglibc/src/headers/time/timer.rs +++ b/userspace/lib/ygglibc/src/headers/time/timer.rs @@ -1,8 +1,14 @@ use core::{ffi::c_int, ptr::NonNull}; -use crate::{error::{self, CIntZeroResult, CResult, EResult, ResultExt}, headers::{ - errno::Errno, sys_time::{__ygg_timespec_t, timespec}, sys_types::{clock_t, clockid_t, pid_t, time_t}, time::{CLOCK_MONOTONIC, CLOCK_REALTIME} -}}; +use crate::{ + error::{CIntZeroResult, EResult, ResultExt}, + headers::{ + errno::Errno, + sys_time::{__ygg_timespec_t, timespec}, + sys_types::{clock_t, clockid_t, pid_t, time_t}, + time::{CLOCK_MONOTONIC, CLOCK_REALTIME}, + }, +}; use yggdrasil_rt::time::{self as rt, ClockType}; @@ -21,7 +27,7 @@ fn clock_type(clock_id: clockid_t) -> EResult { match clock_id { CLOCK_REALTIME => EResult::Ok(ClockType::RealTime), CLOCK_MONOTONIC => EResult::Ok(ClockType::Monotonic), - _ => EResult::Err(Errno::EINVAL) + _ => EResult::Err(Errno::EINVAL), } } @@ -42,14 +48,17 @@ unsafe extern "C" fn clock_getres(_clock_id: clockid_t, _ts: *mut __ygg_timespec } #[no_mangle] -unsafe extern "C" fn clock_gettime(clock_id: clockid_t, ts: *mut __ygg_timespec_t) -> CIntZeroResult { +unsafe extern "C" fn clock_gettime( + clock_id: clockid_t, + ts: *mut __ygg_timespec_t, +) -> CIntZeroResult { let clock = clock_type(clock_id)?; let time = rt::get_clock(clock).e_map_err(Errno::from)?; if let Some(ts) = NonNull::new(ts) { ts.write(timespec { - tv_sec: time_t(time.seconds as _), - tv_nsec: time.nanoseconds as _ + tv_sec: time_t(time.seconds() as _), + tv_nsec: time.subsec_nanos() as _, }); } diff --git a/userspace/lib/ygglibc/src/headers/wchar/multibyte.rs b/userspace/lib/ygglibc/src/headers/wchar/multibyte.rs index 350adb79..d322497d 100644 --- a/userspace/lib/ygglibc/src/headers/wchar/multibyte.rs +++ b/userspace/lib/ygglibc/src/headers/wchar/multibyte.rs @@ -1,7 +1,4 @@ -use core::{ - ffi::{c_char, c_int}, - ptr::NonNull, -}; +use core::{ffi::c_char, ptr::NonNull}; use crate::{ error::{CIntZeroResult, CUsizeResult, OptionExt}, @@ -21,7 +18,7 @@ unsafe extern "C" fn mbrlen(_str: *const c_char, _n: usize, _state: *mut mbstate #[no_mangle] unsafe extern "C" fn wcrtomb(dst: *mut c_char, wc: wchar_t, state: *mut mbstate_t) -> CUsizeResult { - let state = match state.as_mut() { + let _state = match state.as_mut() { Some(state) => state, #[allow(static_mut_refs)] None => &mut GLOBAL, diff --git a/userspace/lib/ygglibc/src/random.rs b/userspace/lib/ygglibc/src/random.rs index 497c1a63..136071d0 100644 --- a/userspace/lib/ygglibc/src/random.rs +++ b/userspace/lib/ygglibc/src/random.rs @@ -1,8 +1,5 @@ use core::cell::RefCell; -use yggdrasil_rt::process::thread_local; - - struct RandomState { xs64: u64 } diff --git a/userspace/netutils/src/dhcp_client.rs b/userspace/netutils/src/dhcp_client.rs index d08fc1fd..71d0c268 100644 --- a/userspace/netutils/src/dhcp_client.rs +++ b/userspace/netutils/src/dhcp_client.rs @@ -2,7 +2,7 @@ use std::os::{ fd::AsRawFd, - yggdrasil::io::{poll::PollChannel, raw_socket::RawSocket, timer::TimerFd}, + yggdrasil::io::{poll::PollChannel, net::raw_socket::RawSocket, timer::TimerFd}, }; use std::{io, mem::size_of, process::ExitCode, time::Duration}; diff --git a/userspace/netutils/src/ping.rs b/userspace/netutils/src/ping.rs index a9e92bf4..2a75a324 100644 --- a/userspace/netutils/src/ping.rs +++ b/userspace/netutils/src/ping.rs @@ -1,347 +1,349 @@ -#![feature(yggdrasil_os, rustc_private)] +fn main() {} -use std::{ - mem::size_of, - os::{ - fd::AsRawFd, - yggdrasil::io::{poll::PollChannel, raw_socket::RawSocket, timer::TimerFd}, - }, - process::ExitCode, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; - -use bytemuck::Zeroable; -use clap::Parser; -use netutils::{netconfig::NetConfig, Error}; -use yggdrasil_abi::net::{ - protocols::{EtherType, EthernetFrame, IcmpV4Frame, InetChecksum, IpProtocol, Ipv4Frame}, - types::NetValueImpl, - IpAddr, Ipv4Addr, MacAddress, -}; - -#[derive(Parser)] -struct Args { - #[clap( - help = "Time (ms) between a reply is received and the next request is sent", - short, - long, - default_value_t = 1000, - value_parser = valid_interval - )] - inteval: u32, - #[clap( - help = "Time (ms) after which the request is considered unanswered", - short, - long, - default_value_t = 500, - value_parser = valid_timeout, - )] - timeout: u32, - #[clap( - help = "Number of requests to perform", - short, - long, - default_value_t = 10 - )] - count: usize, - #[clap( - help = "Amount of bytes to include as data", - short, - long, - default_value_t = 64, - value_parser = valid_data_size - )] - data_size: usize, - - #[clap(help = "Address to ping")] - address: core::net::IpAddr, -} - -fn valid_interval(s: &str) -> Result { - clap_num::number_range(s, 100, 10000) -} - -fn valid_timeout(s: &str) -> Result { - clap_num::number_range(s, 100, 5000) -} - -fn valid_data_size(s: &str) -> Result { - clap_num::number_range(s, 4, 128) -} - -struct PingRouting { - interface_id: u32, - source_ip: IpAddr, - destination_ip: IpAddr, - source_mac: MacAddress, - gateway_mac: MacAddress, -} - -struct PingStats { - packets_sent: usize, - packets_received: usize, -} - -fn resolve_routing(address: IpAddr) -> Result { - let mut nc = NetConfig::open()?; - let routing = nc.query_route(address)?; - let Some(source) = routing.source else { - todo!(); - }; - let Some(gateway) = routing.gateway else { - todo!(); - }; - - let gateway_mac = nc.query_arp(routing.interface_id, gateway, true)?; - - Ok(PingRouting { - interface_id: routing.interface_id, - source_ip: source, - destination_ip: routing.destination, - source_mac: routing.source_mac, - gateway_mac, - }) -} - -fn validate_ping_reply( - packet: &[u8], - local: Ipv4Addr, - remote: Ipv4Addr, - expect_l4_data: &[u8], - expect_id: u16, - expect_seq: u16, -) -> bool { - if packet.len() < size_of::() + size_of::() { - return false; - } - - let l3_offset = size_of::(); - - let l2_frame: &EthernetFrame = bytemuck::from_bytes(&packet[..l3_offset]); - - if EtherType::from_network_order(l2_frame.ethertype) != EtherType::IPV4 { - return false; - } - let l3_frame: &Ipv4Frame = - bytemuck::from_bytes(&packet[l3_offset..l3_offset + size_of::()]); - if l3_frame.protocol != IpProtocol::ICMP - || u32::from_network_order(l3_frame.source_address) != u32::from(remote) - || u32::from_network_order(l3_frame.destination_address) != u32::from(local) - { - return false; - } - let mut ip_checksum = InetChecksum::new(); - ip_checksum.add_value(l3_frame, true); - let ip_checksum = ip_checksum.finish(); - - if ip_checksum != 0 { - eprintln!("IP checksum mismatch: {:#06x}", ip_checksum); - return false; - } - - let l4_offset = l3_offset + l3_frame.header_length(); - let l4_size = l3_frame - .total_length() - .saturating_sub(l3_frame.header_length()); - if packet.len() < l4_offset + size_of::() + expect_l4_data.len() { - return false; - } - let l4_frame: &IcmpV4Frame = - bytemuck::from_bytes(&packet[l4_offset..l4_offset + size_of::()]); - let l4_data = &packet[l4_offset + size_of::()..l4_offset + l4_size]; - - if l4_frame.ty != 0 || l4_frame.code != 0 { - return false; - } - - let rest = u32::from_network_order(l4_frame.rest); - let reply_id = (rest >> 16) as u16; - let reply_seq = rest as u16; - - if reply_id != expect_id || reply_seq != expect_seq { - eprintln!( - "ICMP seq/id mismatch: sent {}/{}, got {}/{}", - expect_id, expect_seq, reply_id, reply_seq - ); - return false; - } - - let mut icmp_checksum = InetChecksum::new(); - icmp_checksum.add_value(l4_frame, true); - icmp_checksum.add_bytes(l4_data, true); - let icmp_checksum = icmp_checksum.finish(); - - if icmp_checksum != 0 { - eprintln!("ICMP checksum mismatch: {:#06x}", icmp_checksum); - return false; - } - - l4_data == expect_l4_data -} - -#[allow(clippy::too_many_arguments)] -fn ping_once( - socket: &mut RawSocket, - poll: &mut PollChannel, - timer: &mut TimerFd, - info: &PingRouting, - timeout: Duration, - data_len: usize, - id: u16, - seq: u16, -) -> Result { - let mut buffer = [0; 4096]; - - let source_ip = info.source_ip.into_ipv4().unwrap(); - let destination_ip = info.destination_ip.into_ipv4().unwrap(); - let mut l4_data = Vec::with_capacity(data_len); - - for _ in 0..data_len { - l4_data.push(rand::random()); - } - - let ip_len = (size_of::() + size_of::() + data_len) - .try_into() - .unwrap(); - - let l2_frame = EthernetFrame { - source_mac: info.source_mac, - destination_mac: info.gateway_mac, - ethertype: EtherType::IPV4.to_network_order(), - }; - let mut l3_frame = Ipv4Frame { - source_address: u32::from(source_ip).to_network_order(), - destination_address: u32::from(destination_ip).to_network_order(), - protocol: IpProtocol::ICMP, - version_length: 0x45, - total_length: u16::to_network_order(ip_len), - flags_frag: u16::to_network_order(0x4000), - id: u16::to_network_order(0), - ttl: 255, - ..Ipv4Frame::zeroed() - }; - let mut l4_frame = IcmpV4Frame { - ty: 8, - code: 0, - checksum: u16::to_network_order(0), - rest: u32::to_network_order(((id as u32) << 16) | (seq as u32)), - }; - - let mut ip_checksum = InetChecksum::new(); - ip_checksum.add_value(&l3_frame, true); - l3_frame.header_checksum = ip_checksum.finish().to_network_order(); - - let mut icmp_checksum = InetChecksum::new(); - icmp_checksum.add_value(&l4_frame, true); - icmp_checksum.add_bytes(&l4_data, true); - l4_frame.checksum = icmp_checksum.finish().to_network_order(); - - let mut packet = vec![]; - packet.extend_from_slice(bytemuck::bytes_of(&l2_frame)); - packet.extend_from_slice(bytemuck::bytes_of(&l3_frame)); - packet.extend_from_slice(bytemuck::bytes_of(&l4_frame)); - packet.extend_from_slice(&l4_data); - - timer.start(timeout)?; - socket.send(&packet)?; - - loop { - let (fd, result) = poll.wait(None, true)?.unwrap(); - result?; - - match fd { - fd if fd == socket.as_raw_fd() => { - // TODO - let len = socket.recv(&mut buffer)?; - if validate_ping_reply(&buffer[..len], source_ip, destination_ip, &l4_data, id, seq) - { - return Ok(true); - } - } - fd if fd == timer.as_raw_fd() => { - return Ok(false); - } - _ => unreachable!(), - } - } -} - -fn ping( - address: IpAddr, - times: usize, - data_len: usize, - interval: Duration, - timeout: Duration, -) -> Result { - let routing = resolve_routing(address)?; - - let mut stats = PingStats { - packets_sent: 0, - packets_received: 0, - }; - let mut poll = PollChannel::new()?; - let mut timer = TimerFd::new(false, false)?; - let mut socket = RawSocket::bind(routing.interface_id)?; - - poll.add(timer.as_raw_fd())?; - poll.add(socket.as_raw_fd())?; - - let id = rand::random(); - for i in 0..times { - if INTERRUPTED.load(Ordering::Acquire) { - break; - } - - let result = ping_once( - &mut socket, - &mut poll, - &mut timer, - &routing, - timeout, - data_len, - id, - i as u16, - )?; - stats.packets_sent += 1; - - if result { - stats.packets_received += 1; - println!("[{}/{}] {}: PONG", i + 1, times, address); - } - - std::thread::sleep(interval); - } - - Ok(stats) -} - -static INTERRUPTED: AtomicBool = AtomicBool::new(false); - -fn main() -> ExitCode { - // set_signal_handler(Signal::Interrupted, SignalHandler::Function(interrupt)); - - let args = Args::parse(); - - let stats = match ping( - args.address.into(), - args.count, - args.data_size, - Duration::from_millis(args.inteval.into()), - Duration::from_millis(args.timeout.into()), - ) { - Ok(stats) => stats, - Err(error) => { - eprintln!("ping: {}", error); - return ExitCode::FAILURE; - } - }; - - let loss = (stats.packets_sent - stats.packets_received) * 100 / stats.packets_sent; - println!( - "{} sent, {} received, {}% loss", - stats.packets_sent, stats.packets_received, loss - ); - - ExitCode::SUCCESS -} +// #![feature(yggdrasil_os, rustc_private)] +// +// use std::{ +// mem::size_of, +// os::{ +// fd::AsRawFd, +// yggdrasil::io::{poll::PollChannel, raw_socket::RawSocket, timer::TimerFd}, +// }, +// process::ExitCode, +// sync::atomic::{AtomicBool, Ordering}, +// time::Duration, +// }; +// +// use bytemuck::Zeroable; +// use clap::Parser; +// use netutils::{netconfig::NetConfig, Error}; +// use yggdrasil_abi::net::{ +// protocols::{EtherType, EthernetFrame, IcmpV4Frame, InetChecksum, IpProtocol, Ipv4Frame}, +// types::NetValueImpl, +// IpAddr, Ipv4Addr, MacAddress, +// }; +// +// #[derive(Parser)] +// struct Args { +// #[clap( +// help = "Time (ms) between a reply is received and the next request is sent", +// short, +// long, +// default_value_t = 1000, +// value_parser = valid_interval +// )] +// inteval: u32, +// #[clap( +// help = "Time (ms) after which the request is considered unanswered", +// short, +// long, +// default_value_t = 500, +// value_parser = valid_timeout, +// )] +// timeout: u32, +// #[clap( +// help = "Number of requests to perform", +// short, +// long, +// default_value_t = 10 +// )] +// count: usize, +// #[clap( +// help = "Amount of bytes to include as data", +// short, +// long, +// default_value_t = 64, +// value_parser = valid_data_size +// )] +// data_size: usize, +// +// #[clap(help = "Address to ping")] +// address: core::net::IpAddr, +// } +// +// fn valid_interval(s: &str) -> Result { +// clap_num::number_range(s, 100, 10000) +// } +// +// fn valid_timeout(s: &str) -> Result { +// clap_num::number_range(s, 100, 5000) +// } +// +// fn valid_data_size(s: &str) -> Result { +// clap_num::number_range(s, 4, 128) +// } +// +// struct PingRouting { +// interface_id: u32, +// source_ip: IpAddr, +// destination_ip: IpAddr, +// source_mac: MacAddress, +// gateway_mac: MacAddress, +// } +// +// struct PingStats { +// packets_sent: usize, +// packets_received: usize, +// } +// +// fn resolve_routing(address: IpAddr) -> Result { +// let mut nc = NetConfig::open()?; +// let routing = nc.query_route(address)?; +// let Some(source) = routing.source else { +// todo!(); +// }; +// let Some(gateway) = routing.gateway else { +// todo!(); +// }; +// +// let gateway_mac = nc.query_arp(routing.interface_id, gateway, true)?; +// +// Ok(PingRouting { +// interface_id: routing.interface_id, +// source_ip: source, +// destination_ip: routing.destination, +// source_mac: routing.source_mac, +// gateway_mac, +// }) +// } +// +// fn validate_ping_reply( +// packet: &[u8], +// local: Ipv4Addr, +// remote: Ipv4Addr, +// expect_l4_data: &[u8], +// expect_id: u16, +// expect_seq: u16, +// ) -> bool { +// if packet.len() < size_of::() + size_of::() { +// return false; +// } +// +// let l3_offset = size_of::(); +// +// let l2_frame: &EthernetFrame = bytemuck::from_bytes(&packet[..l3_offset]); +// +// if EtherType::from_network_order(l2_frame.ethertype) != EtherType::IPV4 { +// return false; +// } +// let l3_frame: &Ipv4Frame = +// bytemuck::from_bytes(&packet[l3_offset..l3_offset + size_of::()]); +// if l3_frame.protocol != IpProtocol::ICMP +// || u32::from_network_order(l3_frame.source_address) != u32::from(remote) +// || u32::from_network_order(l3_frame.destination_address) != u32::from(local) +// { +// return false; +// } +// let mut ip_checksum = InetChecksum::new(); +// ip_checksum.add_value(l3_frame, true); +// let ip_checksum = ip_checksum.finish(); +// +// if ip_checksum != 0 { +// eprintln!("IP checksum mismatch: {:#06x}", ip_checksum); +// return false; +// } +// +// let l4_offset = l3_offset + l3_frame.header_length(); +// let l4_size = l3_frame +// .total_length() +// .saturating_sub(l3_frame.header_length()); +// if packet.len() < l4_offset + size_of::() + expect_l4_data.len() { +// return false; +// } +// let l4_frame: &IcmpV4Frame = +// bytemuck::from_bytes(&packet[l4_offset..l4_offset + size_of::()]); +// let l4_data = &packet[l4_offset + size_of::()..l4_offset + l4_size]; +// +// if l4_frame.ty != 0 || l4_frame.code != 0 { +// return false; +// } +// +// let rest = u32::from_network_order(l4_frame.rest); +// let reply_id = (rest >> 16) as u16; +// let reply_seq = rest as u16; +// +// if reply_id != expect_id || reply_seq != expect_seq { +// eprintln!( +// "ICMP seq/id mismatch: sent {}/{}, got {}/{}", +// expect_id, expect_seq, reply_id, reply_seq +// ); +// return false; +// } +// +// let mut icmp_checksum = InetChecksum::new(); +// icmp_checksum.add_value(l4_frame, true); +// icmp_checksum.add_bytes(l4_data, true); +// let icmp_checksum = icmp_checksum.finish(); +// +// if icmp_checksum != 0 { +// eprintln!("ICMP checksum mismatch: {:#06x}", icmp_checksum); +// return false; +// } +// +// l4_data == expect_l4_data +// } +// +// #[allow(clippy::too_many_arguments)] +// fn ping_once( +// socket: &mut RawSocket, +// poll: &mut PollChannel, +// timer: &mut TimerFd, +// info: &PingRouting, +// timeout: Duration, +// data_len: usize, +// id: u16, +// seq: u16, +// ) -> Result { +// let mut buffer = [0; 4096]; +// +// let source_ip = info.source_ip.into_ipv4().unwrap(); +// let destination_ip = info.destination_ip.into_ipv4().unwrap(); +// let mut l4_data = Vec::with_capacity(data_len); +// +// for _ in 0..data_len { +// l4_data.push(rand::random()); +// } +// +// let ip_len = (size_of::() + size_of::() + data_len) +// .try_into() +// .unwrap(); +// +// let l2_frame = EthernetFrame { +// source_mac: info.source_mac, +// destination_mac: info.gateway_mac, +// ethertype: EtherType::IPV4.to_network_order(), +// }; +// let mut l3_frame = Ipv4Frame { +// source_address: u32::from(source_ip).to_network_order(), +// destination_address: u32::from(destination_ip).to_network_order(), +// protocol: IpProtocol::ICMP, +// version_length: 0x45, +// total_length: u16::to_network_order(ip_len), +// flags_frag: u16::to_network_order(0x4000), +// id: u16::to_network_order(0), +// ttl: 255, +// ..Ipv4Frame::zeroed() +// }; +// let mut l4_frame = IcmpV4Frame { +// ty: 8, +// code: 0, +// checksum: u16::to_network_order(0), +// rest: u32::to_network_order(((id as u32) << 16) | (seq as u32)), +// }; +// +// let mut ip_checksum = InetChecksum::new(); +// ip_checksum.add_value(&l3_frame, true); +// l3_frame.header_checksum = ip_checksum.finish().to_network_order(); +// +// let mut icmp_checksum = InetChecksum::new(); +// icmp_checksum.add_value(&l4_frame, true); +// icmp_checksum.add_bytes(&l4_data, true); +// l4_frame.checksum = icmp_checksum.finish().to_network_order(); +// +// let mut packet = vec![]; +// packet.extend_from_slice(bytemuck::bytes_of(&l2_frame)); +// packet.extend_from_slice(bytemuck::bytes_of(&l3_frame)); +// packet.extend_from_slice(bytemuck::bytes_of(&l4_frame)); +// packet.extend_from_slice(&l4_data); +// +// timer.start(timeout)?; +// socket.send(&packet)?; +// +// loop { +// let (fd, result) = poll.wait(None, true)?.unwrap(); +// result?; +// +// match fd { +// fd if fd == socket.as_raw_fd() => { +// // TODO +// let len = socket.recv(&mut buffer)?; +// if validate_ping_reply(&buffer[..len], source_ip, destination_ip, &l4_data, id, seq) +// { +// return Ok(true); +// } +// } +// fd if fd == timer.as_raw_fd() => { +// return Ok(false); +// } +// _ => unreachable!(), +// } +// } +// } +// +// fn ping( +// address: IpAddr, +// times: usize, +// data_len: usize, +// interval: Duration, +// timeout: Duration, +// ) -> Result { +// let routing = resolve_routing(address)?; +// +// let mut stats = PingStats { +// packets_sent: 0, +// packets_received: 0, +// }; +// let mut poll = PollChannel::new()?; +// let mut timer = TimerFd::new(false, false)?; +// let mut socket = RawSocket::bind(routing.interface_id)?; +// +// poll.add(timer.as_raw_fd())?; +// poll.add(socket.as_raw_fd())?; +// +// let id = rand::random(); +// for i in 0..times { +// if INTERRUPTED.load(Ordering::Acquire) { +// break; +// } +// +// let result = ping_once( +// &mut socket, +// &mut poll, +// &mut timer, +// &routing, +// timeout, +// data_len, +// id, +// i as u16, +// )?; +// stats.packets_sent += 1; +// +// if result { +// stats.packets_received += 1; +// println!("[{}/{}] {}: PONG", i + 1, times, address); +// } +// +// std::thread::sleep(interval); +// } +// +// Ok(stats) +// } +// +// static INTERRUPTED: AtomicBool = AtomicBool::new(false); +// +// fn main() -> ExitCode { +// // set_signal_handler(Signal::Interrupted, SignalHandler::Function(interrupt)); +// +// let args = Args::parse(); +// +// let stats = match ping( +// args.address.into(), +// args.count, +// args.data_size, +// Duration::from_millis(args.inteval.into()), +// Duration::from_millis(args.timeout.into()), +// ) { +// Ok(stats) => stats, +// Err(error) => { +// eprintln!("ping: {}", error); +// return ExitCode::FAILURE; +// } +// }; +// +// let loss = (stats.packets_sent - stats.packets_received) * 100 / stats.packets_sent; +// println!( +// "{} sent, {} received, {}% loss", +// stats.packets_sent, stats.packets_received, loss +// ); +// +// ExitCode::SUCCESS +// }