diff --git a/kernel/driver/net/core/src/socket.rs b/kernel/driver/net/core/src/socket.rs index d7196bc1..86a9e2f6 100644 --- a/kernel/driver/net/core/src/socket.rs +++ b/kernel/driver/net/core/src/socket.rs @@ -1,7 +1,7 @@ use core::{ future::{poll_fn, Future}, pin::Pin, - sync::atomic::{AtomicBool, AtomicU32, Ordering}, + sync::atomic::{AtomicBool, AtomicU32, AtomicU8, Ordering}, task::{Context, Poll}, time::Duration, }; @@ -42,6 +42,8 @@ pub struct UdpSocket { remote: Option, broadcast: AtomicBool, + ttl: AtomicU8, + non_blocking: AtomicBool, // TODO just place packets here for one less copy? receive_queue: BoundedMpmcQueue<(SocketAddr, Vec)>, @@ -50,6 +52,10 @@ pub struct UdpSocket { pub struct TcpSocket { pub(crate) local: SocketAddr, pub(crate) remote: SocketAddr, + + ttl: AtomicU8, + non_blocking: AtomicBool, + // Listener which accepted the socket listener: Option>, connection: IrqSafeRwLock, @@ -58,6 +64,8 @@ pub struct TcpSocket { pub struct TcpListener { accept: SocketAddr, + non_blocking: AtomicBool, + // Currently active sockets sockets: IrqSafeRwLock>>, pending_accept: IrqSafeSpinlock>>, @@ -66,6 +74,7 @@ pub struct TcpListener { pub struct RawSocket { id: u32, + non_blocking: AtomicBool, bound: IrqSafeSpinlock>, receive_queue: BoundedMpmcQueue, } @@ -208,6 +217,8 @@ impl UdpSocket { log::debug!("UDP socket opened: {}", local); Arc::new(UdpSocket { local, + ttl: AtomicU8::new(64), + non_blocking: AtomicBool::new(false), remote: None, broadcast: AtomicBool::new(false), receive_queue: BoundedMpmcQueue::new(128), @@ -253,7 +264,7 @@ impl PacketSocket for UdpSocket { return Err(Error::InvalidArgument); }; // TODO check that destnation family matches self family - match (self.broadcast.load(Ordering::Relaxed), destination.ip()) { + match (self.broadcast.load(Ordering::Acquire), destination.ip()) { // SendTo in broadcast? (true, _) => todo!(), (false, _) => { @@ -262,7 +273,7 @@ impl PacketSocket for UdpSocket { self.local.port(), destination.ip(), destination.port(), - 64, + self.ttl.load(Ordering::Acquire), data, ) .await @@ -274,7 +285,14 @@ impl PacketSocket for UdpSocket { } fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error> { - let (source, data) = block!(self.receive_queue.pop_front().await)?; + let (source, data) = if self.non_blocking.load(Ordering::Acquire) { + self.receive_queue + .try_pop_front() + .ok_or(Error::WouldBlock)? + } else { + block!(self.receive_queue.pop_front().await)? + }; + if data.len() > buffer.len() { // TODO check how other OSs handle this return Err(Error::BufferTooSmall); @@ -302,9 +320,40 @@ impl Socket for UdpSocket { match option { &SocketOption::Broadcast(broadcast) => { log::debug!("{} broadcast: {}", self.local, broadcast); - self.broadcast.store(broadcast, Ordering::Relaxed); + self.broadcast.store(broadcast, Ordering::Release); Ok(()) } + &SocketOption::Ttl(ttl) => { + if ttl == 0 || ttl > 255 { + return Err(Error::InvalidArgument); + } + self.ttl.store(ttl as _, Ordering::Release); + Ok(()) + } + &SocketOption::NonBlocking(nb) => { + self.non_blocking.store(nb, Ordering::Release); + Ok(()) + } + SocketOption::RecvTimeout(_timeout) => { + log::warn!("TODO: UDP recv timeout"); + Err(Error::InvalidOperation) + } + SocketOption::SendTimeout(_timeout) => { + log::warn!("TODO: UDP send timeout"); + Err(Error::InvalidOperation) + } + SocketOption::MulticastTtlV4(_) => { + log::warn!("TODO: UDP multicast v4 timeout"); + Err(Error::InvalidOperation) + } + SocketOption::MulticastLoopV4(_) => { + log::warn!("TODO: UDP multicast loop v4"); + Err(Error::InvalidOperation) + } + SocketOption::MulticastLoopV6(_) => { + log::warn!("TODO: UDP multicast loop v6"); + Err(Error::InvalidOperation) + } _ => Err(Error::InvalidOperation), } } @@ -312,7 +361,35 @@ impl Socket for UdpSocket { fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> { match option { SocketOption::Broadcast(broadcast) => { - *broadcast = self.broadcast.load(Ordering::Relaxed); + *broadcast = self.broadcast.load(Ordering::Acquire); + Ok(()) + } + SocketOption::Ttl(ttl) => { + *ttl = self.ttl.load(Ordering::Acquire) as _; + Ok(()) + } + SocketOption::NonBlocking(nb) => { + *nb = self.non_blocking.load(Ordering::Acquire); + Ok(()) + } + SocketOption::RecvTimeout(timeout) => { + *timeout = None; + Ok(()) + } + SocketOption::SendTimeout(timeout) => { + *timeout = None; + Ok(()) + } + SocketOption::MulticastTtlV4(ttl) => { + *ttl = 64; + Ok(()) + } + SocketOption::MulticastLoopV4(loop_v4) => { + *loop_v4 = false; + Ok(()) + } + SocketOption::MulticastLoopV6(loop_v6) => { + *loop_v6 = false; Ok(()) } _ => Err(Error::InvalidOperation), @@ -325,6 +402,7 @@ impl RawSocket { let id = RAW_SOCKET_ID.fetch_add(1, Ordering::SeqCst); let socket = Self { id, + non_blocking: AtomicBool::new(false), bound: IrqSafeSpinlock::new(None), receive_queue: BoundedMpmcQueue::new(256), }; @@ -367,6 +445,10 @@ impl Socket for RawSocket { *mac = interface.mac; Ok(()) } + SocketOption::NonBlocking(nb) => { + *nb = self.non_blocking.load(Ordering::Acquire); + Ok(()) + } _ => Err(Error::InvalidOperation), } } @@ -393,6 +475,10 @@ impl Socket for RawSocket { Ok(()) } SocketOption::UnbindInterface => todo!(), + &SocketOption::NonBlocking(nb) => { + self.non_blocking.store(nb, Ordering::Release); + Ok(()) + } _ => Err(Error::InvalidOperation), } } @@ -444,7 +530,13 @@ impl PacketSocket for RawSocket { } fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error> { - let data = block!(self.receive_queue.pop_front().await)?; + let data = if self.non_blocking.load(Ordering::Acquire) { + self.receive_queue + .try_pop_front() + .ok_or(Error::WouldBlock)? + } else { + block!(self.receive_queue.pop_front().await)? + }; let full_len = data.data.len(); let len = full_len - data.l2_offset; if buffer.len() < len { @@ -484,6 +576,8 @@ impl TcpSocket { let socket = Self { local, remote, + ttl: AtomicU8::new(64), + non_blocking: AtomicBool::new(false), listener: Some(listener), connection: IrqSafeRwLock::new(connection), }; @@ -662,6 +756,8 @@ impl TcpSocket { let socket = Self { local, remote, + ttl: AtomicU8::new(64), + non_blocking: AtomicBool::new(false), listener: None, connection: IrqSafeRwLock::new(connection), }; @@ -716,6 +812,61 @@ impl Socket for TcpSocket { fn close(&self) -> Result<(), Error> { block!(self.close_async(true).await)? } + + fn set_option(&self, option: &SocketOption) -> Result<(), Error> { + match option { + &SocketOption::Ttl(ttl) => { + if ttl == 0 || ttl > 255 { + return Err(Error::InvalidArgument); + } + self.ttl.store(ttl as _, Ordering::Release); + Ok(()) + } + &SocketOption::NonBlocking(nb) => { + self.non_blocking.store(nb, Ordering::Release); + Ok(()) + } + SocketOption::RecvTimeout(_timeout) => { + log::warn!("TODO: TCP recv timeout"); + Err(Error::InvalidOperation) + } + SocketOption::SendTimeout(_timeout) => { + log::warn!("TODO: TCP send timeout"); + Err(Error::InvalidOperation) + } + SocketOption::NoDelay(_) => { + log::warn!("TODO: TCP nodelay"); + Ok(()) + } + _ => Err(Error::InvalidOperation), + } + } + + fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> { + match option { + SocketOption::Ttl(ttl) => { + *ttl = self.ttl.load(Ordering::Acquire) as _; + Ok(()) + } + SocketOption::NonBlocking(nb) => { + *nb = self.non_blocking.load(Ordering::Acquire); + Ok(()) + } + SocketOption::RecvTimeout(timeout) => { + *timeout = None; + Ok(()) + } + SocketOption::SendTimeout(timeout) => { + *timeout = None; + Ok(()) + } + SocketOption::NoDelay(nodelay) => { + *nodelay = false; + Ok(()) + } + _ => Err(Error::InvalidOperation), + } + } } impl FileReadiness for TcpSocket { @@ -726,7 +877,16 @@ impl FileReadiness for TcpSocket { impl ConnectionSocket for TcpSocket { fn receive(&self, buffer: &mut [u8]) -> Result { - block!(self.receive_async(buffer).await)? + if self.non_blocking.load(Ordering::Acquire) { + match self.connection.write().read_nonblocking(buffer) { + // TODO check if this really means "no data at the moment" + Ok(0) => Err(Error::WouldBlock), + Ok(len) => Ok(len), + Err(error) => Err(error), + } + } else { + block!(self.receive_async(buffer).await)? + } } fn send(&self, data: &[u8]) -> Result { @@ -739,6 +899,7 @@ impl TcpListener { TCP_LISTENERS.write().try_insert_with(accept, || { let listener = TcpListener { accept, + non_blocking: AtomicBool::new(false), sockets: IrqSafeRwLock::new(BTreeMap::new()), pending_accept: IrqSafeSpinlock::new(Vec::new()), accept_notify: QueueWaker::new(), @@ -810,6 +971,34 @@ impl Socket for TcpListener { // TODO if clients not closed already, send RST? TCP_LISTENERS.write().remove(self.accept) } + + fn set_option(&self, option: &SocketOption) -> Result<(), Error> { + match option { + &SocketOption::NonBlocking(nb) => { + self.non_blocking.store(nb, Ordering::Release); + Ok(()) + } + SocketOption::Ipv6Only(_v6only) => { + log::warn!("TODO: TCP listener IPv6-only"); + Err(Error::InvalidOperation) + } + _ => Err(Error::InvalidOperation), + } + } + + fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> { + match option { + SocketOption::NonBlocking(nb) => { + *nb = self.non_blocking.load(Ordering::Acquire); + Ok(()) + } + SocketOption::Ipv6Only(v6only) => { + *v6only = false; + Ok(()) + } + _ => Err(Error::InvalidOperation), + } + } } impl FileReadiness for TcpListener { @@ -820,7 +1009,11 @@ impl FileReadiness for TcpListener { impl ListenerSocket for TcpListener { fn accept(&self) -> Result<(SocketAddr, Arc), Error> { - let socket = block!(self.accept_async().await)??; + let socket = if self.non_blocking.load(Ordering::Acquire) { + todo!() + } else { + block!(self.accept_async().await)?? + }; let remote = socket.remote; Ok((remote, socket)) } diff --git a/lib/abi/src/net/mod.rs b/lib/abi/src/net/mod.rs index 73b5094c..866f9279 100644 --- a/lib/abi/src/net/mod.rs +++ b/lib/abi/src/net/mod.rs @@ -8,6 +8,8 @@ pub mod netconfig; pub mod protocols; pub mod types; +use core::time::Duration; + pub use crate::generated::SocketType; pub use types::{ @@ -45,6 +47,23 @@ pub enum SocketOption<'a> { UnbindInterface, /// (Read-only) Hardware address of the bound interface BoundHardwareAddress(MacAddress), + /// If set, reception will return [crate::error::Error::WouldBlock] if the socket has + /// no data in its queue/buffer. + NonBlocking(bool), + /// If set, the socket will be restricted to IPv6 only. + Ipv6Only(bool), + /// If not [None], receive operations will have a time limit set before returning an error. + RecvTimeout(Option), + /// If not [None], send operations will have a time limit set before returning an error. + SendTimeout(Option), + /// (UDP) If set, allows multicast packets to be looped back to local host. + MulticastLoopV4(bool), + /// (UDP) If set, allows multicast packets to be looped back to local host. + MulticastLoopV6(bool), + /// (UDP) Time-to-Live for IPv4 multicast packets. + MulticastTtlV4(u32), + /// (TCP) If set, disables any internal buffering for the socket. + NoDelay(bool), } impl<'a> From<&'a str> for SocketInterfaceQuery<'a> { diff --git a/userspace/sysutils/src/login.rs b/userspace/sysutils/src/login.rs index c22f4c30..0b583aab 100644 --- a/userspace/sysutils/src/login.rs +++ b/userspace/sysutils/src/login.rs @@ -4,7 +4,6 @@ use std::{ env, io::{self, stdin, stdout, BufRead, Write}, os::{ - fd::AsRawFd, yggdrasil::io::{ device::{DeviceRequest, FdDeviceRequest}, terminal::start_terminal_session, @@ -13,7 +12,7 @@ use std::{ process::{Command, ExitCode}, }; -fn login_readline( +fn login_readline( reader: &mut R, buf: &mut String, _secret: bool, @@ -29,7 +28,7 @@ fn login_as(username: &str, _password: &str) -> Result<(), io::Error> { } fn login_attempt(erase: bool) -> Result<(), io::Error> { - let mut stdin = stdin().lock(); + let stdin = stdin(); let mut stdout = stdout(); if erase { @@ -41,7 +40,7 @@ fn login_attempt(erase: bool) -> Result<(), io::Error> { print!("Username: "); stdout.flush().ok(); - if login_readline(&mut stdin, &mut username, false)? == 0 { + if login_readline(&mut stdin.lock(), &mut username, false)? == 0 { return Ok(()); }