Mark Poliakov 18fa8b954a Add 'kernel/' from commit '7f1f6b73377367db17f98a740316b904c37ce3b1'
git-subtree-dir: kernel
git-subtree-mainline: 817f71f90f97270dd569fd44246bf74e57636552
git-subtree-split: 7f1f6b73377367db17f98a740316b904c37ce3b1
2024-03-12 15:52:48 +02:00

828 lines
25 KiB
Rust

use core::{
future::{poll_fn, Future},
pin::Pin,
sync::atomic::{AtomicBool, AtomicU32, Ordering},
task::{Context, Poll},
time::Duration,
};
use alloc::{collections::BTreeMap, sync::Arc, vec::Vec};
use libk_device::monotonic_timestamp;
use libk_mm::PageBox;
use libk_thread::{
block,
runtime::{run_with_timeout, FutureTimeout},
};
use libk_util::{
queue::BoundedMpmcQueue,
sync::{
spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockWriteGuard},
IrqSafeSpinlock, IrqSafeSpinlockGuard,
},
waker::QueueWaker,
};
use vfs::{ConnectionSocket, FileReadiness, ListenerSocket, PacketSocket, Socket};
use yggdrasil_abi::{
error::Error,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketInterfaceQuery, SocketOption},
};
use crate::{
ethernet::L2Packet,
interface::NetworkInterface,
l3::Route,
l4::{
self,
tcp::{TcpConnection, TcpConnectionState},
},
};
pub struct UdpSocket {
local: SocketAddr,
remote: Option<SocketAddr>,
broadcast: AtomicBool,
// TODO just place packets here for one less copy?
receive_queue: BoundedMpmcQueue<(SocketAddr, Vec<u8>)>,
}
pub struct TcpSocket {
pub(crate) local: SocketAddr,
pub(crate) remote: SocketAddr,
// Listener which accepted the socket
listener: Option<Arc<TcpListener>>,
connection: IrqSafeRwLock<TcpConnection>,
}
pub struct TcpListener {
accept: SocketAddr,
// Currently active sockets
sockets: IrqSafeRwLock<BTreeMap<SocketAddr, Arc<TcpSocket>>>,
pending_accept: IrqSafeSpinlock<Vec<Arc<TcpSocket>>>,
accept_notify: QueueWaker,
}
pub struct RawSocket {
id: u32,
bound: IrqSafeSpinlock<Option<u32>>,
receive_queue: BoundedMpmcQueue<L2Packet>,
}
pub struct SocketTable<T: Socket> {
inner: BTreeMap<SocketAddr, Arc<T>>,
}
pub struct TwoWaySocketTable<T> {
inner: BTreeMap<(SocketAddr, SocketAddr), Arc<T>>,
}
impl<T> TwoWaySocketTable<T> {
pub const fn new() -> Self {
Self {
inner: BTreeMap::new(),
}
}
pub fn try_insert_with<F: FnOnce() -> Result<Arc<T>, Error>>(
&mut self,
local: SocketAddr,
remote: SocketAddr,
with: F,
) -> Result<Arc<T>, Error> {
if self.inner.contains_key(&(local, remote)) {
return Err(Error::AddrInUse);
}
let socket = with()?;
self.inner.insert((local, remote), socket.clone());
Ok(socket)
}
pub fn try_insert_with_ephemeral_port<F: FnMut(u16) -> Result<Arc<T>, Error>>(
&mut self,
local: IpAddr,
remote: SocketAddr,
mut with: F,
) -> Result<Arc<T>, Error> {
for port in 32768..u16::MAX - 1 {
let local = SocketAddr::new(local, port);
match self.try_insert_with(local, remote, || with(port)) {
Ok(socket) => return Ok(socket),
Err(Error::AddrInUse) => continue,
Err(error) => return Err(error),
}
}
Err(Error::AddrInUse)
}
pub fn remove(&mut self, local: SocketAddr, remote: SocketAddr) -> Result<(), Error> {
match self.inner.remove(&(local, remote)) {
Some(_) => Ok(()),
None => Err(Error::DoesNotExist),
}
}
pub fn get(&self, local: SocketAddr, remote: SocketAddr) -> Option<Arc<T>> {
self.inner.get(&(local, remote)).cloned()
}
}
impl<T: Socket> SocketTable<T> {
pub const fn new() -> Self {
Self {
inner: BTreeMap::new(),
}
}
pub fn try_insert_with_ephemeral_port<F: FnMut(u16) -> Result<Arc<T>, Error>>(
&mut self,
local: IpAddr,
mut with: F,
) -> Result<Arc<T>, Error> {
for port in 32768..u16::MAX - 1 {
let local = SocketAddr::new(local, port);
match self.try_insert_with(local, || with(port)) {
Ok(socket) => return Ok(socket),
Err(Error::AddrInUse) => continue,
Err(error) => return Err(error),
}
}
Err(Error::AddrInUse)
}
pub fn try_insert_with<F: FnOnce() -> Result<Arc<T>, Error>>(
&mut self,
address: SocketAddr,
with: F,
) -> Result<Arc<T>, Error> {
if self.inner.contains_key(&address) {
return Err(Error::AddrInUse);
}
let socket = with()?;
self.inner.insert(address, socket.clone());
Ok(socket)
}
pub fn remove(&mut self, local: SocketAddr) -> Result<(), Error> {
match self.inner.remove(&local) {
Some(_) => Ok(()),
None => Err(Error::DoesNotExist),
}
}
pub fn get_exact(&self, local: &SocketAddr) -> Option<Arc<T>> {
self.inner.get(local).cloned()
}
pub fn get(&self, local: &SocketAddr) -> Option<Arc<T>> {
if let Some(socket) = self.inner.get(local) {
return Some(socket.clone());
}
match local {
SocketAddr::V4(_v4) => {
let unspec_v4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, local.port());
self.inner.get(&SocketAddr::V4(unspec_v4)).cloned()
}
SocketAddr::V6(_) => todo!(),
}
}
}
static UDP_SOCKETS: IrqSafeRwLock<SocketTable<UdpSocket>> = IrqSafeRwLock::new(SocketTable::new());
static TCP_SOCKETS: IrqSafeRwLock<TwoWaySocketTable<TcpSocket>> =
IrqSafeRwLock::new(TwoWaySocketTable::new());
static RAW_SOCKET_ID: AtomicU32 = AtomicU32::new(0);
static RAW_SOCKETS: IrqSafeRwLock<BTreeMap<u32, Arc<RawSocket>>> =
IrqSafeRwLock::new(BTreeMap::new());
static BOUND_RAW_SOCKETS: IrqSafeRwLock<BTreeMap<u32, Vec<u32>>> =
IrqSafeRwLock::new(BTreeMap::new());
static TCP_LISTENERS: IrqSafeRwLock<SocketTable<TcpListener>> =
IrqSafeRwLock::new(SocketTable::new());
impl UdpSocket {
fn create_socket(local: SocketAddr) -> Arc<UdpSocket> {
log::debug!("UDP socket opened: {}", local);
Arc::new(UdpSocket {
local,
remote: None,
broadcast: AtomicBool::new(false),
receive_queue: BoundedMpmcQueue::new(128),
})
}
pub fn bind(address: SocketAddr) -> Result<Arc<UdpSocket>, Error> {
let mut sockets = UDP_SOCKETS.write();
if address.port() == 0 {
sockets.try_insert_with_ephemeral_port(address.ip(), |port| {
Ok(Self::create_socket(SocketAddr::new(address.ip(), port)))
})
} else {
sockets.try_insert_with(address, move || Ok(Self::create_socket(address)))
}
}
pub fn connect(&self, _address: SocketAddr) -> Result<(), Error> {
todo!()
}
pub fn get(local: &SocketAddr) -> Option<Arc<UdpSocket>> {
UDP_SOCKETS.read().get(local)
}
pub fn packet_received(&self, source: SocketAddr, data: &[u8]) -> Result<(), Error> {
self.receive_queue
.try_push_back((source, Vec::from(data)))
.map_err(|_| Error::QueueFull)
}
}
impl FileReadiness for UdpSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.receive_queue.poll_not_empty(cx).map(Ok)
}
}
impl PacketSocket for UdpSocket {
fn send(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error> {
let Some(destination) = destination else {
// TODO can still send without setting address if "connected"
return Err(Error::InvalidArgument);
};
// TODO check that destnation family matches self family
match (self.broadcast.load(Ordering::Relaxed), destination.ip()) {
// SendTo in broadcast?
(true, _) => todo!(),
(false, _) => {
block!(
l4::udp::send(
self.local.port(),
destination.ip(),
destination.port(),
64,
data,
)
.await
)??;
}
}
Ok(data.len())
}
fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error> {
let (source, data) = block!(self.receive_queue.pop_front().await)?;
if data.len() > buffer.len() {
// TODO check how other OSs handle this
return Err(Error::BufferTooSmall);
}
buffer[..data.len()].copy_from_slice(&data);
Ok((source, data.len()))
}
}
impl Socket for UdpSocket {
fn local_address(&self) -> SocketAddr {
self.local
}
fn remote_address(&self) -> Option<SocketAddr> {
self.remote
}
fn close(&self) -> Result<(), Error> {
log::debug!("UDP socket closed: {}", self.local);
UDP_SOCKETS.write().remove(self.local)
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
&SocketOption::Broadcast(broadcast) => {
log::debug!("{} broadcast: {}", self.local, broadcast);
self.broadcast.store(broadcast, Ordering::Relaxed);
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::Broadcast(broadcast) => {
*broadcast = self.broadcast.load(Ordering::Relaxed);
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
}
impl RawSocket {
pub fn bind() -> Result<Arc<Self>, Error> {
let id = RAW_SOCKET_ID.fetch_add(1, Ordering::SeqCst);
let socket = Self {
id,
bound: IrqSafeSpinlock::new(None),
receive_queue: BoundedMpmcQueue::new(256),
};
let socket = Arc::new(socket);
RAW_SOCKETS.write().insert(id, socket.clone());
Ok(socket)
}
fn bound_packet_received(&self, packet: L2Packet) {
// TODO do something with the dropped packet?
self.receive_queue.try_push_back(packet).ok();
}
pub fn packet_received(packet: L2Packet) {
let bound_sockets = BOUND_RAW_SOCKETS.read();
let raw_sockets = RAW_SOCKETS.read();
if let Some(ids) = bound_sockets.get(&packet.interface_id) {
for id in ids {
let socket = raw_sockets.get(id).unwrap();
socket.bound_packet_received(packet.clone());
}
}
}
}
impl FileReadiness for RawSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.receive_queue.poll_not_empty(cx).map(Ok)
}
}
impl Socket for RawSocket {
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::BoundHardwareAddress(mac) => {
let bound = self.bound.lock().ok_or(Error::DoesNotExist)?;
let interface = NetworkInterface::get(bound).unwrap();
*mac = interface.mac;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
SocketOption::BindInterface(query) => {
let mut bound = self.bound.lock();
if bound.is_some() {
return Err(Error::AlreadyExists);
}
let mut bound_sockets = BOUND_RAW_SOCKETS.write();
let interface = match *query {
SocketInterfaceQuery::ById(id) => NetworkInterface::get(id),
SocketInterfaceQuery::ByName(name) => NetworkInterface::query_by_name(name),
}?;
let list = bound_sockets.entry(interface.id).or_default();
bound.replace(interface.id);
list.push(self.id);
Ok(())
}
SocketOption::UnbindInterface => todo!(),
_ => Err(Error::InvalidOperation),
}
}
fn close(&self) -> Result<(), Error> {
let bound = self.bound.lock().take();
if let Some(bound) = bound {
let mut bound_sockets = BOUND_RAW_SOCKETS.write();
let mut clear = false;
if let Some(list) = bound_sockets.get_mut(&bound) {
list.retain(|&item| item != self.id);
clear = list.is_empty();
}
if clear {
bound_sockets.remove(&bound);
}
}
RAW_SOCKETS.write().remove(&self.id).unwrap();
Ok(())
}
fn local_address(&self) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
}
fn remote_address(&self) -> Option<SocketAddr> {
None
}
}
impl PacketSocket for RawSocket {
fn send(&self, _destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error> {
// TODO cap by MTU?
let bound = self.bound.lock().ok_or(Error::InvalidOperation)?;
let interface = NetworkInterface::get(bound).unwrap();
let l2_offset = interface.device.packet_prefix_size();
if data.len() > 4096 - l2_offset {
return Err(Error::InvalidArgument);
}
let mut packet = PageBox::new_slice(0, l2_offset + data.len())?;
packet[l2_offset..l2_offset + data.len()].copy_from_slice(data);
interface.device.transmit(packet)?;
Ok(data.len())
}
fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error> {
let data = block!(self.receive_queue.pop_front().await)?;
let full_len = data.data.len();
let len = full_len - data.l2_offset;
if buffer.len() < len {
return Err(Error::BufferTooSmall);
}
buffer[..len].copy_from_slice(&data.data[data.l2_offset..full_len]);
Ok((SocketAddr::NULL_V4, len))
}
}
impl TcpSocket {
pub fn connect(remote: SocketAddr) -> Result<(SocketAddr, Arc<TcpSocket>), Error> {
block!(Self::connect_async(remote).await)?
}
pub fn accept_remote(
listener: Arc<TcpListener>,
local: SocketAddr,
remote: SocketAddr,
remote_window_size: usize,
tx_seq: u32,
rx_seq: u32,
) -> Result<Arc<TcpSocket>, Error> {
let mut sockets = TCP_SOCKETS.write();
sockets.try_insert_with(local, remote, move || {
let connection = TcpConnection::new(
local,
remote,
remote_window_size,
tx_seq,
rx_seq,
TcpConnectionState::SynReceived,
);
log::debug!("Accepted TCP socket {} -> {}", local, remote);
let socket = Self {
local,
remote,
listener: Some(listener),
connection: IrqSafeRwLock::new(connection),
};
Ok(Arc::new(socket))
})
}
pub fn connection(&self) -> &IrqSafeRwLock<TcpConnection> {
&self.connection
}
pub(crate) fn accept(self: &Arc<Self>) {
if let Some(listener) = self.listener.as_ref() {
listener.accept_socket(self.clone());
}
}
pub fn get(local: SocketAddr, remote: SocketAddr) -> Option<Arc<Self>> {
TCP_SOCKETS.read().get(local, remote)
}
pub fn receive_async<'a>(
&'a self,
buffer: &'a mut [u8],
) -> impl Future<Output = Result<usize, Error>> + 'a {
// TODO timeout here
// TODO don't throw ConnectionReset immediately
struct F<'f> {
socket: &'f TcpSocket,
buffer: &'f mut [u8],
}
impl<'f> Future for F<'f> {
type Output = Result<usize, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.socket.poll_receive(cx) {
Poll::Ready(Ok(mut lock)) => Poll::Ready(lock.read_nonblocking(self.buffer)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
}
}
F {
socket: self,
buffer,
}
}
pub async fn send_async(&self, data: &[u8]) -> Result<usize, Error> {
let mut pos = 0;
let mut rem = data.len();
while rem != 0 {
// TODO check MTU
let amount = rem.min(512);
self.send_segment_async(&data[pos..pos + amount]).await?;
pos += amount;
rem -= amount;
}
Ok(pos)
}
pub async fn close_async(&self, remove_from_listener: bool) -> Result<(), Error> {
// TODO timeout here
// Already closing
if self.connection.read().is_closing() {
return Ok(());
}
// Wait for all sent data to be acknowledged
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.finish().await?;
}
log::debug!(
"TCP socket closed (FinWait2/Closed): {} <-> {}",
self.local,
self.remote
);
// Wait for connection to get closed
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_finish(cx)
})
.await;
if remove_from_listener {
if let Some(listener) = self.listener.as_ref() {
listener.remove_socket(self.remote);
};
}
Ok(())
}
pub(crate) fn remove_socket(&self) -> Result<(), Error> {
log::debug!(
"TCP socket closed and removed: {} <-> {}",
self.local,
self.remote
);
let connection = self.connection.read();
debug_assert!(connection.is_closed());
TCP_SOCKETS.write().remove(self.local, self.remote)?;
connection.notify_all();
Ok(())
}
fn poll_receive(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<IrqSafeRwLockWriteGuard<TcpConnection>, Error>> {
let lock = self.connection.write();
match lock.poll_receive(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(lock)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
}
async fn send_segment_async(&self, data: &[u8]) -> Result<(), Error> {
// TODO timeout here
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.transmit(data).await?;
}
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_acknowledge(cx)
})
.await;
Ok(())
}
async fn connect_async(remote: SocketAddr) -> Result<(SocketAddr, Arc<TcpSocket>), Error> {
// Lookup route to remote
let (interface_id, _, remote_ip) =
Route::lookup(remote.ip()).ok_or(Error::HostUnreachable)?;
let remote = SocketAddr::new(remote_ip, remote.port());
let interface = NetworkInterface::get(interface_id)?;
let local_ip = interface.address.read().ok_or(Error::NetworkUnreachable)?;
let socket = {
let mut sockets = TCP_SOCKETS.write();
sockets.try_insert_with_ephemeral_port(local_ip, remote, |port| {
let t = monotonic_timestamp()?;
let tx_seq = t.as_micros() as u32;
let local = SocketAddr::new(local_ip, port);
let connection =
TcpConnection::new(local, remote, 16384, tx_seq, 0, TcpConnectionState::Closed);
let socket = Self {
local,
remote,
listener: None,
connection: IrqSafeRwLock::new(connection),
};
Ok(Arc::new(socket))
})?
};
let mut t = 200;
for _ in 0..5 {
let timeout = Duration::from_millis(t);
log::debug!("Try SYN with timeout={:?}", timeout);
match socket.try_connect(timeout).await {
Ok(()) => return Ok((socket.local, socket)),
Err(Error::TimedOut) => (),
Err(error) => return Err(error),
}
t *= 2;
}
// Couldn't establish
Err(Error::TimedOut)
}
async fn try_connect(&self, timeout: Duration) -> Result<(), Error> {
{
let mut connection = self.connection.write();
connection.send_syn().await?;
}
let fut = poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_established(cx)
});
match run_with_timeout(timeout, fut).await {
FutureTimeout::Ok(value) => value,
FutureTimeout::Timeout => Err(Error::TimedOut),
}
}
}
impl Socket for TcpSocket {
fn local_address(&self) -> SocketAddr {
self.local
}
fn remote_address(&self) -> Option<SocketAddr> {
Some(self.remote)
}
fn close(&self) -> Result<(), Error> {
block!(self.close_async(true).await)?
}
}
impl FileReadiness for TcpSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.poll_receive(cx).map_ok(|_| ())
}
}
impl ConnectionSocket for TcpSocket {
fn receive(&self, buffer: &mut [u8]) -> Result<usize, Error> {
block!(self.receive_async(buffer).await)?
}
fn send(&self, data: &[u8]) -> Result<usize, Error> {
block!(self.send_async(data).await)?
}
}
impl TcpListener {
pub fn bind(accept: SocketAddr) -> Result<Arc<Self>, Error> {
TCP_LISTENERS.write().try_insert_with(accept, || {
let listener = TcpListener {
accept,
sockets: IrqSafeRwLock::new(BTreeMap::new()),
pending_accept: IrqSafeSpinlock::new(Vec::new()),
accept_notify: QueueWaker::new(),
};
log::debug!("TCP Listener opened: {}", accept);
Ok(Arc::new(listener))
})
}
pub fn get(local: SocketAddr) -> Option<Arc<Self>> {
TCP_LISTENERS.read().get(&local)
}
pub fn accept_async(&self) -> impl Future<Output = Result<Arc<TcpSocket>, Error>> + '_ {
struct F<'f> {
listener: &'f TcpListener,
}
impl<'f> Future for F<'f> {
type Output = Result<Arc<TcpSocket>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.listener.poll_accept(cx) {
Poll::Ready(mut lock) => Poll::Ready(Ok(lock.pop().unwrap())),
Poll::Pending => Poll::Pending,
}
}
}
F { listener: self }
}
fn accept_socket(&self, socket: Arc<TcpSocket>) {
log::debug!("{}: accept {}", self.accept, socket.remote);
self.sockets.write().insert(socket.remote, socket.clone());
self.pending_accept.lock().push(socket);
self.accept_notify.wake_one();
}
fn remove_socket(&self, remote: SocketAddr) {
log::debug!("Remove client {}", remote);
self.sockets.write().remove(&remote);
}
fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<IrqSafeSpinlockGuard<Vec<Arc<TcpSocket>>>> {
let lock = self.pending_accept.lock();
self.accept_notify.register(cx.waker());
if !lock.is_empty() {
self.accept_notify.remove(cx.waker());
Poll::Ready(lock)
} else {
Poll::Pending
}
}
}
impl Socket for TcpListener {
fn local_address(&self) -> SocketAddr {
self.accept
}
fn remote_address(&self) -> Option<SocketAddr> {
None
}
fn close(&self) -> Result<(), Error> {
// TODO if clients not closed already, send RST?
TCP_LISTENERS.write().remove(self.accept)
}
}
impl FileReadiness for TcpListener {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.poll_accept(cx).map(|_| Ok(()))
}
}
impl ListenerSocket for TcpListener {
fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let socket = block!(self.accept_async().await)??;
let remote = socket.remote;
Ok((remote, socket))
}
}