From 968ce321da2fbbe3eb8508c3efb054133ce419f0 Mon Sep 17 00:00:00 2001 From: Mark Poliakov Date: Mon, 4 Nov 2024 16:29:09 +0200 Subject: [PATCH] netutils: better HTTP library --- userspace/Cargo.lock | 25 ++ userspace/Cargo.toml | 1 + userspace/netutils/Cargo.toml | 2 + userspace/netutils/src/dhcp_client.rs | 2 +- userspace/netutils/src/http.rs | 98 ++---- userspace/netutils/src/lib.rs | 163 +-------- userspace/netutils/src/netconf.rs | 2 +- userspace/netutils/src/netconfig.rs | 91 +++++ userspace/netutils/src/ping.rs | 6 +- userspace/netutils/src/proto/http.rs | 471 ++++++++++++++++++++++++++ userspace/netutils/src/proto/mod.rs | 70 ++++ userspace/rsh/Cargo.toml | 2 +- 12 files changed, 695 insertions(+), 238 deletions(-) create mode 100644 userspace/netutils/src/netconfig.rs create mode 100644 userspace/netutils/src/proto/http.rs create mode 100644 userspace/netutils/src/proto/mod.rs diff --git a/userspace/Cargo.lock b/userspace/Cargo.lock index cf2ddce9..78a0cd19 100644 --- a/userspace/Cargo.lock +++ b/userspace/Cargo.lock @@ -150,6 +150,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" + [[package]] name = "cc" version = "1.1.31" @@ -477,6 +483,12 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -577,6 +589,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "humansize" version = "2.1.3" @@ -761,6 +784,8 @@ dependencies = [ "bytemuck", "clap", "clap-num", + "http", + "log", "rand 0.9.0-alpha.1", "serde", "serde_json", diff --git a/userspace/Cargo.toml b/userspace/Cargo.toml index bdae0958..ea7e93ce 100644 --- a/userspace/Cargo.toml +++ b/userspace/Cargo.toml @@ -22,6 +22,7 @@ members = [ exclude = ["dynload-program", "test-kernel-module"] [workspace.dependencies] +log = "0.4.22" clap = { version = "4.5.20", features = ["std", "derive", "help", "usage"], default-features = false } clap-num = "1.1.1" serde_json = "1.0.132" diff --git a/userspace/netutils/Cargo.toml b/userspace/netutils/Cargo.toml index 15e65554..aa208b2e 100644 --- a/userspace/netutils/Cargo.toml +++ b/userspace/netutils/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +log.workspace = true yggdrasil-abi.workspace = true bytemuck.workspace = true serde_json.workspace = true @@ -14,6 +15,7 @@ clap-num.workspace = true rand.workspace = true url = "2.5.0" +http = "1.1.0" [lib] path = "src/lib.rs" diff --git a/userspace/netutils/src/dhcp_client.rs b/userspace/netutils/src/dhcp_client.rs index 99317fe8..d08fc1fd 100644 --- a/userspace/netutils/src/dhcp_client.rs +++ b/userspace/netutils/src/dhcp_client.rs @@ -7,7 +7,7 @@ use std::os::{ use std::{io, mem::size_of, process::ExitCode, time::Duration}; use bytemuck::{Pod, Zeroable}; -use netutils::{parse_udp_protocol, Error, NetConfig}; +use netutils::{netconfig::NetConfig, proto::parse_udp_protocol, Error}; use yggdrasil_abi::net::protocols::{ EtherType, EthernetFrame, InetChecksum, IpProtocol, Ipv4Frame, UdpFrame, }; diff --git a/userspace/netutils/src/http.rs b/userspace/netutils/src/http.rs index 657386e0..cb883546 100644 --- a/userspace/netutils/src/http.rs +++ b/userspace/netutils/src/http.rs @@ -1,29 +1,36 @@ use std::{ fs::File, - io::{self, stdout, Read, Stdout, Write}, - net::TcpStream, + io::{self, stdout, Stdout, Write}, path::{Path, PathBuf}, process::ExitCode, - str::FromStr, }; use clap::{Parser, Subcommand}; -use url::{Host, Url}; +use http::{Method, Uri}; +use netutils::proto::http::{Bytes, HttpClient, HttpError}; + +#[derive(Debug, thiserror::Error)] +enum Error { + #[error("I/O error: {0}")] + IoError(#[from] io::Error), + #[error("HTTP error: {0}")] + HttpError(#[from] HttpError), +} #[derive(Debug, Parser)] struct Arguments { #[clap(short, long)] output: Option, #[clap(subcommand)] - method: Method, + method: RequestMethod, } #[derive(Debug, Subcommand)] -enum Method { +enum RequestMethod { #[clap(arg_required_else_help = true)] Get { #[clap(help = "URL to GET")] - url: Url, + url: Uri, }, } @@ -57,75 +64,18 @@ impl Write for Output { } } -fn receive_header(stream: &mut R) -> Result, io::Error> { - let mut buf = [0]; - let mut line = Vec::new(); +fn get(url: Uri, output: Option) -> Result<(), Error> { + let mut output = Output::open(output)?; + let mut client = HttpClient::default(); + + let mut buffer = [0; 4096]; + let mut response = client.request(Method::GET, url).call::().unwrap(); loop { - stream.read_exact(&mut buf[..1])?; - if buf[0] == b'\r' { - continue; - } - if buf[0] == b'\n' { + let len = response.read(&mut buffer)?; + if len == 0 { break; } - line.push(buf[0]); - } - if line.is_empty() { - return Ok(None); - } - let line = String::from_utf8(line).unwrap(); - Ok(Some(line)) -} - -fn get(url: Url, output: Option) -> Result<(), io::Error> { - let mut output = Output::open(output)?; - - let host = url.host().unwrap(); - - let request = format!( - "GET {} HTTP/1.1\r\nHost: {}\r\nAccept: */*\r\n\r\n", - url.path(), - host - ); - - let port = url.port_or_known_default().unwrap(); - let mut stream = match host { - Host::Domain(hostname) => TcpStream::connect((hostname, port))?, - Host::Ipv4(address) => TcpStream::connect((address, port))?, - Host::Ipv6(address) => TcpStream::connect((address, port))?, - }; - - eprintln!("Connecting to {}:{}...", host, port); - - //let mut stream = TcpStream::connect(remote)?; - let mut buf = [0; 2048]; - - for line in request.split('\n') { - eprintln!("> {}", line.trim()); - } - - stream.write_all(request.as_bytes())?; - - let mut content_length = 0; - - while let Some(header) = receive_header(&mut stream)? { - let Some((key, value)) = header.split_once(':') else { - continue; - }; - let key = key.trim(); - let value = value.trim(); - - if key == "Content-Length" { - content_length = usize::from_str(value).unwrap(); - } - eprintln!("< {}", header); - } - - while content_length != 0 { - let limit = buf.len().min(content_length); - let amount = stream.read(&mut buf[..limit])?; - output.write_all(&buf[..amount])?; - content_length -= amount; + output.write_all(&buffer[..len])?; } Ok(()) @@ -135,7 +85,7 @@ fn main() -> ExitCode { let args = Arguments::parse(); let result = match args.method { - Method::Get { url } => get(url, args.output), + RequestMethod::Get { url } => get(url, args.output), }; match result { diff --git a/userspace/netutils/src/lib.rs b/userspace/netutils/src/lib.rs index 94a43e7c..60e063e9 100644 --- a/userspace/netutils/src/lib.rs +++ b/userspace/netutils/src/lib.rs @@ -1,21 +1,11 @@ -#![feature(yggdrasil_os, rustc_private)] +#![feature(yggdrasil_os, rustc_private, let_chains)] -use std::{ - io, - mem::size_of, - os::yggdrasil::io::message_channel::{ - MessageChannel, MessageChannelReceiver, MessageChannelSender, MessageDestination, - MessageReceiver, MessageSender, - }, -}; +use std::io::{self}; -use serde::Deserialize; -use yggdrasil_abi::net::{ - netconfig::{InterfaceQuery, NetConfigRequest, NetConfigResult, RoutingInfo}, - protocols::{EtherType, EthernetFrame, IpProtocol, Ipv4Frame, UdpFrame}, - types::NetValueImpl, - IpAddr, Ipv4Addr, MacAddress, SocketAddr, SubnetAddr, -}; +pub mod netconfig; +pub mod proto; + +// pub use proto::http::HttpClient; #[derive(thiserror::Error, Debug)] pub enum Error { @@ -28,144 +18,3 @@ pub enum Error { #[error("Timed out")] TimedOut, } - -pub struct NetConfig { - sender: MessageChannelSender, - receiver: MessageChannelReceiver, - buffer: [u8; 4096], -} - -impl NetConfig { - pub fn open() -> Result { - let channel = MessageChannel::open("@kernel-netconf", true)?; - let (sender, receiver) = channel.split(); - Ok(Self { - sender, - receiver, - buffer: [0; 4096], - }) - } - - pub fn send(&mut self, request: &NetConfigRequest) -> Result<(), Error> { - let bytes = serde_json::to_vec(&request)?; - self.sender - .send_message(&bytes, MessageDestination::Specific(0))?; - Ok(()) - } - - pub fn request<'de, T: Deserialize<'de>>( - &'de mut self, - request: &NetConfigRequest, - ) -> Result { - self.send(request)?; - let (_sender, len) = self.receiver.receive_message(&mut self.buffer)?; - let msg: NetConfigResult = serde_json::from_slice(&self.buffer[..len])?; - match msg { - NetConfigResult::Ok(value) => Ok(value), - NetConfigResult::Err(error) => Err(Error::NetConfError(error)), - } - } - - pub fn query_route(&mut self, address: IpAddr) -> Result { - self.request(&NetConfigRequest::QueryRoute(address)) - } - - pub fn query_arp( - &mut self, - interface_id: u32, - address: IpAddr, - perform_query: bool, - ) -> Result { - self.request(&NetConfigRequest::QueryArp( - interface_id, - address, - perform_query, - )) - } - - pub fn set_interface_address>( - &mut self, - interface: Q, - address: IpAddr, - ) -> Result<(), Error> { - self.request(&NetConfigRequest::SetNetworkAddress { - interface: interface.into(), - address, - }) - } - - pub fn add_route>( - &mut self, - interface: Q, - subnet: SubnetAddr, - gateway: Option, - ) -> Result<(), Error> { - self.request(&NetConfigRequest::AddRoute { - interface: interface.into(), - gateway, - subnet, - }) - } -} - -pub fn parse_l2_protocol(packet: &[u8]) -> Option<(EthernetFrame, &[u8])> { - if packet.len() < size_of::() { - return None; - } - - let l2_frame: &EthernetFrame = bytemuck::from_bytes(&packet[..size_of::()]); - let l2_data = &packet[size_of::()..]; - - Some((*l2_frame, l2_data)) -} - -pub fn parse_ip_protocol(packet: &[u8]) -> Option<(IpProtocol, IpAddr, IpAddr, &[u8])> { - let (l2_frame, l2_data) = parse_l2_protocol(packet)?; - - match EtherType::from_network_order(l2_frame.ethertype) { - EtherType::IPV4 if l2_data.len() >= size_of::() => { - let l3_frame: &Ipv4Frame = bytemuck::from_bytes(&l2_data[..size_of::()]); - - let source_addr = IpAddr::V4(Ipv4Addr::from(u32::from_network_order( - l3_frame.source_address, - ))); - let destination_addr = IpAddr::V4(Ipv4Addr::from(u32::from_network_order( - l3_frame.destination_address, - ))); - - Some(( - l3_frame.protocol, - source_addr, - destination_addr, - &l2_data[l3_frame.header_length()..l3_frame.total_length()], - )) - } - _ => None, - } -} - -pub fn parse_udp_protocol(packet: &[u8]) -> Option<(SocketAddr, SocketAddr, &[u8])> { - let (protocol, source_ip, destination_ip, l3_data) = parse_ip_protocol(packet)?; - - if protocol != IpProtocol::UDP || l3_data.len() < size_of::() { - return None; - } - - let l4_frame: &UdpFrame = bytemuck::from_bytes(&l3_data[..size_of::()]); - let l4_data_size = core::cmp::min( - l3_data.len() - size_of::(), - l4_frame.data_length(), - ); - - let source_addr = SocketAddr::new(source_ip, u16::from_network_order(l4_frame.source_port)); - let destination_addr = SocketAddr::new( - destination_ip, - u16::from_network_order(l4_frame.destination_port), - ); - - Some(( - source_addr, - destination_addr, - &l3_data[size_of::()..size_of::() + l4_data_size], - )) -} diff --git a/userspace/netutils/src/netconf.rs b/userspace/netutils/src/netconf.rs index a7999f83..67fa66f7 100644 --- a/userspace/netutils/src/netconf.rs +++ b/userspace/netutils/src/netconf.rs @@ -1,7 +1,7 @@ use std::{net::IpAddr, process::ExitCode, str::FromStr}; use clap::{Args, Parser, Subcommand}; -use netutils::{Error, NetConfig}; +use netutils::{netconfig::NetConfig, Error}; use yggdrasil_abi::net::netconfig::{InterfaceInfo, InterfaceQuery, NetConfigRequest, RouteInfo}; #[derive(Debug, Parser)] diff --git a/userspace/netutils/src/netconfig.rs b/userspace/netutils/src/netconfig.rs new file mode 100644 index 00000000..4caefd63 --- /dev/null +++ b/userspace/netutils/src/netconfig.rs @@ -0,0 +1,91 @@ +use std::os::yggdrasil::io::message_channel::{ + MessageChannel, MessageChannelReceiver, MessageChannelSender, MessageDestination, + MessageReceiver, MessageSender, +}; + +use serde::Deserialize; +use yggdrasil_abi::net::{ + netconfig::{InterfaceQuery, NetConfigRequest, NetConfigResult, RoutingInfo}, + IpAddr, MacAddress, SubnetAddr, +}; + +use crate::Error; + +pub struct NetConfig { + sender: MessageChannelSender, + receiver: MessageChannelReceiver, + buffer: [u8; 4096], +} + +impl NetConfig { + pub fn open() -> Result { + let channel = MessageChannel::open("@kernel-netconf", true)?; + let (sender, receiver) = channel.split(); + Ok(Self { + sender, + receiver, + buffer: [0; 4096], + }) + } + + pub fn send(&mut self, request: &NetConfigRequest) -> Result<(), Error> { + let bytes = serde_json::to_vec(&request)?; + self.sender + .send_message(&bytes, MessageDestination::Specific(0))?; + Ok(()) + } + + pub fn request<'de, T: Deserialize<'de>>( + &'de mut self, + request: &NetConfigRequest, + ) -> Result { + self.send(request)?; + let (_sender, len) = self.receiver.receive_message(&mut self.buffer)?; + let msg: NetConfigResult = serde_json::from_slice(&self.buffer[..len])?; + match msg { + NetConfigResult::Ok(value) => Ok(value), + NetConfigResult::Err(error) => Err(Error::NetConfError(error)), + } + } + + pub fn query_route(&mut self, address: IpAddr) -> Result { + self.request(&NetConfigRequest::QueryRoute(address)) + } + + pub fn query_arp( + &mut self, + interface_id: u32, + address: IpAddr, + perform_query: bool, + ) -> Result { + self.request(&NetConfigRequest::QueryArp( + interface_id, + address, + perform_query, + )) + } + + pub fn set_interface_address>( + &mut self, + interface: Q, + address: IpAddr, + ) -> Result<(), Error> { + self.request(&NetConfigRequest::SetNetworkAddress { + interface: interface.into(), + address, + }) + } + + pub fn add_route>( + &mut self, + interface: Q, + subnet: SubnetAddr, + gateway: Option, + ) -> Result<(), Error> { + self.request(&NetConfigRequest::AddRoute { + interface: interface.into(), + gateway, + subnet, + }) + } +} diff --git a/userspace/netutils/src/ping.rs b/userspace/netutils/src/ping.rs index 07d298e2..a9e92bf4 100644 --- a/userspace/netutils/src/ping.rs +++ b/userspace/netutils/src/ping.rs @@ -4,9 +4,7 @@ use std::{ mem::size_of, os::{ fd::AsRawFd, - yggdrasil::{ - io::{poll::PollChannel, raw_socket::RawSocket, timer::TimerFd}, - }, + yggdrasil::io::{poll::PollChannel, raw_socket::RawSocket, timer::TimerFd}, }, process::ExitCode, sync::atomic::{AtomicBool, Ordering}, @@ -15,7 +13,7 @@ use std::{ use bytemuck::Zeroable; use clap::Parser; -use netutils::{Error, NetConfig}; +use netutils::{netconfig::NetConfig, Error}; use yggdrasil_abi::net::{ protocols::{EtherType, EthernetFrame, IcmpV4Frame, InetChecksum, IpProtocol, Ipv4Frame}, types::NetValueImpl, diff --git a/userspace/netutils/src/proto/http.rs b/userspace/netutils/src/proto/http.rs new file mode 100644 index 00000000..346f4dde --- /dev/null +++ b/userspace/netutils/src/proto/http.rs @@ -0,0 +1,471 @@ +use std::{ + fmt, + io::{self, Read, Write}, + marker::PhantomData, + net::{SocketAddr, TcpStream, ToSocketAddrs}, + ops::{Deref, DerefMut}, + string::FromUtf8Error, + time::Duration, +}; + +use http::{ + header::{self, InvalidHeaderValue}, + response, + status::InvalidStatusCode, + HeaderName, HeaderValue, Method, Response, StatusCode, Uri, +}; + +pub trait HttpConnection { + type Error: std::error::Error + Send + 'static; + + fn send(&mut self, buffer: &[u8]) -> Result; + fn send_all(&mut self, buffer: &[u8]) -> Result<(), Self::Error>; + fn recv(&mut self, buffer: &mut [u8]) -> Result; + fn recv_exact(&mut self, buffer: &mut [u8]) -> Result<(), Self::Error>; +} +pub trait HttpConnector { + type Connection: HttpConnection; + type Error: std::error::Error + Send + 'static; + + fn connect( + &mut self, + remote: &SocketAddr, + scheme: &str, + server_name: &str, + timeout: Option, + options: HttpConnectionOptions, + ) -> Result; + + fn supports_scheme(&self, scheme: &str) -> bool; +} + +#[derive(Clone)] +pub struct HttpConnectionOptions { + pub read_timeout: Option, + pub write_timeout: Option, +} + +pub struct TcpConnector; + +pub trait HttpBody: Sized { + fn from_response(content_length: Option) -> Option; + fn write(&self, connection: &mut C) -> Result<(), C::Error>; + fn content_length(&self) -> Option; +} + +#[derive(Debug)] +pub struct Bytes { + content_length: Option, +} + +pub struct HttpClient { + connector: C, + options: HttpConnectionOptions, + connect_timeout: Option, +} + +pub struct HttpRequestBuilder<'c, C: HttpConnector, U: TryInto> { + client: &'c mut HttpClient, + url: U, + builder: http::request::Builder, +} + +pub struct HttpResponse { + connection: C, + position: usize, + inner: Response, +} + +#[derive(Debug, thiserror::Error)] +pub enum HttpError { + #[error("Malformed URL")] + MalformedUrl, + #[error("Unsupported URL scheme: {0:?}")] + UnsupportedScheme(String), + #[error("Connection error")] + ConnectionError(#[from] E), + #[error("Unexpected end of connection")] + EndOfConnection, + #[error("Malformed response header")] + MalformedHeader(FromUtf8Error), + #[error("Malformed request header: {0}")] + InvalidHeaderValue(InvalidHeaderValue), + #[error("Request error: {0}")] + Request(http::Error), + #[error("Response error: {0}")] + Response(http::Error), + #[error("Hostname error: {0}")] + Hostname(io::Error), + #[error("Invalid status code: {0}")] + InvalidStatusCode(InvalidStatusCode), + #[error("Invalid status line")] + InvalidStatusLine, + #[error("Could not read response body")] + BodyError, + #[error("Request too large")] + RequestTooLarge, + #[error("Could not connect to {0}")] + CouldNotConnect(Uri), +} + +impl HttpConnector for TcpConnector { + type Connection = TcpStream; + type Error = io::Error; + + fn connect( + &mut self, + remote: &SocketAddr, + _scheme: &str, + _server_name: &str, + timeout: Option, + options: HttpConnectionOptions, + ) -> Result { + let socket = match timeout { + Some(timeout) => TcpStream::connect_timeout(remote, timeout)?, + None => TcpStream::connect(remote)?, + }; + socket.set_read_timeout(options.read_timeout)?; + socket.set_write_timeout(options.write_timeout)?; + Ok(socket) + } + + fn supports_scheme(&self, scheme: &str) -> bool { + scheme.eq_ignore_ascii_case("http") + } +} + +impl HttpConnection for TcpStream { + type Error = io::Error; + + fn send(&mut self, buffer: &[u8]) -> Result { + self.write(buffer) + } + + fn send_all(&mut self, buffer: &[u8]) -> Result<(), Self::Error> { + self.write_all(buffer) + } + + fn recv(&mut self, buffer: &mut [u8]) -> Result { + self.read(buffer) + } + + fn recv_exact(&mut self, buffer: &mut [u8]) -> Result<(), Self::Error> { + self.read_exact(buffer) + } +} + +impl HttpClient { + pub fn new_default(connector: C) -> Self { + Self { + connector, + connect_timeout: Some(Duration::from_secs(3)), + options: Default::default(), + } + } + + pub fn request>( + &mut self, + method: http::Method, + url: U, + ) -> HttpRequestBuilder { + HttpRequestBuilder { + client: self, + url, + builder: http::Request::builder().method(method), + } + } + + fn send>( + &mut self, + url: U, + request: http::request::Builder, + body: T, + ) -> Result, HttpError> { + let url: Uri = url.try_into().map_err(|_| HttpError::MalformedUrl)?; + + let scheme = url.scheme_str().ok_or(HttpError::MalformedUrl)?; + + if !self.connector.supports_scheme(scheme) { + return Err(HttpError::UnsupportedScheme(scheme.into())); + } + + let host = url.host().ok_or(HttpError::MalformedUrl)?; + let port = url.port_u16().unwrap_or(80); + + let request = if let Some(path) = url.path_and_query() { + request.uri(path.as_str()) + } else { + request + }; + let request = if let Some(content_length) = body.content_length() { + request.header(header::CONTENT_LENGTH, content_length) + } else { + request + }; + let request = if let Some(authority) = url.authority() { + request.header( + header::HOST, + HeaderValue::from_str(authority.as_str()).map_err(HttpError::InvalidHeaderValue)?, + ) + } else { + request + }; + + let request = request.body(body).map_err(HttpError::Request)?; + + if request.version() != http::Version::HTTP_11 { + unimplemented!() + } + + for socket_addr in + ToSocketAddrs::to_socket_addrs(&(host, port)).map_err(HttpError::Hostname)? + { + let Ok(mut connection) = self.connector.connect( + &socket_addr, + scheme, + host, + self.connect_timeout, + self.options.clone(), + ) else { + continue; + }; + + send_request_http1(&mut connection, &request)?; + return recv_response_http1(connection); + } + + Err(HttpError::CouldNotConnect(url)) + } +} + +impl From for HttpClient { + fn from(value: C) -> Self { + Self::new_default(value) + } +} + +impl Default for HttpClient { + fn default() -> Self { + Self::new_default(TcpConnector) + } +} + +impl> HttpRequestBuilder<'_, C, U> { + pub fn call(self) -> Result, HttpError> { + self.client.send(self.url, self.builder, ()) + } +} + +impl HttpBody for () { + fn from_response(_content_length: Option) -> Option { + Some(()) + } + + fn content_length(&self) -> Option { + Some(0) + } + + fn write(&self, _connection: &mut C) -> Result<(), C::Error> { + Ok(()) + } +} + +impl HttpBody for Bytes { + fn from_response(content_length: Option) -> Option { + Some(Self { content_length }) + } + + fn content_length(&self) -> Option { + self.content_length + } + + fn write(&self, _connection: &mut C) -> Result<(), C::Error> { + todo!() + } +} + +impl Default for HttpConnectionOptions { + fn default() -> Self { + Self { + read_timeout: Some(Duration::from_secs(3)), + write_timeout: Some(Duration::from_secs(3)), + } + } +} + +impl fmt::Debug for HttpResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(f) + } +} + +impl HttpResponse { + pub fn read(&mut self, buffer: &mut [u8]) -> Result> { + let amount = match self.inner.body().content_length { + Some(len) => core::cmp::min(len - self.position, buffer.len()), + None => buffer.len(), + }; + if amount == 0 { + return Ok(0); + } + let len = self.connection.recv(&mut buffer[..amount])?; + self.position += len; + Ok(len) + } +} + +impl Deref for HttpResponse { + type Target = Response; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for HttpResponse { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +struct RequestWriter<'b, C: HttpConnection> { + pos: usize, + buffer: &'b mut [u8], + _pd: PhantomData, +} + +impl<'b, C: HttpConnection> RequestWriter<'b, C> { + pub fn new(buffer: &'b mut [u8]) -> Self { + Self { + buffer, + pos: 0, + _pd: PhantomData, + } + } + + pub fn write(&mut self, bytes: &[u8]) -> Result<(), HttpError> { + if bytes.len() + self.pos > self.buffer.len() { + return Err(HttpError::RequestTooLarge); + } + self.buffer[self.pos..self.pos + bytes.len()].copy_from_slice(bytes); + self.pos += bytes.len(); + Ok(()) + } + + pub fn header( + &mut self, + key: &HeaderName, + value: &HeaderValue, + ) -> Result<(), HttpError> { + self.write(key.as_str().as_bytes())?; + self.write(b":")?; + self.write(value.as_bytes())?; + self.write(b"\r\n")?; + Ok(()) + } + + pub fn request(&mut self, method: &Method, uri: &Uri) -> Result<(), HttpError> { + let line = format!("{} {} HTTP/1.1\r\n", method, uri); + self.write(line.as_bytes()) + } + + pub fn finish(mut self) -> Result<&'b [u8], HttpError> { + self.write(b"\r\n")?; + Ok(&self.buffer[..self.pos]) + } +} + +fn send_request_http1( + connection: &mut C, + request: &http::Request, +) -> Result<(), HttpError> { + let mut buffer = [0; 4096]; + let mut writer = RequestWriter::::new(&mut buffer); + writer.request(request.method(), request.uri())?; + for (name, value) in request.headers() { + writer.header(name, value)?; + } + let preamble = writer.finish()?; + connection.send_all(preamble)?; + + request.body().write(connection)?; + Ok(()) +} + +fn recv_http1_line( + connection: &mut C, + buffer: &mut [u8], +) -> Result> { + let mut pos = 0; + loop { + if pos == buffer.len() { + todo!() + } + connection.recv_exact(&mut buffer[pos..pos + 1])?; + match buffer[pos] { + b'\r' => continue, + b'\n' => break, + _ => pos += 1, + } + } + Ok(pos) +} + +fn recv_response_http1( + mut connection: C, +) -> Result, HttpError> { + let mut buffer = [0; 256]; + let len = recv_http1_line(&mut connection, &mut buffer)?; + + let mut content_length = None; + let mut builder = response::Builder::new(); + let status = parse_status_line_http1::(&buffer[..len])?; + builder = builder.status(status); + + loop { + let len = recv_http1_line(&mut connection, &mut buffer)?; + if len == 0 { + break; + } + let Some(eq) = buffer.iter().position(|q| *q == b':') else { + continue; + }; + let (name, value) = buffer[..len].split_at(eq); + let name = name.trim_ascii_end(); + let value = value[1..].trim_ascii_start(); + let Ok(name) = HeaderName::from_bytes(name) else { + continue; + }; + + if name == header::CONTENT_LENGTH { + if let Some(value) = std::str::from_utf8(value).ok().and_then(|s| s.parse().ok()) { + content_length = Some(value); + } + } + + let Ok(value) = HeaderValue::from_bytes(value) else { + continue; + }; + + builder = builder.header(name, value); + } + + let body = T::from_response(content_length).ok_or(HttpError::BodyError)?; + let inner = builder.body(body).map_err(HttpError::Response)?; + + Ok(HttpResponse { + connection, + inner, + position: 0, + }) +} + +fn parse_status_line_http1( + line: &[u8], +) -> Result> { + let mut it = line.split(|c| *c == b' '); + it.next().ok_or(HttpError::InvalidStatusLine)?; + let status = it.next().ok_or(HttpError::InvalidStatusLine)?; + let status = StatusCode::from_bytes(status).map_err(HttpError::InvalidStatusCode)?; + Ok(status) +} diff --git a/userspace/netutils/src/proto/mod.rs b/userspace/netutils/src/proto/mod.rs new file mode 100644 index 00000000..e0b000f4 --- /dev/null +++ b/userspace/netutils/src/proto/mod.rs @@ -0,0 +1,70 @@ +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +use yggdrasil_abi::net::{ + protocols::{EtherType, EthernetFrame, IpProtocol, Ipv4Frame, UdpFrame}, + types::NetValueImpl, +}; + +pub mod http; + +pub fn parse_l2_protocol(packet: &[u8]) -> Option<(EthernetFrame, &[u8])> { + if packet.len() < size_of::() { + return None; + } + + let l2_frame: &EthernetFrame = bytemuck::from_bytes(&packet[..size_of::()]); + let l2_data = &packet[size_of::()..]; + + Some((*l2_frame, l2_data)) +} + +pub fn parse_ip_protocol(packet: &[u8]) -> Option<(IpProtocol, IpAddr, IpAddr, &[u8])> { + let (l2_frame, l2_data) = parse_l2_protocol(packet)?; + + match EtherType::from_network_order(l2_frame.ethertype) { + EtherType::IPV4 if l2_data.len() >= size_of::() => { + let l3_frame: &Ipv4Frame = bytemuck::from_bytes(&l2_data[..size_of::()]); + + let source_addr = IpAddr::V4(Ipv4Addr::from(u32::from_network_order( + l3_frame.source_address, + ))); + let destination_addr = IpAddr::V4(Ipv4Addr::from(u32::from_network_order( + l3_frame.destination_address, + ))); + + Some(( + l3_frame.protocol, + source_addr, + destination_addr, + &l2_data[l3_frame.header_length()..l3_frame.total_length()], + )) + } + _ => None, + } +} + +pub fn parse_udp_protocol(packet: &[u8]) -> Option<(SocketAddr, SocketAddr, &[u8])> { + let (protocol, source_ip, destination_ip, l3_data) = parse_ip_protocol(packet)?; + + if protocol != IpProtocol::UDP || l3_data.len() < size_of::() { + return None; + } + + let l4_frame: &UdpFrame = bytemuck::from_bytes(&l3_data[..size_of::()]); + let l4_data_size = core::cmp::min( + l3_data.len() - size_of::(), + l4_frame.data_length(), + ); + + let source_addr = SocketAddr::new(source_ip, u16::from_network_order(l4_frame.source_port)); + let destination_addr = SocketAddr::new( + destination_ip, + u16::from_network_order(l4_frame.destination_port), + ); + + Some(( + source_addr, + destination_addr, + &l3_data[size_of::()..size_of::() + l4_data_size], + )) +} diff --git a/userspace/rsh/Cargo.toml b/userspace/rsh/Cargo.toml index cf0665aa..b76cd113 100644 --- a/userspace/rsh/Cargo.toml +++ b/userspace/rsh/Cargo.toml @@ -17,8 +17,8 @@ bytemuck.workspace = true x25519-dalek.workspace = true ed25519-dalek = { workspace = true, features = ["rand_core", "pem"] } sha2.workspace = true +log.workspace = true rand = { git = "https://git.alnyan.me/yggdrasil/rand.git", branch = "alnyan/yggdrasil-rng_core-0.6.4" } aes = { version = "0.8.4" } -log = "0.4.22" env_logger = "0.11.5"