net: stubs for more socket options

This commit is contained in:
Mark Poliakov 2024-11-01 01:33:18 +02:00
parent 0284456ddf
commit e43b7ee44b
3 changed files with 224 additions and 13 deletions

View File

@ -1,7 +1,7 @@
use core::{ use core::{
future::{poll_fn, Future}, future::{poll_fn, Future},
pin::Pin, pin::Pin,
sync::atomic::{AtomicBool, AtomicU32, Ordering}, sync::atomic::{AtomicBool, AtomicU32, AtomicU8, Ordering},
task::{Context, Poll}, task::{Context, Poll},
time::Duration, time::Duration,
}; };
@ -42,6 +42,8 @@ pub struct UdpSocket {
remote: Option<SocketAddr>, remote: Option<SocketAddr>,
broadcast: AtomicBool, broadcast: AtomicBool,
ttl: AtomicU8,
non_blocking: AtomicBool,
// TODO just place packets here for one less copy? // TODO just place packets here for one less copy?
receive_queue: BoundedMpmcQueue<(SocketAddr, Vec<u8>)>, receive_queue: BoundedMpmcQueue<(SocketAddr, Vec<u8>)>,
@ -50,6 +52,10 @@ pub struct UdpSocket {
pub struct TcpSocket { pub struct TcpSocket {
pub(crate) local: SocketAddr, pub(crate) local: SocketAddr,
pub(crate) remote: SocketAddr, pub(crate) remote: SocketAddr,
ttl: AtomicU8,
non_blocking: AtomicBool,
// Listener which accepted the socket // Listener which accepted the socket
listener: Option<Arc<TcpListener>>, listener: Option<Arc<TcpListener>>,
connection: IrqSafeRwLock<TcpConnection>, connection: IrqSafeRwLock<TcpConnection>,
@ -58,6 +64,8 @@ pub struct TcpSocket {
pub struct TcpListener { pub struct TcpListener {
accept: SocketAddr, accept: SocketAddr,
non_blocking: AtomicBool,
// Currently active sockets // Currently active sockets
sockets: IrqSafeRwLock<BTreeMap<SocketAddr, Arc<TcpSocket>>>, sockets: IrqSafeRwLock<BTreeMap<SocketAddr, Arc<TcpSocket>>>,
pending_accept: IrqSafeSpinlock<Vec<Arc<TcpSocket>>>, pending_accept: IrqSafeSpinlock<Vec<Arc<TcpSocket>>>,
@ -66,6 +74,7 @@ pub struct TcpListener {
pub struct RawSocket { pub struct RawSocket {
id: u32, id: u32,
non_blocking: AtomicBool,
bound: IrqSafeSpinlock<Option<u32>>, bound: IrqSafeSpinlock<Option<u32>>,
receive_queue: BoundedMpmcQueue<L2Packet>, receive_queue: BoundedMpmcQueue<L2Packet>,
} }
@ -208,6 +217,8 @@ impl UdpSocket {
log::debug!("UDP socket opened: {}", local); log::debug!("UDP socket opened: {}", local);
Arc::new(UdpSocket { Arc::new(UdpSocket {
local, local,
ttl: AtomicU8::new(64),
non_blocking: AtomicBool::new(false),
remote: None, remote: None,
broadcast: AtomicBool::new(false), broadcast: AtomicBool::new(false),
receive_queue: BoundedMpmcQueue::new(128), receive_queue: BoundedMpmcQueue::new(128),
@ -253,7 +264,7 @@ impl PacketSocket for UdpSocket {
return Err(Error::InvalidArgument); return Err(Error::InvalidArgument);
}; };
// TODO check that destnation family matches self family // TODO check that destnation family matches self family
match (self.broadcast.load(Ordering::Relaxed), destination.ip()) { match (self.broadcast.load(Ordering::Acquire), destination.ip()) {
// SendTo in broadcast? // SendTo in broadcast?
(true, _) => todo!(), (true, _) => todo!(),
(false, _) => { (false, _) => {
@ -262,7 +273,7 @@ impl PacketSocket for UdpSocket {
self.local.port(), self.local.port(),
destination.ip(), destination.ip(),
destination.port(), destination.port(),
64, self.ttl.load(Ordering::Acquire),
data, data,
) )
.await .await
@ -274,7 +285,14 @@ impl PacketSocket for UdpSocket {
} }
fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error> { fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error> {
let (source, data) = block!(self.receive_queue.pop_front().await)?; let (source, data) = if self.non_blocking.load(Ordering::Acquire) {
self.receive_queue
.try_pop_front()
.ok_or(Error::WouldBlock)?
} else {
block!(self.receive_queue.pop_front().await)?
};
if data.len() > buffer.len() { if data.len() > buffer.len() {
// TODO check how other OSs handle this // TODO check how other OSs handle this
return Err(Error::BufferTooSmall); return Err(Error::BufferTooSmall);
@ -302,9 +320,40 @@ impl Socket for UdpSocket {
match option { match option {
&SocketOption::Broadcast(broadcast) => { &SocketOption::Broadcast(broadcast) => {
log::debug!("{} broadcast: {}", self.local, broadcast); log::debug!("{} broadcast: {}", self.local, broadcast);
self.broadcast.store(broadcast, Ordering::Relaxed); self.broadcast.store(broadcast, Ordering::Release);
Ok(()) Ok(())
} }
&SocketOption::Ttl(ttl) => {
if ttl == 0 || ttl > 255 {
return Err(Error::InvalidArgument);
}
self.ttl.store(ttl as _, Ordering::Release);
Ok(())
}
&SocketOption::NonBlocking(nb) => {
self.non_blocking.store(nb, Ordering::Release);
Ok(())
}
SocketOption::RecvTimeout(_timeout) => {
log::warn!("TODO: UDP recv timeout");
Err(Error::InvalidOperation)
}
SocketOption::SendTimeout(_timeout) => {
log::warn!("TODO: UDP send timeout");
Err(Error::InvalidOperation)
}
SocketOption::MulticastTtlV4(_) => {
log::warn!("TODO: UDP multicast v4 timeout");
Err(Error::InvalidOperation)
}
SocketOption::MulticastLoopV4(_) => {
log::warn!("TODO: UDP multicast loop v4");
Err(Error::InvalidOperation)
}
SocketOption::MulticastLoopV6(_) => {
log::warn!("TODO: UDP multicast loop v6");
Err(Error::InvalidOperation)
}
_ => Err(Error::InvalidOperation), _ => Err(Error::InvalidOperation),
} }
} }
@ -312,7 +361,35 @@ impl Socket for UdpSocket {
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> { fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option { match option {
SocketOption::Broadcast(broadcast) => { SocketOption::Broadcast(broadcast) => {
*broadcast = self.broadcast.load(Ordering::Relaxed); *broadcast = self.broadcast.load(Ordering::Acquire);
Ok(())
}
SocketOption::Ttl(ttl) => {
*ttl = self.ttl.load(Ordering::Acquire) as _;
Ok(())
}
SocketOption::NonBlocking(nb) => {
*nb = self.non_blocking.load(Ordering::Acquire);
Ok(())
}
SocketOption::RecvTimeout(timeout) => {
*timeout = None;
Ok(())
}
SocketOption::SendTimeout(timeout) => {
*timeout = None;
Ok(())
}
SocketOption::MulticastTtlV4(ttl) => {
*ttl = 64;
Ok(())
}
SocketOption::MulticastLoopV4(loop_v4) => {
*loop_v4 = false;
Ok(())
}
SocketOption::MulticastLoopV6(loop_v6) => {
*loop_v6 = false;
Ok(()) Ok(())
} }
_ => Err(Error::InvalidOperation), _ => Err(Error::InvalidOperation),
@ -325,6 +402,7 @@ impl RawSocket {
let id = RAW_SOCKET_ID.fetch_add(1, Ordering::SeqCst); let id = RAW_SOCKET_ID.fetch_add(1, Ordering::SeqCst);
let socket = Self { let socket = Self {
id, id,
non_blocking: AtomicBool::new(false),
bound: IrqSafeSpinlock::new(None), bound: IrqSafeSpinlock::new(None),
receive_queue: BoundedMpmcQueue::new(256), receive_queue: BoundedMpmcQueue::new(256),
}; };
@ -367,6 +445,10 @@ impl Socket for RawSocket {
*mac = interface.mac; *mac = interface.mac;
Ok(()) Ok(())
} }
SocketOption::NonBlocking(nb) => {
*nb = self.non_blocking.load(Ordering::Acquire);
Ok(())
}
_ => Err(Error::InvalidOperation), _ => Err(Error::InvalidOperation),
} }
} }
@ -393,6 +475,10 @@ impl Socket for RawSocket {
Ok(()) Ok(())
} }
SocketOption::UnbindInterface => todo!(), SocketOption::UnbindInterface => todo!(),
&SocketOption::NonBlocking(nb) => {
self.non_blocking.store(nb, Ordering::Release);
Ok(())
}
_ => Err(Error::InvalidOperation), _ => Err(Error::InvalidOperation),
} }
} }
@ -444,7 +530,13 @@ impl PacketSocket for RawSocket {
} }
fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error> { fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error> {
let data = block!(self.receive_queue.pop_front().await)?; let data = if self.non_blocking.load(Ordering::Acquire) {
self.receive_queue
.try_pop_front()
.ok_or(Error::WouldBlock)?
} else {
block!(self.receive_queue.pop_front().await)?
};
let full_len = data.data.len(); let full_len = data.data.len();
let len = full_len - data.l2_offset; let len = full_len - data.l2_offset;
if buffer.len() < len { if buffer.len() < len {
@ -484,6 +576,8 @@ impl TcpSocket {
let socket = Self { let socket = Self {
local, local,
remote, remote,
ttl: AtomicU8::new(64),
non_blocking: AtomicBool::new(false),
listener: Some(listener), listener: Some(listener),
connection: IrqSafeRwLock::new(connection), connection: IrqSafeRwLock::new(connection),
}; };
@ -662,6 +756,8 @@ impl TcpSocket {
let socket = Self { let socket = Self {
local, local,
remote, remote,
ttl: AtomicU8::new(64),
non_blocking: AtomicBool::new(false),
listener: None, listener: None,
connection: IrqSafeRwLock::new(connection), connection: IrqSafeRwLock::new(connection),
}; };
@ -716,6 +812,61 @@ impl Socket for TcpSocket {
fn close(&self) -> Result<(), Error> { fn close(&self) -> Result<(), Error> {
block!(self.close_async(true).await)? block!(self.close_async(true).await)?
} }
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
&SocketOption::Ttl(ttl) => {
if ttl == 0 || ttl > 255 {
return Err(Error::InvalidArgument);
}
self.ttl.store(ttl as _, Ordering::Release);
Ok(())
}
&SocketOption::NonBlocking(nb) => {
self.non_blocking.store(nb, Ordering::Release);
Ok(())
}
SocketOption::RecvTimeout(_timeout) => {
log::warn!("TODO: TCP recv timeout");
Err(Error::InvalidOperation)
}
SocketOption::SendTimeout(_timeout) => {
log::warn!("TODO: TCP send timeout");
Err(Error::InvalidOperation)
}
SocketOption::NoDelay(_) => {
log::warn!("TODO: TCP nodelay");
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::Ttl(ttl) => {
*ttl = self.ttl.load(Ordering::Acquire) as _;
Ok(())
}
SocketOption::NonBlocking(nb) => {
*nb = self.non_blocking.load(Ordering::Acquire);
Ok(())
}
SocketOption::RecvTimeout(timeout) => {
*timeout = None;
Ok(())
}
SocketOption::SendTimeout(timeout) => {
*timeout = None;
Ok(())
}
SocketOption::NoDelay(nodelay) => {
*nodelay = false;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
} }
impl FileReadiness for TcpSocket { impl FileReadiness for TcpSocket {
@ -726,7 +877,16 @@ impl FileReadiness for TcpSocket {
impl ConnectionSocket for TcpSocket { impl ConnectionSocket for TcpSocket {
fn receive(&self, buffer: &mut [u8]) -> Result<usize, Error> { fn receive(&self, buffer: &mut [u8]) -> Result<usize, Error> {
block!(self.receive_async(buffer).await)? if self.non_blocking.load(Ordering::Acquire) {
match self.connection.write().read_nonblocking(buffer) {
// TODO check if this really means "no data at the moment"
Ok(0) => Err(Error::WouldBlock),
Ok(len) => Ok(len),
Err(error) => Err(error),
}
} else {
block!(self.receive_async(buffer).await)?
}
} }
fn send(&self, data: &[u8]) -> Result<usize, Error> { fn send(&self, data: &[u8]) -> Result<usize, Error> {
@ -739,6 +899,7 @@ impl TcpListener {
TCP_LISTENERS.write().try_insert_with(accept, || { TCP_LISTENERS.write().try_insert_with(accept, || {
let listener = TcpListener { let listener = TcpListener {
accept, accept,
non_blocking: AtomicBool::new(false),
sockets: IrqSafeRwLock::new(BTreeMap::new()), sockets: IrqSafeRwLock::new(BTreeMap::new()),
pending_accept: IrqSafeSpinlock::new(Vec::new()), pending_accept: IrqSafeSpinlock::new(Vec::new()),
accept_notify: QueueWaker::new(), accept_notify: QueueWaker::new(),
@ -810,6 +971,34 @@ impl Socket for TcpListener {
// TODO if clients not closed already, send RST? // TODO if clients not closed already, send RST?
TCP_LISTENERS.write().remove(self.accept) TCP_LISTENERS.write().remove(self.accept)
} }
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
&SocketOption::NonBlocking(nb) => {
self.non_blocking.store(nb, Ordering::Release);
Ok(())
}
SocketOption::Ipv6Only(_v6only) => {
log::warn!("TODO: TCP listener IPv6-only");
Err(Error::InvalidOperation)
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::NonBlocking(nb) => {
*nb = self.non_blocking.load(Ordering::Acquire);
Ok(())
}
SocketOption::Ipv6Only(v6only) => {
*v6only = false;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
} }
impl FileReadiness for TcpListener { impl FileReadiness for TcpListener {
@ -820,7 +1009,11 @@ impl FileReadiness for TcpListener {
impl ListenerSocket for TcpListener { impl ListenerSocket for TcpListener {
fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> { fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let socket = block!(self.accept_async().await)??; let socket = if self.non_blocking.load(Ordering::Acquire) {
todo!()
} else {
block!(self.accept_async().await)??
};
let remote = socket.remote; let remote = socket.remote;
Ok((remote, socket)) Ok((remote, socket))
} }

View File

@ -8,6 +8,8 @@ pub mod netconfig;
pub mod protocols; pub mod protocols;
pub mod types; pub mod types;
use core::time::Duration;
pub use crate::generated::SocketType; pub use crate::generated::SocketType;
pub use types::{ pub use types::{
@ -45,6 +47,23 @@ pub enum SocketOption<'a> {
UnbindInterface, UnbindInterface,
/// (Read-only) Hardware address of the bound interface /// (Read-only) Hardware address of the bound interface
BoundHardwareAddress(MacAddress), BoundHardwareAddress(MacAddress),
/// If set, reception will return [crate::error::Error::WouldBlock] if the socket has
/// no data in its queue/buffer.
NonBlocking(bool),
/// If set, the socket will be restricted to IPv6 only.
Ipv6Only(bool),
/// If not [None], receive operations will have a time limit set before returning an error.
RecvTimeout(Option<Duration>),
/// If not [None], send operations will have a time limit set before returning an error.
SendTimeout(Option<Duration>),
/// (UDP) If set, allows multicast packets to be looped back to local host.
MulticastLoopV4(bool),
/// (UDP) If set, allows multicast packets to be looped back to local host.
MulticastLoopV6(bool),
/// (UDP) Time-to-Live for IPv4 multicast packets.
MulticastTtlV4(u32),
/// (TCP) If set, disables any internal buffering for the socket.
NoDelay(bool),
} }
impl<'a> From<&'a str> for SocketInterfaceQuery<'a> { impl<'a> From<&'a str> for SocketInterfaceQuery<'a> {

View File

@ -4,7 +4,6 @@ use std::{
env, env,
io::{self, stdin, stdout, BufRead, Write}, io::{self, stdin, stdout, BufRead, Write},
os::{ os::{
fd::AsRawFd,
yggdrasil::io::{ yggdrasil::io::{
device::{DeviceRequest, FdDeviceRequest}, device::{DeviceRequest, FdDeviceRequest},
terminal::start_terminal_session, terminal::start_terminal_session,
@ -13,7 +12,7 @@ use std::{
process::{Command, ExitCode}, process::{Command, ExitCode},
}; };
fn login_readline<R: BufRead + AsRawFd>( fn login_readline<R: BufRead>(
reader: &mut R, reader: &mut R,
buf: &mut String, buf: &mut String,
_secret: bool, _secret: bool,
@ -29,7 +28,7 @@ fn login_as(username: &str, _password: &str) -> Result<(), io::Error> {
} }
fn login_attempt(erase: bool) -> Result<(), io::Error> { fn login_attempt(erase: bool) -> Result<(), io::Error> {
let mut stdin = stdin().lock(); let stdin = stdin();
let mut stdout = stdout(); let mut stdout = stdout();
if erase { if erase {
@ -41,7 +40,7 @@ fn login_attempt(erase: bool) -> Result<(), io::Error> {
print!("Username: "); print!("Username: ");
stdout.flush().ok(); stdout.flush().ok();
if login_readline(&mut stdin, &mut username, false)? == 0 { if login_readline(&mut stdin.lock(), &mut username, false)? == 0 {
return Ok(()); return Ok(());
} }