Add 'kernel/' from commit '7f1f6b73377367db17f98a740316b904c37ce3b1'

git-subtree-dir: kernel
git-subtree-mainline: 817f71f90f97270dd569fd44246bf74e57636552
git-subtree-split: 7f1f6b73377367db17f98a740316b904c37ce3b1
This commit is contained in:
Mark Poliakov 2024-03-12 15:52:48 +02:00
commit 18fa8b954a
291 changed files with 42349 additions and 0 deletions

View File

@ -0,0 +1,51 @@
name: Kernel tests
run_name: Kernel tests
on: [pull_request]
jobs:
Test-x86_64-Build:
runs-on: ubuntu-latest
steps:
- name: Checkout kernel sources
uses: actions/checkout@v3
- name: Install build dependencies
run: |
apt update && apt install -y nasm gcc
- name: Install nightly Rust toolchain
run: |
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y --default-toolchain nightly
source "$HOME/.cargo/env"
rustup component add rust-src --toolchain nightly-x86_64-unknown-linux-gnu
- name: Update dependencies
run: |
source "$HOME/.cargo/env"
cd ${{ gitea.workspace }}
cargo update yggdrasil-abi elf
- name: Build x86-64
run: |
source "$HOME/.cargo/env"
cd ${{ gitea.workspace }}
cargo build -Z build-std=core,alloc,compiler_builtins --target=etc/x86_64-unknown-none.json
Test-aarch64-Build:
runs-on: ubuntu-latest
steps:
- name: Checkout kernel sources
uses: actions/checkout@v3
- name: Install build dependencies
run: |
apt update && apt install -y nasm gcc
- name: Install nightly Rust toolchain
run: |
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y --default-toolchain nightly
source "$HOME/.cargo/env"
rustup component add rust-src --toolchain nightly-x86_64-unknown-linux-gnu
- name: Update dependencies
run: |
source "$HOME/.cargo/env"
cd ${{ gitea.workspace }}
cargo update yggdrasil-abi elf
- name: Build aarch64
run: |
source "$HOME/.cargo/env"
cd ${{ gitea.workspace }}
cargo build -Z build-std=core,alloc,compiler_builtins --target=etc/aarch64-unknown-qemu.json

1
kernel/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

82
kernel/Cargo.toml Normal file
View File

@ -0,0 +1,82 @@
[package]
name = "yggdrasil-kernel"
version = "0.1.0"
edition = "2021"
build = "build.rs"
authors = ["Mark Poliakov <mark@alnyan.me>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[profile.dev.package.tock-registers]
opt-level = 3
[dependencies]
abi-lib = { git = "https://git.alnyan.me/yggdrasil/abi-generator.git" }
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
vfs = { path = "lib/vfs" }
device-api = { path = "lib/device-api", features = ["derive"] }
libk = { path = "libk" }
libk-util = { path = "libk/libk-util" }
libk-mm = { path = "libk/libk-mm" }
libk-thread = { path = "libk/libk-thread" }
libk-device = { path = "libk/libk-device" }
memtables = { path = "lib/memtables" }
vmalloc = { path = "lib/vmalloc" }
device-api-macros = { path = "lib/device-api/macros" }
kernel-arch = { path = "arch" }
# Drivers
ygg_driver_pci = { path = "driver/bus/pci" }
ygg_driver_usb = { path = "driver/bus/usb" }
ygg_driver_block = { path = "driver/block/core" }
ygg_driver_net_core = { path = "driver/net/core" }
ygg_driver_net_loopback = { path = "driver/net/loopback" }
ygg_driver_virtio_net = { path = "driver/virtio/net", features = ["pci"] }
ygg_driver_ahci = { path = "driver/block/ahci" }
ygg_driver_usb_xhci = { path = "driver/usb/xhci" }
ygg_driver_input = { path = "driver/input" }
kernel-fs = { path = "driver/fs/kernel-fs" }
memfs = { path = "driver/fs/memfs" }
atomic_enum = "0.2.0"
bitflags = "2.3.3"
linked_list_allocator = "0.10.5"
spinning_top = "0.2.5"
static_assertions = "1.1.0"
tock-registers = "0.8.1"
cfg-if = "1.0.0"
git-version = "0.3.5"
log = "0.4.20"
futures-util = { version = "0.3.28", default-features = false, features = ["alloc", "async-await"] }
crossbeam-queue = { version = "0.3.8", default-features = false, features = ["alloc"] }
bytemuck = { version = "1.14.0", features = ["derive"] }
[dependencies.elf]
version = "0.7.2"
git = "https://git.alnyan.me/yggdrasil/yggdrasil-elf.git"
default-features = false
features = ["no_std_stream"]
[target.'cfg(target_arch = "aarch64")'.dependencies]
aarch64-cpu = "9.3.1"
device-tree = { path = "lib/device-tree" }
kernel-arch-aarch64 = { path = "arch/aarch64" }
[target.'cfg(target_arch = "x86_64")'.dependencies]
yboot-proto = { git = "https://git.alnyan.me/yggdrasil/yboot-proto.git" }
aml = { git = "https://github.com/alnyan/acpi.git", branch = "acpi-system" }
acpi_lib = { git = "https://github.com/alnyan/acpi.git", package = "acpi", branch = "acpi-system" }
acpi-system = { git = "https://github.com/alnyan/acpi-system.git" }
ygg_driver_nvme = { path = "driver/block/nvme" }
kernel-arch-x86_64 = { path = "arch/x86_64" }
[build-dependencies]
prettyplease = "0.2.15"
yggdrasil-abi-def = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi-def.git" }
abi-generator = { git = "https://git.alnyan.me/yggdrasil/abi-generator.git" }
[features]
default = ["fb_console"]
fb_console = []

87
kernel/README.md Normal file
View File

@ -0,0 +1,87 @@
yggdrasil-kernel
================
Rust Unix-like operating system kernel.
See also:
* [ABI for kernel-user communication](https://git.alnyan.me/yggdrasil-abi)
* [Rust fork to use with the kernel](https://git.alnyan.me/yggdrasil/yggdrasil-rust)
* [Userspace programs](https://git.alnyan.me/yggdrasil/yggdrasil-userspace)
* [yboot — x86-64 UEFI bootloader](https://git.alnyan.me/yggdrasil/yboot)
Main features
-------------
* Architecture support: [aarch64](/src/arch/aarch64) and [x86_64](/src/arch/x86_64)
* Kernel/userspace preemptive multithreading
* Kernel-space multitasking with `async`/`await` runtime
* Symmetric Multiprocessing
* Unix-like virtual filesystem:
files, directories, block/char devices, symlinks, mounts
* In-memory read-write filesystem for tar-based initrd
* sysfs/devfs
* Binary formats: ELF + `#!/...` shebangs
* Rust-style interfaces for most of the stuff like memory management, devices etc.
aarch64-specific:
* PSCI for SMP start-up and power control
* PL011 serial port
* ARM generic timer as system/monotonic timer
* GICv2 IRQ controller
x86_64-specific:
* UEFI boot through [yboot](https://git.alnyan.me/yggdrasil/yboot)
(no plans for legacy boot)
* PCIe, with plans to extend to aarch64 as well
* NVMe drive support (read/write)
* AHCI SATA drive support (read/write)
* I/O and Local APIC IRQ controllers
* PS/2 keyboard,
* i8253-based timer (got some problems with HPET on
real hw, had to revert, lol)
* COM ports
* ACPI, [work in progress](https://github.com/rust-osdev/acpi), mostly broken
on real hardware
* ACPI shutdown
* PCI IRQ pin routing
* Events like power button, etc.
* Fancy framebuffer console
Userspace features:
* Sanitized system calls better suited for Rust
* Userspace threads
* Synchronization primitives through futex-like interface
* Unix-like signals and exceptions
General plans (in no particular order)
--------------------------------------
* Better unification of architecture code
* `async` for VFS (?)
* PCIe NVMe block device
* PCIe SATA block device
* PCIe XHCI USB devices
* Better algorithms for memory management
Navigation
----------
* `src/arch` — architecture-specific code
* `src/device` — device driver implementations
* `bus` — bus devices like USB, PCIe etc.
* `display` — everything related to graphic displays
* `power` — power and reset controllers
* `serial` — serial transceiver drivers
* `devtree.rs` — stuff related to ARM DeviceTree
* `tty.rs` — Unix-style terminal driver implementation
* `src/fs` — in-kernel filesystems (sysfs/devfs)
* `src/mem` — memory management
* `src/proc` — process information management
* `src/syscall` — system call handling
* `src/task` — kernel and userspace tasks, processes and threads
* `src/util` — utilities used within the kernel
* `src/init.rs` — kernel init thread impl.

20
kernel/arch/Cargo.toml Normal file
View File

@ -0,0 +1,20 @@
[package]
name = "kernel-arch"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[target.'cfg(all(target_os = "none", target_arch = "x86_64"))'.dependencies]
kernel-arch-x86_64 = { path = "x86_64" }
[target.'cfg(all(target_os = "none", target_arch = "aarch64"))'.dependencies]
kernel-arch-aarch64 = { path = "aarch64" }
[target.'cfg(not(target_os = "none"))'.dependencies]
kernel-arch-hosted = { path = "hosted" }
[dependencies]
kernel-arch-interface = { path = "interface" }
cfg-if = "1.0.0"

View File

@ -0,0 +1,18 @@
[package]
name = "kernel-arch-aarch64"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
kernel-arch-interface = { path = "../interface" }
libk-mm-interface = { path = "../../libk/libk-mm/interface" }
memtables = { path = "../../lib/memtables" }
device-api = { path = "../../lib/device-api", features = ["derive"] }
bitflags = "2.3.3"
static_assertions = "1.1.0"
aarch64-cpu = "9.3.1"
tock-registers = "0.8.1"

View File

@ -0,0 +1,123 @@
.global __aarch64_enter_task
.global __aarch64_switch_task
.global __aarch64_switch_task_and_drop
.global __aarch64_task_enter_kernel
.global __aarch64_task_enter_user
.section .text
.macro SAVE_TASK_STATE
sub sp, sp, #{context_size}
stp x19, x20, [sp, #16 * 0]
stp x21, x22, [sp, #16 * 1]
stp x23, x24, [sp, #16 * 2]
stp x25, x26, [sp, #16 * 3]
stp x27, x28, [sp, #16 * 4]
stp x29, x30, [sp, #16 * 5]
mrs x19, tpidr_el0
mrs x20, ttbr0_el1
stp x19, x20, [sp, #16 * 6]
.endm
.macro LOAD_TASK_STATE
// x19 == tpidr_el0, x20 = ttbr0_el1
ldp x19, x20, [sp, #16 * 6]
msr tpidr_el0, x19
msr ttbr0_el1, x20
ldp x19, x20, [sp, #16 * 0]
ldp x21, x22, [sp, #16 * 1]
ldp x23, x24, [sp, #16 * 2]
ldp x25, x26, [sp, #16 * 3]
ldp x27, x28, [sp, #16 * 4]
ldp x29, x30, [sp, #16 * 5]
add sp, sp, #{context_size}
.endm
__aarch64_task_enter_kernel:
# EL1h, IRQs unmasked
mov x0, #5
msr spsr_el1, x0
# x0 == argument, x1 == entry point
ldp x0, x1, [sp, #0]
msr elr_el1, x1
add sp, sp, #16
eret
__aarch64_task_enter_user:
// x0 == sp, x1 == ignored
ldp x0, x1, [sp, #16 * 0]
msr sp_el0, x0
# EL0t, IRQs unmasked
msr spsr_el1, xzr
// x0 == arg, x1 == entry
ldp x0, x1, [sp, #16 * 1]
msr elr_el1, x1
add sp, sp, #32
// Zero the registers
mov x1, xzr
mov x2, xzr
mov x3, xzr
mov x4, xzr
mov x5, xzr
mov x6, xzr
mov x7, xzr
mov x8, xzr
mov x9, xzr
mov x10, xzr
mov x11, xzr
mov x12, xzr
mov x13, xzr
mov x14, xzr
mov x15, xzr
mov x16, xzr
mov x17, xzr
mov x18, xzr
mov lr, xzr
dmb ish
isb sy
eret
__aarch64_switch_task:
SAVE_TASK_STATE
mov x19, sp
str x19, [x1]
ldr x0, [x0]
mov sp, x0
LOAD_TASK_STATE
ret
// x0 -- destination context
// x1 -- source (dropped) thread
__aarch64_switch_task_and_drop:
ldr x0, [x0]
mov sp, x0
mov x0, x1
bl __arch_drop_thread
LOAD_TASK_STATE
ret
__aarch64_enter_task:
ldr x0, [x0]
mov sp, x0
LOAD_TASK_STATE
ret

View File

@ -0,0 +1,242 @@
//! AArch64-specific task context implementation
use core::{arch::global_asm, cell::UnsafeCell, fmt, marker::PhantomData};
use kernel_arch_interface::{
mem::{KernelTableManager, PhysicalMemoryAllocator},
task::{StackBuilder, TaskContext, TaskFrame},
};
use libk_mm_interface::address::PhysicalAddress;
use yggdrasil_abi::{arch::SavedFrame, error::Error};
/// Struct for register values saved when taking an exception
#[repr(C)]
pub struct ExceptionFrame {
/// General-purpose registers
pub r: [u64; 32],
/// SPSR_EL1, userspace flags register
pub spsr_el1: u64,
/// ELR_EL1, userspace program counter
pub elr_el1: u64,
/// SP_EL0, userspace stack pointer
pub sp_el0: u64,
_x: u64,
// ...
}
#[repr(C, align(0x10))]
struct TaskContextInner {
// 0x00
sp: usize,
}
/// AArch64 implementation of a task context
#[allow(unused)]
pub struct TaskContextImpl<
K: KernelTableManager,
PA: PhysicalMemoryAllocator<Address = PhysicalAddress>,
> {
inner: UnsafeCell<TaskContextInner>,
stack_base_phys: PhysicalAddress,
stack_size: usize,
_alloc: PhantomData<PA>,
_table_manager: PhantomData<K>,
}
const COMMON_CONTEXT_SIZE: usize = 8 * 14;
impl TaskFrame for ExceptionFrame {
fn store(&self) -> SavedFrame {
SavedFrame {
gp_regs: self.r,
spsr_el1: self.spsr_el1,
elr_el1: self.elr_el1,
sp_el0: self.sp_el0,
}
}
fn restore(&mut self, saved: &SavedFrame) {
self.r = saved.gp_regs;
self.spsr_el1 = saved.spsr_el1;
self.elr_el1 = saved.elr_el1;
self.sp_el0 = saved.sp_el0;
}
fn argument(&self) -> u64 {
self.r[0]
}
fn user_ip(&self) -> usize {
self.elr_el1 as _
}
fn user_sp(&self) -> usize {
self.sp_el0 as _
}
fn set_argument(&mut self, value: u64) {
self.r[0] = value;
}
fn set_return_value(&mut self, value: u64) {
self.r[0] = value;
}
fn set_user_ip(&mut self, value: usize) {
self.elr_el1 = value as _;
}
fn set_user_sp(&mut self, value: usize) {
self.sp_el0 = value as _;
}
}
impl fmt::Debug for ExceptionFrame {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for i in (0..32).step_by(2) {
write!(
f,
"x{:<2} = {:#020x}\tx{:<2} = {:#020x}",
i,
self.r[i],
i + 1,
self.r[i + 1]
)?;
if i != 30 {
f.write_str("\n")?;
}
}
Ok(())
}
}
unsafe impl<K: KernelTableManager, PA: PhysicalMemoryAllocator<Address = PhysicalAddress>> Sync
for TaskContextImpl<K, PA>
{
}
impl<K: KernelTableManager, PA: PhysicalMemoryAllocator<Address = PhysicalAddress>>
TaskContext<K, PA> for TaskContextImpl<K, PA>
{
const USER_STACK_EXTRA_ALIGN: usize = 0;
const SIGNAL_STACK_EXTRA_ALIGN: usize = 0;
fn kernel(entry: extern "C" fn(usize) -> !, arg: usize) -> Result<Self, Error> {
const KERNEL_TASK_PAGES: usize = 8;
let stack_base_phys = PA::allocate_contiguous_pages(KERNEL_TASK_PAGES)?;
let stack_base = stack_base_phys.raw_virtualize::<K>();
let mut stack = StackBuilder::new(stack_base, KERNEL_TASK_PAGES * 0x1000);
// Entry and argument
stack.push(entry as _);
stack.push(arg);
setup_common_context(&mut stack, __aarch64_task_enter_kernel as _, 0, 0);
let sp = stack.build();
// TODO stack is leaked
Ok(Self {
inner: UnsafeCell::new(TaskContextInner { sp }),
stack_base_phys,
stack_size: KERNEL_TASK_PAGES * 0x1000,
_alloc: PhantomData,
_table_manager: PhantomData,
})
}
fn user(
entry: usize,
arg: usize,
ttbr0: u64,
user_stack_sp: usize,
tpidr_el0: usize,
) -> Result<Self, Error> {
const USER_TASK_PAGES: usize = 16;
let stack_base_phys = PA::allocate_contiguous_pages(USER_TASK_PAGES)?;
let stack_base = stack_base_phys.raw_virtualize::<K>();
let mut stack = StackBuilder::new(stack_base, USER_TASK_PAGES * 0x1000);
stack.push(entry as _);
stack.push(arg);
stack.push(0);
stack.push(user_stack_sp);
setup_common_context(
&mut stack,
__aarch64_task_enter_user as _,
ttbr0,
tpidr_el0 as _,
);
let sp = stack.build();
Ok(Self {
inner: UnsafeCell::new(TaskContextInner { sp }),
stack_base_phys,
stack_size: USER_TASK_PAGES * 0x1000,
_alloc: PhantomData,
_table_manager: PhantomData,
})
}
unsafe fn enter(&self) -> ! {
__aarch64_enter_task(self.inner.get())
}
unsafe fn switch(&self, from: &Self) {
__aarch64_switch_task(self.inner.get(), from.inner.get())
}
unsafe fn switch_and_drop(&self, thread: *const ()) {
__aarch64_switch_task_and_drop(self.inner.get(), thread);
}
}
impl<K: KernelTableManager, PA: PhysicalMemoryAllocator<Address = PhysicalAddress>> Drop
for TaskContextImpl<K, PA>
{
fn drop(&mut self) {
assert_eq!(self.stack_size % 0x1000, 0);
for offset in (0..self.stack_size).step_by(0x1000) {
unsafe {
PA::free_page(self.stack_base_phys.add(offset));
}
}
}
}
fn setup_common_context(builder: &mut StackBuilder, entry: usize, ttbr0: u64, tpidr_el0: u64) {
builder.push(ttbr0 as _); // ttbr0_el1
builder.push(tpidr_el0 as _); // tpidr_el0
builder.push(entry); // x30/lr
builder.push(0); // x29
builder.push(0); // x28
builder.push(0); // x27
builder.push(0); // x26
builder.push(0); // x25
builder.push(0); // x24
builder.push(0); // x23
builder.push(0); // x22
builder.push(0); // x21
builder.push(0); // x20
builder.push(0); // x19
}
extern "C" {
fn __aarch64_enter_task(to: *mut TaskContextInner) -> !;
fn __aarch64_switch_task(to: *mut TaskContextInner, from: *mut TaskContextInner);
fn __aarch64_switch_task_and_drop(to: *mut TaskContextInner, thread: *const ()) -> !;
fn __aarch64_task_enter_kernel();
fn __aarch64_task_enter_user();
}
global_asm!(include_str!("context.S"), context_size = const COMMON_CONTEXT_SIZE);

View File

@ -0,0 +1,119 @@
#![no_std]
#![feature(
effects,
strict_provenance,
asm_const,
naked_functions,
trait_upcasting
)]
extern crate alloc;
use core::sync::atomic::{AtomicUsize, Ordering};
use aarch64_cpu::registers::{DAIF, MPIDR_EL1, TPIDR_EL1};
use alloc::{boxed::Box, vec::Vec};
use device_api::interrupt::{LocalInterruptController, MessageInterruptController};
use kernel_arch_interface::{
cpu::{CpuImpl, IpiQueue},
task::Scheduler,
util::OneTimeInit,
Architecture,
};
use tock_registers::interfaces::{ReadWriteable, Readable, Writeable};
pub mod context;
pub mod mem;
pub use context::TaskContextImpl;
pub use mem::{process::ProcessAddressSpaceImpl, KernelTableManagerImpl};
pub struct ArchitectureImpl;
pub trait GicInterface: LocalInterruptController {}
pub struct PerCpuData {
pub gic: OneTimeInit<&'static dyn GicInterface>,
}
static IPI_QUEUES: OneTimeInit<Vec<IpiQueue<ArchitectureImpl>>> = OneTimeInit::new();
pub static CPU_COUNT: AtomicUsize = AtomicUsize::new(1);
#[naked]
extern "C" fn idle_task(_: usize) -> ! {
unsafe {
core::arch::asm!("1: nop; b 1b", options(noreturn));
}
}
impl ArchitectureImpl {
pub fn local_cpu_data() -> Option<&'static mut PerCpuData> {
unsafe { (Self::local_cpu() as *mut PerCpuData).as_mut() }
}
}
impl Architecture for ArchitectureImpl {
type PerCpuData = PerCpuData;
fn cpu_index<S: Scheduler + 'static>() -> u32 {
(MPIDR_EL1.get() & 0xFF) as u32
}
fn interrupt_mask() -> bool {
DAIF.read(DAIF::I) != 0
}
unsafe fn set_interrupt_mask(mask: bool) -> bool {
let old = Self::interrupt_mask();
if mask {
DAIF.modify(DAIF::I::SET);
} else {
DAIF.modify(DAIF::I::CLEAR);
}
old
}
fn wait_for_interrupt() {
aarch64_cpu::asm::wfi();
}
unsafe fn set_local_cpu(cpu: *mut ()) {
TPIDR_EL1.set(cpu as _);
}
unsafe fn init_local_cpu<S: Scheduler + 'static>(id: Option<u32>, data: Self::PerCpuData) {
assert!(
id.is_none(),
"AArch64 uses MPIDR_EL1 instead of manual ID set"
);
let id = (MPIDR_EL1.get() & 0xFF) as u32;
let cpu = Box::leak(Box::new(CpuImpl::<Self, S>::new(id, data)));
cpu.set_local();
}
fn local_cpu() -> *mut () {
TPIDR_EL1.get() as _
}
unsafe fn init_ipi_queues(queues: Vec<IpiQueue<Self>>) {
IPI_QUEUES.init(queues);
}
fn idle_task() -> extern "C" fn(usize) -> ! {
idle_task
}
fn cpu_count() -> usize {
CPU_COUNT.load(Ordering::Acquire)
}
fn local_interrupt_controller() -> &'static dyn LocalInterruptController {
let local = Self::local_cpu_data().unwrap();
*local.gic.get()
}
fn message_interrupt_controller() -> &'static dyn MessageInterruptController {
todo!()
}
}

View File

@ -0,0 +1,408 @@
use core::{
alloc::Layout,
ops::{Deref, DerefMut},
ptr::addr_of,
sync::atomic::AtomicUsize,
sync::atomic::Ordering,
};
use aarch64_cpu::registers::{TTBR0_EL1, TTBR1_EL1};
use kernel_arch_interface::{
mem::{DeviceMemoryAttributes, KernelTableManager, RawDeviceMemoryMapping},
KERNEL_VIRT_OFFSET,
};
use libk_mm_interface::{
address::{FromRaw, PhysicalAddress},
table::{EntryLevel, EntryLevelExt},
KernelImageObject,
};
use memtables::aarch64::{FixedTables, KERNEL_L3_COUNT};
use static_assertions::const_assert_eq;
use tock_registers::interfaces::Writeable;
use yggdrasil_abi::error::Error;
use self::table::{PageAttributes, PageEntry, PageTable, L1, L2, L3};
pub mod process;
pub mod table;
#[derive(Debug)]
pub struct KernelTableManagerImpl;
// TODO eliminate this requirement by using precomputed indices
const MAPPING_OFFSET: usize = KERNEL_VIRT_OFFSET;
const KERNEL_PHYS_BASE: usize = 0x40080000;
// Precomputed mappings
const KERNEL_L1_INDEX: usize = (KERNEL_VIRT_OFFSET + KERNEL_PHYS_BASE).page_index::<L1>();
const KERNEL_START_L2_INDEX: usize = (KERNEL_VIRT_OFFSET + KERNEL_PHYS_BASE).page_index::<L2>();
const KERNEL_END_L2_INDEX: usize = KERNEL_START_L2_INDEX + KERNEL_L3_COUNT;
// Must not be zero, should be at 4MiB
const_assert_eq!(KERNEL_START_L2_INDEX, 0);
// From static mapping
const_assert_eq!(KERNEL_L1_INDEX, 1);
// Runtime mappings
// 2MiB max
const EARLY_MAPPING_L2I: usize = KERNEL_END_L2_INDEX + 1;
// 1GiB max
const HEAP_MAPPING_L1I: usize = KERNEL_L1_INDEX + 1;
// 1GiB max
const DEVICE_MAPPING_L1I: usize = KERNEL_L1_INDEX + 2;
const DEVICE_MAPPING_L3_COUNT: usize = 4;
// 16GiB max
const RAM_MAPPING_START_L1I: usize = KERNEL_L1_INDEX + 3;
pub const RAM_MAPPING_L1_COUNT: usize = 16;
// 2MiB for early mappings
const EARLY_MAPPING_OFFSET: usize =
MAPPING_OFFSET | (KERNEL_L1_INDEX * L1::SIZE) | (EARLY_MAPPING_L2I * L2::SIZE);
static mut EARLY_MAPPING_L3: PageTable<L3> = PageTable::zeroed();
// 1GiB for heap mapping
pub const HEAP_MAPPING_OFFSET: usize = MAPPING_OFFSET | (HEAP_MAPPING_L1I * L1::SIZE);
pub static mut HEAP_MAPPING_L2: PageTable<L2> = PageTable::zeroed();
// 1GiB for device MMIO mapping
const DEVICE_MAPPING_OFFSET: usize = MAPPING_OFFSET | (DEVICE_MAPPING_L1I * L1::SIZE);
static mut DEVICE_MAPPING_L2: PageTable<L2> = PageTable::zeroed();
static mut DEVICE_MAPPING_L3S: [PageTable<L3>; DEVICE_MAPPING_L3_COUNT] =
[PageTable::zeroed(); DEVICE_MAPPING_L3_COUNT];
// 16GiB for RAM mapping
pub const RAM_MAPPING_OFFSET: usize = MAPPING_OFFSET | (RAM_MAPPING_START_L1I * L1::SIZE);
pub static MEMORY_LIMIT: AtomicUsize = AtomicUsize::new(0);
#[link_section = ".data.tables"]
pub static mut KERNEL_TABLES: KernelImageObject<FixedTables> =
unsafe { KernelImageObject::new(FixedTables::zeroed()) };
impl KernelTableManager for KernelTableManagerImpl {
fn virtualize(address: u64) -> usize {
let address = address as usize;
if address < MEMORY_LIMIT.load(Ordering::Acquire) {
address + RAM_MAPPING_OFFSET
} else {
panic!("Invalid physical address: {:#x}", address);
}
}
fn physicalize(address: usize) -> u64 {
if address < RAM_MAPPING_OFFSET
|| address - RAM_MAPPING_OFFSET >= MEMORY_LIMIT.load(Ordering::Acquire)
{
panic!("Not a virtualized physical address: {:#x}", address);
}
(address - RAM_MAPPING_OFFSET) as _
}
unsafe fn map_device_pages(
base: u64,
count: usize,
attrs: DeviceMemoryAttributes,
) -> Result<RawDeviceMemoryMapping<Self>, Error> {
map_device_memory(PhysicalAddress::from_raw(base), count, attrs)
}
unsafe fn unmap_device_pages(mapping: &RawDeviceMemoryMapping<Self>) {
unmap_device_memory(mapping)
}
}
/// Memory mapping which may be used for performing early kernel initialization
pub struct EarlyMapping<'a, T: ?Sized> {
value: &'a mut T,
page_count: usize,
}
impl<'a, T: Sized> EarlyMapping<'a, T> {
pub unsafe fn map_slice(
physical: PhysicalAddress,
len: usize,
) -> Result<EarlyMapping<'a, [T]>, Error> {
let layout = Layout::array::<T>(len).unwrap();
let aligned = physical.page_align_down::<L3>();
let offset = physical.page_offset::<L3>();
let page_count = (offset + layout.size() + L3::SIZE - 1) / L3::SIZE;
let virt = map_early_pages(aligned, page_count)?;
let value = core::slice::from_raw_parts_mut((virt + offset) as *mut T, len);
Ok(EarlyMapping { value, page_count })
}
}
impl<'a, T: ?Sized> Deref for EarlyMapping<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.value
}
}
impl<'a, T: ?Sized> DerefMut for EarlyMapping<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.value
}
}
impl<'a, T: ?Sized> Drop for EarlyMapping<'a, T> {
fn drop(&mut self) {
let address = (self.value as *mut T).addr() & !(L3::SIZE - 1);
for i in 0..self.page_count {
let page = address + i * L3::SIZE;
unsafe {
unmap_early_page(page);
}
}
}
}
fn kernel_table_flags() -> PageAttributes {
PageAttributes::TABLE
| PageAttributes::ACCESS
| PageAttributes::SH_INNER
| PageAttributes::PAGE_ATTR_NORMAL
| PageAttributes::PRESENT
}
fn ram_block_flags() -> PageAttributes {
// TODO UXN, PXN
PageAttributes::BLOCK
| PageAttributes::ACCESS
| PageAttributes::SH_INNER
| PageAttributes::PAGE_ATTR_NORMAL
| PageAttributes::PRESENT
}
// Early mappings
unsafe fn map_early_pages(physical: PhysicalAddress, count: usize) -> Result<usize, Error> {
for l3i in 0..512 {
let mut taken = false;
for i in 0..count {
if EARLY_MAPPING_L3[i + l3i].is_present() {
taken = true;
break;
}
}
if taken {
continue;
}
for i in 0..count {
let page = physical.add(i * L3::SIZE);
// TODO NX, NC
EARLY_MAPPING_L3[i + l3i] = PageEntry::normal_page(page, PageAttributes::empty());
}
return Ok(EARLY_MAPPING_OFFSET + l3i * L3::SIZE);
}
Err(Error::OutOfMemory)
}
unsafe fn unmap_early_page(address: usize) {
if !(EARLY_MAPPING_OFFSET..EARLY_MAPPING_OFFSET + L2::SIZE).contains(&address) {
panic!("Tried to unmap invalid early mapping: {:#x}", address);
}
let l3i = (address - EARLY_MAPPING_OFFSET).page_index::<L3>();
assert!(EARLY_MAPPING_L3[l3i].is_present());
EARLY_MAPPING_L3[l3i] = PageEntry::INVALID;
// TODO invalidate tlb
}
pub unsafe fn map_ram_l1(index: usize) {
if index >= RAM_MAPPING_L1_COUNT {
todo!()
}
assert_eq!(KERNEL_TABLES.l1.data[index + RAM_MAPPING_START_L1I], 0);
KERNEL_TABLES.l1.data[index + RAM_MAPPING_START_L1I] =
((index * L1::SIZE) as u64) | ram_block_flags().bits();
}
pub unsafe fn map_heap_l2(index: usize, page: PhysicalAddress) {
if index >= 512 {
todo!()
}
assert!(!HEAP_MAPPING_L2[index].is_present());
// TODO UXN, PXN
HEAP_MAPPING_L2[index] = PageEntry::normal_block(page, PageAttributes::empty());
}
// Device mappings
unsafe fn map_device_memory_l3(
base: PhysicalAddress,
count: usize,
_attrs: DeviceMemoryAttributes,
) -> Result<usize, Error> {
// TODO don't map pages if already mapped
'l0: for i in 0..DEVICE_MAPPING_L3_COUNT * 512 {
for j in 0..count {
let l2i = (i + j) / 512;
let l3i = (i + j) % 512;
if DEVICE_MAPPING_L3S[l2i][l3i].is_present() {
continue 'l0;
}
}
for j in 0..count {
let l2i = (i + j) / 512;
let l3i = (i + j) % 512;
// TODO NX, NC
DEVICE_MAPPING_L3S[l2i][l3i] = PageEntry::device_page(base.add(j * L3::SIZE));
}
return Ok(DEVICE_MAPPING_OFFSET + i * L3::SIZE);
}
Err(Error::OutOfMemory)
}
unsafe fn map_device_memory_l2(
base: PhysicalAddress,
count: usize,
_attrs: DeviceMemoryAttributes,
) -> Result<usize, Error> {
'l0: for i in DEVICE_MAPPING_L3_COUNT..512 {
for j in 0..count {
if DEVICE_MAPPING_L2[i + j].is_present() {
continue 'l0;
}
}
for j in 0..count {
DEVICE_MAPPING_L2[i + j] = PageEntry::<L2>::device_block(base.add(j * L2::SIZE));
}
// log::debug!(
// "map l2s: base={:#x}, count={} -> {:#x}",
// base,
// count,
// DEVICE_MAPPING_OFFSET + i * L2::SIZE
// );
return Ok(DEVICE_MAPPING_OFFSET + i * L2::SIZE);
}
Err(Error::OutOfMemory)
}
pub unsafe fn map_device_memory(
base: PhysicalAddress,
size: usize,
attrs: DeviceMemoryAttributes,
) -> Result<RawDeviceMemoryMapping<KernelTableManagerImpl>, Error> {
// debugln!("Map {}B @ {:#x}", size, base);
let l3_aligned = base.page_align_down::<L3>();
let l3_offset = base.page_offset::<L3>();
let page_count = (l3_offset + size).page_count::<L3>();
if page_count > 256 {
// Large mapping, use L2 mapping instead
let l2_aligned = base.page_align_down::<L2>();
let l2_offset = base.page_offset::<L2>();
let page_count = (l2_offset + size).page_count::<L2>();
let base_address = map_device_memory_l2(l2_aligned, page_count, attrs)?;
let address = base_address + l2_offset;
Ok(RawDeviceMemoryMapping::from_raw_parts(
address,
base_address,
page_count,
L2::SIZE,
))
} else {
// Just map the pages directly
let base_address = map_device_memory_l3(l3_aligned, page_count, attrs)?;
let address = base_address + l3_offset;
Ok(RawDeviceMemoryMapping::from_raw_parts(
address,
base_address,
page_count,
L3::SIZE,
))
}
}
pub unsafe fn unmap_device_memory(map: &RawDeviceMemoryMapping<KernelTableManagerImpl>) {
// debugln!(
// "Unmap {}B @ {:#x}",
// map.page_count * map.page_size,
// map.base_address
// );
match map.page_size {
L3::SIZE => {
for i in 0..map.page_count {
let page = map.base_address + i * L3::SIZE;
let l2i = page.page_index::<L2>();
let l3i = page.page_index::<L3>();
assert!(DEVICE_MAPPING_L3S[l2i][l3i].is_present());
DEVICE_MAPPING_L3S[l2i][l3i] = PageEntry::INVALID;
tlb_flush_vaae1(page);
}
}
L2::SIZE => todo!(),
_ => unimplemented!(),
}
}
#[inline]
pub fn tlb_flush_vaae1(mut page: usize) {
page >>= 12;
unsafe {
core::arch::asm!("tlbi vaae1, {page}", page = in(reg) page);
}
}
/// (BSP-early init) loads precomputed kernel mapping tables for the kernel to jump to "higher-half"
///
/// # Safety
///
/// Unsafe, must only be called by BSP during its early init while still in "lower-half"
pub unsafe fn load_fixed_tables() {
let ttbr0 = KERNEL_TABLES.l1.data.as_ptr() as u64;
TTBR0_EL1.set(ttbr0);
TTBR1_EL1.set(ttbr0);
}
/// Sets up additional translation tables for kernel usage
///
/// # Safety
///
/// Unsafe, must only be called by BSP during its early init, must already be in "higher-half"
pub unsafe fn init_fixed_tables() {
// TODO this could be built in compile-time too?
let early_mapping_l3_phys = addr_of!(EARLY_MAPPING_L3) as usize - KERNEL_VIRT_OFFSET;
let device_mapping_l2_phys = addr_of!(DEVICE_MAPPING_L2) as usize - KERNEL_VIRT_OFFSET;
let heap_mapping_l2_phys = addr_of!(HEAP_MAPPING_L2) as usize - KERNEL_VIRT_OFFSET;
for i in 0..DEVICE_MAPPING_L3_COUNT {
let device_mapping_l3_phys = PhysicalAddress::from_raw(
&DEVICE_MAPPING_L3S[i] as *const _ as usize - KERNEL_VIRT_OFFSET,
);
DEVICE_MAPPING_L2[i] = PageEntry::table(device_mapping_l3_phys, PageAttributes::empty());
}
assert_eq!(KERNEL_TABLES.l2.data[EARLY_MAPPING_L2I], 0);
KERNEL_TABLES.l2.data[EARLY_MAPPING_L2I] =
(early_mapping_l3_phys as u64) | kernel_table_flags().bits();
assert_eq!(KERNEL_TABLES.l1.data[HEAP_MAPPING_L1I], 0);
KERNEL_TABLES.l1.data[HEAP_MAPPING_L1I] =
(heap_mapping_l2_phys as u64) | kernel_table_flags().bits();
assert_eq!(KERNEL_TABLES.l1.data[DEVICE_MAPPING_L1I], 0);
KERNEL_TABLES.l1.data[DEVICE_MAPPING_L1I] =
(device_mapping_l2_phys as u64) | kernel_table_flags().bits();
}

View File

@ -0,0 +1,156 @@
//! AArch64-specific process address space management
use core::{
marker::PhantomData,
sync::atomic::{AtomicU8, Ordering},
};
use libk_mm_interface::{
address::{AsPhysicalAddress, PhysicalAddress},
pointer::PhysicalRefMut,
process::ProcessAddressSpaceManager,
table::{
EntryLevel, EntryLevelDrop, EntryLevelExt, MapAttributes, NextPageTable, TableAllocator,
},
};
use yggdrasil_abi::error::Error;
use crate::{mem::table::PageEntry, KernelTableManagerImpl};
use super::{
table::{PageTable, L1, L2, L3},
tlb_flush_vaae1,
};
/// AArch64 implementation of a process address space table
#[repr(C)]
pub struct ProcessAddressSpaceImpl<TA: TableAllocator> {
l1: PhysicalRefMut<'static, PageTable<L1>, KernelTableManagerImpl>,
asid: u8,
_alloc: PhantomData<TA>,
}
impl<TA: TableAllocator> ProcessAddressSpaceManager<TA> for ProcessAddressSpaceImpl<TA> {
const LOWER_LIMIT_PFN: usize = 8;
// 16GiB VM limit
const UPPER_LIMIT_PFN: usize = (16 << 30) / L3::SIZE;
fn new() -> Result<Self, Error> {
static LAST_ASID: AtomicU8 = AtomicU8::new(1);
let asid = LAST_ASID.fetch_add(1, Ordering::AcqRel);
let mut l1 = unsafe {
PhysicalRefMut::<'static, PageTable<L1>, KernelTableManagerImpl>::map(
TA::allocate_page_table()?,
)
};
for i in 0..512 {
l1[i] = PageEntry::INVALID;
}
Ok(Self {
l1,
asid,
_alloc: PhantomData,
})
}
fn translate(&self, address: usize) -> Result<(PhysicalAddress, MapAttributes), Error> {
self.read_l3_entry(address).ok_or(Error::DoesNotExist)
}
unsafe fn map_page(
&mut self,
address: usize,
physical: PhysicalAddress,
flags: MapAttributes,
) -> Result<(), Error> {
self.write_l3_entry(
address,
PageEntry::normal_page(physical, flags.into()),
false,
)
}
unsafe fn unmap_page(&mut self, address: usize) -> Result<PhysicalAddress, Error> {
self.pop_l3_entry(address)
}
fn as_address_with_asid(&self) -> u64 {
unsafe { u64::from(self.l1.as_physical_address()) | ((self.asid as u64) << 48) }
}
unsafe fn clear(&mut self) {
self.l1
.drop_range::<TA>(0..((Self::UPPER_LIMIT_PFN * L3::SIZE).page_index::<L1>()));
}
}
impl<TA: TableAllocator> ProcessAddressSpaceImpl<TA> {
// Write a single 4KiB entry
fn write_l3_entry(
&mut self,
virt: usize,
entry: PageEntry<L3>,
overwrite: bool,
) -> Result<(), Error> {
let l1i = virt.page_index::<L1>();
let l2i = virt.page_index::<L2>();
let l3i = virt.page_index::<L3>();
let mut l2 = self.l1.get_mut_or_alloc::<TA>(l1i)?;
let mut l3 = l2.get_mut_or_alloc::<TA>(l2i)?;
if l3[l3i].is_present() && !overwrite {
todo!();
}
l3[l3i] = entry;
tlb_flush_vaae1(virt);
Ok(())
}
fn pop_l3_entry(&mut self, virt: usize) -> Result<PhysicalAddress, Error> {
let l1i = virt.page_index::<L1>();
let l2i = virt.page_index::<L2>();
let l3i = virt.page_index::<L3>();
// TODO somehow drop tables if they're known to be empty?
let mut l2 = self.l1.get_mut(l1i).ok_or(Error::DoesNotExist)?;
let mut l3 = l2.get_mut(l2i).ok_or(Error::DoesNotExist)?;
let page = l3[l3i].as_page().ok_or(Error::DoesNotExist)?;
l3[l3i] = PageEntry::INVALID;
tlb_flush_vaae1(virt);
Ok(page)
}
fn read_l3_entry(&self, virt: usize) -> Option<(PhysicalAddress, MapAttributes)> {
let l1i = virt.page_index::<L1>();
let l2i = virt.page_index::<L2>();
let l3i = virt.page_index::<L3>();
let l2 = self.l1.get(l1i)?;
let l3 = l2.get(l2i)?;
let page = l3[l3i].as_page()?;
Some((page, l3[l3i].attributes().into()))
}
}
impl<TA: TableAllocator> Drop for ProcessAddressSpaceImpl<TA> {
fn drop(&mut self) {
// SAFETY: with safe usage of the ProcessAddressSpaceImpl, clearing and dropping
// is safe, no one refers to the memory
unsafe {
self.clear();
let l1_phys = self.l1.as_physical_address();
TA::free_page_table(l1_phys);
}
}
}

View File

@ -0,0 +1,342 @@
use core::{
marker::PhantomData,
ops::{Index, IndexMut, Range},
};
use bitflags::bitflags;
use libk_mm_interface::{
address::{AsPhysicalAddress, FromRaw, IntoRaw, PhysicalAddress},
pointer::{PhysicalRef, PhysicalRefMut},
table::{
EntryLevel, EntryLevelDrop, MapAttributes, NextPageTable, NonTerminalEntryLevel,
TableAllocator,
},
};
use yggdrasil_abi::error::Error;
use crate::KernelTableManagerImpl;
bitflags! {
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct PageAttributes: u64 {
const PRESENT = 1 << 0;
const TABLE = 1 << 1;
const PAGE = 1 << 1;
const BLOCK = 0 << 1;
const ACCESS = 1 << 10;
const AP_KERNEL_READWRITE = 0 << 6;
const AP_BOTH_READWRITE = 1 << 6;
const AP_KERNEL_READONLY = 2 << 6;
const AP_BOTH_READONLY = 3 << 6;
const AP_ACCESS_MASK = 3 << 6;
const SH_OUTER = 2 << 8;
const SH_INNER = 3 << 8;
const PAGE_ATTR_NORMAL = 0 << 2;
const PAGE_ATTR_DEVICE = 1 << 2;
const NON_GLOBAL = 1 << 11;
const PXN = 1 << 53;
const UXN = 1 << 54;
}
}
#[derive(Clone, Copy)]
#[repr(C, align(0x1000))]
pub struct PageTable<L: EntryLevel> {
entries: [PageEntry<L>; 512],
}
#[derive(Clone, Copy)]
pub struct PageEntry<L: EntryLevel>(u64, PhantomData<L>);
#[derive(Clone, Copy)]
pub struct L1;
#[derive(Clone, Copy)]
pub struct L2;
#[derive(Clone, Copy)]
pub struct L3;
impl NonTerminalEntryLevel for L1 {
type NextLevel = L2;
}
impl NonTerminalEntryLevel for L2 {
type NextLevel = L3;
}
impl EntryLevel for L1 {
const SHIFT: usize = 30;
}
impl EntryLevel for L2 {
const SHIFT: usize = 21;
}
impl EntryLevel for L3 {
const SHIFT: usize = 12;
}
impl<L: EntryLevel> PageTable<L> {
pub const fn zeroed() -> Self {
Self {
entries: [PageEntry::INVALID; 512],
}
}
pub fn new_zeroed<'a, TA: TableAllocator>(
) -> Result<PhysicalRefMut<'a, Self, KernelTableManagerImpl>, Error> {
let physical = TA::allocate_page_table()?;
let mut table =
unsafe { PhysicalRefMut::<'a, Self, KernelTableManagerImpl>::map(physical) };
for i in 0..512 {
table[i] = PageEntry::INVALID;
}
Ok(table)
}
}
impl<L: EntryLevel> PageEntry<L> {
pub const INVALID: Self = Self(0, PhantomData);
pub const fn is_present(self) -> bool {
self.0 & PageAttributes::PRESENT.bits() != 0
}
pub fn attributes(self) -> PageAttributes {
PageAttributes::from_bits_retain(self.0)
}
}
impl<L: NonTerminalEntryLevel + 'static> NextPageTable for PageTable<L> {
type NextLevel = PageTable<L::NextLevel>;
type TableRef = PhysicalRef<'static, PageTable<L::NextLevel>, KernelTableManagerImpl>;
type TableRefMut = PhysicalRefMut<'static, PageTable<L::NextLevel>, KernelTableManagerImpl>;
fn get(&self, index: usize) -> Option<Self::TableRef> {
self[index]
.as_table()
.map(|phys| unsafe { PhysicalRef::map(phys) })
}
fn get_mut(&mut self, index: usize) -> Option<Self::TableRefMut> {
self[index]
.as_table()
.map(|phys| unsafe { PhysicalRefMut::map(phys) })
}
fn get_mut_or_alloc<TA: TableAllocator>(
&mut self,
index: usize,
) -> Result<Self::TableRefMut, Error> {
let entry = self[index];
if let Some(table) = entry.as_table() {
Ok(unsafe { PhysicalRefMut::map(table) })
} else {
let table = PageTable::new_zeroed::<TA>()?;
self[index] = PageEntry::<L>::table(
unsafe { table.as_physical_address() },
PageAttributes::empty(),
);
Ok(table)
}
}
}
impl EntryLevelDrop for PageTable<L3> {
const FULL_RANGE: Range<usize> = 0..512;
// Do nothing
unsafe fn drop_range<TA: TableAllocator>(&mut self, _range: Range<usize>) {}
}
impl<L: NonTerminalEntryLevel + 'static> EntryLevelDrop for PageTable<L>
where
PageTable<L::NextLevel>: EntryLevelDrop,
{
const FULL_RANGE: Range<usize> = 0..512;
unsafe fn drop_range<TA: TableAllocator>(&mut self, range: Range<usize>) {
for index in range {
let entry = self[index];
if let Some(table) = entry.as_table() {
let mut table_ref: PhysicalRefMut<PageTable<L::NextLevel>, KernelTableManagerImpl> =
PhysicalRefMut::map(table);
table_ref.drop_all::<TA>();
// Drop the table
drop(table_ref);
TA::free_page_table(table);
} else if entry.is_present() {
// Memory must've been cleared beforehand, so no non-table entries must be present
panic!(
"Expected a table containing only tables, got table[{}] = {:#x?}",
index, entry.0
);
}
self[index] = PageEntry::INVALID;
}
}
}
impl<L: NonTerminalEntryLevel> PageEntry<L> {
pub fn table(phys: PhysicalAddress, attrs: PageAttributes) -> Self {
Self(
IntoRaw::<u64>::into_raw(phys)
| (PageAttributes::TABLE | PageAttributes::PRESENT | attrs).bits(),
PhantomData,
)
}
pub fn normal_block(phys: PhysicalAddress, attrs: PageAttributes) -> Self {
Self(
IntoRaw::<u64>::into_raw(phys)
| (PageAttributes::BLOCK
| PageAttributes::PRESENT
| PageAttributes::ACCESS
| PageAttributes::SH_INNER
| PageAttributes::PAGE_ATTR_NORMAL
| attrs)
.bits(),
PhantomData,
)
}
pub fn device_block(phys: PhysicalAddress) -> Self {
Self(
IntoRaw::<u64>::into_raw(phys)
| (PageAttributes::BLOCK
| PageAttributes::PRESENT
| PageAttributes::ACCESS
| PageAttributes::SH_OUTER
| PageAttributes::PAGE_ATTR_DEVICE
| PageAttributes::UXN
| PageAttributes::PXN)
.bits(),
PhantomData,
)
}
/// Returns the physical address of the table this entry refers to, returning None if it
/// does not
pub fn as_table(self) -> Option<PhysicalAddress> {
if self.0 & PageAttributes::PRESENT.bits() != 0
&& self.0 & PageAttributes::BLOCK.bits() == 0
{
Some(PhysicalAddress::from_raw(self.0 & !0xFFF))
} else {
None
}
}
}
impl PageEntry<L3> {
pub fn normal_page(phys: PhysicalAddress, attrs: PageAttributes) -> Self {
Self(
IntoRaw::<u64>::into_raw(phys)
| (PageAttributes::PAGE
| PageAttributes::PRESENT
| PageAttributes::ACCESS
| PageAttributes::SH_INNER
| PageAttributes::PAGE_ATTR_NORMAL
| attrs)
.bits(),
PhantomData,
)
}
pub fn device_page(phys: PhysicalAddress) -> Self {
Self(
IntoRaw::<u64>::into_raw(phys)
| (PageAttributes::PAGE
| PageAttributes::PRESENT
| PageAttributes::ACCESS
| PageAttributes::SH_OUTER
| PageAttributes::PAGE_ATTR_DEVICE
| PageAttributes::UXN
| PageAttributes::PXN)
.bits(),
PhantomData,
)
}
pub fn as_page(&self) -> Option<PhysicalAddress> {
let mask = (PageAttributes::PRESENT | PageAttributes::PAGE).bits();
if self.0 & mask == mask {
Some(PhysicalAddress::from_raw(self.0 & !0xFFF))
} else {
None
}
}
}
impl<L: EntryLevel> Index<usize> for PageTable<L> {
type Output = PageEntry<L>;
fn index(&self, index: usize) -> &Self::Output {
&self.entries[index]
}
}
impl<L: EntryLevel> IndexMut<usize> for PageTable<L> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.entries[index]
}
}
impl From<MapAttributes> for PageAttributes {
fn from(value: MapAttributes) -> Self {
let mut out = PageAttributes::empty();
// TODO kernel cannot write el0 readonly pages
if value.contains(MapAttributes::USER_WRITE) {
// Read/write
out |= PageAttributes::AP_BOTH_READWRITE;
} else if value.contains(MapAttributes::USER_READ) {
// Read only
out |= PageAttributes::AP_BOTH_READONLY;
} else {
// No read/write
out |= PageAttributes::AP_KERNEL_READONLY;
}
if value.contains(MapAttributes::NON_GLOBAL) {
out |= PageAttributes::NON_GLOBAL;
}
out
}
}
impl From<PageAttributes> for MapAttributes {
fn from(value: PageAttributes) -> Self {
let mut out = MapAttributes::empty();
out |= match value.intersection(PageAttributes::AP_ACCESS_MASK) {
PageAttributes::AP_BOTH_READWRITE => {
MapAttributes::USER_WRITE | MapAttributes::USER_READ
}
PageAttributes::AP_BOTH_READONLY => MapAttributes::USER_READ,
PageAttributes::AP_KERNEL_READONLY => MapAttributes::empty(),
PageAttributes::AP_KERNEL_READWRITE => panic!("This variant cannot be constructed"),
_ => unreachable!(),
};
if value.contains(PageAttributes::NON_GLOBAL) {
out |= MapAttributes::NON_GLOBAL;
}
out
}
}

View File

@ -0,0 +1,9 @@
[package]
name = "kernel-arch-hosted"
version = "0.1.0"
edition = "2021"
[dependencies]
kernel-arch-interface = { path = "../interface" }
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
libk-mm-interface = { path = "../../libk/libk-mm/interface" }

View File

@ -0,0 +1,176 @@
#![feature(never_type)]
use std::{
marker::PhantomData,
sync::atomic::{AtomicBool, Ordering},
};
use kernel_arch_interface::{
cpu::IpiQueue,
mem::{
DeviceMemoryAttributes, KernelTableManager, PhysicalMemoryAllocator, RawDeviceMemoryMapping,
},
task::{Scheduler, TaskContext},
Architecture,
};
use libk_mm_interface::{
address::PhysicalAddress,
process::ProcessAddressSpaceManager,
table::{MapAttributes, TableAllocator},
};
use yggdrasil_abi::{error::Error, process::Signal};
pub struct ArchitectureImpl;
#[derive(Debug)]
pub struct KernelTableManagerImpl;
pub struct ProcessAddressSpaceImpl<TA: TableAllocator>(!, PhantomData<TA>);
pub struct TaskContextImpl<K: KernelTableManager, PA: PhysicalMemoryAllocator>(
!,
PhantomData<(K, PA)>,
);
static DUMMY_INTERRUPT_MASK: AtomicBool = AtomicBool::new(true);
impl Architecture for ArchitectureImpl {
type PerCpuData = ();
fn local_cpu() -> *mut Self::PerCpuData {
unimplemented!()
}
unsafe fn set_local_cpu(_cpu: *mut Self::PerCpuData) {
unimplemented!()
}
unsafe fn init_local_cpu<S: Scheduler + 'static>(_id: Option<u32>, _data: Self::PerCpuData) {
unimplemented!()
}
unsafe fn init_ipi_queues(_queues: Vec<IpiQueue<Self>>) {
unimplemented!()
}
fn idle_task() -> extern "C" fn(usize) -> ! {
unimplemented!()
}
fn cpu_count() -> usize {
unimplemented!()
}
fn cpu_index<S: Scheduler + 'static>() -> u32 {
unimplemented!()
}
unsafe fn set_interrupt_mask(mask: bool) -> bool {
DUMMY_INTERRUPT_MASK.swap(mask, Ordering::Acquire)
}
fn interrupt_mask() -> bool {
unimplemented!()
}
fn wait_for_interrupt() {
unimplemented!()
}
}
impl KernelTableManager for KernelTableManagerImpl {
fn virtualize(_phys: u64) -> usize {
unimplemented!()
}
fn physicalize(_virt: usize) -> u64 {
unimplemented!()
}
unsafe fn map_device_pages(
_base: u64,
_count: usize,
_attrs: DeviceMemoryAttributes,
) -> Result<RawDeviceMemoryMapping<Self>, Error> {
unimplemented!()
}
unsafe fn unmap_device_pages(_mapping: &RawDeviceMemoryMapping<Self>) {
unimplemented!()
}
}
impl<TA: TableAllocator> ProcessAddressSpaceManager<TA> for ProcessAddressSpaceImpl<TA> {
const LOWER_LIMIT_PFN: usize = 16;
const UPPER_LIMIT_PFN: usize = 1024;
fn new() -> Result<Self, Error> {
unimplemented!()
}
unsafe fn clear(&mut self) {
unimplemented!()
}
unsafe fn map_page(
&mut self,
_address: usize,
_physical: PhysicalAddress,
_flags: MapAttributes,
) -> Result<(), Error> {
unimplemented!()
}
unsafe fn unmap_page(&mut self, _address: usize) -> Result<PhysicalAddress, Error> {
unimplemented!()
}
fn translate(&self, _address: usize) -> Result<(PhysicalAddress, MapAttributes), Error> {
unimplemented!()
}
fn as_address_with_asid(&self) -> u64 {
unimplemented!()
}
}
impl<K: KernelTableManager, PA: PhysicalMemoryAllocator> TaskContext<K, PA>
for TaskContextImpl<K, PA>
{
const USER_STACK_EXTRA_ALIGN: usize = 0;
const SIGNAL_STACK_EXTRA_ALIGN: usize = 0;
unsafe fn enter(&self) -> ! {
unimplemented!()
}
unsafe fn switch(&self, _from: &Self) {
unimplemented!()
}
unsafe fn switch_and_drop(&self, _thread: *const ()) {
unimplemented!()
}
fn user(
_entry: usize,
_arg: usize,
_cr3: u64,
_user_stack_sp: usize,
_tls_address: usize,
) -> Result<Self, Error> {
unimplemented!()
}
fn kernel(_entry: extern "C" fn(usize) -> !, _arg: usize) -> Result<Self, Error> {
unimplemented!()
}
fn kernel_closure<F: FnOnce() -> ! + Send + 'static>(_f: F) -> Result<Self, Error> {
unimplemented!()
}
}
#[no_mangle]
extern "Rust" fn __signal_process_group(_group_id: u32, _signal: Signal) {
unimplemented!()
}

View File

@ -0,0 +1,10 @@
[package]
name = "kernel-arch-interface"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
device-api = { path = "../../lib/device-api", features = ["derive"] }

View File

@ -0,0 +1,151 @@
use core::{
marker::PhantomData,
ops::{Deref, DerefMut},
};
use alloc::vec::Vec;
use device_api::interrupt::IpiMessage;
use crate::{
guard::IrqGuard, sync::IrqSafeSpinlock, task::Scheduler, util::OneTimeInit, Architecture,
};
#[repr(C, align(0x10))]
pub struct CpuImpl<A: Architecture, S: Scheduler + 'static> {
inner: A::PerCpuData,
scheduler: OneTimeInit<&'static S>,
id: u32,
current_thread_id: Option<S::ThreadId>,
_pd: PhantomData<A>,
}
pub struct LocalCpuImpl<'a, A: Architecture, S: Scheduler + 'static> {
cpu: &'a mut CpuImpl<A, S>,
guard: IrqGuard<A>,
}
pub struct IpiQueue<A: Architecture> {
data: IrqSafeSpinlock<A, Option<IpiMessage>>,
}
impl<A: Architecture, S: Scheduler + 'static> CpuImpl<A, S> {
pub fn new(id: u32, inner: A::PerCpuData) -> Self {
Self {
inner,
scheduler: OneTimeInit::new(),
id,
current_thread_id: None,
_pd: PhantomData,
}
}
pub fn init_ipi_queues(cpu_count: usize) {
let queues = Vec::from_iter((0..cpu_count).map(|_| IpiQueue::new()));
unsafe { A::init_ipi_queues(queues) }
}
pub fn set_current_thread_id(&mut self, id: Option<S::ThreadId>) {
self.current_thread_id = id;
}
pub fn current_thread_id(&self) -> Option<S::ThreadId> {
self.current_thread_id
}
pub fn set_scheduler(&mut self, sched: &'static S) {
self.scheduler.init(sched);
}
pub fn try_get_scheduler(&self) -> Option<&'static S> {
self.scheduler.try_get().copied()
}
pub fn scheduler(&self) -> &'static S {
self.scheduler.get()
}
pub unsafe fn set_local(&'static mut self) {
A::set_local_cpu(self as *mut _ as *mut _)
}
pub fn try_local<'a>() -> Option<LocalCpuImpl<'a, A, S>> {
let guard = IrqGuard::acquire();
let cpu = A::local_cpu() as *mut Self;
unsafe { cpu.as_mut().map(|cpu| LocalCpuImpl { cpu, guard }) }
}
pub fn local<'a>() -> LocalCpuImpl<'a, A, S> {
Self::try_local().expect("Local CPU not initialized")
}
pub fn id(&self) -> u32 {
self.id
}
pub fn push_ipi_queue(_cpu_id: u32, _msg: IpiMessage) {
// XXX
todo!()
}
pub fn get_ipi(&self) -> Option<IpiMessage> {
// XXX
todo!()
}
}
impl<A: Architecture, S: Scheduler> Deref for CpuImpl<A, S> {
type Target = A::PerCpuData;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<A: Architecture, S: Scheduler> DerefMut for CpuImpl<A, S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<'a, A: Architecture, S: Scheduler + 'static> LocalCpuImpl<'a, A, S> {
pub fn into_guard(self) -> IrqGuard<A> {
self.guard
}
}
impl<'a, A: Architecture, S: Scheduler> Deref for LocalCpuImpl<'a, A, S> {
type Target = CpuImpl<A, S>;
fn deref(&self) -> &Self::Target {
self.cpu
}
}
impl<'a, A: Architecture, S: Scheduler> DerefMut for LocalCpuImpl<'a, A, S> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.cpu
}
}
impl<A: Architecture> IpiQueue<A> {
pub const fn new() -> Self {
Self {
data: IrqSafeSpinlock::<A, _>::new(None),
}
}
pub fn push(&self, msg: IpiMessage) {
let mut lock = self.data.lock();
assert!(lock.is_none());
lock.replace(msg);
}
pub fn pop(&self) -> Option<IpiMessage> {
let mut lock = self.data.lock();
lock.take()
}
}

View File

@ -0,0 +1,24 @@
use core::marker::PhantomData;
use crate::Architecture;
/// Token type used to prevent IRQs from firing during some critical section. Normal IRQ operation
/// (if enabled before) is resumed when [IrqGuard]'s lifetime is over.
pub struct IrqGuard<A: Architecture>(bool, PhantomData<A>);
// IrqGuard impls
impl<A: Architecture> IrqGuard<A> {
/// Saves the current IRQ state and masks them
pub fn acquire() -> Self {
let mask = unsafe { A::set_interrupt_mask(true) };
Self(mask, PhantomData)
}
}
impl<A: Architecture> Drop for IrqGuard<A> {
fn drop(&mut self) {
unsafe {
A::set_interrupt_mask(self.0);
}
}
}

View File

@ -0,0 +1,47 @@
#![no_std]
#![feature(step_trait, effects, const_trait_impl, never_type)]
use alloc::vec::Vec;
use cpu::IpiQueue;
use device_api::interrupt::{LocalInterruptController, MessageInterruptController};
use task::Scheduler;
extern crate alloc;
pub mod cpu;
pub mod guard;
pub mod mem;
pub mod sync;
pub mod task;
pub mod util;
pub const KERNEL_VIRT_OFFSET: usize = 0xFFFFFF8000000000;
pub trait Architecture: Sized {
type PerCpuData;
// Cpu management
unsafe fn set_local_cpu(cpu: *mut ());
fn local_cpu() -> *mut ();
unsafe fn init_ipi_queues(queues: Vec<IpiQueue<Self>>);
unsafe fn init_local_cpu<S: Scheduler + 'static>(id: Option<u32>, data: Self::PerCpuData);
fn idle_task() -> extern "C" fn(usize) -> !;
fn cpu_count() -> usize;
fn cpu_index<S: Scheduler + 'static>() -> u32;
// Interrupt management
fn interrupt_mask() -> bool;
unsafe fn set_interrupt_mask(mask: bool) -> bool;
fn wait_for_interrupt();
// Architectural devices
fn local_interrupt_controller() -> &'static dyn LocalInterruptController {
unimplemented!()
}
fn message_interrupt_controller() -> &'static dyn MessageInterruptController {
unimplemented!()
}
}

View File

View File

@ -0,0 +1,120 @@
use core::{fmt, marker::PhantomData, mem::size_of, ptr::NonNull};
use yggdrasil_abi::error::Error;
pub mod address;
pub mod table;
pub trait PhysicalMemoryAllocator {
type Address;
fn allocate_page() -> Result<Self::Address, Error>;
fn allocate_contiguous_pages(count: usize) -> Result<Self::Address, Error>;
unsafe fn free_page(page: Self::Address);
}
#[derive(Debug, Default, Clone, Copy)]
pub enum DeviceMemoryCaching {
#[default]
None,
Cacheable,
}
#[derive(Default, Debug, Clone, Copy)]
pub struct DeviceMemoryAttributes {
pub caching: DeviceMemoryCaching,
}
/// Describes a single device memory mapping
#[derive(Debug)]
pub struct RawDeviceMemoryMapping<A: KernelTableManager> {
/// Virtual address of the mapped object
pub address: usize,
/// Base address of the mapping start
pub base_address: usize,
/// Page size used for the mapping
pub page_size: usize,
/// Number of pages used to map the object
pub page_count: usize,
_manager: PhantomData<A>,
}
pub trait KernelTableManager: Sized + fmt::Debug {
fn virtualize(phys: u64) -> usize;
fn physicalize(virt: usize) -> u64;
unsafe fn map_device_pages(
base: u64,
count: usize,
attrs: DeviceMemoryAttributes,
) -> Result<RawDeviceMemoryMapping<Self>, Error>;
unsafe fn unmap_device_pages(mapping: &RawDeviceMemoryMapping<Self>);
}
impl<A: KernelTableManager> RawDeviceMemoryMapping<A> {
/// Maps a region of physical memory as device memory of given size.
///
/// # Safety
///
/// The caller must ensure proper access synchronization, as well as the address' origin.
#[inline]
pub unsafe fn map(
base: u64,
size: usize,
attrs: DeviceMemoryAttributes,
) -> Result<Self, Error> {
A::map_device_pages(base, size, attrs)
}
/// Consumes the device mapping, leaking its address without deallocating the translation
/// mapping itself
pub fn leak(self) -> usize {
let address = self.address;
core::mem::forget(self);
address
}
pub fn into_raw_parts(self) -> (usize, usize, usize, usize) {
let address = self.address;
let base_address = self.base_address;
let page_count = self.page_count;
let page_size = self.page_size;
core::mem::forget(self);
(address, base_address, page_count, page_size)
}
pub unsafe fn from_raw_parts(
address: usize,
base_address: usize,
page_count: usize,
page_size: usize,
) -> Self {
Self {
address,
base_address,
page_count,
page_size,
_manager: PhantomData,
}
}
/// "Casts" the mapping to a specific type T and returns a [NonNull] pointer to it
pub unsafe fn as_non_null<T>(&self) -> NonNull<T> {
if self.page_size * self.page_count < size_of::<T>() {
panic!();
}
NonNull::new_unchecked(self.address as *mut T)
}
}
impl<A: KernelTableManager> Drop for RawDeviceMemoryMapping<A> {
fn drop(&mut self) {
unsafe {
A::unmap_device_pages(self);
}
}
}

View File

View File

@ -0,0 +1,155 @@
use core::{
cell::UnsafeCell,
marker::PhantomData,
mem,
ops::{Deref, DerefMut},
sync::atomic::{AtomicBool, Ordering},
};
use crate::{guard::IrqGuard, Architecture};
struct SpinlockInner<A: Architecture, T> {
value: UnsafeCell<T>,
state: AtomicBool,
_pd: PhantomData<A>,
}
struct SpinlockInnerGuard<'a, A: Architecture, T> {
lock: &'a SpinlockInner<A, T>,
}
/// Spinlock implementation which prevents interrupts to avoid deadlocks when an interrupt handler
/// tries to acquire a lock taken before the IRQ fired.
pub struct IrqSafeSpinlock<A: Architecture, T> {
inner: SpinlockInner<A, T>,
}
/// Token type allowing safe access to the underlying data of the [IrqSafeSpinlock]. Resumes normal
/// IRQ operation (if enabled before acquiring) when the lifetime is over.
pub struct IrqSafeSpinlockGuard<'a, A: Architecture, T> {
// Must come first to ensure the lock is dropped first and only then IRQs are re-enabled
inner: SpinlockInnerGuard<'a, A, T>,
_irq: IrqGuard<A>,
}
// Spinlock impls
impl<A: Architecture, T> SpinlockInner<A, T> {
const fn new(value: T) -> Self {
Self {
value: UnsafeCell::new(value),
state: AtomicBool::new(false),
_pd: PhantomData,
}
}
fn lock(&self) -> SpinlockInnerGuard<A, T> {
// Loop until the lock can be acquired
// if LOCK_HACK.load(Ordering::Acquire) {
// return SpinlockInnerGuard { lock: self };
// }
while self
.state
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
core::hint::spin_loop();
}
SpinlockInnerGuard { lock: self }
}
}
impl<'a, A: Architecture, T> Deref for SpinlockInnerGuard<'a, A, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.value.get() }
}
}
impl<'a, A: Architecture, T> DerefMut for SpinlockInnerGuard<'a, A, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.lock.value.get() }
}
}
impl<'a, A: Architecture, T> Drop for SpinlockInnerGuard<'a, A, T> {
fn drop(&mut self) {
// if !LOCK_HACK.load(Ordering::Acquire) {
self.lock
.state
.compare_exchange(true, false, Ordering::Release, Ordering::Relaxed)
.unwrap();
// }
}
}
unsafe impl<A: Architecture, T> Sync for SpinlockInner<A, T> {}
unsafe impl<A: Architecture, T> Send for SpinlockInner<A, T> {}
// IrqSafeSpinlock impls
impl<A: Architecture, T> IrqSafeSpinlock<A, T> {
/// Wraps the value in a spinlock primitive
pub const fn new(value: T) -> Self {
Self {
inner: SpinlockInner::new(value),
}
}
#[inline]
pub fn replace(&self, value: T) -> T {
let mut lock = self.lock();
mem::replace(&mut lock, value)
}
/// Attempts to acquire a lock. IRQs will be disabled until the lock is released.
pub fn lock(&self) -> IrqSafeSpinlockGuard<A, T> {
// Disable IRQs to avoid IRQ handler trying to acquire the same lock
let irq_guard = IrqGuard::acquire();
// Acquire the inner lock
let inner = self.inner.lock();
IrqSafeSpinlockGuard {
inner,
_irq: irq_guard,
}
}
/// Returns an unsafe reference to the inner value.
///
/// # Safety
///
/// Unsafe: explicitly ignores proper access sharing.
#[allow(clippy::mut_from_ref)]
pub unsafe fn grab(&self) -> &mut T {
unsafe { &mut *self.inner.value.get() }
}
}
impl<A: Architecture, T: Clone> IrqSafeSpinlock<A, T> {
pub fn get_cloned(&self) -> T {
self.lock().clone()
}
}
impl<A: Architecture, T: Clone> Clone for IrqSafeSpinlock<A, T> {
fn clone(&self) -> Self {
let inner = self.lock();
IrqSafeSpinlock::new(inner.clone())
}
}
impl<'a, A: Architecture, T> Deref for IrqSafeSpinlockGuard<'a, A, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl<'a, A: Architecture, T> DerefMut for IrqSafeSpinlockGuard<'a, A, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.deref_mut()
}
}

View File

@ -0,0 +1,183 @@
use core::fmt;
use alloc::boxed::Box;
use yggdrasil_abi::{arch::SavedFrame, error::Error, process::ExitCode};
use crate::mem::{KernelTableManager, PhysicalMemoryAllocator};
pub trait Scheduler {
type ThreadId: Copy;
fn for_cpu(index: usize) -> &'static Self;
fn for_affinity_mask(mask: u64) -> &'static Self;
fn local() -> &'static Self;
fn is_local(&self) -> bool;
fn push(&self, task: Self::ThreadId);
/// Selects a new thread from the queue and performs a context switch if necessary.
///
/// # Safety
///
/// Only meant to be called from within the timer handler or the thread impl.
unsafe fn yield_cpu(&self) -> bool;
}
/// Conversion trait to allow multiple kernel closure return types
pub trait Termination {
/// Converts the closure return type into [ExitCode]
fn into_exit_code(self) -> ExitCode;
}
/// Interface for task state save/restore mechanisms
pub trait TaskFrame {
/// Creates a "snapshot" of a exception/syscall frame
fn store(&self) -> SavedFrame;
/// Restores the exception/syscall frame from its saved state
fn restore(&mut self, saved: &SavedFrame);
/// Replaces the return value in the frame (or does nothing, if the frame is not a part of a
/// syscall signal handler)
fn set_return_value(&mut self, value: u64);
/// Replaces the userspace stack pointer in the frame
fn set_user_sp(&mut self, value: usize);
/// Replaces the userspace instruction pointer in the frame
fn set_user_ip(&mut self, value: usize);
/// Replaces the argument in the frame
fn set_argument(&mut self, value: u64);
/// Returns the argument (if any) of the frame being processed
fn argument(&self) -> u64;
/// Returns the userspace stack pointer
fn user_sp(&self) -> usize;
/// Returns the userspace instruction pointer
fn user_ip(&self) -> usize;
}
/// Interface for performing context fork operations
pub trait ForkFrame<K: KernelTableManager, PA: PhysicalMemoryAllocator>: Sized {
type Context: TaskContext<K, PA>;
/// Constructs a "forked" task context by copying the registers from this one and supplying a
/// new address space to it.
///
/// # Safety
///
/// Unsafe: accepts raw frames and address space address.
unsafe fn fork(&self, address_space: u64) -> Result<Self::Context, Error>;
/// Replaces the return value inside the frame with a new one
fn set_return_value(&mut self, value: u64);
}
/// Platform-specific task context implementation
pub trait TaskContext<K: KernelTableManager, PA: PhysicalMemoryAllocator>: Sized {
/// Number of bytes to offset the signal stack pointer by
const SIGNAL_STACK_EXTRA_ALIGN: usize;
/// Number of bytes to offset the user stack pointer by
const USER_STACK_EXTRA_ALIGN: usize;
/// Constructs a kernel-space task context
fn kernel(entry: extern "C" fn(usize) -> !, arg: usize) -> Result<Self, Error>;
/// Constructs a user thread context. The caller is responsible for allocating the userspace
/// stack and setting up a valid address space for the context.
fn user(
entry: usize,
arg: usize,
cr3: u64,
user_stack_sp: usize,
tls_address: usize,
) -> Result<Self, Error>;
/// Performs an entry into a context.
///
/// # Safety
///
/// Only meant to be called from the scheduler code.
unsafe fn enter(&self) -> !;
/// Performs a context switch between two contexts.
///
/// # Safety
///
/// Only meant to be called from the scheduler code.
unsafe fn switch(&self, from: &Self);
/// Performs a context switch and drops the source thread.
///
/// # Safety
///
/// Only meant to be called from the scheduler code after the `thread` has terminated.
unsafe fn switch_and_drop(&self, thread: *const ());
// XXX
/// Constructs a safe wrapper process to execute a kernel-space closure
fn kernel_closure<F: FnOnce() -> ! + Send + 'static>(f: F) -> Result<Self, Error> {
extern "C" fn closure_wrapper<F: FnOnce() -> ! + Send + 'static>(closure_addr: usize) -> ! {
let closure = unsafe { Box::from_raw(closure_addr as *mut F) };
closure()
}
let closure = Box::new(f);
Self::kernel(closure_wrapper::<F>, Box::into_raw(closure) as usize)
}
}
pub struct StackBuilder {
base: usize,
sp: usize,
}
impl StackBuilder {
pub fn new(base: usize, size: usize) -> Self {
Self {
base,
sp: base + size,
}
}
pub fn push(&mut self, value: usize) {
if self.sp == self.base {
panic!();
}
self.sp -= 8;
unsafe {
(self.sp as *mut usize).write_volatile(value);
}
}
pub fn build(self) -> usize {
self.sp
}
}
impl<T, E: fmt::Debug> Termination for Result<T, E> {
fn into_exit_code(self) -> ExitCode {
match self {
Ok(_) => ExitCode::SUCCESS,
Err(_err) => {
// XXX
// log::warn!("Kernel thread failed: {:?}", err);
ExitCode::Exited(1)
}
}
}
}
impl Termination for ExitCode {
fn into_exit_code(self) -> ExitCode {
self
}
}
impl Termination for () {
fn into_exit_code(self) -> ExitCode {
ExitCode::SUCCESS
}
}

View File

@ -0,0 +1,125 @@
use core::{
cell::UnsafeCell,
mem::MaybeUninit,
panic,
sync::atomic::{AtomicUsize, Ordering},
};
/// Wrapper struct to ensure a value can only be initialized once and used only after that
#[repr(C)]
pub struct OneTimeInit<T> {
value: UnsafeCell<MaybeUninit<T>>,
state: AtomicUsize,
}
unsafe impl<T> Sync for OneTimeInit<T> {}
unsafe impl<T> Send for OneTimeInit<T> {}
impl<T> OneTimeInit<T> {
const STATE_UNINITIALIZED: usize = 0;
const STATE_INITIALIZING: usize = 1;
const STATE_INITIALIZED: usize = 2;
/// Wraps the value in an [OneTimeInit]
pub const fn new() -> Self {
Self {
value: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicUsize::new(Self::STATE_UNINITIALIZED),
}
}
/// Returns `true` if the value has already been initialized
#[inline]
pub fn is_initialized(&self) -> bool {
self.state.load(Ordering::Acquire) == Self::STATE_INITIALIZED
}
pub fn try_init_with<F: FnOnce() -> T>(&self, f: F) -> Option<&T> {
if self
.state
.compare_exchange(
Self::STATE_UNINITIALIZED,
Self::STATE_INITIALIZING,
Ordering::Release,
Ordering::Relaxed,
)
.is_err()
{
// Already initialized
return None;
}
let value = unsafe { (*self.value.get()).write(f()) };
self.state
.compare_exchange(
Self::STATE_INITIALIZING,
Self::STATE_INITIALIZED,
Ordering::Release,
Ordering::Relaxed,
)
.unwrap();
Some(value)
}
/// Sets the underlying value of the [OneTimeInit]. If already initialized, panics.
#[track_caller]
pub fn init(&self, value: T) -> &T {
// Transition to "initializing" state
if self
.state
.compare_exchange(
Self::STATE_UNINITIALIZED,
Self::STATE_INITIALIZING,
Ordering::Release,
Ordering::Relaxed,
)
.is_err()
{
panic!(
"{:?}: Double initialization of OneTimeInit<T>",
panic::Location::caller()
);
}
let value = unsafe { (*self.value.get()).write(value) };
// Transition to "initialized" state. This must not fail
self.state
.compare_exchange(
Self::STATE_INITIALIZING,
Self::STATE_INITIALIZED,
Ordering::Release,
Ordering::Relaxed,
)
.unwrap();
value
}
/// Returns an immutable reference to the underlying value and panics if it hasn't yet been
/// initialized
#[track_caller]
pub fn get(&self) -> &T {
// TODO check for INITIALIZING state and wait until it becomes INITIALIZED?
if !self.is_initialized() {
panic!(
"{:?}: Attempt to dereference an uninitialized value",
panic::Location::caller()
);
}
unsafe { (*self.value.get()).assume_init_ref() }
}
/// Returns an immutable reference to the underlying value and [None] if the value hasn't yet
/// been initialized
pub fn try_get(&self) -> Option<&T> {
if self.is_initialized() {
Some(self.get())
} else {
None
}
}
}

39
kernel/arch/src/lib.rs Normal file
View File

@ -0,0 +1,39 @@
#![no_std]
use cfg_if::cfg_if;
/// Returns an absolute address to the given symbol
#[macro_export]
macro_rules! absolute_address {
($sym:expr) => {{
let mut _x: usize;
#[cfg(target_arch = "aarch64")]
unsafe {
core::arch::asm!("ldr {0}, ={1}", out(reg) _x, sym $sym);
}
#[cfg(target_arch = "x86_64")]
unsafe {
core::arch::asm!("movabsq ${1}, {0}", out(reg) _x, sym $sym, options(att_syntax));
}
_x
}};
}
cfg_if! {
if #[cfg(any(test, not(target_os = "none")))] {
extern crate kernel_arch_hosted as imp;
} else if #[cfg(target_arch = "aarch64")] {
extern crate kernel_arch_aarch64 as imp;
} else if #[cfg(target_arch = "x86_64")] {
extern crate kernel_arch_x86_64 as imp;
} else {
compile_error!("Unsupported architecture");
}
}
pub use imp::{ArchitectureImpl, KernelTableManagerImpl, ProcessAddressSpaceImpl, TaskContextImpl};
pub use kernel_arch_interface::{guard, mem, sync, task, util, Architecture};
pub type CpuImpl<S> = kernel_arch_interface::cpu::CpuImpl<ArchitectureImpl, S>;
pub type LocalCpuImpl<'a, S> = kernel_arch_interface::cpu::LocalCpuImpl<'a, ArchitectureImpl, S>;

View File

@ -0,0 +1,17 @@
[package]
name = "kernel-arch-x86_64"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
kernel-arch-interface = { path = "../interface" }
libk-mm-interface = { path = "../../libk/libk-mm/interface" }
memtables = { path = "../../lib/memtables" }
device-api = { path = "../../lib/device-api", features = ["derive"] }
bitflags = "2.3.3"
static_assertions = "1.1.0"
tock-registers = "0.8.1"

View File

@ -0,0 +1,186 @@
// vi: set ft=asm :
.set MSR_IA32_FS_BASE, 0xC0000100
.macro SAVE_TASK_STATE
sub ${context_size}, %rsp
mov %rbx, 0(%rsp)
mov %r12, 8(%rsp)
mov %r13, 16(%rsp)
mov %r14, 24(%rsp)
mov %r15, 32(%rsp)
// Store FS_BASE
mov $MSR_IA32_FS_BASE, %ecx
rdmsr
mov %edx, %ecx
shl $32, %rcx
or %rax, %rcx
mov %rcx, 40(%rsp)
// TODO save %fs
mov %rbp, 48(%rsp)
mov %cr3, %rbx
mov %rbx, 56(%rsp)
.endm
.macro LOAD_TASK_STATE
mov 56(%rsp), %rbx
mov %rbx, %cr3
mov 0(%rsp), %rbx
mov 8(%rsp), %r12
mov 16(%rsp), %r13
mov 24(%rsp), %r14
mov 32(%rsp), %r15
// Load FS_BASE
// edx:eax = fs_base
mov 40(%rsp), %rdx
mov %edx, %eax
shr $32, %rdx
mov $MSR_IA32_FS_BASE, %ecx
wrmsr
// mov 40(%rsp), %fs
mov 48(%rsp), %rbp
add ${context_size}, %rsp
.endm
.global __x86_64_task_enter_user
.global __x86_64_task_enter_kernel
.global __x86_64_task_enter_from_fork
.global __x86_64_enter_task
.global __x86_64_switch_task
.global __x86_64_switch_and_drop
.section .text
__x86_64_task_enter_from_fork:
xorq %rax, %rax
xorq %rcx, %rcx
xorq %r11, %r11
popq %rdi
popq %rsi
popq %rdx
popq %r10
popq %r8
popq %r9
swapgs
iretq
__x86_64_task_enter_user:
// User stack pointer
popq %rcx
// Argument
popq %rdi
// Entry address
popq %rax
// SS:RSP
pushq $0x1B
pushq %rcx
// RFLAGS
pushq $0x200
// CS:RIP
pushq $0x23
pushq %rax
swapgs
iretq
__x86_64_task_enter_kernel:
// Argument
popq %rdi
// Entry address
popq %rax
// Alignment word + fake return address to terminate "call chain"
pushq $0
// Enable IRQ in RFLAGS
pushfq
popq %rdx
or $(1 << 9), %rdx
mov %rsp, %rcx
// SS:RSP
pushq $0x10
pushq %rcx
// RFLAGS
pushq %rdx
// CS:RIP
pushq $0x08
pushq %rax
iretq
// %rsi - from struct ptr, %rdi - to struct ptr
__x86_64_switch_task:
SAVE_TASK_STATE
mov %rsp, 0(%rsi)
// TSS.RSP0
mov 8(%rdi), %rax
// Kernel stack
mov 0(%rdi), %rdi
mov %rdi, %rsp
// Load TSS.RSP0
mov %gs:(8), %rdi
mov %rax, 4(%rdi)
LOAD_TASK_STATE
ret
__x86_64_switch_and_drop:
// TSS.RSP0
mov 8(%rdi), %rax
// Kernel stack
mov 0(%rdi), %rdi
mov %rdi, %rsp
// Load TSS.RSP0
mov %gs:(8), %rdi
mov %rax, 4(%rdi)
mov %rsi, %rdi
call __arch_drop_thread
LOAD_TASK_STATE
ret
// %rdi - to struct ptr
__x86_64_enter_task:
// TSS.RSP0
mov 8(%rdi), %rax
// Kernel stack
mov 0(%rdi), %rdi
mov %rdi, %rsp
// Load TSS.RSP0
mov %gs:(8), %rdi
mov %rax, 4(%rdi)
LOAD_TASK_STATE
ret

View File

@ -0,0 +1,525 @@
use core::{arch::global_asm, cell::UnsafeCell, marker::PhantomData};
use kernel_arch_interface::{
mem::{KernelTableManager, PhysicalMemoryAllocator},
task::{ForkFrame, StackBuilder, TaskContext, TaskFrame},
};
use libk_mm_interface::address::{AsPhysicalAddress, IntoRaw, PhysicalAddress};
use yggdrasil_abi::{arch::SavedFrame, error::Error};
use crate::{mem::KERNEL_TABLES, registers::FpuContext};
/// Frame saved onto the stack when taking an IRQ
#[derive(Debug)]
#[repr(C)]
pub struct IrqFrame {
pub rax: u64,
pub rcx: u64,
pub rdx: u64,
pub rbx: u64,
pub rsi: u64,
pub rdi: u64,
pub rbp: u64,
pub r8: u64,
pub r9: u64,
pub r10: u64,
pub r11: u64,
pub r12: u64,
pub r13: u64,
pub r14: u64,
pub r15: u64,
pub rip: u64,
pub cs: u64,
pub rflags: u64,
pub rsp: u64,
pub ss: u64,
}
/// Set of registers saved when taking an exception/interrupt
#[derive(Debug)]
#[repr(C)]
pub struct ExceptionFrame {
pub rax: u64,
pub rcx: u64,
pub rdx: u64,
pub rbx: u64,
pub rsi: u64,
pub rdi: u64,
pub rbp: u64,
pub r8: u64,
pub r9: u64,
pub r10: u64,
pub r11: u64,
pub r12: u64,
pub r13: u64,
pub r14: u64,
pub r15: u64,
pub exc_number: u64,
pub exc_code: u64,
pub rip: u64,
pub cs: u64,
pub rflags: u64,
pub rsp: u64,
pub ss: u64,
}
/// Set of registers saved when taking a syscall instruction
#[derive(Debug)]
#[repr(C)]
pub struct SyscallFrame {
pub rax: u64,
pub args: [u64; 6],
pub rcx: u64,
pub r11: u64,
pub user_ip: u64,
pub user_sp: u64,
pub user_flags: u64,
pub rbx: u64,
pub rbp: u64,
pub r12: u64,
pub r13: u64,
pub r14: u64,
pub r15: u64,
}
#[repr(C, align(0x10))]
struct Inner {
// 0x00
sp: usize,
// 0x08
tss_rsp0: usize,
}
/// x86-64 implementation of a task context
#[allow(dead_code)]
pub struct TaskContextImpl<
K: KernelTableManager,
PA: PhysicalMemoryAllocator<Address = PhysicalAddress>,
> {
inner: UnsafeCell<Inner>,
fpu_context: UnsafeCell<FpuContext>,
stack_base_phys: PhysicalAddress,
stack_size: usize,
_alloc: PhantomData<PA>,
_table_manager: PhantomData<K>,
}
// 8 registers + return address (which is not included)
const COMMON_CONTEXT_SIZE: usize = 8 * 8;
impl TaskFrame for IrqFrame {
fn store(&self) -> SavedFrame {
SavedFrame {
rax: self.rax,
rcx: self.rcx,
rdx: self.rdx,
rbx: self.rbx,
rsi: self.rsi,
rdi: self.rdi,
rbp: self.rbp,
r8: self.r8,
r9: self.r9,
r10: self.r10,
r11: self.r11,
r12: self.r12,
r13: self.r13,
r14: self.r14,
r15: self.r15,
user_ip: self.rip,
user_sp: self.rsp,
rflags: self.rflags,
}
}
fn restore(&mut self, _saved: &SavedFrame) {
todo!()
}
fn argument(&self) -> u64 {
self.rdi as _
}
fn user_ip(&self) -> usize {
self.rip as _
}
fn user_sp(&self) -> usize {
self.rsp as _
}
fn set_argument(&mut self, value: u64) {
self.rdi = value;
}
fn set_return_value(&mut self, value: u64) {
self.rax = value;
}
fn set_user_ip(&mut self, value: usize) {
self.rip = value as _;
}
fn set_user_sp(&mut self, value: usize) {
self.rsp = value as _;
}
}
impl TaskFrame for ExceptionFrame {
fn store(&self) -> SavedFrame {
SavedFrame {
rax: self.rax,
rcx: self.rcx,
rdx: self.rdx,
rbx: self.rbx,
rsi: self.rsi,
rdi: self.rdi,
rbp: self.rbp,
r8: self.r8,
r9: self.r9,
r10: self.r10,
r11: self.r11,
r12: self.r12,
r13: self.r13,
r14: self.r14,
r15: self.r15,
user_ip: self.rip,
user_sp: self.rsp,
rflags: self.rflags,
}
}
fn restore(&mut self, _saved: &SavedFrame) {
todo!()
}
fn argument(&self) -> u64 {
0
}
fn user_sp(&self) -> usize {
self.rsp as _
}
fn user_ip(&self) -> usize {
self.rip as _
}
fn set_user_sp(&mut self, value: usize) {
self.rsp = value as _;
}
fn set_user_ip(&mut self, value: usize) {
self.rip = value as _;
}
fn set_return_value(&mut self, _value: u64) {
// Not in syscall, do not overwrite
}
fn set_argument(&mut self, value: u64) {
self.rdi = value;
}
}
impl<K: KernelTableManager, PA: PhysicalMemoryAllocator<Address = PhysicalAddress>> ForkFrame<K, PA>
for SyscallFrame
{
type Context = TaskContextImpl<K, PA>;
unsafe fn fork(&self, address_space: u64) -> Result<TaskContextImpl<K, PA>, Error> {
TaskContextImpl::from_syscall_frame(self, address_space)
}
fn set_return_value(&mut self, value: u64) {
self.rax = value;
}
}
impl TaskFrame for SyscallFrame {
fn store(&self) -> SavedFrame {
SavedFrame {
rax: self.rax,
rcx: self.rcx,
rdx: self.args[2],
rbx: self.rbx,
rsi: self.args[1],
rdi: self.args[0],
rbp: self.rbp,
r8: self.args[4],
r9: self.args[5],
r10: self.args[3],
r11: self.r11,
r12: self.r12,
r13: self.r13,
r14: self.r14,
r15: self.r15,
user_ip: self.user_ip,
user_sp: self.user_sp,
rflags: self.user_flags,
}
}
fn restore(&mut self, saved: &SavedFrame) {
self.rax = saved.rax;
self.args[0] = saved.rdi;
self.args[1] = saved.rsi;
self.args[2] = saved.rdx;
self.args[3] = saved.r10;
self.args[4] = saved.r8;
self.args[5] = saved.r9;
self.rcx = saved.rcx;
self.r11 = saved.r11;
self.user_ip = saved.user_ip;
self.user_sp = saved.user_sp;
self.user_flags = saved.rflags;
self.rbx = saved.rbx;
self.rbp = saved.rbp;
self.r12 = saved.r12;
self.r13 = saved.r13;
self.r14 = saved.r14;
self.r15 = saved.r15;
}
fn argument(&self) -> u64 {
self.args[0]
}
fn user_sp(&self) -> usize {
self.user_sp as _
}
fn user_ip(&self) -> usize {
self.user_ip as _
}
fn set_user_sp(&mut self, value: usize) {
self.user_sp = value as _;
}
fn set_user_ip(&mut self, value: usize) {
self.user_ip = value as _;
}
fn set_return_value(&mut self, value: u64) {
self.rax = value;
}
fn set_argument(&mut self, value: u64) {
self.args[0] = value;
}
}
impl<K: KernelTableManager, PA: PhysicalMemoryAllocator<Address = PhysicalAddress>>
TaskContextImpl<K, PA>
{
/// Constructs a new task context from a "forked" syscall frame
pub(super) unsafe fn from_syscall_frame(frame: &SyscallFrame, cr3: u64) -> Result<Self, Error> {
const USER_TASK_PAGES: usize = 8;
let stack_base_phys = PA::allocate_contiguous_pages(USER_TASK_PAGES)?;
let stack_base = stack_base_phys.raw_virtualize::<K>();
let mut stack = StackBuilder::new(stack_base, USER_TASK_PAGES * 0x1000);
// iretq frame
stack.push(0x1B);
stack.push(frame.user_sp as _);
stack.push(0x200);
stack.push(0x23);
stack.push(frame.user_ip as _);
stack.push(frame.args[5] as _); // r9
stack.push(frame.args[4] as _); // r8
stack.push(frame.args[3] as _); // r10
stack.push(frame.args[2] as _); // rdx
stack.push(frame.args[1] as _); // rsi
stack.push(frame.args[0] as _); // rdi
// callee-saved registers
stack.push(__x86_64_task_enter_from_fork as _);
stack.push(cr3 as _);
stack.push(frame.rbp as _);
stack.push(0x12345678); // XXX TODO: fs_base from SyscallFrame
stack.push(frame.r15 as _);
stack.push(frame.r14 as _);
stack.push(frame.r13 as _);
stack.push(frame.r12 as _);
stack.push(frame.rbx as _);
let sp = stack.build();
let rsp0 = stack_base + USER_TASK_PAGES * 0x1000;
Ok(Self {
inner: UnsafeCell::new(Inner { sp, tss_rsp0: rsp0 }),
fpu_context: UnsafeCell::new(FpuContext::new()),
stack_base_phys,
stack_size: USER_TASK_PAGES * 0x1000,
_alloc: PhantomData,
_table_manager: PhantomData,
})
}
}
unsafe impl<K: KernelTableManager, PA: PhysicalMemoryAllocator<Address = PhysicalAddress>> Sync
for TaskContextImpl<K, PA>
{
}
impl<K: KernelTableManager, PA: PhysicalMemoryAllocator<Address = PhysicalAddress>>
TaskContext<K, PA> for TaskContextImpl<K, PA>
{
const SIGNAL_STACK_EXTRA_ALIGN: usize = 8;
const USER_STACK_EXTRA_ALIGN: usize = 8;
fn kernel(entry: extern "C" fn(usize) -> !, arg: usize) -> Result<Self, Error> {
const KERNEL_TASK_PAGES: usize = 32;
let stack_base_phys = PA::allocate_contiguous_pages(KERNEL_TASK_PAGES)?;
let stack_base = stack_base_phys.raw_virtualize::<K>();
let mut stack = StackBuilder::new(stack_base, KERNEL_TASK_PAGES * 0x1000);
// Entry and argument
stack.push(entry as _);
stack.push(arg);
// XXX
setup_common_context(
&mut stack,
__x86_64_task_enter_kernel as _,
unsafe { KERNEL_TABLES.as_physical_address().into_raw() },
0,
);
let sp = stack.build();
// TODO stack is leaked
Ok(Self {
inner: UnsafeCell::new(Inner { sp, tss_rsp0: 0 }),
fpu_context: UnsafeCell::new(FpuContext::new()),
stack_base_phys,
stack_size: KERNEL_TASK_PAGES * 0x1000,
_alloc: PhantomData,
_table_manager: PhantomData,
})
}
fn user(
entry: usize,
arg: usize,
cr3: u64,
user_stack_sp: usize,
fs_base: usize,
) -> Result<Self, Error> {
const USER_TASK_PAGES: usize = 8;
let stack_base_phys = PA::allocate_contiguous_pages(USER_TASK_PAGES)?;
let stack_base = stack_base_phys.raw_virtualize::<K>();
let mut stack = StackBuilder::new(stack_base, USER_TASK_PAGES * 0x1000);
stack.push(entry as _);
stack.push(arg);
stack.push(user_stack_sp);
setup_common_context(&mut stack, __x86_64_task_enter_user as _, cr3, fs_base);
let sp = stack.build();
let rsp0 = stack_base + USER_TASK_PAGES * 0x1000;
Ok(Self {
inner: UnsafeCell::new(Inner { sp, tss_rsp0: rsp0 }),
fpu_context: UnsafeCell::new(FpuContext::new()),
stack_base_phys,
stack_size: USER_TASK_PAGES * 0x1000,
_alloc: PhantomData,
_table_manager: PhantomData,
})
}
unsafe fn enter(&self) -> ! {
FpuContext::restore(self.fpu_context.get());
__x86_64_enter_task(self.inner.get())
}
unsafe fn switch(&self, from: &Self) {
let dst = self.inner.get();
let src = from.inner.get();
if dst != src {
// Save the old context
FpuContext::save(from.fpu_context.get());
// Load next context
FpuContext::restore(self.fpu_context.get());
__x86_64_switch_task(dst, src);
}
}
unsafe fn switch_and_drop(&self, thread: *const ()) {
let dst = self.inner.get();
FpuContext::restore(self.fpu_context.get());
__x86_64_switch_and_drop(dst, thread)
}
}
impl<K: KernelTableManager, PA: PhysicalMemoryAllocator<Address = PhysicalAddress>> Drop
for TaskContextImpl<K, PA>
{
fn drop(&mut self) {
assert_eq!(self.stack_size % 0x1000, 0);
for offset in (0..self.stack_size).step_by(0x1000) {
unsafe {
PA::free_page(self.stack_base_phys.add(offset));
}
}
}
}
fn setup_common_context(builder: &mut StackBuilder, entry: usize, cr3: u64, fs_base: usize) {
builder.push(entry);
builder.push(cr3 as _);
builder.push(0); // %rbp
builder.push(fs_base); // %fs_base
builder.push(0); // %r15
builder.push(0); // %r14
builder.push(0); // %r13
builder.push(0); // %r12
builder.push(0); // %rbx
}
extern "C" {
fn __x86_64_task_enter_kernel();
fn __x86_64_task_enter_user();
fn __x86_64_task_enter_from_fork();
fn __x86_64_enter_task(to: *mut Inner) -> !;
fn __x86_64_switch_task(to: *mut Inner, from: *mut Inner);
fn __x86_64_switch_and_drop(to: *mut Inner, from: *const ());
}
global_asm!(
include_str!("context.S"),
context_size = const COMMON_CONTEXT_SIZE,
options(att_syntax)
);

View File

@ -0,0 +1,172 @@
#![no_std]
#![feature(
effects,
strict_provenance,
asm_const,
naked_functions,
trait_upcasting
)]
extern crate alloc;
use core::{
ops::DerefMut,
sync::atomic::{AtomicUsize, Ordering},
};
use alloc::vec::Vec;
use device_api::interrupt::{LocalInterruptController, MessageInterruptController};
use kernel_arch_interface::{
cpu::{CpuImpl, IpiQueue},
task::Scheduler,
util::OneTimeInit,
Architecture,
};
use libk_mm_interface::address::PhysicalAddress;
use registers::MSR_IA32_KERNEL_GS_BASE;
use tock_registers::interfaces::Writeable;
pub mod context;
pub mod mem;
pub mod registers;
pub use context::TaskContextImpl;
pub use mem::{process::ProcessAddressSpaceImpl, KernelTableManagerImpl};
pub struct ArchitectureImpl;
pub const KERNEL_VIRT_OFFSET: usize = 0xFFFFFF8000000000;
pub trait LocalApicInterface: LocalInterruptController + MessageInterruptController {
/// Performs an application processor startup sequence.
///
/// # Safety
///
/// Unsafe: only meant to be called by the BSP during SMP init.
unsafe fn wakeup_cpu(&self, apic_id: u32, bootstrap_code: PhysicalAddress);
/// Signals local APIC that we've handled the IRQ
fn clear_interrupt(&self);
}
#[repr(C, align(0x10))]
pub struct PerCpuData {
// 0x00
pub this: *mut Self,
// 0x08, used in assembly
pub tss_address: usize,
// 0x10, used in assembly
pub tmp_address: usize,
pub local_apic: &'static dyn LocalApicInterface,
}
impl PerCpuData {
pub fn local_apic(&self) -> &'static dyn LocalApicInterface {
self.local_apic
}
}
static IPI_QUEUES: OneTimeInit<Vec<IpiQueue<ArchitectureImpl>>> = OneTimeInit::new();
pub static CPU_COUNT: AtomicUsize = AtomicUsize::new(1);
#[naked]
extern "C" fn idle_task(_: usize) -> ! {
unsafe {
core::arch::asm!(
r#"
1:
nop
jmp 1b
"#,
options(noreturn, att_syntax)
);
}
}
impl ArchitectureImpl {
fn local_cpu_data() -> Option<&'static mut PerCpuData> {
unsafe { (Self::local_cpu() as *mut PerCpuData).as_mut() }
}
}
impl Architecture for ArchitectureImpl {
type PerCpuData = PerCpuData;
unsafe fn set_local_cpu(cpu: *mut ()) {
MSR_IA32_KERNEL_GS_BASE.set(cpu as u64);
core::arch::asm!("wbinvd; swapgs");
}
fn local_cpu() -> *mut () {
let mut addr: u64;
unsafe {
core::arch::asm!("movq %gs:(0), {0}", out(reg) addr, options(att_syntax));
}
addr as _
}
unsafe fn init_ipi_queues(queues: Vec<IpiQueue<Self>>) {
IPI_QUEUES.init(queues);
}
unsafe fn init_local_cpu<S: Scheduler + 'static>(id: Option<u32>, data: Self::PerCpuData) {
use alloc::boxed::Box;
let cpu = Box::leak(Box::new(CpuImpl::<Self, S>::new(
id.expect("x86_64 required manual CPU ID set"),
data,
)));
cpu.this = cpu.deref_mut();
cpu.set_local();
}
fn idle_task() -> extern "C" fn(usize) -> ! {
idle_task
}
fn cpu_count() -> usize {
CPU_COUNT.load(Ordering::Acquire)
}
fn cpu_index<S: Scheduler + 'static>() -> u32 {
CpuImpl::<Self, S>::local().id()
}
fn interrupt_mask() -> bool {
let mut flags: u64;
unsafe {
core::arch::asm!("pushfq; pop {0}", out(reg) flags, options(att_syntax));
}
// If IF is zero, interrupts are disabled (masked)
flags & (1 << 9) == 0
}
unsafe fn set_interrupt_mask(mask: bool) -> bool {
let old = Self::interrupt_mask();
if mask {
core::arch::asm!("cli");
} else {
core::arch::asm!("sti");
}
old
}
#[inline]
fn wait_for_interrupt() {
unsafe {
core::arch::asm!("hlt");
}
}
fn local_interrupt_controller() -> &'static dyn LocalInterruptController {
let local = Self::local_cpu_data().unwrap();
local.local_apic
}
fn message_interrupt_controller() -> &'static dyn MessageInterruptController {
let local = Self::local_cpu_data().unwrap();
local.local_apic
}
}

View File

@ -0,0 +1,405 @@
use core::{
alloc::Layout,
ops::{Deref, DerefMut},
ptr::addr_of,
sync::atomic::{AtomicUsize, Ordering},
};
use kernel_arch_interface::mem::{
DeviceMemoryAttributes, KernelTableManager, RawDeviceMemoryMapping,
};
use libk_mm_interface::{
address::{FromRaw, PhysicalAddress},
table::{EntryLevel, EntryLevelExt},
KernelImageObject,
};
use memtables::x86_64::FixedTables;
use static_assertions::{const_assert_eq, const_assert_ne};
use yggdrasil_abi::error::Error;
use crate::{registers::CR3, KERNEL_VIRT_OFFSET};
use self::table::{PageAttributes, PageEntry, PageTable, L0, L1, L2, L3};
pub mod process;
pub mod table;
#[derive(Debug)]
pub struct KernelTableManagerImpl;
const CANONICAL_ADDRESS_MASK: usize = 0xFFFF000000000000;
const KERNEL_PHYS_BASE: usize = 0x200000;
// Mapped at compile time
const KERNEL_MAPPING_BASE: usize = KERNEL_VIRT_OFFSET + KERNEL_PHYS_BASE;
const KERNEL_L0_INDEX: usize = KERNEL_MAPPING_BASE.page_index::<L0>();
const KERNEL_L1_INDEX: usize = KERNEL_MAPPING_BASE.page_index::<L1>();
const KERNEL_START_L2_INDEX: usize = KERNEL_MAPPING_BASE.page_index::<L2>();
// Must not be zero, should be at 4MiB
const_assert_ne!(KERNEL_START_L2_INDEX, 0);
// From static mapping
const_assert_eq!(KERNEL_L0_INDEX, 511);
const_assert_eq!(KERNEL_L1_INDEX, 0);
// Mapped at boot
const EARLY_MAPPING_L2I: usize = KERNEL_START_L2_INDEX - 1;
const HEAP_MAPPING_L1I: usize = KERNEL_L1_INDEX + 1;
const DEVICE_MAPPING_L1I: usize = KERNEL_L1_INDEX + 2;
const RAM_MAPPING_L0I: usize = KERNEL_L0_INDEX - 1;
const DEVICE_MAPPING_L3_COUNT: usize = 4;
#[link_section = ".data.tables"]
pub static mut KERNEL_TABLES: KernelImageObject<FixedTables> =
unsafe { KernelImageObject::new(FixedTables::zeroed()) };
// 2MiB for early mappings
const EARLY_MAPPING_OFFSET: usize = CANONICAL_ADDRESS_MASK
| (KERNEL_L0_INDEX * L0::SIZE)
| (KERNEL_L1_INDEX * L1::SIZE)
| (EARLY_MAPPING_L2I * L2::SIZE);
static mut EARLY_MAPPING_L3: PageTable<L3> = PageTable::zeroed();
// 1GiB for heap mapping
pub const HEAP_MAPPING_OFFSET: usize =
CANONICAL_ADDRESS_MASK | (KERNEL_L0_INDEX * L0::SIZE) | (HEAP_MAPPING_L1I * L1::SIZE);
pub(super) static mut HEAP_MAPPING_L2: PageTable<L2> = PageTable::zeroed();
// 1GiB for device MMIO mapping
const DEVICE_MAPPING_OFFSET: usize =
CANONICAL_ADDRESS_MASK | (KERNEL_L0_INDEX * L0::SIZE) | (DEVICE_MAPPING_L1I * L1::SIZE);
static mut DEVICE_MAPPING_L2: PageTable<L2> = PageTable::zeroed();
static mut DEVICE_MAPPING_L3S: [PageTable<L3>; DEVICE_MAPPING_L3_COUNT] =
[PageTable::zeroed(); DEVICE_MAPPING_L3_COUNT];
// 512GiB for whole RAM mapping
pub const RAM_MAPPING_OFFSET: usize = CANONICAL_ADDRESS_MASK | (RAM_MAPPING_L0I * L0::SIZE);
pub static MEMORY_LIMIT: AtomicUsize = AtomicUsize::new(0);
pub static mut RAM_MAPPING_L1: PageTable<L1> = PageTable::zeroed();
impl KernelTableManager for KernelTableManagerImpl {
fn virtualize(address: u64) -> usize {
let address = address as usize;
if address < MEMORY_LIMIT.load(Ordering::Acquire) {
address + RAM_MAPPING_OFFSET
} else {
panic!("Invalid physical address: {:#x}", address);
}
}
fn physicalize(address: usize) -> u64 {
if address < RAM_MAPPING_OFFSET
|| address - RAM_MAPPING_OFFSET >= MEMORY_LIMIT.load(Ordering::Acquire)
{
panic!("Not a virtualized physical address: {:#x}", address);
}
(address - RAM_MAPPING_OFFSET) as _
}
unsafe fn map_device_pages(
base: u64,
count: usize,
attrs: DeviceMemoryAttributes,
) -> Result<RawDeviceMemoryMapping<Self>, Error> {
map_device_memory(PhysicalAddress::from_raw(base), count, attrs)
}
unsafe fn unmap_device_pages(mapping: &RawDeviceMemoryMapping<Self>) {
unmap_device_memory(mapping)
}
}
// Early mappings
unsafe fn map_early_pages(physical: PhysicalAddress, count: usize) -> Result<usize, Error> {
for l3i in 0..512 {
let mut taken = false;
for i in 0..count {
if EARLY_MAPPING_L3[i + l3i].is_present() {
taken = true;
break;
}
}
if taken {
continue;
}
for i in 0..count {
// TODO NX, NC
EARLY_MAPPING_L3[i + l3i] =
PageEntry::page(physical.add(i * L3::SIZE), PageAttributes::WRITABLE);
}
return Ok(EARLY_MAPPING_OFFSET + l3i * L3::SIZE);
}
Err(Error::OutOfMemory)
}
unsafe fn unmap_early_page(address: usize) {
if !(EARLY_MAPPING_OFFSET..EARLY_MAPPING_OFFSET + L2::SIZE).contains(&address) {
panic!("Tried to unmap invalid early mapping: {:#x}", address);
}
let l3i = (address - EARLY_MAPPING_OFFSET).page_index::<L3>();
assert!(EARLY_MAPPING_L3[l3i].is_present());
EARLY_MAPPING_L3[l3i] = PageEntry::INVALID;
}
// Device mappings
unsafe fn map_device_memory_l3(
base: PhysicalAddress,
count: usize,
_attrs: DeviceMemoryAttributes,
) -> Result<usize, Error> {
// TODO don't map pages if already mapped
'l0: for i in 0..DEVICE_MAPPING_L3_COUNT * 512 {
for j in 0..count {
let l2i = (i + j) / 512;
let l3i = (i + j) % 512;
if DEVICE_MAPPING_L3S[l2i][l3i].is_present() {
continue 'l0;
}
}
for j in 0..count {
let l2i = (i + j) / 512;
let l3i = (i + j) % 512;
// TODO NX, NC
DEVICE_MAPPING_L3S[l2i][l3i] =
PageEntry::page(base.add(j * L3::SIZE), PageAttributes::WRITABLE);
}
return Ok(DEVICE_MAPPING_OFFSET + i * L3::SIZE);
}
Err(Error::OutOfMemory)
}
unsafe fn map_device_memory_l2(
base: PhysicalAddress,
count: usize,
_attrs: DeviceMemoryAttributes,
) -> Result<usize, Error> {
'l0: for i in DEVICE_MAPPING_L3_COUNT..512 {
for j in 0..count {
if DEVICE_MAPPING_L2[i + j].is_present() {
continue 'l0;
}
}
for j in 0..count {
DEVICE_MAPPING_L2[i + j] =
PageEntry::<L2>::block(base.add(j * L2::SIZE), PageAttributes::WRITABLE);
}
// debugln!(
// "map l2s: base={:#x}, count={} -> {:#x}",
// base,
// count,
// DEVICE_MAPPING_OFFSET + i * L2::SIZE
// );
return Ok(DEVICE_MAPPING_OFFSET + i * L2::SIZE);
}
Err(Error::OutOfMemory)
}
unsafe fn map_device_memory(
base: PhysicalAddress,
size: usize,
attrs: DeviceMemoryAttributes,
) -> Result<RawDeviceMemoryMapping<KernelTableManagerImpl>, Error> {
// debugln!("Map {}B @ {:#x}", size, base);
let l3_aligned = base.page_align_down::<L3>();
let l3_offset = base.page_offset::<L3>();
let page_count = (l3_offset + size).page_count::<L3>();
if page_count > 256 {
// Large mapping, use L2 mapping instead
let l2_aligned = base.page_align_down::<L2>();
let l2_offset = base.page_offset::<L2>();
let page_count = (l2_offset + size).page_count::<L2>();
let base_address = map_device_memory_l2(l2_aligned, page_count, attrs)?;
let address = base_address + l2_offset;
Ok(RawDeviceMemoryMapping::from_raw_parts(
address,
base_address,
page_count,
L2::SIZE,
))
} else {
// Just map the pages directly
let base_address = map_device_memory_l3(l3_aligned, page_count, attrs)?;
let address = base_address + l3_offset;
Ok(RawDeviceMemoryMapping::from_raw_parts(
address,
base_address,
page_count,
L3::SIZE,
))
}
}
unsafe fn unmap_device_memory(map: &RawDeviceMemoryMapping<KernelTableManagerImpl>) {
// debugln!(
// "Unmap {}B @ {:#x}",
// map.page_count * map.page_size,
// map.base_address
// );
match map.page_size {
L3::SIZE => {
for i in 0..map.page_count {
let page = map.base_address + i * L3::SIZE;
let l2i = page.page_index::<L2>();
let l3i = page.page_index::<L3>();
assert!(DEVICE_MAPPING_L3S[l2i][l3i].is_present());
DEVICE_MAPPING_L3S[l2i][l3i] = PageEntry::INVALID;
flush_tlb_entry(page);
}
}
L2::SIZE => todo!(),
_ => unimplemented!(),
}
}
pub unsafe fn map_heap_block(index: usize, page: PhysicalAddress) {
if !page.is_page_aligned_for::<L2>() {
panic!("Attempted to map a misaligned 2MiB page");
}
assert!(index < 512);
if HEAP_MAPPING_L2[index].is_present() {
panic!("Page is already mappged: {:#x}", page);
}
// TODO NX
HEAP_MAPPING_L2[index] = PageEntry::<L2>::block(page, PageAttributes::WRITABLE);
}
/// Memory mapping which may be used for performing early kernel initialization
pub struct EarlyMapping<'a, T: ?Sized> {
value: &'a mut T,
page_count: usize,
}
impl<'a, T: Sized> EarlyMapping<'a, T> {
pub unsafe fn map(physical: PhysicalAddress) -> Result<EarlyMapping<'a, T>, Error> {
let layout = Layout::new::<T>();
let aligned = physical.page_align_down::<L3>();
let offset = physical.page_offset::<L3>();
let page_count = (offset + layout.size() + L3::SIZE - 1) / L3::SIZE;
let virt = map_early_pages(aligned, page_count)?;
let value = &mut *((virt + offset) as *mut T);
Ok(EarlyMapping { value, page_count })
}
pub unsafe fn map_slice(
physical: PhysicalAddress,
len: usize,
) -> Result<EarlyMapping<'a, [T]>, Error> {
let layout = Layout::array::<T>(len).unwrap();
let aligned = physical.page_align_down::<L3>();
let offset = physical.page_offset::<L3>();
let page_count = (offset + layout.size() + L3::SIZE - 1) / L3::SIZE;
let virt = map_early_pages(aligned, page_count)?;
let value = core::slice::from_raw_parts_mut((virt + offset) as *mut T, len);
Ok(EarlyMapping { value, page_count })
}
}
impl<'a, T: ?Sized> Deref for EarlyMapping<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.value
}
}
impl<'a, T: ?Sized> DerefMut for EarlyMapping<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.value
}
}
impl<'a, T: ?Sized> Drop for EarlyMapping<'a, T> {
fn drop(&mut self) {
let address = (self.value as *mut T).addr() & !(L3::SIZE - 1);
for i in 0..self.page_count {
let page = address + i * L3::SIZE;
unsafe {
unmap_early_page(page);
}
}
}
}
pub fn clone_kernel_tables(dst: &mut PageTable<L0>) {
unsafe {
dst[KERNEL_L0_INDEX] = PageEntry::from_raw(KERNEL_TABLES.l0.data[KERNEL_L0_INDEX]);
dst[RAM_MAPPING_L0I] = PageEntry::from_raw(KERNEL_TABLES.l0.data[RAM_MAPPING_L0I]);
}
}
/// Sets up the following memory map:
/// ...: KERNEL_TABLES.l0:
/// * 0xFFFFFF0000000000 .. 0xFFFFFFFF8000000000 : RAM_MAPPING_L1
/// * 0xFFFFFF8000000000 .. ... : KERNEL_TABLES.kernel_l1:
/// * 0xFFFFFF8000000000 .. 0xFFFFFF8040000000 : KERNEL_TABLES.kernel_l2
/// * 0xFFFFFF8000000000 .. 0xFFFFFF8000200000 : ---
/// * 0xFFFFFF8000200000 .. 0xFFFFFF8000400000 : EARLY_MAPPING_L3
/// * 0xFFFFFF8000400000 .. ... : KERNEL_TABLES.kernel_l3s
/// * 0xFFFFFF8040000000 .. 0xFFFFFF8080000000 : HEAP_MAPPING_L2
/// * 0xFFFFFF8080000000 .. 0xFFFFFF8100000000 : DEVICE_MAPPING_L2
/// * 0xFFFFFF8080000000 .. 0xFFFFFF8080800000 : DEVICE_MAPPING_L3S
/// * 0xFFFFFF8080800000 .. 0xFFFFFF8100000000 : ...
pub unsafe fn init_fixed_tables() {
// TODO this could be built in compile-time too?
let early_mapping_l3_phys = addr_of!(EARLY_MAPPING_L3) as usize - KERNEL_VIRT_OFFSET;
let device_mapping_l2_phys = addr_of!(DEVICE_MAPPING_L2) as usize - KERNEL_VIRT_OFFSET;
let heap_mapping_l2_phys = addr_of!(HEAP_MAPPING_L2) as usize - KERNEL_VIRT_OFFSET;
let ram_mapping_l1_phys = addr_of!(RAM_MAPPING_L1) as usize - KERNEL_VIRT_OFFSET;
for i in 0..DEVICE_MAPPING_L3_COUNT {
let device_mapping_l3_phys = PhysicalAddress::from_raw(
&DEVICE_MAPPING_L3S[i] as *const _ as usize - KERNEL_VIRT_OFFSET,
);
DEVICE_MAPPING_L2[i] = PageEntry::table(device_mapping_l3_phys, PageAttributes::WRITABLE);
}
assert_eq!(KERNEL_TABLES.kernel_l2.data[EARLY_MAPPING_L2I], 0);
KERNEL_TABLES.kernel_l2.data[EARLY_MAPPING_L2I] = (early_mapping_l3_phys as u64)
| (PageAttributes::WRITABLE | PageAttributes::PRESENT).bits();
assert_eq!(KERNEL_TABLES.kernel_l1.data[HEAP_MAPPING_L1I], 0);
KERNEL_TABLES.kernel_l1.data[HEAP_MAPPING_L1I] =
(heap_mapping_l2_phys as u64) | (PageAttributes::WRITABLE | PageAttributes::PRESENT).bits();
assert_eq!(KERNEL_TABLES.kernel_l1.data[DEVICE_MAPPING_L1I], 0);
KERNEL_TABLES.kernel_l1.data[DEVICE_MAPPING_L1I] = (device_mapping_l2_phys as u64)
| (PageAttributes::WRITABLE | PageAttributes::PRESENT).bits();
assert_eq!(KERNEL_TABLES.l0.data[RAM_MAPPING_L0I], 0);
KERNEL_TABLES.l0.data[RAM_MAPPING_L0I] =
(ram_mapping_l1_phys as u64) | (PageAttributes::WRITABLE | PageAttributes::PRESENT).bits();
// TODO ENABLE EFER.NXE
let cr3 = &KERNEL_TABLES.l0 as *const _ as usize - KERNEL_VIRT_OFFSET;
CR3.set_address(cr3);
}
#[inline]
pub unsafe fn flush_tlb_entry(address: usize) {
core::arch::asm!("invlpg ({0})", in(reg) address, options(att_syntax));
}

View File

@ -0,0 +1,161 @@
//! x86-64-specific process address space management functions
use core::marker::PhantomData;
use libk_mm_interface::{
address::{AsPhysicalAddress, IntoRaw, PhysicalAddress},
pointer::PhysicalRefMut,
process::ProcessAddressSpaceManager,
table::{
EntryLevel, EntryLevelDrop, EntryLevelExt, MapAttributes, NextPageTable, TableAllocator,
},
};
use yggdrasil_abi::error::Error;
use crate::KernelTableManagerImpl;
use super::{
clone_kernel_tables, flush_tlb_entry,
table::{PageEntry, PageTable, L0, L1, L2, L3},
};
/// Represents a process or kernel address space. Because x86-64 does not have cool stuff like
/// TTBR0 and TTBR1, all address spaces are initially cloned from the kernel space.
#[repr(C)]
pub struct ProcessAddressSpaceImpl<TA: TableAllocator> {
l0: PhysicalRefMut<'static, PageTable<L0>, KernelTableManagerImpl>,
_alloc: PhantomData<TA>,
}
impl<TA: TableAllocator> ProcessAddressSpaceManager<TA> for ProcessAddressSpaceImpl<TA> {
// Start with 8GiB
const LOWER_LIMIT_PFN: usize = (8 << 30) / L3::SIZE;
// 16GiB VM limit
const UPPER_LIMIT_PFN: usize = (16 << 30) / L3::SIZE;
fn new() -> Result<Self, Error> {
let mut l0 = unsafe {
PhysicalRefMut::<'static, PageTable<L0>, KernelTableManagerImpl>::map(
TA::allocate_page_table()?,
)
};
for i in 0..512 {
l0[i] = PageEntry::INVALID;
}
clone_kernel_tables(&mut l0);
Ok(Self {
l0,
_alloc: PhantomData,
})
}
#[inline]
unsafe fn map_page(
&mut self,
address: usize,
physical: PhysicalAddress,
flags: MapAttributes,
) -> Result<(), Error> {
self.write_l3_entry(address, PageEntry::page(physical, flags.into()), false)
}
unsafe fn unmap_page(&mut self, address: usize) -> Result<PhysicalAddress, Error> {
self.pop_l3_entry(address)
}
#[inline]
fn translate(&self, address: usize) -> Result<(PhysicalAddress, MapAttributes), Error> {
self.read_l3_entry(address)
.ok_or(Error::InvalidMemoryOperation)
}
fn as_address_with_asid(&self) -> u64 {
// TODO x86-64 PCID/ASID?
unsafe { self.l0.as_physical_address().into_raw() }
}
unsafe fn clear(&mut self) {
self.l0
.drop_range::<TA>(0..((Self::UPPER_LIMIT_PFN * L3::SIZE).page_index::<L1>()));
}
}
impl<TA: TableAllocator> ProcessAddressSpaceImpl<TA> {
// Write a single 4KiB entry
fn write_l3_entry(
&mut self,
virt: usize,
entry: PageEntry<L3>,
overwrite: bool,
) -> Result<(), Error> {
let l0i = virt.page_index::<L0>();
let l1i = virt.page_index::<L1>();
let l2i = virt.page_index::<L2>();
let l3i = virt.page_index::<L3>();
let mut l1 = self.l0.get_mut_or_alloc::<TA>(l0i)?;
let mut l2 = l1.get_mut_or_alloc::<TA>(l1i)?;
let mut l3 = l2.get_mut_or_alloc::<TA>(l2i)?;
if l3[l3i].is_present() && !overwrite {
todo!();
}
l3[l3i] = entry;
unsafe {
flush_tlb_entry(virt);
}
Ok(())
}
fn pop_l3_entry(&mut self, virt: usize) -> Result<PhysicalAddress, Error> {
let l0i = virt.page_index::<L0>();
let l1i = virt.page_index::<L1>();
let l2i = virt.page_index::<L2>();
let l3i = virt.page_index::<L3>();
// TODO somehow drop tables if they're known to be empty?
let mut l1 = self.l0.get_mut(l0i).ok_or(Error::DoesNotExist)?;
let mut l2 = l1.get_mut(l1i).ok_or(Error::DoesNotExist)?;
let mut l3 = l2.get_mut(l2i).ok_or(Error::DoesNotExist)?;
let page = l3[l3i].as_page().ok_or(Error::DoesNotExist)?;
l3[l3i] = PageEntry::INVALID;
unsafe {
flush_tlb_entry(virt);
}
Ok(page)
}
fn read_l3_entry(&self, virt: usize) -> Option<(PhysicalAddress, MapAttributes)> {
let l0i = virt.page_index::<L0>();
let l1i = virt.page_index::<L1>();
let l2i = virt.page_index::<L2>();
let l3i = virt.page_index::<L3>();
let l1 = self.l0.get(l0i)?;
let l2 = l1.get(l1i)?;
let l3 = l2.get(l2i)?;
let page = l3[l3i].as_page()?;
Some((page, l3[l3i].attributes().into()))
}
}
impl<TA: TableAllocator> Drop for ProcessAddressSpaceImpl<TA> {
fn drop(&mut self) {
// SAFETY: with safe usage of the ProcessAddressSpaceImpl, clearing and dropping
// is safe, no one refers to the memory
unsafe {
self.clear();
let l0_phys = self.l0.as_physical_address();
TA::free_page_table(l0_phys);
}
}
}

View File

@ -0,0 +1,335 @@
//! x86-64-specific memory translation table management interfaces and functions
use core::{
marker::PhantomData,
ops::{Index, IndexMut, Range},
};
use bitflags::bitflags;
use libk_mm_interface::{
address::{AsPhysicalAddress, FromRaw, PhysicalAddress},
pointer::{PhysicalRef, PhysicalRefMut},
table::{
EntryLevel, EntryLevelDrop, MapAttributes, NextPageTable, NonTerminalEntryLevel,
TableAllocator,
},
};
use yggdrasil_abi::error::Error;
use crate::KernelTableManagerImpl;
bitflags! {
/// Describes how each page table entry is mapped
pub struct PageAttributes: u64 {
/// When set, the mapping is considered valid and pointing somewhere
const PRESENT = 1 << 0;
/// For tables, allows writes to further translation levels, for pages/blocks, allows
/// writes to the region covered by the entry
const WRITABLE = 1 << 1;
/// When set for L2 entries, the mapping specifies a 2MiB page instead of a page table
/// reference
const BLOCK = 1 << 7;
/// For tables, allows user access to further translation levels, for pages/blocks, allows
/// user access to the region covered by the entry
const USER = 1 << 2;
}
}
/// Represents a single virtual address space mapping depending on its translation level
#[derive(Clone, Copy, Debug)]
#[repr(transparent)]
pub struct PageEntry<L: EntryLevel>(u64, PhantomData<L>);
/// Table describing a single level of address translation
#[derive(Clone, Copy)]
#[repr(C, align(0x1000))]
pub struct PageTable<L: EntryLevel> {
data: [PageEntry<L>; 512],
}
/// Translation level 0 (PML4): Entry is 512GiB table
#[derive(Clone, Copy, Debug)]
pub struct L0;
/// Translation level 1 (PDPT): Entry is 1GiB table
#[derive(Clone, Copy, Debug)]
pub struct L1;
/// Translation level 2 (Page directory): Entry is 2MiB block/table
#[derive(Clone, Copy, Debug)]
pub struct L2;
/// Translation level 3 (Page table): Entry is 4KiB page
#[derive(Clone, Copy, Debug)]
pub struct L3;
impl NonTerminalEntryLevel for L0 {
type NextLevel = L1;
}
impl NonTerminalEntryLevel for L1 {
type NextLevel = L2;
}
impl NonTerminalEntryLevel for L2 {
type NextLevel = L3;
}
impl EntryLevel for L0 {
const SHIFT: usize = 39;
}
impl EntryLevel for L1 {
const SHIFT: usize = 30;
}
impl EntryLevel for L2 {
const SHIFT: usize = 21;
}
impl EntryLevel for L3 {
const SHIFT: usize = 12;
}
impl PageEntry<L3> {
/// Constructs a mapping which points to a 4KiB page
pub fn page(phys: PhysicalAddress, attrs: PageAttributes) -> Self {
Self(
u64::from(phys) | (attrs | PageAttributes::PRESENT | PageAttributes::USER).bits(),
PhantomData,
)
}
/// Returns the physical address of the page this entry refers to, returning None if it does
/// not
pub fn as_page(self) -> Option<PhysicalAddress> {
if self.0 & PageAttributes::PRESENT.bits() != 0 {
Some(PhysicalAddress::from_raw(self.0 & !0xFFF))
} else {
None
}
}
}
impl PageEntry<L2> {
/// Constructs a mapping which points to a 2MiB block
pub fn block(phys: PhysicalAddress, attrs: PageAttributes) -> Self {
Self(
u64::from(phys) | (attrs | PageAttributes::PRESENT | PageAttributes::BLOCK).bits(),
PhantomData,
)
}
}
impl PageEntry<L1> {
/// Constructs a mapping which points to a 1GiB block
pub fn block(phys: PhysicalAddress, attrs: PageAttributes) -> Self {
Self(
u64::from(phys) | (attrs | PageAttributes::PRESENT | PageAttributes::BLOCK).bits(),
PhantomData,
)
}
}
impl<L: NonTerminalEntryLevel> PageEntry<L> {
/// Constructs a mapping which points to a next-level table
pub fn table(phys: PhysicalAddress, attrs: PageAttributes) -> Self {
Self(
u64::from(phys)
| (attrs
| PageAttributes::PRESENT
| PageAttributes::WRITABLE
| PageAttributes::USER)
.bits(),
PhantomData,
)
}
/// Returns the physical address of the table this entry refers to, returning None if it
/// does not
pub fn as_table(self) -> Option<PhysicalAddress> {
if self.0 & PageAttributes::PRESENT.bits() != 0
&& self.0 & PageAttributes::BLOCK.bits() == 0
{
Some(PhysicalAddress::from_raw(self.0 & !0xFFF))
} else {
None
}
}
/// Returns `true` if the mapping represents a "page"/"block" and not a table
pub fn is_block(self) -> bool {
self.0 & PageAttributes::BLOCK.bits() != 0
}
}
impl<L: EntryLevel> PageEntry<L> {
/// An entry that is not mapped
pub const INVALID: Self = Self(0, PhantomData);
/// Reinterprets raw [u64] as a [PageEntry].
///
/// # Safety
///
/// Unsafe: the caller must ensure the value is a valid page translation entry.
pub const unsafe fn from_raw(raw: u64) -> Self {
Self(raw, PhantomData)
}
/// Returns the translation attributes of the entry
pub fn attributes(&self) -> PageAttributes {
PageAttributes::from_bits_retain(self.0)
}
/// Returns `true` if the entry contains a valid mapping to either a table or to a page/block
pub fn is_present(&self) -> bool {
self.0 & PageAttributes::PRESENT.bits() != 0
}
}
impl<L: EntryLevel> PageTable<L> {
/// Constructs a page table filled with invalid (non-present) entries
pub const fn zeroed() -> Self {
Self {
data: [PageEntry::INVALID; 512],
}
}
/// Reinterprets given [PageEntry] slice as a reference to [PageTable].
///
/// # Safety
///
/// Unsafe: the caller must ensure the provided reference is properly aligned and contains sane
/// data.
pub unsafe fn from_raw_slice_mut(data: &mut [PageEntry<L>; 512]) -> &mut Self {
core::mem::transmute(data)
}
/// Allocates a new page table, filling it with non-preset entries
pub fn new_zeroed<'a, TA: TableAllocator>(
) -> Result<PhysicalRefMut<'a, Self, KernelTableManagerImpl>, Error> {
let physical = TA::allocate_page_table()?;
let mut table =
unsafe { PhysicalRefMut::<'a, Self, KernelTableManagerImpl>::map(physical) };
for i in 0..512 {
table[i] = PageEntry::INVALID;
}
Ok(table)
}
// /// Returns the physical address of this table
// pub fn physical_address(&self) -> usize {
// unsafe { (self.data.as_ptr() as usize).physicalize() }
// }
}
impl<L: NonTerminalEntryLevel + 'static> NextPageTable for PageTable<L> {
type NextLevel = PageTable<L::NextLevel>;
type TableRef = PhysicalRef<'static, Self::NextLevel, KernelTableManagerImpl>;
type TableRefMut = PhysicalRefMut<'static, Self::NextLevel, KernelTableManagerImpl>;
fn get(&self, index: usize) -> Option<Self::TableRef> {
self[index]
.as_table()
.map(|addr| unsafe { PhysicalRef::map(addr) })
}
fn get_mut(&mut self, index: usize) -> Option<Self::TableRefMut> {
self[index]
.as_table()
.map(|addr| unsafe { PhysicalRefMut::map(addr) })
}
fn get_mut_or_alloc<TA: TableAllocator>(
&mut self,
index: usize,
) -> Result<Self::TableRefMut, Error> {
let entry = self[index];
if let Some(table) = entry.as_table() {
Ok(unsafe { PhysicalRefMut::map(table) })
} else {
let table = PageTable::new_zeroed::<TA>()?;
self[index] = PageEntry::<L>::table(
unsafe { table.as_physical_address() },
PageAttributes::WRITABLE | PageAttributes::USER,
);
Ok(table)
}
}
}
impl EntryLevelDrop for PageTable<L3> {
const FULL_RANGE: Range<usize> = 0..512;
// Do nothing
unsafe fn drop_range<TA: TableAllocator>(&mut self, _range: Range<usize>) {}
}
impl<L: NonTerminalEntryLevel + 'static> EntryLevelDrop for PageTable<L>
where
PageTable<L::NextLevel>: EntryLevelDrop,
{
const FULL_RANGE: Range<usize> = 0..512;
unsafe fn drop_range<TA: TableAllocator>(&mut self, range: Range<usize>) {
for index in range {
let entry = self[index];
if let Some(table) = entry.as_table() {
let mut table_ref: PhysicalRefMut<PageTable<L::NextLevel>, KernelTableManagerImpl> =
PhysicalRefMut::map(table);
table_ref.drop_all::<TA>();
// Drop the table
drop(table_ref);
TA::free_page_table(table);
} else if entry.is_present() {
// Memory must've been cleared beforehand, so no non-table entries must be present
panic!(
"Expected a table containing only tables, got table[{}] = {:#x?}",
index, entry.0
);
}
self[index] = PageEntry::INVALID;
}
}
}
impl<L: EntryLevel> Index<usize> for PageTable<L> {
type Output = PageEntry<L>;
fn index(&self, index: usize) -> &Self::Output {
&self.data[index]
}
}
impl<L: EntryLevel> IndexMut<usize> for PageTable<L> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.data[index]
}
}
impl From<MapAttributes> for PageAttributes {
fn from(value: MapAttributes) -> Self {
let mut res = PageAttributes::WRITABLE;
if value.intersects(MapAttributes::USER_READ | MapAttributes::USER_WRITE) {
res |= PageAttributes::USER;
}
res
}
}
impl From<PageAttributes> for MapAttributes {
fn from(value: PageAttributes) -> Self {
let mut res = MapAttributes::empty();
if value.contains(PageAttributes::USER) {
res |= MapAttributes::USER_READ;
if value.contains(PageAttributes::WRITABLE) {
res |= MapAttributes::USER_WRITE;
}
}
// TODO ???
res |= MapAttributes::NON_GLOBAL;
res
}
}

View File

@ -0,0 +1,426 @@
//! Helper types for interfacing with x86-64 registers
#![allow(unused)]
macro_rules! impl_read {
($t:ident, $register:ty, $body:expr) => {
impl tock_registers::interfaces::Readable for $t {
type T = u64;
type R = $register;
#[inline]
fn get(&self) -> u64 {
$body
}
}
};
}
macro_rules! impl_write {
($t:ident, $register:ty, $value:ident, $body:expr) => {
impl tock_registers::interfaces::Writeable for $t {
type T = u64;
type R = $register;
#[inline]
fn set(&self, $value: u64) {
$body
}
}
};
}
macro_rules! msr_impl_read {
($t:ident, $addr:expr, $register:ty) => {
impl_read!($t, $register, {
let (high, low): (u32, u32);
unsafe {
core::arch::asm!(
"rdmsr",
in("ecx") $addr,
out("eax") low,
out("edx") high,
options(att_syntax)
);
}
((high as u64) << 32) | (low as u64)
});
};
($t:ident, $addr:expr) => { msr_impl_read!($t, $addr, ()); };
}
macro_rules! msr_impl_write {
($t:ident, $addr:expr, $register:ty) => {
impl_write!($t, $register, value, {
let low = value as u32;
let high = (value >> 32) as u32;
unsafe {
core::arch::asm!(
"wrmsr",
in("ecx") $addr,
in("eax") low,
in("edx") high,
options(att_syntax)
);
}
});
};
($t:ident, $addr:expr) => { msr_impl_write!($t, $addr, ()); };
}
macro_rules! cr_impl_read {
($t:ident, $cr:ident, $register:ty) => {
impl_read!($t, $register, {
let value: u64;
unsafe {
core::arch::asm!(
concat!("mov %", stringify!($cr), ", {}"),
out(reg) value,
options(att_syntax)
);
}
value
});
};
}
macro_rules! cr_impl_write {
($t:ident, $cr:ident, $register:ty) => {
impl_write!($t, $register, value, {
unsafe {
core::arch::asm!(
concat!("mov {}, %", stringify!($cr)),
in(reg) value,
options(att_syntax)
);
}
});
};
}
mod msr_ia32_kernel_gs_base {
const ADDR: u32 = 0xC0000102;
pub struct Reg;
msr_impl_read!(Reg, ADDR);
msr_impl_write!(Reg, ADDR);
/// IA32_KERNEL_GS_BASE model-specific register. Provides the base address for %gs-relative
/// loads/stores.
pub const MSR_IA32_KERNEL_GS_BASE: Reg = Reg;
}
mod msr_ia32_apic_base {
use tock_registers::{interfaces::Readable, register_bitfields};
register_bitfields! {
u64,
#[allow(missing_docs)]
#[doc = "IA32_APIC_BASE model-specific register"]
pub MSR_IA32_APIC_BASE [
#[doc = "Contains a virtual page number of the Local APIC base address for this processor"]
AddressPage OFFSET(12) NUMBITS(40) [],
#[doc = "If set, the APIC is enabled"]
ApicEnable OFFSET(11) NUMBITS(1) [],
#[doc = "If set, x2APIC mode is enabled"]
ExtendedEnable OFFSET(10) NUMBITS(1) [],
#[doc = "If set, this CPU is a bootstrap processor"]
BootstrapCpuCore OFFSET(8) NUMBITS(1) [],
]
}
const ADDR: u32 = 0x0000001B;
pub struct Reg;
msr_impl_read!(Reg, ADDR, MSR_IA32_APIC_BASE::Register);
msr_impl_write!(Reg, ADDR, MSR_IA32_APIC_BASE::Register);
impl Reg {
#[inline]
pub fn read_base(&self) -> u64 {
self.read(MSR_IA32_APIC_BASE::AddressPage) << 12
}
}
/// IA32_APIC_BASE model-specific register
pub const MSR_IA32_APIC_BASE: Reg = Reg;
}
mod msr_ia32_sfmask {
use tock_registers::register_bitfields;
register_bitfields! {
u64,
#[allow(missing_docs)]
pub MSR_IA32_SFMASK [
IF OFFSET(9) NUMBITS(1) [
Masked = 1,
Unmasked = 0
]
]
}
const ADDR: u32 = 0xC0000084;
pub struct Reg;
msr_impl_read!(Reg, ADDR, MSR_IA32_SFMASK::Register);
msr_impl_write!(Reg, ADDR, MSR_IA32_SFMASK::Register);
/// IA32_SFMASK model-specific register
pub const MSR_IA32_SFMASK: Reg = Reg;
}
mod msr_ia32_star {
use tock_registers::register_bitfields;
register_bitfields! {
u64,
#[allow(missing_docs)]
pub MSR_IA32_STAR [
SYSCALL_CS_SS OFFSET(32) NUMBITS(16) [],
SYSRET_CS_SS OFFSET(48) NUMBITS(16) [],
]
}
const ADDR: u32 = 0xC0000081;
pub struct Reg;
msr_impl_read!(Reg, ADDR, MSR_IA32_STAR::Register);
msr_impl_write!(Reg, ADDR, MSR_IA32_STAR::Register);
/// IA32_STAR model-specific register
pub const MSR_IA32_STAR: Reg = Reg;
}
mod msr_ia32_lstar {
const ADDR: u32 = 0xC0000082;
pub struct Reg;
msr_impl_read!(Reg, ADDR);
msr_impl_write!(Reg, ADDR);
/// IA32_LSTAR model-specific register
pub const MSR_IA32_LSTAR: Reg = Reg;
}
mod msr_ia32_efer {
use tock_registers::register_bitfields;
register_bitfields! {
u64,
#[allow(missing_docs)]
pub MSR_IA32_EFER [
// If set, support for SYSCALL/SYSRET instructions is enabled
SCE OFFSET(0) NUMBITS(1) [
Enable = 1,
Disable = 0
]
]
}
const ADDR: u32 = 0xC0000080;
pub struct Reg;
msr_impl_read!(Reg, ADDR, MSR_IA32_EFER::Register);
msr_impl_write!(Reg, ADDR, MSR_IA32_EFER::Register);
/// IA32_EFER Extended Feature Enable model-specific Register
pub const MSR_IA32_EFER: Reg = Reg;
}
mod cr0 {
use tock_registers::register_bitfields;
register_bitfields! {
u64,
#[allow(missing_docs)]
pub CR0 [
PG OFFSET(31) NUMBITS(1) [],
CD OFFSET(30) NUMBITS(1) [],
NW OFFSET(29) NUMBITS(1) [],
AM OFFSET(18) NUMBITS(1) [],
WP OFFSET(16) NUMBITS(1) [],
NE OFFSET(5) NUMBITS(1) [],
ET OFFSET(4) NUMBITS(1) [],
TS OFFSET(3) NUMBITS(1) [],
EM OFFSET(2) NUMBITS(1) [],
MP OFFSET(1) NUMBITS(1) [],
PE OFFSET(0) NUMBITS(1) [],
]
}
pub struct Reg;
cr_impl_read!(Reg, cr0, CR0::Register);
cr_impl_write!(Reg, cr0, CR0::Register);
/// x86-64 control register 0
pub const CR0: Reg = Reg;
}
mod cr3 {
use tock_registers::{interfaces::ReadWriteable, register_bitfields};
register_bitfields! {
u64,
#[allow(missing_docs)]
pub CR3 [
ADDR OFFSET(12) NUMBITS(40) [],
]
}
pub struct Reg;
cr_impl_read!(Reg, cr3, CR3::Register);
cr_impl_write!(Reg, cr3, CR3::Register);
impl Reg {
pub fn set_address(&self, address: usize) {
assert_eq!(address & 0xFFF, 0);
self.modify(CR3::ADDR.val((address as u64) >> 12))
}
}
/// x86-64 control register 3
pub const CR3: Reg = Reg;
}
mod cr4 {
use tock_registers::register_bitfields;
register_bitfields! {
u64,
#[allow(missing_docs)]
pub CR4 [
/// If set, XSAVE and extended processor states are enabled
OSXSAVE OFFSET(18) NUMBITS(1) [],
/// Indicates OS support for FXSAVE and FXRSTOR instructions
OSFXSR OFFSET(9) NUMBITS(1) [],
/// Performance-Monitoring Counter enable
PCE OFFSET(8) NUMBITS(1) [],
/// If set, "page global" attribute is enabled
PGE OFFSET(7) NUMBITS(1) [],
/// Machine Check enable
MCE OFFSET(6) NUMBITS(1) [],
/// Physical Address Extension (enabled if 64-bit mode)
PAE OFFSET(5) NUMBITS(1) [],
/// Page Size Extension (should be enabled by yboot)
PSE OFFSET(4) NUMBITS(1) [],
/// Debugging extensions
DE OFFSET(3) NUMBITS(1) [],
TSD OFFSET(2) NUMBITS(1) [],
PVI OFFSET(1) NUMBITS(1) [],
VME OFFSET(0) NUMBITS(1) [],
]
}
pub struct Reg;
cr_impl_read!(Reg, cr4, CR4::Register);
cr_impl_write!(Reg, cr4, CR4::Register);
/// x86-64 control register 4
pub const CR4: Reg = Reg;
}
mod xcr0 {
use tock_registers::{
interfaces::{Readable, Writeable},
register_bitfields,
};
register_bitfields! {
u64,
#[allow(missing_docs)]
pub XCR0 [
/// If set, x87 FPU/MMX is enabled
X87 OFFSET(0) NUMBITS(1) [],
/// If set, XSAVE support for MXCSR and XMM registers is enabled
SSE OFFSET(1) NUMBITS(1) [],
/// If set, AVX is enabled and XSAVE supports YMM upper halves
AVX OFFSET(2) NUMBITS(1) [],
]
}
pub struct Reg;
impl Readable for Reg {
type T = u64;
type R = XCR0::Register;
fn get(&self) -> Self::T {
let eax: u32;
let edx: u32;
unsafe {
core::arch::asm!(
"xgetbv",
in("ecx") 0,
out("eax") eax,
out("edx") edx,
options(att_syntax)
);
}
((edx as u64) << 32) | (eax as u64)
}
}
impl Writeable for Reg {
type T = u64;
type R = XCR0::Register;
fn set(&self, value: Self::T) {
let eax = value as u32;
let edx = (value >> 32) as u32;
unsafe {
core::arch::asm!(
"xsetbv",
in("ecx") 0,
in("eax") eax,
in("edx") edx,
options(att_syntax)
);
}
}
}
/// Extended control register for SSE/AVX/FPU configuration
pub const XCR0: Reg = Reg;
}
use core::ptr::NonNull;
pub use cr0::CR0;
pub use cr3::CR3;
pub use cr4::CR4;
pub use msr_ia32_apic_base::MSR_IA32_APIC_BASE;
pub use msr_ia32_efer::MSR_IA32_EFER;
pub use msr_ia32_kernel_gs_base::MSR_IA32_KERNEL_GS_BASE;
pub use msr_ia32_lstar::MSR_IA32_LSTAR;
pub use msr_ia32_sfmask::MSR_IA32_SFMASK;
pub use msr_ia32_star::MSR_IA32_STAR;
pub use xcr0::XCR0;
#[repr(C, align(0x10))]
pub struct FpuContext {
data: [u8; 512],
}
impl FpuContext {
pub fn new() -> Self {
let mut value = Self { data: [0; 512] };
unsafe {
let ptr = value.data.as_mut_ptr();
core::arch::asm!("fninit; fxsave64 ({})", in(reg) ptr, options(att_syntax));
}
value
}
pub unsafe fn save(dst: *mut FpuContext) {
core::arch::asm!("fxsave64 ({})", in(reg) dst, options(att_syntax));
}
pub unsafe fn restore(src: *mut FpuContext) {
core::arch::asm!("fxrstor64 ({})", in(reg) src, options(att_syntax));
}
}

74
kernel/build.rs Normal file
View File

@ -0,0 +1,74 @@
use std::{
env, fs,
io::{self, Write},
path::{Path, PathBuf},
process::Command,
};
use abi_generator::{
abi::{ty::TypeWidth, AbiBuilder},
syntax::UnwrapFancy,
TargetEnv,
};
fn build_x86_64() {
const DEFAULT_8086_AS: &str = "nasm";
const AP_BOOTSTRAP_S: &str = "src/arch/x86_64/boot/ap_boot.S";
println!("cargo:rerun-if-changed={}", AP_BOOTSTRAP_S);
let out_dir = env::var("OUT_DIR").unwrap();
let assembler = env::var("AS8086").unwrap_or(DEFAULT_8086_AS.to_owned());
let ap_bootstrap_out = PathBuf::from(out_dir).join("__x86_64_ap_boot.bin");
// Assemble the code
let output = Command::new(assembler.as_str())
.args([
"-fbin",
"-o",
ap_bootstrap_out.to_str().unwrap(),
AP_BOOTSTRAP_S,
])
.output()
.unwrap();
if !output.status.success() {
io::stderr().write_all(&output.stderr).ok();
panic!("{}: could not assemble {}", assembler, AP_BOOTSTRAP_S);
}
}
fn generate_syscall_dispatcher<P: AsRef<Path>>(out_dir: P) {
let abi: AbiBuilder = AbiBuilder::from_string(
yggdrasil_abi_def::ABI_FILE,
TargetEnv {
thin_pointer_width: TypeWidth::U64,
fat_pointer_width: TypeWidth::U128,
},
)
.unwrap_fancy("");
let generated_dispatcher = out_dir.as_ref().join("generated_dispatcher.rs");
let file = prettyplease::unparse(
&abi.emit_syscall_dispatcher("handle_syscall", "impls")
.unwrap_fancy(""),
);
fs::write(generated_dispatcher, file.as_bytes()).unwrap();
}
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
generate_syscall_dispatcher(&out_dir);
println!("cargo:rerun-if-changed=build.rs");
match arch.as_str() {
"x86_64" => build_x86_64(),
"aarch64" => (),
_ => panic!("Unknown target arch: {:?}", arch),
}
}

View File

@ -0,0 +1,26 @@
[package]
name = "ygg_driver_ahci"
version = "0.1.0"
edition = "2021"
authors = ["Mark Poliakov <mark@alnyan.me>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
libk-mm = { path = "../../../libk/libk-mm" }
libk-thread = { path = "../../../libk/libk-thread" }
libk-util = { path = "../../../libk/libk-util" }
device-api = { path = "../../../lib/device-api", features = ["derive"] }
vfs = { path = "../../../lib/vfs" }
ygg_driver_pci = { path = "../../bus/pci" }
ygg_driver_block = { path = "../../block/core" }
kernel-fs = { path = "../../fs/kernel-fs" }
log = "0.4.20"
futures-util = { version = "0.3.28", default-features = false, features = ["alloc", "async-await"] }
static_assertions = "1.1.0"
tock-registers = "0.8.1"
bytemuck = { version = "1.14.0", features = ["derive"] }
memoffset = "0.9.0"

View File

@ -0,0 +1,140 @@
use core::mem::{size_of, MaybeUninit};
use libk_mm::{
address::{AsPhysicalAddress, PhysicalAddress},
PageBox,
};
use tock_registers::register_structs;
use crate::{data::AtaString, error::AhciError, SECTOR_SIZE};
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[repr(u8)]
pub enum AtaCommandId {
Identify = 0xEC,
ReadDmaEx = 0x25,
}
pub trait AtaCommand {
type Response;
const COMMAND_ID: AtaCommandId;
fn lba(&self) -> u64;
fn sector_count(&self) -> usize;
fn regions(&self) -> &[(PhysicalAddress, usize)];
unsafe fn into_response(self) -> Self::Response;
}
register_structs! {
// Offsets in the ATA8-ACS spec are in words, so each value take from there is
// multiplied by two
pub AtaIdentifyResponse {
(0 => _0),
(20 => pub serial_number: AtaString<20>),
(40 => _1),
(54 => pub model_number: AtaString<40>),
(94 => _2),
(98 => pub capabilities: [u16; 2]),
(102 => _3),
(120 => pub logical_sector_count_28: [u8; 4]),
(124 => _4),
(138 => pub additional_features: u16),
(140 => _5),
(164 => pub command_sets: [u16; 6]),
(176 => _6),
(200 => pub logical_sector_count_qw: u64),
(208 => _7),
(212 => pub phys_logical_sector_size: u16),
(214 => _8),
(234 => pub logical_sector_size: [u16; 2]),
(238 => _9),
(460 => pub ext_logical_sector_count_qw: [u8; 8]),
(468 => _10),
(512 => @END),
}
}
pub struct AtaIdentify {
buffer: PageBox<MaybeUninit<AtaIdentifyResponse>>,
regions: [(PhysicalAddress, usize); 1],
}
pub struct AtaReadDmaEx {
lba: u64,
sector_count: usize,
regions: [(PhysicalAddress, usize); 1],
}
impl AtaIdentify {
pub fn create() -> Result<Self, AhciError> {
PageBox::new_uninit()
.map(Self::with_data)
.map_err(AhciError::MemoryError)
}
pub fn with_data(buffer: PageBox<MaybeUninit<AtaIdentifyResponse>>) -> Self {
Self {
regions: [(
unsafe { buffer.as_physical_address() },
size_of::<AtaIdentifyResponse>(),
)],
buffer,
}
}
}
impl AtaReadDmaEx {
pub fn new(lba: u64, sector_count: usize, buffer: &PageBox<[MaybeUninit<u8>]>) -> Self {
assert_eq!(buffer.len() % SECTOR_SIZE, 0);
assert_ne!(buffer.len(), 0);
Self {
lba,
sector_count,
regions: [(unsafe { buffer.as_physical_address() }, buffer.len())],
}
}
}
impl AtaCommand for AtaIdentify {
type Response = PageBox<AtaIdentifyResponse>;
const COMMAND_ID: AtaCommandId = AtaCommandId::Identify;
fn lba(&self) -> u64 {
0
}
fn sector_count(&self) -> usize {
0
}
fn regions(&self) -> &[(PhysicalAddress, usize)] {
&self.regions
}
unsafe fn into_response(self) -> Self::Response {
self.buffer.assume_init()
}
}
impl AtaCommand for AtaReadDmaEx {
type Response = ();
const COMMAND_ID: AtaCommandId = AtaCommandId::ReadDmaEx;
fn lba(&self) -> u64 {
self.lba
}
fn sector_count(&self) -> usize {
self.sector_count
}
fn regions(&self) -> &[(PhysicalAddress, usize)] {
&self.regions
}
unsafe fn into_response(self) -> Self::Response {}
}

View File

@ -0,0 +1,252 @@
use core::mem::size_of;
use alloc::string::String;
use bytemuck::{Pod, Zeroable};
use libk_mm::address::{IntoRaw, PhysicalAddress};
use libk_util::{ConstAssert, IsTrue};
use static_assertions::const_assert_eq;
use crate::{
command::{AtaCommand, AtaIdentify, AtaIdentifyResponse},
error::AhciError,
MAX_PRD_SIZE,
};
pub const COMMAND_LIST_LENGTH: usize = 32;
const AHCI_FIS_REG_H2D_COMMAND: u8 = 1 << 7;
const AHCI_FIS_REG_H2D: u8 = 0x27;
#[repr(C)]
pub struct AtaString<const N: usize>
where
ConstAssert<{ N % 2 == 0 }>: IsTrue,
{
data: [u8; N],
}
#[derive(Debug, Clone, Copy, Zeroable)]
#[repr(C)]
pub struct PhysicalRegionDescriptor {
buffer_address: u64,
_0: u32,
dbc: u32,
}
#[derive(Debug, Clone, Copy, Zeroable)]
#[repr(C)]
pub struct CommandListEntry {
attr: u16,
prdtl: u16,
prdbc: u32,
ctba: u64,
_0: [u32; 4],
}
#[derive(Clone, Copy, Zeroable)]
#[repr(C)]
pub union SentFis {
reg_h2d: RegisterHostToDeviceFis,
raw: RawFis,
}
#[derive(Clone, Copy, Zeroable)]
#[repr(C)]
pub struct RegisterHostToDeviceFis {
pub ty: u8,
pub cmd_port: u8,
pub cmd: u8,
pub feature_low: u8,
pub lba0: u8,
pub lba1: u8,
pub lba2: u8,
pub device: u8,
pub lba3: u8,
pub lba4: u8,
pub lba5: u8,
pub feature_high: u8,
pub count: u16,
pub icc: u8,
pub control: u8,
_0: u32,
}
#[derive(Clone, Copy, Zeroable, Pod)]
#[repr(C)]
pub struct RawFis {
pub bytes: [u8; 64],
}
#[derive(Clone, Copy)]
#[repr(C)]
pub struct CommandTable {
fis: SentFis, // 0x00..0x40
_0: [u8; 16], // 0x40..0x50
_1: [u8; 48], // 0x50..0x80
prdt: [PhysicalRegionDescriptor; 248], // 0x80..0x1000
}
#[derive(Clone, Copy, Zeroable, Pod)]
#[repr(C)]
pub struct ReceivedFis {
_dsfis: [u8; 0x1C], // 0x00..0x1C
_0: [u8; 0x04], // 0x1C..0x20
_psfis: [u8; 0x14], // 0x20..0x34
_1: [u8; 0x0C], // 0x34..0x40
_rfis: [u8; 0x14], // 0x40..0x54
_2: [u8; 0x04], // 0x54..0x58
_sdbfis: [u8; 0x08], // 0x58..0x60
_ufis: [u8; 0x40], // 0x60..0xA0
_3: [u8; 0x60], // 0xA0..0x100
}
const_assert_eq!(size_of::<SentFis>(), 0x40);
const_assert_eq!(size_of::<CommandTable>(), 0x1000);
const_assert_eq!(size_of::<CommandListEntry>(), 32);
const_assert_eq!(size_of::<ReceivedFis>(), 0x100);
impl CommandTable {
pub fn setup_command<C: AtaCommand>(&mut self, command: &C) -> Result<(), AhciError> {
let lba = command.lba();
assert_eq!(lba & !0xFFFFFFFFFF, 0);
let count = command.sector_count().try_into().unwrap();
if C::COMMAND_ID == AtaIdentify::COMMAND_ID {
self.fis = SentFis {
reg_h2d: RegisterHostToDeviceFis {
ty: AHCI_FIS_REG_H2D,
cmd_port: AHCI_FIS_REG_H2D_COMMAND,
cmd: C::COMMAND_ID as _,
..RegisterHostToDeviceFis::zeroed()
},
};
} else {
self.fis = SentFis {
reg_h2d: RegisterHostToDeviceFis {
ty: AHCI_FIS_REG_H2D,
cmd_port: AHCI_FIS_REG_H2D_COMMAND,
cmd: C::COMMAND_ID as _,
device: 1 << 6, // LBA mode
lba0: lba as u8,
lba1: (lba >> 8) as u8,
lba2: (lba >> 16) as u8,
lba3: (lba >> 24) as u8,
lba4: (lba >> 32) as u8,
lba5: (lba >> 40) as u8,
count,
..RegisterHostToDeviceFis::zeroed()
},
};
}
let regions = command.regions();
for (i, &(base, size)) in regions.iter().enumerate() {
let last = i == regions.len() - 1;
self.prdt[i] = PhysicalRegionDescriptor::new(base, size, last)?;
}
Ok(())
}
}
impl CommandListEntry {
pub fn new(command_table_entry: PhysicalAddress, prd_count: usize) -> Result<Self, AhciError> {
if prd_count > 0xFFFF {
todo!()
}
Ok(Self {
// attr = FIS size in dwords
attr: (size_of::<RegisterHostToDeviceFis>() / size_of::<u32>()) as _,
prdtl: prd_count as _,
prdbc: 0,
ctba: command_table_entry.into_raw(),
_0: [0; 4],
})
}
}
unsafe impl Zeroable for CommandTable {
fn zeroed() -> Self {
Self {
fis: SentFis::zeroed(),
_0: [0; 16],
_1: [0; 48],
prdt: [PhysicalRegionDescriptor::zeroed(); 248],
}
}
}
impl PhysicalRegionDescriptor {
pub fn new(
address: PhysicalAddress,
byte_count: usize,
is_last: bool,
) -> Result<Self, AhciError> {
if byte_count >= MAX_PRD_SIZE {
return Err(AhciError::RegionTooLarge);
}
let dbc_mask = (is_last as u32) << 31;
Ok(Self {
buffer_address: address.into_raw(),
_0: 0,
dbc: ((byte_count as u32 - 1) << 1) | 1 | dbc_mask,
})
}
}
impl AtaIdentifyResponse {
pub fn logical_sector_count(&self) -> u64 {
// If logical_sector_count_28 == 0x0FFFFFFF, and logical_sector_count_qw >= 0x0FFFFFFF,
// then ACCEESSIBLE CAPACITY (?) field contains the total number of user addressable
// LBAs (see 4.1)
// bit 3 in additional_features -> logical_sector_count_qw:
// 0 -> max value = 0xFFFFFFFFFFFF (48)
// 1 -> max value = 0xFFFFFFFF (32)
// If bit 3 in additional_features is set, ext_logical_sector_count_qw contains
// the maximum addressable LBA count. Max value = 0xFFFFFFFFFFFF (48)
if self.command_sets[1] & (1 << 10) != 0 {
// 48-bit supported
if self.additional_features & (1 << 3) != 0 {
// Use ext_logical_sector_count_qw
todo!()
} else {
// Use logical_sector_count_qw
self.logical_sector_count_qw
}
} else {
todo!()
}
}
}
impl<const N: usize> AtaString<N>
where
ConstAssert<{ N % 2 == 0 }>: IsTrue,
{
#[allow(clippy::inherent_to_string)]
pub fn to_string(&self) -> String {
let mut buf = [0; N];
for i in (0..N).step_by(2) {
buf[i] = self.data[i + 1];
buf[i + 1] = self.data[i];
}
let mut len = 0;
for i in (0..N).rev() {
if buf[i] != b' ' {
len = i + 1;
break;
}
}
String::from(core::str::from_utf8(&buf[..len]).unwrap())
}
}

View File

@ -0,0 +1,8 @@
use yggdrasil_abi::error::Error;
#[derive(Debug)]
pub enum AhciError {
MemoryError(Error),
RegionTooLarge,
DeviceError,
}

View File

@ -0,0 +1,256 @@
#![feature(generic_const_exprs, inline_const)]
#![allow(incomplete_features)]
#![no_std]
extern crate alloc;
use alloc::{boxed::Box, format, vec, vec::Vec};
use bytemuck::Zeroable;
use data::ReceivedFis;
use device_api::{
interrupt::{InterruptAffinity, InterruptHandler},
Device,
};
use error::AhciError;
use kernel_fs::devfs;
use libk_mm::{address::AsPhysicalAddress, device::DeviceMemoryIo, PageBox};
use libk_thread::runtime;
use libk_util::{sync::IrqSafeSpinlock, OneTimeInit};
use port::AhciPort;
use regs::{PortRegs, Regs};
use tock_registers::interfaces::{ReadWriteable, Readable, Writeable};
use ygg_driver_block::{probe_partitions, NgBlockDeviceWrapper};
use ygg_driver_pci::{
device::{PciDeviceInfo, PreferredInterruptMode},
PciCommandRegister, PciConfigurationSpace,
};
use yggdrasil_abi::error::Error;
use crate::regs::{Version, CAP, GHC, SSTS};
mod command;
mod data;
mod error;
mod port;
mod regs;
const MAX_PRD_SIZE: usize = 8192;
const MAX_COMMANDS: usize = u32::BITS as usize;
const SECTOR_SIZE: usize = 512;
const MAX_DRIVES: usize = (b'z' - b'a') as usize;
pub struct AhciController {
regs: IrqSafeSpinlock<DeviceMemoryIo<'static, Regs>>,
ports: OneTimeInit<Vec<&'static AhciPort>>,
received_fis_buffers: OneTimeInit<[Option<PageBox<ReceivedFis>>; 16]>,
version: Version,
max_port_count: usize,
ahci_only: bool,
has_64_bit: bool,
}
impl AhciController {
async fn late_init(&'static self) -> Result<(), AhciError> {
log::info!("Initializing AHCI SATA Controller {:?}", self.version);
let regs = self.regs.lock();
regs.GHC.modify(GHC::HR::SET);
while regs.GHC.matches_all(GHC::HR::SET) {
core::hint::spin_loop();
}
if !self.ahci_only {
regs.GHC.modify(GHC::AE::SET);
}
let pi = regs.PI.get();
let mut ports = vec![];
drop(regs);
let mut fis_buffers = [const { None }; 16];
// Allocate FIS receive buffers for the ports
for i in 0..self.max_port_count {
if pi & (1 << i) == 0 {
continue;
}
let regs = self.regs.lock();
let port = &regs.PORTS[i];
let buffer = PageBox::new(ReceivedFis::zeroed()).map_err(AhciError::MemoryError)?;
port.set_received_fis_address_64(unsafe { buffer.as_physical_address() });
fis_buffers[i] = Some(buffer);
}
self.received_fis_buffers.init(fis_buffers);
for i in 0..self.max_port_count {
if pi & (1 << i) == 0 {
continue;
}
let regs = self.regs.lock();
let port = &regs.PORTS[i];
if !port.SSTS.matches_all(SSTS::DET::Online + SSTS::IPM::Active) {
continue;
}
port.start()?;
// TODO wait here
let sig = port.SIG.get();
if sig != PortRegs::SIG_SATA {
log::warn!("Skipping unknown port {} with signature {:#x}", i, sig);
continue;
}
let port = unsafe { regs.extract(|regs| &regs.PORTS[i]) };
drop(regs);
let port = match AhciPort::create(port, self, i) {
Ok(port) => port,
Err(error) => {
log::warn!("Port {} init error: {:?}", i, error);
continue;
}
};
ports.push(port);
}
let ports = self.ports.init(ports);
// Enable global HC interrupts
self.regs.lock().GHC.modify(GHC::IE::SET);
// Setup the detected ports
for (i, &port) in ports.iter().enumerate() {
log::info!("Init port {}", i);
port.init().await?;
}
// Dump info about the drives
for (i, &port) in ports.iter().enumerate() {
let info = port.info().unwrap();
log::info!(
"Port {}: model={:?}, serial={:?}, lba_count={}",
i,
info.model,
info.serial,
info.lba_count
);
}
{
let mut lock = SATA_DRIVES.lock();
for &port in ports.iter() {
let n = lock.len();
if n >= MAX_DRIVES {
todo!("Too many drives, ran out of letters");
}
let n = n as u8;
lock.push(port);
let name = format!("sd{}", (n + b'a') as char);
let blk = NgBlockDeviceWrapper::new(port);
devfs::add_named_block_device(blk, name.clone()).ok();
probe_partitions(blk, move |index, partition| {
devfs::add_block_device_partition(name.clone(), index, partition)
})
.ok();
}
}
log::debug!("All ports initialized");
Ok(())
}
}
impl InterruptHandler for AhciController {
fn handle_irq(&self, _vector: Option<usize>) -> bool {
let regs = self.regs.lock();
let is = regs.IS.get();
if is != 0 {
if let Some(ports) = self.ports.try_get() {
// Clear global interrupt status
regs.IS.set(u32::MAX);
for &port in ports {
if is & (1 << port.index) != 0 {
port.handle_pending_interrupts();
}
}
}
}
false
}
}
impl Device for AhciController {
unsafe fn init(&'static self) -> Result<(), Error> {
// Do the init in background
runtime::spawn(self.late_init())?;
Ok(())
}
fn display_name(&self) -> &'static str {
"AHCI SATA Controller"
}
}
static SATA_DRIVES: IrqSafeSpinlock<Vec<&'static AhciPort>> = IrqSafeSpinlock::new(Vec::new());
pub fn probe(info: &PciDeviceInfo) -> Result<&'static dyn Device, Error> {
let bar5 = info.config_space.bar(5).ok_or(Error::InvalidOperation)?;
let bar5 = bar5.as_memory().ok_or(Error::InvalidOperation)?;
let mut cmd = PciCommandRegister::from_bits_retain(info.config_space.command());
cmd &= !(PciCommandRegister::DISABLE_INTERRUPTS | PciCommandRegister::ENABLE_IO);
cmd |= PciCommandRegister::ENABLE_MEMORY | PciCommandRegister::BUS_MASTER;
info.config_space.set_command(cmd.bits());
info.init_interrupts(PreferredInterruptMode::Msi)?;
// // TODO support regular PCI interrupts (ACPI dependency)
// let Some(mut msi) = info.config_space.capability::<MsiCapability>() else {
// log::warn!("Ignoring AHCI: does not support MSI (and the OS doesn't yet support PCI IRQ)");
// return Err(Error::InvalidOperation);
// };
// Map the registers
let regs = unsafe { DeviceMemoryIo::<Regs>::map(bar5, Default::default()) }?;
let version = Version::try_from(regs.VS.get())?;
let ahci_only = regs.CAP.matches_all(CAP::SAM::SET);
let max_port_count = regs.CAP.read(CAP::NP) as usize;
let has_64_bit = regs.CAP.matches_all(CAP::S64A::SET);
// TODO extract Number of Command Slots
let ahci = Box::leak(Box::new(AhciController {
regs: IrqSafeSpinlock::new(regs),
ports: OneTimeInit::new(),
received_fis_buffers: OneTimeInit::new(),
version,
max_port_count,
ahci_only,
has_64_bit,
}));
// TODO use multiple vectors if capable
info.map_interrupt(InterruptAffinity::Any, ahci)?;
Ok(ahci)
}

View File

@ -0,0 +1,401 @@
use core::{
pin::Pin,
sync::atomic::{AtomicU32, Ordering},
task::{Context, Poll},
};
use alloc::{boxed::Box, string::String};
use bytemuck::Zeroable;
use futures_util::{task::AtomicWaker, Future};
use libk_mm::{address::AsPhysicalAddress, device::DeviceMemoryIo, PageBox};
use libk_util::{sync::IrqSafeSpinlock, waker::QueueWaker, OneTimeInit};
use tock_registers::interfaces::{Readable, Writeable};
use ygg_driver_block::{IoOperation, IoRequest, IoSubmissionId, NgBlockDevice};
use yggdrasil_abi::error::Error;
use crate::{
command::{AtaCommand, AtaIdentify, AtaReadDmaEx},
data::{CommandListEntry, CommandTable, ReceivedFis, COMMAND_LIST_LENGTH},
error::AhciError,
regs::{CommandState, CommandStatus, PortRegs, IE, TFD},
AhciController, MAX_COMMANDS, SECTOR_SIZE,
};
#[derive(Clone, Copy, PartialEq, Debug)]
pub enum PortType {
Sata,
}
struct PortInner {
regs: DeviceMemoryIo<'static, PortRegs>,
#[allow(unused)]
received_fis: PageBox<ReceivedFis>,
command_list: PageBox<[CommandListEntry]>,
}
pub struct PortInfo {
pub model: String,
pub serial: String,
pub lba_count: u64,
}
#[allow(unused)]
pub struct AhciPort {
inner: IrqSafeSpinlock<PortInner>,
ahci: &'static AhciController,
ty: PortType,
pub(crate) index: usize,
info: OneTimeInit<PortInfo>,
command_allocation: IrqSafeSpinlock<u32>,
// One command index can only be waited for by one task, so this approach is usable
command_completion: [(AtomicWaker, AtomicU32); COMMAND_LIST_LENGTH],
command_available: QueueWaker,
}
impl PortInner {
fn submit_command<C: AtaCommand>(
&mut self,
index: usize,
command: &C,
) -> Result<(), AhciError> {
let list_entry = &mut self.command_list[index];
let mut table_entry =
PageBox::new(CommandTable::zeroed()).map_err(AhciError::MemoryError)?;
table_entry.setup_command(command)?;
*list_entry = CommandListEntry::new(
unsafe { table_entry.as_physical_address() },
command.regions().len(),
)?;
// Sync before send
// XXX do this properly
#[cfg(target_arch = "x86_64")]
unsafe {
core::arch::asm!("wbinvd");
}
// TODO deal with this async way
while self.regs.TFD.matches_any(TFD::BSY::SET + TFD::DRQ::SET) {
core::hint::spin_loop();
}
let ci = self.regs.CI.get();
assert_eq!(ci & (1 << index), 0);
self.regs.CI.set(ci | (1 << index));
Ok(())
}
}
impl AhciPort {
pub fn create(
regs: DeviceMemoryIo<'static, PortRegs>,
ahci: &'static AhciController,
index: usize,
) -> Result<&'static Self, AhciError> {
log::debug!("Initialize port {}", index);
regs.stop()?;
if !ahci.has_64_bit {
todo!("Handle controllers incapable of 64 bit");
}
let received_fis = PageBox::new(ReceivedFis::zeroed()).map_err(AhciError::MemoryError)?;
let command_list = PageBox::new_slice(CommandListEntry::zeroed(), COMMAND_LIST_LENGTH)
.map_err(AhciError::MemoryError)?;
regs.set_received_fis_address_64(unsafe { received_fis.as_physical_address() });
regs.set_command_list_address_64(unsafe { command_list.as_physical_address() });
regs.IE.write(
IE::DPE::SET
+ IE::IFE::SET
+ IE::OFE::SET
+ IE::HBDE::SET
+ IE::HBFE::SET
+ IE::TFEE::SET
+ IE::DHRE::SET,
);
regs.start()?;
let inner = PortInner {
regs,
command_list,
received_fis,
};
let command_completion = [const { (AtomicWaker::new(), AtomicU32::new(0)) }; MAX_COMMANDS];
let command_available = QueueWaker::new();
let command_allocation = IrqSafeSpinlock::new(0);
Ok(Box::leak(Box::new(Self {
inner: IrqSafeSpinlock::new(inner),
ty: PortType::Sata,
info: OneTimeInit::new(),
ahci,
index,
command_completion,
command_allocation,
command_available,
})))
}
pub async fn init(&'static self) -> Result<(), AhciError> {
let identify = self.perform_command(AtaIdentify::create()?).await?;
let model = identify.model_number.to_string();
let serial = identify.serial_number.to_string();
let lba_count = identify.logical_sector_count();
// TODO can sector size be different from 512 in ATA?
// should logical sector size be accounted for?
// TODO test for ReadDmaEx capability (?)
self.info.init(PortInfo {
model,
serial,
lba_count,
});
Ok(())
}
pub fn info(&self) -> Option<&PortInfo> {
self.info.try_get()
}
async fn perform_command<C: AtaCommand>(&self, command: C) -> Result<C::Response, AhciError> {
let slot = self.allocate_command().await?;
log::trace!(
"Submit command on port {}, cmd index = {}",
self.index,
slot
);
self.inner.lock().submit_command(slot, &command)?;
self.wait_for_completion(slot).await?;
self.free_command(slot);
Ok(unsafe { command.into_response() })
}
fn allocate_command(&self) -> impl Future<Output = Result<usize, AhciError>> + '_ {
struct F<'f> {
waker: &'f QueueWaker,
state: &'f IrqSafeSpinlock<u32>,
}
impl<'f> Future for F<'f> {
type Output = Result<usize, AhciError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.waker.register(cx.waker());
let mut state = self.state.lock();
if *state != u32::MAX {
self.waker.remove(cx.waker());
for i in 0..MAX_COMMANDS {
if *state & (1 << i) == 0 {
*state |= 1 << i;
return Poll::Ready(Ok(i));
}
}
panic!("Unreachable");
} else {
Poll::Pending
}
}
}
let waker = &self.command_available;
let state = &self.command_allocation;
F { waker, state }
}
fn wait_for_completion(
&self,
index: usize,
) -> impl Future<Output = Result<(), AhciError>> + '_ {
struct F<'f> {
waker: &'f AtomicWaker,
status: &'f AtomicU32,
}
impl<'f> Future for F<'f> {
type Output = Result<(), AhciError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.status.load(Ordering::Acquire) {
0 => (),
1 => return Poll::Ready(Ok(())),
_ => return Poll::Ready(Err(AhciError::DeviceError)),
}
self.waker.register(cx.waker());
match self.status.load(Ordering::Acquire) {
0 => Poll::Pending,
1 => Poll::Ready(Ok(())),
_ => Poll::Ready(Err(AhciError::DeviceError)),
}
}
}
let (waker, status) = &self.command_completion[index];
F { status, waker }
}
fn free_command(&self, index: usize) {
{
let mut alloc = self.command_allocation.lock();
assert_ne!(*alloc & (1 << index), 0);
*alloc &= !(1 << index);
}
self.command_available.wake_one();
}
pub fn handle_pending_interrupts(&self) -> bool {
let inner = self.inner.lock();
for i in 0..MAX_COMMANDS {
match inner.regs.clear_state(i) {
CommandState::Pending => (),
CommandState::Ready(status) => {
// TODO better error handling?
let val = match status {
CommandStatus::Success => 1,
_ => 2,
};
self.command_completion[i].1.store(val, Ordering::Release);
self.command_completion[i].0.wake();
}
}
}
true
}
}
impl NgBlockDevice for AhciPort {
type CompletionNotify = AtomicWaker;
fn bus_id(&self) -> u32 {
0
}
fn unit_id(&self) -> u32 {
self.index as u32
}
fn block_size(&self) -> u64 {
SECTOR_SIZE as _
}
fn block_count(&self) -> u64 {
self.info.get().lba_count
}
fn max_blocks_per_request(&self) -> u64 {
// TODO
1
}
async fn submit_request(&self, request: IoRequest<'_>) -> Result<IoSubmissionId, Error> {
// TODO better error handling
let slot = self.allocate_command().await.unwrap();
log::trace!(
"Submit command on port {}, cmd index = {}",
self.index,
slot
);
match request.operation {
IoOperation::Read { lba, count } => {
self.inner
.lock()
.submit_command(slot, &AtaReadDmaEx::new(lba, count, request.data))
.unwrap();
}
IoOperation::Write { .. } => todo!(),
}
Ok(IoSubmissionId {
queue_id: self.index,
command_id: slot,
})
}
fn poll_completion(&self, id: IoSubmissionId) -> Poll<Result<(), Error>> {
let (_, status) = &self.command_completion[id.command_id];
match status.load(Ordering::Acquire) {
0 => Poll::Pending,
1 => {
self.free_command(id.command_id);
log::debug!("COMMAND FINISHED");
Poll::Ready(Ok(()))
}
_ => todo!(), // Poll::Ready(Err(AhciError::DeviceError)),
}
}
fn completion_notify(&self, id: IoSubmissionId) -> &Self::CompletionNotify {
let (notify, _) = &self.command_completion[id.command_id];
notify
}
}
// impl BlockDevice for AhciPort {
// fn read(&'static self, mut pos: u64, buf: &mut [u8]) -> Result<usize, Error> {
// let info = self.info.try_get().ok_or(Error::PermissionDenied)?;
//
// let mut cache = self.cache.lock();
// let mut rem = buf.len();
// let mut off = 0;
//
// while rem != 0 {
// let lba = pos / SECTOR_SIZE as u64;
//
// if lba >= info.lba_count {
// break;
// }
//
// let block_offset = (pos % SECTOR_SIZE as u64) as usize;
// let count = core::cmp::min(SECTOR_SIZE - block_offset, rem);
//
// let block = cache.get_or_fetch_with(lba, |block| {
// block! {
// self.read_block(lba, block).await
// }?
// .map_err(|_| Error::InvalidOperation)
// })?;
//
// buf[off..off + count].copy_from_slice(&block[block_offset..block_offset + count]);
//
// rem -= count;
// off += count;
// pos += count as u64;
// }
//
// Ok(off)
// }
//
// fn write(&'static self, _pos: u64, _buf: &[u8]) -> Result<usize, Error> {
// todo!()
// }
//
// fn size(&self) -> Result<u64, Error> {
// let info = self.info.try_get().ok_or(Error::PermissionDenied)?;
// Ok(info.lba_count * SECTOR_SIZE as u64)
// }
//
// fn device_request(&self, _req: &mut DeviceRequest) -> Result<(), Error> {
// todo!()
// }
// }

View File

@ -0,0 +1,203 @@
use libk_mm::address::{IntoRaw, PhysicalAddress};
use tock_registers::{
interfaces::{ReadWriteable, Readable, Writeable},
register_bitfields, register_structs,
registers::{ReadOnly, ReadWrite},
};
use yggdrasil_abi::error::Error;
use crate::error::AhciError;
register_bitfields! {
u32,
pub CAP [
NP OFFSET(0) NUMBITS(5) [],
NCS OFFSET(8) NUMBITS(5) [],
SAM OFFSET(18) NUMBITS(1) [],
S64A OFFSET(31) NUMBITS(1) [],
],
pub GHC [
HR OFFSET(0) NUMBITS(1) [],
IE OFFSET(1) NUMBITS(1) [],
AE OFFSET(31) NUMBITS(1) [],
],
// Read/write 1 to clear
pub IS [
TFES OFFSET(30) NUMBITS(1) [],
HBFS OFFSET(29) NUMBITS(1) [],
HBDS OFFSET(28) NUMBITS(1) [],
IFS OFFSET(27) NUMBITS(1) [],
OFS OFFSET(24) NUMBITS(1) [],
],
pub IE [
TFEE OFFSET(30) NUMBITS(1) [],
HBFE OFFSET(29) NUMBITS(1) [],
HBDE OFFSET(28) NUMBITS(1) [],
IFE OFFSET(27) NUMBITS(1) [],
OFE OFFSET(24) NUMBITS(1) [],
DPE OFFSET(5) NUMBITS(1) [],
DHRE OFFSET(0) NUMBITS(1) [],
],
pub CMD [
CR OFFSET(15) NUMBITS(1) [],
FR OFFSET(14) NUMBITS(1) [],
CCS OFFSET(8) NUMBITS(5) [],
FRE OFFSET(4) NUMBITS(1) [],
POD OFFSET(2) NUMBITS(1) [],
ST OFFSET(0) NUMBITS(1) [],
],
pub SSTS [
IPM OFFSET(8) NUMBITS(4) [
NotPresent = 0,
Active = 1,
],
DET OFFSET(0) NUMBITS(4) [
NotPresent = 0,
Online = 3,
],
],
pub TFD [
BSY OFFSET(7) NUMBITS(1) [],
DRQ OFFSET(3) NUMBITS(1) [],
ERR OFFSET(0) NUMBITS(1) [],
]
}
register_structs! {
#[allow(non_snake_case)]
pub Regs {
(0x0000 => pub CAP: ReadOnly<u32, CAP::Register>),
(0x0004 => pub GHC: ReadWrite<u32, GHC::Register>),
(0x0008 => pub IS: ReadWrite<u32>),
(0x000C => pub PI: ReadOnly<u32>),
(0x0010 => pub VS: ReadOnly<u32>),
(0x0014 => _0),
(0x0100 => pub PORTS: [PortRegs; 30]),
(0x1000 => @END),
}
}
register_structs! {
#[allow(non_snake_case)]
pub PortRegs {
(0x00 => pub CLB: ReadWrite<u32>),
(0x04 => pub CLBU: ReadWrite<u32>),
(0x08 => pub FB: ReadWrite<u32>),
(0x0C => pub FBU: ReadWrite<u32>),
(0x10 => pub IS: ReadWrite<u32, IS::Register>),
(0x14 => pub IE: ReadWrite<u32, IE::Register>),
(0x18 => pub CMD: ReadWrite<u32, CMD::Register>),
(0x1C => _0),
(0x20 => pub TFD: ReadWrite<u32, TFD::Register>),
(0x24 => pub SIG: ReadOnly<u32>),
(0x28 => pub SSTS: ReadOnly<u32, SSTS::Register>),
(0x2C => pub SCTL: ReadOnly<u32>),
(0x30 => pub SERR: ReadOnly<u32>),
(0x34 => pub SACT: ReadOnly<u32>),
(0x38 => pub CI: ReadWrite<u32>),
(0x3C => pub SNTF: ReadOnly<u32>),
(0x40 => _1),
(0x80 => @END),
}
}
#[derive(Clone, Copy, PartialEq, Debug)]
pub enum Version {
V0_95,
V1_0,
V1_1,
V1_2,
V1_3,
V1_3_1,
}
#[derive(Clone, Copy, PartialEq, Debug)]
pub enum CommandStatus {
Success,
TaskFileError,
}
#[derive(Clone, Copy, PartialEq, Debug)]
pub enum CommandState {
Pending,
Ready(CommandStatus),
}
impl PortRegs {
pub const SIG_SATA: u32 = 0x101;
// NOTE: usually doesn't take long, so not async, I guess
pub fn stop(&self) -> Result<(), AhciError> {
self.CMD.modify(CMD::ST::CLEAR + CMD::FRE::CLEAR);
// TODO timeout here
while self.CMD.matches_any(CMD::FR::SET + CMD::CR::SET) {
core::hint::spin_loop();
}
Ok(())
}
pub fn start(&self) -> Result<(), AhciError> {
while self.CMD.matches_all(CMD::CR::SET) {
core::hint::spin_loop();
}
self.CMD.modify(CMD::ST::SET + CMD::FRE::SET);
Ok(())
}
pub fn set_received_fis_address_64(&self, address: PhysicalAddress) {
let address: u64 = address.into_raw();
self.FB.set(address as u32);
self.FBU.set((address >> 32) as u32);
}
pub fn set_command_list_address_64(&self, address: PhysicalAddress) {
let address: u64 = address.into_raw();
self.CLB.set(address as u32);
self.CLBU.set((address >> 32) as u32);
}
pub fn clear_state(&self, index: usize) -> CommandState {
let is = self.IS.extract();
let ci = self.CI.get();
if is.get() == 0 {
return CommandState::Pending;
}
// Clear everything
self.IS.set(0xFFFFFFFF);
if is.matches_any(IS::HBDS::SET + IS::HBFS::SET) {
todo!("Host communication error unhandled");
}
assert_eq!(ci & (1 << index), 0);
if is.matches_any(IS::TFES::SET + IS::IFS::SET + IS::OFS::SET) {
return CommandState::Ready(CommandStatus::TaskFileError);
}
CommandState::Ready(CommandStatus::Success)
}
}
impl TryFrom<u32> for Version {
type Error = Error;
fn try_from(value: u32) -> Result<Self, Self::Error> {
match value {
0x00000905 => Ok(Self::V0_95),
0x00010000 => Ok(Self::V1_0),
0x00010100 => Ok(Self::V1_1),
0x00010200 => Ok(Self::V1_2),
0x00010300 => Ok(Self::V1_3),
0x00010301 => Ok(Self::V1_3_1),
_ => Err(Error::InvalidArgument),
}
}
}

View File

@ -0,0 +1,18 @@
[package]
name = "ygg_driver_block"
version = "0.1.0"
edition = "2021"
authors = ["Mark Poliakov <mark@alnyan.me>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
libk-util = { path = "../../../libk/libk-util" }
libk-mm = { path = "../../../libk/libk-mm" }
log = "0.4.20"
futures-util = { version = "0.3.28", default-features = false, features = ["alloc", "async-await"] }
bytemuck = { version = "1.14.0", features = ["derive"] }
static_assertions = "1.1.0"
uuid = { version = "1.6.1", default-features = false, features = ["bytemuck"] }

View File

@ -0,0 +1,379 @@
#![allow(unused)]
use core::{
ops::Range,
pin::Pin,
task::{Context, Poll},
};
use alloc::boxed::Box;
use futures_util::{task::AtomicWaker, Future};
use libk_mm::{address::PhysicalAddress, table::MapAttributes, PageBox, PageProvider};
use libk_util::waker::QueueWaker;
use yggdrasil_abi::{error::Error, io::DeviceRequest};
use crate::{
request::{IoOperation, IoRequest, IoSubmissionId},
BlockDevice,
};
pub trait CompletionNotify {
fn wait_for_completion<'a, D: NgBlockDevice + 'a>(
&'a self,
device: &'a D,
id: IoSubmissionId,
) -> impl Future<Output = Result<(), Error>> + Send + '_;
}
pub trait NgBlockDevice: Sync {
type CompletionNotify: CompletionNotify;
fn bus_id(&self) -> u32; // HBA, controller ID, etc.
fn unit_id(&self) -> u32; // Drive, slot, connector ID, etc.
fn block_size(&self) -> u64;
fn block_count(&self) -> u64;
fn max_blocks_per_request(&self) -> u64;
fn submit_request(
&self,
request: IoRequest,
) -> impl Future<Output = Result<IoSubmissionId, Error>> + Send;
fn poll_completion(&self, id: IoSubmissionId) -> Poll<Result<(), Error>>;
fn completion_notify(&self, id: IoSubmissionId) -> &Self::CompletionNotify;
fn wait_for_completion(
&self,
id: IoSubmissionId,
) -> impl Future<Output = Result<(), Error>> + Send + '_
where
Self: Sized,
{
self.completion_notify(id).wait_for_completion(self, id)
}
}
pub struct NgBlockDeviceWrapper<'a, D: NgBlockDevice + 'a> {
device: &'a D,
pub(crate) block_size: u64,
pub(crate) block_count: u64,
#[allow(unused)]
max_blocks_per_request: u64,
}
#[derive(Debug, PartialEq)]
struct BlockChunk {
lba_start: u64,
lba_count: usize,
buffer_offset: usize,
lba_offset: usize,
byte_count: usize,
}
struct BlockChunkIter {
remaining: usize,
buffer_offset: usize,
position: u64,
block_size: u64,
max_blocks_per_request: u64,
}
impl CompletionNotify for QueueWaker {
fn wait_for_completion<'a, D: NgBlockDevice + 'a>(
&'a self,
device: &'a D,
id: IoSubmissionId,
) -> impl Future<Output = Result<(), Error>> + Send + '_ {
struct F<'f, D: NgBlockDevice + 'f> {
device: &'f D,
notify: &'f QueueWaker,
id: IoSubmissionId,
}
impl<'f, D: NgBlockDevice + 'f> Future for F<'f, D> {
type Output = Result<(), Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.notify.register(cx.waker());
match self.device.poll_completion(self.id) {
Poll::Ready(result) => {
self.notify.remove(cx.waker());
Poll::Ready(result)
}
Poll::Pending => Poll::Pending,
}
}
}
F {
notify: self,
device,
id,
}
}
}
impl CompletionNotify for AtomicWaker {
fn wait_for_completion<'a, D: NgBlockDevice + 'a>(
&'a self,
device: &'a D,
id: IoSubmissionId,
) -> impl Future<Output = Result<(), Error>> + Send + '_ {
struct F<'f, D: NgBlockDevice + 'f> {
device: &'f D,
notify: &'f AtomicWaker,
id: IoSubmissionId,
}
impl<'f, D: NgBlockDevice + 'f> Future for F<'f, D> {
type Output = Result<(), Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Poll::Ready(result) = self.device.poll_completion(self.id) {
return Poll::Ready(result);
}
self.notify.register(cx.waker());
self.device.poll_completion(self.id)
}
}
F {
notify: self,
device,
id,
}
}
}
impl BlockChunk {
pub fn block_range(&self) -> Range<usize> {
self.lba_offset..self.lba_offset + self.byte_count
}
pub fn buffer_range(&self) -> Range<usize> {
self.buffer_offset..self.buffer_offset + self.byte_count
}
}
impl BlockChunkIter {
pub fn new(pos: u64, count: usize, lba_size: u64, max_lba_per_request: u64) -> Self {
Self {
remaining: count,
buffer_offset: 0,
position: pos,
block_size: lba_size,
max_blocks_per_request: max_lba_per_request,
}
}
}
impl Iterator for BlockChunkIter {
type Item = BlockChunk;
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
let lba_start = self.position / self.block_size;
let lba_end =
(self.position + self.remaining as u64 + self.block_size - 1) / self.block_size;
let lba_count = core::cmp::min(lba_end - lba_start, self.max_blocks_per_request);
let lba_offset = (self.position % self.block_size) as usize;
let byte_count = core::cmp::min(
(lba_count * self.block_size) as usize - lba_offset,
self.remaining,
);
let buffer_offset = self.buffer_offset;
self.position += byte_count as u64;
self.buffer_offset += byte_count;
self.remaining -= byte_count;
Some(BlockChunk {
lba_start,
lba_count: lba_count as usize,
buffer_offset,
lba_offset,
byte_count,
})
}
}
impl<'a, D: NgBlockDevice + 'a> NgBlockDeviceWrapper<'a, D> {
pub fn new(device: &'a D) -> &'a Self {
let block_size = device.block_size();
let block_count = device.block_count();
let max_blocks_per_request = device.max_blocks_per_request();
Box::leak(Box::new(Self {
device,
block_size,
block_count,
max_blocks_per_request,
}))
}
async fn read_range_inner(&self, lba: u64, count: usize) -> Result<PageBox<[u8]>, Error> {
let mut data = PageBox::new_uninit_slice(self.block_size as usize * count)?;
let id = self
.device
.submit_request(IoRequest {
operation: IoOperation::Read { lba, count },
data: &mut data,
})
.await?;
self.device.wait_for_completion(id).await?;
Ok(unsafe { data.assume_init_slice() })
}
}
impl<'a, D: NgBlockDevice + 'a> PageProvider for NgBlockDeviceWrapper<'a, D> {
fn get_page(&self, _offset: u64) -> Result<PhysicalAddress, Error> {
todo!()
}
fn release_page(&self, _offset: u64, _phys: PhysicalAddress) -> Result<(), Error> {
todo!()
}
fn clone_page(
&self,
_offset: u64,
_src_phys: PhysicalAddress,
_src_attrs: MapAttributes,
) -> Result<PhysicalAddress, Error> {
todo!()
}
}
impl<'a, D: NgBlockDevice + 'a> BlockDevice for NgBlockDeviceWrapper<'a, D> {
fn poll_read(
&self,
cx: &mut Context<'_>,
pos: u64,
buf: &mut [u8],
) -> Poll<Result<usize, Error>> {
todo!()
}
fn poll_write(&self, cx: &mut Context<'_>, pos: u64, buf: &[u8]) -> Poll<Result<usize, Error>> {
todo!()
}
// fn read(&'static self, pos: u64, buf: &mut [u8]) -> Result<usize, Error> {
// // TODO block cache
// block! {
// let mut bytes_read = 0;
// for chunk in
// BlockChunkIter::new(pos, buf.len(), self.block_size, self.max_blocks_per_request)
// {
// log::debug!(
// "Read chunk: lba_start={}, lba_count={}",
// chunk.lba_start,
// chunk.lba_count
// );
// let block = self.read_range_inner(chunk.lba_start, chunk.lba_count).await?;
// buf[chunk.buffer_range()].copy_from_slice(&block[chunk.block_range()]);
// bytes_read += chunk.byte_count;
// }
// Ok(bytes_read)
// }?
// }
// fn write(&'static self, _pos: u64, _buf: &[u8]) -> Result<usize, Error> {
// todo!()
// }
fn size(&self) -> Result<u64, Error> {
Ok(self.block_size * self.block_count)
}
fn device_request(&self, _req: &mut DeviceRequest) -> Result<(), Error> {
todo!()
}
}
#[cfg(test)]
mod tests {
use crate::device::BlockChunk;
use super::BlockChunkIter;
#[test]
fn block_chunk_iter() {
let mut it = BlockChunkIter {
remaining: 512 * 9 + 1,
position: 123,
block_size: 512,
buffer_offset: 0,
max_blocks_per_request: 2,
};
assert_eq!(
it.next().unwrap(),
BlockChunk {
lba_start: 0,
lba_count: 2,
buffer_offset: 0,
lba_offset: 123,
byte_count: 901
}
);
assert_eq!(
it.next().unwrap(),
BlockChunk {
lba_start: 2,
lba_count: 2,
buffer_offset: 1024 - 123,
lba_offset: 0,
byte_count: 1024
}
);
assert_eq!(
it.next().unwrap(),
BlockChunk {
lba_start: 4,
lba_count: 2,
buffer_offset: 2 * 1024 - 123,
lba_offset: 0,
byte_count: 1024
}
);
assert_eq!(
it.next().unwrap(),
BlockChunk {
lba_start: 6,
lba_count: 2,
buffer_offset: 3 * 1024 - 123,
lba_offset: 0,
byte_count: 1024
}
);
assert_eq!(
it.next().unwrap(),
BlockChunk {
lba_start: 8,
lba_count: 2,
buffer_offset: 4 * 1024 - 123,
lba_offset: 0,
byte_count: 512 + 123 + 1
}
);
}
}

View File

@ -0,0 +1,110 @@
#![no_std]
extern crate alloc;
use core::task::{Context, Poll};
use libk_mm::PageProvider;
use yggdrasil_abi::{error::Error, io::DeviceRequest};
pub mod device;
// mod partition;
pub mod request;
pub use device::{NgBlockDevice, NgBlockDeviceWrapper};
pub use request::{IoOperation, IoRequest, IoSubmissionId};
// TODO
pub fn probe_partitions<
D: NgBlockDevice + 'static,
F: Fn(usize, &'static dyn BlockDevice) -> Result<(), Error> + Send + 'static,
>(
_dev: &'static NgBlockDeviceWrapper<D>,
_callback: F,
) -> Result<(), Error> {
Ok(())
// async fn probe_table<D: NgBlockDevice + 'static>(
// dev: &'static NgBlockDeviceWrapper<'static, D>,
// ) -> Result<Option<Vec<Partition<'static, D>>>, Error> {
// if let Some(partitions) = partition::probe_gpt(dev)? {
// return Ok(Some(partitions));
// }
// Ok(None)
// }
// runtime::spawn(async move {
// match probe_table(dev).await {
// Ok(Some(partitions)) => {
// // Create block devices for the partitions
// for (i, partition) in partitions.into_iter().enumerate() {
// let partition_blkdev = Box::leak(Box::new(partition));
// if let Err(error) = callback(i, partition_blkdev) {
// log::warn!("Could not add partition {}: {:?}", i, error);
// }
// }
// }
// Ok(None) => {
// log::warn!("Unknown or missing partition table");
// }
// Err(error) => {
// log::warn!("Could not probe partition table: {:?}", error);
// }
// }
// })
}
/// Block device interface
#[allow(unused)]
pub trait BlockDevice: PageProvider + Sync {
fn poll_read(
&self,
cx: &mut Context<'_>,
pos: u64,
buf: &mut [u8],
) -> Poll<Result<usize, Error>> {
Poll::Ready(Err(Error::NotImplemented))
}
fn poll_write(&self, cx: &mut Context<'_>, pos: u64, buf: &[u8]) -> Poll<Result<usize, Error>> {
Poll::Ready(Err(Error::NotImplemented))
}
// /// Reads data frmo the given offset of the device
// fn read(&'static self, pos: u64, buf: &mut [u8]) -> Result<usize, Error> {
// Err(Error::NotImplemented)
// }
// /// Writes the data to the given offset of the device
// fn write(&'static self, pos: u64, buf: &[u8]) -> Result<usize, Error> {
// Err(Error::NotImplemented)
// }
/// Returns the size of the block device in bytes
fn size(&self) -> Result<u64, Error> {
Err(Error::NotImplemented)
}
/// Returns `true` if the device can be read from
fn is_readable(&self) -> bool {
true
}
/// Returns `true` if the device can be written to
fn is_writable(&self) -> bool {
true
}
/// Performs a device-specific function
fn device_request(&self, req: &mut DeviceRequest) -> Result<(), Error> {
Err(Error::NotImplemented)
}
// fn read_exact(&'static self, pos: u64, buf: &mut [u8]) -> Result<(), Error> {
// let count = self.read(pos, buf)?;
// if count == buf.len() {
// Ok(())
// } else {
// Err(Error::MissingData)
// }
// }
}

View File

@ -0,0 +1,137 @@
use core::mem::{size_of, MaybeUninit};
use alloc::{vec, vec::Vec};
use bytemuck::{Pod, Zeroable};
use libk::mem::PageBox;
use static_assertions::const_assert_eq;
use uuid::Uuid;
use yggdrasil_abi::{error::Error, io::DeviceRequest};
use crate::{BlockDevice, NgBlockDevice, NgBlockDeviceWrapper};
pub struct Partition<'a, D: NgBlockDevice + 'a> {
pub device: &'a NgBlockDeviceWrapper<'a, D>,
pub lba_start: u64,
pub lba_end: u64,
}
#[derive(Clone, Copy)]
#[repr(C)]
struct GptHeader {
signature: [u8; 8],
revision: u32,
header_size: u32,
crc32: u32,
_0: u32,
header_lba: u64,
alternate_header_lba: u64,
first_usable_lba: u64,
last_usable_lba: u64,
guid: [u8; 16],
partition_table_lba: u64,
partition_table_len: u32,
partition_table_entry_size: u32,
partition_table_crc32: u32,
_1: [u8; 420],
}
#[derive(Clone, Copy, Zeroable, Pod)]
#[repr(C)]
struct GptEntry {
type_guid: Uuid,
part_guid: Uuid,
lba_start: u64,
lba_end: u64,
attrs: u64,
}
const_assert_eq!(size_of::<GptHeader>(), 512);
impl<'a, D: NgBlockDevice + 'a> Partition<'a, D> {
fn end_byte(&self) -> u64 {
self.lba_end * self.device.block_size
}
fn start_byte(&self) -> u64 {
self.lba_start * self.device.block_size
}
}
impl<'a, D: NgBlockDevice + 'a> BlockDevice for Partition<'a, D> {
fn read(&'static self, pos: u64, buf: &mut [u8]) -> Result<usize, Error> {
if pos >= self.end_byte() {
return Ok(0);
}
let start = self.start_byte() + pos;
let end = core::cmp::min(start + buf.len() as u64, self.end_byte());
let count = (end - start) as usize;
self.device.read(start, &mut buf[..count])
}
fn write(&'static self, pos: u64, buf: &[u8]) -> Result<usize, Error> {
if pos >= self.end_byte() {
return Ok(0);
}
let start = self.start_byte() + pos;
let end = core::cmp::min(start + buf.len() as u64, self.end_byte());
let count = (end - start) as usize;
self.device.write(start, &buf[..count])
}
fn size(&self) -> Result<u64, Error> {
Ok((self.lba_end - self.lba_start) * self.device.block_size)
}
fn device_request(&self, req: &mut DeviceRequest) -> Result<(), Error> {
self.device.device_request(req)
}
}
unsafe fn read_struct_lba<T>(dev: &'static dyn BlockDevice, lba: u64) -> Result<T, Error> {
assert_eq!(size_of::<T>(), 512);
let mut data = MaybeUninit::<T>::uninit();
let buffer = core::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, 512);
dev.read_exact(lba * 512, buffer)?;
Ok(data.assume_init())
}
pub(crate) fn probe_gpt<D: NgBlockDevice + 'static>(
dev: &'static NgBlockDeviceWrapper<'static, D>,
) -> Result<Option<Vec<Partition<'static, D>>>, Error> {
let header = unsafe { read_struct_lba::<GptHeader>(dev, 1) }?;
if &header.signature != b"EFI PART" {
// Not a GPT partition table
return Ok(None);
}
let pt_entsize = header.partition_table_entry_size as usize;
let pt_len = header.partition_table_len as usize;
let mut pt_data = PageBox::new_slice(0, pt_len * pt_entsize)?;
assert!(size_of::<GptEntry>() <= pt_entsize);
dev.read_exact(header.partition_table_lba * 512, &mut pt_data)?;
let mut partitions = vec![];
for i in 0..pt_len {
let pt_entry_data = &pt_data[i * pt_entsize..i * pt_entsize + size_of::<GptEntry>()];
let pt_entry: &GptEntry = bytemuck::from_bytes(pt_entry_data);
if pt_entry.type_guid.is_nil() {
continue;
}
partitions.push(Partition {
device: dev,
lba_start: pt_entry.lba_start,
lba_end: pt_entry.lba_end,
});
}
Ok(Some(partitions))
}

View File

@ -0,0 +1,19 @@
use core::mem::MaybeUninit;
use libk_mm::PageBox;
pub enum IoOperation {
Read { lba: u64, count: usize },
Write { lba: u64, count: usize },
}
pub struct IoRequest<'a> {
pub operation: IoOperation,
pub data: &'a mut PageBox<[MaybeUninit<u8>]>,
}
#[derive(Clone, Copy, Debug)]
pub struct IoSubmissionId {
pub queue_id: usize,
pub command_id: usize,
}

View File

@ -0,0 +1,25 @@
[package]
name = "ygg_driver_nvme"
version = "0.1.0"
edition = "2021"
authors = ["Mark Poliakov <mark@alnyan.me>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
libk-util = { path = "../../../libk/libk-util" }
libk-thread = { path = "../../../libk/libk-thread" }
libk-mm = { path = "../../../libk/libk-mm" }
device-api = { path = "../../../lib/device-api", features = ["derive"] }
vfs = { path = "../../../lib/vfs" }
ygg_driver_pci = { path = "../../bus/pci" }
ygg_driver_block = { path = "../../block/core" }
kernel-fs = { path = "../../fs/kernel-fs" }
log = "0.4.20"
futures-util = { version = "0.3.28", default-features = false, features = ["alloc", "async-await"] }
static_assertions = "1.1.0"
tock-registers = "0.8.1"
bytemuck = { version = "1.14.0", features = ["derive"] }

View File

@ -0,0 +1,269 @@
#![allow(unused)]
use core::fmt::{self, Write};
use libk_mm::address::PhysicalAddress;
use tock_registers::{interfaces::Readable, register_structs, registers::ReadOnly, UIntLike};
use crate::queue::PhysicalRegionPage;
use super::queue::SubmissionQueueEntry;
pub trait Command {
fn fill_sqe(&self, sqe: &mut SubmissionQueueEntry);
}
pub trait Request: Command {
type Response;
}
#[derive(Clone, Copy)]
#[repr(transparent)]
pub struct String<const N: usize> {
data: [u8; N],
}
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
#[repr(u8)]
pub enum ControllerType {
Reserved,
Io,
Discovery,
Administrative,
}
// I/O commands
#[derive(Clone, Copy, Debug)]
pub struct IoRead {
pub nsid: u32,
pub lba: u64,
pub count: u32,
}
#[derive(Clone, Copy, Debug)]
pub struct IoWrite {
pub nsid: u32,
pub lba: u64,
pub count: u32,
}
// Requests
#[derive(Clone, Copy, Debug)]
pub enum SetFeatureRequest {
NumberOfQueues(u32, u32),
}
#[derive(Clone, Copy, Debug)]
pub struct IdentifyControllerRequest;
#[derive(Clone, Copy, Debug)]
pub struct IdentifyActiveNamespaceIdListRequest {
pub start_id: u32,
}
#[derive(Clone, Copy, Debug)]
pub struct IdentifyNamespaceRequest {
pub nsid: u32,
}
#[derive(Clone, Copy, Debug)]
pub struct CreateIoCompletionQueue {
pub id: u32,
pub size: usize,
pub vector: u32,
pub data: PhysicalAddress,
}
#[derive(Clone, Copy, Debug)]
pub struct CreateIoSubmissionQueue {
pub id: u32,
pub cq_id: u32,
pub size: usize,
pub data: PhysicalAddress,
}
// Replies
#[derive(Clone, Copy, Debug)]
#[repr(C)]
pub struct IdentifyControllerResponse {
pub pci_vid: u16,
pub pci_ssvid: u16,
pub serial_number: String<20>,
pub model_number: String<40>,
pub firmware_rev: u64,
_0: [u8; 5], // 72..77
pub mdts: u8,
pub cntlid: u16,
pub ver: u32,
_1: [u8; 12], // 84..96
pub ctratt: u32,
_2: [u8; 11], // 100..111
pub cntrltype: ControllerType,
}
#[derive(Clone, Copy, Debug)]
#[repr(C)]
pub struct IdentifyActiveNamespaceIdListResponse {
pub entries: [u32; 1024],
}
register_structs! {
#[allow(non_snake_case)]
pub IdentifyNamespaceResponse {
(0 => NSZE: ReadOnly<u64>),
(8 => _0),
(25 => NLBAF: ReadOnly<u8>),
(26 => FLBAS: ReadOnly<u8>),
(27 => _1),
(128 => LBAFS: [ReadOnly<u32>; 64]),
(384 => _2),
(4096 => @END),
}
}
#[derive(Clone, Copy, Debug)]
#[repr(transparent)]
pub struct LbaFormat(u32);
impl Command for IdentifyControllerRequest {
fn fill_sqe(&self, sqe: &mut SubmissionQueueEntry) {
sqe.command.set_opcode(0x06);
sqe.command_specific[0] = 0x01;
}
}
impl Request for IdentifyControllerRequest {
type Response = IdentifyControllerResponse;
}
impl Command for IdentifyActiveNamespaceIdListRequest {
fn fill_sqe(&self, sqe: &mut SubmissionQueueEntry) {
sqe.command.set_opcode(0x06);
sqe.command_specific[0] = 0x02;
sqe.nsid = self.start_id;
}
}
impl Request for IdentifyActiveNamespaceIdListRequest {
type Response = IdentifyActiveNamespaceIdListResponse;
}
impl Command for IdentifyNamespaceRequest {
fn fill_sqe(&self, sqe: &mut SubmissionQueueEntry) {
sqe.command.set_opcode(0x06);
sqe.command_specific[0] = 0x00;
sqe.nsid = self.nsid;
}
}
impl Request for IdentifyNamespaceRequest {
type Response = IdentifyNamespaceResponse;
}
impl IdentifyNamespaceResponse {
pub fn current_lba_fmt_idx(&self) -> usize {
let flbas = self.FLBAS.get();
let mut index = flbas & 0xF;
if self.NLBAF.get() > 16 {
index |= (flbas & 0xE0) >> 1;
}
index as usize
}
pub fn lba_fmt(&self, idx: usize) -> Option<LbaFormat> {
if idx > self.NLBAF.get() as usize {
return None;
}
Some(LbaFormat(self.LBAFS[idx].get()))
}
pub fn total_lba_count(&self) -> u64 {
self.NSZE.get()
}
}
impl LbaFormat {
pub fn lba_data_size(&self) -> Option<u64> {
let lbads = (self.0 >> 16) & 0xFF;
if lbads < 9 {
return None;
}
Some(1 << lbads)
}
}
impl Command for SetFeatureRequest {
fn fill_sqe(&self, sqe: &mut SubmissionQueueEntry) {
sqe.command.set_opcode(0x09);
match self {
Self::NumberOfQueues(cq, sq) => {
let dw11 = (cq << 16) | sq;
sqe.command_specific[0] = 0x07;
sqe.command_specific[1] = dw11;
}
}
}
}
impl Command for CreateIoCompletionQueue {
fn fill_sqe(&self, sqe: &mut SubmissionQueueEntry) {
sqe.command.set_opcode(0x05);
sqe.data_pointer[0] = PhysicalRegionPage::with_addr(self.data);
sqe.command_specific[0] = ((self.size as u32 - 1) << 16) | self.id;
sqe.command_specific[1] = (self.vector << 16) | 3;
}
}
impl Command for CreateIoSubmissionQueue {
fn fill_sqe(&self, sqe: &mut SubmissionQueueEntry) {
sqe.command.set_opcode(0x01);
sqe.data_pointer[0] = PhysicalRegionPage::with_addr(self.data);
sqe.command_specific[0] = ((self.size as u32 - 1) << 16) | self.id;
// Medium priority
sqe.command_specific[1] = (self.cq_id << 16) | 1;
}
}
impl<const N: usize> fmt::Debug for String<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_char('"')?;
for ch in self.data {
if ch == b' ' || ch == 0 {
break;
}
f.write_char(ch as _)?;
}
f.write_char('"')?;
Ok(())
}
}
impl Command for IoRead {
fn fill_sqe(&self, sqe: &mut SubmissionQueueEntry) {
assert!(self.count < 65536);
sqe.command.set_opcode(0x02);
sqe.command_specific[0] = self.lba as u32;
sqe.command_specific[1] = (self.lba >> 32) as u32;
sqe.command_specific[2] = self.count;
sqe.nsid = self.nsid;
}
}
impl Command for IoWrite {
fn fill_sqe(&self, sqe: &mut SubmissionQueueEntry) {
assert!(self.count < 65536);
sqe.command.set_opcode(0x01);
sqe.command_specific[0] = self.lba as u32;
sqe.command_specific[1] = (self.lba >> 32) as u32;
sqe.command_specific[2] = self.count;
sqe.nsid = self.nsid;
}
}

View File

@ -0,0 +1,128 @@
use core::task::Poll;
use alloc::{boxed::Box, format};
use kernel_fs::devfs;
use libk_mm::address::AsPhysicalAddress;
use libk_thread::cpu_index;
use libk_util::waker::QueueWaker;
use ygg_driver_block::{
probe_partitions, IoOperation, IoRequest, IoSubmissionId, NgBlockDevice, NgBlockDeviceWrapper,
};
use yggdrasil_abi::error::Error;
use crate::command::{IdentifyNamespaceRequest, IoRead};
use super::{error::NvmeError, NvmeController};
#[allow(unused)]
pub struct NvmeDrive {
controller: &'static NvmeController,
nsid: u32,
total_lba_count: u64,
lba_size: u64,
}
impl NvmeDrive {
pub async fn create(
controller: &'static NvmeController,
nsid: u32,
) -> Result<&'static NvmeDrive, NvmeError> {
let admin_q = controller.admin_q.get();
let identify = admin_q.request(IdentifyNamespaceRequest { nsid }).await?;
let current_lba_format_idx = identify.current_lba_fmt_idx();
let current_lba_format = identify.lba_fmt(current_lba_format_idx).unwrap();
let lba_size = current_lba_format.lba_data_size().unwrap();
let total_lba_count = identify.total_lba_count();
log::debug!(
"ns = {}, lba = {}B, size = {}M",
nsid,
lba_size,
(total_lba_count * lba_size) / (1024 * 1024)
);
let dev = Box::leak(Box::new(NvmeDrive {
controller,
nsid,
total_lba_count,
lba_size,
}));
let node_name = format!("nvme{}n{}", controller.controller_id.get(), nsid);
let blk = NgBlockDeviceWrapper::new(dev);
devfs::add_named_block_device(blk, node_name.clone()).ok();
probe_partitions(blk, move |index, partition| {
devfs::add_block_device_partition(format!("{}p", node_name), index, partition)
})
.ok();
Ok(dev)
}
}
impl NgBlockDevice for NvmeDrive {
type CompletionNotify = QueueWaker;
fn bus_id(&self) -> u32 {
(*self.controller.controller_id.get()) as _
}
fn unit_id(&self) -> u32 {
self.nsid
}
fn block_size(&self) -> u64 {
self.lba_size
}
fn block_count(&self) -> u64 {
self.total_lba_count
}
fn max_blocks_per_request(&self) -> u64 {
// TODO get from identify
8
}
async fn submit_request(&self, request: IoRequest<'_>) -> Result<IoSubmissionId, Error> {
let queue_id = cpu_index();
let ioq = &self.controller.ioqs.get()[queue_id as usize];
let command_id = match request.operation {
IoOperation::Read { lba, count } => {
log::debug!(
"Submit read of {} lbas from ns {} to queue {}",
count,
self.nsid,
queue_id
);
let range = unsafe { request.data.as_physical_address() };
ioq.submit(
IoRead {
lba,
count: count as _,
nsid: self.nsid,
},
&[range],
true,
)
}
IoOperation::Write { .. } => todo!(),
};
Ok(IoSubmissionId {
queue_id: queue_id as _,
command_id: command_id as _,
})
}
fn poll_completion(&self, id: IoSubmissionId) -> Poll<Result<(), Error>> {
let ioq = &self.controller.ioqs.get()[id.queue_id];
ioq.poll_completion(id.command_id as _)
}
fn completion_notify(&self, id: IoSubmissionId) -> &QueueWaker {
&self.controller.ioqs.get()[id.queue_id].completion_notify
}
}

View File

@ -0,0 +1,15 @@
use yggdrasil_abi::error::Error;
use super::queue::CommandError;
#[derive(Debug)]
pub enum NvmeError {
MemoryError(Error),
CommandError(CommandError),
}
impl From<CommandError> for NvmeError {
fn from(value: CommandError) -> Self {
Self::CommandError(value)
}
}

View File

@ -0,0 +1,456 @@
#![feature(strict_provenance, const_trait_impl, let_chains, if_let_guard, effects)]
#![allow(missing_docs)]
#![no_std]
extern crate alloc;
use core::{
mem::size_of,
sync::atomic::{AtomicUsize, Ordering},
time::Duration,
};
use alloc::{boxed::Box, collections::BTreeMap, vec::Vec};
use command::{IdentifyActiveNamespaceIdListRequest, IdentifyControllerRequest};
use device_api::{
interrupt::{InterruptAffinity, InterruptHandler},
Device,
};
use drive::NvmeDrive;
use libk_mm::{
address::{IntoRaw, PhysicalAddress},
device::DeviceMemoryIo,
};
use libk_thread::{cpu_count, cpu_index, runtime};
use libk_util::{
sync::{IrqGuard, IrqSafeSpinlock},
OneTimeInit,
};
use tock_registers::{
interfaces::{ReadWriteable, Readable, Writeable},
register_bitfields, register_structs,
registers::{ReadOnly, ReadWrite, WriteOnly},
};
use ygg_driver_pci::{
device::{PciDeviceInfo, PreferredInterruptMode},
PciCommandRegister, PciConfigurationSpace,
};
use yggdrasil_abi::error::Error;
use crate::{
command::{IoRead, IoWrite},
queue::{CompletionQueueEntry, SubmissionQueueEntry},
};
use self::{
command::{CreateIoCompletionQueue, CreateIoSubmissionQueue, SetFeatureRequest},
error::NvmeError,
queue::QueuePair,
};
mod command;
mod drive;
mod error;
mod queue;
register_bitfields! {
u32,
CC [
IOCQES OFFSET(20) NUMBITS(4) [],
IOSQES OFFSET(16) NUMBITS(4) [],
AMS OFFSET(11) NUMBITS(3) [],
MPS OFFSET(7) NUMBITS(4) [],
CSS OFFSET(4) NUMBITS(3) [
NvmCommandSet = 0
],
ENABLE OFFSET(0) NUMBITS(1) [],
],
CSTS [
CFS OFFSET(1) NUMBITS(1) [],
RDY OFFSET(0) NUMBITS(1) [],
],
AQA [
/// Admin Completion Queue Size in entries - 1
ACQS OFFSET(16) NUMBITS(12) [],
/// Admin Submission Queue Size in entries - 1
ASQS OFFSET(0) NUMBITS(12) [],
]
}
register_bitfields! {
u64,
CAP [
/// Maximum Queue Entries Supported - 1. i.e., 0 means maximum queue len of 1, 1 = 2 etc.
MQES OFFSET(0) NUMBITS(16) [],
/// Timeout. Represents the worst-case time the host software should wait for CSTS.RDY to
/// change its state.
TO OFFSET(24) NUMBITS(8) [],
/// Doorbell stride. Stride in bytes = pow(2, 2 + DSTRD).
DSTRD OFFSET(32) NUMBITS(4) [],
/// NVM Subsystem Reset Supported (see NVMe BS Section 3.7.1)
NSSRS OFFSET(36) NUMBITS(1) [],
/// Controller supports one or more I/O command sets
CSS_IO_COMMANDS OFFSET(43) NUMBITS(1) [],
/// Controller only supports admin commands and no I/O commands
CSS_ADMIN_ONLY OFFSET(44) NUMBITS(1) [],
/// Memory page size minimum (bytes = pow(2, 12 + MPSMIN))
MPSMIN OFFSET(48) NUMBITS(4) [],
/// Memory page size maximum -|-
MPSMAX OFFSET(52) NUMBITS(4) [],
]
}
register_structs! {
#[allow(non_snake_case)]
Regs {
(0x00 => CAP: ReadOnly<u64, CAP::Register>),
(0x08 => VS: ReadOnly<u32>),
(0x0C => INTMS: WriteOnly<u32>),
(0x10 => INTMC: WriteOnly<u32>),
(0x14 => CC: ReadWrite<u32, CC::Register>),
(0x18 => _0),
(0x1C => CSTS: ReadOnly<u32, CSTS::Register>),
(0x20 => _1),
(0x24 => AQA: ReadWrite<u32, AQA::Register>),
(0x28 => ASQ: ReadWrite<u64>),
(0x30 => ACQ: ReadWrite<u64>),
(0x38 => _2),
(0x2000 => @END),
}
}
pub struct NvmeController {
regs: IrqSafeSpinlock<DeviceMemoryIo<'static, Regs>>,
admin_q: OneTimeInit<QueuePair>,
ioqs: OneTimeInit<Vec<QueuePair>>,
io_queue_count: AtomicUsize,
drive_table: IrqSafeSpinlock<BTreeMap<u32, &'static NvmeDrive>>,
controller_id: OneTimeInit<usize>,
pci: PciDeviceInfo,
doorbell_shift: usize,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum IoDirection {
Read,
Write,
}
impl Regs {
unsafe fn doorbell_ptr(&self, shift: usize, completion: bool, queue_index: usize) -> *mut u32 {
let doorbell_base = (self as *const Regs as *mut Regs).addr() + 0x1000;
let offset = ((queue_index << shift) + completion as usize) * 4;
(doorbell_base + offset) as *mut u32
}
}
impl NvmeController {
const ADMIN_QUEUE_SIZE: usize = 32;
const IO_QUEUE_SIZE: usize = 32;
async fn create_queues(&'static self) -> Result<(), NvmeError> {
let admin_q = self.admin_q.get();
let io_queue_count = self.io_queue_count.load(Ordering::Acquire);
log::info!(
"Creating {} queue pairs for nvme{}",
io_queue_count,
self.controller_id.get()
);
// Request a CQ/SQ pair for I/O
admin_q
.request_no_data(SetFeatureRequest::NumberOfQueues(
io_queue_count as _,
io_queue_count as _,
))
.await?;
let mut queues = Vec::new();
for i in 1..=io_queue_count {
let id = i as u32;
let (sq_doorbell, cq_doorbell) = unsafe { self.doorbell_pair(i) };
let queue = QueuePair::new(id, i, Self::IO_QUEUE_SIZE, sq_doorbell, cq_doorbell)
.map_err(NvmeError::MemoryError)?;
admin_q
.request_no_data(CreateIoCompletionQueue {
id,
vector: id,
size: Self::IO_QUEUE_SIZE,
data: queue.cq_physical_pointer(),
})
.await?;
admin_q
.request_no_data(CreateIoSubmissionQueue {
id,
cq_id: id,
size: Self::IO_QUEUE_SIZE,
data: queue.sq_physical_pointer(),
})
.await?;
queues.push(queue);
}
self.ioqs.init(queues);
Ok(())
}
async fn late_init(&'static self) -> Result<(), NvmeError> {
let io_queue_count = cpu_count();
self.io_queue_count.store(io_queue_count, Ordering::Release);
{
let range = self
.pci
.map_interrupt_multiple(0..io_queue_count + 1, InterruptAffinity::Any, self)
.unwrap();
// TODO handle different MSI range allocations
for (i, msi) in range.iter().enumerate() {
assert_eq!(i, msi.vector);
}
}
register_nvme_controller(self);
let admin_q = self.admin_q.get();
// Identify the controller
let _identify = admin_q.request(IdentifyControllerRequest).await?;
// TODO do something with identify_controller
self.create_queues().await?;
// Identify namespaces
self.enumerate_namespaces().await?;
Ok(())
}
async fn enumerate_namespaces(&'static self) -> Result<(), NvmeError> {
let admin_q = self.admin_q.get();
let namespaces = admin_q
.request(IdentifyActiveNamespaceIdListRequest { start_id: 0 })
.await?;
let count = namespaces.entries.iter().position(|&x| x == 0).unwrap();
let list = &namespaces.entries[..count];
for &nsid in list {
match NvmeDrive::create(self, nsid).await {
Ok(drive) => {
self.drive_table.lock().insert(nsid, drive);
}
Err(error) => {
log::warn!("Could not create nvme drive, nsid={}: {:?}", nsid, error);
}
}
}
Ok(())
}
pub async fn perform_io(
&'static self,
nsid: u32,
lba: u64,
buffer_address: PhysicalAddress,
direction: IoDirection,
) -> Result<(), NvmeError> {
let _guard = IrqGuard::acquire();
let cpu_index = cpu_index();
let ioq = &self.ioqs.get()[cpu_index as usize];
log::debug!(
"{:?} ioq #{}, nsid={}, lba={:#x}",
direction,
cpu_index,
nsid,
lba
);
let cmd_id = match direction {
IoDirection::Read => ioq.submit(
IoRead {
nsid,
lba,
count: 1,
},
&[buffer_address],
true,
),
IoDirection::Write => ioq.submit(
IoWrite {
nsid,
lba,
count: 1,
},
&[buffer_address],
true,
),
};
ioq.wait_for_completion(cmd_id, ()).await?;
Ok(())
}
unsafe fn doorbell_pair(&self, idx: usize) -> (*mut u32, *mut u32) {
let regs = self.regs.lock();
let sq_ptr = regs.doorbell_ptr(self.doorbell_shift, false, idx);
let cq_ptr = regs.doorbell_ptr(self.doorbell_shift, true, idx);
(sq_ptr, cq_ptr)
}
}
impl InterruptHandler for NvmeController {
fn handle_irq(&self, vector: Option<usize>) -> bool {
let vector = vector.expect("Only MSI-X interrupts are supported");
if vector == 0 {
self.admin_q.get().process_completions() != 0
} else if vector <= self.io_queue_count.load(Ordering::Acquire)
&& let Some(ioqs) = self.ioqs.try_get()
{
ioqs[vector - 1].process_completions() != 0
} else {
false
}
}
}
impl Device for NvmeController {
unsafe fn init(&'static self) -> Result<(), Error> {
let regs = self.regs.lock();
let min_page_size = 1usize << (12 + regs.CAP.read(CAP::MPSMIN));
if min_page_size > 4096 {
panic!();
}
let timeout = Duration::from_millis(regs.CAP.read(CAP::TO) * 500);
log::debug!("Worst-case timeout: {:?}", timeout);
while regs.CSTS.matches_any(CSTS::RDY::SET) {
core::hint::spin_loop();
}
if Self::ADMIN_QUEUE_SIZE as u64 > regs.CAP.read(CAP::MQES) + 1 {
todo!(
"queue_slots too big, max = {}",
regs.CAP.read(CAP::MQES) + 1
);
}
// Setup the admin queue (index 0)
let admin_sq_doorbell = unsafe { regs.doorbell_ptr(self.doorbell_shift, false, 0) };
let admin_cq_doorbell = unsafe { regs.doorbell_ptr(self.doorbell_shift, true, 0) };
log::debug!("sq_doorbell for adminq = {:p}", admin_sq_doorbell);
let admin_q = QueuePair::new(
0,
0,
Self::ADMIN_QUEUE_SIZE,
admin_sq_doorbell,
admin_cq_doorbell,
)
.unwrap();
regs.AQA.modify(
AQA::ASQS.val(Self::ADMIN_QUEUE_SIZE as u32 - 1)
+ AQA::ACQS.val(Self::ADMIN_QUEUE_SIZE as u32 - 1),
);
regs.ASQ.set(admin_q.sq_physical_pointer().into_raw());
regs.ACQ.set(admin_q.cq_physical_pointer().into_raw());
// Configure the controller
const IOSQES: u32 = size_of::<SubmissionQueueEntry>().ilog2();
const IOCQES: u32 = size_of::<CompletionQueueEntry>().ilog2();
regs.CC.modify(
CC::IOCQES.val(IOCQES)
+ CC::IOSQES.val(IOSQES)
+ CC::MPS.val(0)
+ CC::CSS::NvmCommandSet,
);
// Enable the controller
regs.CC.modify(CC::ENABLE::SET);
log::debug!("Reset the controller");
while !regs.CSTS.matches_any(CSTS::RDY::SET + CSTS::CFS::SET) {
core::hint::spin_loop();
}
if regs.CSTS.matches_any(CSTS::CFS::SET) {
todo!("CFS set after reset!");
}
self.admin_q.init(admin_q);
// Schedule late_init task
runtime::spawn(self.late_init())?;
Ok(())
}
fn display_name(&self) -> &'static str {
"NVM Express Controller"
}
}
static NVME_CONTROLLERS: IrqSafeSpinlock<Vec<&'static NvmeController>> =
IrqSafeSpinlock::new(Vec::new());
pub fn probe(info: &PciDeviceInfo) -> Result<&'static dyn Device, Error> {
let bar0 = info
.config_space
.bar(0)
.unwrap()
.as_memory()
.expect("Expected a memory BAR0");
info.init_interrupts(PreferredInterruptMode::Msi)?;
let mut cmd = PciCommandRegister::from_bits_retain(info.config_space.command());
cmd &= !(PciCommandRegister::DISABLE_INTERRUPTS | PciCommandRegister::ENABLE_IO);
cmd |= PciCommandRegister::ENABLE_MEMORY | PciCommandRegister::BUS_MASTER;
info.config_space.set_command(cmd.bits());
let regs = unsafe { DeviceMemoryIo::<Regs>::map(bar0, Default::default()) }?;
// Disable the controller
regs.CC.modify(CC::ENABLE::CLEAR);
let doorbell_shift = regs.CAP.read(CAP::DSTRD) as usize + 1;
Ok(Box::leak(Box::new(NvmeController {
regs: IrqSafeSpinlock::new(regs),
admin_q: OneTimeInit::new(),
ioqs: OneTimeInit::new(),
drive_table: IrqSafeSpinlock::new(BTreeMap::new()),
controller_id: OneTimeInit::new(),
pci: info.clone(),
io_queue_count: AtomicUsize::new(1),
doorbell_shift,
})))
}
pub fn register_nvme_controller(ctrl: &'static NvmeController) {
let mut list = NVME_CONTROLLERS.lock();
let id = list.len();
list.push(ctrl);
ctrl.controller_id.init(id);
}

View File

@ -0,0 +1,432 @@
use core::{
mem::size_of,
pin::Pin,
ptr::null_mut,
task::{Context, Poll},
};
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use bytemuck::{Pod, Zeroable};
use futures_util::Future;
use libk_mm::{
address::{AsPhysicalAddress, IntoRaw, PhysicalAddress},
PageBox,
};
use libk_util::{sync::IrqSafeSpinlock, waker::QueueWaker};
use static_assertions::const_assert;
use yggdrasil_abi::error::Error;
use super::{
command::{Command, Request},
error::NvmeError,
};
#[derive(Zeroable, Pod, Clone, Copy, Debug)]
#[repr(C)]
pub struct PhysicalRegionPage(u64);
// Bits:
//
// 16..32 - CID. Command identifier
// 14..16 - PSDT. PRP or SGL for data transfer.
// 0b00 - PRP used
// 0b01 - SGL used. Not implemented
// 0b10 - SGL used. Not implemented
// 0b11 - Reserved
// 10..14 - Reserved
// 08..10 - FUSE. Fused Operation
// 00..08 - OPC. Opcode
#[derive(Zeroable, Pod, Clone, Copy, Debug)]
#[repr(C)]
pub struct CommandDword0(u32);
#[derive(Zeroable, Pod, Clone, Copy, Debug)]
#[repr(C)]
pub struct SubmissionQueueEntry {
pub command: CommandDword0, // 0
pub nsid: u32, // 1
pub io_data: [u32; 2], // 2, 3
pub metadata_pointer: u64, // 4, 5
pub data_pointer: [PhysicalRegionPage; 2], // 6, 7, 8, 9
pub command_specific: [u32; 6], // 10, 11, 12, 13, 14, 15
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct CommandError {
sct: u8,
sc: u8,
}
#[derive(Zeroable, Pod, Clone, Copy, Debug)]
#[repr(C)]
pub struct CompletionQueueEntry {
dw: [u32; 4],
}
pub struct Queue<T> {
data: PageBox<[T]>,
mask: usize,
head: usize,
tail: usize,
phase: bool,
head_doorbell: *mut u32,
tail_doorbell: *mut u32,
}
struct Inner {
sq: Queue<SubmissionQueueEntry>,
cq: Queue<CompletionQueueEntry>,
completed: BTreeMap<u32, CompletionQueueEntry>,
pending: BTreeSet<u32>,
}
pub struct QueuePair {
id: u32,
#[allow(unused)]
vector: usize,
sq_base: PhysicalAddress,
cq_base: PhysicalAddress,
pub completion_notify: QueueWaker,
inner: IrqSafeSpinlock<Inner>,
}
const_assert!(size_of::<CompletionQueueEntry>().is_power_of_two());
impl PhysicalRegionPage {
pub const fn null() -> Self {
Self(0)
}
pub const fn with_addr(address: PhysicalAddress) -> Self {
Self(address.into_raw())
}
}
impl CommandDword0 {
pub fn set_command_id(&mut self, id: u32) {
debug_assert!(id & 0xFFFF0000 == 0);
self.0 &= !(0xFFFF << 16);
self.0 |= id << 16;
}
pub fn set_opcode(&mut self, opcode: u8) {
self.0 &= !0xFF;
self.0 |= opcode as u32;
}
}
impl CompletionQueueEntry {
pub fn phase(&self) -> bool {
self.dw[3] & (1 << 16) != 0
}
pub fn sub_queue_id(&self) -> u32 {
self.dw[2] >> 16
}
pub fn sub_queue_head(&self) -> usize {
(self.dw[2] & 0xFFFF) as _
}
pub fn command_id(&self) -> u32 {
self.dw[3] & 0xFFFF
}
pub fn error(&self) -> Option<CommandError> {
let status = (self.dw[3] >> 17) as u16;
if status != 0 {
Some(CommandError {
sct: ((status >> 8) & 0x7) as u8,
sc: status as u8,
})
} else {
None
}
}
}
impl<T> Queue<T> {
pub fn new(
data: PageBox<[T]>,
head_doorbell: *mut u32,
tail_doorbell: *mut u32,
phase: bool,
) -> Self {
assert!(
(head_doorbell.is_null() && !tail_doorbell.is_null())
|| (!head_doorbell.is_null() && tail_doorbell.is_null())
);
Self {
mask: data.len() - 1,
head: 0,
tail: 0,
data,
head_doorbell,
tail_doorbell,
phase,
}
}
pub fn enqueue(&mut self, item: T) -> usize {
let index = self.tail;
self.data[self.tail] = item;
self.phase ^= self.set_tail(self.next_index(self.tail));
index
}
pub fn at_head(&self, offset: usize) -> (&T, bool) {
let index = (self.head + offset) & self.mask;
let expected_phase = self.phase ^ (index < self.head);
(&self.data[index], expected_phase)
}
pub fn take(&mut self, count: usize) {
let index = (self.head + count) & self.mask;
self.phase ^= self.set_head(index);
}
pub fn take_until(&mut self, new_head: usize) {
self.phase ^= self.set_head(new_head);
}
fn next_index(&self, index: usize) -> usize {
(index + 1) & self.mask
}
fn set_tail(&mut self, new_tail: usize) -> bool {
let wrapped = new_tail < self.tail;
self.tail = new_tail;
if !self.tail_doorbell.is_null() {
unsafe {
self.tail_doorbell
.write_volatile(self.tail.try_into().unwrap());
}
}
wrapped
}
fn set_head(&mut self, new_head: usize) -> bool {
let wrapped = new_head < self.head;
self.head = new_head;
if !self.head_doorbell.is_null() {
unsafe {
self.head_doorbell
.write_volatile(self.head.try_into().unwrap());
}
}
wrapped
}
}
impl QueuePair {
pub fn new(
id: u32,
vector: usize,
capacity: usize,
sq_doorbell: *mut u32,
cq_doorbell: *mut u32,
) -> Result<Self, Error> {
let sq_data = PageBox::new_slice(SubmissionQueueEntry::zeroed(), capacity)?;
let cq_data = PageBox::new_slice(CompletionQueueEntry::zeroed(), capacity)?;
let sq_base = unsafe { sq_data.as_physical_address() };
let cq_base = unsafe { cq_data.as_physical_address() };
log::debug!("Allocated queue pair: sq={:p}, cq={:p}", sq_data, cq_data);
let sq = Queue::new(sq_data, null_mut(), sq_doorbell, true);
let cq = Queue::new(cq_data, cq_doorbell, null_mut(), true);
let inner = IrqSafeSpinlock::new(Inner {
sq,
cq,
pending: BTreeSet::new(),
completed: BTreeMap::new(),
});
Ok(Self {
completion_notify: QueueWaker::new(),
id,
vector,
sq_base,
cq_base,
inner,
})
}
#[inline]
pub fn sq_physical_pointer(&self) -> PhysicalAddress {
self.sq_base
}
#[inline]
pub fn cq_physical_pointer(&self) -> PhysicalAddress {
self.cq_base
}
pub fn poll_completion(&self, command_id: u32) -> Poll<Result<(), Error>> {
let mut inner = self.inner.lock();
match inner.completed.remove(&command_id) {
Some(result) if let Some(_error) = result.error() => todo!(),
Some(_) => Poll::Ready(Ok(())),
None => Poll::Pending,
}
}
pub fn wait_for_completion<'r, T: Unpin + 'r>(
&'r self,
command_id: u32,
result: T,
) -> impl Future<Output = Result<T, CommandError>> + 'r {
struct Fut<'r, R: Unpin + 'r> {
this: &'r QueuePair,
response: Option<R>,
command_id: u32,
}
impl<'r, R: Unpin + 'r> Future for Fut<'r, R> {
type Output = Result<R, CommandError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.this.completion_notify.register(cx.waker());
let mut inner = self.this.inner.lock();
if let Some(entry) = inner.completed.remove(&self.command_id) {
self.this.completion_notify.remove(cx.waker());
let result = if let Some(error) = entry.error() {
Err(error)
} else {
Ok(self.response.take().unwrap())
};
Poll::Ready(result)
} else {
Poll::Pending
}
}
}
Fut {
this: self,
response: Some(result),
command_id,
}
}
pub fn submit<C: Command>(&self, cmd: C, ranges: &[PhysicalAddress], set_pending: bool) -> u32 {
let mut inner = self.inner.lock();
let mut sqe = SubmissionQueueEntry::zeroed();
match ranges.len() {
1 => {
sqe.data_pointer[0] = PhysicalRegionPage::with_addr(ranges[0]);
sqe.data_pointer[1] = PhysicalRegionPage::null();
}
0 => {
sqe.data_pointer[0] = PhysicalRegionPage::null();
sqe.data_pointer[1] = PhysicalRegionPage::null();
}
_ => todo!(),
}
cmd.fill_sqe(&mut sqe);
let command_id = inner.sq.tail.try_into().unwrap();
sqe.command.set_command_id(command_id);
if set_pending {
inner.pending.insert(command_id);
}
inner.sq.enqueue(sqe);
command_id
}
pub fn request_no_data<C: Command>(
&self,
req: C,
) -> impl Future<Output = Result<(), CommandError>> + '_ {
let command_id = self.submit(req, &[], true);
self.wait_for_completion(command_id, ())
}
pub async fn request<'r, R: Request>(
&'r self,
req: R,
) -> Result<PageBox<R::Response>, NvmeError>
where
R::Response: 'r,
{
let response = PageBox::new_uninit().map_err(NvmeError::MemoryError)?;
let command_id = self.submit(req, &[unsafe { response.as_physical_address() }], true);
let result = self.wait_for_completion(command_id, response).await?;
Ok(unsafe { result.assume_init() })
}
pub fn process_completions(&self) -> usize {
let mut inner = self.inner.lock();
let mut n = 0;
let mut completion_list = Vec::new();
loop {
let (cmp, expected_phase) = inner.cq.at_head(n);
let cmp_phase = cmp.phase();
if cmp_phase != expected_phase {
break;
}
n += 1;
let sub_queue_id = cmp.sub_queue_id();
// TODO allow several sqs receive completions through one cq?
assert_eq!(sub_queue_id, self.id);
let sub_queue_head = cmp.sub_queue_head();
let cmp = *cmp;
inner.sq.take_until(sub_queue_head);
completion_list.push(cmp);
}
if n != 0 {
inner.cq.take(n);
}
for cmp in completion_list {
let command_id = cmp.command_id();
if inner.pending.remove(&command_id) {
inner.completed.insert(command_id, cmp);
}
}
if n != 0 {
self.completion_notify.wake_all();
}
n
}
}

View File

@ -0,0 +1,21 @@
[package]
name = "ygg_driver_pci"
version = "0.1.0"
edition = "2021"
authors = ["Mark Poliakov <mark@alnyan.me>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
device-api = { path = "../../../lib/device-api", features = ["derive"] }
libk-mm = { path = "../../../libk/libk-mm" }
libk-device = { path = "../../../libk/libk-device" }
libk-util = { path = "../../../libk/libk-util" }
log = "0.4.20"
bitflags = "2.3.3"
tock-registers = "0.8.1"
[target.'cfg(target_arch = "x86_64")'.dependencies]
acpi = { git = "https://github.com/alnyan/acpi.git", package = "acpi", branch = "acpi-system" }

View File

@ -0,0 +1,388 @@
//! PCI capability structures and queries
use alloc::{vec, vec::Vec};
use device_api::interrupt::{
InterruptAffinity, InterruptHandler, MessageInterruptController, MsiInfo,
};
use libk_mm::{address::PhysicalAddress, device::DeviceMemoryIoMut};
use tock_registers::{
interfaces::{Readable, Writeable},
registers::{ReadWrite, WriteOnly},
};
use yggdrasil_abi::error::Error;
use super::{PciCapability, PciCapabilityId, PciConfigurationSpace};
pub trait VirtioCapabilityData<'s, S: PciConfigurationSpace + ?Sized + 's>: Sized {
fn from_space_offset(space: &'s S, offset: usize) -> Self;
fn space(&self) -> &'s S;
fn offset(&self) -> usize;
fn bar_index(&self) -> Option<usize> {
let value = self.space().read_u8(self.offset() + 4);
(value <= 0x5).then_some(value as _)
}
fn bar_offset(&self) -> usize {
let value = self.space().read_u32(self.offset() + 8);
value as _
}
fn length(&self) -> usize {
let value = self.space().read_u32(self.offset() + 12);
value as _
}
}
pub trait VirtioCapability {
const CFG_TYPE: u8;
const MIN_LEN: usize = 0;
type Output<'a, S: PciConfigurationSpace + ?Sized + 'a>: VirtioCapabilityData<'a, S>;
}
/// MSI-X capability query
pub struct MsiXCapability;
/// MSI capability query
pub struct MsiCapability;
// VirtIO-over-PCI capabilities
/// VirtIO PCI configuration access
pub struct VirtioDeviceConfigCapability;
/// VirtIO common configuration
pub struct VirtioCommonConfigCapability;
/// VirtIO notify configuration
pub struct VirtioNotifyConfigCapability;
/// VirtIO interrupt status
pub struct VirtioInterruptStatusCapability;
/// Represents an entry in MSI-X vector table
#[repr(C)]
pub struct MsiXEntry {
/// Address to which the value is written on interrupt
pub address: WriteOnly<u64>,
/// Value which is written to trigger an interrupt
pub data: WriteOnly<u32>,
/// Vector control word
pub control: ReadWrite<u32>,
}
pub struct MsiXVectorTable<'a> {
vectors: DeviceMemoryIoMut<'a, [MsiXEntry]>,
}
/// MSI-X capability data structure
pub struct MsiXData<'s, S: PciConfigurationSpace + ?Sized + 's> {
space: &'s S,
offset: usize,
}
/// MSI capability data structure
pub struct MsiData<'s, S: PciConfigurationSpace + ?Sized + 's> {
space: &'s S,
offset: usize,
}
pub struct VirtioDeviceConfigData<'s, S: PciConfigurationSpace + ?Sized + 's> {
space: &'s S,
offset: usize,
}
pub struct VirtioCommonConfigData<'s, S: PciConfigurationSpace + ?Sized + 's> {
space: &'s S,
offset: usize,
}
pub struct VirtioNotifyConfigData<'s, S: PciConfigurationSpace + ?Sized + 's> {
space: &'s S,
offset: usize,
}
pub struct VirtioInterruptStatusData<'s, S: PciConfigurationSpace + ?Sized + 's> {
space: &'s S,
offset: usize,
}
impl<T: VirtioCapability> PciCapability for T {
const ID: PciCapabilityId = PciCapabilityId::VendorSpecific;
type CapabilityData<'a, S: PciConfigurationSpace + ?Sized + 'a> = T::Output<'a, S>;
fn check<S: PciConfigurationSpace + ?Sized>(space: &S, offset: usize, len: usize) -> bool {
let cfg_type = space.read_u8(offset + 3);
cfg_type == T::CFG_TYPE && len >= T::MIN_LEN
}
fn data<'s, S: PciConfigurationSpace + ?Sized + 's>(
space: &'s S,
offset: usize,
_len: usize,
) -> Self::CapabilityData<'s, S> {
T::Output::from_space_offset(space, offset)
}
}
impl PciCapability for MsiXCapability {
const ID: PciCapabilityId = PciCapabilityId::MsiX;
type CapabilityData<'a, S: PciConfigurationSpace + ?Sized + 'a> = MsiXData<'a, S>;
fn data<'s, S: PciConfigurationSpace + ?Sized + 's>(
space: &'s S,
offset: usize,
_len: usize,
) -> Self::CapabilityData<'s, S> {
MsiXData { space, offset }
}
}
impl PciCapability for MsiCapability {
const ID: PciCapabilityId = PciCapabilityId::Msi;
type CapabilityData<'a, S: PciConfigurationSpace + ?Sized + 'a> = MsiData<'a, S>;
fn data<'s, S: PciConfigurationSpace + ?Sized + 's>(
space: &'s S,
offset: usize,
_len: usize,
) -> Self::CapabilityData<'s, S> {
MsiData { space, offset }
}
}
impl VirtioCapability for VirtioDeviceConfigCapability {
const CFG_TYPE: u8 = 0x04;
type Output<'a, S: PciConfigurationSpace + ?Sized + 'a> = VirtioDeviceConfigData<'a, S>;
}
impl<'s, S: PciConfigurationSpace + ?Sized + 's> VirtioCapabilityData<'s, S>
for VirtioDeviceConfigData<'s, S>
{
fn from_space_offset(space: &'s S, offset: usize) -> Self {
Self { space, offset }
}
fn space(&self) -> &'s S {
self.space
}
fn offset(&self) -> usize {
self.offset
}
}
impl VirtioCapability for VirtioCommonConfigCapability {
const CFG_TYPE: u8 = 0x01;
type Output<'a, S: PciConfigurationSpace + ?Sized + 'a> = VirtioCommonConfigData<'a, S>;
}
impl<'s, S: PciConfigurationSpace + ?Sized + 's> VirtioCapabilityData<'s, S>
for VirtioCommonConfigData<'s, S>
{
fn from_space_offset(space: &'s S, offset: usize) -> Self {
Self { space, offset }
}
fn space(&self) -> &'s S {
self.space
}
fn offset(&self) -> usize {
self.offset
}
}
impl VirtioCapability for VirtioNotifyConfigCapability {
const CFG_TYPE: u8 = 0x02;
const MIN_LEN: usize = 0x14;
type Output<'a, S: PciConfigurationSpace + ?Sized + 'a> = VirtioNotifyConfigData<'a, S>;
}
impl<'s, S: PciConfigurationSpace + ?Sized + 's> VirtioNotifyConfigData<'s, S> {
pub fn offset_multiplier(&self) -> usize {
self.space.read_u32(self.offset + 16) as usize
}
}
impl<'s, S: PciConfigurationSpace + ?Sized + 's> VirtioCapabilityData<'s, S>
for VirtioNotifyConfigData<'s, S>
{
fn from_space_offset(space: &'s S, offset: usize) -> Self {
Self { space, offset }
}
fn space(&self) -> &'s S {
self.space
}
fn offset(&self) -> usize {
self.offset
}
}
impl VirtioCapability for VirtioInterruptStatusCapability {
const CFG_TYPE: u8 = 0x03;
const MIN_LEN: usize = 1;
type Output<'a, S: PciConfigurationSpace + ?Sized + 'a> = VirtioInterruptStatusData<'a, S>;
}
impl<'s, S: PciConfigurationSpace + ?Sized + 's> VirtioInterruptStatusData<'s, S> {
pub fn read_status(&self) -> (bool, bool) {
todo!()
}
}
impl<'s, S: PciConfigurationSpace + ?Sized + 's> VirtioCapabilityData<'s, S>
for VirtioInterruptStatusData<'s, S>
{
fn from_space_offset(space: &'s S, offset: usize) -> Self {
Self { space, offset }
}
fn space(&self) -> &'s S {
self.space
}
fn offset(&self) -> usize {
self.offset
}
}
impl<'s, S: PciConfigurationSpace + ?Sized + 's> MsiXData<'s, S> {
// TODO use pending bits as well
/// Maps and returns the vector table associated with the device's MSI-X capability
pub fn vector_table<'a>(&self) -> Result<MsiXVectorTable<'a>, Error> {
let w0 = self.space.read_u16(self.offset + 2);
let dw1 = self.space.read_u32(self.offset + 4);
let table_size = (w0 as usize & 0x3FF) + 1;
let bir = dw1 as usize & 0x3;
let table_offset = dw1 as usize & !0x3;
let Some(base) = self.space.bar(bir) else {
return Err(Error::DoesNotExist);
};
let Some(base) = base.as_memory() else {
return Err(Error::InvalidOperation);
};
log::debug!("MSI-X table address: {:#x}", base.add(table_offset));
unsafe { MsiXVectorTable::from_raw_parts(base.add(table_offset), table_size) }
}
/// Changes the global enable status for the device's MSI-X capability. If set, regular IRQs
/// are not generated.
pub fn set_enabled(&mut self, enabled: bool) {
let mut w0 = self.space.read_u32(self.offset);
if enabled {
w0 |= 1 << 31;
} else {
w0 &= !(1 << 31);
}
self.space.write_u32(self.offset, w0);
}
pub fn set_function_mask(&mut self, masked: bool) {
let mut w0 = self.space.read_u32(self.offset);
if masked {
w0 |= 1 << 30;
} else {
w0 &= !(1 << 30);
}
self.space.write_u32(self.offset, w0);
}
}
impl MsiXVectorTable<'_> {
unsafe fn from_raw_parts(base: PhysicalAddress, len: usize) -> Result<Self, Error> {
let vectors = DeviceMemoryIoMut::map_slice(base, len, Default::default())?;
Ok(Self { vectors })
}
pub fn mask_all(&mut self) {
for vector in self.vectors.iter_mut() {
vector.set_masked(true);
}
}
pub fn register_range<C: MessageInterruptController + ?Sized>(
&mut self,
start: usize,
end: usize,
ic: &C,
affinity: InterruptAffinity,
handler: &'static dyn InterruptHandler,
) -> Result<Vec<MsiInfo>, Error> {
assert!(end > start);
let mut range = vec![
MsiInfo {
affinity,
..Default::default()
};
end - start
];
ic.register_msi_range(&mut range, handler)?;
for (i, info) in range.iter().enumerate() {
let index = i + start;
self.vectors[index].address.set(info.address as _);
self.vectors[index].data.set(info.value);
self.vectors[index].set_masked(false);
}
Ok(range)
}
}
impl MsiXEntry {
/// If set, prevents the MSI-X interrupt from being delivered
fn set_masked(&mut self, masked: bool) {
if masked {
self.control.set(self.control.get() | 1);
} else {
self.control.set(self.control.get() & !1);
}
}
}
impl<'s, S: PciConfigurationSpace + ?Sized + 's> MsiData<'s, S> {
pub fn register<C: MessageInterruptController + ?Sized>(
&mut self,
ic: &C,
affinity: InterruptAffinity,
handler: &'static dyn InterruptHandler,
) -> Result<MsiInfo, Error> {
let info = ic.register_msi(affinity, handler)?;
let mut w0 = self.space.read_u16(self.offset + 2);
// Enable the vector first
w0 |= 1 << 0;
// Reset to one vector
w0 &= !(0x7 << 4);
self.space.write_u16(self.offset + 2, w0);
if info.value > u16::MAX as u32 {
log::warn!("Could not setup a MSI: value={:#x} > u16", info.value);
return Err(Error::InvalidOperation);
}
if info.address > u32::MAX as usize {
if w0 & (1 << 7) == 0 {
log::warn!(
"Could not setup a MSI: address={:#x} and MSI is not 64 bit capable",
info.address
);
return Err(Error::InvalidOperation);
}
self.space
.write_u32(self.offset + 8, (info.address >> 32) as u32);
}
self.space.write_u32(self.offset + 4, info.address as u32);
self.space.write_u16(self.offset + 12, info.value as u16);
Ok(info)
}
}

View File

@ -0,0 +1,220 @@
use core::ops::Range;
use alloc::{sync::Arc, vec::Vec};
use device_api::{
interrupt::{InterruptAffinity, InterruptHandler, IrqOptions, MsiInfo},
Device,
};
use libk_device::{message_interrupt_controller, register_global_interrupt};
use libk_util::{sync::spin_rwlock::IrqSafeRwLock, OneTimeInit};
use yggdrasil_abi::error::Error;
use crate::{
capability::{MsiCapability, MsiXCapability, MsiXVectorTable},
PciAddress, PciConfigSpace, PciConfigurationSpace, PciSegmentInfo,
};
/// Describes a PCI device
#[derive(Clone)]
pub struct PciDeviceInfo {
/// Address of the device
pub address: PciAddress,
/// Configuration space access method
pub config_space: PciConfigSpace,
/// Describes the PCI segment this device is a part of
pub segment: Arc<PciSegmentInfo>,
pub(crate) interrupt_config: Arc<OneTimeInit<IrqSafeRwLock<InterruptConfig>>>,
}
pub struct InterruptConfig {
#[allow(unused)]
preferred_mode: PreferredInterruptMode,
configured_mode: ConfiguredInterruptMode,
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub enum PciInterruptPin {
A,
B,
C,
D,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum PreferredInterruptMode {
Msi,
Legacy,
}
enum ConfiguredInterruptMode {
MsiX(MsiXVectorTable<'static>),
Msi,
Legacy(PciInterruptPin),
None,
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct PciInterrupt {
pub address: PciAddress,
pub pin: PciInterruptPin,
}
#[derive(Clone, Copy, Debug)]
pub struct PciInterruptRoute {
pub number: u32,
pub options: IrqOptions,
}
pub enum PciMatch {
Generic(fn(&PciDeviceInfo) -> bool),
Vendor(u16, u16),
Class(u8, Option<u8>, Option<u8>),
}
pub struct PciDriver {
pub(crate) name: &'static str,
pub(crate) check: PciMatch,
pub(crate) probe: fn(&PciDeviceInfo) -> Result<&'static dyn Device, Error>,
}
/// Used to store PCI bus devices which were enumerated by the kernel
pub struct PciBusDevice {
pub(crate) info: PciDeviceInfo,
pub(crate) driver: Option<&'static dyn Device>,
}
impl PciDeviceInfo {
pub fn init_interrupts(&self, preferred_mode: PreferredInterruptMode) -> Result<(), Error> {
self.interrupt_config
.try_init_with(|| {
let configured_mode =
if self.segment.has_msi && preferred_mode == PreferredInterruptMode::Msi {
if let Some(mut msix) = self.config_space.capability::<MsiXCapability>() {
let mut vt = msix.vector_table().unwrap();
vt.mask_all();
msix.set_function_mask(false);
msix.set_enabled(true);
ConfiguredInterruptMode::MsiX(vt)
} else if self.config_space.capability::<MsiCapability>().is_some() {
ConfiguredInterruptMode::Msi
} else {
self.interrupt_mode_from_pin()
}
} else {
// Ignore preferred_mode, the only supported is Legacy
self.interrupt_mode_from_pin()
};
IrqSafeRwLock::new(InterruptConfig {
preferred_mode,
configured_mode,
})
})
.expect("Attempted to double-configure interrupts for a PCI device");
Ok(())
}
fn interrupt_mode_from_pin(&self) -> ConfiguredInterruptMode {
match self.config_space.interrupt_pin() {
Some(pin) => ConfiguredInterruptMode::Legacy(pin),
None => ConfiguredInterruptMode::None,
}
}
pub fn map_interrupt(
&self,
affinity: InterruptAffinity,
handler: &'static dyn InterruptHandler,
) -> Result<Option<MsiInfo>, Error> {
let mut irq = self.interrupt_config.get().write();
match &mut irq.configured_mode {
ConfiguredInterruptMode::MsiX(msix) => {
let info =
msix.register_range(0, 1, message_interrupt_controller(), affinity, handler)?;
Ok(Some(info[0]))
}
ConfiguredInterruptMode::Msi => {
let mut msi = self
.config_space
.capability::<MsiCapability>()
.ok_or(Error::InvalidOperation)?;
let info = msi.register(message_interrupt_controller(), affinity, handler)?;
Ok(Some(info))
}
ConfiguredInterruptMode::Legacy(pin) => {
self.try_map_legacy(*pin, handler)?;
Ok(None)
}
ConfiguredInterruptMode::None => Err(Error::InvalidOperation),
}
}
pub fn map_interrupt_multiple(
&self,
vector_range: Range<usize>,
affinity: InterruptAffinity,
handler: &'static dyn InterruptHandler,
) -> Result<Vec<MsiInfo>, Error> {
let mut irq = self.interrupt_config.get().write();
let start = vector_range.start;
let end = vector_range.end;
match &mut irq.configured_mode {
ConfiguredInterruptMode::MsiX(msix) => msix.register_range(
start,
end,
message_interrupt_controller(),
affinity,
handler,
),
_ => Err(Error::InvalidOperation),
}
}
fn try_map_legacy(
&self,
pin: PciInterruptPin,
handler: &'static dyn InterruptHandler,
) -> Result<(), Error> {
let src = PciInterrupt {
address: self.address,
pin,
};
let route = self
.segment
.irq_translation_map
.get(&src)
.ok_or(Error::InvalidOperation)?;
log::debug!(
"PCI {} pin {:?} -> system IRQ #{}",
src.address,
src.pin,
route.number
);
register_global_interrupt(route.number, route.options, handler)
}
}
impl TryFrom<u32> for PciInterruptPin {
type Error = ();
fn try_from(value: u32) -> Result<Self, Self::Error> {
match value {
1 => Ok(Self::A),
2 => Ok(Self::B),
3 => Ok(Self::C),
4 => Ok(Self::D),
_ => Err(()),
}
}
}

View File

@ -0,0 +1,624 @@
//! PCI/PCIe bus interfaces
#![no_std]
extern crate alloc;
use core::fmt;
#[cfg(target_arch = "x86_64")]
use acpi::mcfg::McfgEntry;
use alloc::{collections::BTreeMap, sync::Arc, vec::Vec};
use bitflags::bitflags;
use device::{PciBusDevice, PciDeviceInfo, PciDriver, PciInterrupt, PciInterruptRoute, PciMatch};
use device_api::Device;
use libk_mm::address::{FromRaw, PhysicalAddress};
use libk_util::{sync::IrqSafeSpinlock, OneTimeInit};
use yggdrasil_abi::error::Error;
pub mod capability;
pub mod device;
mod space;
pub use space::{
ecam::PciEcam, PciConfigSpace, PciConfigurationSpace, PciLegacyConfigurationSpace,
};
bitflags! {
/// Command register of the PCI configuration space
pub struct PciCommandRegister: u16 {
/// If set, I/O access to the device is enabled
const ENABLE_IO = 1 << 0;
/// If set, memory-mapped access to the device is enabled
const ENABLE_MEMORY = 1 << 1;
/// If set, the device can generate PCI bus accesses on its own
const BUS_MASTER = 1 << 2;
/// If set, interrupts are masked from being raised
const DISABLE_INTERRUPTS = 1 << 10;
}
}
bitflags! {
/// Status register of the PCI configuration space
pub struct PciStatusRegister: u16 {
/// Read-only. If set, the configuration space has a pointer to the capabilities list.
const CAPABILITIES_LIST = 1 << 4;
}
}
/// Represents the address of a single object on a bus (or the bus itself)
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct PciAddress {
/// PCIe segment group, ignored (?) with PCI
pub segment: u8,
/// Bus number
pub bus: u8,
/// Slot/device number
pub device: u8,
/// Function number
pub function: u8,
}
/// Address provided by PCI configuration space Base Address Register
#[derive(Debug, Clone, Copy)]
pub enum PciBaseAddress {
/// 32-bit memory address
Memory32(u32),
/// 64-bit memory address
Memory64(u64),
/// I/O space address
Io(u16),
}
/// Unique ID assigned to PCI capability structures
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[non_exhaustive]
#[repr(u8)]
pub enum PciCapabilityId {
/// MSI (32-bit or 64-bit)
Msi = 0x05,
/// Vendor-specific capability
VendorSpecific = 0x09,
/// MSI-X
MsiX = 0x11,
/// Unknown capability missing from this list
Unknown,
}
/// Interface used for querying PCI capabilities
#[allow(unused)]
pub trait PciCapability {
/// Capability ID
const ID: PciCapabilityId;
/// Wrapper for accessing the capability data structure
type CapabilityData<'a, S: PciConfigurationSpace + ?Sized + 'a>;
fn check<S: PciConfigurationSpace + ?Sized>(space: &S, offset: usize, len: usize) -> bool {
true
}
/// Constructs an access wrapper for this capability with given offset
fn data<'s, S: PciConfigurationSpace + ?Sized + 's>(
space: &'s S,
offset: usize,
len: usize,
) -> Self::CapabilityData<'s, S>;
}
struct BusAddressAllocator {
pci_base_64: u64,
pci_base_32: u32,
// pci_base_io: u16,
host_base_64: PhysicalAddress,
host_base_32: PhysicalAddress,
// host_base_io: PhysicalAddress,
size_64: usize,
size_32: usize,
// size_io: usize,
offset_64: u64,
offset_32: u32,
}
#[cfg_attr(target_arch = "x86_64", allow(dead_code))]
impl BusAddressAllocator {
pub fn from_ranges(ranges: &[PciAddressRange]) -> Self {
let mut range_32 = None;
let mut range_64 = None;
// let mut range_io = None;
for range in ranges {
let range_val = (range.pci_base, range.host_base, range.size);
match range.ty {
// PciRangeType::Io if range_io.is_none() => {
// range_io.replace(range_val);
// }
PciRangeType::Memory32 if range_32.is_none() => {
range_32.replace(range_val);
}
PciRangeType::Memory64 if range_64.is_none() => {
range_64.replace(range_val);
}
_ => (),
}
}
let (pci_base_32, host_base_32, size_32) = range_32.unwrap();
let (pci_base_64, host_base_64, size_64) = range_64.unwrap();
// let (pci_base_io, host_base_io, size_io) = range_io.unwrap();
Self {
pci_base_64,
pci_base_32: pci_base_32.try_into().unwrap(),
// pci_base_io: pci_base_io.try_into().unwrap(),
host_base_64,
host_base_32,
// host_base_io,
size_64,
size_32,
// size_io,
offset_64: 0,
offset_32: 0,
}
}
pub fn allocate(&mut self, ty: PciRangeType, size: usize) -> (PciBaseAddress, PhysicalAddress) {
match ty {
PciRangeType::Io => todo!(),
PciRangeType::Memory32 => {
if self.offset_32 as usize + size >= self.size_32 {
todo!();
}
let bar = PciBaseAddress::Memory32(self.pci_base_32 + self.offset_32);
let host = self.host_base_32.add(self.offset_32 as usize);
self.offset_32 += size as u32;
(bar, host)
}
PciRangeType::Memory64 => {
if self.offset_64 as usize + size >= self.size_64 {
todo!();
}
let bar = PciBaseAddress::Memory64(self.pci_base_64 + self.offset_64);
let host = self.host_base_64.add(self.offset_64 as usize);
self.offset_64 += size as u64;
(bar, host)
}
PciRangeType::Configuration => unimplemented!(),
}
}
}
#[derive(Debug)]
pub struct PciSegmentInfo {
pub segment_number: u8,
pub bus_number_start: u8,
pub bus_number_end: u8,
pub ecam_phys_base: Option<PhysicalAddress>,
pub irq_translation_map: BTreeMap<PciInterrupt, PciInterruptRoute>,
pub has_msi: bool,
}
/// Represents a single PCIe bus segment
pub struct PciBusSegment {
allocator: Option<BusAddressAllocator>,
info: Arc<PciSegmentInfo>,
devices: Vec<PciBusDevice>,
}
pub enum PciRangeType {
Configuration,
Io,
Memory32,
Memory64,
}
pub struct PciAddressRange {
pub ty: PciRangeType,
pub bus_number: u8,
pub pci_base: u64,
pub host_base: PhysicalAddress,
pub size: usize,
}
/// Manager struct to store and control all PCI devices in the system
pub struct PciBusManager {
segments: Vec<PciBusSegment>,
}
impl PciBaseAddress {
pub fn as_memory(self) -> Option<PhysicalAddress> {
match self {
Self::Memory32(address) => Some(PhysicalAddress::from_raw(address as u64)),
Self::Memory64(address) => Some(PhysicalAddress::from_raw(address)),
_ => None,
}
}
}
impl PciBusSegment {
fn probe_config_space(&self, address: PciAddress) -> Result<Option<PciConfigSpace>, Error> {
match self.info.ecam_phys_base {
Some(ecam_phys_base) => Ok(unsafe {
PciEcam::probe_raw_parts(ecam_phys_base, self.info.bus_number_start, address)?
}
.map(PciConfigSpace::Ecam)),
None => todo!(),
}
}
fn enumerate_function(&mut self, address: PciAddress) -> Result<(), Error> {
let Some(config) = self.probe_config_space(address)? else {
return Ok(());
};
let header_type = config.header_type();
// Enumerate multi-function devices
if address.function == 0 && header_type & 0x80 != 0 {
for function in 1..8 {
self.enumerate_function(address.with_function(function))?;
}
}
// PCI-to-PCI bridge
// if config.class_code() == 0x06 && config.subclass() == 0x04 {
// let secondary_bus = config.secondary_bus();
// // TODO
// }
if let Some(allocator) = self.allocator.as_mut() {
log::debug!("Remapping BARs for {}", address);
// Find valid BARs
let mut i = 0;
let mut bar_mask = 0;
while i < 6 {
let w0 = config.read_u32(0x10 + i * 4);
let bar_width = match w0 & 1 == 0 {
// Memory BAR
true => match (w0 >> 1) & 3 {
// 32-bit BAR
0 => 1,
// Reserved
1 => unimplemented!(),
// 64-bit BAR
2 => 2,
// Unknown
_ => unreachable!(),
},
false => 1,
};
bar_mask |= 1 << i;
i += bar_width;
}
for i in 0..6 {
if (1 << i) & bar_mask != 0 {
let orig_value = config.bar(i).unwrap();
let size = unsafe { config.bar_size(i) };
if size != 0 {
log::debug!("BAR{}: size={:#x}", i, size);
match orig_value {
PciBaseAddress::Io(_) => (),
PciBaseAddress::Memory64(_) => {
let (bar, host) = allocator.allocate(PciRangeType::Memory64, size);
let bar_address = bar.as_memory().unwrap();
unsafe {
config.set_bar(i, bar);
}
log::debug!(
"Mapped BAR{} -> pci {:#x} host {:#x}",
i,
bar_address,
host
);
// TODO Don't yet differentiate between Host/PCI addresses, lol
assert_eq!(bar_address, host);
}
PciBaseAddress::Memory32(_) => {
let (bar, host) = allocator.allocate(PciRangeType::Memory32, size);
let bar_address = bar.as_memory().unwrap();
unsafe {
config.set_bar(i, bar);
}
log::debug!(
"Mapped BAR{} -> pci {:#x} host {:#x}",
i,
bar_address,
host
);
// TODO Don't yet differentiate between Host/PCI addresses, lol
assert_eq!(bar_address, host);
}
}
}
}
}
}
let info = PciDeviceInfo {
address,
segment: self.info.clone(),
config_space: config,
interrupt_config: Arc::new(OneTimeInit::new()),
};
self.devices.push(PciBusDevice { info, driver: None });
Ok(())
}
fn enumerate_bus(&mut self, bus: u8) -> Result<(), Error> {
let address = PciAddress::for_bus(self.info.segment_number, bus);
for i in 0..32 {
let device_address = address.with_device(i);
self.enumerate_function(device_address)?;
}
Ok(())
}
/// Enumerates the bus segment, placing found devices into the manager
pub fn enumerate(&mut self) -> Result<(), Error> {
for bus in self.info.bus_number_start..self.info.bus_number_end {
self.enumerate_bus(bus)?;
}
Ok(())
}
}
impl PciBusManager {
const fn new() -> Self {
Self {
segments: Vec::new(),
}
}
/// Walks the bus device list and calls init/init_irq functions on any devices with associated
/// drivers
pub fn setup_bus_devices() -> Result<(), Error> {
log::info!("Setting up bus devices");
Self::walk_bus_devices(|device| {
log::info!("Set up {}", device.info.address);
setup_bus_device(device)?;
Ok(true)
})
}
/// Iterates over the bus devices, calling the function on each of them until either an error
/// or `Ok(false)` is returned
pub fn walk_bus_devices<F: FnMut(&mut PciBusDevice) -> Result<bool, Error>>(
mut f: F,
) -> Result<(), Error> {
let mut this = PCI_MANAGER.lock();
for segment in this.segments.iter_mut() {
for device in segment.devices.iter_mut() {
if !f(device)? {
return Ok(());
}
}
}
Ok(())
}
/// Enumerates a bus segment provided by ACPI MCFG table entry
#[cfg(target_arch = "x86_64")]
pub fn add_segment_from_mcfg(entry: &McfgEntry) -> Result<(), Error> {
let mut bus_segment = PciBusSegment {
info: Arc::new(PciSegmentInfo {
segment_number: entry.pci_segment_group as u8,
bus_number_start: entry.bus_number_start,
bus_number_end: entry.bus_number_end,
ecam_phys_base: Some(PhysicalAddress::from_raw(entry.base_address)),
// TODO obtain this from ACPI SSDT
irq_translation_map: BTreeMap::new(),
has_msi: true,
}),
// Firmware done this for us
allocator: None,
devices: Vec::new(),
};
let mut this = PCI_MANAGER.lock();
bus_segment.enumerate()?;
this.segments.push(bus_segment);
Ok(())
}
#[cfg(target_arch = "aarch64")]
pub fn add_segment_from_device_tree(
cfg_base: PhysicalAddress,
bus_range: core::ops::Range<u8>,
ranges: Vec<PciAddressRange>,
interrupt_map: BTreeMap<PciInterrupt, PciInterruptRoute>,
) -> Result<(), Error> {
let mut bus_segment = PciBusSegment {
info: Arc::new(PciSegmentInfo {
segment_number: 0,
bus_number_start: bus_range.start,
bus_number_end: bus_range.end,
ecam_phys_base: Some(cfg_base),
irq_translation_map: interrupt_map,
has_msi: false,
}),
allocator: Some(BusAddressAllocator::from_ranges(&ranges)),
devices: Vec::new(),
};
let mut this = PCI_MANAGER.lock();
bus_segment.enumerate()?;
this.segments.push(bus_segment);
Ok(())
}
}
impl fmt::Display for PciAddress {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}:{}", self.bus, self.device, self.function)
}
}
impl PciAddress {
/// Constructs a [PciAddress] representing a bus
pub const fn for_bus(segment: u8, bus: u8) -> Self {
Self {
segment,
bus,
device: 0,
function: 0,
}
}
/// Constructs a [PciAddress] representing a specific function
pub const fn for_function(segment: u8, bus: u8, device: u8, function: u8) -> Self {
Self {
segment,
bus,
device,
function,
}
}
/// Constructs a [PciAddress] representing a device on a given bus
pub const fn with_device(self, device: u8) -> Self {
Self {
device,
function: 0,
..self
}
}
/// Constructs a [PciAddress] representing a function of a given bus device
pub const fn with_function(self, function: u8) -> Self {
Self { function, ..self }
}
}
impl PciConfigurationSpace for PciConfigSpace {
fn read_u32(&self, offset: usize) -> u32 {
match self {
Self::Ecam(ecam) => ecam.read_u32(offset),
_ => todo!(),
}
}
fn write_u32(&self, offset: usize, value: u32) {
match self {
Self::Ecam(ecam) => ecam.write_u32(offset, value),
_ => todo!(),
}
}
}
fn setup_bus_device(device: &mut PciBusDevice) -> Result<(), Error> {
if device.driver.is_some() {
return Ok(());
}
let config = &device.info.config_space;
log::debug!(
"{}: {:04x}:{:04x}",
device.info.address,
config.vendor_id(),
config.device_id()
);
let class = config.class_code();
let subclass = config.subclass();
let prog_if = config.prog_if();
let drivers = PCI_DRIVERS.lock();
for driver in drivers.iter() {
if driver
.check
.check_device(&device.info, class, subclass, prog_if)
{
// TODO add the device to the bus
log::debug!(" -> {:?}", driver.name);
let instance = (driver.probe)(&device.info)?;
unsafe {
instance.init()?;
}
device.driver.replace(instance);
break;
} else {
log::debug!(" -> No driver");
}
}
Ok(())
}
impl PciMatch {
pub fn check_device(&self, info: &PciDeviceInfo, class: u8, subclass: u8, prog_if: u8) -> bool {
match self {
Self::Generic(f) => f(info),
&Self::Vendor(vendor_, device_) => {
info.config_space.vendor_id() == vendor_ && info.config_space.device_id() == device_
}
&Self::Class(class_, Some(subclass_), Some(prog_if_)) => {
class_ == class && subclass_ == subclass && prog_if_ == prog_if
}
&Self::Class(class_, Some(subclass_), _) => class_ == class && subclass_ == subclass,
&Self::Class(class_, _, _) => class_ == class,
}
}
}
pub fn register_class_driver(
name: &'static str,
class: u8,
subclass: Option<u8>,
prog_if: Option<u8>,
probe: fn(&PciDeviceInfo) -> Result<&'static dyn Device, Error>,
) {
PCI_DRIVERS.lock().push(PciDriver {
name,
check: PciMatch::Class(class, subclass, prog_if),
probe,
});
}
pub fn register_vendor_driver(
name: &'static str,
vendor_id: u16,
device_id: u16,
probe: fn(&PciDeviceInfo) -> Result<&'static dyn Device, Error>,
) {
PCI_DRIVERS.lock().push(PciDriver {
name,
check: PciMatch::Vendor(vendor_id, device_id),
probe,
});
}
pub fn register_generic_driver(
name: &'static str,
check: fn(&PciDeviceInfo) -> bool,
probe: fn(&PciDeviceInfo) -> Result<&'static dyn Device, Error>,
) {
PCI_DRIVERS.lock().push(PciDriver {
name,
check: PciMatch::Generic(check),
probe,
});
}
static PCI_DRIVERS: IrqSafeSpinlock<Vec<PciDriver>> = IrqSafeSpinlock::new(Vec::new());
static PCI_MANAGER: IrqSafeSpinlock<PciBusManager> = IrqSafeSpinlock::new(PciBusManager::new());

View File

@ -0,0 +1,60 @@
//! PCI Express ECAM interface
use libk_mm::{address::PhysicalAddress, device::DeviceMemoryMapping};
use yggdrasil_abi::error::Error;
use super::{PciAddress, PciConfigurationSpace};
/// PCI Express Enhanced Configuration Access Mechanism
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PciEcam {
mapping: DeviceMemoryMapping,
}
impl PciConfigurationSpace for PciEcam {
fn read_u32(&self, offset: usize) -> u32 {
assert_eq!(offset & 3, 0);
unsafe { ((self.mapping.address() + offset) as *const u32).read_volatile() }
}
fn write_u32(&self, offset: usize, value: u32) {
assert_eq!(offset & 3, 0);
unsafe { ((self.mapping.address() + offset) as *mut u32).write_volatile(value) }
}
}
impl PciEcam {
/// Maps the physical address of a ECAM space for kernel access.
///
/// # Safety
///
/// The `phys_addr` must be a valid ECAM address. The address must not alias any other mapped
/// regions. The address must be aligned to a 4KiB boundary and be valid for accesses within a
/// 4KiB-sized range.
pub unsafe fn map(phys_addr: PhysicalAddress) -> Result<Self, Error> {
let mapping = DeviceMemoryMapping::map(phys_addr, 0x1000, Default::default())?;
Ok(Self { mapping })
}
/// Checks if the ECAM contains a valid device configuration space, mapping and returning a
/// [PciEcam] if it does.
///
/// # Safety
///
/// See [PciEcam::map].
pub unsafe fn probe_raw_parts(
segment_phys_addr: PhysicalAddress,
bus_offset: u8,
address: PciAddress,
) -> Result<Option<Self>, Error> {
let phys_addr = segment_phys_addr.add(
((address.bus - bus_offset) as usize * 256
+ address.device as usize * 8
+ address.function as usize)
* 0x1000,
);
let this = Self::map(phys_addr)?;
Ok(if this.is_valid() { Some(this) } else { None })
}
}

View File

@ -0,0 +1,382 @@
use alloc::sync::Arc;
use super::{PciAddress, PciBaseAddress, PciCapability, PciCapabilityId, PciEcam};
use crate::{device::PciInterruptPin, PciCommandRegister, PciStatusRegister};
pub(super) mod ecam;
macro_rules! pci_config_field_getter {
($self:ident, u32, $offset:expr) => {
$self.read_u32($offset)
};
($self:ident, u16, $offset:expr) => {
$self.read_u16($offset)
};
($self:ident, u8, $offset:expr) => {
$self.read_u8($offset)
};
}
macro_rules! pci_config_field_setter {
($self:ident, u32, $offset:expr, $value:expr) => {
$self.write_u32($offset, $value)
};
($self:ident, u16, $offset:expr, $value:expr) => {{
$self.write_u16($offset, $value)
}};
($self:ident, u8, $offset:expr, $value:expr) => {
$self.write_u8($offset, $value)
};
}
macro_rules! pci_config_field {
(
$offset:expr => $ty:ident,
$(#[$getter_meta:meta])* $getter:ident
$(, $(#[$setter_meta:meta])* $setter:ident)?
) => {
$(#[$getter_meta])*
fn $getter(&self) -> $ty {
pci_config_field_getter!(self, $ty, $offset)
}
$(
$(#[$setter_meta])*
fn $setter(&self, value: $ty) {
pci_config_field_setter!(self, $ty, $offset, value)
}
)?
};
}
/// Provides access to the legacy (port I/O-driven) PCI configuration space
#[derive(Debug)]
#[repr(transparent)]
pub struct PciLegacyConfigurationSpace {
#[allow(unused)]
address: PciAddress,
}
/// Describes a configuration space access method for a PCI device
#[derive(Debug, Clone)]
pub enum PciConfigSpace {
/// Legacy configuration space.
///
/// See [PciLegacyConfigurationSpace].
Legacy(PciAddress),
/// Enhanced Configuration Access Mechanism (PCIe).
///
/// See [PciEcam].
Ecam(PciEcam),
}
pub struct CapabilityIterator<'s, S: PciConfigurationSpace + ?Sized> {
space: &'s S,
current: Option<usize>,
}
impl<'s, S: PciConfigurationSpace + ?Sized> Iterator for CapabilityIterator<'s, S> {
type Item = (PciCapabilityId, usize, usize);
fn next(&mut self) -> Option<Self::Item> {
let offset = self.current? & !0x3;
let id = unsafe { core::mem::transmute(self.space.read_u8(offset)) };
let len = self.space.read_u8(offset + 2);
let next_pointer = self.space.read_u8(offset + 1);
self.current = if next_pointer != 0 {
Some(next_pointer as usize)
} else {
None
};
Some((id, offset, len as usize))
}
}
/// Interface for accessing the configuration space of a device
pub trait PciConfigurationSpace {
/// Reads a 32-bit value from the device configuration space.
///
/// # Note
///
/// The `offset` must be u32-aligned.
fn read_u32(&self, offset: usize) -> u32;
/// Writes a 32-bit value to the device configuration space.
///
/// # Note
///
/// The `offset` must be u32-aligned.
fn write_u32(&self, offset: usize, value: u32);
/// Reads a 16-bit value from the device configuration space.
///
/// # Note
///
/// The `offset` must be u16-aligned.
fn read_u16(&self, offset: usize) -> u16 {
assert_eq!(offset & 1, 0);
let value = self.read_u32(offset & !3);
(value >> ((offset & 3) * 8)) as u16
}
/// Reads a byte from the device configuration space
fn read_u8(&self, offset: usize) -> u8 {
let value = self.read_u32(offset & !3);
(value >> ((offset & 3) * 8)) as u8
}
/// Writes a 16-bit value to the device configuration space.
///
/// # Note
///
/// The `offset` must be u16-aligned.
fn write_u16(&self, offset: usize, value: u16) {
let shift = ((offset >> 1) & 1) << 4;
assert_eq!(offset & 1, 0);
let mut tmp = self.read_u32(offset & !3);
tmp &= !(0xFFFF << shift);
tmp |= (value as u32) << shift;
self.write_u32(offset & !3, tmp);
}
/// Writes a byte to the device configuration space
fn write_u8(&self, _offset: usize, _value: u16) {
todo!()
}
/// Returns `true` if the device is present on the bus (i.e. configuration space is not filled
/// with only 1's)
fn is_valid(&self) -> bool {
self.vendor_id() != 0xFFFF && self.device_id() != 0xFFFF
}
pci_config_field!(
0x00 => u16,
#[doc = "Returns the Vendor ID"] vendor_id
);
pci_config_field!(0x02 => u16,
#[doc = "Returns the Device ID"] device_id
);
pci_config_field!(
0x04 => u16,
#[doc = "Returns the value of the command register"] command,
#[doc = "Writes to the command word register"] set_command
);
pci_config_field!(
0x06 => u16,
#[doc = "Returns the value of the status register"] status
);
pci_config_field!(
0x08 => u8,
#[doc = "Returns the device Revision ID"]
rev_id
);
pci_config_field!(
0x09 => u8,
#[doc = "Returns the device Prog IF field"]
prog_if
);
pci_config_field!(
0x0A => u8,
#[doc = "Returns the device Subclass field"]
subclass
);
pci_config_field!(
0x0B => u8,
#[doc = "Returns the device Class Code field"]
class_code
);
// ...
pci_config_field!(
0x0E => u8,
#[doc = "Returns the header type of the device"]
header_type
);
pci_config_field!(
0x19 => u8,
#[doc = r#"
Returns the secondary bus number associated with this device
# Note
The function is only valid for devices with `header_type() == 1`
"#]
secondary_bus
);
pci_config_field!(
0x34 => u8,
#[doc =
r"Returns the offset within the configuration space where the Capabilities List
is located. Only valid if the corresponding Status Register bit is set"
]
capability_pointer
);
fn interrupt_pin(&self) -> Option<PciInterruptPin> {
PciInterruptPin::try_from(self.read_u8(0x3D) as u32).ok()
}
unsafe fn bar_size(&self, index: usize) -> usize {
let cmd = self.command();
// Disable I/O and memory
self.set_command(
cmd & !(PciCommandRegister::ENABLE_IO | PciCommandRegister::ENABLE_MEMORY).bits(),
);
let orig_value = self.bar(index).unwrap();
// TODO preserve prefetch bit
let mask_value = match orig_value {
PciBaseAddress::Io(_) => PciBaseAddress::Io(0xFFFC),
PciBaseAddress::Memory32(_) => PciBaseAddress::Memory32(0xFFFFFFF0),
PciBaseAddress::Memory64(_) => PciBaseAddress::Memory64(0xFFFFFFFFFFFFFFF0),
};
self.set_bar(index, mask_value);
let new_value = self.bar(index).unwrap();
let size = match new_value {
PciBaseAddress::Io(address) if address != 0 => ((!address) + 1) as usize,
PciBaseAddress::Memory32(address) if address != 0 => ((!address) + 1) as usize,
PciBaseAddress::Memory64(address) if address != 0 => ((!address) + 1) as usize,
_ => 0,
};
self.set_bar(index, orig_value);
self.set_command(cmd);
size
}
/// Updates the value of the Base Address Register with given index.
///
/// # Note
///
/// The function is only valid for devices with `header_type() == 0`
///
/// The `index` corresponds to the actual configuration space BAR index.
unsafe fn set_bar(&self, index: usize, value: PciBaseAddress) {
assert!(index < 6);
match value {
PciBaseAddress::Io(value) => {
self.write_u32(0x10 + index * 4, ((value as u32) & !0x3) | 1)
}
PciBaseAddress::Memory32(address) => self.write_u32(0x10 + index * 4, address & !0xF),
PciBaseAddress::Memory64(address) => {
self.write_u32(0x10 + index * 4, ((address as u32) & !0xF) | (2 << 1));
self.write_u32(0x10 + (index + 1) * 4, (address >> 32) as u32);
}
}
}
/// Returns the value of the Base Address Register with given index.
///
/// # Note
///
/// The function is only valid for devices with `header_type() == 0`
///
/// The `index` corresponds to the actual configuration space BAR index, i.e. if a 64-bit
/// address occupies [BAR0, BAR1] and BAR 1 is requested, the function will return [None].
fn bar(&self, index: usize) -> Option<PciBaseAddress> {
assert!(index < 6);
if index % 2 == 0 {
let w0 = self.read_u32(0x10 + index * 4);
match w0 & 1 {
0 => match (w0 >> 1) & 3 {
0 => {
// 32-bit memory BAR
Some(PciBaseAddress::Memory32(w0 & !0xF))
}
2 => {
// 64-bit memory BAR
let w1 = self.read_u32(0x10 + (index + 1) * 4);
Some(PciBaseAddress::Memory64(
((w1 as u64) << 32) | ((w0 as u64) & !0xF),
))
}
_ => unimplemented!(),
},
1 => Some(PciBaseAddress::Io((w0 as u16) & !0x3)),
_ => unreachable!(),
}
} else {
let prev_w0 = self.read_u32(0x10 + (index - 1) * 4);
if prev_w0 & 0x7 == 0x4 {
// Previous BAR is 64-bit memory and this one is its continuation
return None;
}
let w0 = self.read_u32(0x10 + index * 4);
match w0 & 1 {
0 => match (w0 >> 1) & 3 {
0 => {
// 32-bit memory BAR
Some(PciBaseAddress::Memory32(w0 & !0xF))
}
// TODO can 64-bit BARs not be on a 64-bit boundary?
2 => todo!(),
_ => unimplemented!(),
},
1 => todo!(),
_ => unreachable!(),
}
}
}
/// Returns an iterator over the PCI capabilities
fn capability_iter(&self) -> CapabilityIterator<Self> {
let status = PciStatusRegister::from_bits_retain(self.status());
let current = if status.contains(PciStatusRegister::CAPABILITIES_LIST) {
let ptr = self.capability_pointer() as usize;
if ptr != 0 {
Some(self.capability_pointer() as usize)
} else {
None
}
} else {
// Return an empty iterator
None
};
CapabilityIterator {
space: self,
current,
}
}
/// Locates a capability within this configuration space
fn capability<C: PciCapability>(&self) -> Option<C::CapabilityData<'_, Self>> {
self.capability_iter().find_map(|(id, offset, len)| {
if id == C::ID && C::check(self, offset, len) {
Some(C::data(self, offset, len))
} else {
None
}
})
}
}
impl<T: PciConfigurationSpace> PciConfigurationSpace for Arc<T> {
fn read_u32(&self, offset: usize) -> u32 {
T::read_u32(self.as_ref(), offset)
}
fn write_u32(&self, offset: usize, value: u32) {
T::write_u32(self.as_ref(), offset, value);
}
}

View File

@ -0,0 +1,18 @@
[package]
name = "ygg_driver_usb"
version = "0.1.0"
edition = "2021"
authors = ["Mark Poliakov <mark@alnyan.me>"]
[dependencies]
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
device-api = { path = "../../../lib/device-api", features = ["derive"] }
ygg_driver_input = { path = "../../input" }
libk-util = { path = "../../../libk/libk-util" }
libk-mm = { path = "../../../libk/libk-mm" }
libk-thread = { path = "../../../libk/libk-thread" }
log = "0.4.20"
bytemuck = { version = "1.14.0", features = ["derive"] }
futures-util = { version = "0.3.28", default-features = false, features = ["alloc", "async-await"] }

View File

@ -0,0 +1,63 @@
use core::sync::atomic::{AtomicU16, Ordering};
use alloc::{collections::BTreeMap, sync::Arc};
use libk_util::{queue::UnboundedMpmcQueue, sync::spin_rwlock::IrqSafeRwLock};
use crate::{
class_driver,
device::{UsbBusAddress, UsbDeviceAccess},
UsbHostController,
};
pub struct UsbBusManager {
busses: IrqSafeRwLock<BTreeMap<u16, &'static dyn UsbHostController>>,
devices: IrqSafeRwLock<BTreeMap<UsbBusAddress, Arc<UsbDeviceAccess>>>,
last_bus_address: AtomicU16,
}
impl UsbBusManager {
pub fn register_bus(hc: &'static dyn UsbHostController) -> u16 {
let i = BUS_MANAGER.last_bus_address.fetch_add(1, Ordering::AcqRel);
BUS_MANAGER.busses.write().insert(i, hc);
i
}
pub fn register_device(device: Arc<UsbDeviceAccess>) {
BUS_MANAGER
.devices
.write()
.insert(device.bus_address(), device.clone());
QUEUE.push_back(device);
}
pub fn detach_device(address: UsbBusAddress) {
if let Some(device) = BUS_MANAGER.devices.write().remove(&address) {
device.handle_detach();
}
}
}
pub async fn bus_handler() {
class_driver::register_default_class_drivers();
loop {
let new_device = QUEUE.pop_front().await;
log::info!(
"New {:?}-speed USB device connected: {}",
new_device.speed(),
new_device.bus_address()
);
class_driver::spawn_driver(new_device).await.ok();
}
}
static BUS_MANAGER: UsbBusManager = UsbBusManager {
busses: IrqSafeRwLock::new(BTreeMap::new()),
devices: IrqSafeRwLock::new(BTreeMap::new()),
last_bus_address: AtomicU16::new(0),
};
static QUEUE: UnboundedMpmcQueue<Arc<UsbDeviceAccess>> = UnboundedMpmcQueue::new();

View File

@ -0,0 +1,297 @@
use alloc::{sync::Arc, vec::Vec};
use futures_util::future::BoxFuture;
use libk_thread::runtime;
use libk_util::sync::spin_rwlock::IrqSafeRwLock;
use crate::{
device::UsbDeviceAccess,
error::UsbError,
info::{UsbDeviceClass, UsbDeviceProtocol},
};
pub struct UsbClassInfo {
pub class: UsbDeviceClass,
pub subclass: u8,
pub protocol: UsbDeviceProtocol,
}
pub trait UsbDriver: Send + Sync {
fn name(&self) -> &'static str;
fn run(
self: Arc<Self>,
device: Arc<UsbDeviceAccess>,
) -> BoxFuture<'static, Result<(), UsbError>>;
fn probe(&self, class: &UsbClassInfo, device: &UsbDeviceAccess) -> bool;
}
async fn extract_class_info(device: &UsbDeviceAccess) -> Result<Option<UsbClassInfo>, UsbError> {
if device.num_configurations != 1 {
return Ok(None);
}
let device_info = &device.info;
let config_info = device.query_configuration_info(0).await?;
if config_info.interfaces.len() == 1 {
let if_info = &config_info.interfaces[0];
let class = if device_info.device_class == UsbDeviceClass::FromInterface {
if_info.interface_class
} else {
device_info.device_class
};
let subclass = if device_info.device_subclass == 0 {
if_info.interface_subclass
} else {
device_info.device_subclass
};
let protocol = if device_info.device_protocol == UsbDeviceProtocol::FromInterface {
if_info.interface_protocol
} else {
device_info.device_protocol
};
Ok(Some(UsbClassInfo {
class,
subclass,
protocol,
}))
} else {
Ok(None)
}
}
async fn pick_driver(
device: &UsbDeviceAccess,
) -> Result<Option<Arc<dyn UsbDriver + 'static>>, UsbError> {
let Some(class) = extract_class_info(&device).await? else {
return Ok(None);
};
for driver in USB_DEVICE_DRIVERS.read().iter() {
if driver.probe(&class, device) {
return Ok(Some(driver.clone()));
}
}
Ok(None)
}
pub async fn spawn_driver(device: Arc<UsbDeviceAccess>) -> Result<(), UsbError> {
if let Some(driver) = pick_driver(&device).await? {
runtime::spawn(async move {
let name = driver.name();
match driver.run(device).await {
e @ Err(UsbError::DeviceDisconnected) => {
log::warn!(
"Driver {:?} did not exit cleanly: device disconnected",
name,
);
e
}
e => e,
}
})
.map_err(UsbError::SystemError)?;
}
Ok(())
}
pub fn register_driver(driver: Arc<dyn UsbDriver + 'static>) {
// TODO check for duplicates
USB_DEVICE_DRIVERS.write().push(driver);
}
pub fn register_default_class_drivers() {
register_driver(Arc::new(hid_keyboard::UsbHidKeyboardDriver));
}
static USB_DEVICE_DRIVERS: IrqSafeRwLock<Vec<Arc<dyn UsbDriver + 'static>>> =
IrqSafeRwLock::new(Vec::new());
pub mod hid_keyboard {
use core::mem::MaybeUninit;
use alloc::sync::Arc;
use futures_util::{future::BoxFuture, FutureExt};
use libk_mm::PageBox;
use yggdrasil_abi::io::{KeyboardKey, KeyboardKeyEvent};
use crate::{device::UsbDeviceAccess, error::UsbError, info::UsbDeviceClass};
use super::{UsbClassInfo, UsbDriver};
pub struct UsbHidKeyboardDriver;
const MODIFIER_MAP: &[KeyboardKey] = &[
KeyboardKey::LControl,
KeyboardKey::LShift,
KeyboardKey::LAlt,
KeyboardKey::Unknown,
KeyboardKey::RControl,
KeyboardKey::RShift,
KeyboardKey::RAlt,
KeyboardKey::Unknown,
];
#[derive(Default)]
struct KeyboardState {
state: [u64; 4],
mods: u8,
}
impl KeyboardState {
pub fn new() -> Self {
Self::default()
}
pub fn translate_key(k: u8) -> KeyboardKey {
match k {
4..=29 => KeyboardKey::Char(k - 4 + b'a'),
30..=38 => KeyboardKey::Char(k - 30 + b'1'),
39 => KeyboardKey::Char(b'0'),
40 => KeyboardKey::Enter,
41 => KeyboardKey::Escape,
42 => KeyboardKey::Backspace,
43 => KeyboardKey::Tab,
44 => KeyboardKey::Char(b' '),
45 => KeyboardKey::Char(b'-'),
46 => KeyboardKey::Char(b'='),
47 => KeyboardKey::Char(b'['),
48 => KeyboardKey::Char(b']'),
49 => KeyboardKey::Char(b'\\'),
51 => KeyboardKey::Char(b';'),
52 => KeyboardKey::Char(b'\''),
53 => KeyboardKey::Char(b'`'),
54 => KeyboardKey::Char(b','),
55 => KeyboardKey::Char(b'.'),
56 => KeyboardKey::Char(b'/'),
58..=69 => KeyboardKey::F(k - 58),
_ => {
log::debug!("Unknown key: {}", k);
KeyboardKey::Unknown
}
}
}
pub fn retain_modifiers(
&mut self,
m: u8,
events: &mut [MaybeUninit<KeyboardKeyEvent>],
) -> usize {
let mut count = 0;
let released = self.mods & !m;
for i in 0..8 {
if released & (1 << i) != 0 {
events[count].write(KeyboardKeyEvent::Released(MODIFIER_MAP[i]));
count += 1;
}
}
self.mods &= m;
count
}
pub fn press_modifiers(
&mut self,
m: u8,
events: &mut [MaybeUninit<KeyboardKeyEvent>],
) -> usize {
let mut count = 0;
let pressed = m & !self.mods;
for i in 0..8 {
if pressed & (1 << i) != 0 {
events[count].write(KeyboardKeyEvent::Pressed(MODIFIER_MAP[i]));
count += 1;
}
}
self.mods = m;
count
}
pub fn retain(
&mut self,
keys: &[u8],
events: &mut [MaybeUninit<KeyboardKeyEvent>],
) -> usize {
let mut count = 0;
for i in 1..256 {
if self.state[i / 64] & (1 << (i % 64)) != 0 {
if !keys.contains(&(i as u8)) {
events[count]
.write(KeyboardKeyEvent::Released(Self::translate_key(i as u8)));
self.state[i / 64] &= !(1 << (i % 64));
count += 1;
}
}
}
count
}
pub fn press(
&mut self,
keys: &[u8],
events: &mut [MaybeUninit<KeyboardKeyEvent>],
) -> usize {
let mut count = 0;
for &k in keys {
let index = (k as usize) / 64;
if self.state[index] & (1 << (k % 64)) == 0 {
self.state[index] |= 1 << (k % 64);
events[count].write(KeyboardKeyEvent::Pressed(Self::translate_key(k)));
count += 1;
}
}
count
}
}
impl UsbDriver for UsbHidKeyboardDriver {
fn run(
self: Arc<Self>,
device: Arc<UsbDeviceAccess>,
) -> BoxFuture<'static, Result<(), UsbError>> {
async move {
// TODO not sure whether to use boot protocol (easy) or GetReport
let config = device.select_configuration(|_| true).await?.unwrap();
assert_eq!(config.endpoints.len(), 1);
let pipe = device.open_interrupt_in_pipe(1).await?;
let mut buffer = PageBox::new_slice(0, 8).map_err(UsbError::MemoryError)?;
let mut state = KeyboardState::new();
let mut events = [MaybeUninit::uninit(); 16];
loop {
let mut event_count = 0;
let data = pipe.read(&mut buffer).await?;
event_count += state.retain_modifiers(data[0], &mut events);
event_count += state.press_modifiers(data[0], &mut events[event_count..]);
event_count += state.retain(&data[2..], &mut events[event_count..]);
event_count += state.press(&data[2..], &mut events[event_count..]);
let events =
unsafe { MaybeUninit::slice_assume_init_ref(&events[..event_count]) };
for &event in events {
log::debug!("Generic Keyboard: {:?}", event);
ygg_driver_input::send_event(event);
}
}
}
.boxed()
}
fn name(&self) -> &'static str {
"USB HID Keyboard"
}
fn probe(&self, class: &UsbClassInfo, _device: &UsbDeviceAccess) -> bool {
class.class == UsbDeviceClass::Hid && class.subclass == 0x01
}
}
}

View File

@ -0,0 +1,124 @@
use core::{
future::poll_fn,
sync::atomic::{AtomicU32, Ordering},
task::{Context, Poll},
};
use alloc::{sync::Arc, vec::Vec};
use futures_util::task::AtomicWaker;
use libk_mm::address::PhysicalAddress;
use crate::error::UsbError;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum UsbDirection {
Out,
In,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[repr(transparent)]
pub struct UsbTransferToken(pub u64);
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[repr(transparent)]
pub struct UsbTransferResult(pub u32);
pub struct UsbTransferStatus {
pub result: AtomicU32,
pub notify: AtomicWaker,
}
pub struct UsbControlTransfer {
pub id: UsbTransferToken,
pub length: usize,
pub direction: UsbDirection,
pub elements: Vec<PhysicalAddress>,
pub status: Arc<UsbTransferStatus>,
}
pub struct UsbInterruptTransfer {
pub address: PhysicalAddress,
pub length: usize,
pub direction: UsbDirection,
pub status: Arc<UsbTransferStatus>,
}
impl UsbDirection {
pub const fn is_device_to_host(self) -> bool {
matches!(self, UsbDirection::In)
}
}
// TODO this is xHCI-specific
impl UsbTransferResult {
pub fn is_aborted(&self) -> bool {
self.0 == u32::MAX
}
pub fn is_success(&self) -> bool {
(self.0 >> 24) & 0xFF == 1
}
pub fn sub_length(&self) -> usize {
(self.0 & 0xFFFFFF) as _
}
}
impl UsbControlTransfer {
pub async fn wait(&self) -> Result<usize, UsbError> {
let sub_length = self.status.wait().await?;
Ok(self.length.saturating_sub(sub_length))
}
}
impl UsbInterruptTransfer {
pub async fn wait(&self) -> Result<usize, UsbError> {
let sub_length = self.status.wait().await?;
Ok(self.length.saturating_sub(sub_length))
}
}
impl UsbTransferStatus {
pub fn new() -> Self {
Self {
result: AtomicU32::new(0),
notify: AtomicWaker::new(),
}
}
pub(crate) async fn wait(&self) -> Result<usize, UsbError> {
poll_fn(|cx| {
self.poll(cx).map(|v| {
if v.is_success() {
Ok(v.sub_length())
} else if v.is_aborted() {
Err(UsbError::DeviceDisconnected)
} else {
Err(UsbError::TransferFailed)
}
})
})
.await
}
pub fn signal(&self, status: u32) {
self.result.store(status, Ordering::Release);
self.notify.wake();
}
pub fn abort(&self) {
self.result.store(u32::MAX, Ordering::Release);
self.notify.wake();
}
pub fn poll(&self, cx: &mut Context<'_>) -> Poll<UsbTransferResult> {
self.notify.register(cx.waker());
let value = self.result.load(Ordering::Acquire);
if value != 0 {
Poll::Ready(UsbTransferResult(value))
} else {
Poll::Pending
}
}
}

View File

@ -0,0 +1,146 @@
use bytemuck::{Pod, Zeroable};
use crate::{
error::UsbError,
info::{UsbDeviceClass, UsbDeviceProtocol, UsbEndpointType},
UsbDirection,
};
#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
#[repr(C, packed)]
pub struct UsbDeviceDescriptor {
pub length: u8,
pub ty: u8,
pub bcd_usb: u16,
pub device_class: u8,
pub device_subclass: u8,
pub device_protocol: u8,
pub max_packet_size_0: u8,
pub id_vendor: u16,
pub id_product: u16,
pub bcd_device: u16,
pub manufacturer_str: u8,
pub product_str: u8,
pub serial_number_str: u8,
pub num_configurations: u8,
}
#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
#[repr(C, packed)]
pub struct UsbConfigurationDescriptor {
pub length: u8,
pub ty: u8,
pub total_length: u16,
pub num_interfaces: u8,
pub config_val: u8,
pub config_str: u8,
pub attributes: u8,
pub max_power: u8,
}
#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
#[repr(C, packed)]
pub struct UsbInterfaceDescriptor {
pub length: u8,
pub ty: u8,
pub interface_number: u8,
pub alternate_setting: u8,
pub num_endpoints: u8,
pub interface_class: u8,
pub interface_subclass: u8,
pub interface_protocol: u8,
pub interface_str: u8,
}
#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
#[repr(C, packed)]
pub struct UsbEndpointDescriptor {
pub length: u8,
pub ty: u8,
pub endpoint_address: u8,
pub attributes: u8,
pub max_packet_size: u16,
pub interval: u8,
}
#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
#[repr(C, packed)]
pub struct UsbDeviceQualifier {
pub length: u8,
pub ty: u8,
pub bcd_usb: u16,
pub device_class: u8,
pub device_subclass: u8,
pub device_protocol: u8,
pub max_packet_size_0: u8,
pub num_configurations: u8,
pub _reserved: u8,
}
#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
#[repr(C, packed)]
pub struct UsbOtherSpeedConfiguration {
pub length: u8,
pub ty: u8,
pub total_length: u16,
pub num_interfaces: u8,
pub config_val: u8,
pub config_str: u8,
pub attributes: u8,
pub max_power: u8,
}
impl UsbInterfaceDescriptor {
pub fn class(&self) -> UsbDeviceClass {
UsbDeviceClass::try_from(self.interface_class).unwrap_or(UsbDeviceClass::Unknown)
}
pub fn protocol(&self) -> UsbDeviceProtocol {
UsbDeviceProtocol::try_from(self.interface_protocol).unwrap_or(UsbDeviceProtocol::Unknown)
}
}
impl UsbEndpointDescriptor {
pub fn direction(&self) -> UsbDirection {
match self.endpoint_address >> 7 {
1 => UsbDirection::In,
0 => UsbDirection::Out,
_ => unreachable!(),
}
}
pub fn number(&self) -> u8 {
assert_ne!(self.endpoint_address & 0xF, 0);
self.endpoint_address & 0xF
}
pub fn transfer_type(&self) -> UsbEndpointType {
match self.attributes & 0x3 {
0 => UsbEndpointType::Control,
1 => UsbEndpointType::Isochronous,
2 => UsbEndpointType::Bulk,
3 => UsbEndpointType::Interrupt,
_ => unreachable!(),
}
}
}
impl UsbDeviceDescriptor {
pub fn class(&self) -> UsbDeviceClass {
UsbDeviceClass::try_from(self.device_class).unwrap_or(UsbDeviceClass::Unknown)
}
pub fn protocol(&self) -> UsbDeviceProtocol {
UsbDeviceProtocol::try_from(self.device_protocol).unwrap_or(UsbDeviceProtocol::Unknown)
}
pub fn max_packet_size(&self) -> Result<usize, UsbError> {
match self.max_packet_size_0 {
8 => Ok(8),
16 => Ok(16),
32 => Ok(32),
64 => Ok(64),
_ => Err(UsbError::InvalidDescriptorField),
}
}
}

View File

@ -0,0 +1,204 @@
use core::{fmt, ops::Deref};
use alloc::{boxed::Box, vec::Vec};
use futures_util::future::BoxFuture;
use libk_mm::PageBox;
use libk_util::sync::spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockReadGuard};
use crate::{
error::UsbError,
info::{UsbConfigurationInfo, UsbDeviceInfo, UsbEndpointInfo, UsbInterfaceInfo},
pipe::{
control::{ConfigurationDescriptorEntry, UsbControlPipeAccess},
interrupt::UsbInterruptInPipeAccess,
},
UsbHostController,
};
// High-level structures for info provided through descriptors
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct UsbBusAddress {
pub bus: u16,
pub device: u8,
}
pub struct UsbDeviceAccess {
pub device: Box<dyn UsbDevice>,
pub info: UsbDeviceInfo,
pub num_configurations: u8,
pub current_configuration: IrqSafeRwLock<Option<UsbConfigurationInfo>>,
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub enum UsbSpeed {
Low,
Full,
High,
Super,
}
#[allow(unused)]
pub trait UsbDevice: Send + Sync {
// Endpoint "0"
fn control_pipe(&self) -> &UsbControlPipeAccess;
fn open_interrupt_in_pipe<'a>(
&'a self,
number: u8,
) -> BoxFuture<Result<UsbInterruptInPipeAccess, UsbError>> {
unimplemented!()
}
fn port_number(&self) -> u8;
fn bus_address(&self) -> UsbBusAddress;
fn speed(&self) -> UsbSpeed;
fn controller(&self) -> &'static dyn UsbHostController;
fn handle_detach(&self);
fn debug(&self) {}
}
impl UsbDeviceAccess {
/// Expected device state:
///
/// * Link-layer stuff has been reset and established properly by the HCD
/// * Device is not yet configured
/// * Control pipe for the device has been properly set up
/// * Device has been assigned a bus address
pub async fn setup(raw: Box<dyn UsbDevice>) -> Result<Self, UsbError> {
let control = raw.control_pipe();
let mut string_buffer = PageBox::new_uninit().map_err(UsbError::MemoryError)?;
let device_desc = control.query_device_descriptor().await?;
let manufacturer = control
.query_string(device_desc.manufacturer_str, &mut string_buffer)
.await?;
let product = control
.query_string(device_desc.product_str, &mut string_buffer)
.await?;
let info = UsbDeviceInfo {
manufacturer,
product,
id_vendor: device_desc.id_vendor,
id_product: device_desc.id_product,
device_class: device_desc.class(),
device_subclass: device_desc.device_subclass,
device_protocol: device_desc.protocol(),
max_packet_size: device_desc.max_packet_size()?,
};
Ok(Self {
device: raw,
info,
num_configurations: device_desc.num_configurations,
current_configuration: IrqSafeRwLock::new(None),
})
}
pub fn read_current_configuration(
&self,
) -> IrqSafeRwLockReadGuard<Option<UsbConfigurationInfo>> {
self.current_configuration.read()
}
pub async fn select_configuration<F: Fn(&UsbConfigurationInfo) -> bool>(
&self,
predicate: F,
) -> Result<Option<UsbConfigurationInfo>, UsbError> {
let mut current_config = self.current_configuration.write();
let control_pipe = self.control_pipe();
for i in 0..self.num_configurations {
let info = self.query_configuration_info(i).await?;
if predicate(&info) {
log::debug!("Selected configuration: {:#?}", info);
let config = current_config.insert(info);
control_pipe
.set_configuration(config.config_value as _)
.await?;
return Ok(Some(config.clone()));
}
}
Ok(None)
}
pub async fn query_configuration_info(
&self,
index: u8,
) -> Result<UsbConfigurationInfo, UsbError> {
if index >= self.num_configurations {
return Err(UsbError::InvalidConfiguration);
}
let mut string_buffer = PageBox::new_uninit().map_err(UsbError::MemoryError)?;
let control_pipe = self.control_pipe();
let query = control_pipe.query_configuration_descriptor(index).await?;
let configuration_name = control_pipe
.query_string(query.configuration().config_str, &mut string_buffer)
.await?;
let mut endpoints = Vec::new();
let mut interfaces = Vec::new();
for desc in query.descriptors() {
match desc {
ConfigurationDescriptorEntry::Endpoint(ep) => {
endpoints.push(UsbEndpointInfo {
number: ep.number(),
direction: ep.direction(),
max_packet_size: ep.max_packet_size as _,
ty: ep.transfer_type(),
});
}
ConfigurationDescriptorEntry::Interface(iface) => {
let name = control_pipe
.query_string(iface.interface_str, &mut string_buffer)
.await?;
interfaces.push(UsbInterfaceInfo {
name,
number: iface.interface_number,
interface_class: iface.class(),
interface_subclass: iface.interface_subclass,
interface_protocol: iface.protocol(),
});
}
_ => (),
}
}
let info = UsbConfigurationInfo {
name: configuration_name,
config_value: query.configuration().config_val,
interfaces,
endpoints,
};
Ok(info)
}
}
impl Deref for UsbDeviceAccess {
type Target = dyn UsbDevice;
fn deref(&self) -> &Self::Target {
&*self.device
}
}
impl fmt::Display for UsbBusAddress {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.bus, self.device)
}
}

View File

@ -0,0 +1,30 @@
use yggdrasil_abi::error::Error;
#[derive(Debug)]
pub enum UsbError {
/// Could not allocate memory for some device structure
MemoryError(Error),
/// Other system errors
SystemError(Error),
// HC-side init stage errors
OutOfAddresses,
HostControllerCommandFailed(u8),
PortResetFailed,
PortInitFailed,
// Setup stage errors
InvalidConfiguration,
InvalidDescriptorField,
// Runtime errors
DeviceBusy,
DeviceDisconnected,
TransferFailed,
}
impl From<UsbError> for Error {
fn from(value: UsbError) -> Self {
match value {
UsbError::MemoryError(e) => e,
_ => Error::InvalidOperation,
}
}
}

View File

@ -0,0 +1,84 @@
use alloc::{string::String, vec::Vec};
use yggdrasil_abi::primitive_enum;
use crate::UsbDirection;
#[derive(Debug, Clone, Copy)]
pub enum UsbEndpointType {
Control,
Isochronous,
Bulk,
Interrupt,
}
#[derive(Debug, Clone, Copy)]
pub enum UsbSyncType {
NoSync,
Async,
Adaptive,
Sync,
}
#[derive(Debug)]
pub enum UsbUsageType {
Data,
Feedback,
ImplicitFeedbackData,
Reserved,
}
primitive_enum! {
pub enum UsbDeviceClass: u8 {
FromInterface = 0x00,
Hid = 0x03,
Unknown = 0xFF,
}
}
primitive_enum! {
pub enum UsbDeviceProtocol: u8 {
FromInterface = 0x00,
Unknown = 0xFF,
}
}
#[derive(Debug, Clone)]
pub struct UsbInterfaceInfo {
pub name: String,
pub number: u8,
pub interface_class: UsbDeviceClass,
pub interface_subclass: u8,
pub interface_protocol: UsbDeviceProtocol,
}
#[derive(Debug, Clone)]
pub struct UsbEndpointInfo {
pub number: u8,
pub direction: UsbDirection,
pub max_packet_size: usize,
pub ty: UsbEndpointType,
}
#[derive(Debug, Clone)]
pub struct UsbConfigurationInfo {
pub name: String,
pub config_value: u8,
pub interfaces: Vec<UsbInterfaceInfo>,
pub endpoints: Vec<UsbEndpointInfo>,
}
#[derive(Debug, Clone)]
pub struct UsbDeviceInfo {
pub manufacturer: String,
pub product: String,
pub id_vendor: u16,
pub id_product: u16,
pub device_class: UsbDeviceClass,
pub device_subclass: u8,
pub device_protocol: UsbDeviceProtocol,
pub max_packet_size: usize,
}

View File

@ -0,0 +1,21 @@
#![no_std]
#![feature(iter_array_chunks, maybe_uninit_slice)]
extern crate alloc;
pub mod bus;
pub mod communication;
pub mod descriptor;
pub mod device;
pub mod error;
pub mod info;
pub mod pipe;
pub mod util;
pub mod class_driver;
pub use communication::{UsbControlTransfer, UsbDirection, UsbTransferStatus, UsbTransferToken};
pub trait UsbEndpoint {}
pub trait UsbHostController {}

View File

@ -0,0 +1,326 @@
use core::{
cmp::Ordering,
mem::{size_of, MaybeUninit},
ops::Deref,
};
use alloc::{boxed::Box, string::String};
use bytemuck::{Pod, Zeroable};
use libk_mm::{
address::{AsPhysicalAddress, PhysicalAddress},
PageBox,
};
use crate::{
descriptor::{
UsbConfigurationDescriptor, UsbDeviceDescriptor, UsbDeviceQualifier, UsbEndpointDescriptor,
UsbInterfaceDescriptor, UsbOtherSpeedConfiguration,
},
error::UsbError,
UsbControlTransfer, UsbDirection,
};
use super::UsbGenericPipe;
#[derive(Debug)]
pub struct ControlTransferSetup {
pub bm_request_type: u8,
pub b_request: u8,
pub w_value: u16,
pub w_index: u16,
pub w_length: u16,
}
#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
#[repr(C)]
pub struct SetConfiguration;
pub trait UsbDeviceRequest: Sized + Pod {
const BM_REQUEST_TYPE: u8;
const B_REQUEST: u8;
}
pub trait UsbDescriptorRequest: UsbDeviceRequest {
const DESCRIPTOR_TYPE: u8;
}
impl UsbDescriptorRequest for UsbDeviceDescriptor {
const DESCRIPTOR_TYPE: u8 = 1;
}
impl UsbDescriptorRequest for UsbConfigurationDescriptor {
const DESCRIPTOR_TYPE: u8 = 2;
}
impl UsbDescriptorRequest for UsbInterfaceDescriptor {
const DESCRIPTOR_TYPE: u8 = 4;
}
impl UsbDeviceRequest for SetConfiguration {
const BM_REQUEST_TYPE: u8 = 0;
const B_REQUEST: u8 = 0x09;
}
impl<U: UsbDescriptorRequest> UsbDeviceRequest for U {
const BM_REQUEST_TYPE: u8 = 0b10000000;
const B_REQUEST: u8 = 0x06;
}
fn decode_usb_string(bytes: &[u8]) -> Result<String, UsbError> {
if bytes.len() % 2 != 0 {
return Err(UsbError::InvalidDescriptorField);
}
char::decode_utf16(
bytes
.into_iter()
.array_chunks::<2>()
.map(|[&a, &b]| u16::from_le_bytes([a, b])),
)
.collect::<Result<String, _>>()
.map_err(|_| UsbError::InvalidDescriptorField)
}
// Pipe impl
pub trait UsbControlPipe: UsbGenericPipe + Send + Sync {
fn start_transfer(
&self,
setup: ControlTransferSetup,
data: Option<(PhysicalAddress, usize, UsbDirection)>,
) -> Result<UsbControlTransfer, UsbError>;
fn complete_transfer(&self, transfer: UsbControlTransfer);
}
pub struct UsbControlPipeAccess(pub Box<dyn UsbControlPipe>);
fn input_buffer<T: Pod>(
data: &mut PageBox<MaybeUninit<T>>,
) -> (PhysicalAddress, usize, UsbDirection) {
(
unsafe { data.as_physical_address() },
size_of::<T>(),
UsbDirection::In,
)
}
#[derive(Debug)]
pub enum ConfigurationDescriptorEntry<'a> {
Configuration(&'a UsbConfigurationDescriptor),
Interface(&'a UsbInterfaceDescriptor),
Endpoint(&'a UsbEndpointDescriptor),
DeviceQualifier(&'a UsbDeviceQualifier),
OtherSpeed(&'a UsbOtherSpeedConfiguration),
Other,
}
pub struct ConfigurationDescriptorIter<'a> {
buffer: &'a PageBox<[u8]>,
offset: usize,
}
pub struct ConfigurationDescriptorQuery {
buffer: PageBox<[u8]>,
}
impl<'a> Iterator for ConfigurationDescriptorIter<'a> {
type Item = ConfigurationDescriptorEntry<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.offset + 2 >= self.buffer.len() {
return None;
}
let desc_len = self.buffer[self.offset] as usize;
let desc_ty = self.buffer[self.offset + 1];
if desc_len == 0 {
return None;
}
let entry = match desc_ty {
0x02 if desc_len == size_of::<UsbConfigurationDescriptor>() => {
ConfigurationDescriptorEntry::Configuration(bytemuck::from_bytes(
&self.buffer[self.offset..self.offset + desc_len],
))
}
0x04 if desc_len == size_of::<UsbInterfaceDescriptor>() => {
ConfigurationDescriptorEntry::Interface(bytemuck::from_bytes(
&self.buffer[self.offset..self.offset + desc_len],
))
}
0x05 if desc_len == size_of::<UsbEndpointDescriptor>() => {
ConfigurationDescriptorEntry::Endpoint(bytemuck::from_bytes(
&self.buffer[self.offset..self.offset + desc_len],
))
}
0x07 if desc_len == size_of::<UsbOtherSpeedConfiguration>() => {
ConfigurationDescriptorEntry::OtherSpeed(bytemuck::from_bytes(
&self.buffer[self.offset..self.offset + desc_len],
))
}
_ => ConfigurationDescriptorEntry::Other,
};
self.offset += desc_len;
Some(entry)
}
}
impl ConfigurationDescriptorQuery {
pub fn configuration(&self) -> &UsbConfigurationDescriptor {
bytemuck::from_bytes(&self.buffer[..size_of::<UsbConfigurationDescriptor>()])
}
pub fn descriptors(&self) -> ConfigurationDescriptorIter<'_> {
ConfigurationDescriptorIter {
buffer: &self.buffer,
offset: 0,
}
}
}
impl UsbControlPipeAccess {
pub async fn perform_value_control(
&self,
setup: ControlTransferSetup,
buffer: Option<(PhysicalAddress, usize, UsbDirection)>,
) -> Result<(), UsbError> {
let transfer = self.start_transfer(setup, buffer)?;
transfer.status.wait().await?;
self.complete_transfer(transfer);
Ok(())
}
async fn fill_configuation_descriptor(
&self,
index: u8,
buffer: &mut PageBox<[MaybeUninit<u8>]>,
) -> Result<(), UsbError> {
self.perform_value_control(
ControlTransferSetup {
bm_request_type: 0b10000000,
b_request: 0x06,
w_value: 0x200 | (index as u16),
w_index: 0,
w_length: buffer.len().try_into().unwrap(),
},
Some((
unsafe { buffer.as_physical_address() },
buffer.len(),
UsbDirection::In,
)),
)
.await
}
pub async fn query_configuration_descriptor(
&self,
index: u8,
) -> Result<ConfigurationDescriptorQuery, UsbError> {
// First, query the real length of the descriptor
let mut buffer = PageBox::new_uninit_slice(size_of::<UsbConfigurationDescriptor>())
.map_err(UsbError::MemoryError)?;
self.fill_configuation_descriptor(index, &mut buffer)
.await?;
let buffer = unsafe { PageBox::assume_init_slice(buffer) };
let desc: &UsbConfigurationDescriptor = bytemuck::from_bytes(&buffer);
let total_len = desc.total_length as usize;
// Return if everything's ready at this point
match total_len.cmp(&size_of::<UsbConfigurationDescriptor>()) {
Ordering::Less => todo!(),
Ordering::Equal => return Ok(ConfigurationDescriptorQuery { buffer }),
_ => (),
}
// Otherwise, query the rest of the data
let mut buffer = PageBox::new_uninit_slice(total_len).map_err(UsbError::MemoryError)?;
self.fill_configuation_descriptor(index, &mut buffer)
.await?;
let buffer = unsafe { PageBox::assume_init_slice(buffer) };
let desc: &UsbConfigurationDescriptor =
bytemuck::from_bytes(&buffer[..size_of::<UsbConfigurationDescriptor>()]);
let total_len = desc.total_length as usize;
if total_len != buffer.len() {
todo!();
}
Ok(ConfigurationDescriptorQuery { buffer })
}
pub async fn query_device_descriptor(&self) -> Result<PageBox<UsbDeviceDescriptor>, UsbError> {
let mut output = PageBox::new_uninit().map_err(UsbError::MemoryError)?;
self.perform_value_control(
ControlTransferSetup {
bm_request_type: 0b10000000,
b_request: 0x06,
w_value: 0x100,
w_index: 0,
w_length: size_of::<UsbDeviceDescriptor>() as _,
},
Some(input_buffer(&mut output)),
)
.await?;
Ok(unsafe { output.assume_init() })
}
pub async fn query_string(
&self,
index: u8,
buffer: &mut PageBox<MaybeUninit<[u8; 4096]>>,
) -> Result<String, UsbError> {
self.perform_value_control(
ControlTransferSetup {
bm_request_type: 0b10000000,
b_request: 0x06,
w_value: 0x300 | (index as u16),
w_index: 0,
w_length: 4096,
},
Some(input_buffer(buffer)),
)
.await?;
let data = unsafe { buffer.assume_init_ref() };
let len = data[0] as usize;
decode_usb_string(&data[2..len])
}
pub async fn perform_action<D: UsbDeviceRequest>(
&self,
w_value: u16,
w_index: u16,
) -> Result<(), UsbError> {
self.perform_value_control(
ControlTransferSetup {
bm_request_type: D::BM_REQUEST_TYPE,
b_request: D::B_REQUEST,
w_value,
w_index,
w_length: 0,
},
None,
)
.await
}
pub async fn set_configuration(&self, value: u16) -> Result<(), UsbError> {
self.perform_action::<SetConfiguration>(value, 0).await
}
}
impl Deref for UsbControlPipeAccess {
type Target = dyn UsbControlPipe;
fn deref(&self) -> &Self::Target {
&*self.0
}
}

View File

@ -0,0 +1,32 @@
use core::ops::Deref;
use alloc::boxed::Box;
use libk_mm::PageBox;
use crate::{communication::UsbInterruptTransfer, error::UsbError};
use super::UsbGenericPipe;
pub trait UsbInterruptInPipe: UsbGenericPipe + Send + Sync {
fn start_read(&self, buffer: &mut PageBox<[u8]>) -> Result<UsbInterruptTransfer, UsbError>;
fn complete_transfer(&self, transfer: UsbInterruptTransfer);
}
pub struct UsbInterruptInPipeAccess(pub Box<dyn UsbInterruptInPipe>);
impl UsbInterruptInPipeAccess {
pub async fn read<'a>(&self, buffer: &'a mut PageBox<[u8]>) -> Result<&'a [u8], UsbError> {
let transfer = self.start_read(buffer)?;
let len = transfer.wait().await?;
self.complete_transfer(transfer);
Ok(&buffer[..len])
}
}
impl Deref for UsbInterruptInPipeAccess {
type Target = dyn UsbInterruptInPipe;
fn deref(&self) -> &Self::Target {
&*self.0
}
}

View File

@ -0,0 +1,8 @@
pub mod control;
pub mod interrupt;
pub trait UsbGenericPipe {}
pub enum UsbPipe {
Control(control::UsbControlPipeAccess),
}

View File

@ -0,0 +1,64 @@
use libk_util::sync::spin_rwlock::IrqSafeRwLock;
use crate::error::UsbError;
pub struct UsbAddressAllocator {
// 256 bits
bitmap: IrqSafeRwLock<Bitmap>,
}
struct Bitmap {
data: [u64; 4],
}
impl Bitmap {
const fn bit_index(bit: u8) -> usize {
(bit / 64) as usize
}
const fn bit_mask(bit: u8) -> u64 {
1 << (bit % 64)
}
pub const fn new() -> Self {
// First is 1 to prevent address 0 from being allocated
Self { data: [1, 0, 0, 0] }
}
pub fn is_set(&self, bit: u8) -> bool {
self.data[Self::bit_index(bit)] & Self::bit_mask(bit) != 0
}
pub fn set(&mut self, bit: u8) {
self.data[Self::bit_index(bit)] |= Self::bit_mask(bit);
}
pub fn clear(&mut self, bit: u8) {
self.data[Self::bit_index(bit)] &= !Self::bit_mask(bit);
}
}
impl UsbAddressAllocator {
pub fn new() -> Self {
Self {
bitmap: IrqSafeRwLock::new(Bitmap::new()),
}
}
pub fn allocate(&self) -> Result<u8, UsbError> {
let mut bitmap = self.bitmap.write();
for bit in 0..=255 {
if !bitmap.is_set(bit) {
bitmap.set(bit);
return Ok(bit);
}
}
Err(UsbError::OutOfAddresses)
}
pub fn free(&self, address: u8) {
let mut bitmap = self.bitmap.write();
assert!(bitmap.is_set(address));
bitmap.clear(address);
}
}

View File

@ -0,0 +1,16 @@
[package]
name = "kernel-fs"
version = "0.1.0"
edition = "2021"
authors = ["Mark Poliakov <mark@alnyan.me>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
vfs = { path = "../../../lib/vfs" }
libk-util = { path = "../../../libk/libk-util" }
ygg_driver_block = { path = "../../block/core" }
log = "0.4.20"

View File

@ -0,0 +1,86 @@
//! Device virtual file system
use core::sync::atomic::{AtomicUsize, Ordering};
use alloc::{format, string::String};
use libk_util::OneTimeInit;
use vfs::{impls::MemoryDirectory, CharDevice, Node, NodeFlags, NodeRef};
use ygg_driver_block::BlockDevice;
use yggdrasil_abi::error::Error;
/// Describes the kind of a character device
#[derive(Debug)]
pub enum CharDeviceType {
/// Regular terminal
TtyRegular,
/// Serial terminal
TtySerial,
}
static DEVFS_ROOT: OneTimeInit<NodeRef> = OneTimeInit::new();
/// Sets up the device filesystem
pub fn init() {
let root = MemoryDirectory::empty();
DEVFS_ROOT.init(root);
}
/// Returns the root of the devfs.
///
/// # Panics
///
/// Will panic if the devfs hasn't yet been initialized.
pub fn root() -> &'static NodeRef {
DEVFS_ROOT.get()
}
/// Adds a character device with a custom name
pub fn add_named_char_device(dev: &'static dyn CharDevice, name: String) -> Result<(), Error> {
log::info!("Add char device: {}", name);
let node = Node::char(dev, NodeFlags::IN_MEMORY_PROPS);
DEVFS_ROOT.get().add_child(name, node)
}
/// Adds a block device with a custom name
pub fn add_named_block_device<S: Into<String>>(
dev: &'static dyn BlockDevice,
name: S,
) -> Result<(), Error> {
let name = name.into();
log::info!("Add block device: {}", name);
let node = Node::block(dev, NodeFlags::IN_MEMORY_PROPS);
DEVFS_ROOT.get().add_child(name, node)
}
pub fn add_block_device_partition<S: Into<String>>(
base_name: S,
index: usize,
partition: &'static dyn BlockDevice,
) -> Result<(), Error> {
let base_name = base_name.into();
let name = format!("{}{}", base_name, index + 1);
log::info!("Add partition: {}", name);
let node = Node::block(partition, NodeFlags::IN_MEMORY_PROPS);
DEVFS_ROOT.get().add_child(name, node)
}
/// Adds a character device to the devfs
pub fn add_char_device(dev: &'static dyn CharDevice, kind: CharDeviceType) -> Result<(), Error> {
static TTY_COUNT: AtomicUsize = AtomicUsize::new(0);
static TTYS_COUNT: AtomicUsize = AtomicUsize::new(0);
let (count, prefix) = match kind {
CharDeviceType::TtyRegular => (&TTY_COUNT, "tty"),
CharDeviceType::TtySerial => (&TTYS_COUNT, "ttyS"),
};
let value = count.fetch_add(1, Ordering::AcqRel);
let name = format!("{}{}", prefix, value);
add_named_char_device(dev, name)
}

View File

@ -0,0 +1,5 @@
#![no_std]
extern crate alloc;
pub mod devfs;

View File

@ -0,0 +1,19 @@
[package]
name = "memfs"
version = "0.1.0"
edition = "2021"
authors = ["Mark Poliakov <mark@alnyan.me>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
libk-util = { path = "../../../libk/libk-util" }
vfs = { path = "../../../lib/vfs" }
static_assertions = "1.1.0"
log = "0.4.20"
[features]
default = []
test-io = []

View File

@ -0,0 +1,357 @@
//! Block management interfaces and structures
use core::{
marker::PhantomData,
mem::{size_of, MaybeUninit},
ops::{Deref, DerefMut},
ptr::NonNull,
};
use yggdrasil_abi::error::Error;
/// Number of bytes in a block
pub const SIZE: usize = 4096;
/// Maximum number of indirection pointers a block can hold
pub const ENTRY_COUNT: usize = SIZE / size_of::<usize>();
/// Interface for a block allocator
///
/// # Safety
///
/// This trait is unsafe to implement because it has to provide and accept raw data pointers of
/// exactly [SIZE].
pub unsafe trait BlockAllocator: Send + Sync + 'static {
/// Allocates a contiguous block of size [SIZE]
fn alloc() -> Result<NonNull<u8>, Error>;
/// Dealocates a block.
///
/// # Safety
///
/// Unsafe: accepts arbitrary data pointers.
unsafe fn dealloc(block: NonNull<u8>);
}
struct BlockRef<'a, A: BlockAllocator> {
ptr: usize,
_pd: PhantomData<&'a A>,
}
#[repr(transparent)]
struct BlockRaw<'a, A: BlockAllocator> {
inner: BlockRef<'a, A>, // inner: Option<&'a mut [u8; SIZE]>,
}
/// Block containing file data
#[repr(transparent)]
pub struct BlockData<'a, A: BlockAllocator> {
inner: BlockRaw<'a, A>,
}
/// Block containing indirection pointers to other blocks
#[repr(transparent)]
pub struct BlockIndirect<'a, A: BlockAllocator> {
inner: BlockRaw<'a, A>,
}
impl<'a, A: BlockAllocator> BlockRef<'a, A> {
const fn null() -> Self {
Self {
ptr: 0,
_pd: PhantomData,
}
}
unsafe fn from_allocated(address: usize) -> Self {
debug_assert_eq!(address & 1, 0);
Self {
ptr: address,
_pd: PhantomData,
}
}
unsafe fn copy_on_write(address: usize) -> Self {
debug_assert_eq!(address & 1, 0);
Self {
ptr: address | 1,
_pd: PhantomData,
}
}
#[inline]
fn is_allocated(&self) -> bool {
self.ptr & 1 == 0
}
#[inline]
fn is_null(&self) -> bool {
self.ptr == 0
}
#[inline]
fn as_mut(&mut self) -> &'a mut [u8; SIZE] {
if self.is_null() {
panic!("Null block dereference");
}
// FIXME: if a non-full block has been marked as CoW, the function will overrun the file
// boundary
if !self.is_allocated() {
// Allocate the block
let ptr = A::alloc().expect("Could not allocate a block").as_ptr() as usize;
// Clone data
let src = self.as_ref();
let dst = unsafe { core::slice::from_raw_parts_mut(ptr as *mut u8, SIZE) };
dst.copy_from_slice(src);
self.ptr = ptr;
}
unsafe { &mut *((self.ptr & !1) as *mut [u8; SIZE]) }
}
#[inline]
fn as_ref(&self) -> &'a [u8; SIZE] {
if self.is_null() {
panic!("Null block dereference");
}
unsafe { &*((self.ptr & !1) as *const [u8; SIZE]) }
}
}
impl<'a, A: BlockAllocator> Drop for BlockRef<'a, A> {
fn drop(&mut self) {
if self.is_allocated() && !self.is_null() {
unsafe {
A::dealloc(NonNull::new_unchecked(self.ptr as *mut _));
}
}
}
}
impl<'a, A: BlockAllocator> BlockRaw<'a, A> {
const fn null() -> Self {
Self {
inner: BlockRef::null(),
}
}
fn new() -> Result<Self, Error> {
let ptr = A::alloc()?;
unsafe {
Ok(Self {
inner: BlockRef::from_allocated(ptr.as_ptr() as _),
})
}
}
unsafe fn as_uninit_indirect_mut(
&mut self,
) -> &'a mut [MaybeUninit<BlockData<'a, A>>; ENTRY_COUNT] {
if self.inner.is_null() {
panic!("Null block dereference");
}
&mut *(self.inner.ptr as *mut _)
}
#[inline]
unsafe fn as_data_ref(&self) -> &'a [u8; SIZE] {
self.inner.as_ref()
}
#[inline]
unsafe fn as_data_mut(&mut self) -> &'a mut [u8; SIZE] {
self.inner.as_mut()
}
#[inline]
fn is_null(&self) -> bool {
self.inner.is_null()
}
}
// Data block
impl<'a, A: BlockAllocator> BlockData<'a, A> {
/// Dummy entry representing a missing block
pub const fn null() -> Self {
Self {
inner: BlockRaw::null(),
}
}
/// Create a Copy-on-Write data block from existing data.
///
/// # Safety
///
/// This function is unsafe as it accepts arbitrary pointers. The caller must ensure the
/// address is properly aligned (at least to a u16 boundary), does not cross any device memory
/// and the pointer outlives the block reference.
pub unsafe fn copy_on_write(address: usize) -> Self {
Self {
inner: BlockRaw {
inner: BlockRef::copy_on_write(address),
},
}
}
/// Allocates a new block for data
pub fn new() -> Result<Self, Error> {
Ok(Self {
inner: BlockRaw::new()?,
})
}
/// Replaces self with a null block and drops any data that might've been allocated
pub fn set_null(&mut self) {
self.inner = BlockRaw::null();
}
/// Returns `true` if the block this structure refers to has not yet been allocated
#[inline]
pub fn is_null(&self) -> bool {
self.inner.is_null()
}
}
impl<A: BlockAllocator> Deref for BlockData<'_, A> {
type Target = [u8; SIZE];
fn deref(&self) -> &Self::Target {
unsafe { self.inner.as_data_ref() }
}
}
impl<A: BlockAllocator> DerefMut for BlockData<'_, A> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.inner.as_data_mut() }
}
}
// Indirect block
impl<'a, A: BlockAllocator> BlockIndirect<'a, A> {
/// Dummy entry representing a missing block
pub const fn null() -> Self {
Self {
inner: BlockRaw::null(),
}
}
/// Allocates a new indirection block
pub fn new() -> Result<Self, Error> {
let mut inner = BlockRaw::new()?;
for item in unsafe { inner.as_uninit_indirect_mut() } {
item.write(BlockData::null());
}
Ok(Self { inner })
}
/// Returns `true` if the block this structure refers to has not yet been allocated
#[inline]
pub fn is_null(&self) -> bool {
self.inner.is_null()
}
}
impl<'a, A: BlockAllocator> Deref for BlockIndirect<'a, A> {
type Target = [BlockData<'a, A>; ENTRY_COUNT];
fn deref(&self) -> &Self::Target {
unsafe { &*(self.inner.inner.as_ref() as *const _ as *const _) }
}
}
impl<'a, A: BlockAllocator> DerefMut for BlockIndirect<'a, A> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *(self.inner.inner.as_mut() as *mut _ as *mut _) }
}
}
impl<'a, A: BlockAllocator> Drop for BlockIndirect<'a, A> {
fn drop(&mut self) {
if self.is_null() {
return;
}
for item in self.iter_mut() {
item.set_null();
}
}
}
#[cfg(test)]
mod tests {
use core::sync::atomic::Ordering;
use std::vec::Vec;
use crate::block::{BlockData, BlockIndirect};
#[test]
fn block_indirect_allocation() {
test_allocator_with_counter!(A_COUNTER, A);
const N: usize = 7;
const M: usize = 13;
assert_eq!(A_COUNTER.load(Ordering::Acquire), 0);
{
let mut indirect = Vec::new();
// Allocate indirect blocks
{
for _ in 0..N {
indirect.push(BlockIndirect::<A>::new().unwrap());
}
}
assert_eq!(A_COUNTER.load(Ordering::Acquire), N);
// Allocate L1 indirection blocks
{
for l1_block in indirect.iter_mut() {
for i in 0..M {
let l0_block = BlockData::new().unwrap();
l1_block[i] = l0_block;
}
}
}
// N * M data blocks and N indirection blocks
assert_eq!(A_COUNTER.load(Ordering::Acquire), N * M + N);
// Drop 1 indirect block for test
indirect.pop();
assert_eq!(A_COUNTER.load(Ordering::Acquire), (N - 1) * M + (N - 1));
}
assert_eq!(A_COUNTER.load(Ordering::Acquire), 0);
}
#[test]
fn block_allocation() {
test_allocator_with_counter!(A_COUNTER, A);
const N: usize = 13;
{
assert_eq!(A_COUNTER.load(Ordering::Acquire), 0);
{
let mut s = Vec::new();
for _ in 0..N {
let mut block = BlockData::<A>::new().unwrap();
block.fill(1);
s.push(block);
}
assert_eq!(A_COUNTER.load(Ordering::Acquire), N);
}
assert_eq!(A_COUNTER.load(Ordering::Acquire), 0);
}
}
}

View File

@ -0,0 +1,483 @@
//! Block vector management structures
use core::{
cmp::Ordering,
mem::MaybeUninit,
ops::{Index, IndexMut},
};
use yggdrasil_abi::error::Error;
use crate::block::{self, BlockAllocator, BlockData, BlockIndirect};
// 16.125M total
const L0_BLOCKS: usize = 32; // 128K in L0
const L1_BLOCKS: usize = 8; // 16M in L1
/// Block vector for efficient in-memory files
pub struct BVec<'a, A: BlockAllocator> {
capacity: usize,
size: usize,
l0: [BlockData<'a, A>; L0_BLOCKS],
l1: [BlockIndirect<'a, A>; L1_BLOCKS],
}
impl<'a, A: BlockAllocator> BVec<'a, A> {
/// Creates an empty block vector.
///
/// # Note
///
/// The function is guaranteed to make no allocations before the vector is actually written to.
pub fn new() -> Self {
let mut l0 = MaybeUninit::uninit_array();
let mut l1 = MaybeUninit::uninit_array();
for it in l0.iter_mut() {
it.write(BlockData::null());
}
for it in l1.iter_mut() {
it.write(BlockIndirect::null());
}
Self {
capacity: 0,
size: 0,
l0: unsafe { MaybeUninit::array_assume_init(l0) },
l1: unsafe { MaybeUninit::array_assume_init(l1) },
}
}
/// Initializes the block vector with existing data, marking all blocks as Copy-on-Write
pub fn init_with_cow(&mut self, data: &'static [u8]) -> Result<(), Error> {
let data_ptr = data.as_ptr() as usize;
assert_eq!(data_ptr & 1, 0);
let blocks = (data.len() + block::SIZE - 1) / block::SIZE;
self.resize(blocks)?;
for i in 0..blocks {
let src = data_ptr + i * block::SIZE;
let block = unsafe { BlockData::copy_on_write(src) };
self[i] = block;
}
self.size = data.len();
Ok(())
}
/// Returns the size of the data inside this vector
#[inline]
pub const fn size(&self) -> usize {
self.size
}
fn grow_l1(&mut self, old_l1_cap: usize, new_l1_cap: usize) -> Result<(), Error> {
for i in old_l1_cap..new_l1_cap {
assert!(self.l1[i].is_null());
self.l1[i] = BlockIndirect::new()?;
}
Ok(())
}
fn shrink_l1(&mut self, old_l1_cap: usize, new_l1_cap: usize) {
debug_assert!(new_l1_cap <= old_l1_cap);
for i in new_l1_cap..old_l1_cap {
assert!(!self.l1[i].is_null());
self.l1[i] = BlockIndirect::null();
}
}
#[inline]
fn caps(cap: usize) -> (usize, usize) {
let l0_cap = core::cmp::min(cap, L0_BLOCKS);
let l1_cap = if cap > L0_BLOCKS {
core::cmp::min(
(cap - L0_BLOCKS + block::ENTRY_COUNT - 1) / block::ENTRY_COUNT,
L1_BLOCKS,
)
} else {
0
};
(l0_cap, l1_cap)
}
/// Resizes the vector to hold exactly `new_capacity` data blocks
pub fn resize(&mut self, new_capacity: usize) -> Result<(), Error> {
// TODO handle L2 capacity
match new_capacity.cmp(&self.capacity) {
Ordering::Less => {
let (_, new_l1_cap) = Self::caps(new_capacity);
let (_, old_l1_cap) = Self::caps(self.capacity);
// Shrink data blocks
for index in new_capacity..self.capacity {
let block = &mut self[index];
assert!(!block.is_null());
block.set_null();
}
// Shrink L1 blocks
self.shrink_l1(old_l1_cap, new_l1_cap);
}
Ordering::Greater => {
let (_, new_l1_cap) = Self::caps(new_capacity);
let (_, old_l1_cap) = Self::caps(self.capacity);
// Allocate L1 indirection blocks
assert!(new_l1_cap >= old_l1_cap);
if new_l1_cap > old_l1_cap {
self.grow_l1(old_l1_cap, new_l1_cap)?;
}
// Grow data blocks
for index in self.capacity..new_capacity {
let block = unsafe { self.index_unchecked_mut(index) };
assert!(block.is_null());
*block = BlockData::new()?;
}
}
Ordering::Equal => (),
}
self.capacity = new_capacity;
Ok(())
}
fn ensure_write_capacity(&mut self, pos: usize, need_to_write: usize) -> Result<(), Error> {
let current_capacity = self.capacity;
let need_capacity =
(core::cmp::max(pos + need_to_write, self.size) + block::SIZE - 1) / block::SIZE;
if need_capacity > current_capacity {
self.resize(need_capacity)
} else {
Ok(())
}
}
/// Writes data to the vector, growing it if needed
pub fn write(&mut self, pos: u64, data: &[u8]) -> Result<usize, Error> {
let mut pos = pos as usize;
let mut rem = data.len();
let mut doff = 0usize;
self.ensure_write_capacity(pos, rem)?;
if pos + rem > self.size {
self.size = pos + rem;
}
while rem > 0 {
let index = pos / block::SIZE;
let offset = pos % block::SIZE;
let count = core::cmp::min(rem, block::SIZE - offset);
let block = &mut self[index];
let dst = &mut block[offset..offset + count];
let src = &data[doff..doff + count];
dst.copy_from_slice(src);
doff += count;
pos += count;
rem -= count;
}
Ok(doff)
}
/// Reads data from the vector
pub fn read(&self, pos: u64, data: &mut [u8]) -> Result<usize, Error> {
let mut pos = pos as usize;
if pos > self.size {
return Err(Error::InvalidFile);
}
let mut rem = core::cmp::min(self.size - pos, data.len());
let mut doff = 0usize;
while rem > 0 {
let index = pos / block::SIZE;
let offset = pos % block::SIZE;
let count = core::cmp::min(block::SIZE - offset, rem);
let block = &self[index];
let src = &block[offset..offset + count];
let dst = &mut data[doff..doff + count];
dst.copy_from_slice(src);
doff += count;
pos += count;
rem -= count;
}
Ok(doff)
}
/// Resize the block vector to requested size
pub fn truncate(&mut self, new_size: u64) -> Result<(), Error> {
let new_size: usize = new_size.try_into().unwrap();
let requested_capacity = (new_size + block::SIZE - 1) / block::SIZE;
self.resize(requested_capacity)?;
// TODO fill with zeros if resizing larger?
self.size = new_size;
Ok(())
}
unsafe fn index_unchecked(&self, mut index: usize) -> &BlockData<'a, A> {
if index < L0_BLOCKS {
return &self.l0[index];
}
index -= L0_BLOCKS;
if index < L1_BLOCKS * block::ENTRY_COUNT {
let l1i = index / block::ENTRY_COUNT;
let l0i = index % block::ENTRY_COUNT;
let l1r = &self.l1[l1i];
assert!(!l1r.is_null());
return &l1r[l0i];
}
todo!();
}
unsafe fn index_unchecked_mut(&mut self, mut index: usize) -> &mut BlockData<'a, A> {
if index < L0_BLOCKS {
return &mut self.l0[index];
}
index -= L0_BLOCKS;
if index < L1_BLOCKS * block::ENTRY_COUNT {
let l1i = index / block::ENTRY_COUNT;
let l0i = index % block::ENTRY_COUNT;
let l1r = &mut self.l1[l1i];
assert!(!l1r.is_null());
return &mut l1r[l0i];
}
todo!()
}
}
impl<'a, A: BlockAllocator> Index<usize> for BVec<'a, A> {
type Output = BlockData<'a, A>;
fn index(&self, index: usize) -> &Self::Output {
if index > self.capacity {
panic!(
"Block index out of bounds: capacity={}, index={}",
self.capacity, index
);
}
unsafe { self.index_unchecked(index) }
}
}
impl<'a, A: BlockAllocator> IndexMut<usize> for BVec<'a, A> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
if index > self.capacity {
panic!(
"Block index out of bounds: capacity={}, index={}",
self.capacity, index
);
}
unsafe { self.index_unchecked_mut(index) }
}
}
impl<'a, A: BlockAllocator> TryFrom<&'static [u8]> for BVec<'a, A> {
type Error = Error;
fn try_from(value: &'static [u8]) -> Result<Self, Self::Error> {
let mut res = Self::new();
res.init_with_cow(value)?;
assert_eq!(res.size(), value.len());
Ok(res)
}
}
#[cfg(test)]
mod bvec_allocation {
use core::sync::atomic::Ordering;
use crate::{
block,
bvec::{BVec, L0_BLOCKS, L1_BLOCKS},
};
#[test]
fn bvec_grow_shrink() {
test_allocator_with_counter!(A_COUNTER, A);
assert_eq!(A_COUNTER.load(Ordering::Acquire), 0);
{
let mut bvec = BVec::<A>::new();
assert_eq!(
A_COUNTER.load(Ordering::Acquire),
0,
"BVec should not allocate on creation"
);
const N: usize = 123;
bvec.resize(N).unwrap();
// N data blocks (12 in L0 + 111 in L1)
assert_eq!(A_COUNTER.load(Ordering::Acquire), N + 1);
// Test the index interface
for i in 0..N {
assert!(!bvec[i].is_null(), "Index {} must be allocated", i);
}
// Test the data structure
for i in 0..L0_BLOCKS {
assert!(!bvec.l0[i].is_null());
}
assert!(!bvec.l1[0].is_null());
for i in L0_BLOCKS..N {
let l1i = (i - L0_BLOCKS) / block::ENTRY_COUNT;
let l0i = (i - L0_BLOCKS) % block::ENTRY_COUNT;
let l1r = &bvec.l1[l1i];
assert!(!l1r.is_null());
assert!(!l1r[l0i].is_null());
}
for i in 1..L1_BLOCKS {
assert!(bvec.l1[i].is_null());
}
// Shrink to 100 blocks, test if L1 is still allocated
const M: usize = 100;
bvec.resize(M).unwrap();
assert_eq!(A_COUNTER.load(Ordering::Acquire), M + 1);
// Test the index interface
for i in 0..M {
assert!(!bvec[i].is_null(), "Index {} must be allocated", i);
}
// Test the data structure
for i in 0..L0_BLOCKS {
assert!(!bvec.l0[i].is_null());
}
assert!(!bvec.l1[0].is_null());
for i in L0_BLOCKS..M {
let l1i = (i - L0_BLOCKS) / block::ENTRY_COUNT;
let l0i = (i - L0_BLOCKS) % block::ENTRY_COUNT;
let l1r = &bvec.l1[l1i];
assert!(!l1r.is_null());
assert!(!l1r[l0i].is_null());
}
for i in M..N {
let l1i = (i - L0_BLOCKS) / block::ENTRY_COUNT;
let l0i = (i - L0_BLOCKS) % block::ENTRY_COUNT;
let l1r = &bvec.l1[l1i];
assert!(!l1r.is_null());
assert!(l1r[l0i].is_null());
}
for i in 1..L1_BLOCKS {
assert!(bvec.l1[i].is_null());
}
// Shrink to 13 blocks, test if L1 got deallocated
const O: usize = 13;
bvec.resize(O).unwrap();
assert_eq!(A_COUNTER.load(Ordering::Acquire), O);
}
assert_eq!(A_COUNTER.load(Ordering::Acquire), 0);
}
}
#[cfg(all(test, feature = "test-io"))]
mod bvec_io {
use crate::{block, bvec::L0_BLOCKS};
use super::BVec;
#[test]
fn test_bvec_write() {
test_allocator_with_counter!(A_COUNTER, A);
{
let data = [1, 2, 3, 4, 5];
let mut bvec = BVec::<A>::new();
// Write at 0
assert_eq!(bvec.write(0, &data).unwrap(), data.len());
assert_eq!(bvec.capacity, 1);
assert_eq!(bvec.size(), data.len());
assert_eq!(&bvec[0][..bvec.size()], &data[..]);
// Write at 3
assert_eq!(bvec.write(3, &data).unwrap(), data.len());
assert_eq!(bvec.capacity, 1);
assert_eq!(bvec.size(), 3 + data.len());
assert_eq!(&bvec[0][..bvec.size()], &[1, 2, 3, 1, 2, 3, 4, 5]);
}
{
let data = [5, 4, 3, 2, 1];
let mut bvec = BVec::<A>::new();
// Write at the end of L0-region
assert_eq!(
bvec.write((L0_BLOCKS * block::SIZE) as u64, &data).unwrap(),
data.len()
);
// L0_BLOCKS + 1 L1 data block
assert_eq!(bvec.capacity, L0_BLOCKS + 1);
assert_eq!(bvec.size(), L0_BLOCKS * block::SIZE + data.len());
assert_eq!(&bvec[L0_BLOCKS][..data.len()], &data[..]);
// Write at zero
assert_eq!(bvec.write(0, &data).unwrap(), data.len());
assert_eq!(bvec.capacity, L0_BLOCKS + 1);
assert_eq!(bvec.size(), L0_BLOCKS * block::SIZE + data.len());
assert_eq!(&bvec[0][..data.len()], &data[..]);
// Test write crossing L0 block boundary
assert_eq!(
bvec.write((block::SIZE - 3) as u64, &data).unwrap(),
data.len()
);
assert_eq!(bvec.capacity, L0_BLOCKS + 1);
assert_eq!(bvec.size(), L0_BLOCKS * block::SIZE + data.len());
assert_eq!(&bvec[0][block::SIZE - 3..], &[5, 4, 3]);
assert_eq!(&bvec[1][..2], &[2, 1]);
// Test write crossing L0-L1 boundary
assert_eq!(
bvec.write((L0_BLOCKS * block::SIZE) as u64 - 2, &data)
.unwrap(),
data.len()
);
assert_eq!(bvec.capacity, L0_BLOCKS + 1);
assert_eq!(bvec.size(), L0_BLOCKS * block::SIZE + data.len());
assert_eq!(&bvec[L0_BLOCKS - 1][block::SIZE - 2..], &[5, 4]);
assert_eq!(&bvec[L0_BLOCKS][..data.len()], &[3, 2, 1, 2, 1]);
}
}
}

View File

@ -0,0 +1,35 @@
use core::marker::PhantomData;
use vfs::{CommonImpl, DirectoryImpl, DirectoryOpenPosition, Node, NodeFlags, NodeRef};
use yggdrasil_abi::{error::Error, io::FileType};
use crate::{block::BlockAllocator, file::FileNode};
pub(crate) struct DirectoryNode<A: BlockAllocator> {
_pd: PhantomData<A>,
}
impl<A: BlockAllocator> DirectoryNode<A> {
pub fn new() -> NodeRef {
Node::directory(
Self { _pd: PhantomData },
NodeFlags::IN_MEMORY_SIZE | NodeFlags::IN_MEMORY_PROPS,
)
}
}
impl<A: BlockAllocator> CommonImpl for DirectoryNode<A> {}
impl<A: BlockAllocator> DirectoryImpl for DirectoryNode<A> {
fn open(&self, _node: &NodeRef) -> Result<DirectoryOpenPosition, Error> {
Ok(DirectoryOpenPosition::FromCache)
}
fn create_node(&self, _parent: &NodeRef, ty: FileType) -> Result<NodeRef, Error> {
match ty {
FileType::File => Ok(FileNode::<A>::new()),
FileType::Directory => Ok(DirectoryNode::<A>::new()),
_ => todo!(),
}
}
}

View File

@ -0,0 +1,75 @@
use core::any::Any;
use libk_util::sync::IrqSafeSpinlock;
use vfs::{CommonImpl, InstanceData, Node, NodeFlags, NodeRef, RegularImpl};
use yggdrasil_abi::{error::Error, io::OpenOptions};
use crate::{block::BlockAllocator, bvec::BVec};
pub(crate) struct FileNode<A: BlockAllocator> {
pub(crate) data: IrqSafeSpinlock<BVec<'static, A>>,
}
impl<A: BlockAllocator> FileNode<A> {
pub fn new() -> NodeRef {
Node::regular(
Self {
data: IrqSafeSpinlock::new(BVec::new()),
},
NodeFlags::IN_MEMORY_PROPS,
)
}
}
impl<A: BlockAllocator> CommonImpl for FileNode<A> {
fn as_any(&self) -> &dyn Any {
self
}
fn size(&self, _node: &NodeRef) -> Result<u64, Error> {
Ok(self.data.lock().size() as u64)
}
}
impl<A: BlockAllocator> RegularImpl for FileNode<A> {
fn open(
&self,
_node: &NodeRef,
opts: OpenOptions,
) -> Result<(u64, Option<InstanceData>), Error> {
// TODO provide APPEND by vfs driver instead
if opts.contains(OpenOptions::APPEND) {
Ok((self.data.lock().size() as u64, None))
} else {
Ok((0, None))
}
}
fn read(
&self,
_node: &NodeRef,
_instance: Option<&InstanceData>,
pos: u64,
buf: &mut [u8],
) -> Result<usize, Error> {
self.data.lock().read(pos, buf)
}
fn write(
&self,
_node: &NodeRef,
_instance: Option<&InstanceData>,
pos: u64,
buf: &[u8],
) -> Result<usize, Error> {
self.data.lock().write(pos, buf)
}
fn truncate(&self, _node: &NodeRef, new_size: u64) -> Result<(), Error> {
self.data.lock().truncate(new_size)
}
fn close(&self, _node: &NodeRef, _instance: Option<&InstanceData>) -> Result<(), Error> {
Ok(())
}
}

View File

@ -0,0 +1,192 @@
//! In-memory filesystem driver
#![no_std]
#![deny(missing_docs)]
#![allow(clippy::new_without_default, clippy::new_ret_no_self)]
#![feature(
const_mut_refs,
maybe_uninit_uninit_array,
const_maybe_uninit_uninit_array,
maybe_uninit_array_assume_init
)]
use core::{cell::RefCell, marker::PhantomData};
use alloc::rc::Rc;
use block::BlockAllocator;
use dir::DirectoryNode;
use file::FileNode;
use vfs::{AccessToken, NodeRef};
use yggdrasil_abi::{
error::Error,
io::{FileMode, FileType, GroupId, UserId},
path::Path,
};
use crate::tar::TarIterator;
#[cfg(test)]
extern crate std;
extern crate alloc;
#[cfg(test)]
macro_rules! test_allocator_with_counter {
($counter:ident, $allocator:ident) => {
static $counter: core::sync::atomic::AtomicUsize = core::sync::atomic::AtomicUsize::new(0);
struct $allocator;
unsafe impl $crate::block::BlockAllocator for $allocator {
fn alloc() -> Result<core::ptr::NonNull<u8>, yggdrasil_abi::error::Error> {
let b = std::boxed::Box::into_raw(std::boxed::Box::new([0; $crate::block::SIZE]));
$counter.fetch_add(1, core::sync::atomic::Ordering::Release);
Ok(unsafe { core::ptr::NonNull::new_unchecked(b as _) })
}
unsafe fn dealloc(block: core::ptr::NonNull<u8>) {
$counter.fetch_sub(1, core::sync::atomic::Ordering::Release);
drop(std::boxed::Box::from_raw(
block.as_ptr() as *mut [u8; $crate::block::SIZE]
));
}
}
};
}
pub mod block;
pub mod bvec;
mod dir;
mod file;
mod tar;
/// In-memory read/write filesystem
pub struct MemoryFilesystem<A: BlockAllocator> {
root: RefCell<Option<NodeRef>>,
_pd: PhantomData<A>,
}
impl<A: BlockAllocator> MemoryFilesystem<A> {
fn make_path(
self: &Rc<Self>,
at: &NodeRef,
path: &Path,
kind: FileType,
create: bool,
) -> Result<NodeRef, Error> {
let access = unsafe { AccessToken::authorized() };
if path.is_empty() {
return Ok(at.clone());
}
let (element, rest) = path.split_left();
// let (element, rest) = path::split_left(path);
assert!(!element.is_empty());
assert!(!element.contains('/'));
// let node = at.lookup(element);
let node = at.lookup_or_load(element, access);
let node = match node {
Ok(node) => node,
Err(Error::DoesNotExist) => {
if !create {
return Err(Error::DoesNotExist);
}
let node = self.create_node_initial(kind);
at.add_child(element, node.clone())?;
node
}
Err(err) => {
log::warn!("{:?}: lookup failed: {:?}", path, err);
return Err(err);
}
};
if rest.is_empty() {
Ok(node)
} else {
assert!(node.is_directory());
self.make_path(&node, rest, kind, create)
}
}
fn create_node_initial(self: &Rc<Self>, kind: FileType) -> NodeRef {
match kind {
FileType::File => FileNode::<A>::new(),
FileType::Directory => DirectoryNode::<A>::new(),
_ => todo!(),
}
}
fn from_slice_internal(self: &Rc<Self>, tar_data: &'static [u8]) -> Result<NodeRef, Error> {
let root = DirectoryNode::<A>::new();
// 1. Create paths in tar
for item in TarIterator::new(tar_data) {
let Ok((hdr, _)) = item else {
return Err(Error::InvalidArgument);
};
let path = Path::from_str(hdr.name.as_str()?.trim_matches('/'));
log::debug!("Make path {:?}", path);
let (dirname, filename) = path.split_right();
let parent = self.make_path(&root, dirname, FileType::Directory, true)?;
let node = self.create_node_initial(hdr.node_kind());
parent.add_child(filename, node)?;
}
// 2. Associate files with their data
for item in TarIterator::new(tar_data) {
let Ok((hdr, data)) = item else {
panic!("Unreachable");
};
let path = Path::from_str(hdr.name.as_str()?.trim_matches('/'));
let node = self.make_path(&root, path, FileType::Directory, false)?;
assert_eq!(node.ty(), hdr.node_kind());
let uid = unsafe { UserId::from_raw(usize::from(&hdr.uid) as u32) };
let gid = unsafe { GroupId::from_raw(usize::from(&hdr.gid) as u32) };
let mode = convert_mode(usize::from(&hdr.mode))?;
let access = unsafe { AccessToken::authorized() };
node.set_access(Some(uid), Some(gid), Some(mode), access)?;
if hdr.node_kind() == FileType::File {
let data = data.unwrap();
let node_data = node.data_as_ref::<FileNode<A>>();
let mut bvec = node_data.data.lock();
bvec.init_with_cow(data)?;
assert_eq!(bvec.size(), data.len());
}
}
Ok(root)
}
/// Constructs a filesystem tree from a tar image in memory
pub fn from_slice(tar_data: &'static [u8]) -> Result<Rc<Self>, Error> {
let fs = Rc::new(Self {
root: RefCell::new(None),
_pd: PhantomData,
});
let root = fs.from_slice_internal(tar_data)?;
fs.root.replace(Some(root));
Ok(fs)
}
// TODO Filesystem trait?
/// Returns the root node of the memory filesystem
pub fn root(&self) -> Result<NodeRef, Error> {
Ok(self.root.borrow().clone().unwrap())
}
}
fn convert_mode(mode: usize) -> Result<FileMode, Error> {
Ok(FileMode::new(mode as u32 & 0o777))
}

View File

@ -0,0 +1,138 @@
use yggdrasil_abi::{error::Error, io::FileType};
#[repr(C)]
pub(crate) struct OctalField<const N: usize> {
data: [u8; N],
}
#[repr(C)]
pub(crate) struct TarString<const N: usize> {
data: [u8; N],
}
pub(crate) struct TarIterator<'a> {
data: &'a [u8],
offset: usize,
zero_blocks: usize,
}
#[repr(packed)]
pub(crate) struct TarEntry {
pub name: TarString<100>,
pub mode: OctalField<8>,
pub uid: OctalField<8>,
pub gid: OctalField<8>,
pub size: OctalField<12>,
_mtime: OctalField<12>,
_checksum: OctalField<8>,
type_: u8,
_link_name: TarString<100>,
_magic: [u8; 8],
_user: TarString<32>,
_group: TarString<32>,
_dev_major: OctalField<8>,
_dev_minor: OctalField<8>,
_prefix: TarString<155>,
__pad: [u8; 12],
}
impl<'a> TarIterator<'a> {
pub const fn new(data: &'a [u8]) -> Self {
Self {
data,
offset: 0,
zero_blocks: 0,
}
}
}
impl<'a> Iterator for TarIterator<'a> {
type Item = Result<(&'a TarEntry, Option<&'a [u8]>), Error>;
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.offset + 512 > self.data.len() {
break None;
}
let hdr_ptr = &self.data[self.offset..];
let hdr = unsafe { &*(hdr_ptr.as_ptr() as *const TarEntry) };
if hdr.is_empty() {
if self.zero_blocks == 1 {
self.offset = self.data.len();
return None;
}
self.zero_blocks += 1;
continue;
}
let size = usize::from(&hdr.size);
let size_aligned = (size + 511) & !511;
let (data, size_aligned) = match hdr.type_ {
0 | b'0' => {
if self.offset + 512 + size > self.data.len() {
return Some(Err(Error::InvalidArgument));
}
let data = &self.data[self.offset + 512..self.offset + 512 + size];
(Some(data), size_aligned)
}
// Directory
b'5' => (None, 0),
_ => {
self.offset += size_aligned + 512;
continue;
}
};
self.offset += size_aligned + 512;
break Some(Ok((hdr, data)));
}
}
}
impl<const N: usize> From<&OctalField<N>> for usize {
fn from(value: &OctalField<N>) -> Self {
let mut acc = 0;
for i in 0..N {
if !(b'0'..b'8').contains(&value.data[i]) {
break;
}
acc <<= 3;
acc |= (value.data[i] - b'0') as usize;
}
acc
}
}
impl<const N: usize> TarString<N> {
pub fn as_str(&self) -> Result<&str, Error> {
core::str::from_utf8(&self.data[..self.len()]).map_err(|_| Error::InvalidArgument)
}
pub fn len(&self) -> usize {
for i in 0..N {
if self.data[i] == 0 {
return i;
}
}
N
}
}
impl TarEntry {
pub fn is_empty(&self) -> bool {
self.name.data[0] == 0
}
pub fn node_kind(&self) -> FileType {
match self.type_ {
0 | b'0' => FileType::File,
b'5' => FileType::Directory,
_ => todo!(),
}
}
}

View File

@ -0,0 +1 @@
This is another test file

View File

@ -0,0 +1 @@
This is a test file

Binary file not shown.

View File

@ -0,0 +1,11 @@
[package]
name = "ygg_driver_input"
version = "0.1.0"
edition = "2021"
[dependencies]
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git" }
libk-util = { path = "../../libk/libk-util" }
libk-thread = { path = "../../libk/libk-thread" }
libk-mm = { path = "../../libk/libk-mm" }
vfs = { path = "../../lib/vfs" }

View File

@ -0,0 +1,54 @@
#![no_std]
extern crate alloc;
use core::task::{Context, Poll};
use libk_thread::block;
use libk_util::ring::LossyRingQueue;
use vfs::{CharDevice, FileReadiness};
use yggdrasil_abi::{
error::Error,
io::{DeviceRequest, KeyboardKeyEvent},
};
pub struct KeyboardDevice;
impl FileReadiness for KeyboardDevice {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
INPUT_QUEUE.poll_readable(cx).map(Ok)
}
}
impl CharDevice for KeyboardDevice {
fn read(&'static self, buf: &mut [u8]) -> Result<usize, Error> {
if buf.len() < 4 {
return Ok(0);
}
let ev = block!(INPUT_QUEUE.read().await)?;
buf[..4].copy_from_slice(&ev.as_bytes());
Ok(4)
}
fn is_writable(&self) -> bool {
false
}
fn device_request(&self, _req: &mut DeviceRequest) -> Result<(), Error> {
todo!()
}
fn is_terminal(&self) -> bool {
false
}
}
static INPUT_QUEUE: LossyRingQueue<KeyboardKeyEvent> = LossyRingQueue::with_capacity(32);
pub static KEYBOARD_DEVICE: KeyboardDevice = KeyboardDevice;
pub fn send_event(ev: KeyboardKeyEvent) {
INPUT_QUEUE.write(ev);
}

View File

@ -0,0 +1,21 @@
[package]
name = "ygg_driver_net_core"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
yggdrasil-abi = { git = "https://git.alnyan.me/yggdrasil/yggdrasil-abi.git", features = ["serde_kernel", "bytemuck"] }
libk-mm = { path = "../../../libk/libk-mm" }
libk-util = { path = "../../../libk/libk-util" }
libk-thread = { path = "../../../libk/libk-thread" }
libk-device = { path = "../../../libk/libk-device" }
vfs = { path = "../../../lib/vfs" }
kernel-fs = { path = "../../fs/kernel-fs" }
log = "0.4.20"
bytemuck = { version = "1.14.0", features = ["derive"] }
serde_json = { version = "1.0.111", default-features = false, features = ["alloc"] }
serde = { version = "1.0.193", features = ["derive"], default-features = false }

View File

@ -0,0 +1,188 @@
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use serde::Serialize;
use vfs::{ChannelDescriptor, MessagePayload};
use yggdrasil_abi::{
error::Error,
io::{ChannelPublisherId, MessageDestination},
net::{
netconfig::{
InterfaceInfo, InterfaceQuery, NetConfigRequest, NetConfigResult, RouteInfo,
RoutingInfo,
},
IpAddr, SubnetAddr,
},
};
use crate::{
interface::NetworkInterface,
l3::{arp, Route},
};
async fn receive_request(
channel: &ChannelDescriptor,
) -> Result<(ChannelPublisherId, NetConfigRequest), Error> {
loop {
let raw = channel.receive_message_async().await?;
match &raw.payload {
MessagePayload::Data(message) => {
let msg =
serde_json::from_slice(message.as_ref()).map_err(|_| Error::InvalidArgument)?;
return Ok((raw.source, msg));
}
MessagePayload::File(_) => (),
}
}
}
fn send_reply<T: Serialize>(
channel: &ChannelDescriptor,
recepient: ChannelPublisherId,
message: NetConfigResult<T>,
) -> Result<(), Error> {
let data = serde_json::to_vec(&message).map_err(|_| Error::InvalidArgument)?;
channel.send_message(
MessagePayload::Data(data.into_boxed_slice()),
MessageDestination::Specific(recepient.into()),
)
}
fn list_interfaces() -> Vec<(Box<str>, u32)> {
let interfaces = NetworkInterface::list_ref();
interfaces
.iter()
.map(|(&id, iface)| (iface.name.clone(), id))
.collect()
}
fn describe_interface(interface: &NetworkInterface) -> InterfaceInfo {
InterfaceInfo {
interface_id: interface.id,
interface_name: interface.name.clone(),
address: interface.address.read().map(Into::into),
mac: interface.mac,
}
}
fn describe_route(route: &Route) -> RouteInfo {
// NOTE: must exist
let interface = NetworkInterface::get(route.interface).unwrap();
RouteInfo {
interface_name: interface.name.clone(),
interface_id: route.interface,
subnet: route.subnet,
gateway: route.gateway.map(Into::into),
}
}
fn query_route(destination: IpAddr) -> Option<RoutingInfo> {
let (interface_id, gateway, destination) = Route::lookup(destination)?;
let interface = NetworkInterface::get(interface_id).unwrap();
let source = *interface.address.read();
Some(RoutingInfo {
interface_name: interface.name.clone(),
interface_id,
destination,
gateway,
source,
source_mac: interface.mac,
})
}
fn query_interface(query: InterfaceQuery) -> Option<Arc<NetworkInterface>> {
match query {
InterfaceQuery::ById(id) => NetworkInterface::get(id).ok(),
InterfaceQuery::ByName(name) => {
let interfaces = NetworkInterface::list_ref();
interfaces.iter().find_map(|(_, iface)| {
if iface.name == name {
Some(iface.clone())
} else {
None
}
})
}
}
}
fn add_route(
query: InterfaceQuery,
gateway: Option<IpAddr>,
subnet: SubnetAddr,
) -> Result<(), &'static str> {
let interface = query_interface(query).ok_or("No such interface")?;
let route = Route {
interface: interface.id,
gateway,
subnet,
};
Route::insert(route).map_err(|_| "Could not insert route")?;
Ok(())
}
pub async fn network_config_service() -> Result<(), Error> {
let channel = ChannelDescriptor::open("@kernel-netconf", true);
loop {
let (sender_id, request) = receive_request(&channel).await?;
match request {
NetConfigRequest::ListRoutes => {
let routes = Route::list_ref();
let route_info: Vec<_> = routes.iter().map(describe_route).collect();
send_reply(&channel, sender_id, NetConfigResult::Ok(route_info))?;
}
NetConfigRequest::ListInterfaces => {
let interfaces = list_interfaces();
send_reply(&channel, sender_id, NetConfigResult::Ok(interfaces))?;
}
NetConfigRequest::DescribeRoutes(_query) => todo!(),
NetConfigRequest::DescribeInterface(query) => {
let result = match query_interface(query) {
Some(interface) => NetConfigResult::Ok(describe_interface(&interface)),
None => NetConfigResult::err("No such interface"),
};
send_reply(&channel, sender_id, result)?;
}
NetConfigRequest::AddRoute {
interface,
gateway,
subnet,
} => {
let result = match add_route(interface, gateway, subnet) {
Ok(()) => NetConfigResult::Ok(()),
Err(error) => NetConfigResult::err(error),
};
send_reply(&channel, sender_id, result)?;
}
NetConfigRequest::SetNetworkAddress { interface, address } => {
let result = match query_interface(interface) {
Some(interface) => {
interface.set_address(address);
NetConfigResult::Ok(())
}
None => NetConfigResult::err("No such interface"),
};
send_reply(&channel, sender_id, result)?;
}
NetConfigRequest::ClearNetworkAddress(_interface) => todo!(),
NetConfigRequest::QueryRoute(destination) => {
let result = match query_route(destination) {
Some(route) => NetConfigResult::Ok(route),
None => NetConfigResult::err("No route to host"),
};
send_reply(&channel, sender_id, result)?;
}
NetConfigRequest::QueryArp(interface_id, destination, perform_query) => {
let result = match arp::lookup(interface_id, destination, perform_query).await {
Ok(mac) => NetConfigResult::Ok(mac),
Err(_) => NetConfigResult::err("No ARP entry"),
};
send_reply(&channel, sender_id, result)?;
}
}
}
}

View File

@ -0,0 +1,80 @@
use core::mem::size_of;
use alloc::sync::Arc;
use bytemuck::Pod;
use libk_mm::PageBox;
use yggdrasil_abi::{
error::Error,
net::{
protocols::{EtherType, EthernetFrame},
types::NetValueImpl,
MacAddress,
},
};
use crate::{interface::NetworkInterface, l3, socket::RawSocket};
#[derive(Clone)]
pub struct L2Packet {
pub interface_id: u32,
pub source_address: MacAddress,
pub destination_address: MacAddress,
pub l2_offset: usize,
pub l3_offset: usize,
pub data: Arc<PageBox<[u8]>>,
}
impl L2Packet {
pub fn ethernet_frame(&self) -> &EthernetFrame {
bytemuck::from_bytes(
&self.data[self.l2_offset..self.l2_offset + size_of::<EthernetFrame>()],
)
}
pub fn l2_data(&self) -> &[u8] {
&self.data[self.l3_offset..]
}
}
pub fn send_l2<T: Pod>(
interface: &NetworkInterface,
source_mac: MacAddress,
destination_mac: MacAddress,
ethertype: EtherType,
l2_data: &T,
) -> Result<(), Error> {
let l2_frame = EthernetFrame {
source_mac,
destination_mac,
ethertype: ethertype.to_network_order(),
};
log::debug!(
"send_l2: {} -> {}",
l2_frame.source_mac,
l2_frame.destination_mac
);
interface.send_l2(&l2_frame, bytemuck::bytes_of(l2_data))
}
pub fn handle(packet: L2Packet) {
let frame = packet.ethernet_frame();
let ty = EtherType::from_network_order(frame.ethertype);
RawSocket::packet_received(packet.clone());
match ty {
EtherType::ARP => l3::arp::handle_packet(packet),
EtherType::IPV4 => l3::ip::handle_v4_packet(packet),
p => {
log::debug!(
"Unrecognized L2 protocol: {:#06x}",
bytemuck::cast::<_, u16>(p)
);
}
}
}

View File

@ -0,0 +1,148 @@
use core::{
mem::size_of,
sync::atomic::{AtomicU32, AtomicUsize, Ordering},
};
use alloc::{boxed::Box, collections::BTreeMap, format, sync::Arc};
// TODO: link state management?
use libk_mm::PageBox;
use libk_util::{
sync::spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockReadGuard},
OneTimeInit,
};
use yggdrasil_abi::{
error::Error,
net::{protocols::EthernetFrame, IpAddr, MacAddress},
};
use crate::l3::{arp::ArpTable, Route};
pub trait NetworkDevice: Sync {
fn transmit(&self, packet: PageBox<[u8]>) -> Result<(), Error>;
fn packet_prefix_size(&self) -> usize;
fn read_hardware_address(&self) -> MacAddress;
}
pub struct NetworkInterface {
pub(crate) name: Box<str>,
pub(crate) device: &'static dyn NetworkDevice,
pub(crate) mac: MacAddress,
pub(crate) address: IrqSafeRwLock<Option<IpAddr>>,
pub(crate) id: u32,
}
#[derive(PartialEq, Eq)]
pub enum NetworkInterfaceType {
Ethernet,
Loopback,
}
static INTERFACES: IrqSafeRwLock<BTreeMap<u32, Arc<NetworkInterface>>> =
IrqSafeRwLock::new(BTreeMap::new());
static LAST_INTERFACE_ID: AtomicU32 = AtomicU32::new(1);
static LOOPBACK: OneTimeInit<Arc<NetworkInterface>> = OneTimeInit::new();
impl NetworkInterface {
pub fn id(&self) -> u32 {
self.id
}
pub fn loopback() -> &'static Arc<Self> {
LOOPBACK.get()
}
pub fn get(id: u32) -> Result<Arc<Self>, Error> {
INTERFACES
.read()
.get(&id)
.cloned()
.ok_or(Error::DoesNotExist)
}
pub fn query_by_name(name: &str) -> Result<Arc<Self>, Error> {
INTERFACES
.read()
.iter()
.find_map(|(_, iface)| {
if iface.name.as_ref() == name {
Some(iface.clone())
} else {
None
}
})
.ok_or(Error::DoesNotExist)
}
pub fn list_ref() -> IrqSafeRwLockReadGuard<'static, BTreeMap<u32, Arc<NetworkInterface>>> {
INTERFACES.read()
}
pub fn set_address(&self, address: IpAddr) {
// Flush routes associated with the interface
{
let mut routes = Route::list_mut();
routes.retain(|route| route.interface != self.id);
}
let mut addr = self.address.write();
// Flush owned ARP entries related to the old address
if let Some(address) = *addr {
ArpTable::flush_address(self.id, address);
}
addr.replace(address);
ArpTable::insert_address(self.id, self.mac, address, true);
}
pub fn send_l2(&self, l2_frame: &EthernetFrame, l2_data: &[u8]) -> Result<(), Error> {
let l2_offset = self.device.packet_prefix_size();
let l2_data_offset = l2_offset + size_of::<EthernetFrame>();
let mut packet = PageBox::new_slice(0, l2_data_offset + l2_data.len())?;
packet[l2_offset..l2_data_offset].copy_from_slice(bytemuck::bytes_of(l2_frame));
packet[l2_data_offset..].copy_from_slice(l2_data);
self.device.transmit(packet)
}
}
pub fn register_interface(
ty: NetworkInterfaceType,
dev: &'static dyn NetworkDevice,
) -> Arc<NetworkInterface> {
let name = match ty {
NetworkInterfaceType::Ethernet => {
static LAST_ETHERNET_ID: AtomicUsize = AtomicUsize::new(0);
let eth_id = LAST_ETHERNET_ID.fetch_add(1, Ordering::SeqCst);
format!("eth{}", eth_id).into_boxed_str()
}
NetworkInterfaceType::Loopback => "lo".into(),
};
let mac = dev.read_hardware_address();
let id = LAST_INTERFACE_ID.fetch_add(1, Ordering::SeqCst);
log::info!("Register network interface {} (#{}): {}", name, id, mac);
let iface = NetworkInterface {
name,
device: dev,
mac,
address: IrqSafeRwLock::new(None),
id,
};
let interface = Arc::new(iface);
INTERFACES.write().insert(id, interface.clone());
if ty == NetworkInterfaceType::Loopback {
LOOPBACK.init(interface.clone());
}
interface
}

View File

@ -0,0 +1,262 @@
use core::{
future::Future,
mem::size_of,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use alloc::{boxed::Box, collections::BTreeMap};
use libk_thread::runtime;
use libk_util::{sync::spin_rwlock::IrqSafeRwLock, waker::QueueWaker};
use yggdrasil_abi::{
error::Error,
net::{
protocols::{ArpFrame, EtherType},
types::NetValueImpl,
IpAddr, Ipv4Addr, MacAddress,
},
};
use crate::{ethernet, interface::NetworkInterface, L2Packet};
struct Inner<A: Ord + Eq + Copy> {
entries: BTreeMap<(u32, A), (MacAddress, bool)>,
reverse: BTreeMap<(u32, MacAddress), (A, bool)>,
}
pub struct ArpTable {
v4: IrqSafeRwLock<Inner<Ipv4Addr>>,
notify: QueueWaker,
}
impl<A: Ord + Eq + Copy> Inner<A> {
const fn new() -> Self {
Self {
entries: BTreeMap::new(),
reverse: BTreeMap::new(),
}
}
fn query_mac(&self, interface: u32, address: A) -> Option<(MacAddress, bool)> {
self.entries.get(&(interface, address)).copied()
}
// fn query_address(&self, interface: u32, mac: MacAddress) -> Option<(A, bool)> {
// self.reverse.get(&(interface, mac)).copied()
// }
fn insert(&mut self, interface: u32, mac: MacAddress, address: A, owned: bool) -> bool {
let new = self
.entries
.insert((interface, address), (mac, owned))
.is_none();
self.reverse.insert((interface, mac), (address, owned));
new
}
fn flush(&mut self, interface: u32, address: A) -> bool {
if let Some((mac, _)) = self.entries.remove(&(interface, address)) {
self.reverse.remove(&(interface, mac));
true
} else {
false
}
}
}
impl ArpTable {
pub const fn new() -> Self {
Self {
v4: IrqSafeRwLock::new(Inner::new()),
notify: QueueWaker::new(),
}
}
pub fn lookup_cache_v4(interface: u32, address: Ipv4Addr) -> Option<(MacAddress, bool)> {
ARP_TABLE.v4.read().query_mac(interface, address)
}
pub fn lookup_cache(interface: u32, address: IpAddr) -> Option<MacAddress> {
let (address, _) = match address {
IpAddr::V4(address) => Self::lookup_cache_v4(interface, address),
IpAddr::V6(_) => todo!(),
}?;
Some(address)
}
pub fn flush_address_v4(interface: u32, address: Ipv4Addr) -> bool {
ARP_TABLE.v4.write().flush(interface, address)
}
pub fn flush_address(interface: u32, address: IpAddr) -> bool {
match address {
IpAddr::V4(address) => Self::flush_address_v4(interface, address),
IpAddr::V6(_) => todo!(),
}
}
pub fn insert_address_v4(interface: u32, mac: MacAddress, address: Ipv4Addr, owned: bool) {
ARP_TABLE.v4.write().insert(interface, mac, address, owned);
}
pub fn insert_address(interface: u32, mac: MacAddress, address: IpAddr, owned: bool) {
match address {
IpAddr::V4(address) => Self::insert_address_v4(interface, mac, address, owned),
IpAddr::V6(_) => todo!(),
}
ARP_TABLE.notify.wake_all();
}
fn poll_address(
interface: u32,
address: IpAddr,
timeout: Duration,
) -> impl Future<Output = Option<MacAddress>> {
struct F<T: Future<Output = ()>> {
interface: u32,
address: IpAddr,
timeout: Pin<Box<T>>,
}
impl<T: Future<Output = ()>> Future for F<T> {
type Output = Option<MacAddress>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.timeout.as_mut().poll(cx).is_ready() {
ARP_TABLE.notify.remove(cx.waker());
return Poll::Ready(None);
}
ARP_TABLE.notify.register(cx.waker());
if let Some(mac) = ArpTable::lookup_cache(self.interface, self.address) {
ARP_TABLE.notify.remove(cx.waker());
Poll::Ready(Some(mac))
} else {
Poll::Pending
}
}
}
F {
interface,
address,
timeout: Box::pin(runtime::sleep(timeout)),
}
}
}
static ARP_TABLE: ArpTable = ArpTable::new();
pub async fn lookup(interface: u32, ip: IpAddr, perform_query: bool) -> Result<MacAddress, Error> {
if let Some(mac) = ArpTable::lookup_cache(interface, ip) {
return Ok(mac);
}
if !perform_query {
return Err(Error::HostUnreachable);
}
query(interface, ip, 5, Duration::from_millis(200)).await
}
async fn query(
interface: u32,
ip: IpAddr,
retries: usize,
retry_timeout: Duration,
) -> Result<MacAddress, Error> {
let interface = NetworkInterface::get(interface)?;
for _ in 0..retries {
send_request(&interface, ip)?;
if let Some(mac) = ArpTable::poll_address(interface.id, ip, retry_timeout).await {
return Ok(mac);
}
}
Err(Error::HostUnreachable)
}
fn send_request_v4(interface: &NetworkInterface, query_address: Ipv4Addr) -> Result<(), Error> {
let request = ArpFrame {
protocol: EtherType::IPV4.to_network_order(),
protocol_size: 4,
hardware_type: u16::to_network_order(1),
hardware_size: 6,
opcode: 1u16.to_network_order(),
sender_mac: interface.mac,
// TODO maybe would be nice to specify
sender_ip: 0u32.to_network_order(),
target_ip: u32::to_network_order(query_address.into()),
target_mac: MacAddress::UNSPECIFIED,
};
ethernet::send_l2(
interface,
interface.mac,
MacAddress::BROADCAST,
EtherType::ARP,
&request,
)
}
fn send_request(interface: &NetworkInterface, query_address: IpAddr) -> Result<(), Error> {
log::debug!("Querying address of {}", query_address);
match query_address {
IpAddr::V4(address) => send_request_v4(interface, address),
IpAddr::V6(_) => todo!(),
}
}
fn send_reply(interface_id: u32, arp: &ArpFrame, target_mac: MacAddress) -> Result<(), Error> {
let interface = NetworkInterface::get(interface_id)?;
let reply = ArpFrame {
protocol: arp.protocol,
hardware_type: arp.hardware_type,
hardware_size: arp.hardware_size,
protocol_size: arp.protocol_size,
opcode: 2u16.to_network_order(),
sender_mac: target_mac,
sender_ip: arp.target_ip,
target_ip: arp.sender_ip,
target_mac: arp.sender_mac,
};
ethernet::send_l2(
&interface,
target_mac,
arp.sender_mac,
EtherType::ARP,
&reply,
)
}
pub fn handle_packet(packet: L2Packet) {
let arp: &ArpFrame = bytemuck::from_bytes(&packet.l2_data()[..size_of::<ArpFrame>()]);
let proto = EtherType::from_network_order(arp.protocol);
let opcode = u16::from_network_order(arp.opcode);
let (target_address, sender_address) = match proto {
EtherType::IPV4 => (
Ipv4Addr::from(u32::from_network_order(arp.target_ip)),
Ipv4Addr::from(u32::from_network_order(arp.sender_ip)),
),
_ => {
log::warn!("TODO: unhandled ARP proto: {:#x?}", proto);
return;
}
};
log::debug!("ARP: {} -> {}", sender_address, target_address);
ArpTable::insert_address_v4(packet.interface_id, arp.sender_mac, sender_address, false);
if opcode == 1 {
// Don't answer with non-owned addresses
if let Some((mac, true)) = ArpTable::lookup_cache_v4(packet.interface_id, target_address) {
// Reply with own address
send_reply(packet.interface_id, arp, mac).ok();
}
}
}

View File

@ -0,0 +1,97 @@
use core::mem::size_of;
use yggdrasil_abi::net::{
protocols::{IpProtocol, Ipv4Frame, TcpFrame, UdpFrame},
types::NetValueImpl,
IpAddr, Ipv4Addr,
};
use crate::{interface::NetworkInterface, L2Packet, L3Packet, ACCEPT_QUEUE};
use super::IpFrame;
impl IpFrame for Ipv4Frame {
fn destination_ip(&self) -> IpAddr {
IpAddr::V4(Ipv4Addr::from(u32::from_network_order(
self.destination_address,
)))
}
fn source_ip(&self) -> IpAddr {
IpAddr::V4(Ipv4Addr::from(u32::from_network_order(self.source_address)))
}
fn data_length(&self) -> usize {
self.total_length().saturating_sub(self.header_length())
}
}
pub fn handle_v4_packet(packet: L2Packet) {
let Ok(interface) = NetworkInterface::get(packet.interface_id) else {
log::debug!("Invalid interface ID in L2 packet");
return;
};
let l2_data = packet.l2_data();
let l3_frame: &Ipv4Frame = bytemuck::from_bytes(&l2_data[..size_of::<Ipv4Frame>()]);
let header_length = l3_frame.header_length();
let l3_data = &l2_data[size_of::<Ipv4Frame>()..];
let is_input = interface
.address
.read()
.map(|address| address == l3_frame.destination_ip())
.unwrap_or(false);
if is_input {
// Extract ports from L4 proto
let (source_port, destination_port) = match l3_frame.protocol {
IpProtocol::UDP => {
// TODO check size
let l4_frame: &UdpFrame = bytemuck::from_bytes(&l3_data[..size_of::<UdpFrame>()]);
(
Some(u16::from_network_order(l4_frame.source_port)),
Some(u16::from_network_order(l4_frame.destination_port)),
)
}
IpProtocol::TCP => {
// TODO check size
let l4_frame: &TcpFrame = bytemuck::from_bytes(&l3_data[..size_of::<TcpFrame>()]);
(
Some(u16::from_network_order(l4_frame.source_port)),
Some(u16::from_network_order(l4_frame.destination_port)),
)
}
IpProtocol::ICMP => (None, None),
_ => (None, None),
};
let l3_packet = L3Packet {
interface_id: packet.interface_id,
protocol: l3_frame.protocol,
source_address: l3_frame.source_ip(),
destination_address: l3_frame.destination_ip(),
source_port,
destination_port,
l2_offset: packet.l2_offset,
l3_offset: packet.l3_offset,
l4_offset: packet.l3_offset + header_length,
data_length: l3_frame.data_length(),
data: packet.data,
};
ACCEPT_QUEUE.push_back(l3_packet);
} else {
// TODO forwarding
log::debug!(
"Dropped forwarded IPv4: {} -> {}",
l3_frame.source_ip(),
l3_frame.destination_ip()
);
}
}

View File

@ -0,0 +1,246 @@
use core::{fmt, mem::size_of};
use alloc::{sync::Arc, vec::Vec};
use bytemuck::{Pod, Zeroable};
use libk_mm::PageBox;
use libk_util::sync::spin_rwlock::{
IrqSafeRwLock, IrqSafeRwLockReadGuard, IrqSafeRwLockWriteGuard,
};
use yggdrasil_abi::{
error::Error,
net::{
protocols::{EtherType, EthernetFrame, InetChecksum, IpProtocol, Ipv4Frame},
types::NetValueImpl,
IpAddr, Ipv4Addr, MacAddress, SubnetAddr,
},
};
use crate::{interface::NetworkInterface, l4, PacketBuilder};
pub mod arp;
pub mod ip;
pub struct L3Packet {
pub interface_id: u32,
pub protocol: IpProtocol,
pub source_address: IpAddr,
pub destination_address: IpAddr,
pub source_port: Option<u16>,
pub destination_port: Option<u16>,
pub l2_offset: usize,
pub l3_offset: usize,
pub l4_offset: usize,
pub data_length: usize,
pub data: Arc<PageBox<[u8]>>,
}
pub trait IpFrame: Pod {
fn destination_ip(&self) -> IpAddr;
fn source_ip(&self) -> IpAddr;
fn data_length(&self) -> usize;
}
// TODO use range map for this?
pub struct Route {
pub subnet: SubnetAddr,
pub interface: u32,
pub gateway: Option<IpAddr>,
}
pub struct L4ResolvedPacket<'a, 'i> {
pub interface: &'i NetworkInterface,
pub source_ip: IpAddr,
pub gateway_ip: IpAddr,
pub destination_ip: IpAddr,
pub l4_frame: &'a [u8],
pub l4_options: &'a [u8],
pub l4_data: &'a [u8],
pub protocol: IpProtocol,
pub ttl: u8,
}
pub struct L4UnresolvedPacket<'a> {
pub destination_ip: IpAddr,
pub l4_frame: &'a [u8],
pub l4_options: &'a [u8],
pub l4_data: &'a [u8],
pub protocol: IpProtocol,
pub ttl: u8,
}
static ROUTES: IrqSafeRwLock<Vec<Route>> = IrqSafeRwLock::new(Vec::new());
impl L3Packet {
pub fn l3_data(&self) -> &[u8] {
&self.data[self.l4_offset..]
}
}
impl Route {
pub fn list_mut() -> IrqSafeRwLockWriteGuard<'static, Vec<Route>> {
ROUTES.write()
}
pub fn list_ref() -> IrqSafeRwLockReadGuard<'static, Vec<Route>> {
ROUTES.read()
}
pub fn lookup(address: IpAddr) -> Option<(u32, Option<IpAddr>, IpAddr)> {
// TODO sort routes based on their "specificity"?
// Check for local route
for (_, interface) in NetworkInterface::list_ref().iter() {
if interface
.address
.read()
.map(|addr| addr == address)
.unwrap_or(false)
{
// This is the address of loopback, return it
return Some((
NetworkInterface::loopback().id,
Some(IpAddr::V4(Ipv4Addr::LOOPBACK)),
IpAddr::V4(Ipv4Addr::LOOPBACK),
));
}
}
let routes = ROUTES.read();
for route in routes.iter() {
if route.subnet.contains(&address) {
return Some((route.interface, route.gateway, address));
}
}
None
}
pub fn insert(route: Self) -> Result<(), Error> {
// TODO check for conflicts
log::debug!("Add route: {}", route);
ROUTES.write().push(route);
Ok(())
}
}
impl fmt::Display for Route {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} ", self.subnet)?;
if let Some(gw) = self.gateway {
write!(f, " via {}", gw)?;
}
Ok(())
}
}
pub async fn handle_accepted(l3_packet: L3Packet) -> Result<(), Error> {
match l3_packet.protocol {
IpProtocol::UDP => l4::udp::handle(l3_packet),
IpProtocol::ICMP => l4::icmp::handle(l3_packet).await,
IpProtocol::TCP => l4::tcp::handle(l3_packet).await,
_ => Ok(()),
}
}
impl<'a, 'i> L4ResolvedPacket<'a, 'i> {
pub async fn lookup_gateway_mac(&self) -> Result<MacAddress, Error> {
arp::lookup(self.interface.id, self.gateway_ip, true).await
}
pub fn make_l3_frame(&self) -> Result<Ipv4Frame, Error> {
// TODO what if source_ip/gateway_ip/destination_ip are a mix of IPv4/IPv6?
let source_ip = self.source_ip.into_ipv4().ok_or(Error::NotImplemented)?;
let destination_ip = self
.destination_ip
.into_ipv4()
.ok_or(Error::NotImplemented)?;
let total_length = (self.total_l4_len() + size_of::<Ipv4Frame>())
.try_into()
.map_err(|_| Error::InvalidArgument)?;
let mut l3_frame = Ipv4Frame {
source_address: u32::to_network_order(source_ip.into()),
destination_address: u32::to_network_order(destination_ip.into()),
protocol: self.protocol,
version_length: 0x45,
total_length: u16::to_network_order(total_length),
flags_frag: u16::to_network_order(0x4000),
id: u16::to_network_order(0),
ttl: self.ttl,
..Ipv4Frame::zeroed()
};
let l3_frame_bytes = bytemuck::bytes_of(&l3_frame);
let mut ip_checksum = InetChecksum::new();
ip_checksum.add_bytes(l3_frame_bytes, true);
let ip_checksum = ip_checksum.finish();
l3_frame.header_checksum = u16::to_network_order(ip_checksum);
Ok(l3_frame)
}
pub fn total_l4_len(&self) -> usize {
self.l4_frame.len() + self.l4_options.len() + self.l4_data.len()
}
}
pub(crate) fn resolve_l3_route(
destination_ip: IpAddr,
) -> Result<(Arc<NetworkInterface>, IpAddr, IpAddr, IpAddr), Error> {
// Find the route itself
let (interface_id, gateway, destination_ip) =
Route::lookup(destination_ip).ok_or(Error::NetworkUnreachable)?;
let interface = NetworkInterface::get(interface_id)?;
// Route exists, but has no gateway (TODO: assign subnets to interfaces)
let gateway = gateway.ok_or(Error::NetworkUnreachable)?;
// Route exists, but network has no address assigned (TODO: how?)
let source_address = interface.address.read().ok_or(Error::NetworkUnreachable)?;
Ok((interface, source_address, gateway, destination_ip))
}
pub async fn send_l4_ip_resolved(packet: &L4ResolvedPacket<'_, '_>) -> Result<(), Error> {
let gateway_mac = packet.lookup_gateway_mac().await?;
let l3_frame = packet.make_l3_frame()?;
let mut builder = PacketBuilder::new(
packet.interface.device.packet_prefix_size(),
size_of::<EthernetFrame>() + size_of::<Ipv4Frame>() + packet.total_l4_len(),
)?;
builder.push(&EthernetFrame {
source_mac: packet.interface.mac,
destination_mac: gateway_mac,
ethertype: EtherType::IPV4.to_network_order(),
})?;
builder.push(&l3_frame)?;
builder.push_bytes(packet.l4_frame)?;
builder.push_bytes(packet.l4_options)?;
builder.push_bytes(packet.l4_data)?;
let (sent_packet, _len) = builder.finish();
packet.interface.device.transmit(sent_packet)
}
pub async fn send_l4_ip(packet: &L4UnresolvedPacket<'_>) -> Result<(), Error> {
let (interface, source_ip, gateway_ip, destination_ip) =
resolve_l3_route(packet.destination_ip)?;
send_l4_ip_resolved(&L4ResolvedPacket {
interface: &interface,
source_ip,
gateway_ip,
destination_ip,
l4_frame: packet.l4_frame,
l4_options: packet.l4_options,
l4_data: packet.l4_data,
protocol: packet.protocol,
ttl: packet.ttl,
})
.await
}

View File

@ -0,0 +1,82 @@
use core::mem::size_of;
use yggdrasil_abi::{
error::Error,
net::{
protocols::{IcmpV4Frame, InetChecksum, IpProtocol},
types::NetValueImpl,
IpAddr, Ipv4Addr,
},
};
use crate::{l3, L3Packet};
async fn send_v4_reply(
destination_ip: Ipv4Addr,
icmp_frame: &IcmpV4Frame,
icmp_data: &[u8],
) -> Result<(), Error> {
let mut reply_frame = IcmpV4Frame {
ty: 0,
code: 0,
checksum: u16::to_network_order(0),
rest: icmp_frame.rest,
};
if icmp_data.len() % 2 != 0 {
todo!();
}
let l4_bytes = bytemuck::bytes_of(&reply_frame);
let mut checksum = InetChecksum::new();
checksum.add_bytes(l4_bytes, true);
checksum.add_bytes(icmp_data, true);
reply_frame.checksum = checksum.finish().to_network_order();
l3::send_l4_ip(&l3::L4UnresolvedPacket {
destination_ip: IpAddr::V4(destination_ip),
l4_frame: bytemuck::bytes_of(&reply_frame),
l4_options: &[],
l4_data: icmp_data,
protocol: IpProtocol::ICMP,
ttl: 255,
})
.await
}
async fn handle_v4(source_address: Ipv4Addr, l3_packet: L3Packet) -> Result<(), Error> {
if l3_packet.data_length < size_of::<IcmpV4Frame>() {
log::debug!("Truncated ICMPv4 packet");
return Err(Error::MissingData);
}
if l3_packet.data_length - size_of::<IcmpV4Frame>() > 576 {
log::debug!("ICMPv4 packet too large");
return Err(Error::MissingData);
}
let l3_data = l3_packet.l3_data();
let icmp_frame: &IcmpV4Frame = bytemuck::from_bytes(&l3_data[..size_of::<IcmpV4Frame>()]);
let icmp_data = &l3_data[size_of::<IcmpV4Frame>()..l3_packet.data_length];
match (icmp_frame.ty, icmp_frame.code) {
(8, 0) => send_v4_reply(source_address, icmp_frame, icmp_data).await,
(0, 0) => Ok(()),
_ => {
log::debug!(
"Ignoring unknown ICMPv4 type:code: {}:{}",
icmp_frame.ty,
icmp_frame.code
);
Ok(())
}
}
}
pub async fn handle(l3_packet: L3Packet) -> Result<(), Error> {
match l3_packet.source_address {
IpAddr::V4(v4) => handle_v4(v4, l3_packet).await,
IpAddr::V6(_) => todo!(),
}
}

View File

@ -0,0 +1,3 @@
pub mod icmp;
pub mod tcp;
pub mod udp;

View File

@ -0,0 +1,721 @@
use core::{
mem::size_of,
task::{Context, Poll},
};
use alloc::{vec, vec::Vec};
use bytemuck::Zeroable;
use libk_util::waker::QueueWaker;
use yggdrasil_abi::{
error::Error,
net::{
protocols::{InetChecksum, IpProtocol, TcpFlags, TcpFrame, TcpV4PseudoHeader},
types::NetValueImpl,
IpAddr, SocketAddr,
},
};
use crate::{
l3::{self, L3Packet},
socket::{TcpListener, TcpSocket},
util::Assembler,
};
#[derive(PartialEq, Debug)]
pub enum TcpConnectionState {
SynSent,
SynReceived,
Established,
FinWait1,
FinWait2,
Closed,
}
pub enum TcpSocketBehavior {
None,
Accept,
Remove,
}
struct SocketBuffer {
data: Vec<u8>,
wr: usize,
rd: usize,
}
pub struct TcpConnection {
state: TcpConnectionState,
local: SocketAddr,
remote: SocketAddr,
rx_window_size: u16,
// Rx half
// RCV.WND = rx_buffer.capacity()
// RCV.NXT = rx_window_start + rx_window.len()
// TODO RCV.UP
rx_buffer: SocketBuffer,
rx_segment_buffer: Vec<u8>,
rx_assembler: Assembler,
// Relative RX sequence number of window start
rx_window_start: usize,
// Tx half
// SND.UNA = tx_window_start + tx_sent_unacknowledged
// SND.WND = tx_buffer.capacity()
// SND.NXT = tx_window_start IF tx_sent_unacknowledged == 0
tx_buffer: Vec<u8>,
tx_window_start: usize,
tx_sent_unacknowledged: usize,
// IRS
initial_rx_seq: u32,
initial_tx_seq: u32,
rx_notify: QueueWaker,
tx_notify: QueueWaker,
}
#[allow(unused)]
struct TcpPacket {
local: SocketAddr,
remote: SocketAddr,
seq: u32,
ack: u32,
window_size: u16,
flags: TcpFlags,
}
impl SocketBuffer {
pub fn with_capacity(capacity: usize) -> Self {
Self {
data: vec![0; capacity],
wr: 0,
rd: 0,
}
}
pub fn len(&self) -> usize {
if self.wr >= self.rd {
self.wr - self.rd
} else {
self.wr + self.capacity() - self.rd
}
}
pub fn capacity(&self) -> usize {
self.data.len()
}
pub fn can_read(&self) -> bool {
self.rd != self.wr
}
pub fn write(&mut self, data: &[u8]) {
for &byte in data {
self.putc(byte);
}
}
pub fn putc(&mut self, data: u8) {
if (self.wr + 1) % self.capacity() == self.rd {
self.rd = (self.rd + 1) % self.capacity();
}
self.data[self.wr] = data;
self.wr = (self.wr + 1) % self.capacity();
}
pub fn read(&mut self, buffer: &mut [u8]) -> usize {
let mut amount = 0;
while amount != buffer.len() {
if self.rd == self.wr {
break;
}
buffer[amount] = self.data[self.rd];
self.rd = (self.rd + 1) % self.capacity();
amount += 1;
}
amount
}
}
impl TcpConnection {
pub fn new(
local: SocketAddr,
remote: SocketAddr,
window_size: usize,
tx_seq: u32,
rx_seq: u32,
state: TcpConnectionState,
) -> Self {
debug_assert!(
state == TcpConnectionState::SynSent
|| state == TcpConnectionState::Closed
|| state == TcpConnectionState::SynReceived
);
debug_assert!(window_size < u16::MAX as usize);
Self {
state,
local,
remote,
rx_buffer: SocketBuffer::with_capacity(window_size),
rx_assembler: Assembler::new(),
rx_segment_buffer: Vec::with_capacity(window_size),
rx_window_start: 1,
tx_buffer: Vec::with_capacity(window_size),
tx_window_start: 1,
tx_sent_unacknowledged: 0,
rx_window_size: window_size as u16,
initial_rx_seq: rx_seq,
initial_tx_seq: tx_seq,
rx_notify: QueueWaker::new(),
tx_notify: QueueWaker::new(),
}
}
fn ack_number(&self) -> u32 {
(self.initial_rx_seq as usize + self.rx_window_start + self.rx_buffer.len()) as u32
}
fn seq_number(&self) -> u32 {
(self.initial_tx_seq as usize + self.tx_window_start) as u32
}
pub fn is_closing(&self) -> bool {
self.state == TcpConnectionState::FinWait1
|| self.state == TcpConnectionState::FinWait2
|| self.state == TcpConnectionState::Closed
}
pub fn is_closed(&self) -> bool {
self.state == TcpConnectionState::Closed
}
pub fn read_nonblocking(&mut self, buffer: &mut [u8]) -> Result<usize, Error> {
let amount = self.rx_buffer.read(buffer);
if amount == 0 && self.state != TcpConnectionState::Established {
// TODO ConnectionAborted?
return Err(Error::ConnectionReset);
}
self.rx_window_start += amount;
Ok(amount)
}
pub(crate) fn poll_receive(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.rx_notify.register(cx.waker());
if self.rx_buffer.can_read() {
self.rx_notify.remove(cx.waker());
Poll::Ready(Ok(()))
} else if self.state != TcpConnectionState::Established {
self.rx_notify.remove(cx.waker());
Poll::Ready(Err(Error::ConnectionReset))
} else {
Poll::Pending
}
}
pub(crate) fn poll_send(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.tx_notify.register(cx.waker());
if self.state == TcpConnectionState::Closed {
self.tx_notify.remove(cx.waker());
Poll::Ready(Err(Error::ConnectionReset))
} else if self.tx_sent_unacknowledged == 0 {
self.tx_notify.remove(cx.waker());
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
pub(crate) fn poll_acknowledge(&self, cx: &mut Context<'_>) -> Poll<()> {
self.tx_notify.register(cx.waker());
if self.tx_sent_unacknowledged == 0 {
self.tx_notify.remove(cx.waker());
Poll::Ready(())
} else {
Poll::Pending
}
}
pub(crate) fn poll_finish(&self, cx: &mut Context<'_>) -> Poll<()> {
self.tx_notify.register(cx.waker());
if self.state == TcpConnectionState::FinWait2 || self.state == TcpConnectionState::Closed {
self.tx_notify.remove(cx.waker());
Poll::Ready(())
} else {
Poll::Pending
}
}
pub(crate) fn poll_established(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.rx_notify.register(cx.waker());
match self.state {
TcpConnectionState::Established => Poll::Ready(Ok(())),
TcpConnectionState::Closed => Poll::Ready(Err(Error::ConnectionRefused)),
_ => Poll::Pending,
}
}
pub(crate) async fn transmit(&mut self, data: &[u8]) -> Result<(), Error> {
assert_eq!(self.tx_sent_unacknowledged, 0);
assert_eq!(self.tx_buffer.len(), 0);
self.tx_buffer.extend_from_slice(data);
self.tx_sent_unacknowledged = data.len();
send(
self.local,
self.remote,
self.seq_number(),
self.ack_number(),
self.rx_window_size,
TcpFlags::ACK,
data,
)
.await
}
pub(crate) async fn finish(&mut self) -> Result<(), Error> {
assert_eq!(self.tx_sent_unacknowledged, 0);
assert_eq!(self.tx_buffer.len(), 0);
log::debug!("Finish connection {} <-> {}", self.local, self.remote);
send(
self.local,
self.remote,
self.seq_number(),
self.ack_number(),
self.rx_window_size,
TcpFlags::FIN | TcpFlags::ACK,
&[],
)
.await?;
self.state = TcpConnectionState::FinWait1;
Ok(())
}
pub(crate) fn notify_all(&self) {
self.rx_notify.wake_all();
self.tx_notify.wake_all();
}
pub(crate) async fn send_syn(&mut self) -> Result<(), Error> {
assert!(
self.state == TcpConnectionState::SynSent || self.state == TcpConnectionState::Closed
);
log::debug!("Send SYN {} -> {}", self.local, self.remote);
self.state = TcpConnectionState::SynSent;
send(
self.local,
self.remote,
self.initial_tx_seq,
0,
self.rx_window_size,
TcpFlags::SYN,
&[],
)
.await
}
fn handle_packet_payload(&mut self, data: &[u8], seq: u32) -> bool {
// Local side
let rx_window_start = self.rx_window_start + self.rx_buffer.len();
let rx_window_end = self.rx_window_start + self.rx_buffer.capacity();
// Remote side
let rx_segment_start = seq.wrapping_sub(self.initial_rx_seq) as usize;
let rx_segment_end = rx_segment_start + data.len();
if rx_segment_end >= rx_window_end || rx_segment_start < rx_window_start {
return false;
}
// Offset from expected seq
let segment_start_offset = rx_segment_start - rx_window_start;
// Push the data into reassembler buffer
assert!(segment_start_offset + data.len() <= self.rx_segment_buffer.capacity());
if segment_start_offset + data.len() > self.rx_segment_buffer.len() {
self.rx_segment_buffer
.resize(segment_start_offset + data.len(), 0);
}
self.rx_segment_buffer[segment_start_offset..segment_start_offset + data.len()]
.copy_from_slice(data);
let amount = self
.rx_assembler
.add_then_remove_front(segment_start_offset, data.len())
.unwrap();
if amount != 0 {
// Take data from reassembly buffer and append to rx_buffer
self.rx_buffer.write(&self.rx_segment_buffer[..amount]);
self.rx_segment_buffer.drain(..amount);
self.rx_notify.wake_one();
}
true
}
async fn handle_packet(
&mut self,
packet: TcpPacket,
data: &[u8],
) -> Result<TcpSocketBehavior, Error> {
// TODO what if window_size changes?
match self.state {
TcpConnectionState::SynSent => {
if packet.flags == TcpFlags::SYN | TcpFlags::ACK {
if packet.ack != self.initial_tx_seq.wrapping_add(1) {
log::warn!(
"Expected ACK {}, got {}",
self.initial_tx_seq.wrapping_add(1),
packet.ack
);
return Ok(TcpSocketBehavior::None);
}
log::debug!(
"TCP {} -> {} got ACKed, established",
self.local,
self.remote
);
self.initial_rx_seq = packet.seq;
// ACK the SYN+ACK
send(
self.local,
self.remote,
self.initial_tx_seq.wrapping_add(1),
self.initial_rx_seq.wrapping_add(1),
self.rx_window_size,
TcpFlags::ACK,
&[],
)
.await?;
self.state = TcpConnectionState::Established;
self.rx_notify.wake_all();
} else if packet.flags == TcpFlags::RST | TcpFlags::ACK {
log::debug!("TCP {} -> {} got RSTd, closing", self.local, self.remote);
self.state = TcpConnectionState::Closed;
self.rx_notify.wake_all();
return Ok(TcpSocketBehavior::Remove);
}
// TODO try re-sending SYN?
return Ok(TcpSocketBehavior::None);
}
TcpConnectionState::FinWait1 => {
// Check if connection close initiated locally got ACKed by remote
// TODO check ack/seq
if packet.flags == TcpFlags::FIN | TcpFlags::ACK {
self.state = TcpConnectionState::Closed;
send(
self.local,
self.remote,
self.seq_number() + 1,
self.ack_number() + 1,
self.rx_window_size,
TcpFlags::ACK,
&[],
)
.await?;
// Socket fully closed, remove from table
return Ok(TcpSocketBehavior::Remove);
}
if packet.flags == TcpFlags::ACK {
self.state = TcpConnectionState::FinWait2;
self.tx_notify.wake_all();
}
return Ok(TcpSocketBehavior::None);
}
TcpConnectionState::FinWait2 => {
if packet.flags == TcpFlags::FIN | TcpFlags::ACK {
self.state = TcpConnectionState::Closed;
send(
self.local,
self.remote,
self.seq_number() + 1,
self.ack_number() + 1,
self.rx_window_size,
TcpFlags::ACK,
&[],
)
.await?;
// Socket fully closed, remove from table
return Ok(TcpSocketBehavior::Remove);
}
return Ok(TcpSocketBehavior::None);
}
TcpConnectionState::Closed => {
log::warn!("Packet received on closed connection");
return Ok(TcpSocketBehavior::None);
}
TcpConnectionState::SynReceived => {
// TODO check ack/seq
// Handshake continuation expected (ACK)
if packet.flags == TcpFlags::ACK {
self.state = TcpConnectionState::Established;
return Ok(TcpSocketBehavior::Accept);
}
return Err(Error::InvalidArgument);
}
TcpConnectionState::Established => (),
}
if self.tx_sent_unacknowledged != 0 {
let tx_acknowledge_end = packet.ack.wrapping_sub(self.initial_tx_seq) as usize;
if tx_acknowledge_end == self.tx_window_start + self.tx_sent_unacknowledged {
self.tx_window_start += self.tx_sent_unacknowledged;
self.tx_sent_unacknowledged = 0;
self.tx_buffer.clear();
self.tx_notify.wake_one();
}
}
let mut reply_flags = TcpFlags::empty();
let mut behavior = TcpSocketBehavior::None;
if !data.is_empty() && self.handle_packet_payload(data, packet.seq) {
reply_flags |= TcpFlags::ACK;
}
// TODO check window resize notification
let mut ack_number = self.ack_number();
if packet.flags.contains(TcpFlags::FIN) {
reply_flags |= TcpFlags::ACK;
ack_number = ack_number.wrapping_add(1);
// Only send an actual FIN after a FIN without any data
if data.is_empty() {
reply_flags |= TcpFlags::FIN;
// TODO go to LastAck state and wait for ACK
self.state = TcpConnectionState::Closed;
log::trace!(
"TCP connection FIN requested by remote: {} <-> {}",
self.local,
self.remote
);
behavior = TcpSocketBehavior::Remove;
}
}
if reply_flags != TcpFlags::empty() {
send(
packet.local,
packet.remote,
self.seq_number(),
ack_number,
self.rx_window_size,
reply_flags,
&[],
)
.await?;
}
Ok(behavior)
}
}
async fn send(
local: SocketAddr,
remote: SocketAddr,
seq: u32,
ack: u32,
window_size: u16,
flags: TcpFlags,
data: &[u8],
) -> Result<(), Error> {
let (interface, source_ip, gateway_ip, destination_ip) = l3::resolve_l3_route(remote.ip())?;
// TODO TCPv6
let source_ip = source_ip.into_ipv4().unwrap();
let destination_ip = destination_ip.into_ipv4().unwrap();
let tcp_length = size_of::<TcpFrame>() + data.len();
let mut frame = TcpFrame {
source_port: local.port().to_network_order(),
destination_port: remote.port().to_network_order(),
sequence_number: seq.to_network_order(),
acknowledge_number: ack.to_network_order(),
data_offset: 5 << 4,
window_size: window_size.to_network_order(),
flags,
..TcpFrame::zeroed()
};
let pseudo_header = TcpV4PseudoHeader {
source_address: u32::from(source_ip).to_network_order(),
destination_address: u32::from(destination_ip).to_network_order(),
_zero: 0,
protocol: IpProtocol::TCP,
tcp_length: (tcp_length as u16).to_network_order(),
};
let mut checksum = InetChecksum::new();
checksum.add_value(&pseudo_header, true);
checksum.add_value(&frame, true);
checksum.add_bytes(data, true);
let checksum = checksum.finish();
frame.checksum = checksum.to_network_order();
l3::send_l4_ip_resolved(&l3::L4ResolvedPacket {
interface: &interface,
source_ip: source_ip.into(),
gateway_ip,
destination_ip: destination_ip.into(),
l4_frame: bytemuck::bytes_of(&frame),
l4_options: &[],
l4_data: data,
protocol: IpProtocol::TCP,
ttl: 64,
})
.await
}
fn validate(source: IpAddr, destination: IpAddr, tcp_frame: &TcpFrame, data: &[u8]) -> bool {
// TODO TCPv6
let source = source.into_ipv4().unwrap();
let destination = destination.into_ipv4().unwrap();
let tcp_length = size_of::<TcpFrame>() + data.len();
let pseudo_header = TcpV4PseudoHeader {
source_address: u32::from(source).to_network_order(),
destination_address: u32::from(destination).to_network_order(),
_zero: 0,
protocol: IpProtocol::TCP,
tcp_length: (tcp_length as u16).to_network_order(),
};
let mut checksum = InetChecksum::new();
checksum.add_value(&pseudo_header, true);
checksum.add_value(tcp_frame, true);
checksum.add_bytes(data, true);
let checksum = checksum.finish();
checksum == 0
}
pub async fn handle(packet: L3Packet) -> Result<(), Error> {
if packet.data_length < size_of::<TcpFrame>() {
log::warn!("Truncated TCP packet");
return Ok(());
}
let l3_data = packet.l3_data();
let tcp_frame: &TcpFrame = bytemuck::from_bytes(&l3_data[..size_of::<TcpFrame>()]);
let tcp_data_offset = tcp_frame.data_offset();
let tcp_data = &l3_data[tcp_data_offset..packet.data_length];
let remote = SocketAddr::new(
packet.source_address,
u16::from_network_order(tcp_frame.source_port),
);
let local = SocketAddr::new(
packet.destination_address,
u16::from_network_order(tcp_frame.destination_port),
);
let seq = u32::from_network_order(tcp_frame.sequence_number);
let ack = u32::from_network_order(tcp_frame.acknowledge_number);
if !validate(
packet.source_address,
packet.destination_address,
tcp_frame,
&l3_data[size_of::<TcpFrame>()..packet.data_length],
) {
log::warn!("Invalid TCP packet received");
return Ok(());
}
match tcp_frame.flags {
TcpFlags::SYN => {
if let Some(listener) = TcpListener::get(local) {
let window_size = u16::from_network_order(tcp_frame.window_size);
let tx_seq = 12345;
// Create a socket and insert it into the table
TcpSocket::accept_remote(
listener.clone(),
local,
remote,
window_size as usize,
tx_seq,
seq,
)?;
// Send SYN+ACK
send(
local,
remote,
tx_seq,
seq.wrapping_add(1),
window_size,
TcpFlags::SYN | TcpFlags::ACK,
&[],
)
.await
} else {
// RST+ACK
log::warn!("SYN {} -> {}: port not listening", remote, local);
let window_size = u16::from_network_order(tcp_frame.window_size);
send(
local,
remote,
0,
seq.wrapping_add(1),
window_size,
TcpFlags::RST | TcpFlags::ACK,
&[],
)
.await
}
}
_ => {
let packet = TcpPacket {
local,
remote,
window_size: u16::from_network_order(tcp_frame.window_size),
flags: tcp_frame.flags,
ack,
seq,
};
let socket = TcpSocket::get(local, remote).ok_or(Error::DoesNotExist)?;
let mut connection = socket.connection().write();
match connection.handle_packet(packet, tcp_data).await? {
TcpSocketBehavior::None => (),
TcpSocketBehavior::Accept => {
socket.accept();
}
TcpSocketBehavior::Remove => {
drop(connection);
socket.remove_socket()?;
}
}
Ok(())
}
}
}

View File

@ -0,0 +1,86 @@
use core::mem::size_of;
use yggdrasil_abi::{
error::Error,
net::{
protocols::{IpProtocol, UdpFrame},
types::NetValueImpl,
IpAddr, SocketAddr,
},
};
use crate::{l3, socket::UdpSocket, L3Packet};
pub async fn send(
source_port: u16,
destination_ip: IpAddr,
destination_port: u16,
ttl: u8,
data: &[u8],
) -> Result<(), Error> {
let length: u16 = (data.len() + size_of::<UdpFrame>()).try_into().unwrap();
let udp_frame = UdpFrame {
source_port: source_port.to_network_order(),
destination_port: destination_port.to_network_order(),
length: length.to_network_order(),
checksum: 0u16.to_network_order(),
};
l3::send_l4_ip(&l3::L4UnresolvedPacket {
destination_ip,
l4_frame: bytemuck::bytes_of(&udp_frame),
l4_options: &[],
l4_data: data,
protocol: IpProtocol::UDP,
ttl,
})
.await
}
// pub fn send_broadcast(
// v6: bool,
// source_port: u16,
// destination_port: u16,
// data: &[u8],
// ) -> Result<(), Error> {
// let length: u16 = (data.len() + size_of::<UdpFrame>()).try_into().unwrap();
// let udp_frame = UdpFrame {
// source_port: source_port.to_network_order(),
// destination_port: destination_port.to_network_order(),
// length: length.to_network_order(),
// checksum: 0u16.to_network_order(),
// };
//
// l3::send_l4_ip_broadcast(v6, IpProtocol::UDP, &udp_frame, data)
// }
pub fn handle(l3_packet: L3Packet) -> Result<(), Error> {
if l3_packet.data_length < size_of::<UdpFrame>() {
log::warn!("Truncated UDP frame received");
return Err(Error::MissingData);
}
let l3_data = l3_packet.l3_data();
let udp_frame: &UdpFrame = bytemuck::from_bytes(&l3_data[..size_of::<UdpFrame>()]);
let data_size = core::cmp::min(
udp_frame.data_length(),
l3_packet.data_length - size_of::<UdpFrame>(),
);
let udp_data = &l3_data[size_of::<UdpFrame>()..data_size + size_of::<UdpFrame>()];
let source = SocketAddr::new(
l3_packet.source_address,
u16::from_network_order(udp_frame.source_port),
);
let destination = SocketAddr::new(
l3_packet.destination_address,
u16::from_network_order(udp_frame.destination_port),
);
if let Some(socket) = UdpSocket::get(&destination) {
socket.packet_received(source, udp_data).ok();
}
Ok(())
}

View File

@ -0,0 +1,143 @@
#![feature(map_try_insert)]
#![allow(clippy::type_complexity)]
#![no_std]
extern crate alloc;
use core::mem::size_of;
use alloc::sync::Arc;
use bytemuck::Pod;
use ethernet::L2Packet;
use l3::L3Packet;
use libk_mm::PageBox;
use libk_thread::runtime;
use libk_util::queue::UnboundedMpmcQueue;
use yggdrasil_abi::{error::Error, net::protocols::EthernetFrame};
pub mod ethernet;
pub mod l3;
pub mod l4;
pub mod socket;
pub mod config;
pub mod interface;
pub mod util;
pub use interface::register_interface;
pub struct Packet {
// TODO info about "received" interface
buffer: PageBox<[u8]>,
offset: usize,
iface: u32,
}
pub struct PacketBuilder {
data: PageBox<[u8]>,
pos: usize,
len: usize,
}
impl PacketBuilder {
pub fn new(l2_offset: usize, l2_size: usize) -> Result<Self, Error> {
let data = PageBox::new_slice(0, l2_offset + l2_size)?;
Ok(Self {
data,
pos: l2_offset,
len: l2_offset,
})
}
#[inline]
pub fn push<T: Pod>(&mut self, value: &T) -> Result<(), Error> {
self.push_bytes(bytemuck::bytes_of(value))
}
pub fn push_bytes(&mut self, bytes: &[u8]) -> Result<(), Error> {
if self.pos + bytes.len() > self.data.len() {
return Err(Error::OutOfMemory);
}
self.data[self.pos..self.pos + bytes.len()].copy_from_slice(bytes);
self.pos += bytes.len();
Ok(())
}
pub fn finish(self) -> (PageBox<[u8]>, usize) {
(self.data, self.len)
}
}
impl Packet {
#[inline]
pub fn new(buffer: PageBox<[u8]>, offset: usize, iface: u32) -> Self {
Self {
buffer,
offset,
iface,
}
}
}
static PACKET_QUEUE: UnboundedMpmcQueue<Packet> = UnboundedMpmcQueue::new();
static ACCEPT_QUEUE: UnboundedMpmcQueue<L3Packet> = UnboundedMpmcQueue::new();
#[inline]
pub fn receive_packet(packet: Packet) -> Result<(), Error> {
PACKET_QUEUE.push_back(packet);
Ok(())
}
pub fn start_network_tasks() -> Result<(), Error> {
runtime::spawn(l2_packet_handler_worker())?;
for _ in 0..4 {
runtime::spawn(l3_accept_worker())?;
}
runtime::spawn(config::network_config_service())?;
Ok(())
}
async fn l2_packet_handler_worker() {
loop {
let packet = PACKET_QUEUE.pop_front().await;
let eth_frame: &EthernetFrame = bytemuck::from_bytes(
&packet.buffer[packet.offset..packet.offset + size_of::<EthernetFrame>()],
);
let l2_packet = L2Packet {
interface_id: packet.iface,
source_address: eth_frame.source_mac,
destination_address: eth_frame.destination_mac,
l2_offset: packet.offset,
l3_offset: packet.offset + size_of::<EthernetFrame>(),
data: Arc::new(packet.buffer),
};
ethernet::handle(l2_packet);
}
}
async fn l3_accept_worker() {
loop {
let l3_packet = ACCEPT_QUEUE.pop_front().await;
// log::debug!(
// "INPUT {} {}:{:?} -> {}:{:?}: ACCEPT",
// l3_packet.protocol,
// l3_packet.source_address,
// l3_packet.source_port,
// l3_packet.destination_address,
// l3_packet.destination_port
// );
if let Err(error) = l3::handle_accepted(l3_packet).await {
log::error!("L3 handle error: {:?}", error);
}
}
}

View File

@ -0,0 +1,827 @@
use core::{
future::{poll_fn, Future},
pin::Pin,
sync::atomic::{AtomicBool, AtomicU32, Ordering},
task::{Context, Poll},
time::Duration,
};
use alloc::{collections::BTreeMap, sync::Arc, vec::Vec};
use libk_device::monotonic_timestamp;
use libk_mm::PageBox;
use libk_thread::{
block,
runtime::{run_with_timeout, FutureTimeout},
};
use libk_util::{
queue::BoundedMpmcQueue,
sync::{
spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockWriteGuard},
IrqSafeSpinlock, IrqSafeSpinlockGuard,
},
waker::QueueWaker,
};
use vfs::{ConnectionSocket, FileReadiness, ListenerSocket, PacketSocket, Socket};
use yggdrasil_abi::{
error::Error,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketInterfaceQuery, SocketOption},
};
use crate::{
ethernet::L2Packet,
interface::NetworkInterface,
l3::Route,
l4::{
self,
tcp::{TcpConnection, TcpConnectionState},
},
};
pub struct UdpSocket {
local: SocketAddr,
remote: Option<SocketAddr>,
broadcast: AtomicBool,
// TODO just place packets here for one less copy?
receive_queue: BoundedMpmcQueue<(SocketAddr, Vec<u8>)>,
}
pub struct TcpSocket {
pub(crate) local: SocketAddr,
pub(crate) remote: SocketAddr,
// Listener which accepted the socket
listener: Option<Arc<TcpListener>>,
connection: IrqSafeRwLock<TcpConnection>,
}
pub struct TcpListener {
accept: SocketAddr,
// Currently active sockets
sockets: IrqSafeRwLock<BTreeMap<SocketAddr, Arc<TcpSocket>>>,
pending_accept: IrqSafeSpinlock<Vec<Arc<TcpSocket>>>,
accept_notify: QueueWaker,
}
pub struct RawSocket {
id: u32,
bound: IrqSafeSpinlock<Option<u32>>,
receive_queue: BoundedMpmcQueue<L2Packet>,
}
pub struct SocketTable<T: Socket> {
inner: BTreeMap<SocketAddr, Arc<T>>,
}
pub struct TwoWaySocketTable<T> {
inner: BTreeMap<(SocketAddr, SocketAddr), Arc<T>>,
}
impl<T> TwoWaySocketTable<T> {
pub const fn new() -> Self {
Self {
inner: BTreeMap::new(),
}
}
pub fn try_insert_with<F: FnOnce() -> Result<Arc<T>, Error>>(
&mut self,
local: SocketAddr,
remote: SocketAddr,
with: F,
) -> Result<Arc<T>, Error> {
if self.inner.contains_key(&(local, remote)) {
return Err(Error::AddrInUse);
}
let socket = with()?;
self.inner.insert((local, remote), socket.clone());
Ok(socket)
}
pub fn try_insert_with_ephemeral_port<F: FnMut(u16) -> Result<Arc<T>, Error>>(
&mut self,
local: IpAddr,
remote: SocketAddr,
mut with: F,
) -> Result<Arc<T>, Error> {
for port in 32768..u16::MAX - 1 {
let local = SocketAddr::new(local, port);
match self.try_insert_with(local, remote, || with(port)) {
Ok(socket) => return Ok(socket),
Err(Error::AddrInUse) => continue,
Err(error) => return Err(error),
}
}
Err(Error::AddrInUse)
}
pub fn remove(&mut self, local: SocketAddr, remote: SocketAddr) -> Result<(), Error> {
match self.inner.remove(&(local, remote)) {
Some(_) => Ok(()),
None => Err(Error::DoesNotExist),
}
}
pub fn get(&self, local: SocketAddr, remote: SocketAddr) -> Option<Arc<T>> {
self.inner.get(&(local, remote)).cloned()
}
}
impl<T: Socket> SocketTable<T> {
pub const fn new() -> Self {
Self {
inner: BTreeMap::new(),
}
}
pub fn try_insert_with_ephemeral_port<F: FnMut(u16) -> Result<Arc<T>, Error>>(
&mut self,
local: IpAddr,
mut with: F,
) -> Result<Arc<T>, Error> {
for port in 32768..u16::MAX - 1 {
let local = SocketAddr::new(local, port);
match self.try_insert_with(local, || with(port)) {
Ok(socket) => return Ok(socket),
Err(Error::AddrInUse) => continue,
Err(error) => return Err(error),
}
}
Err(Error::AddrInUse)
}
pub fn try_insert_with<F: FnOnce() -> Result<Arc<T>, Error>>(
&mut self,
address: SocketAddr,
with: F,
) -> Result<Arc<T>, Error> {
if self.inner.contains_key(&address) {
return Err(Error::AddrInUse);
}
let socket = with()?;
self.inner.insert(address, socket.clone());
Ok(socket)
}
pub fn remove(&mut self, local: SocketAddr) -> Result<(), Error> {
match self.inner.remove(&local) {
Some(_) => Ok(()),
None => Err(Error::DoesNotExist),
}
}
pub fn get_exact(&self, local: &SocketAddr) -> Option<Arc<T>> {
self.inner.get(local).cloned()
}
pub fn get(&self, local: &SocketAddr) -> Option<Arc<T>> {
if let Some(socket) = self.inner.get(local) {
return Some(socket.clone());
}
match local {
SocketAddr::V4(_v4) => {
let unspec_v4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, local.port());
self.inner.get(&SocketAddr::V4(unspec_v4)).cloned()
}
SocketAddr::V6(_) => todo!(),
}
}
}
static UDP_SOCKETS: IrqSafeRwLock<SocketTable<UdpSocket>> = IrqSafeRwLock::new(SocketTable::new());
static TCP_SOCKETS: IrqSafeRwLock<TwoWaySocketTable<TcpSocket>> =
IrqSafeRwLock::new(TwoWaySocketTable::new());
static RAW_SOCKET_ID: AtomicU32 = AtomicU32::new(0);
static RAW_SOCKETS: IrqSafeRwLock<BTreeMap<u32, Arc<RawSocket>>> =
IrqSafeRwLock::new(BTreeMap::new());
static BOUND_RAW_SOCKETS: IrqSafeRwLock<BTreeMap<u32, Vec<u32>>> =
IrqSafeRwLock::new(BTreeMap::new());
static TCP_LISTENERS: IrqSafeRwLock<SocketTable<TcpListener>> =
IrqSafeRwLock::new(SocketTable::new());
impl UdpSocket {
fn create_socket(local: SocketAddr) -> Arc<UdpSocket> {
log::debug!("UDP socket opened: {}", local);
Arc::new(UdpSocket {
local,
remote: None,
broadcast: AtomicBool::new(false),
receive_queue: BoundedMpmcQueue::new(128),
})
}
pub fn bind(address: SocketAddr) -> Result<Arc<UdpSocket>, Error> {
let mut sockets = UDP_SOCKETS.write();
if address.port() == 0 {
sockets.try_insert_with_ephemeral_port(address.ip(), |port| {
Ok(Self::create_socket(SocketAddr::new(address.ip(), port)))
})
} else {
sockets.try_insert_with(address, move || Ok(Self::create_socket(address)))
}
}
pub fn connect(&self, _address: SocketAddr) -> Result<(), Error> {
todo!()
}
pub fn get(local: &SocketAddr) -> Option<Arc<UdpSocket>> {
UDP_SOCKETS.read().get(local)
}
pub fn packet_received(&self, source: SocketAddr, data: &[u8]) -> Result<(), Error> {
self.receive_queue
.try_push_back((source, Vec::from(data)))
.map_err(|_| Error::QueueFull)
}
}
impl FileReadiness for UdpSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.receive_queue.poll_not_empty(cx).map(Ok)
}
}
impl PacketSocket for UdpSocket {
fn send(&self, destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error> {
let Some(destination) = destination else {
// TODO can still send without setting address if "connected"
return Err(Error::InvalidArgument);
};
// TODO check that destnation family matches self family
match (self.broadcast.load(Ordering::Relaxed), destination.ip()) {
// SendTo in broadcast?
(true, _) => todo!(),
(false, _) => {
block!(
l4::udp::send(
self.local.port(),
destination.ip(),
destination.port(),
64,
data,
)
.await
)??;
}
}
Ok(data.len())
}
fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error> {
let (source, data) = block!(self.receive_queue.pop_front().await)?;
if data.len() > buffer.len() {
// TODO check how other OSs handle this
return Err(Error::BufferTooSmall);
}
buffer[..data.len()].copy_from_slice(&data);
Ok((source, data.len()))
}
}
impl Socket for UdpSocket {
fn local_address(&self) -> SocketAddr {
self.local
}
fn remote_address(&self) -> Option<SocketAddr> {
self.remote
}
fn close(&self) -> Result<(), Error> {
log::debug!("UDP socket closed: {}", self.local);
UDP_SOCKETS.write().remove(self.local)
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
&SocketOption::Broadcast(broadcast) => {
log::debug!("{} broadcast: {}", self.local, broadcast);
self.broadcast.store(broadcast, Ordering::Relaxed);
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::Broadcast(broadcast) => {
*broadcast = self.broadcast.load(Ordering::Relaxed);
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
}
impl RawSocket {
pub fn bind() -> Result<Arc<Self>, Error> {
let id = RAW_SOCKET_ID.fetch_add(1, Ordering::SeqCst);
let socket = Self {
id,
bound: IrqSafeSpinlock::new(None),
receive_queue: BoundedMpmcQueue::new(256),
};
let socket = Arc::new(socket);
RAW_SOCKETS.write().insert(id, socket.clone());
Ok(socket)
}
fn bound_packet_received(&self, packet: L2Packet) {
// TODO do something with the dropped packet?
self.receive_queue.try_push_back(packet).ok();
}
pub fn packet_received(packet: L2Packet) {
let bound_sockets = BOUND_RAW_SOCKETS.read();
let raw_sockets = RAW_SOCKETS.read();
if let Some(ids) = bound_sockets.get(&packet.interface_id) {
for id in ids {
let socket = raw_sockets.get(id).unwrap();
socket.bound_packet_received(packet.clone());
}
}
}
}
impl FileReadiness for RawSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.receive_queue.poll_not_empty(cx).map(Ok)
}
}
impl Socket for RawSocket {
fn get_option(&self, option: &mut SocketOption) -> Result<(), Error> {
match option {
SocketOption::BoundHardwareAddress(mac) => {
let bound = self.bound.lock().ok_or(Error::DoesNotExist)?;
let interface = NetworkInterface::get(bound).unwrap();
*mac = interface.mac;
Ok(())
}
_ => Err(Error::InvalidOperation),
}
}
fn set_option(&self, option: &SocketOption) -> Result<(), Error> {
match option {
SocketOption::BindInterface(query) => {
let mut bound = self.bound.lock();
if bound.is_some() {
return Err(Error::AlreadyExists);
}
let mut bound_sockets = BOUND_RAW_SOCKETS.write();
let interface = match *query {
SocketInterfaceQuery::ById(id) => NetworkInterface::get(id),
SocketInterfaceQuery::ByName(name) => NetworkInterface::query_by_name(name),
}?;
let list = bound_sockets.entry(interface.id).or_default();
bound.replace(interface.id);
list.push(self.id);
Ok(())
}
SocketOption::UnbindInterface => todo!(),
_ => Err(Error::InvalidOperation),
}
}
fn close(&self) -> Result<(), Error> {
let bound = self.bound.lock().take();
if let Some(bound) = bound {
let mut bound_sockets = BOUND_RAW_SOCKETS.write();
let mut clear = false;
if let Some(list) = bound_sockets.get_mut(&bound) {
list.retain(|&item| item != self.id);
clear = list.is_empty();
}
if clear {
bound_sockets.remove(&bound);
}
}
RAW_SOCKETS.write().remove(&self.id).unwrap();
Ok(())
}
fn local_address(&self) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
}
fn remote_address(&self) -> Option<SocketAddr> {
None
}
}
impl PacketSocket for RawSocket {
fn send(&self, _destination: Option<SocketAddr>, data: &[u8]) -> Result<usize, Error> {
// TODO cap by MTU?
let bound = self.bound.lock().ok_or(Error::InvalidOperation)?;
let interface = NetworkInterface::get(bound).unwrap();
let l2_offset = interface.device.packet_prefix_size();
if data.len() > 4096 - l2_offset {
return Err(Error::InvalidArgument);
}
let mut packet = PageBox::new_slice(0, l2_offset + data.len())?;
packet[l2_offset..l2_offset + data.len()].copy_from_slice(data);
interface.device.transmit(packet)?;
Ok(data.len())
}
fn receive(&self, buffer: &mut [u8]) -> Result<(SocketAddr, usize), Error> {
let data = block!(self.receive_queue.pop_front().await)?;
let full_len = data.data.len();
let len = full_len - data.l2_offset;
if buffer.len() < len {
return Err(Error::BufferTooSmall);
}
buffer[..len].copy_from_slice(&data.data[data.l2_offset..full_len]);
Ok((SocketAddr::NULL_V4, len))
}
}
impl TcpSocket {
pub fn connect(remote: SocketAddr) -> Result<(SocketAddr, Arc<TcpSocket>), Error> {
block!(Self::connect_async(remote).await)?
}
pub fn accept_remote(
listener: Arc<TcpListener>,
local: SocketAddr,
remote: SocketAddr,
remote_window_size: usize,
tx_seq: u32,
rx_seq: u32,
) -> Result<Arc<TcpSocket>, Error> {
let mut sockets = TCP_SOCKETS.write();
sockets.try_insert_with(local, remote, move || {
let connection = TcpConnection::new(
local,
remote,
remote_window_size,
tx_seq,
rx_seq,
TcpConnectionState::SynReceived,
);
log::debug!("Accepted TCP socket {} -> {}", local, remote);
let socket = Self {
local,
remote,
listener: Some(listener),
connection: IrqSafeRwLock::new(connection),
};
Ok(Arc::new(socket))
})
}
pub fn connection(&self) -> &IrqSafeRwLock<TcpConnection> {
&self.connection
}
pub(crate) fn accept(self: &Arc<Self>) {
if let Some(listener) = self.listener.as_ref() {
listener.accept_socket(self.clone());
}
}
pub fn get(local: SocketAddr, remote: SocketAddr) -> Option<Arc<Self>> {
TCP_SOCKETS.read().get(local, remote)
}
pub fn receive_async<'a>(
&'a self,
buffer: &'a mut [u8],
) -> impl Future<Output = Result<usize, Error>> + 'a {
// TODO timeout here
// TODO don't throw ConnectionReset immediately
struct F<'f> {
socket: &'f TcpSocket,
buffer: &'f mut [u8],
}
impl<'f> Future for F<'f> {
type Output = Result<usize, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.socket.poll_receive(cx) {
Poll::Ready(Ok(mut lock)) => Poll::Ready(lock.read_nonblocking(self.buffer)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
}
}
F {
socket: self,
buffer,
}
}
pub async fn send_async(&self, data: &[u8]) -> Result<usize, Error> {
let mut pos = 0;
let mut rem = data.len();
while rem != 0 {
// TODO check MTU
let amount = rem.min(512);
self.send_segment_async(&data[pos..pos + amount]).await?;
pos += amount;
rem -= amount;
}
Ok(pos)
}
pub async fn close_async(&self, remove_from_listener: bool) -> Result<(), Error> {
// TODO timeout here
// Already closing
if self.connection.read().is_closing() {
return Ok(());
}
// Wait for all sent data to be acknowledged
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.finish().await?;
}
log::debug!(
"TCP socket closed (FinWait2/Closed): {} <-> {}",
self.local,
self.remote
);
// Wait for connection to get closed
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_finish(cx)
})
.await;
if remove_from_listener {
if let Some(listener) = self.listener.as_ref() {
listener.remove_socket(self.remote);
};
}
Ok(())
}
pub(crate) fn remove_socket(&self) -> Result<(), Error> {
log::debug!(
"TCP socket closed and removed: {} <-> {}",
self.local,
self.remote
);
let connection = self.connection.read();
debug_assert!(connection.is_closed());
TCP_SOCKETS.write().remove(self.local, self.remote)?;
connection.notify_all();
Ok(())
}
fn poll_receive(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<IrqSafeRwLockWriteGuard<TcpConnection>, Error>> {
let lock = self.connection.write();
match lock.poll_receive(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(lock)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
}
async fn send_segment_async(&self, data: &[u8]) -> Result<(), Error> {
// TODO timeout here
{
let mut connection = poll_fn(|cx| {
let connection = self.connection.write();
match connection.poll_send(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(connection)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
})
.await?;
connection.transmit(data).await?;
}
poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_acknowledge(cx)
})
.await;
Ok(())
}
async fn connect_async(remote: SocketAddr) -> Result<(SocketAddr, Arc<TcpSocket>), Error> {
// Lookup route to remote
let (interface_id, _, remote_ip) =
Route::lookup(remote.ip()).ok_or(Error::HostUnreachable)?;
let remote = SocketAddr::new(remote_ip, remote.port());
let interface = NetworkInterface::get(interface_id)?;
let local_ip = interface.address.read().ok_or(Error::NetworkUnreachable)?;
let socket = {
let mut sockets = TCP_SOCKETS.write();
sockets.try_insert_with_ephemeral_port(local_ip, remote, |port| {
let t = monotonic_timestamp()?;
let tx_seq = t.as_micros() as u32;
let local = SocketAddr::new(local_ip, port);
let connection =
TcpConnection::new(local, remote, 16384, tx_seq, 0, TcpConnectionState::Closed);
let socket = Self {
local,
remote,
listener: None,
connection: IrqSafeRwLock::new(connection),
};
Ok(Arc::new(socket))
})?
};
let mut t = 200;
for _ in 0..5 {
let timeout = Duration::from_millis(t);
log::debug!("Try SYN with timeout={:?}", timeout);
match socket.try_connect(timeout).await {
Ok(()) => return Ok((socket.local, socket)),
Err(Error::TimedOut) => (),
Err(error) => return Err(error),
}
t *= 2;
}
// Couldn't establish
Err(Error::TimedOut)
}
async fn try_connect(&self, timeout: Duration) -> Result<(), Error> {
{
let mut connection = self.connection.write();
connection.send_syn().await?;
}
let fut = poll_fn(|cx| {
let connection = self.connection.read();
connection.poll_established(cx)
});
match run_with_timeout(timeout, fut).await {
FutureTimeout::Ok(value) => value,
FutureTimeout::Timeout => Err(Error::TimedOut),
}
}
}
impl Socket for TcpSocket {
fn local_address(&self) -> SocketAddr {
self.local
}
fn remote_address(&self) -> Option<SocketAddr> {
Some(self.remote)
}
fn close(&self) -> Result<(), Error> {
block!(self.close_async(true).await)?
}
}
impl FileReadiness for TcpSocket {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.poll_receive(cx).map_ok(|_| ())
}
}
impl ConnectionSocket for TcpSocket {
fn receive(&self, buffer: &mut [u8]) -> Result<usize, Error> {
block!(self.receive_async(buffer).await)?
}
fn send(&self, data: &[u8]) -> Result<usize, Error> {
block!(self.send_async(data).await)?
}
}
impl TcpListener {
pub fn bind(accept: SocketAddr) -> Result<Arc<Self>, Error> {
TCP_LISTENERS.write().try_insert_with(accept, || {
let listener = TcpListener {
accept,
sockets: IrqSafeRwLock::new(BTreeMap::new()),
pending_accept: IrqSafeSpinlock::new(Vec::new()),
accept_notify: QueueWaker::new(),
};
log::debug!("TCP Listener opened: {}", accept);
Ok(Arc::new(listener))
})
}
pub fn get(local: SocketAddr) -> Option<Arc<Self>> {
TCP_LISTENERS.read().get(&local)
}
pub fn accept_async(&self) -> impl Future<Output = Result<Arc<TcpSocket>, Error>> + '_ {
struct F<'f> {
listener: &'f TcpListener,
}
impl<'f> Future for F<'f> {
type Output = Result<Arc<TcpSocket>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.listener.poll_accept(cx) {
Poll::Ready(mut lock) => Poll::Ready(Ok(lock.pop().unwrap())),
Poll::Pending => Poll::Pending,
}
}
}
F { listener: self }
}
fn accept_socket(&self, socket: Arc<TcpSocket>) {
log::debug!("{}: accept {}", self.accept, socket.remote);
self.sockets.write().insert(socket.remote, socket.clone());
self.pending_accept.lock().push(socket);
self.accept_notify.wake_one();
}
fn remove_socket(&self, remote: SocketAddr) {
log::debug!("Remove client {}", remote);
self.sockets.write().remove(&remote);
}
fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<IrqSafeSpinlockGuard<Vec<Arc<TcpSocket>>>> {
let lock = self.pending_accept.lock();
self.accept_notify.register(cx.waker());
if !lock.is_empty() {
self.accept_notify.remove(cx.waker());
Poll::Ready(lock)
} else {
Poll::Pending
}
}
}
impl Socket for TcpListener {
fn local_address(&self) -> SocketAddr {
self.accept
}
fn remote_address(&self) -> Option<SocketAddr> {
None
}
fn close(&self) -> Result<(), Error> {
// TODO if clients not closed already, send RST?
TCP_LISTENERS.write().remove(self.accept)
}
}
impl FileReadiness for TcpListener {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.poll_accept(cx).map(|_| Ok(()))
}
}
impl ListenerSocket for TcpListener {
fn accept(&self) -> Result<(SocketAddr, Arc<dyn ConnectionSocket>), Error> {
let socket = block!(self.accept_async().await)??;
let remote = socket.remote;
Ok((remote, socket))
}
}

View File

@ -0,0 +1,292 @@
// This TCP reassembler was taken from smoltcp-rs/smoltcp:
//
// https://github.com/smoltcp-rs/smoltcp
use core::fmt;
pub const ASSEMBLER_MAX_SEGMENT_COUNT: usize = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TooManyHolesError;
impl fmt::Display for TooManyHolesError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "too many holes")
}
}
/// A contiguous chunk of absent data, followed by a contiguous chunk of present data.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct Contig {
hole_size: usize,
data_size: usize,
}
impl fmt::Display for Contig {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.has_hole() {
write!(f, "({})", self.hole_size)?;
}
if self.has_hole() && self.has_data() {
write!(f, " ")?;
}
if self.has_data() {
write!(f, "{}", self.data_size)?;
}
Ok(())
}
}
impl Contig {
const fn empty() -> Contig {
Contig {
hole_size: 0,
data_size: 0,
}
}
fn hole_and_data(hole_size: usize, data_size: usize) -> Contig {
Contig {
hole_size,
data_size,
}
}
fn has_hole(&self) -> bool {
self.hole_size != 0
}
fn has_data(&self) -> bool {
self.data_size != 0
}
fn total_size(&self) -> usize {
self.hole_size + self.data_size
}
fn shrink_hole_by(&mut self, size: usize) {
self.hole_size -= size;
}
fn shrink_hole_to(&mut self, size: usize) {
debug_assert!(self.hole_size >= size);
let total_size = self.total_size();
self.hole_size = size;
self.data_size = total_size - size;
}
}
/// A buffer (re)assembler.
///
/// Currently, up to a hardcoded limit of 4 or 32 holes can be tracked in the buffer.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Assembler {
contigs: [Contig; ASSEMBLER_MAX_SEGMENT_COUNT],
}
impl fmt::Display for Assembler {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "[ ")?;
for contig in self.contigs.iter() {
if !contig.has_data() {
break;
}
write!(f, "{contig} ")?;
}
write!(f, "]")?;
Ok(())
}
}
// Invariant on Assembler::contigs:
// - There's an index `i` where all contigs before have data, and all contigs after don't (are unused).
// - All contigs with data must have hole_size != 0, except the first.
impl Assembler {
/// Create a new buffer assembler.
pub const fn new() -> Assembler {
const EMPTY: Contig = Contig::empty();
Assembler {
contigs: [EMPTY; ASSEMBLER_MAX_SEGMENT_COUNT],
}
}
pub fn clear(&mut self) {
self.contigs.fill(Contig::empty());
}
fn front(&self) -> Contig {
self.contigs[0]
}
/// Return length of the front contiguous range without removing it from the assembler
pub fn peek_front(&self) -> usize {
let front = self.front();
if front.has_hole() {
0
} else {
front.data_size
}
}
fn back(&self) -> Contig {
self.contigs[self.contigs.len() - 1]
}
/// Return whether the assembler contains no data.
pub fn is_empty(&self) -> bool {
!self.front().has_data()
}
/// Remove a contig at the given index.
fn remove_contig_at(&mut self, at: usize) {
debug_assert!(self.contigs[at].has_data());
for i in at..self.contigs.len() - 1 {
if !self.contigs[i].has_data() {
return;
}
self.contigs[i] = self.contigs[i + 1];
}
// Removing the last one.
self.contigs[self.contigs.len() - 1] = Contig::empty();
}
/// Add a contig at the given index, and return a pointer to it.
fn add_contig_at(&mut self, at: usize) -> Result<&mut Contig, TooManyHolesError> {
if self.back().has_data() {
return Err(TooManyHolesError);
}
for i in (at + 1..self.contigs.len()).rev() {
self.contigs[i] = self.contigs[i - 1];
}
self.contigs[at] = Contig::empty();
Ok(&mut self.contigs[at])
}
/// Add a new contiguous range to the assembler,
/// or return `Err(TooManyHolesError)` if too many discontinuities are already recorded.
pub fn add(&mut self, mut offset: usize, size: usize) -> Result<(), TooManyHolesError> {
if size == 0 {
return Ok(());
}
let mut i = 0;
// Find index of the contig containing the start of the range.
loop {
if i == self.contigs.len() {
// The new range is after all the previous ranges, but there/s no space to add it.
return Err(TooManyHolesError);
}
let contig = &mut self.contigs[i];
if !contig.has_data() {
// The new range is after all the previous ranges. Add it.
*contig = Contig::hole_and_data(offset, size);
return Ok(());
}
if offset <= contig.total_size() {
break;
}
offset -= contig.total_size();
i += 1;
}
let contig = &mut self.contigs[i];
if offset < contig.hole_size {
// Range starts within the hole.
if offset + size < contig.hole_size {
// Range also ends within the hole.
let new_contig = self.add_contig_at(i)?;
new_contig.hole_size = offset;
new_contig.data_size = size;
// Previous contigs[index] got moved to contigs[index+1]
self.contigs[i + 1].shrink_hole_by(offset + size);
return Ok(());
}
// The range being added covers both a part of the hole and a part of the data
// in this contig, shrink the hole in this contig.
contig.shrink_hole_to(offset);
}
// coalesce contigs to the right.
let mut j = i + 1;
while j < self.contigs.len()
&& self.contigs[j].has_data()
&& offset + size >= self.contigs[i].total_size() + self.contigs[j].hole_size
{
self.contigs[i].data_size += self.contigs[j].total_size();
j += 1;
}
let shift = j - i - 1;
if shift != 0 {
for x in i + 1..self.contigs.len() {
if !self.contigs[x].has_data() {
break;
}
self.contigs[x] = self
.contigs
.get(x + shift)
.copied()
.unwrap_or_else(Contig::empty);
}
}
if offset + size > self.contigs[i].total_size() {
// The added range still extends beyond the current contig. Increase data size.
let left = offset + size - self.contigs[i].total_size();
self.contigs[i].data_size += left;
// Decrease hole size of the next, if any.
if i + 1 < self.contigs.len() && self.contigs[i + 1].has_data() {
self.contigs[i + 1].hole_size -= left;
}
}
Ok(())
}
/// Remove a contiguous range from the front of the assembler.
/// If no such range, return 0.
pub fn remove_front(&mut self) -> usize {
let front = self.front();
if front.has_hole() || !front.has_data() {
0
} else {
self.remove_contig_at(0);
debug_assert!(front.data_size > 0);
front.data_size
}
}
/// Add a segment, then remove_front.
///
/// This is equivalent to calling `add` then `remove_front` individually,
/// except it's guaranteed to not fail when offset = 0.
/// This is required for TCP: we must never drop the next expected segment, or
/// the protocol might get stuck.
pub fn add_then_remove_front(
&mut self,
offset: usize,
size: usize,
) -> Result<usize, TooManyHolesError> {
// This is the only case where a segment at offset=0 would cause the
// total amount of contigs to rise (and therefore can potentially cause
// a TooManyHolesError). Handle it in a way that is guaranteed to succeed.
if offset == 0 && size < self.contigs[0].hole_size {
self.contigs[0].hole_size -= size;
return Ok(size);
}
self.add(offset, size)?;
Ok(self.remove_front())
}
}

Some files were not shown because too many files have changed in this diff Show More