git-subtree-dir: kernel git-subtree-mainline: 817f71f90f97270dd569fd44246bf74e57636552 git-subtree-split: 7f1f6b73377367db17f98a740316b904c37ce3b1
828 lines
25 KiB
Rust
828 lines
25 KiB
Rust
use core::{
|
|
future::{poll_fn, Future},
|
|
pin::Pin,
|
|
sync::atomic::{AtomicBool, AtomicU32, Ordering},
|
|
task::{Context, Poll},
|
|
time::Duration,
|
|
};
|
|
|
|
use alloc::{collections::BTreeMap, sync::Arc, vec::Vec};
|
|
use libk_device::monotonic_timestamp;
|
|
use libk_mm::PageBox;
|
|
use libk_thread::{
|
|
block,
|
|
runtime::{run_with_timeout, FutureTimeout},
|
|
};
|
|
use libk_util::{
|
|
queue::BoundedMpmcQueue,
|
|
sync::{
|
|
spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockWriteGuard},
|
|
IrqSafeSpinlock, IrqSafeSpinlockGuard,
|
|
},
|
|
waker::QueueWaker,
|
|
};
|
|
use vfs::{ConnectionSocket, FileReadiness, ListenerSocket, PacketSocket, Socket};
|
|
use yggdrasil_abi::{
|
|
error::Error,
|
|
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketInterfaceQuery, SocketOption},
|
|
};
|
|
|
|
use crate::{
|
|
ethernet::L2Packet,
|
|
interface::NetworkInterface,
|
|
l3::Route,
|
|
l4::{
|
|
self,
|
|
tcp::{TcpConnection, TcpConnectionState},
|
|
},
|
|
};
|
|
|
|
pub struct UdpSocket {
|
|
local: SocketAddr,
|
|
remote: Option<SocketAddr>,
|
|
|
|
broadcast: AtomicBool,
|
|
|
|
// TODO just place packets here for one less copy?
|
|
receive_queue: BoundedMpmcQueue<(SocketAddr, Vec<u8>)>,
|
|
}
|
|
|
|
pub struct TcpSocket {
|
|
pub(crate) local: SocketAddr,
|
|
pub(crate) remote: SocketAddr,
|
|
// Listener which accepted the socket
|
|
listener: Option<Arc<TcpListener>>,
|
|
connection: IrqSafeRwLock<TcpConnection>,
|
|
}
|
|
|
|
pub struct TcpListener {
|
|
accept: SocketAddr,
|
|
|
|
// Currently active sockets
|
|
sockets: IrqSafeRwLock<BTreeMap<SocketAddr, Arc<TcpSocket>>>,
|
|
pending_accept: IrqSafeSpinlock<Vec<Arc<TcpSocket>>>,
|
|
accept_notify: QueueWaker,
|
|
}
|
|
|
|
pub struct RawSocket {
|
|
id: u32,
|
|
bound: IrqSafeSpinlock<Option<u32>>,
|
|
receive_queue: BoundedMpmcQueue<L2Packet>,
|
|
}
|
|
|
|
pub struct SocketTable<T: Socket> {
|
|
inner: BTreeMap<SocketAddr, Arc<T>>,
|
|
}
|
|
|
|
pub struct TwoWaySocketTable<T> {
|
|
inner: BTreeMap<(SocketAddr, SocketAddr), Arc<T>>,
|
|
}
|
|
|
|
impl<T> TwoWaySocketTable<T> {
|
|
pub const fn new() -> Self {
|
|
Self {
|
|
inner: BTreeMap::new(),
|
|
}
|
|
}
|
|
|
|
pub fn try_insert_with<F: FnOnce() -> Result<Arc<T>, Error>>(
|
|
&mut self,
|
|
local: SocketAddr,
|
|
remote: SocketAddr,
|
|
with: F,
|
|
) -> Result<Arc<T>, Error> {
|
|
if self.inner.contains_key(&(local, remote)) {
|
|
return Err(Error::AddrInUse);
|
|
}
|
|
let socket = with()?;
|
|
self.inner.insert((local, remote), socket.clone());
|
|
Ok(socket)
|
|
}
|
|
|
|
pub fn try_insert_with_ephemeral_port<F: FnMut(u16) -> Result<Arc<T>, Error>>(
|
|
&mut self,
|
|
local: IpAddr,
|
|
remote: SocketAddr,
|
|
mut with: F,
|
|
) -> Result<Arc<T>, Error> {
|
|
for port in 32768..u16::MAX - 1 {
|
|
let local = SocketAddr::new(local, port);
|
|
|
|
match self.try_insert_with(local, remote, || with(port)) {
|
|
Ok(socket) => return Ok(socket),
|
|
Err(Error::AddrInUse) => continue,
|
|
Err(error) => return Err(error),
|
|
}
|
|
}
|
|
Err(Error::AddrInUse)
|
|
}
|
|
|
|
pub fn remove(&mut self, local: SocketAddr, remote: SocketAddr) -> Result<(), Error> {
|
|
match self.inner.remove(&(local, remote)) {
|
|
Some(_) => Ok(()),
|
|
None => Err(Error::DoesNotExist),
|
|
}
|
|
}
|
|
|
|
pub fn get(&self, local: SocketAddr, remote: SocketAddr) -> Option<Arc<T>> {
|
|
self.inner.get(&(local, remote)).cloned()
|
|
}
|
|
}
|
|
|
|
impl<T: Socket> SocketTable<T> {
|
|
pub const fn new() -> Self {
|
|
Self {
|
|
inner: BTreeMap::new(),
|
|
}
|
|
}
|
|
|
|
pub fn try_insert_with_ephemeral_port<F: FnMut(u16) -> Result<Arc<T>, Error>>(
|
|
&mut self,
|
|
local: IpAddr,
|
|
mut with: F,
|
|
) -> Result<Arc<T>, Error> {
|
|
for port in 32768..u16::MAX - 1 {
|
|
let local = SocketAddr::new(local, port);
|
|
|
|
match self.try_insert_with(local, || with(port)) {
|
|
Ok(socket) => return Ok(socket),
|
|
Err(Error::AddrInUse) => continue,
|
|
Err(error) => return Err(error),
|
|
}
|
|
}
|
|
Err(Error::AddrInUse)
|
|
}
|
|
|
|
pub fn try_insert_with<F: FnOnce() -> Result<Arc<T>, Error>>(
|
|
&mut self,
|
|
address: SocketAddr,
|
|
with: F,
|
|
) -> Result<Arc<T>, Error> {
|
|
if self.inner.contains_key(&address) {
|
|
return Err(Error::AddrInUse);
|
|
}
|
|
let socket = with()?;
|
|
self.inner.insert(address, socket.clone());
|
|
Ok(socket)
|
|
}
|
|
|
|
pub fn remove(&mut self, local: SocketAddr) -> Result<(), Error> {
|
|
match self.inner.remove(&local) {
|
|
Some(_) => Ok(()),
|
|
None => Err(Error::DoesNotExist),
|
|
}
|
|
}
|
|
|
|
pub fn get_exact(&self, local: &SocketAddr) -> Option<Arc<T>> {
|
|
self.inner.get(local).cloned()
|
|
}
|
|
|
|
pub fn get(&self, local: &SocketAddr) -> Option<Arc<T>> {
|
|
if let Some(socket) = self.inner.get(local) {
|
|
return Some(socket.clone());
|
|
}
|
|
|
|
match local {
|
|
SocketAddr::V4(_v4) => {
|
|
let unspec_v4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, local.port());
|
|
self.inner.get(&SocketAddr::V4(unspec_v4)).cloned()
|
|
}
|
|
SocketAddr::V6(_) => todo!(),
|
|
}
|
|
}
|
|
}
|
|
|
|
static UDP_SOCKETS: IrqSafeRwLock<SocketTable<UdpSocket>> = IrqSafeRwLock::new(SocketTable::new());
|
|
static TCP_SOCKETS: IrqSafeRwLock<TwoWaySocketTable<TcpSocket>> =
|
|
IrqSafeRwLock::new(TwoWaySocketTable::new());
|
|
static RAW_SOCKET_ID: AtomicU32 = AtomicU32::new(0);
|
|
static RAW_SOCKETS: IrqSafeRwLock<BTreeMap<u32, Arc<RawSocket>>> =
|
|
IrqSafeRwLock::new(BTreeMap::new());
|
|
static BOUND_RAW_SOCKETS: IrqSafeRwLock<BTreeMap<u32, Vec<u32>>> =
|
|
IrqSafeRwLock::new(BTreeMap::new());
|
|
static TCP_LISTENERS: IrqSafeRwLock<SocketTable<TcpListener>> =
|
|
IrqSafeRwLock::new(SocketTable::new());
|
|
|
|
impl UdpSocket {
|
|
fn create_socket(local: SocketAddr) -> Arc<UdpSocket> {
|
|
log::debug!("UDP socket opened: {}", local);
|
|
Arc::new(UdpSocket {
|
|
local,
|
|
remote: None,
|
|
broadcast: AtomicBool::new(false),
|
|
receive_queue: BoundedMpmcQueue::new(128),
|
|
})
|
|
}
|
|
|
|
pub fn bind(address: SocketAddr) -> Result<Arc<UdpSocket>, Error> {
|
|
let mut sockets = UDP_SOCKETS.write();
|
|
if address.port() == 0 {
|
|
sockets.try_insert_with_ephemeral_port(address.ip(), |port| {
|
|
Ok(Self::create_socket(SocketAddr::new(address.ip(), port)))
|
|
})
|
|
} else {
|
|
sockets.try_insert_with(address, move || Ok(Self::create_socket(address)))
|
|
}
|
|
}
|
|
|
|
pub fn connect(&self, _address: SocketAddr) -> Result<(), Error> {
|
|
todo!()
|
|
}
|
|
|
|
pub fn get(local: &SocketAddr) -> Option<Arc<UdpSocket>> {
|
|
UDP_SOCKETS.read().get(local)
|
|
}
|
|
|
|
pub fn packet_received(&self, source: SocketAddr, data: &[u8]) -> Result<(), Error> {
|
|
self.receive_queue
|
|
.try_push_back((source, Vec::from(data)))
|
|
.map_err(|_| Error::QueueFull)
|
|
}
|
|
}
|
|
|
|
impl FileReadiness for UdpSocket {
|
|
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
|
self.receive_queue.poll_not_empty(cx).map(Ok)
|
|
}
|
|
}
|
|
|
|
impl PacketSocket for UdpSocket {
|
|
fn send(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error> {
|
|
let Some(destination) = destination else {
|
|
// TODO can still send without setting address if "connected"
|
|
return Err(Error::InvalidArgument);
|
|
};
|
|
// TODO check that destnation family matches self family
|
|
match (self.broadcast.load(Ordering::Relaxed), destination.ip()) {
|
|
// SendTo in broadcast?
|
|
(true, _) => todo!(),
|
|
(false, _) => {
|
|
block!(
|
|
l4::udp::send(
|
|
self.local.port(),
|
|
destination.ip(),
|
|
destination.port(),
|
|
64,
|
|
data,
|
|
)
|
|
.await
|
|
)??;
|
|
}
|
|
}
|
|
|
|
Ok(data.len())
|
|
}
|
|
|
|
fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error> {
|
|
let (source, data) = block!(self.receive_queue.pop_front().await)?;
|
|
if data.len() > buffer.len() {
|
|
// TODO check how other OSs handle this
|
|
return Err(Error::BufferTooSmall);
|
|
}
|
|
buffer[..data.len()].copy_from_slice(&data);
|
|
Ok((source, data.len()))
|
|
}
|
|
}
|
|
|
|
impl Socket for UdpSocket {
|
|
fn local_address(&self) -> SocketAddr {
|
|
self.local
|
|
}
|
|
|
|
fn remote_address(&self) -> Option<SocketAddr> {
|
|
self.remote
|
|
}
|
|
|
|
fn close(&self) -> Result<(), Error> {
|
|
log::debug!("UDP socket closed: {}", self.local);
|
|
UDP_SOCKETS.write().remove(self.local)
|
|
}
|
|
|
|
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
|
|
match option {
|
|
&SocketOption::Broadcast(broadcast) => {
|
|
log::debug!("{} broadcast: {}", self.local, broadcast);
|
|
self.broadcast.store(broadcast, Ordering::Relaxed);
|
|
Ok(())
|
|
}
|
|
_ => Err(Error::InvalidOperation),
|
|
}
|
|
}
|
|
|
|
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
|
|
match option {
|
|
SocketOption::Broadcast(broadcast) => {
|
|
*broadcast = self.broadcast.load(Ordering::Relaxed);
|
|
Ok(())
|
|
}
|
|
_ => Err(Error::InvalidOperation),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl RawSocket {
|
|
pub fn bind() -> Result<Arc<Self>, Error> {
|
|
let id = RAW_SOCKET_ID.fetch_add(1, Ordering::SeqCst);
|
|
let socket = Self {
|
|
id,
|
|
bound: IrqSafeSpinlock::new(None),
|
|
receive_queue: BoundedMpmcQueue::new(256),
|
|
};
|
|
let socket = Arc::new(socket);
|
|
RAW_SOCKETS.write().insert(id, socket.clone());
|
|
|
|
Ok(socket)
|
|
}
|
|
|
|
fn bound_packet_received(&self, packet: L2Packet) {
|
|
// TODO do something with the dropped packet?
|
|
self.receive_queue.try_push_back(packet).ok();
|
|
}
|
|
|
|
pub fn packet_received(packet: L2Packet) {
|
|
let bound_sockets = BOUND_RAW_SOCKETS.read();
|
|
let raw_sockets = RAW_SOCKETS.read();
|
|
|
|
if let Some(ids) = bound_sockets.get(&packet.interface_id) {
|
|
for id in ids {
|
|
let socket = raw_sockets.get(id).unwrap();
|
|
socket.bound_packet_received(packet.clone());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl FileReadiness for RawSocket {
|
|
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
|
self.receive_queue.poll_not_empty(cx).map(Ok)
|
|
}
|
|
}
|
|
|
|
impl Socket for RawSocket {
|
|
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
|
|
match option {
|
|
SocketOption::BoundHardwareAddress(mac) => {
|
|
let bound = self.bound.lock().ok_or(Error::DoesNotExist)?;
|
|
let interface = NetworkInterface::get(bound).unwrap();
|
|
*mac = interface.mac;
|
|
Ok(())
|
|
}
|
|
_ => Err(Error::InvalidOperation),
|
|
}
|
|
}
|
|
|
|
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
|
|
match option {
|
|
SocketOption::BindInterface(query) => {
|
|
let mut bound = self.bound.lock();
|
|
if bound.is_some() {
|
|
return Err(Error::AlreadyExists);
|
|
}
|
|
|
|
let mut bound_sockets = BOUND_RAW_SOCKETS.write();
|
|
|
|
let interface = match *query {
|
|
SocketInterfaceQuery::ById(id) => NetworkInterface::get(id),
|
|
SocketInterfaceQuery::ByName(name) => NetworkInterface::query_by_name(name),
|
|
}?;
|
|
|
|
let list = bound_sockets.entry(interface.id).or_default();
|
|
bound.replace(interface.id);
|
|
list.push(self.id);
|
|
|
|
Ok(())
|
|
}
|
|
SocketOption::UnbindInterface => todo!(),
|
|
_ => Err(Error::InvalidOperation),
|
|
}
|
|
}
|
|
|
|
fn close(&self) -> Result<(), Error> {
|
|
let bound = self.bound.lock().take();
|
|
|
|
if let Some(bound) = bound {
|
|
let mut bound_sockets = BOUND_RAW_SOCKETS.write();
|
|
let mut clear = false;
|
|
|
|
if let Some(list) = bound_sockets.get_mut(&bound) {
|
|
list.retain(|&item| item != self.id);
|
|
clear = list.is_empty();
|
|
}
|
|
|
|
if clear {
|
|
bound_sockets.remove(&bound);
|
|
}
|
|
}
|
|
|
|
RAW_SOCKETS.write().remove(&self.id).unwrap();
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn local_address(&self) -> SocketAddr {
|
|
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
|
|
}
|
|
|
|
fn remote_address(&self) -> Option<SocketAddr> {
|
|
None
|
|
}
|
|
}
|
|
|
|
impl PacketSocket for RawSocket {
|
|
fn send(&self, _destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error> {
|
|
// TODO cap by MTU?
|
|
let bound = self.bound.lock().ok_or(Error::InvalidOperation)?;
|
|
let interface = NetworkInterface::get(bound).unwrap();
|
|
let l2_offset = interface.device.packet_prefix_size();
|
|
if data.len() > 4096 - l2_offset {
|
|
return Err(Error::InvalidArgument);
|
|
}
|
|
let mut packet = PageBox::new_slice(0, l2_offset + data.len())?;
|
|
packet[l2_offset..l2_offset + data.len()].copy_from_slice(data);
|
|
interface.device.transmit(packet)?;
|
|
Ok(data.len())
|
|
}
|
|
|
|
fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error> {
|
|
let data = block!(self.receive_queue.pop_front().await)?;
|
|
let full_len = data.data.len();
|
|
let len = full_len - data.l2_offset;
|
|
if buffer.len() < len {
|
|
return Err(Error::BufferTooSmall);
|
|
}
|
|
buffer[..len].copy_from_slice(&data.data[data.l2_offset..full_len]);
|
|
Ok((SocketAddr::NULL_V4, len))
|
|
}
|
|
}
|
|
|
|
impl TcpSocket {
|
|
pub fn connect(remote: SocketAddr) -> Result<(SocketAddr, Arc<TcpSocket>), Error> {
|
|
block!(Self::connect_async(remote).await)?
|
|
}
|
|
|
|
pub fn accept_remote(
|
|
listener: Arc<TcpListener>,
|
|
local: SocketAddr,
|
|
remote: SocketAddr,
|
|
remote_window_size: usize,
|
|
tx_seq: u32,
|
|
rx_seq: u32,
|
|
) -> Result<Arc<TcpSocket>, Error> {
|
|
let mut sockets = TCP_SOCKETS.write();
|
|
sockets.try_insert_with(local, remote, move || {
|
|
let connection = TcpConnection::new(
|
|
local,
|
|
remote,
|
|
remote_window_size,
|
|
tx_seq,
|
|
rx_seq,
|
|
TcpConnectionState::SynReceived,
|
|
);
|
|
|
|
log::debug!("Accepted TCP socket {} -> {}", local, remote);
|
|
|
|
let socket = Self {
|
|
local,
|
|
remote,
|
|
listener: Some(listener),
|
|
connection: IrqSafeRwLock::new(connection),
|
|
};
|
|
|
|
Ok(Arc::new(socket))
|
|
})
|
|
}
|
|
|
|
pub fn connection(&self) -> &IrqSafeRwLock<TcpConnection> {
|
|
&self.connection
|
|
}
|
|
|
|
pub(crate) fn accept(self: &Arc<Self>) {
|
|
if let Some(listener) = self.listener.as_ref() {
|
|
listener.accept_socket(self.clone());
|
|
}
|
|
}
|
|
|
|
pub fn get(local: SocketAddr, remote: SocketAddr) -> Option<Arc<Self>> {
|
|
TCP_SOCKETS.read().get(local, remote)
|
|
}
|
|
|
|
pub fn receive_async<'a>(
|
|
&'a self,
|
|
buffer: &'a mut [u8],
|
|
) -> impl Future<Output = Result<usize, Error>> + 'a {
|
|
// TODO timeout here
|
|
// TODO don't throw ConnectionReset immediately
|
|
struct F<'f> {
|
|
socket: &'f TcpSocket,
|
|
buffer: &'f mut [u8],
|
|
}
|
|
|
|
impl<'f> Future for F<'f> {
|
|
type Output = Result<usize, Error>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
match self.socket.poll_receive(cx) {
|
|
Poll::Ready(Ok(mut lock)) => Poll::Ready(lock.read_nonblocking(self.buffer)),
|
|
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
}
|
|
}
|
|
|
|
F {
|
|
socket: self,
|
|
buffer,
|
|
}
|
|
}
|
|
|
|
pub async fn send_async(&self, data: &[u8]) -> Result<usize, Error> {
|
|
let mut pos = 0;
|
|
let mut rem = data.len();
|
|
while rem != 0 {
|
|
// TODO check MTU
|
|
let amount = rem.min(512);
|
|
self.send_segment_async(&data[pos..pos + amount]).await?;
|
|
pos += amount;
|
|
rem -= amount;
|
|
}
|
|
Ok(pos)
|
|
}
|
|
|
|
pub async fn close_async(&self, remove_from_listener: bool) -> Result<(), Error> {
|
|
// TODO timeout here
|
|
// Already closing
|
|
if self.connection.read().is_closing() {
|
|
return Ok(());
|
|
}
|
|
|
|
// Wait for all sent data to be acknowledged
|
|
{
|
|
let mut connection = poll_fn(|cx| {
|
|
let connection = self.connection.write();
|
|
match connection.poll_send(cx) {
|
|
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
|
|
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
})
|
|
.await?;
|
|
|
|
connection.finish().await?;
|
|
}
|
|
|
|
log::debug!(
|
|
"TCP socket closed (FinWait2/Closed): {} <-> {}",
|
|
self.local,
|
|
self.remote
|
|
);
|
|
|
|
// Wait for connection to get closed
|
|
poll_fn(|cx| {
|
|
let connection = self.connection.read();
|
|
connection.poll_finish(cx)
|
|
})
|
|
.await;
|
|
|
|
if remove_from_listener {
|
|
if let Some(listener) = self.listener.as_ref() {
|
|
listener.remove_socket(self.remote);
|
|
};
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub(crate) fn remove_socket(&self) -> Result<(), Error> {
|
|
log::debug!(
|
|
"TCP socket closed and removed: {} <-> {}",
|
|
self.local,
|
|
self.remote
|
|
);
|
|
let connection = self.connection.read();
|
|
debug_assert!(connection.is_closed());
|
|
TCP_SOCKETS.write().remove(self.local, self.remote)?;
|
|
connection.notify_all();
|
|
Ok(())
|
|
}
|
|
|
|
fn poll_receive(
|
|
&self,
|
|
cx: &mut Context<'_>,
|
|
) -> Poll<Result<IrqSafeRwLockWriteGuard<TcpConnection>, Error>> {
|
|
let lock = self.connection.write();
|
|
match lock.poll_receive(cx) {
|
|
Poll::Ready(Ok(())) => Poll::Ready(Ok(lock)),
|
|
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
}
|
|
|
|
async fn send_segment_async(&self, data: &[u8]) -> Result<(), Error> {
|
|
// TODO timeout here
|
|
{
|
|
let mut connection = poll_fn(|cx| {
|
|
let connection = self.connection.write();
|
|
match connection.poll_send(cx) {
|
|
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
|
|
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
})
|
|
.await?;
|
|
|
|
connection.transmit(data).await?;
|
|
}
|
|
|
|
poll_fn(|cx| {
|
|
let connection = self.connection.read();
|
|
connection.poll_acknowledge(cx)
|
|
})
|
|
.await;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn connect_async(remote: SocketAddr) -> Result<(SocketAddr, Arc<TcpSocket>), Error> {
|
|
// Lookup route to remote
|
|
let (interface_id, _, remote_ip) =
|
|
Route::lookup(remote.ip()).ok_or(Error::HostUnreachable)?;
|
|
let remote = SocketAddr::new(remote_ip, remote.port());
|
|
let interface = NetworkInterface::get(interface_id)?;
|
|
let local_ip = interface.address.read().ok_or(Error::NetworkUnreachable)?;
|
|
|
|
let socket = {
|
|
let mut sockets = TCP_SOCKETS.write();
|
|
sockets.try_insert_with_ephemeral_port(local_ip, remote, |port| {
|
|
let t = monotonic_timestamp()?;
|
|
let tx_seq = t.as_micros() as u32;
|
|
let local = SocketAddr::new(local_ip, port);
|
|
let connection =
|
|
TcpConnection::new(local, remote, 16384, tx_seq, 0, TcpConnectionState::Closed);
|
|
|
|
let socket = Self {
|
|
local,
|
|
remote,
|
|
listener: None,
|
|
connection: IrqSafeRwLock::new(connection),
|
|
};
|
|
|
|
Ok(Arc::new(socket))
|
|
})?
|
|
};
|
|
|
|
let mut t = 200;
|
|
for _ in 0..5 {
|
|
let timeout = Duration::from_millis(t);
|
|
log::debug!("Try SYN with timeout={:?}", timeout);
|
|
match socket.try_connect(timeout).await {
|
|
Ok(()) => return Ok((socket.local, socket)),
|
|
Err(Error::TimedOut) => (),
|
|
Err(error) => return Err(error),
|
|
}
|
|
t *= 2;
|
|
}
|
|
|
|
// Couldn't establish
|
|
Err(Error::TimedOut)
|
|
}
|
|
|
|
async fn try_connect(&self, timeout: Duration) -> Result<(), Error> {
|
|
{
|
|
let mut connection = self.connection.write();
|
|
connection.send_syn().await?;
|
|
}
|
|
|
|
let fut = poll_fn(|cx| {
|
|
let connection = self.connection.read();
|
|
connection.poll_established(cx)
|
|
});
|
|
|
|
match run_with_timeout(timeout, fut).await {
|
|
FutureTimeout::Ok(value) => value,
|
|
FutureTimeout::Timeout => Err(Error::TimedOut),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Socket for TcpSocket {
|
|
fn local_address(&self) -> SocketAddr {
|
|
self.local
|
|
}
|
|
|
|
fn remote_address(&self) -> Option<SocketAddr> {
|
|
Some(self.remote)
|
|
}
|
|
|
|
fn close(&self) -> Result<(), Error> {
|
|
block!(self.close_async(true).await)?
|
|
}
|
|
}
|
|
|
|
impl FileReadiness for TcpSocket {
|
|
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
|
self.poll_receive(cx).map_ok(|_| ())
|
|
}
|
|
}
|
|
|
|
impl ConnectionSocket for TcpSocket {
|
|
fn receive(&self, buffer: &mut [u8]) -> Result<usize, Error> {
|
|
block!(self.receive_async(buffer).await)?
|
|
}
|
|
|
|
fn send(&self, data: &[u8]) -> Result<usize, Error> {
|
|
block!(self.send_async(data).await)?
|
|
}
|
|
}
|
|
|
|
impl TcpListener {
|
|
pub fn bind(accept: SocketAddr) -> Result<Arc<Self>, Error> {
|
|
TCP_LISTENERS.write().try_insert_with(accept, || {
|
|
let listener = TcpListener {
|
|
accept,
|
|
sockets: IrqSafeRwLock::new(BTreeMap::new()),
|
|
pending_accept: IrqSafeSpinlock::new(Vec::new()),
|
|
accept_notify: QueueWaker::new(),
|
|
};
|
|
|
|
log::debug!("TCP Listener opened: {}", accept);
|
|
|
|
Ok(Arc::new(listener))
|
|
})
|
|
}
|
|
|
|
pub fn get(local: SocketAddr) -> Option<Arc<Self>> {
|
|
TCP_LISTENERS.read().get(&local)
|
|
}
|
|
|
|
pub fn accept_async(&self) -> impl Future<Output = Result<Arc<TcpSocket>, Error>> + '_ {
|
|
struct F<'f> {
|
|
listener: &'f TcpListener,
|
|
}
|
|
|
|
impl<'f> Future for F<'f> {
|
|
type Output = Result<Arc<TcpSocket>, Error>;
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
match self.listener.poll_accept(cx) {
|
|
Poll::Ready(mut lock) => Poll::Ready(Ok(lock.pop().unwrap())),
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
}
|
|
}
|
|
|
|
F { listener: self }
|
|
}
|
|
|
|
fn accept_socket(&self, socket: Arc<TcpSocket>) {
|
|
log::debug!("{}: accept {}", self.accept, socket.remote);
|
|
self.sockets.write().insert(socket.remote, socket.clone());
|
|
self.pending_accept.lock().push(socket);
|
|
self.accept_notify.wake_one();
|
|
}
|
|
|
|
fn remove_socket(&self, remote: SocketAddr) {
|
|
log::debug!("Remove client {}", remote);
|
|
self.sockets.write().remove(&remote);
|
|
}
|
|
|
|
fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<IrqSafeSpinlockGuard<Vec<Arc<TcpSocket>>>> {
|
|
let lock = self.pending_accept.lock();
|
|
self.accept_notify.register(cx.waker());
|
|
if !lock.is_empty() {
|
|
self.accept_notify.remove(cx.waker());
|
|
Poll::Ready(lock)
|
|
} else {
|
|
Poll::Pending
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Socket for TcpListener {
|
|
fn local_address(&self) -> SocketAddr {
|
|
self.accept
|
|
}
|
|
|
|
fn remote_address(&self) -> Option<SocketAddr> {
|
|
None
|
|
}
|
|
|
|
fn close(&self) -> Result<(), Error> {
|
|
// TODO if clients not closed already, send RST?
|
|
TCP_LISTENERS.write().remove(self.accept)
|
|
}
|
|
}
|
|
|
|
impl FileReadiness for TcpListener {
|
|
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
|
self.poll_accept(cx).map(|_| Ok(()))
|
|
}
|
|
}
|
|
|
|
impl ListenerSocket for TcpListener {
|
|
fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
|
|
let socket = block!(self.accept_async().await)??;
|
|
let remote = socket.remote;
|
|
Ok((remote, socket))
|
|
}
|
|
}
|