Add some docs

This commit is contained in:
2024-01-07 03:45:35 +02:00
parent 36502625db
commit c5411cd0cd
6 changed files with 95 additions and 19 deletions
+3
View File
@@ -1,3 +1,6 @@
//! Errors defined by different modules
/// Common error type for the crate
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("I/O error: {0}")]
+31 -8
View File
@@ -1,9 +1,24 @@
//! A minimal asynchronous STUN server implementation.
//!
//! The server implements the minimal necessary functionality from
//! [https://datatracker.ietf.org/doc/html/rfc8489](RFC 8489), namely
//! the Binding Request.
//!
//! The server is both IPv6 and IPv4 capable. When listening on IPv6 address,
//! the server is capable of handling IPv6/IPv4 clients.
//!
//! To try out the server, just run:
//!
//! ```bash
//! $ cargo run
//! ```
#![deny(warnings)]
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use clap::Parser;
use protocol::{
convert::{deserialize_message, DeserializeResult, WireSerialize},
Attribute, Class, Message, Method, Response,
Attribute, Class, IncomingMessage, Method, Response,
};
use tokio::{
io::AsyncReadExt,
@@ -28,14 +43,15 @@ struct Args {
listen_address: Option<SocketAddr>,
}
/// Helper function to report a [ErrorCode::BadRequest] reply
async fn send_bad_request<S: AsyncSendTo>(
socket: &mut S,
source: SocketAddr,
resp_buffer: &mut [u8],
message: Option<&Message>,
message: Option<&IncomingMessage>,
) -> Result<(), Error> {
// Dummy message in case the message failed to parse at all
let dummy = Message::invalid();
let dummy = IncomingMessage::invalid();
let message = message.unwrap_or(&dummy);
let mut response = Response::new(false);
response.add_attribute(Attribute::ErrorCode(ErrorCode::BadRequest));
@@ -44,12 +60,14 @@ async fn send_bad_request<S: AsyncSendTo>(
Ok(())
}
/// Handles the [DeserializeResult] returned by [deserialize_message] by properly replying
/// to malformed messages and unrecognized comprehension-required attributes
async fn handle_message_parse<S: AsyncSendTo>(
socket: &mut S,
source: SocketAddr,
message: &[u8],
resp_buffer: &mut [u8],
) -> Result<Message, Error> {
) -> Result<IncomingMessage, Error> {
match deserialize_message(message) {
DeserializeResult::Ok(message) => Ok(message),
DeserializeResult::Error(err) => {
@@ -70,6 +88,7 @@ async fn handle_message_parse<S: AsyncSendTo>(
}
}
/// Handles a single incoming message and generates an appropriate reply
async fn handle_message<S: AsyncSendTo>(
socket: &mut S,
source: SocketAddr,
@@ -103,6 +122,8 @@ async fn handle_message<S: AsyncSendTo>(
Ok(())
}
/// Handles errors during processing and prints them to log instead of propagating
/// them up the call chain
async fn handle_and_log_message<S: AsyncSendTo>(
socket: &mut S,
source: SocketAddr,
@@ -147,11 +168,13 @@ async fn main() -> Result<(), Error> {
res = tcp_accept_fut => {
let (mut socket, source) = res?;
let Ok(len) = socket.read(&mut buffer).await else {
continue;
};
tokio::spawn(async move {
let Ok(len) = socket.read(&mut buffer).await else {
return;
};
handle_and_log_message(&mut socket, source, &buffer[..len]).await;
handle_and_log_message(&mut socket, source, &buffer[..len]).await;
});
}
res = udp_receive_fut => {
let (len, source) = res?;
+22 -9
View File
@@ -1,3 +1,6 @@
//! Routines for converting between the "wire format" and the "Rust format"
//! representations of RFC8489-defined structures.
use std::{mem::size_of, net::SocketAddr};
use bytemuck::Pod;
@@ -12,17 +15,19 @@ use super::{
RawErrorCodeAttribute, RawMessageHeader, RawXorMappedAddressAttributeV4,
RawXorMappedAddressAttributeV6, FAMILY_IPV6, MESSAGE_COOKIE,
},
Attribute, Class, Message, Method, Response, MESSAGE_SIZE_LIMIT,
Attribute, Class, IncomingMessage, Method, Response, MESSAGE_SIZE_LIMIT,
};
/// Result-like enum to handle messages with unrecognized comprehension-required attributes
pub enum DeserializeResult {
Ok(Message),
UnknownComprehensionAttributes(Message, Vec<u16>),
Ok(IncomingMessage),
UnknownComprehensionAttributes(IncomingMessage, Vec<u16>),
Error(Error),
}
/// Helper trait for types sent to the requesting Agent in a response
pub trait WireSerialize {
fn wire_serialize(&self, request: &Message, output: &mut [u8]) -> Result<usize, Error>;
fn wire_serialize(&self, request: &IncomingMessage, output: &mut [u8]) -> Result<usize, Error>;
}
fn put_struct<T: Pod>(buffer: &mut [u8], value: &T) -> usize {
@@ -43,7 +48,7 @@ fn get_struct<T: Pod>(buffer: &[u8]) -> &T {
}
impl WireSerialize for Attribute {
fn wire_serialize(&self, request: &Message, output: &mut [u8]) -> Result<usize, Error> {
fn wire_serialize(&self, request: &IncomingMessage, output: &mut [u8]) -> Result<usize, Error> {
match self {
Self::XorMappedAddress(SocketAddr::V4(v4)) => {
let raw = RawXorMappedAddressAttributeV4::new(*v4).wrap();
@@ -86,7 +91,7 @@ impl WireSerialize for Attribute {
}
impl WireSerialize for Response {
fn wire_serialize(&self, request: &Message, output: &mut [u8]) -> Result<usize, Error> {
fn wire_serialize(&self, request: &IncomingMessage, output: &mut [u8]) -> Result<usize, Error> {
// Serialize attributes first
let mut attr_len = 0;
for attr in self.attributes.iter() {
@@ -114,6 +119,14 @@ impl WireSerialize for Response {
}
}
/// Processes a raw incoming message and outputs a [DeserializeResult]:
///
/// * If the message was received properly and all comprehension-required attributes are
/// recognized, returns [DeserializeResult::Ok].
/// * If the message was received properly, but there were comprehension-required
/// attributes which could not be recognized, returns a
/// [DeserializeResult::UnknownComprehensionAttributes].
/// * If the message was malformed, returns a [DeserializeResult::Error].
pub fn deserialize_message(data: &[u8]) -> DeserializeResult {
use DeserializeResult as DR;
@@ -143,7 +156,7 @@ pub fn deserialize_message(data: &[u8]) -> DeserializeResult {
};
let transaction_id = header.transaction_id.map(u32::from_be);
let mut unknown_attributes = vec![];
let mut message = Message {
let mut message = IncomingMessage {
method,
class,
transaction_id,
@@ -209,7 +222,7 @@ pub fn deserialize_message(data: &[u8]) -> DeserializeResult {
}
}
pub fn parse_ty(ty: u16) -> Result<(Method, Class), Error> {
fn parse_ty(ty: u16) -> Result<(Method, Class), Error> {
if ty & 0b1100000000000000 != 0 {
return Err(Error::MalformedMessage);
}
@@ -223,7 +236,7 @@ pub fn parse_ty(ty: u16) -> Result<(Method, Class), Error> {
Ok((method, class))
}
pub fn serialize_ty(method: Method, class: Class) -> u16 {
fn serialize_ty(method: Method, class: Class) -> u16 {
let method = method.repr();
let class = class.repr();
+29 -2
View File
@@ -1,3 +1,7 @@
//! RFC8489 protocol data types and handling functions.
//!
//! The main module contains the Rust-friendly types, while the raw specifics
//! of the actual RFC-specified protocol are defined in the [rfc8489] module.
use std::net::SocketAddr;
use enum_repr::EnumRepr;
@@ -5,8 +9,10 @@ use enum_repr::EnumRepr;
pub mod convert;
pub mod rfc8489;
/// Message limit imposed on both receive/send functions
pub const MESSAGE_SIZE_LIMIT: usize = 512;
/// Request/indication methods defined by the RFC
#[EnumRepr(type = "u16")]
#[derive(Debug, PartialEq, Clone, Copy)]
#[non_exhaustive]
@@ -15,6 +21,8 @@ pub enum Method {
Binding = 1,
}
/// Classes defined by the RFC. Although there are four defined, only two types are
/// used by STUN: request/response transactions and indication messages.
#[EnumRepr(type = "u16")]
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum Class {
@@ -24,34 +32,47 @@ pub enum Class {
ErrorResponse = 3,
}
/// Error codes reported by the server
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum ErrorCode {
UnknownAttribute,
BadRequest,
}
/// Defines main attributes handled by the server
#[derive(Debug)]
pub enum Attribute {
/// XOR-MAPPED-ADDRESS attribute, contains the address to be "echoed back"
/// to the requesting Agent
XorMappedAddress(SocketAddr),
/// ERROR-CODE, used to report an error
ErrorCode(ErrorCode),
/// UNKNOWN-ATTRIBUTES, MUST be present when reporting [ErrorCode::UnknownAttribute]
UnknownAttributes(Vec<u16>),
}
/// Defines a single message received by the server
#[derive(Debug)]
pub struct Message {
pub struct IncomingMessage {
/// Method of the message/transaction
pub method: Method,
/// Class of the message/transaction
pub class: Class,
/// Transaction ID, gets "echoed back" to the requesting Agent
pub transaction_id: [u32; 3],
/// Attributes coming from the requesting Agent
pub attributes: Vec<Attribute>,
}
/// Defines a response generated by the server
#[derive(Default, Debug)]
pub struct Response {
success: bool,
attributes: Vec<Attribute>,
}
impl Message {
impl IncomingMessage {
/// Dummy message used when a full message could not be received from the requesting Agent
pub fn invalid() -> Self {
Self {
method: Method::Reserved,
@@ -63,6 +84,11 @@ impl Message {
}
impl Response {
/// Creates a new [Response], with its actual [Class] code being based on the `success`
/// parameter:
///
/// * If `success` is `true`, [Class::SuccessResponse] (2) is used.
/// * Otherwise, [Class::ErrorResponse] (3) is used.
pub fn new(success: bool) -> Self {
Self {
success,
@@ -70,6 +96,7 @@ impl Response {
}
}
/// Pushes an attribute to the response's attribute list
pub fn add_attribute(&mut self, attr: Attribute) {
self.attributes.push(attr);
}
+2
View File
@@ -1,3 +1,5 @@
//! The "wire format" data structures of the RFC8489 protocol
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use bytemuck::{Pod, Zeroable};
+8
View File
@@ -1,3 +1,5 @@
//! Utility functions
use std::{
future::Future,
io,
@@ -9,6 +11,7 @@ use tokio::{
net::{TcpStream, UdpSocket},
};
/// Helper trait to allow the crate to operate on both TCP and UDP sockets
pub trait AsyncSendTo {
fn send_to<'a>(
&'a mut self,
@@ -38,6 +41,11 @@ impl AsyncSendTo for TcpStream {
}
}
/// Utility function to map incoming socket addresses to their actual representations.
///
/// When listening on an IPv6 socket, any incoming IPv4 socket addresses are converted
/// to their IPv4-to-IPv6 mapped forms, this routine takes care of that by mapping them
/// back to IPv4 if needed.
pub fn map_sockaddr(source: SocketAddr) -> SocketAddr {
match source {
SocketAddr::V4(_) => source,