From c5411cd0cd5f1d62860addc0c64b131e52ae2022 Mon Sep 17 00:00:00 2001 From: Mark Poliakov Date: Sun, 7 Jan 2024 03:45:35 +0200 Subject: [PATCH] Add some docs --- src/error.rs | 3 +++ src/main.rs | 39 +++++++++++++++++++++++++++++++-------- src/protocol/convert.rs | 31 ++++++++++++++++++++++--------- src/protocol/mod.rs | 31 +++++++++++++++++++++++++++++-- src/protocol/rfc8489.rs | 2 ++ src/util.rs | 8 ++++++++ 6 files changed, 95 insertions(+), 19 deletions(-) diff --git a/src/error.rs b/src/error.rs index ce64ccb..85311cb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,6 @@ +//! Errors defined by different modules + +/// Common error type for the crate #[derive(Debug, thiserror::Error)] pub enum Error { #[error("I/O error: {0}")] diff --git a/src/main.rs b/src/main.rs index f7ecffe..df8913b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,24 @@ +//! A minimal asynchronous STUN server implementation. +//! +//! The server implements the minimal necessary functionality from +//! [https://datatracker.ietf.org/doc/html/rfc8489](RFC 8489), namely +//! the Binding Request. +//! +//! The server is both IPv6 and IPv4 capable. When listening on IPv6 address, +//! the server is capable of handling IPv6/IPv4 clients. +//! +//! To try out the server, just run: +//! +//! ```bash +//! $ cargo run +//! ``` +#![deny(warnings)] use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use clap::Parser; use protocol::{ convert::{deserialize_message, DeserializeResult, WireSerialize}, - Attribute, Class, Message, Method, Response, + Attribute, Class, IncomingMessage, Method, Response, }; use tokio::{ io::AsyncReadExt, @@ -28,14 +43,15 @@ struct Args { listen_address: Option, } +/// Helper function to report a [ErrorCode::BadRequest] reply async fn send_bad_request( socket: &mut S, source: SocketAddr, resp_buffer: &mut [u8], - message: Option<&Message>, + message: Option<&IncomingMessage>, ) -> Result<(), Error> { // Dummy message in case the message failed to parse at all - let dummy = Message::invalid(); + let dummy = IncomingMessage::invalid(); let message = message.unwrap_or(&dummy); let mut response = Response::new(false); response.add_attribute(Attribute::ErrorCode(ErrorCode::BadRequest)); @@ -44,12 +60,14 @@ async fn send_bad_request( Ok(()) } +/// Handles the [DeserializeResult] returned by [deserialize_message] by properly replying +/// to malformed messages and unrecognized comprehension-required attributes async fn handle_message_parse( socket: &mut S, source: SocketAddr, message: &[u8], resp_buffer: &mut [u8], -) -> Result { +) -> Result { match deserialize_message(message) { DeserializeResult::Ok(message) => Ok(message), DeserializeResult::Error(err) => { @@ -70,6 +88,7 @@ async fn handle_message_parse( } } +/// Handles a single incoming message and generates an appropriate reply async fn handle_message( socket: &mut S, source: SocketAddr, @@ -103,6 +122,8 @@ async fn handle_message( Ok(()) } +/// Handles errors during processing and prints them to log instead of propagating +/// them up the call chain async fn handle_and_log_message( socket: &mut S, source: SocketAddr, @@ -147,11 +168,13 @@ async fn main() -> Result<(), Error> { res = tcp_accept_fut => { let (mut socket, source) = res?; - let Ok(len) = socket.read(&mut buffer).await else { - continue; - }; + tokio::spawn(async move { + let Ok(len) = socket.read(&mut buffer).await else { + return; + }; - handle_and_log_message(&mut socket, source, &buffer[..len]).await; + handle_and_log_message(&mut socket, source, &buffer[..len]).await; + }); } res = udp_receive_fut => { let (len, source) = res?; diff --git a/src/protocol/convert.rs b/src/protocol/convert.rs index 936e9a9..fe18a46 100644 --- a/src/protocol/convert.rs +++ b/src/protocol/convert.rs @@ -1,3 +1,6 @@ +//! Routines for converting between the "wire format" and the "Rust format" +//! representations of RFC8489-defined structures. + use std::{mem::size_of, net::SocketAddr}; use bytemuck::Pod; @@ -12,17 +15,19 @@ use super::{ RawErrorCodeAttribute, RawMessageHeader, RawXorMappedAddressAttributeV4, RawXorMappedAddressAttributeV6, FAMILY_IPV6, MESSAGE_COOKIE, }, - Attribute, Class, Message, Method, Response, MESSAGE_SIZE_LIMIT, + Attribute, Class, IncomingMessage, Method, Response, MESSAGE_SIZE_LIMIT, }; +/// Result-like enum to handle messages with unrecognized comprehension-required attributes pub enum DeserializeResult { - Ok(Message), - UnknownComprehensionAttributes(Message, Vec), + Ok(IncomingMessage), + UnknownComprehensionAttributes(IncomingMessage, Vec), Error(Error), } +/// Helper trait for types sent to the requesting Agent in a response pub trait WireSerialize { - fn wire_serialize(&self, request: &Message, output: &mut [u8]) -> Result; + fn wire_serialize(&self, request: &IncomingMessage, output: &mut [u8]) -> Result; } fn put_struct(buffer: &mut [u8], value: &T) -> usize { @@ -43,7 +48,7 @@ fn get_struct(buffer: &[u8]) -> &T { } impl WireSerialize for Attribute { - fn wire_serialize(&self, request: &Message, output: &mut [u8]) -> Result { + fn wire_serialize(&self, request: &IncomingMessage, output: &mut [u8]) -> Result { match self { Self::XorMappedAddress(SocketAddr::V4(v4)) => { let raw = RawXorMappedAddressAttributeV4::new(*v4).wrap(); @@ -86,7 +91,7 @@ impl WireSerialize for Attribute { } impl WireSerialize for Response { - fn wire_serialize(&self, request: &Message, output: &mut [u8]) -> Result { + fn wire_serialize(&self, request: &IncomingMessage, output: &mut [u8]) -> Result { // Serialize attributes first let mut attr_len = 0; for attr in self.attributes.iter() { @@ -114,6 +119,14 @@ impl WireSerialize for Response { } } +/// Processes a raw incoming message and outputs a [DeserializeResult]: +/// +/// * If the message was received properly and all comprehension-required attributes are +/// recognized, returns [DeserializeResult::Ok]. +/// * If the message was received properly, but there were comprehension-required +/// attributes which could not be recognized, returns a +/// [DeserializeResult::UnknownComprehensionAttributes]. +/// * If the message was malformed, returns a [DeserializeResult::Error]. pub fn deserialize_message(data: &[u8]) -> DeserializeResult { use DeserializeResult as DR; @@ -143,7 +156,7 @@ pub fn deserialize_message(data: &[u8]) -> DeserializeResult { }; let transaction_id = header.transaction_id.map(u32::from_be); let mut unknown_attributes = vec![]; - let mut message = Message { + let mut message = IncomingMessage { method, class, transaction_id, @@ -209,7 +222,7 @@ pub fn deserialize_message(data: &[u8]) -> DeserializeResult { } } -pub fn parse_ty(ty: u16) -> Result<(Method, Class), Error> { +fn parse_ty(ty: u16) -> Result<(Method, Class), Error> { if ty & 0b1100000000000000 != 0 { return Err(Error::MalformedMessage); } @@ -223,7 +236,7 @@ pub fn parse_ty(ty: u16) -> Result<(Method, Class), Error> { Ok((method, class)) } -pub fn serialize_ty(method: Method, class: Class) -> u16 { +fn serialize_ty(method: Method, class: Class) -> u16 { let method = method.repr(); let class = class.repr(); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index b5ba93c..bb44cae 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,3 +1,7 @@ +//! RFC8489 protocol data types and handling functions. +//! +//! The main module contains the Rust-friendly types, while the raw specifics +//! of the actual RFC-specified protocol are defined in the [rfc8489] module. use std::net::SocketAddr; use enum_repr::EnumRepr; @@ -5,8 +9,10 @@ use enum_repr::EnumRepr; pub mod convert; pub mod rfc8489; +/// Message limit imposed on both receive/send functions pub const MESSAGE_SIZE_LIMIT: usize = 512; +/// Request/indication methods defined by the RFC #[EnumRepr(type = "u16")] #[derive(Debug, PartialEq, Clone, Copy)] #[non_exhaustive] @@ -15,6 +21,8 @@ pub enum Method { Binding = 1, } +/// Classes defined by the RFC. Although there are four defined, only two types are +/// used by STUN: request/response transactions and indication messages. #[EnumRepr(type = "u16")] #[derive(Debug, PartialEq, Clone, Copy)] pub enum Class { @@ -24,34 +32,47 @@ pub enum Class { ErrorResponse = 3, } +/// Error codes reported by the server #[derive(Debug, PartialEq, Clone, Copy)] pub enum ErrorCode { UnknownAttribute, BadRequest, } +/// Defines main attributes handled by the server #[derive(Debug)] pub enum Attribute { + /// XOR-MAPPED-ADDRESS attribute, contains the address to be "echoed back" + /// to the requesting Agent XorMappedAddress(SocketAddr), + /// ERROR-CODE, used to report an error ErrorCode(ErrorCode), + /// UNKNOWN-ATTRIBUTES, MUST be present when reporting [ErrorCode::UnknownAttribute] UnknownAttributes(Vec), } +/// Defines a single message received by the server #[derive(Debug)] -pub struct Message { +pub struct IncomingMessage { + /// Method of the message/transaction pub method: Method, + /// Class of the message/transaction pub class: Class, + /// Transaction ID, gets "echoed back" to the requesting Agent pub transaction_id: [u32; 3], + /// Attributes coming from the requesting Agent pub attributes: Vec, } +/// Defines a response generated by the server #[derive(Default, Debug)] pub struct Response { success: bool, attributes: Vec, } -impl Message { +impl IncomingMessage { + /// Dummy message used when a full message could not be received from the requesting Agent pub fn invalid() -> Self { Self { method: Method::Reserved, @@ -63,6 +84,11 @@ impl Message { } impl Response { + /// Creates a new [Response], with its actual [Class] code being based on the `success` + /// parameter: + /// + /// * If `success` is `true`, [Class::SuccessResponse] (2) is used. + /// * Otherwise, [Class::ErrorResponse] (3) is used. pub fn new(success: bool) -> Self { Self { success, @@ -70,6 +96,7 @@ impl Response { } } + /// Pushes an attribute to the response's attribute list pub fn add_attribute(&mut self, attr: Attribute) { self.attributes.push(attr); } diff --git a/src/protocol/rfc8489.rs b/src/protocol/rfc8489.rs index 8b83f45..06ebffd 100644 --- a/src/protocol/rfc8489.rs +++ b/src/protocol/rfc8489.rs @@ -1,3 +1,5 @@ +//! The "wire format" data structures of the RFC8489 protocol + use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use bytemuck::{Pod, Zeroable}; diff --git a/src/util.rs b/src/util.rs index 4c8f3ff..c0ba767 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,3 +1,5 @@ +//! Utility functions + use std::{ future::Future, io, @@ -9,6 +11,7 @@ use tokio::{ net::{TcpStream, UdpSocket}, }; +/// Helper trait to allow the crate to operate on both TCP and UDP sockets pub trait AsyncSendTo { fn send_to<'a>( &'a mut self, @@ -38,6 +41,11 @@ impl AsyncSendTo for TcpStream { } } +/// Utility function to map incoming socket addresses to their actual representations. +/// +/// When listening on an IPv6 socket, any incoming IPv4 socket addresses are converted +/// to their IPv4-to-IPv6 mapped forms, this routine takes care of that by mapping them +/// back to IPv4 if needed. pub fn map_sockaddr(source: SocketAddr) -> SocketAddr { match source { SocketAddr::V4(_) => source,