abi: rework socket option ABI

This commit is contained in:
Mark Poliakov 2025-01-09 19:35:58 +02:00
parent dcf3658bd1
commit ab71cac6fa
37 changed files with 1589 additions and 693 deletions

95
Cargo.lock generated
View File

@ -29,6 +29,14 @@ dependencies = [
"rustc-std-workspace-core",
]
[[package]]
name = "abi-serde"
version = "0.1.0"
dependencies = [
"compiler_builtins",
"rustc-std-workspace-core",
]
[[package]]
name = "accessor"
version = "0.3.3"
@ -186,6 +194,15 @@ dependencies = [
"syn 2.0.87",
]
[[package]]
name = "atomic-polyfill"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4"
dependencies = [
"critical-section",
]
[[package]]
name = "atomic_enum"
version = "0.3.0"
@ -376,6 +393,12 @@ version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97"
[[package]]
name = "cobs"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15"
[[package]]
name = "colorchoice"
version = "1.0.3"
@ -410,6 +433,12 @@ version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "critical-section"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b"
[[package]]
name = "crossbeam-queue"
version = "0.3.11"
@ -616,7 +645,7 @@ dependencies = [
"memoffset 0.5.6",
"num-derive",
"num-traits",
"rustc_version",
"rustc_version 0.2.3",
"static_assertions",
"unsafe_unwrap",
]
@ -763,6 +792,15 @@ dependencies = [
"url",
]
[[package]]
name = "hash32"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67"
dependencies = [
"byteorder",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
@ -780,6 +818,20 @@ dependencies = [
"foldhash",
]
[[package]]
name = "heapless"
version = "0.7.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f"
dependencies = [
"atomic-polyfill",
"hash32",
"rustc_version 0.4.1",
"serde",
"spin",
"stable_deref_trait",
]
[[package]]
name = "heck"
version = "0.5.0"
@ -1037,6 +1089,7 @@ name = "libk"
version = "0.1.0"
dependencies = [
"abi-lib",
"abi-serde",
"async-trait",
"atomic_enum",
"bitflags 2.6.0",
@ -1052,6 +1105,7 @@ dependencies = [
"libk-util",
"log",
"lru",
"postcard",
"serde",
"serde_json",
"static_assertions",
@ -1371,6 +1425,17 @@ version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6"
[[package]]
name = "postcard"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "170a2601f67cc9dba8edd8c4870b15f71a6a2dc196daec8c83f72b59dff628a8"
dependencies = [
"cobs",
"heapless",
"serde",
]
[[package]]
name = "ppv-lite86"
version = "0.2.20"
@ -1571,6 +1636,15 @@ dependencies = [
"semver 0.9.0",
]
[[package]]
name = "rustc_version"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
dependencies = [
"semver 1.0.23",
]
[[package]]
name = "rustix"
version = "0.38.38"
@ -1703,6 +1777,15 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]]
name = "spinning_top"
version = "0.2.5"
@ -1712,6 +1795,12 @@ dependencies = [
"lock_api",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "static_assertions"
version = "1.1.0"
@ -2395,6 +2484,7 @@ dependencies = [
"libk-mm",
"libk-util",
"log",
"postcard",
"serde",
"serde_json",
"yggdrasil-abi",
@ -2531,6 +2621,7 @@ version = "0.1.0"
dependencies = [
"abi-generator",
"abi-lib",
"abi-serde",
"bytemuck",
"compiler_builtins",
"prettyplease",
@ -2546,6 +2637,7 @@ dependencies = [
"aarch64-cpu",
"abi-generator",
"abi-lib",
"abi-serde",
"acpi",
"acpi-system",
"aml",
@ -2597,6 +2689,7 @@ version = "0.1.0"
dependencies = [
"abi-generator",
"abi-lib",
"abi-serde",
"cc",
"compiler_builtins",
"libm 0.2.8",

View File

@ -17,7 +17,7 @@ members = [
"lib/libyalloc",
"lib/runtime",
"lib/qemu"
]
, "lib/abi-serde"]
[workspace.dependencies]
chrono = { version = "0.4.38", default-features = false, features = ["alloc"] }
@ -46,6 +46,7 @@ yboot-proto.path = "boot/yboot-proto"
# Local libs
abi-lib.path = "lib/abi-lib"
abi-serde.path = "lib/abi-serde"
yggdrasil-abi.path = "lib/abi"
abi-generator.path = "tool/abi-generator"

View File

@ -7,6 +7,7 @@ authors = ["Mark Poliakov <mark@alnyan.me>"]
[dependencies]
abi-lib.workspace = true
abi-serde.workspace = true
yggdrasil-abi.workspace = true
kernel-arch-interface.workspace = true
libk.workspace = true

View File

@ -16,3 +16,4 @@ log.workspace = true
bytemuck.workspace = true
serde_json.workspace = true
serde.workspace = true
postcard = "1.1.1"

View File

@ -14,7 +14,10 @@ use libk::{
};
use libk_mm::PageBox;
use libk_util::{queue::BoundedMpmcQueue, sync::spin_rwlock::IrqSafeRwLock};
use yggdrasil_abi::net::{SocketAddr, SocketInterfaceQuery, SocketOption};
use yggdrasil_abi::net::{
options::{self, RawSocketOptionVariant, SocketOption},
SocketAddr, SocketInterfaceQuery,
};
use crate::{ethernet::L2Packet, interface::NetworkInterface};
@ -82,28 +85,35 @@ impl Socket for RawSocket {
Err(Error::InvalidOperation)
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
fn get_option(&self, option: u32, buffer: &mut [u8]) -> Result<usize, Error> {
let option = RawSocketOptionVariant::try_from(option)?;
match option {
SocketOption::BoundHardwareAddress(mac) => {
RawSocketOptionVariant::BoundInterface => {
let bound = self.bound.read().ok_or(Error::NotConnected)?;
let interface = NetworkInterface::get(bound).unwrap();
*mac = interface.mac;
Ok(())
options::BoundInterface::store(&SocketInterfaceQuery::ById(bound), buffer)
}
RawSocketOptionVariant::UnbindInterface => Err(Error::InvalidArgument),
RawSocketOptionVariant::BoundHardwareAddress => {
let bound = self.bound.read().ok_or(Error::NotConnected)?;
let interface = NetworkInterface::get(bound)?;
options::BoundHardwareAddress::store(&interface.mac, buffer)
}
_ => Err(Error::InvalidOperation),
}
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
fn set_option(&self, option: u32, buffer: &[u8]) -> Result<(), Error> {
let option = RawSocketOptionVariant::try_from(option)?;
match option {
SocketOption::BindInterface(query) => {
RawSocketOptionVariant::BoundInterface => {
let query = options::BoundInterface::load(buffer)?;
log::info!("raw: bind interface {query:?}");
let mut bound = self.bound.write();
if bound.is_some() {
return Err(Error::AlreadyExists);
}
let mut bound_sockets = BOUND_RAW_SOCKETS.write();
let interface = match *query {
let interface = match query {
SocketInterfaceQuery::ById(id) => NetworkInterface::get(id),
SocketInterfaceQuery::ByName(name) => NetworkInterface::query_by_name(name),
}?;
@ -113,8 +123,11 @@ impl Socket for RawSocket {
Ok(())
}
SocketOption::UnbindInterface => todo!(),
_ => Err(Error::InvalidOperation),
RawSocketOptionVariant::UnbindInterface => {
log::warn!("TODO: raw socket interface unbind");
Err(Error::NotImplemented)
}
RawSocketOptionVariant::BoundHardwareAddress => Err(Error::ReadOnly),
}
}

View File

@ -14,7 +14,10 @@ use libk::{
vfs::{ConnectionSocket, FileReadiness, Socket},
};
use libk_util::sync::spin_rwlock::IrqSafeRwLock;
use yggdrasil_abi::net::{SocketAddr, SocketOption};
use yggdrasil_abi::net::{
options::{self, SocketOption, TcpSocketOptionVariant},
SocketAddr,
};
mod listener;
mod stream;
@ -34,15 +37,14 @@ pub enum TcpSocketInner {
}
struct TcpSocketOptions {
ttl: u8,
nodelay: bool,
v6_only: bool,
}
#[allow(clippy::derivable_impls)]
impl Default for TcpSocketOptions {
fn default() -> Self {
Self {
ttl: 64,
nodelay: false,
v6_only: false,
}
@ -115,43 +117,27 @@ impl Socket for TcpSocket {
}
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match *option {
SocketOption::Ttl(ttl) => {
if !(1..256).contains(&ttl) {
return Err(Error::InvalidArgument);
}
let ttl = ttl as u8;
self.options.write().ttl = ttl;
fn set_option(&self, option: u32, buffer: &[u8]) -> Result<(), Error> {
let option = TcpSocketOptionVariant::try_from(option)?;
let mut options = self.options.write();
match option {
TcpSocketOptionVariant::NoDelay => {
options.nodelay = options::NoDelay::load(buffer)?;
Ok(())
}
SocketOption::NoDelay(nodelay) => {
self.options.write().nodelay = nodelay;
TcpSocketOptionVariant::Ipv6Only => {
options.v6_only = options::Ipv6Only::load(buffer)?;
Ok(())
}
SocketOption::Ipv6Only(v6_only) => {
self.options.write().v6_only = v6_only;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
fn get_option(&self, option: u32, buffer: &mut [u8]) -> Result<usize, Error> {
let option = TcpSocketOptionVariant::try_from(option)?;
let options = self.options.read();
match option {
SocketOption::Ttl(ttl) => {
*ttl = self.options.read().ttl as u32;
Ok(())
}
SocketOption::NoDelay(nodelay) => {
*nodelay = self.options.read().nodelay;
Ok(())
}
SocketOption::Ipv6Only(v6_only) => {
*v6_only = self.options.read().v6_only;
Ok(())
}
_ => Err(Error::InvalidOperation),
TcpSocketOptionVariant::NoDelay => options::NoDelay::store(&options.nodelay, buffer),
TcpSocketOptionVariant::Ipv6Only => options::Ipv6Only::store(&options.v6_only, buffer),
}
}
}

View File

@ -1,6 +1,5 @@
use core::{
fmt,
sync::atomic::{AtomicBool, AtomicU8, Ordering},
task::{Context, Poll},
time::Duration,
};
@ -17,7 +16,10 @@ use libk_util::{
queue::BoundedMpmcQueue,
sync::spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockReadGuard},
};
use yggdrasil_abi::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketOption};
use yggdrasil_abi::net::{
options::{self, SocketOption, UdpSocketOptionVariant},
IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr,
};
use crate::l4;
@ -27,23 +29,35 @@ pub struct UdpSocket {
local: IrqSafeRwLock<Option<SocketAddr>>,
remote: IrqSafeRwLock<Option<SocketAddr>>,
broadcast: AtomicBool,
ttl: AtomicU8,
options: IrqSafeRwLock<Options>,
// TODO just place packets here for one less copy?
receive_queue: BoundedMpmcQueue<(SocketAddr, Vec<u8>)>,
}
struct Options {
broadcast: bool,
ttl: u8,
}
static UDP_SOCKETS: IrqSafeRwLock<SocketTable<UdpSocket>> = IrqSafeRwLock::new(SocketTable::new());
impl Default for Options {
fn default() -> Self {
Self {
broadcast: false,
ttl: 64,
}
}
}
impl UdpSocket {
pub fn new() -> Arc<Self> {
Arc::new(Self {
local: IrqSafeRwLock::new(None),
remote: IrqSafeRwLock::new(None),
broadcast: AtomicBool::new(false),
ttl: AtomicU8::new(64),
options: IrqSafeRwLock::new(Options::default()),
receive_queue: BoundedMpmcQueue::new(128),
})
@ -109,8 +123,10 @@ impl PacketSocket for UdpSocket {
// If socket wasn't bound yet, bind it to an ephemeral port
let port = self.ensure_address(destination.ip().is_ipv6())?;
let options = self.options.read();
// TODO check that destnation family matches self family
match (self.broadcast.load(Ordering::Acquire), destination.ip()) {
match (options.broadcast, destination.ip()) {
// TODO broadcast
(true, _) => return Err(Error::NotImplemented),
(false, _) => {
@ -118,7 +134,7 @@ impl PacketSocket for UdpSocket {
port,
destination.ip(),
destination.port(),
self.ttl.load(Ordering::Acquire),
options.ttl,
data,
)
.await?;
@ -205,59 +221,48 @@ impl Socket for UdpSocket {
Ok(())
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
fn set_option(&self, option: u32, buffer: &[u8]) -> Result<(), Error> {
let option = UdpSocketOptionVariant::try_from(option)?;
let mut options = self.options.write();
match option {
&SocketOption::Broadcast(broadcast) => {
// log::debug!("{} broadcast: {}", self.local, broadcast);
self.broadcast.store(broadcast, Ordering::Release);
Ok(())
}
&SocketOption::Ttl(ttl) => {
if ttl == 0 || ttl > 255 {
UdpSocketOptionVariant::Ttl => {
let ttl = options::Ttl::load(buffer)?;
if !(1..256).contains(&ttl) {
return Err(Error::InvalidArgument);
}
self.ttl.store(ttl as _, Ordering::Release);
options.ttl = ttl as u8;
Ok(())
}
SocketOption::MulticastTtlV4(_) => {
log::warn!("TODO: UDP multicast v4 timeout");
Err(Error::InvalidOperation)
UdpSocketOptionVariant::Broadcast => {
options.broadcast = options::Broadcast::load(buffer)?;
Ok(())
}
SocketOption::MulticastLoopV4(_) => {
log::warn!("TODO: UDP multicast loop v4");
Err(Error::InvalidOperation)
UdpSocketOptionVariant::MulticastTtlV4
| UdpSocketOptionVariant::MulticastLoopV4
| UdpSocketOptionVariant::MulticastLoopV6 => {
log::warn!("TODO: udp multicast not yet implemented");
Err(Error::InvalidArgument)
}
SocketOption::MulticastLoopV6(_) => {
log::warn!("TODO: UDP multicast loop v6");
Err(Error::InvalidOperation)
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
fn get_option(&self, option: u32, buffer: &mut [u8]) -> Result<usize, Error> {
let option = UdpSocketOptionVariant::try_from(option)?;
let options = self.options.read();
match option {
SocketOption::Broadcast(broadcast) => {
*broadcast = self.broadcast.load(Ordering::Acquire);
Ok(())
UdpSocketOptionVariant::Ttl => {
let ttl = options.ttl as u32;
options::Ttl::store(&ttl, buffer)
}
SocketOption::Ttl(ttl) => {
*ttl = self.ttl.load(Ordering::Acquire) as _;
Ok(())
UdpSocketOptionVariant::Broadcast => {
options::Broadcast::store(&options.broadcast, buffer)
}
SocketOption::MulticastTtlV4(ttl) => {
*ttl = 64;
Ok(())
UdpSocketOptionVariant::MulticastLoopV6
| UdpSocketOptionVariant::MulticastLoopV4
| UdpSocketOptionVariant::MulticastTtlV4 => {
log::warn!("TODO: udp multicast not yet implemented");
Err(Error::InvalidArgument)
}
SocketOption::MulticastLoopV4(loop_v4) => {
*loop_v4 = false;
Ok(())
}
SocketOption::MulticastLoopV6(loop_v6) => {
*loop_v6 = false;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
}

View File

@ -13,6 +13,7 @@ crate-type = ["rlib", "dylib"]
libk-mm.workspace = true
libk-util.workspace = true
kernel-arch.workspace = true
abi-serde.workspace = true
abi-lib.workspace = true
yggdrasil-abi = { workspace = true, features = ["alloc", "serde"] }
device-api = { workspace = true, features = ["derive"] }
@ -33,6 +34,7 @@ elf.workspace = true
uuid = { version = "1.10.0", features = ["bytemuck"], default-features = false }
lru = "0.12.3"
postcard = "1.1.1"
[dev-dependencies]
tokio = { workspace = true, features = ["rt", "macros"] }

View File

@ -9,7 +9,10 @@ use async_trait::async_trait;
use libk_util::sync::spin_rwlock::IrqSafeRwLock;
use yggdrasil_abi::{
error::Error,
net::{SocketAddr, SocketOption, SocketShutdown},
net::{
options::{self, SocketOption, SocketOptionVariant},
SocketAddr, SocketShutdown,
},
};
use crate::vfs::FileReadiness;
@ -18,7 +21,6 @@ use super::{File, FileRef};
enum SocketInner {
Connection(Arc<dyn ConnectionSocket + 'static>),
// Listener(Arc<dyn ListenerSocket + Send + 'static>),
Packet(Arc<dyn PacketSocket + 'static>),
}
@ -49,12 +51,12 @@ pub trait Socket: FileReadiness + fmt::Debug + Send {
fn close(self: Arc<Self>) -> Result<(), Error>;
/// Updates a socket option
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
fn set_option(&self, option: u32, buffer: &[u8]) -> Result<(), Error> {
Err(Error::InvalidOperation)
}
/// Gets a socket option
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
fn get_option(&self, option: u32, buffer: &mut [u8]) -> Result<usize, Error> {
Err(Error::InvalidOperation)
}
}
@ -137,6 +139,13 @@ impl SocketWrapper {
}
}
pub fn as_inner(&self) -> &dyn Socket {
match &self.inner {
SocketInner::Packet(socket) => socket.as_ref(),
SocketInner::Connection(socket) => socket.as_ref(),
}
}
async fn send_inner(&self, data: &[u8], remote: Option<SocketAddr>) -> Result<usize, Error> {
let timeout = self.options.read().send_timeout;
@ -233,57 +242,84 @@ impl SocketWrapper {
)?
}
pub fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
SocketOption::NonBlocking(nb) => {
self.options.write().non_blocking = *nb;
return Ok(());
}
SocketOption::RecvTimeout(timeout) => {
self.options.write().recv_timeout = *timeout;
return Ok(());
}
SocketOption::SendTimeout(timeout) => {
self.options.write().send_timeout = *timeout;
return Ok(());
}
SocketOption::ConnectTimeout(timeout) => {
self.options.write().connect_timeout = *timeout;
return Ok(());
}
_ => (),
}
pub fn local_address(&self) -> Option<SocketAddr> {
match &self.inner {
SocketInner::Connection(socket) => socket.set_option(option),
SocketInner::Packet(socket) => socket.set_option(option),
SocketInner::Packet(socket) => socket.local_address(),
SocketInner::Connection(socket) => socket.local_address(),
}
}
pub fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::NonBlocking(nb) => {
*nb = self.options.read().non_blocking;
return Ok(());
}
SocketOption::RecvTimeout(timeout) => {
*timeout = self.options.read().recv_timeout;
return Ok(());
}
SocketOption::SendTimeout(timeout) => {
*timeout = self.options.read().send_timeout;
return Ok(());
}
SocketOption::ConnectTimeout(timeout) => {
*timeout = self.options.read().connect_timeout;
return Ok(());
}
_ => (),
pub fn remote_address(&self) -> Option<SocketAddr> {
match &self.inner {
SocketInner::Packet(socket) => socket.remote_address(),
SocketInner::Connection(socket) => socket.remote_address(),
}
}
pub fn set_option(&self, option: u32, buffer: &[u8]) -> Result<(), Error> {
if let Ok(option) = SocketOptionVariant::try_from(option) {
let mut options = self.options.write();
return match option {
SocketOptionVariant::RecvTimeout => {
options.recv_timeout = options::RecvTimeout::load(buffer)?;
Ok(())
}
SocketOptionVariant::SendTimeout => {
options.send_timeout = options::SendTimeout::load(buffer)?;
Ok(())
}
SocketOptionVariant::ConnectTimeout => {
options.connect_timeout = options::ConnectTimeout::load(buffer)?;
Ok(())
}
SocketOptionVariant::NonBlocking => {
options.non_blocking = options::NonBlocking::load(buffer)?;
Ok(())
}
SocketOptionVariant::LocalAddress | SocketOptionVariant::PeerAddress => {
Err(Error::InvalidArgument)
}
};
}
// Not a socket-level option, pass to the specific protocol level
match &self.inner {
SocketInner::Connection(socket) => socket.get_option(option),
SocketInner::Packet(socket) => socket.get_option(option),
SocketInner::Connection(socket) => socket.set_option(option, buffer),
SocketInner::Packet(socket) => socket.set_option(option, buffer),
}
}
pub fn get_option(&self, option: u32, buffer: &mut [u8]) -> Result<usize, Error> {
if let Ok(option) = SocketOptionVariant::try_from(option) {
let options = self.options.read();
return match option {
SocketOptionVariant::LocalAddress => {
let address = self.local_address().map(Into::into);
options::LocalAddress::store(&address, buffer)
}
SocketOptionVariant::PeerAddress => {
let address = self.local_address().map(Into::into);
options::PeerAddress::store(&address, buffer)
}
SocketOptionVariant::RecvTimeout => {
options::RecvTimeout::store(&options.recv_timeout, buffer)
}
SocketOptionVariant::SendTimeout => {
options::SendTimeout::store(&options.send_timeout, buffer)
}
SocketOptionVariant::ConnectTimeout => {
options::ConnectTimeout::store(&options.connect_timeout, buffer)
}
SocketOptionVariant::NonBlocking => {
options::NonBlocking::store(&options.non_blocking, buffer)
}
};
}
// Not a socket-level option, pass to the specific protocol level
match &self.inner {
SocketInner::Connection(socket) => socket.get_option(option, buffer),
SocketInner::Packet(socket) => socket.get_option(option, buffer),
}
}
}

View File

@ -3,7 +3,7 @@ use core::{mem::MaybeUninit, net::SocketAddr};
use abi::{
error::Error,
io::RawFd,
net::{SocketOption, SocketShutdown, SocketType},
net::{SocketShutdown, SocketType},
};
use libk::{
task::thread::Thread,
@ -94,16 +94,28 @@ pub(crate) fn receive_from(
Ok(len)
}
pub(crate) fn get_socket_option(sock_fd: RawFd, option: &mut SocketOption) -> Result<(), Error> {
pub(crate) fn get_socket_option(
sock_fd: RawFd,
option: u32,
value: &mut [u8],
) -> Result<usize, Error> {
let file = get_socket(sock_fd)?;
file.as_socket()?.get_option(option)
file.as_socket()?.get_option(option, value)
}
pub(crate) fn set_socket_option(sock_fd: RawFd, option: &SocketOption) -> Result<(), Error> {
pub(crate) fn set_socket_option(sock_fd: RawFd, option: u32, value: &[u8]) -> Result<(), Error> {
let file = get_socket(sock_fd)?;
file.as_socket()?.set_option(option)
file.as_socket()?.set_option(option, value)
}
// pub(crate) fn get_socket_option(sock_fd: RawFd, option: &mut SocketOption) -> Result<(), Error> {
// file.as_socket()?.get_option(option)
// }
//
// pub(crate) fn set_socket_option(sock_fd: RawFd, option: &SocketOption) -> Result<(), Error> {
// file.as_socket()?.set_option(option)
// }
// // Network
// pub(crate) fn connect_socket(
// connect: &mut SocketConnect,

18
lib/abi-serde/Cargo.toml Normal file
View File

@ -0,0 +1,18 @@
[package]
name = "abi-serde"
version = "0.1.0"
edition = "2021"
[dependencies]
compiler_builtins = { version = "0.1", optional = true }
core = { version = "1.0.0", optional = true, package = "rustc-std-workspace-core" }
[features]
default = []
rustc-dep-of-std = [
"core",
"compiler_builtins/rustc-dep-of-std",
]
[lints]
workspace = true

27
lib/abi-serde/src/des.rs Normal file
View File

@ -0,0 +1,27 @@
pub trait Deserializer<'de> {
type Error: DeserializeError;
fn read_bool(&mut self) -> Result<bool, Self::Error>;
fn read_str(&mut self) -> Result<&'de str, Self::Error>;
fn read_bytes(&mut self) -> Result<&'de [u8], Self::Error>;
fn read_i8(&mut self) -> Result<i8, Self::Error>;
fn read_i16(&mut self) -> Result<i16, Self::Error>;
fn read_i32(&mut self) -> Result<i32, Self::Error>;
fn read_i64(&mut self) -> Result<i64, Self::Error>;
fn read_u8(&mut self) -> Result<u8, Self::Error>;
fn read_u16(&mut self) -> Result<u16, Self::Error>;
fn read_u32(&mut self) -> Result<u32, Self::Error>;
fn read_u64(&mut self) -> Result<u64, Self::Error>;
fn read_enum_variant(&mut self) -> Result<u32, Self::Error>;
}
pub trait Deserialize<'de>: Sized + 'de {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error>;
}
pub trait DeserializeError {
const INVALID_ENUM_VARIANT: Self;
const INVALID_ARRAY_LEN: Self;
}

View File

@ -0,0 +1,99 @@
use crate::{
des::{DeserializeError, Deserializer},
ser::Serializer,
Deserialize, Serialize,
};
macro impl_primitive_serde($($ty:ty : [$read:ident, $write:ident]),+) {
$(
impl<'de> $crate::des::Deserialize<'de> for $ty {
fn deserialize<D: $crate::des::Deserializer<'de>>(deserializer: &mut D) -> Result<$ty, D::Error> {
deserializer.$read()
}
}
impl $crate::ser::Serialize for $ty {
fn serialize<S: $crate::ser::Serializer>(
&self,
serializer: &mut S,
) -> Result<(), S::Error> {
serializer.$write(*self)
}
}
)+
}
impl_primitive_serde!(
bool: [read_bool, write_bool],
i8: [read_i8, write_i8],
i16: [read_i16, write_i16],
i32: [read_i32, write_i32],
i64: [read_i64, write_i64],
u8: [read_u8, write_u8],
u16: [read_u16, write_u16],
u32: [read_u32, write_u32],
u64: [read_u64, write_u64]
);
impl Serialize for () {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error> {
let _ = serializer;
Ok(())
}
}
impl<'de> Deserialize<'de> for () {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error> {
let _ = deserializer;
Ok(())
}
}
impl<T: Serialize> Serialize for Option<T> {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error> {
match self {
None => serializer.write_enum_variant(0),
Some(value) => {
serializer.write_enum_variant(1)?;
value.serialize(serializer)
}
}
}
}
impl<'de, T: Deserialize<'de>> Deserialize<'de> for Option<T> {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error> {
let variant = deserializer.read_enum_variant()?;
match variant {
0 => T::deserialize(deserializer).map(Some),
1 => Ok(None),
_ => Err(D::Error::INVALID_ENUM_VARIANT),
}
}
}
impl<T: Serialize, E: Serialize> Serialize for Result<T, E> {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error> {
match self {
Ok(value) => {
serializer.write_enum_variant(0)?;
value.serialize(serializer)
}
Err(error) => {
serializer.write_enum_variant(1)?;
error.serialize(serializer)
}
}
}
}
impl<'de, T: Deserialize<'de>, E: Deserialize<'de>> Deserialize<'de> for Result<T, E> {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error> {
let variant = deserializer.read_enum_variant()?;
match variant {
0 => T::deserialize(deserializer).map(Ok),
1 => E::deserialize(deserializer).map(Err),
_ => Err(D::Error::INVALID_ENUM_VARIANT),
}
}
}

View File

@ -0,0 +1,3 @@
mod base;
mod net;
mod time;

View File

@ -0,0 +1,121 @@
use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use crate::{
des::{DeserializeError, Deserializer},
ser::Serializer,
Deserialize, Serialize,
};
impl<'de> Deserialize<'de> for Ipv4Addr {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error> {
deserializer.read_u32().map(Self::from_bits)
}
}
impl<'de> Deserialize<'de> for SocketAddrV4 {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error> {
let ip = Ipv4Addr::deserialize(deserializer)?;
let port = u16::deserialize(deserializer)?;
Ok(Self::new(ip, port))
}
}
impl Serialize for Ipv4Addr {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error> {
serializer.write_u32(self.to_bits())
}
}
impl Serialize for SocketAddrV4 {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error> {
self.ip().serialize(serializer)?;
self.port().serialize(serializer)?;
Ok(())
}
}
impl<'de> Deserialize<'de> for Ipv6Addr {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error> {
deserializer
.read_bytes()?
.try_into()
.map(Self::from_octets)
.map_err(|_| D::Error::INVALID_ARRAY_LEN)
}
}
impl<'de> Deserialize<'de> for SocketAddrV6 {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error> {
let ip = Ipv6Addr::deserialize(deserializer)?;
let port = u16::deserialize(deserializer)?;
let flowinfo = u32::deserialize(deserializer)?;
let scope_id = u32::deserialize(deserializer)?;
Ok(Self::new(ip, port, flowinfo, scope_id))
}
}
impl Serialize for Ipv6Addr {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error> {
serializer.write_bytes(&self.octets())
}
}
impl Serialize for SocketAddrV6 {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error> {
self.ip().serialize(serializer)?;
self.port().serialize(serializer)?;
self.flowinfo().serialize(serializer)?;
self.scope_id().serialize(serializer)?;
Ok(())
}
}
impl<'de> Deserialize<'de> for IpAddr {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error> {
match deserializer.read_enum_variant()? {
4 => Ipv4Addr::deserialize(deserializer).map(Self::V4),
6 => Ipv6Addr::deserialize(deserializer).map(Self::V6),
_ => Err(D::Error::INVALID_ENUM_VARIANT),
}
}
}
impl<'de> Deserialize<'de> for SocketAddr {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error> {
match deserializer.read_enum_variant()? {
4 => SocketAddrV4::deserialize(deserializer).map(Self::V4),
6 => SocketAddrV6::deserialize(deserializer).map(Self::V6),
_ => Err(D::Error::INVALID_ENUM_VARIANT),
}
}
}
impl Serialize for IpAddr {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error> {
match self {
Self::V4(v4) => {
serializer.write_enum_variant(4)?;
v4.serialize(serializer)
}
Self::V6(v6) => {
serializer.write_enum_variant(6)?;
v6.serialize(serializer)
}
}
}
}
impl Serialize for SocketAddr {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error> {
match self {
Self::V4(v4) => {
serializer.write_enum_variant(4)?;
v4.serialize(serializer)
}
Self::V6(v6) => {
serializer.write_enum_variant(6)?;
v6.serialize(serializer)
}
}
}
}

View File

@ -0,0 +1,18 @@
use core::time::Duration;
use crate::{des::Deserializer, ser::Serializer, Deserialize, Serialize};
impl<'de> Deserialize<'de> for Duration {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error> {
let seconds = deserializer.read_u64()?;
let nanoseconds = deserializer.read_u32()?;
Ok(Self::new(seconds, nanoseconds))
}
}
impl Serialize for Duration {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error> {
serializer.write_u64(self.as_secs())?;
serializer.write_u32(self.subsec_nanos())
}
}

12
lib/abi-serde/src/lib.rs Normal file
View File

@ -0,0 +1,12 @@
#![feature(decl_macro, ip_from)]
#![no_std]
pub mod des;
pub mod ser;
mod impls;
pub mod wire;
pub use des::Deserialize;
pub use ser::Serialize;

24
lib/abi-serde/src/ser.rs Normal file
View File

@ -0,0 +1,24 @@
pub trait Serializer {
type Error;
fn write_str(&mut self, data: &str) -> Result<(), Self::Error>;
fn write_bytes(&mut self, bytes: &[u8]) -> Result<(), Self::Error>;
fn write_bool(&mut self, value: bool) -> Result<(), Self::Error>;
fn write_i8(&mut self, value: i8) -> Result<(), Self::Error>;
fn write_i16(&mut self, value: i16) -> Result<(), Self::Error>;
fn write_i32(&mut self, value: i32) -> Result<(), Self::Error>;
fn write_i64(&mut self, value: i64) -> Result<(), Self::Error>;
fn write_u8(&mut self, value: u8) -> Result<(), Self::Error>;
fn write_u16(&mut self, value: u16) -> Result<(), Self::Error>;
fn write_u32(&mut self, value: u32) -> Result<(), Self::Error>;
fn write_u64(&mut self, value: u64) -> Result<(), Self::Error>;
fn write_enum_variant(&mut self, index: u32) -> Result<(), Self::Error>;
}
pub trait Serialize {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error>;
}

193
lib/abi-serde/src/wire.rs Normal file
View File

@ -0,0 +1,193 @@
use crate::{
des::{DeserializeError, Deserializer},
ser::Serializer,
Deserialize, Serialize,
};
#[derive(Debug)]
pub struct Error;
pub struct WireSerializer<'b> {
buffer: &'b mut [u8],
position: usize,
}
pub struct WireDeserializer<'b> {
buffer: &'b [u8],
position: usize,
}
impl WireSerializer<'_> {
fn write_naked_bytes(&mut self, bytes: &[u8]) -> Result<(), Error> {
if self.position + bytes.len() > self.buffer.len() {
return Err(Error);
}
self.buffer[self.position..self.position + bytes.len()].copy_from_slice(bytes);
self.position += bytes.len();
Ok(())
}
}
impl<'de> WireDeserializer<'de> {
fn read_naked_bytes(&mut self, len: usize) -> Result<&'de [u8], Error> {
if self.position + len > self.buffer.len() {
return Err(Error);
}
let slice = &self.buffer[self.position..self.position + len];
self.position += len;
Ok(slice)
}
}
impl Serializer for WireSerializer<'_> {
type Error = Error;
fn write_i8(&mut self, value: i8) -> Result<(), Self::Error> {
self.write_naked_bytes(&[value as u8])
}
fn write_u8(&mut self, value: u8) -> Result<(), Self::Error> {
self.write_naked_bytes(&[value])
}
fn write_i16(&mut self, value: i16) -> Result<(), Self::Error> {
self.write_naked_bytes(&value.to_ne_bytes())
}
fn write_u16(&mut self, value: u16) -> Result<(), Self::Error> {
self.write_naked_bytes(&value.to_ne_bytes())
}
fn write_i32(&mut self, value: i32) -> Result<(), Self::Error> {
self.write_naked_bytes(&value.to_ne_bytes())
}
fn write_u32(&mut self, value: u32) -> Result<(), Self::Error> {
self.write_naked_bytes(&value.to_ne_bytes())
}
fn write_i64(&mut self, value: i64) -> Result<(), Self::Error> {
self.write_naked_bytes(&value.to_ne_bytes())
}
fn write_u64(&mut self, value: u64) -> Result<(), Self::Error> {
self.write_naked_bytes(&value.to_ne_bytes())
}
fn write_str(&mut self, data: &str) -> Result<(), Self::Error> {
self.write_bytes(data.as_bytes())
}
fn write_bool(&mut self, value: bool) -> Result<(), Self::Error> {
self.write_naked_bytes(&[value as u8])
}
fn write_bytes(&mut self, bytes: &[u8]) -> Result<(), Self::Error> {
if bytes.len() >= u16::MAX as usize {
return Err(Error);
}
self.write_u16(bytes.len() as u16)?;
self.write_naked_bytes(bytes)
}
fn write_enum_variant(&mut self, index: u32) -> Result<(), Self::Error> {
self.write_u32(index)
}
}
impl<'de> Deserializer<'de> for WireDeserializer<'de> {
type Error = Error;
fn read_i8(&mut self) -> Result<i8, Self::Error> {
Ok(self.read_naked_bytes(1)?[0] as i8)
}
fn read_u8(&mut self) -> Result<u8, Self::Error> {
Ok(self.read_naked_bytes(1)?[0])
}
fn read_i16(&mut self) -> Result<i16, Self::Error> {
self.read_naked_bytes(size_of::<i16>())?
.try_into()
.map(i16::from_ne_bytes)
.map_err(|_| Error)
}
fn read_u16(&mut self) -> Result<u16, Self::Error> {
self.read_naked_bytes(size_of::<u16>())?
.try_into()
.map(u16::from_ne_bytes)
.map_err(|_| Error)
}
fn read_i32(&mut self) -> Result<i32, Self::Error> {
self.read_naked_bytes(size_of::<i32>())?
.try_into()
.map(i32::from_ne_bytes)
.map_err(|_| Error)
}
fn read_u32(&mut self) -> Result<u32, Self::Error> {
self.read_naked_bytes(size_of::<u32>())?
.try_into()
.map(u32::from_ne_bytes)
.map_err(|_| Error)
}
fn read_i64(&mut self) -> Result<i64, Self::Error> {
self.read_naked_bytes(size_of::<i64>())?
.try_into()
.map(i64::from_ne_bytes)
.map_err(|_| Error)
}
fn read_u64(&mut self) -> Result<u64, Self::Error> {
self.read_naked_bytes(size_of::<u64>())?
.try_into()
.map(u64::from_ne_bytes)
.map_err(|_| Error)
}
fn read_str(&mut self) -> Result<&'de str, Self::Error> {
let bytes = self.read_bytes()?;
core::str::from_utf8(bytes).map_err(|_| Error)
}
fn read_bytes(&mut self) -> Result<&'de [u8], Self::Error> {
let len = self.read_u16()? as usize;
self.read_naked_bytes(len)
}
fn read_bool(&mut self) -> Result<bool, Self::Error> {
Ok(self.read_u8()? != 0)
}
fn read_enum_variant(&mut self) -> Result<u32, Self::Error> {
self.read_u32()
}
}
impl DeserializeError for Error {
const INVALID_ARRAY_LEN: Self = Self;
const INVALID_ENUM_VARIANT: Self = Self;
}
pub fn to_slice<T: Serialize>(value: &T, buffer: &mut [u8]) -> Result<usize, Error> {
let mut ser = WireSerializer {
buffer,
position: 0,
};
value.serialize(&mut ser)?;
Ok(ser.position)
}
pub fn from_slice<'de, T: Deserialize<'de>>(buffer: &'de [u8]) -> Result<T, Error> {
let mut des = WireDeserializer {
buffer,
position: 0,
};
T::deserialize(&mut des)
}

View File

@ -11,6 +11,8 @@ core = { version = "1.0.0", optional = true, package = "rustc-std-workspace-core
rustc_std_alloc = { version = "1.0.0", optional = true, package = "rustc-std-workspace-alloc" }
compiler_builtins = { version = "0.1", optional = true }
abi-serde = { path = "../abi-serde" }
serde = { version = "1.0.193", features = ["derive"], default-features = false, optional = true }
bytemuck = { version = "1.14.0", features = ["derive"], optional = true }
@ -29,5 +31,6 @@ rustc-dep-of-std = [
"core",
"rustc_std_alloc",
"compiler_builtins/rustc-dep-of-std",
"abi-lib/rustc-dep-of-std"
"abi-lib/rustc-dep-of-std",
"abi-serde/rustc-dep-of-std"
]

View File

@ -177,8 +177,8 @@ syscall accept(sock_fd: RawFd, remote: &mut MaybeUninit<SocketAddr>) -> Result<R
syscall shutdown(sock_fd: RawFd, how: SocketShutdown) -> Result<()>;
syscall send_to(sock_fd: RawFd, data: &[u8], remote: &Option<SocketAddr>) -> Result<usize>;
syscall receive_from(sock_fd: RawFd, data: &mut [u8], remote: &mut MaybeUninit<SocketAddr>) -> Result<usize>;
syscall get_socket_option(sock_fd: RawFd, option: &mut SocketOption<'_>) -> Result<()>;
syscall set_socket_option(sock_fd: RawFd, option: &SocketOption<'_>) -> Result<()>;
syscall get_socket_option(sock_fd: RawFd, option: u32, value: &mut [u8]) -> Result<usize>;
syscall set_socket_option(sock_fd: RawFd, option: u32, value: &[u8]) -> Result<()>;
// C compat
syscall fork() -> Result<ProcessId>;

View File

@ -34,6 +34,12 @@ pub mod error {
pub use crate::generated::Error;
// TODO have syscall-generator implement TryFrom<#repr> for #enum
impl From<abi_serde::wire::Error> for Error {
fn from(_value: abi_serde::wire::Error) -> Self {
Self::InvalidArgument
}
}
}
pub use generated::SyscallFunction;

View File

@ -15,14 +15,14 @@ macro_rules! primitive_enum {
}
impl TryFrom<$repr> for $name {
type Error = ();
type Error = $crate::error::Error;
fn try_from(v: $repr) -> Result<$name, ()> {
fn try_from(v: $repr) -> Result<$name, Self::Error> {
match v {
$(
$discriminant => Ok($name::$variant)
,)+
_ => Err(())
_ => Err($crate::error::Error::InvalidArgument)
}
}
}

View File

@ -1,15 +1,20 @@
//! Defines data types for network operations
use abi_serde::{
des::{DeserializeError, Deserializer},
ser::Serializer,
Deserialize, Serialize,
};
#[cfg(any(feature = "alloc", feature = "rustc_std_alloc"))]
pub mod dns;
#[cfg(feature = "alloc")]
pub mod netconfig;
pub mod options;
pub mod protocols;
pub mod types;
use core::time::Duration;
pub use crate::generated::{SocketShutdown, SocketType};
pub use types::{
@ -20,7 +25,8 @@ pub use types::{
};
/// Describes a method to query an interface
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[derive(Clone, Debug, PartialEq)]
pub enum SocketInterfaceQuery<'a> {
/// Query by name
ByName(&'a str),
@ -28,44 +34,6 @@ pub enum SocketInterfaceQuery<'a> {
ById(u32),
}
/// Describes a socket operation parameter
#[derive(Clone, Debug)]
pub enum SocketOption<'a> {
/// Whether the socket is broadcast
Broadcast(bool),
/// Time-to-Live parameter
Ttl(u32),
/// Bind the socket to a specific interface
BindInterface(SocketInterfaceQuery<'a>),
/// Unbind the socket from an interface
UnbindInterface,
/// (Read-only) Hardware address of the bound interface
BoundHardwareAddress(MacAddress),
/// (Read-only) Local socket address
LocalAddress(Option<core::net::SocketAddr>),
/// (Read-only) Remote socket address
PeerAddress(Option<core::net::SocketAddr>),
/// If set, reception will return [crate::error::Error::WouldBlock] if the socket has
/// no data in its queue/buffer.
NonBlocking(bool),
/// If set, the socket will be restricted to IPv6 only.
Ipv6Only(bool),
/// If not [None], receive operations will have a time limit set before returning an error.
RecvTimeout(Option<Duration>),
/// If not [None], send operations will have a time limit set before returning an error.
SendTimeout(Option<Duration>),
/// If not [None], connect() call will timeout after the specified time limit.
ConnectTimeout(Option<Duration>),
/// (UDP) If set, allows multicast packets to be looped back to local host.
MulticastLoopV4(bool),
/// (UDP) If set, allows multicast packets to be looped back to local host.
MulticastLoopV6(bool),
/// (UDP) Time-to-Live for IPv4 multicast packets.
MulticastTtlV4(u32),
/// (TCP) If set, disables any internal buffering for the socket.
NoDelay(bool),
}
impl<'a> From<&'a str> for SocketInterfaceQuery<'a> {
fn from(value: &'a str) -> Self {
Self::ByName(value)
@ -77,3 +45,50 @@ impl From<u32> for SocketInterfaceQuery<'_> {
Self::ById(value)
}
}
impl<'de> Deserialize<'de> for SocketInterfaceQuery<'de> {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error> {
match deserializer.read_enum_variant()? {
1 => deserializer.read_u32().map(Self::ById),
2 => deserializer.read_str().map(Self::ByName),
_ => Err(D::Error::INVALID_ENUM_VARIANT),
}
}
}
impl Serialize for SocketInterfaceQuery<'_> {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error> {
match *self {
Self::ById(id) => {
serializer.write_enum_variant(1)?;
serializer.write_u32(id)
}
Self::ByName(name) => {
serializer.write_enum_variant(2)?;
serializer.write_str(name)
}
}
}
}
#[cfg(test)]
mod tests {
use abi_serde::wire;
use super::SocketInterfaceQuery;
#[test]
fn socket_interface_query_serde() {
let mut buffer = [0; 512];
let source = SocketInterfaceQuery::ById(123);
let len = wire::to_slice(&source, &mut buffer).unwrap();
let result: SocketInterfaceQuery<'_> = wire::from_slice(&buffer[..len]).unwrap();
assert_eq!(source, result);
let mut buffer = [0; 512];
let source = SocketInterfaceQuery::ByName("hello");
let len = wire::to_slice(&source, &mut buffer).unwrap();
let result: SocketInterfaceQuery<'_> = wire::from_slice(&buffer[..len]).unwrap();
assert_eq!(source, result);
}
}

129
lib/abi/src/net/options.rs Normal file
View File

@ -0,0 +1,129 @@
//! Socket option definitions
use core::{net::SocketAddr, time::Duration};
use abi_serde::{wire, Deserialize, Serialize};
use crate::error::Error;
use super::{MacAddress, SocketInterfaceQuery};
#[allow(missing_docs)]
pub trait SocketOption<'de> {
type Value: Deserialize<'de> + Serialize;
type Variant: Copy + Into<u32> + TryFrom<u32, Error = Error>;
const VARIANT: Self::Variant;
fn store(value: &Self::Value, buffer: &mut [u8]) -> Result<usize, Error> {
Ok(wire::to_slice(value, buffer)?)
}
fn load(buffer: &'de [u8]) -> Result<Self::Value, Error> {
Ok(wire::from_slice(buffer)?)
}
}
#[allow(missing_docs)]
pub trait SocketOptionSizeHint {
const SIZE_HINT: usize;
}
macro_rules! socket_option_group {
(
$(#[$enum_meta:meta])*
$vis:vis enum $variant_ty:ident<$lifetime:lifetime> {
$(
$(#[$variant_meta:meta])*
$discriminant:literal: $variant:ident $(# $size_hint:literal)? ($value:ty)
),* $(,)?
}
) => {
$crate::primitive_enum! {
$(#[$enum_meta])*
$vis enum $variant_ty: u32 {
$(
$(#[$variant_meta])*
$variant = $discriminant
),*
}
}
$(
$(#[$variant_meta])*
$vis struct $variant;
impl<$lifetime> $crate::net::options::SocketOption<$lifetime> for $variant {
type Variant = $variant_ty;
type Value = $value;
const VARIANT: Self::Variant = $variant_ty::$variant;
}
$(
impl $crate::net::options::SocketOptionSizeHint for $variant {
const SIZE_HINT: usize = $size_hint;
}
)?
)*
};
}
socket_option_group!(
#[doc = "Common socket options"]
pub enum SocketOptionVariant<'de> {
#[doc = "(Read-only) Local address of the socket"]
0x1000: LocalAddress # 32 (Option<SocketAddr>),
#[doc = "(Read-only) Remote address of the socket"]
0x1001: PeerAddress # 32 (Option<SocketAddr>),
#[doc = "If not [None], receive/accept operations time out after a certain amount"]
0x1002: RecvTimeout # 32 (Option<Duration>),
#[doc = "If not [None], send operations time out after a certain amount"]
0x1003: SendTimeout # 32 (Option<Duration>),
#[doc = "If not [None], connect operations time out after a certain amount"]
0x1004: ConnectTimeout # 32 (Option<Duration>),
#[doc = "
If `true`, receive/accept/send operations return an error if no
data/buffer space is immediately available.
"]
0x1005: NonBlocking # 4 (bool),
}
);
socket_option_group!(
#[doc = "UDP socket protocol-level options"]
pub enum UdpSocketOptionVariant<'de> {
#[doc = "UDP unicast packet Time-to-Live"]
0x2FFF: Ttl # 4 (u32),
#[doc = "If `true`, allows broadcast packets to be sent from the socket"]
0x2000: Broadcast # 4 (bool),
#[doc = "If `true`, IPv4 multicast packets loop back to the local host"]
0x2001: MulticastLoopV4 # 4 (bool),
#[doc = "If `true`, IPv6 multicast packets loop back to the local host"]
0x2002: MulticastLoopV6 # 4 (bool),
#[doc = "Time-to-Live for IPv4 multicast packets"]
0x2003: MulticastTtlV4 # 4 (u32),
}
);
socket_option_group!(
#[doc = "TCP socket protocol-level options"]
pub enum TcpSocketOptionVariant<'de> {
#[doc = "If `true`, disables the Nagle's algorithm (which is not yet implemented)"]
0x3000: NoDelay # 4 (bool),
#[doc = "If `true`, only allows IPv6 connections to be made to a IPv6 TCP listener"]
0x3001: Ipv6Only # 4 (bool),
}
);
socket_option_group!(
#[doc = "Raw socket options"]
pub enum RawSocketOptionVariant<'de> {
#[doc = "
If a raw socket is bound to an interface, it will
receive packets from only that interface
"]
0xF000: BoundInterface(SocketInterfaceQuery<'de>),
#[doc = "(Write-only) unbinds the socket from its current interface"]
0xF001: UnbindInterface # 0 (()),
#[doc = "(Read-only) currently bound interface's hardware address"]
0xF002: BoundHardwareAddress # 32 (MacAddress),
}
);

View File

@ -2,6 +2,12 @@
use core::fmt;
use abi_serde::{
des::{DeserializeError, Deserializer},
ser::Serializer,
Deserialize, Serialize,
};
pub mod ip_addr;
pub mod net_value;
pub mod socket_addr;
@ -47,3 +53,19 @@ impl fmt::Display for MacAddress {
Ok(())
}
}
impl<'de> Deserialize<'de> for MacAddress {
fn deserialize<D: Deserializer<'de>>(deserializer: &mut D) -> Result<Self, D::Error> {
deserializer
.read_bytes()?
.try_into()
.map_err(|_| D::Error::INVALID_ARRAY_LEN)
.map(Self)
}
}
impl Serialize for MacAddress {
fn serialize<S: Serializer>(&self, serializer: &mut S) -> Result<(), S::Error> {
serializer.write_bytes(&self.0)
}
}

View File

@ -1,7 +1,5 @@
use core::ffi::CStr;
use alloc::string::String;
use crate::error::Error;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
@ -109,9 +107,10 @@ impl Path {
&self.0
}
#[cfg(feature = "alloc")]
#[allow(clippy::inherent_to_string)]
#[inline]
pub fn to_string(&self) -> String {
pub fn to_string(&self) -> alloc::string::String {
self.0.into()
}
}

View File

@ -15,6 +15,7 @@ pub const MILLISECONDS_IN_SECOND: u64 = 1_000;
/// Represents a point of time as measured by some system clock
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[repr(C)]
pub struct SystemTime {
seconds: u64,
nanoseconds: u64,

View File

@ -14,6 +14,7 @@ compiler_builtins = { version = "0.1", optional = true }
libm = { git = "https://git.alnyan.me/yggdrasil/libm.git", optional = true }
abi-lib = { path = "../abi-lib" }
abi-serde = { path = "../abi-serde" }
[build-dependencies]
abi-generator = { path = "../../tool/abi-generator" }
@ -29,5 +30,6 @@ rustc-dep-of-std = [
"compiler_builtins/rustc-dep-of-std",
"yggdrasil-abi/rustc-dep-of-std",
"abi-lib/rustc-dep-of-std",
"abi-serde/rustc-dep-of-std",
"libm/rustc-dep-of-std"
]

View File

@ -5,11 +5,12 @@
f128,
linkage,
naked_functions,
thread_local
thread_local,
generic_const_exprs
)]
#![no_std]
#![warn(missing_docs)]
#![allow(nonstandard_style, clippy::new_without_default)]
#![allow(nonstandard_style, clippy::new_without_default, incomplete_features)]
#[cfg(not(rust_analyzer))]
#[allow(unused_extern_crates)]

View File

@ -1,146 +0,0 @@
//! Network-related functions and types
use core::{net::SocketAddr, time::Duration};
pub use abi::net::{MacAddress, SocketInterfaceQuery, SocketOption, SocketShutdown, SocketType};
use abi::{error::Error, io::RawFd};
#[allow(unused_macros)]
macro socket_option_variant {
($opt:ident: bool) => { $crate::net::SocketOption::$opt(false) },
($opt:ident: Option) => { $crate::net::SocketOption::$opt(None) },
($opt:ident: int) => { $crate::net::SocketOption::$opt(0) }
}
/// Helper macro for getting a [SocketOption] with 1 argument
pub macro get_socket_option1($fd:expr, $opt:ident: $ty:tt) {{
let mut option = socket_option_variant!($opt: $ty);
match unsafe { $crate::sys::get_socket_option($fd, &mut option) } {
Ok(()) => {
let $crate::net::SocketOption::$opt(value) = option else { unreachable!() };
Ok(value)
}
Err(error) => {
Err(error)
}
}
}}
#[cfg(any(feature = "alloc", rust_analyzer))]
pub mod dns {
//! DNS resolver module for `std`
use abi::net::dns::DnsSerialize;
pub use abi::net::dns::{
DnsClass, DnsMessage, DnsMethod, DnsRecordData, DnsReplyCode, DnsType,
};
/// Interface for performing UDP communication
pub trait UdpRequester {
/// Error returned by the requester
type Error;
/// Sends a message through the communication channel
fn send_message(&mut self, message: &[u8]) -> Result<(), Self::Error>;
/// Wait for a message, parses it and runs a function on it to determine whether to return
fn receive_message<F: Fn(&[u8]) -> Option<DnsMessage>>(
&mut self,
map: F,
) -> Result<DnsMessage, Self::Error>;
}
/// Performs a raw DNS query without analyzing the reply
pub fn perform_query<R: UdpRequester>(
requester: &mut R,
name: &str,
ty: DnsType,
xid: u16,
cookie: u64,
) -> Result<DnsMessage, R::Error> {
let mut packet = alloc::vec![];
let query = DnsMessage::query(name, ty, xid, cookie);
query.serialize(&mut packet);
requester.send_message(&packet)?;
let message = requester.receive_message(|data| {
let message = DnsMessage::parse(data)?;
if message.xid != xid || message.method != DnsMethod::REPLY {
return None;
}
Some(message)
})?;
Ok(message)
}
}
fn bind_inner(fd: RawFd, local: &SocketAddr, listen: bool) -> Result<(), Error> {
unsafe { crate::sys::bind(fd, local) }?;
if listen {
unsafe { crate::sys::listen(fd) }?;
}
Ok(())
}
fn connect_inner(fd: RawFd, remote: &SocketAddr, timeout: Option<Duration>) -> Result<(), Error> {
if timeout.is_some() {
unsafe { crate::sys::set_socket_option(fd, &SocketOption::ConnectTimeout(timeout)) }?;
}
unsafe { crate::sys::connect(fd, remote) }?;
Ok(())
}
/// Creates a new socket and binds it to a local address
pub fn create_and_bind(ty: SocketType, local: &SocketAddr, listen: bool) -> Result<RawFd, Error> {
let fd = unsafe { crate::sys::create_socket(ty) }?;
match bind_inner(fd, local, listen) {
Ok(()) => Ok(fd),
Err(error) => {
unsafe { crate::sys::close(fd) }.ok();
Err(error)
}
}
}
/// Binds a TCP listener socket to some local address
pub fn bind_tcp(local: &SocketAddr) -> Result<RawFd, Error> {
create_and_bind(SocketType::TcpStream, local, true)
}
/// Binds a raw socket to some network interface
pub fn bind_raw(iface: SocketInterfaceQuery<'_>) -> Result<RawFd, Error> {
let fd = unsafe { crate::sys::create_socket(SocketType::RawPacket) }?;
let option = SocketOption::BindInterface(iface);
match unsafe { crate::sys::set_socket_option(fd, &option) } {
Ok(()) => Ok(fd),
Err(error) => {
unsafe { crate::sys::close(fd) }.ok();
Err(error)
}
}
}
/// Binds an UDP socket to some local address
pub fn bind_udp(local: &SocketAddr) -> Result<RawFd, Error> {
create_and_bind(SocketType::UdpPacket, local, false)
}
/// Connect to a TCP listener
pub fn connect_tcp(remote: &SocketAddr, timeout: Option<Duration>) -> Result<RawFd, Error> {
let fd = unsafe { crate::sys::create_socket(SocketType::TcpStream) }?;
match connect_inner(fd, remote, timeout) {
Ok(()) => Ok(fd),
Err(error) => {
unsafe { crate::sys::close(fd) }.ok();
Err(error)
}
}
}
/// "Connect" an UDP socket
pub fn connect_udp(socket_fd: RawFd, remote: &SocketAddr) -> Result<(), Error> {
connect_inner(socket_fd, remote, None)
}

View File

@ -0,0 +1,45 @@
//! DNS resolver module for `std`
use abi::net::dns::DnsSerialize;
pub use abi::net::dns::{DnsClass, DnsMessage, DnsMethod, DnsRecordData, DnsReplyCode, DnsType};
/// Interface for performing UDP communication
pub trait UdpRequester {
/// Error returned by the requester
type Error;
/// Sends a message through the communication channel
fn send_message(&mut self, message: &[u8]) -> Result<(), Self::Error>;
/// Wait for a message, parses it and runs a function on it to determine whether to return
fn receive_message<F: Fn(&[u8]) -> Option<DnsMessage>>(
&mut self,
map: F,
) -> Result<DnsMessage, Self::Error>;
}
/// Performs a raw DNS query without analyzing the reply
pub fn perform_query<R: UdpRequester>(
requester: &mut R,
name: &str,
ty: DnsType,
xid: u16,
cookie: u64,
) -> Result<DnsMessage, R::Error> {
let mut packet = alloc::vec![];
let query = DnsMessage::query(name, ty, xid, cookie);
query.serialize(&mut packet);
requester.send_message(&packet)?;
let message = requester.receive_message(|data| {
let message = DnsMessage::parse(data)?;
if message.xid != xid || message.method != DnsMethod::REPLY {
return None;
}
Some(message)
})?;
Ok(message)
}

View File

@ -0,0 +1,13 @@
//! Network-related functions and types
#[cfg(any(feature = "alloc", rust_analyzer))]
pub mod dns;
pub mod socket;
pub use abi::net::{options, MacAddress, SocketInterfaceQuery, SocketShutdown, SocketType};
pub use socket::{
bind_raw, bind_tcp, bind_udp, connect_tcp, connect_udp, get_socket_option, local_address,
peer_address, set_socket_option, set_socket_option_with,
};

View File

@ -0,0 +1,131 @@
//! Socket management functions
use core::{net::SocketAddr, time::Duration};
use abi::{
error::Error,
io::RawFd,
net::{
options::{self, SocketOption, SocketOptionSizeHint},
SocketInterfaceQuery, SocketType,
},
};
use abi_serde::wire;
/// Short-hand macro for [get_socket_option].
///
/// Automatically sets up the `buffer` argument for socket options which implement
/// [SocketOptionSizeHint].
pub macro get_socket_option($fd:expr, $variant_ty:ty) {{
let mut buffer = [0; <$variant_ty as $crate::net::options::SocketOptionSizeHint>::SIZE_HINT];
$crate::net::get_socket_option::<$variant_ty>($fd, &mut buffer)
}}
/// Retrieves the value of a socket's option
pub fn get_socket_option<'de, T: SocketOption<'de>>(
fd: RawFd,
buffer: &'de mut [u8],
) -> Result<T::Value, Error> {
let len = unsafe { crate::sys::get_socket_option(fd, T::VARIANT.into(), buffer) }?;
Ok(wire::from_slice(&buffer[..len])?)
}
/// Sets the value of a socket's option, using the buffer provided to marshall the payload
pub fn set_socket_option_with<'de, T: SocketOption<'de>>(
fd: RawFd,
buffer: &mut [u8],
value: &T::Value,
) -> Result<(), Error> {
let len = wire::to_slice(value, buffer)?;
unsafe { crate::sys::set_socket_option(fd, T::VARIANT.into(), &buffer[..len]) }
}
/// Sets the value of a socket's option, using a buffer set up automatically based on
/// [SocketOptionSizeHint]. See [set_socket_option_with].
pub fn set_socket_option<'de, T: SocketOption<'de> + SocketOptionSizeHint>(
fd: RawFd,
value: &T::Value,
) -> Result<(), Error>
where
[u8; T::SIZE_HINT]: Sized,
{
let mut buffer = [0; T::SIZE_HINT];
set_socket_option_with::<T>(fd, &mut buffer, value)
}
/// Returns the socket's local address or an error if it isn't bound yet
pub fn local_address(fd: RawFd) -> Result<SocketAddr, Error> {
let mut buffer = [0; options::LocalAddress::SIZE_HINT];
get_socket_option::<options::LocalAddress>(fd, &mut buffer)?.ok_or(Error::InvalidOperation)
}
/// Returns the socket's remote address or an error if the socket is not connected
pub fn peer_address(fd: RawFd) -> Result<SocketAddr, Error> {
let mut buffer = [0; options::PeerAddress::SIZE_HINT];
get_socket_option::<options::PeerAddress>(fd, &mut buffer)?.ok_or(Error::NotConnected)
}
fn bind_inner(fd: RawFd, local: &SocketAddr, listen: bool) -> Result<(), Error> {
unsafe { crate::sys::bind(fd, local) }?;
if listen {
unsafe { crate::sys::listen(fd) }?;
}
Ok(())
}
fn connect_inner(fd: RawFd, remote: &SocketAddr, timeout: Option<Duration>) -> Result<(), Error> {
set_socket_option::<options::ConnectTimeout>(fd, &timeout)?;
unsafe { crate::sys::connect(fd, remote) }?;
Ok(())
}
/// Creates a new socket and binds it to a local address
pub fn create_and_bind(ty: SocketType, local: &SocketAddr, listen: bool) -> Result<RawFd, Error> {
let fd = unsafe { crate::sys::create_socket(ty) }?;
match bind_inner(fd, local, listen) {
Ok(()) => Ok(fd),
Err(error) => {
unsafe { crate::sys::close(fd) }.ok();
Err(error)
}
}
}
/// Binds a TCP listener socket to some local address
pub fn bind_tcp(local: &SocketAddr) -> Result<RawFd, Error> {
create_and_bind(SocketType::TcpStream, local, true)
}
/// Binds a raw socket to some network interface
pub fn bind_raw(iface: SocketInterfaceQuery<'_>) -> Result<RawFd, Error> {
let mut buffer = [0; 128];
let fd = unsafe { crate::sys::create_socket(SocketType::RawPacket) }?;
match set_socket_option_with::<options::BoundInterface>(fd, &mut buffer, &iface) {
Ok(()) => Ok(fd),
Err(error) => {
unsafe { crate::sys::close(fd) }.ok();
Err(error)
}
}
}
/// Binds an UDP socket to some local address
pub fn bind_udp(local: &SocketAddr) -> Result<RawFd, Error> {
create_and_bind(SocketType::UdpPacket, local, false)
}
/// Connect to a TCP listener
pub fn connect_tcp(remote: &SocketAddr, timeout: Option<Duration>) -> Result<RawFd, Error> {
let fd = unsafe { crate::sys::create_socket(SocketType::TcpStream) }?;
match connect_inner(fd, remote, timeout) {
Ok(()) => Ok(fd),
Err(error) => {
unsafe { crate::sys::close(fd) }.ok();
Err(error)
}
}
}
/// "Connect" an UDP socket
pub fn connect_udp(socket_fd: RawFd, remote: &SocketAddr) -> Result<(), Error> {
connect_inner(socket_fd, remote, None)
}

6
userspace/Cargo.lock generated
View File

@ -16,6 +16,10 @@ dependencies = [
name = "abi-lib"
version = "0.1.0"
[[package]]
name = "abi-serde"
version = "0.1.0"
[[package]]
name = "aes"
version = "0.8.4"
@ -1841,6 +1845,7 @@ version = "0.1.0"
dependencies = [
"abi-generator",
"abi-lib",
"abi-serde",
"bytemuck",
"prettyplease",
"serde",
@ -1852,6 +1857,7 @@ version = "0.1.0"
dependencies = [
"abi-generator",
"abi-lib",
"abi-serde",
"cc",
"prettyplease",
"yggdrasil-abi",

View File

@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
version = 4
[[package]]
name = "abi-generator"
@ -16,6 +16,10 @@ dependencies = [
name = "abi-lib"
version = "0.1.0"
[[package]]
name = "abi-serde"
version = "0.1.0"
[[package]]
name = "cc"
version = "1.0.90"
@ -100,6 +104,7 @@ version = "0.1.0"
dependencies = [
"abi-generator",
"abi-lib",
"abi-serde",
"prettyplease",
]
@ -109,6 +114,7 @@ version = "0.1.0"
dependencies = [
"abi-generator",
"abi-lib",
"abi-serde",
"cc",
"prettyplease",
"yggdrasil-abi",

View File

@ -1,349 +1,347 @@
fn main() {}
#![feature(yggdrasil_os, rustc_private)]
// #![feature(yggdrasil_os, rustc_private)]
//
// use std::{
// mem::size_of,
// os::{
// fd::AsRawFd,
// yggdrasil::io::{poll::PollChannel, raw_socket::RawSocket, timer::TimerFd},
// },
// process::ExitCode,
// sync::atomic::{AtomicBool, Ordering},
// time::Duration,
// };
//
// use bytemuck::Zeroable;
// use clap::Parser;
// use netutils::{netconfig::NetConfig, Error};
// use yggdrasil_abi::net::{
// protocols::{EtherType, EthernetFrame, IcmpV4Frame, InetChecksum, IpProtocol, Ipv4Frame},
// types::NetValueImpl,
// IpAddr, Ipv4Addr, MacAddress,
// };
//
// #[derive(Parser)]
// struct Args {
// #[clap(
// help = "Time (ms) between a reply is received and the next request is sent",
// short,
// long,
// default_value_t = 1000,
// value_parser = valid_interval
// )]
// inteval: u32,
// #[clap(
// help = "Time (ms) after which the request is considered unanswered",
// short,
// long,
// default_value_t = 500,
// value_parser = valid_timeout,
// )]
// timeout: u32,
// #[clap(
// help = "Number of requests to perform",
// short,
// long,
// default_value_t = 10
// )]
// count: usize,
// #[clap(
// help = "Amount of bytes to include as data",
// short,
// long,
// default_value_t = 64,
// value_parser = valid_data_size
// )]
// data_size: usize,
//
// #[clap(help = "Address to ping")]
// address: core::net::IpAddr,
// }
//
// fn valid_interval(s: &str) -> Result<u32, String> {
// clap_num::number_range(s, 100, 10000)
// }
//
// fn valid_timeout(s: &str) -> Result<u32, String> {
// clap_num::number_range(s, 100, 5000)
// }
//
// fn valid_data_size(s: &str) -> Result<usize, String> {
// clap_num::number_range(s, 4, 128)
// }
//
// struct PingRouting {
// interface_id: u32,
// source_ip: IpAddr,
// destination_ip: IpAddr,
// source_mac: MacAddress,
// gateway_mac: MacAddress,
// }
//
// struct PingStats {
// packets_sent: usize,
// packets_received: usize,
// }
//
// fn resolve_routing(address: IpAddr) -> Result<PingRouting, Error> {
// let mut nc = NetConfig::open()?;
// let routing = nc.query_route(address)?;
// let Some(source) = routing.source else {
// todo!();
// };
// let Some(gateway) = routing.gateway else {
// todo!();
// };
//
// let gateway_mac = nc.query_arp(routing.interface_id, gateway, true)?;
//
// Ok(PingRouting {
// interface_id: routing.interface_id,
// source_ip: source,
// destination_ip: routing.destination,
// source_mac: routing.source_mac,
// gateway_mac,
// })
// }
//
// fn validate_ping_reply(
// packet: &[u8],
// local: Ipv4Addr,
// remote: Ipv4Addr,
// expect_l4_data: &[u8],
// expect_id: u16,
// expect_seq: u16,
// ) -> bool {
// if packet.len() < size_of::<EthernetFrame>() + size_of::<Ipv4Frame>() {
// return false;
// }
//
// let l3_offset = size_of::<EthernetFrame>();
//
// let l2_frame: &EthernetFrame = bytemuck::from_bytes(&packet[..l3_offset]);
//
// if EtherType::from_network_order(l2_frame.ethertype) != EtherType::IPV4 {
// return false;
// }
// let l3_frame: &Ipv4Frame =
// bytemuck::from_bytes(&packet[l3_offset..l3_offset + size_of::<Ipv4Frame>()]);
// if l3_frame.protocol != IpProtocol::ICMP
// || u32::from_network_order(l3_frame.source_address) != u32::from(remote)
// || u32::from_network_order(l3_frame.destination_address) != u32::from(local)
// {
// return false;
// }
// let mut ip_checksum = InetChecksum::new();
// ip_checksum.add_value(l3_frame, true);
// let ip_checksum = ip_checksum.finish();
//
// if ip_checksum != 0 {
// eprintln!("IP checksum mismatch: {:#06x}", ip_checksum);
// return false;
// }
//
// let l4_offset = l3_offset + l3_frame.header_length();
// let l4_size = l3_frame
// .total_length()
// .saturating_sub(l3_frame.header_length());
// if packet.len() < l4_offset + size_of::<IcmpV4Frame>() + expect_l4_data.len() {
// return false;
// }
// let l4_frame: &IcmpV4Frame =
// bytemuck::from_bytes(&packet[l4_offset..l4_offset + size_of::<IcmpV4Frame>()]);
// let l4_data = &packet[l4_offset + size_of::<IcmpV4Frame>()..l4_offset + l4_size];
//
// if l4_frame.ty != 0 || l4_frame.code != 0 {
// return false;
// }
//
// let rest = u32::from_network_order(l4_frame.rest);
// let reply_id = (rest >> 16) as u16;
// let reply_seq = rest as u16;
//
// if reply_id != expect_id || reply_seq != expect_seq {
// eprintln!(
// "ICMP seq/id mismatch: sent {}/{}, got {}/{}",
// expect_id, expect_seq, reply_id, reply_seq
// );
// return false;
// }
//
// let mut icmp_checksum = InetChecksum::new();
// icmp_checksum.add_value(l4_frame, true);
// icmp_checksum.add_bytes(l4_data, true);
// let icmp_checksum = icmp_checksum.finish();
//
// if icmp_checksum != 0 {
// eprintln!("ICMP checksum mismatch: {:#06x}", icmp_checksum);
// return false;
// }
//
// l4_data == expect_l4_data
// }
//
// #[allow(clippy::too_many_arguments)]
// fn ping_once(
// socket: &mut RawSocket,
// poll: &mut PollChannel,
// timer: &mut TimerFd,
// info: &PingRouting,
// timeout: Duration,
// data_len: usize,
// id: u16,
// seq: u16,
// ) -> Result<bool, Error> {
// let mut buffer = [0; 4096];
//
// let source_ip = info.source_ip.into_ipv4().unwrap();
// let destination_ip = info.destination_ip.into_ipv4().unwrap();
// let mut l4_data = Vec::with_capacity(data_len);
//
// for _ in 0..data_len {
// l4_data.push(rand::random());
// }
//
// let ip_len = (size_of::<Ipv4Frame>() + size_of::<IcmpV4Frame>() + data_len)
// .try_into()
// .unwrap();
//
// let l2_frame = EthernetFrame {
// source_mac: info.source_mac,
// destination_mac: info.gateway_mac,
// ethertype: EtherType::IPV4.to_network_order(),
// };
// let mut l3_frame = Ipv4Frame {
// source_address: u32::from(source_ip).to_network_order(),
// destination_address: u32::from(destination_ip).to_network_order(),
// protocol: IpProtocol::ICMP,
// version_length: 0x45,
// total_length: u16::to_network_order(ip_len),
// flags_frag: u16::to_network_order(0x4000),
// id: u16::to_network_order(0),
// ttl: 255,
// ..Ipv4Frame::zeroed()
// };
// let mut l4_frame = IcmpV4Frame {
// ty: 8,
// code: 0,
// checksum: u16::to_network_order(0),
// rest: u32::to_network_order(((id as u32) << 16) | (seq as u32)),
// };
//
// let mut ip_checksum = InetChecksum::new();
// ip_checksum.add_value(&l3_frame, true);
// l3_frame.header_checksum = ip_checksum.finish().to_network_order();
//
// let mut icmp_checksum = InetChecksum::new();
// icmp_checksum.add_value(&l4_frame, true);
// icmp_checksum.add_bytes(&l4_data, true);
// l4_frame.checksum = icmp_checksum.finish().to_network_order();
//
// let mut packet = vec![];
// packet.extend_from_slice(bytemuck::bytes_of(&l2_frame));
// packet.extend_from_slice(bytemuck::bytes_of(&l3_frame));
// packet.extend_from_slice(bytemuck::bytes_of(&l4_frame));
// packet.extend_from_slice(&l4_data);
//
// timer.start(timeout)?;
// socket.send(&packet)?;
//
// loop {
// let (fd, result) = poll.wait(None, true)?.unwrap();
// result?;
//
// match fd {
// fd if fd == socket.as_raw_fd() => {
// // TODO
// let len = socket.recv(&mut buffer)?;
// if validate_ping_reply(&buffer[..len], source_ip, destination_ip, &l4_data, id, seq)
// {
// return Ok(true);
// }
// }
// fd if fd == timer.as_raw_fd() => {
// return Ok(false);
// }
// _ => unreachable!(),
// }
// }
// }
//
// fn ping(
// address: IpAddr,
// times: usize,
// data_len: usize,
// interval: Duration,
// timeout: Duration,
// ) -> Result<PingStats, Error> {
// let routing = resolve_routing(address)?;
//
// let mut stats = PingStats {
// packets_sent: 0,
// packets_received: 0,
// };
// let mut poll = PollChannel::new()?;
// let mut timer = TimerFd::new(false, false)?;
// let mut socket = RawSocket::bind(routing.interface_id)?;
//
// poll.add(timer.as_raw_fd())?;
// poll.add(socket.as_raw_fd())?;
//
// let id = rand::random();
// for i in 0..times {
// if INTERRUPTED.load(Ordering::Acquire) {
// break;
// }
//
// let result = ping_once(
// &mut socket,
// &mut poll,
// &mut timer,
// &routing,
// timeout,
// data_len,
// id,
// i as u16,
// )?;
// stats.packets_sent += 1;
//
// if result {
// stats.packets_received += 1;
// println!("[{}/{}] {}: PONG", i + 1, times, address);
// }
//
// std::thread::sleep(interval);
// }
//
// Ok(stats)
// }
//
// static INTERRUPTED: AtomicBool = AtomicBool::new(false);
//
// fn main() -> ExitCode {
// // set_signal_handler(Signal::Interrupted, SignalHandler::Function(interrupt));
//
// let args = Args::parse();
//
// let stats = match ping(
// args.address.into(),
// args.count,
// args.data_size,
// Duration::from_millis(args.inteval.into()),
// Duration::from_millis(args.timeout.into()),
// ) {
// Ok(stats) => stats,
// Err(error) => {
// eprintln!("ping: {}", error);
// return ExitCode::FAILURE;
// }
// };
//
// let loss = (stats.packets_sent - stats.packets_received) * 100 / stats.packets_sent;
// println!(
// "{} sent, {} received, {}% loss",
// stats.packets_sent, stats.packets_received, loss
// );
//
// ExitCode::SUCCESS
// }
use std::{
mem::size_of,
os::{
fd::AsRawFd,
yggdrasil::io::{poll::PollChannel, net::raw_socket::RawSocket, timer::TimerFd},
},
process::ExitCode,
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use bytemuck::Zeroable;
use clap::Parser;
use netutils::{netconfig::NetConfig, Error};
use yggdrasil_abi::net::{
protocols::{EtherType, EthernetFrame, IcmpV4Frame, InetChecksum, IpProtocol, Ipv4Frame},
types::NetValueImpl,
IpAddr, Ipv4Addr, MacAddress,
};
#[derive(Parser)]
struct Args {
#[clap(
help = "Time (ms) between a reply is received and the next request is sent",
short,
long,
default_value_t = 1000,
value_parser = valid_interval
)]
inteval: u32,
#[clap(
help = "Time (ms) after which the request is considered unanswered",
short,
long,
default_value_t = 500,
value_parser = valid_timeout,
)]
timeout: u32,
#[clap(
help = "Number of requests to perform",
short,
long,
default_value_t = 10
)]
count: usize,
#[clap(
help = "Amount of bytes to include as data",
short,
long,
default_value_t = 64,
value_parser = valid_data_size
)]
data_size: usize,
#[clap(help = "Address to ping")]
address: core::net::IpAddr,
}
fn valid_interval(s: &str) -> Result<u32, String> {
clap_num::number_range(s, 100, 10000)
}
fn valid_timeout(s: &str) -> Result<u32, String> {
clap_num::number_range(s, 100, 5000)
}
fn valid_data_size(s: &str) -> Result<usize, String> {
clap_num::number_range(s, 4, 128)
}
struct PingRouting {
interface_id: u32,
source_ip: IpAddr,
destination_ip: IpAddr,
source_mac: MacAddress,
gateway_mac: MacAddress,
}
struct PingStats {
packets_sent: usize,
packets_received: usize,
}
fn resolve_routing(address: IpAddr) -> Result<PingRouting, Error> {
let mut nc = NetConfig::open()?;
let routing = nc.query_route(address)?;
let Some(source) = routing.source else {
todo!();
};
let Some(gateway) = routing.gateway else {
todo!();
};
let gateway_mac = nc.query_arp(routing.interface_id, gateway, true)?;
Ok(PingRouting {
interface_id: routing.interface_id,
source_ip: source,
destination_ip: routing.destination,
source_mac: routing.source_mac,
gateway_mac,
})
}
fn validate_ping_reply(
packet: &[u8],
local: Ipv4Addr,
remote: Ipv4Addr,
expect_l4_data: &[u8],
expect_id: u16,
expect_seq: u16,
) -> bool {
if packet.len() < size_of::<EthernetFrame>() + size_of::<Ipv4Frame>() {
return false;
}
let l3_offset = size_of::<EthernetFrame>();
let l2_frame: &EthernetFrame = bytemuck::from_bytes(&packet[..l3_offset]);
if EtherType::from_network_order(l2_frame.ethertype) != EtherType::IPV4 {
return false;
}
let l3_frame: &Ipv4Frame =
bytemuck::from_bytes(&packet[l3_offset..l3_offset + size_of::<Ipv4Frame>()]);
if l3_frame.protocol != IpProtocol::ICMP
|| u32::from_network_order(l3_frame.source_address) != u32::from(remote)
|| u32::from_network_order(l3_frame.destination_address) != u32::from(local)
{
return false;
}
let mut ip_checksum = InetChecksum::new();
ip_checksum.add_value(l3_frame, true);
let ip_checksum = ip_checksum.finish();
if ip_checksum != 0 {
eprintln!("IP checksum mismatch: {:#06x}", ip_checksum);
return false;
}
let l4_offset = l3_offset + l3_frame.header_length();
let l4_size = l3_frame
.total_length()
.saturating_sub(l3_frame.header_length());
if packet.len() < l4_offset + size_of::<IcmpV4Frame>() + expect_l4_data.len() {
return false;
}
let l4_frame: &IcmpV4Frame =
bytemuck::from_bytes(&packet[l4_offset..l4_offset + size_of::<IcmpV4Frame>()]);
let l4_data = &packet[l4_offset + size_of::<IcmpV4Frame>()..l4_offset + l4_size];
if l4_frame.ty != 0 || l4_frame.code != 0 {
return false;
}
let rest = u32::from_network_order(l4_frame.rest);
let reply_id = (rest >> 16) as u16;
let reply_seq = rest as u16;
if reply_id != expect_id || reply_seq != expect_seq {
eprintln!(
"ICMP seq/id mismatch: sent {}/{}, got {}/{}",
expect_id, expect_seq, reply_id, reply_seq
);
return false;
}
let mut icmp_checksum = InetChecksum::new();
icmp_checksum.add_value(l4_frame, true);
icmp_checksum.add_bytes(l4_data, true);
let icmp_checksum = icmp_checksum.finish();
if icmp_checksum != 0 {
eprintln!("ICMP checksum mismatch: {:#06x}", icmp_checksum);
return false;
}
l4_data == expect_l4_data
}
#[allow(clippy::too_many_arguments)]
fn ping_once(
socket: &mut RawSocket,
poll: &mut PollChannel,
timer: &mut TimerFd,
info: &PingRouting,
timeout: Duration,
data_len: usize,
id: u16,
seq: u16,
) -> Result<bool, Error> {
let mut buffer = [0; 4096];
let source_ip = info.source_ip.into_ipv4().unwrap();
let destination_ip = info.destination_ip.into_ipv4().unwrap();
let mut l4_data = Vec::with_capacity(data_len);
for _ in 0..data_len {
l4_data.push(rand::random());
}
let ip_len = (size_of::<Ipv4Frame>() + size_of::<IcmpV4Frame>() + data_len)
.try_into()
.unwrap();
let l2_frame = EthernetFrame {
source_mac: info.source_mac,
destination_mac: info.gateway_mac,
ethertype: EtherType::IPV4.to_network_order(),
};
let mut l3_frame = Ipv4Frame {
source_address: u32::from(source_ip).to_network_order(),
destination_address: u32::from(destination_ip).to_network_order(),
protocol: IpProtocol::ICMP,
version_length: 0x45,
total_length: u16::to_network_order(ip_len),
flags_frag: u16::to_network_order(0x4000),
id: u16::to_network_order(0),
ttl: 255,
..Ipv4Frame::zeroed()
};
let mut l4_frame = IcmpV4Frame {
ty: 8,
code: 0,
checksum: u16::to_network_order(0),
rest: u32::to_network_order(((id as u32) << 16) | (seq as u32)),
};
let mut ip_checksum = InetChecksum::new();
ip_checksum.add_value(&l3_frame, true);
l3_frame.header_checksum = ip_checksum.finish().to_network_order();
let mut icmp_checksum = InetChecksum::new();
icmp_checksum.add_value(&l4_frame, true);
icmp_checksum.add_bytes(&l4_data, true);
l4_frame.checksum = icmp_checksum.finish().to_network_order();
let mut packet = vec![];
packet.extend_from_slice(bytemuck::bytes_of(&l2_frame));
packet.extend_from_slice(bytemuck::bytes_of(&l3_frame));
packet.extend_from_slice(bytemuck::bytes_of(&l4_frame));
packet.extend_from_slice(&l4_data);
timer.start(timeout)?;
socket.send(&packet)?;
loop {
let (fd, result) = poll.wait(None, true)?.unwrap();
result?;
match fd {
fd if fd == socket.as_raw_fd() => {
// TODO
let len = socket.recv(&mut buffer)?;
if validate_ping_reply(&buffer[..len], source_ip, destination_ip, &l4_data, id, seq)
{
return Ok(true);
}
}
fd if fd == timer.as_raw_fd() => {
return Ok(false);
}
_ => unreachable!(),
}
}
}
fn ping(
address: IpAddr,
times: usize,
data_len: usize,
interval: Duration,
timeout: Duration,
) -> Result<PingStats, Error> {
let routing = resolve_routing(address)?;
let mut stats = PingStats {
packets_sent: 0,
packets_received: 0,
};
let mut poll = PollChannel::new()?;
let mut timer = TimerFd::new(false, false)?;
let mut socket = RawSocket::bind(routing.interface_id)?;
poll.add(timer.as_raw_fd())?;
poll.add(socket.as_raw_fd())?;
let id = rand::random();
for i in 0..times {
if INTERRUPTED.load(Ordering::Acquire) {
break;
}
let result = ping_once(
&mut socket,
&mut poll,
&mut timer,
&routing,
timeout,
data_len,
id,
i as u16,
)?;
stats.packets_sent += 1;
if result {
stats.packets_received += 1;
println!("[{}/{}] {}: PONG", i + 1, times, address);
}
std::thread::sleep(interval);
}
Ok(stats)
}
static INTERRUPTED: AtomicBool = AtomicBool::new(false);
fn main() -> ExitCode {
// set_signal_handler(Signal::Interrupted, SignalHandler::Function(interrupt));
let args = Args::parse();
let stats = match ping(
args.address.into(),
args.count,
args.data_size,
Duration::from_millis(args.inteval.into()),
Duration::from_millis(args.timeout.into()),
) {
Ok(stats) => stats,
Err(error) => {
eprintln!("ping: {}", error);
return ExitCode::FAILURE;
}
};
let loss = (stats.packets_sent - stats.packets_received) * 100 / stats.packets_sent;
println!(
"{} sent, {} received, {}% loss",
stats.packets_sent, stats.packets_received, loss
);
ExitCode::SUCCESS
}