virtio: add virtio-blk, rework virtio queues

This commit is contained in:
Mark Poliakov 2025-02-15 16:41:26 +02:00
parent 99f5ad0987
commit f716c50988
18 changed files with 1149 additions and 470 deletions

21
Cargo.lock generated
View File

@ -1145,6 +1145,7 @@ dependencies = [
name = "kernel-arch-hosted"
version = "0.1.0"
dependencies = [
"device-api",
"kernel-arch-interface",
"libk-mm-interface",
"yggdrasil-abi",
@ -2867,12 +2868,30 @@ dependencies = [
"yggdrasil-abi",
]
[[package]]
name = "ygg_driver_virtio_blk"
version = "0.1.0"
dependencies = [
"async-trait",
"bytemuck",
"device-api",
"libk",
"libk-mm",
"libk-util",
"log",
"tock-registers",
"ygg_driver_pci",
"ygg_driver_virtio_core",
"yggdrasil-abi",
]
[[package]]
name = "ygg_driver_virtio_core"
version = "0.1.0"
dependencies = [
"bitflags 2.8.0",
"device-api",
"kernel-arch-hosted",
"libk",
"libk-mm",
"libk-util",
@ -2904,6 +2923,7 @@ dependencies = [
"bitflags 2.8.0",
"bytemuck",
"device-api",
"futures-util",
"libk",
"libk-mm",
"libk-util",
@ -2983,6 +3003,7 @@ dependencies = [
"ygg_driver_pci",
"ygg_driver_usb",
"ygg_driver_usb_xhci",
"ygg_driver_virtio_blk",
"ygg_driver_virtio_gpu",
"ygg_driver_virtio_net",
"yggdrasil-abi",

View File

@ -30,6 +30,7 @@ ygg_driver_net_core = { path = "driver/net/core" }
ygg_driver_net_loopback = { path = "driver/net/loopback" }
ygg_driver_virtio_net = { path = "driver/virtio/net", features = ["pci"] }
ygg_driver_virtio_gpu = { path = "driver/virtio/gpu", features = ["pci"] }
ygg_driver_virtio_blk = { path = "driver/virtio/blk", features = ["pci"] }
ygg_driver_nvme = { path = "driver/block/nvme" }
ygg_driver_ahci = { path = "driver/block/ahci" }
ygg_driver_input = { path = "driver/input" }

View File

@ -7,3 +7,4 @@ edition = "2021"
kernel-arch-interface.workspace = true
yggdrasil-abi.workspace = true
libk-mm-interface.workspace = true
device-api.workspace = true

View File

@ -1,9 +1,11 @@
#![feature(never_type)]
#![feature(never_type, allocator_api, slice_ptr_get)]
use std::{
alloc::{Allocator, Global, Layout},
marker::PhantomData,
sync::atomic::{AtomicBool, Ordering},
};
use device_api::dma::{DmaAllocation, DmaAllocator};
use kernel_arch_interface::{
cpu::{CpuData, IpiQueue},
mem::{
@ -105,6 +107,14 @@ impl Architecture for ArchitectureImpl {
fn ipi_queue(_cpu_id: u32) -> Option<&'static IpiQueue<Self>> {
None
}
fn load_barrier() {}
fn store_barrier() {}
fn memory_barrier() {}
fn flush_virtual_range(_range: std::ops::Range<usize>) {}
}
impl KernelTableManager for KernelTableManagerImpl {
@ -202,3 +212,19 @@ impl<K: KernelTableManager, PA: PhysicalMemoryAllocator> TaskContext<K, PA>
extern "Rust" fn __signal_process_group(_group_id: ProcessGroupId, _signal: Signal) {
unimplemented!()
}
pub struct HostedDmaAllocator;
impl DmaAllocator for HostedDmaAllocator {
fn allocate(&self, layout: Layout) -> Result<DmaAllocation, Error> {
let ptr = Global.allocate(layout.align_to(0x1000).unwrap()).unwrap();
let base = ptr.as_non_null_ptr();
let addr: usize = base.addr().into();
Ok(DmaAllocation {
host_virtual: base.cast(),
host_physical: addr as _,
page_count: layout.size().div_ceil(0x1000),
bus_address: addr as _,
})
}
}

View File

@ -0,0 +1,23 @@
[package]
name = "ygg_driver_virtio_blk"
version = "0.1.0"
edition = "2024"
[dependencies]
yggdrasil-abi.workspace = true
libk-util.workspace = true
libk-mm.workspace = true
libk.workspace = true
device-api = { workspace = true, features = ["derive"] }
ygg_driver_virtio_core = { path = "../core" }
ygg_driver_pci = { path = "../../bus/pci", optional = true }
log.workspace = true
bytemuck.workspace = true
tock-registers.workspace = true
async-trait.workspace = true
[features]
default = []
pci = ["ygg_driver_pci", "ygg_driver_virtio_core/pci"]

View File

@ -0,0 +1,403 @@
#![no_std]
use core::mem::MaybeUninit;
use alloc::{boxed::Box, format, sync::Arc, vec::Vec};
use async_trait::async_trait;
use bytemuck::{Pod, Zeroable};
use device_api::{
device::{Device, DeviceInitContext},
dma::DmaAllocator,
interrupt::{InterruptAffinity, InterruptHandler, IrqVector},
};
use libk::{
device::{block::BlockDevice, manager::probe_partitions},
dma::{DmaBuffer, DmaSlice, DmaSliceMut},
error::Error,
fs::devfs,
task::runtime,
};
use libk_mm::{address::PhysicalAddress, table::MapAttributes, PageProvider};
use libk_util::sync::{spin_rwlock::IrqSafeRwLock, IrqSafeSpinlock};
use ygg_driver_pci::{
device::{PciDeviceInfo, PreferredInterruptMode},
macros::pci_driver,
};
use ygg_driver_virtio_core::{
queue::VirtQueue,
transport::{pci::PciTransport, Transport},
DeviceStatus,
};
use yggdrasil_abi::{bitflags, io::FileMode};
extern crate alloc;
bitflags! {
pub struct Features: u64 {
const F_SIZE_MAX: bit 1;
const F_SEG_MAX: bit 2;
const F_RO: bit 5;
}
}
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
struct DeviceConfig {
capacity: u64,
size_max: u32,
seg_max: u32,
// virtio_blk_geometry {
cylinders: u16,
heads: u8,
sectors: u8,
// }
blk_size: u32,
// virtio_blk_topology {
physical_block_exp: u8,
alignment_offset: u8,
min_io_size: u16,
opt_io_size: u32,
// }
writeback: u8,
_0: u8,
num_queues: u16,
max_discard_sectors: u32,
max_discard_seg: u32,
discard_sector_alignment: u32,
max_write_zeroes_sectors: u32,
max_write_zeroes_seg: u32,
write_zeroes_may_unmap: u8,
_1: [u8; 3],
max_secure_erase_sectors: u32,
max_secure_erase_seg: u32,
secure_erase_sector_alignment: u32,
}
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
struct CommandHeader {
ty: u32,
_0: u32,
sector: u64,
}
pub struct VirtioBlk<T: Transport + 'static> {
transport: IrqSafeSpinlock<T>,
pci_device_info: PciDeviceInfo,
dma: Arc<dyn DmaAllocator>,
segment_size: usize,
read_only: bool,
capacity: u64,
request_queue: VirtQueue,
}
impl CommandHeader {
const TYPE_READ: u32 = 0;
const TYPE_WRITE: u32 = 1;
pub fn for_read(lba: u64) -> Self {
Self {
ty: Self::TYPE_READ,
_0: 0,
sector: lba,
}
}
pub fn for_write(lba: u64) -> Self {
Self {
ty: Self::TYPE_WRITE,
_0: 0,
sector: lba,
}
}
}
impl<T: Transport + 'static> VirtioBlk<T> {
// Only one VQ
const VQ_REQUEST_0: u16 = 0;
fn new(
dma: Arc<dyn DmaAllocator>,
mut transport: T,
pci_device_info: PciDeviceInfo,
) -> Result<Self, Error> {
let features = Features::from(transport.read_device_features());
let device_cfg = transport
.device_cfg()
.ok_or(Error::InvalidArgument)
.inspect_err(|_| log::error!("virtio-blk does not expose device configuration"))?;
let device_cfg: &DeviceConfig =
bytemuck::from_bytes(&device_cfg[..size_of::<DeviceConfig>()]);
let read_only = features.contains(Features::F_RO);
let segment_size = if features.contains(Features::F_SIZE_MAX) {
device_cfg.size_max as usize
} else {
// I guess no limit then?
262144
};
let segment_limit = if features.contains(Features::F_SEG_MAX) {
device_cfg.seg_max as usize
} else {
// I guess no limit then?
8
};
let capacity = device_cfg.capacity;
if segment_limit < 3 {
// Won't be able to send header + data + status
log::error!("virtio-blk: allowed segment count too small");
return Err(Error::InvalidArgument);
}
let request_queue = VirtQueue::with_capacity(&*dma, Self::VQ_REQUEST_0, 256)?;
Ok(Self {
transport: IrqSafeSpinlock::new(transport),
dma,
pci_device_info,
segment_size,
read_only,
capacity,
request_queue,
})
}
fn begin_init(&self) -> Result<DeviceStatus, Error> {
let mut transport = self.transport.lock();
let mut status = DeviceStatus::RESET_VALUE;
log::debug!("Reset device");
transport.write_device_status(status);
status |= DeviceStatus::ACKNOWLEDGE;
transport.write_device_status(status);
status |= DeviceStatus::DRIVER;
transport.write_device_status(status);
let _device_features = transport.read_device_features();
// TODO blah blah blah
transport.write_driver_features(0);
status |= DeviceStatus::FEATURES_OK;
transport.write_device_status(status);
if !transport
.read_device_status()
.contains(DeviceStatus::FEATURES_OK)
{
return Err(Error::InvalidOperation);
}
Ok(status)
}
fn finish_init(&self, status: DeviceStatus) {
let mut transport = self.transport.lock();
transport.write_device_status(status | DeviceStatus::DRIVER_OK);
}
fn setup_queues(self: &Arc<Self>) -> Result<(), Error> {
self.pci_device_info
.init_interrupts(PreferredInterruptMode::Msi(true))?;
let msi_info = self
.pci_device_info
.map_interrupt(InterruptAffinity::Any, self.clone())?;
let vector = msi_info.map(|msi| msi.vector as u16);
let mut transport = self.transport.lock();
transport.set_queue(Self::VQ_REQUEST_0, &self.request_queue, vector);
Ok(())
}
}
impl<T: Transport + 'static> InterruptHandler for VirtioBlk<T> {
fn handle_irq(self: Arc<Self>, _vector: IrqVector) -> bool {
// Only one queue
self.request_queue.handle_notify();
true
}
}
impl<T: Transport + 'static> Device for VirtioBlk<T> {
unsafe fn init(self: Arc<Self>, _cx: DeviceInitContext) -> Result<(), Error> {
let status = self.begin_init()?;
self.setup_queues()?;
self.finish_init(status);
register_virtio_block_device(self.clone());
Ok(())
}
fn display_name(&self) -> &str {
"VirtIO Block Device"
}
}
#[async_trait]
impl<T: Transport + 'static> BlockDevice for VirtioBlk<T> {
fn allocate_buffer(&self, size: usize) -> Result<DmaBuffer<[MaybeUninit<u8>]>, Error> {
DmaBuffer::new_uninit_slice(&*self.dma, size)
}
async fn read_aligned(
&self,
position: u64,
buffer: DmaSliceMut<'_, MaybeUninit<u8>>,
) -> Result<(), Error> {
if position % 512 != 0 || buffer.len() % 512 != 0 {
return Err(Error::InvalidArgument);
}
let lba = position / 512;
let lba_count = buffer.len() / 512;
if lba + lba_count as u64 >= self.capacity {
return Err(Error::InvalidArgument);
}
let mut header = DmaBuffer::new_slice(&*self.dma, 0, size_of::<CommandHeader>())?;
*bytemuck::from_bytes_mut(&mut header[..]) = CommandHeader::for_read(lba);
let mut status = DmaBuffer::new_uninit_slice(&*self.dma, 1)?;
self.request_queue
.enqueue_wait(
&[header.slice(0..size_of::<CommandHeader>())],
&[buffer, status.slice_mut(0..1)],
|| {
self.transport.lock().notify(Self::VQ_REQUEST_0);
},
)
.await?;
let status = unsafe { DmaBuffer::assume_init_slice(status) }[0];
if status == 0 {
Ok(())
} else {
Err(Error::InvalidOperation)
}
}
async fn write_aligned(&self, position: u64, buffer: DmaSlice<'_, u8>) -> Result<(), Error> {
if self.read_only {
return Err(Error::ReadOnly);
}
if position % 512 != 0 || buffer.len() % 512 != 0 {
return Err(Error::InvalidArgument);
}
let lba = position / 512;
let lba_count = buffer.len() / 512;
if lba + lba_count as u64 >= self.capacity {
return Err(Error::InvalidArgument);
}
let mut header = DmaBuffer::new_slice(&*self.dma, 0, size_of::<CommandHeader>())?;
*bytemuck::from_bytes_mut(&mut header[..]) = CommandHeader::for_write(lba);
let mut status = DmaBuffer::new_uninit_slice(&*self.dma, 1)?;
self.request_queue
.enqueue_wait(
&[header.slice(0..size_of::<CommandHeader>()), buffer],
&[status.slice_mut(0..1)],
|| {
self.transport.lock().notify(Self::VQ_REQUEST_0);
},
)
.await?;
let status = unsafe { DmaBuffer::assume_init_slice(status) }[0];
if status == 0 {
Ok(())
} else {
Err(Error::InvalidOperation)
}
}
fn block_size(&self) -> usize {
512
}
fn block_count(&self) -> u64 {
self.capacity
}
fn max_blocks_per_request(&self) -> usize {
// TODO this limit can be bumped or scatter-gather operations
self.segment_size / 512
}
}
impl<T: Transport + 'static> PageProvider for VirtioBlk<T> {
fn get_page(&self, _offset: u64) -> Result<PhysicalAddress, Error> {
todo!()
}
fn clone_page(
&self,
_offset: u64,
_src_phys: PhysicalAddress,
_src_attrs: MapAttributes,
) -> Result<PhysicalAddress, Error> {
todo!()
}
fn release_page(&self, _offset: u64, _phys: PhysicalAddress) -> Result<(), Error> {
todo!()
}
}
static DEVICES: IrqSafeRwLock<Vec<Arc<dyn BlockDevice>>> = IrqSafeRwLock::new(Vec::new());
fn register_virtio_block_device(device: Arc<dyn BlockDevice>) {
let index = {
let mut devices = DEVICES.write();
let index = devices.len();
devices.push(device.clone());
index
};
let name = format!("vb{index}");
devfs::add_named_block_device(device.clone(), name.clone(), FileMode::new(0o600)).ok();
runtime::spawn(async move {
let name = name;
log::info!("Probing partitions for {name}");
probe_partitions(device, |index, partition| {
let partition_name = format!("{name}p{}", index + 1);
devfs::add_named_block_device(
Arc::new(partition),
partition_name,
FileMode::new(0o600),
)
.ok();
})
.await
.ok();
})
.ok();
}
pci_driver! {
matches: [device (0x1AF4:0x1001)],
driver: {
fn probe(
&self,
info: &PciDeviceInfo,
dma: &Arc<dyn DmaAllocator>,
) -> Result<Arc<dyn Device>, Error> {
let space = &info.config_space;
let transport = PciTransport::from_config_space(space).unwrap();
let device = VirtioBlk::new(dma.clone(), transport, info.clone())?;
let device = Arc::new(device);
Ok(device)
}
fn driver_name(&self) -> &str {
"virtio-blk"
}
}
}

View File

@ -16,6 +16,12 @@ log.workspace = true
bitflags.workspace = true
tock-registers.workspace = true
[dev-dependencies]
kernel-arch-hosted.path = "../../../arch/hosted"
[features]
default = []
pci = ["ygg_driver_pci"]
[lints]
workspace = true

View File

@ -1,361 +1,536 @@
//! VirtIO queue implementation.
//!
//! # Note
//!
//! The code is poorly borrowed from `virtio-drivers` crate. I want to rewrite it properly myself.
use core::{
mem::MaybeUninit,
sync::atomic::{fence, Ordering},
use core::{future::poll_fn, mem::MaybeUninit, task::Poll, time::Duration};
use alloc::{boxed::Box, sync::Arc};
use device_api::dma::DmaAllocator;
use libk::{
dma::{BusAddress, DmaBuffer, DmaSlice, DmaSliceMut},
error::Error,
task::runtime::psleep,
};
use libk_util::{
event::OneTimeEvent, hash_table::DefaultHashTable, sync::IrqSafeSpinlock, waker::QueueWaker,
};
use device_api::dma::DmaAllocator;
use libk::dma::{BusAddress, DmaBuffer};
use crate::{error::Error, transport::Transport};
#[derive(Debug)]
#[derive(Clone, Copy, Debug)]
#[repr(C)]
struct Descriptor {
address: BusAddress,
len: u32,
length: u32,
flags: u16,
next: u16,
}
// Layout:
// {
// flags: u16,
// idx: u16,
// ring: [u16; QUEUE_SIZE],
// used_event: u16
// }
struct AvailableRing {
data: DmaBuffer<[MaybeUninit<u16>]>,
mapping: DmaBuffer<[u16]>,
capacity: usize,
}
// Layout:
// {
// flags: u16,
// idx: u16,
// ring: [UsedElem; QUEUE_SIZE],
// avail_event: u16,
// _pad: u16
// }
struct UsedRing {
data: DmaBuffer<[MaybeUninit<u32>]>,
used_count: usize,
mapping: DmaBuffer<[u32]>,
last_seen_used: u16,
capacity: usize,
}
pub struct VirtQueue {
descriptor_table: DmaBuffer<[MaybeUninit<Descriptor>]>,
pub struct DescriptorTable {
descriptors: DmaBuffer<[Descriptor]>,
free_count: usize,
first_free: Option<u16>,
last_free: Option<u16>,
}
struct VqInner {
descriptors: DescriptorTable,
available: AvailableRing,
used: UsedRing,
}
pub struct VirtQueue<N: VqNotificationMechanism = VqAsyncNotification> {
inner: IrqSafeSpinlock<VqInner>,
free_descriptor_notify: QueueWaker,
#[allow(unused)]
index: u16,
capacity: usize,
queue_index: u16,
free_head: u16,
used_notify: N,
}
avail_idx: u16,
last_used_idx: u16,
pub trait VqNotificationMechanism {
type Token;
msix_vector: u16,
fn notify_used(&self, head: u16, length: u32);
fn create_token(&self, head: u16) -> Self::Token;
}
pub struct VqAsyncNotification {
completions: IrqSafeSpinlock<DefaultHashTable<u16, Arc<OneTimeEvent<u32>>>>,
}
pub struct VqManualNotification;
pub struct VqCallbackNotification(Box<dyn Fn(u16, u32) + Sync + Send>);
impl VqNotificationMechanism for VqAsyncNotification {
type Token = Arc<OneTimeEvent<u32>>;
fn notify_used(&self, head: u16, length: u32) {
let mut completions = self.completions.lock();
if let Some(completion) = completions.remove(&head) {
log::trace!("vq: completion #{head}");
completion.signal(length);
}
}
fn create_token(&self, head: u16) -> Self::Token {
let mut completions = self.completions.lock();
let token = Arc::new(OneTimeEvent::new());
completions.insert(head, token.clone());
token
}
}
impl VqNotificationMechanism for VqCallbackNotification {
type Token = u16;
fn create_token(&self, head: u16) -> Self::Token {
head
}
fn notify_used(&self, head: u16, length: u32) {
(self.0)(head, length);
}
}
impl VqNotificationMechanism for VqManualNotification {
type Token = u16;
fn create_token(&self, head: u16) -> Self::Token {
head
}
fn notify_used(&self, _head: u16, _length: u32) {
unreachable!()
}
}
impl AvailableRing {
const FLAGS: usize = 0;
const IDX: usize = 1;
const RING: usize = 2;
pub fn with_capacity(
dma: &dyn DmaAllocator,
no_irq: bool,
capacity: usize,
no_interrupt: bool,
) -> Result<Self, Error> {
let mut data = DmaBuffer::new_zeroed_slice(dma, capacity + 3)?;
if no_irq {
data[0].write(1);
// flags + idx + [ring] + used_event
let mut mapping = DmaBuffer::new_slice(dma, 0u16, (capacity + 6) & !3)?;
if no_interrupt {
mapping[Self::FLAGS] |= 1 << 0;
}
data[1].write(0);
Ok(Self { data })
Ok(Self { mapping, capacity })
}
pub fn set_head(&mut self, slot: u16, head: u16) {
self.data[slot as usize + 2].write(head);
}
pub fn set_index(&mut self, index: u16) {
self.data[1].write(index);
pub fn push(&mut self, head: u16) -> u16 {
log::trace!("enqueue #{head}");
let idx = self.mapping[Self::IDX];
let index = idx as usize % self.capacity;
self.mapping[Self::RING + index] = head;
self.mapping.cache_flush_element(Self::RING + index, true);
let idx = idx.wrapping_add(1);
self.mapping[Self::IDX] = idx;
self.mapping.cache_flush_element(Self::IDX, true);
idx
}
}
impl UsedRing {
pub fn with_capacity(dma: &dyn DmaAllocator, capacity: usize) -> Result<Self, Error> {
let mut data = DmaBuffer::new_zeroed_slice(dma, capacity * 2 + 2)?;
data[0].write(0);
const FLAGS_IDX: usize = 0;
const RING: usize = 1;
pub fn with_capacity(
dma: &dyn DmaAllocator,
capacity: usize,
no_notify: bool,
) -> Result<Self, Error> {
// 2x u16 (flags + idx) + [ring x 2 x u32] + avail_event
let mut mapping = DmaBuffer::new_slice(dma, 0, capacity * 2 + 1)?;
if no_notify {
mapping[Self::FLAGS_IDX] |= 1 << 0;
}
mapping.cache_flush_element(Self::FLAGS_IDX, true);
Ok(Self {
data,
used_count: 0,
mapping,
capacity,
last_seen_used: 0,
})
}
pub fn read_slot(&self, index: u16) -> (u32, u32) {
let index = unsafe { self.data[1 + index as usize * 2].assume_init() };
let len = unsafe { self.data[2 + index as usize * 2].assume_init() };
(index, len)
}
pub fn index(&self) -> u16 {
unsafe { (self.data[0].assume_init() >> 16) as u16 }
pub fn consume<F: FnMut(u16, u32)>(&mut self, mut handler: F) -> usize {
self.mapping.cache_flush_element(Self::FLAGS_IDX, false);
let idx = (self.mapping[Self::FLAGS_IDX] >> 16) as u16;
let mut count = 0;
while self.last_seen_used != idx {
let index = self.last_seen_used as usize % self.capacity;
self.mapping
.cache_flush_range(Self::RING + index..Self::RING + index + 1, false);
let head = self.mapping[Self::RING + index * 2] as u16;
let len = self.mapping[Self::RING + index * 2 + 1];
handler(head, len);
count += 1;
self.last_seen_used = self.last_seen_used.wrapping_add(1);
}
count
}
}
impl VirtQueue {
pub fn with_capacity<T: Transport>(
transport: &mut T,
dma: &dyn DmaAllocator,
index: u16,
capacity: usize,
msix_vector: Option<u16>,
no_avail_irq: bool,
) -> Result<Self, Error> {
// TODO check if queue is already set up
impl Descriptor {
pub const EMPTY: Self = Self {
address: BusAddress::ZERO,
length: 0,
flags: 0,
next: 0,
};
let max_capacity = transport.max_queue_size(index);
pub const F_NEXT: u16 = 1 << 0;
pub const F_WRITE: u16 = 1 << 1;
}
if !capacity.is_power_of_two() || capacity > u16::MAX.into() {
return Err(Error::InvalidQueueSize);
impl DescriptorTable {
const FREE_CHAIN_END: u16 = u16::MAX;
pub fn with_capacity(dma: &dyn DmaAllocator, capacity: usize) -> Result<Self, Error> {
if capacity >= Self::FREE_CHAIN_END as usize - 1 {
return Err(Error::InvalidArgument);
}
if capacity > max_capacity as usize {
return Err(Error::QueueTooLarge);
let mut descriptors = DmaBuffer::new_slice(dma, Descriptor::EMPTY, capacity)?;
for i in 0..capacity {
if i == capacity - 1 {
// Last descriptor of the free chain
descriptors[i].next = Self::FREE_CHAIN_END;
} else {
descriptors[i].next = (i + 1) as u16;
}
}
let descriptor_table = DmaBuffer::new_zeroed_slice(dma, capacity)?;
let available = AvailableRing::with_capacity(dma, no_avail_irq, capacity)?;
let used = UsedRing::with_capacity(dma, capacity)?;
transport.set_queue(
index,
capacity as u16,
descriptor_table.bus_address(),
available.data.bus_address(),
used.data.bus_address(),
msix_vector,
);
Ok(Self {
descriptor_table,
available,
used,
capacity,
queue_index: index,
free_head: 0,
avail_idx: 0,
last_used_idx: 0,
msix_vector: msix_vector.unwrap_or(0xFFFF),
descriptors,
free_count: capacity,
first_free: Some(0),
last_free: Some(capacity as u16 - 1),
})
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn with_max_capacity<T: Transport>(
transport: &mut T,
dma: &dyn DmaAllocator,
index: u16,
capacity: usize,
msix_vector: Option<u16>,
no_avail_irq: bool,
) -> Result<Self, Error> {
let max_capacity = transport.max_queue_size(index);
let capacity = capacity.min(max_capacity as usize);
Self::with_capacity(transport, dma, index, capacity, msix_vector, no_avail_irq)
// Allocate a chain of descriptors
fn alloc_descriptors(
&mut self,
h2d: &[DmaSlice<u8>],
d2h: &[DmaSliceMut<MaybeUninit<u8>>],
) -> Result<u16, Error> {
if d2h.len() + h2d.len() == 0 {
// Empty transfer
return Err(Error::InvalidArgument);
}
if self.free_count < d2h.len() + h2d.len() {
// Not enough descriptor "slots" to place the buffers
return Err(Error::WouldBlock);
}
// Implied by free_count
debug_assert!(self.first_free.is_some());
debug_assert!(self.last_free.is_some());
let head = unsafe {
self.follow_descriptor_chain(h2d.len() + d2h.len(), |i, desc| {
if i < h2d.len() {
desc.address = h2d[i].bus_address();
desc.length = h2d[i].len() as u32;
desc.flags = 0;
} else {
let i = i - h2d.len();
desc.address = d2h[i].bus_address();
desc.length = d2h[i].len() as u32;
desc.flags = Descriptor::F_WRITE;
}
})
};
Ok(head)
}
/// # Safety
///
/// Invariants: DmaBuffer remains valid and allocated until it is properly dequeued.
pub unsafe fn add<'a, 'b>(
/// The following invariants must hold:
///
/// * **There must actually be `count` free descriptors**.
/// * **`count` should not be zero**.
/// * The VQ must be in a consistent state: first_free/last_free/free_count must be consistent
/// in terms of tracking the current state of the descriptor allocation.
/// * Free descriptors in the table must be chained.
/// * Last free descriptor must have [Self::FREE_CHAIN_END] as its `next` field.
unsafe fn follow_descriptor_chain<F: Fn(usize, &mut Descriptor)>(
&mut self,
input: &'a [&'b mut DmaBuffer<[MaybeUninit<u8>]>],
output: &'a [&'b DmaBuffer<[u8]>],
) -> Result<u16, Error> {
if input.is_empty() && output.is_empty() {
return Err(Error::EmptyTransaction);
}
let n_desc = input.len() + output.len();
if self.used.used_count + 1 > self.capacity || self.used.used_count + n_desc > self.capacity
{
return Err(Error::QueueFull);
}
let head = self.add_direct(input, output);
let avail_slot = self.avail_idx % self.capacity as u16;
self.available.set_head(avail_slot, head);
fence(Ordering::SeqCst);
self.avail_idx = self.avail_idx.wrapping_add(1);
self.available.set_index(self.avail_idx);
fence(Ordering::SeqCst);
Ok(head)
}
unsafe fn add_direct<'a, 'b>(
&mut self,
input: &'a [&'b mut DmaBuffer<[MaybeUninit<u8>]>],
output: &'a [&'b DmaBuffer<[u8]>],
count: usize,
visitor: F,
) -> u16 {
let head = self.free_head;
let mut last = self.free_head;
for item in output {
assert_ne!(item.len(), 0);
let desc = &mut self.descriptor_table[usize::from(self.free_head)];
let next = (self.free_head + 1) % self.capacity as u16;
desc.write(Descriptor {
address: item.bus_address(),
len: item.len().try_into().unwrap(),
// TODO
flags: (1 << 0),
next,
});
last = self.free_head;
self.free_head = next;
debug_assert_ne!(count, 0);
let mut current = self.first_free.unwrap_unchecked();
let head = current;
for i in 0..count {
let descriptor = &mut self.descriptors[current as usize];
log::trace!("vq: alloc desc #{current}");
visitor(i, descriptor);
if i == count - 1 {
debug_assert_eq!(descriptor.flags & Descriptor::F_NEXT, 0);
current = descriptor.next;
descriptor.next = 0;
} else {
current = descriptor.next;
descriptor.flags |= Descriptor::F_NEXT;
debug_assert_ne!(current, Self::FREE_CHAIN_END);
};
}
for item in input {
assert_ne!(item.len(), 0);
let desc = &mut self.descriptor_table[usize::from(self.free_head)];
let next = (self.free_head + 1) % self.capacity as u16;
desc.write(Descriptor {
address: item.bus_address(),
len: item.len().try_into().unwrap(),
// TODO MAGIC
flags: (1 << 0) | (1 << 1),
next,
});
last = self.free_head;
self.free_head = next;
if current == Self::FREE_CHAIN_END {
// No free descriptors left
debug_assert_eq!(self.free_count, count);
self.first_free = None;
self.last_free = None;
} else {
self.first_free = Some(current);
}
{
let last_desc = self.descriptor_table[last as usize].assume_init_mut();
// TODO
last_desc.flags &= !(1 << 0);
}
self.used.used_count += input.len() + output.len();
fence(Ordering::SeqCst);
self.free_count -= count;
head
}
pub fn add_notify_wait_pop<'a, 'b, T: Transport>(
&mut self,
input: &'a [&'b mut DmaBuffer<[MaybeUninit<u8>]>],
output: &'a [&'b DmaBuffer<[u8]>],
transport: &mut T,
) -> Result<u32, Error> {
let token = unsafe { self.add(input, output) }?;
fn add_free_descriptor(&mut self, idx: u16) {
log::trace!("vq: free descriptor #{idx}");
self.descriptors[idx as usize] = Descriptor {
next: Self::FREE_CHAIN_END,
..Descriptor::EMPTY
};
transport.notify(self.queue_index);
while self.is_used_empty() {
core::hint::spin_loop();
}
fence(Ordering::SeqCst);
unsafe { self.pop_used(token) }
}
pub fn is_used_empty(&self) -> bool {
fence(Ordering::SeqCst);
self.last_used_idx == self.used.index()
}
pub fn pop_last_used(&mut self) -> Option<(u16, u32)> {
let token = self.peek_used()?;
let len = unsafe { self.pop_used(token) }.unwrap();
Some((token, len))
}
fn peek_used(&mut self) -> Option<u16> {
if !self.is_used_empty() {
let last_used = self.last_used_idx % self.capacity as u16;
Some(self.used.read_slot(last_used).0 as u16)
if let Some(last) = self.last_free {
// Implies first free is Some as well
self.descriptors[last as usize].next = idx;
// last -> idx
self.last_free = Some(idx);
} else {
None
// Implies first free is None as well
self.first_free = Some(idx);
self.last_free = Some(idx);
}
self.free_count += 1;
}
unsafe fn pop_used(&mut self, token: u16) -> Result<u32, Error> {
if self.is_used_empty() {
return Err(Error::QueueEmpty);
}
let last_used_slot = self.last_used_idx % self.capacity as u16;
let (index, len) = self.used.read_slot(last_used_slot);
if index != token as u32 {
return Err(Error::WrongToken);
}
self.free_descriptor_chain(token);
fence(Ordering::SeqCst);
self.last_used_idx = self.last_used_idx.wrapping_add(1);
Ok(len)
}
unsafe fn free_descriptor_chain(&mut self, head: u16) -> usize {
let mut current_node = Some(self.descriptor_table[usize::from(head)].assume_init_mut());
fn free_descriptor_chain(&mut self, first: u16) -> usize {
let mut current = first;
let mut count = 0;
loop {
let descriptor = &self.descriptors[current as usize];
let next = if descriptor.flags & Descriptor::F_NEXT != 0 {
Some(descriptor.next)
} else {
None
};
while let Some(current) = current_node {
assert_ne!(current.len, 0);
let next_head = (current.flags & (1 << 0) != 0).then_some(current.next);
current.address = BusAddress::ZERO;
current.flags = 0;
current.next = 0;
current.len = 0;
self.used.used_count -= 1;
count += 1;
self.add_free_descriptor(current);
current_node =
next_head.map(|head| self.descriptor_table[usize::from(head)].assume_init_mut());
if let Some(next) = next {
current = next;
} else {
break;
}
}
self.free_head = head;
count
}
}
pub fn msix_vector(&self) -> u16 {
self.msix_vector
impl VqInner {
fn with_capacity(dma: &dyn DmaAllocator, capacity: usize) -> Result<Self, Error> {
let descriptors = DescriptorTable::with_capacity(dma, capacity)?;
let available = AvailableRing::with_capacity(dma, capacity, false)?;
let used = UsedRing::with_capacity(dma, capacity, false)?;
Ok(Self {
descriptors,
available,
used,
})
}
fn try_enqueue(
&mut self,
h2d: &[DmaSlice<u8>],
d2h: &[DmaSliceMut<MaybeUninit<u8>>],
) -> Result<u16, Error> {
let head = self.descriptors.alloc_descriptors(h2d, d2h)?;
self.available.push(head);
Ok(head)
}
fn consume<F: FnMut(u16, u32)>(&mut self, free_notify: &QueueWaker, mut handler: F) -> usize {
self.used.consume(|head, len| {
log::trace!("vq: used #{head}, len={len}");
self.descriptors.free_descriptor_chain(head);
free_notify.wake_all();
handler(head, len);
})
}
}
impl VirtQueue<VqAsyncNotification> {
pub fn with_capacity(
dma: &dyn DmaAllocator,
index: u16,
capacity: usize,
) -> Result<Self, Error> {
let used_notify = VqAsyncNotification::new();
Self::with_capacity_and_notify(dma, index, capacity, used_notify)
}
pub async fn enqueue_wait<F: FnOnce()>(
&self,
h2d: &[DmaSlice<'_, u8>],
d2h: &[DmaSliceMut<'_, MaybeUninit<u8>>],
notify_queue: F,
) -> Result<u32, Error> {
let completion = self.enqueue(h2d, d2h).await?;
notify_queue();
let result = completion.wait_copy().await;
Ok(result)
}
}
impl VirtQueue<VqCallbackNotification> {
pub fn with_capacity_and_callback(
dma: &dyn DmaAllocator,
index: u16,
capacity: usize,
used_callback: Box<dyn Fn(u16, u32) + Sync + Send>,
) -> Result<Self, Error> {
let used_notify = VqCallbackNotification(used_callback);
Self::with_capacity_and_notify(dma, index, capacity, used_notify)
}
}
impl VirtQueue<VqManualNotification> {
pub fn with_capacity_manual(
dma: &dyn DmaAllocator,
index: u16,
capacity: usize,
) -> Result<Self, Error> {
let used_notify = VqManualNotification;
Self::with_capacity_and_notify(dma, index, capacity, used_notify)
}
pub fn handle_notify_manual<F: FnMut(u16, u32)>(&self, handler: F) -> usize {
self.inner
.lock()
.consume(&self.free_descriptor_notify, handler)
}
// Used when queue does not support device-side used notification
pub fn enqueue_blocking<F: FnOnce()>(
&self,
h2d: &[DmaSlice<u8>],
d2h: &[DmaSliceMut<MaybeUninit<u8>>],
notify_queue: F,
) -> Result<u32, Error> {
let token = self.try_enqueue(h2d, d2h)?;
let mut length = 0;
notify_queue();
loop {
self.handle_notify_manual(|head, len| {
assert_eq!(head, token);
length = len;
});
if length != 0 {
break;
}
psleep(Duration::from_millis(1));
}
Ok(length)
}
}
impl<N: VqNotificationMechanism> VirtQueue<N> {
pub fn with_capacity_and_notify(
dma: &dyn DmaAllocator,
index: u16,
capacity: usize,
used_notify: N,
) -> Result<Self, Error> {
let inner = VqInner::with_capacity(dma, capacity)?;
Ok(Self {
inner: IrqSafeSpinlock::new(inner),
index,
capacity,
used_notify,
free_descriptor_notify: QueueWaker::new(),
})
}
pub fn descriptor_table_base(&self) -> BusAddress {
self.inner.lock().descriptors.descriptors.bus_address()
}
pub fn available_ring_base(&self) -> BusAddress {
self.inner.lock().available.mapping.bus_address()
}
pub fn used_ring_base(&self) -> BusAddress {
self.inner.lock().used.mapping.bus_address()
}
pub fn capacity(&self) -> u16 {
self.capacity as u16
}
pub fn handle_notify(&self) -> usize {
self.inner
.lock()
.consume(&self.free_descriptor_notify, |head, len| {
self.used_notify.notify_used(head, len);
})
}
pub fn try_enqueue(
&self,
h2d: &[DmaSlice<'_, u8>],
d2h: &[DmaSliceMut<'_, MaybeUninit<u8>>],
) -> Result<N::Token, Error> {
let head = self.inner.lock().try_enqueue(h2d, d2h)?;
Ok(self.used_notify.create_token(head))
}
pub async fn enqueue(
&self,
h2d: &[DmaSlice<'_, u8>],
d2h: &[DmaSliceMut<'_, MaybeUninit<u8>>],
) -> Result<N::Token, Error> {
poll_fn(|cx| match self.try_enqueue(h2d, d2h) {
Err(Error::WouldBlock) => {
self.free_descriptor_notify.register(cx.waker());
Poll::Pending
}
result => {
self.free_descriptor_notify.remove(cx.waker());
Poll::Ready(result)
}
})
.await
}
}
impl VqAsyncNotification {
pub fn new() -> Self {
Self {
completions: IrqSafeSpinlock::new(DefaultHashTable::new()),
}
}
}

View File

@ -7,8 +7,12 @@ use tock_registers::{
registers::WriteOnly,
};
use crate::{CommonConfiguration, DeviceStatus};
use crate::{
queue::{VirtQueue, VqNotificationMechanism},
CommonConfiguration, DeviceStatus,
};
#[cfg(any(feature = "pci", rust_analyzer))]
pub mod pci;
pub trait Transport: Send {
@ -53,7 +57,23 @@ pub trait Transport: Send {
cfg.queue_size.get().into()
}
fn set_queue(
fn set_queue<N: VqNotificationMechanism>(
&mut self,
index: u16,
queue: &VirtQueue<N>,
msix_vector: Option<u16>,
) {
self.set_queue_raw(
index,
queue.capacity(),
queue.descriptor_table_base(),
queue.available_ring_base(),
queue.used_ring_base(),
msix_vector,
);
}
fn set_queue_raw(
&mut self,
queue: u16,
capacity: u16,

View File

@ -7,8 +7,11 @@ use libk::{
dma::{BusAddress, DmaBuffer},
error::Error,
};
use libk_util::sync::IrqSafeSpinlockGuard;
use ygg_driver_virtio_core::{queue::VirtQueue, transport::Transport};
use libk_util::sync::IrqSafeSpinlock;
use ygg_driver_virtio_core::{
queue::{VirtQueue, VqManualNotification},
transport::Transport,
};
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
@ -93,19 +96,12 @@ pub struct TransferToHost2d {
pub _0: u32,
}
pub struct ControlLock<'a, T: Transport> {
control: IrqSafeSpinlockGuard<'a, VirtQueue>,
transport: IrqSafeSpinlockGuard<'a, T>,
pub struct CommandExecution<'a, T: Transport> {
pub(super) transport: &'a IrqSafeSpinlock<T>,
pub(super) control: &'a VirtQueue<VqManualNotification>,
}
impl<'a, T: Transport> ControlLock<'a, T> {
pub const fn new(
control: IrqSafeSpinlockGuard<'a, VirtQueue>,
transport: IrqSafeSpinlockGuard<'a, T>,
) -> Self {
Self { control, transport }
}
impl<'a, T: Transport> CommandExecution<'a, T> {
fn send_recv<'r, Req: Pod>(
&mut self,
dma: &dyn DmaAllocator,
@ -116,13 +112,11 @@ impl<'a, T: Transport> ControlLock<'a, T> {
let mut request = unsafe { DmaBuffer::assume_init_slice(request) };
request.copy_from_slice(bytemuck::bytes_of(req));
let len = self
.control
.add_notify_wait_pop(&[buffer], &[&request], &mut *self.transport)
.inspect_err(|error| {
log::warn!("virtio queue: {error:?}");
})
.map_err(|_| Error::InvalidArgument)? as usize;
let len = self.control.enqueue_blocking(
&[request.slice(0..size_of::<Req>())],
&[buffer.slice_mut(0..buffer.len())],
|| self.transport.lock().notify(0),
)? as usize;
if len < size_of::<ControlHeader>() {
log::warn!("virtio-gpu: invalid device response length: {len}");

View File

@ -6,7 +6,7 @@ extern crate alloc;
use core::mem::MaybeUninit;
use alloc::{sync::Arc, vec::Vec};
use command::{ControlLock, ScanoutInfo};
use command::{CommandExecution, ScanoutInfo};
use device_api::{
device::{Device, DeviceInitContext},
dma::DmaAllocator,
@ -31,7 +31,7 @@ use libk_util::{
};
use ygg_driver_pci::{device::PciDeviceInfo, macros::pci_driver};
use ygg_driver_virtio_core::{
queue::VirtQueue,
queue::{VirtQueue, VqManualNotification},
transport::{pci::PciTransport, Transport},
DeviceStatus,
};
@ -40,7 +40,7 @@ use yggdrasil_abi::error::Error;
mod command;
struct Queues {
control: IrqSafeSpinlock<VirtQueue>,
control: VirtQueue<VqManualNotification>,
}
struct Framebuffer {
@ -64,7 +64,7 @@ struct Config {
pub struct VirtioGpu<T: Transport> {
transport: IrqSafeSpinlock<T>,
#[allow(unused)]
pci_device_info: Option<PciDeviceInfo>,
pci_device_info: PciDeviceInfo,
queues: OneTimeInit<Queues>,
config: IrqSafeRwLock<Config>,
@ -78,7 +78,7 @@ impl<T: Transport + 'static> VirtioGpu<T> {
pub fn new(
dma: Arc<dyn DmaAllocator>,
transport: T,
info: Option<PciDeviceInfo>,
info: PciDeviceInfo,
) -> Result<Self, Error> {
// Read num-scanouts from device config
let Some(device_cfg) = transport.device_cfg() else {
@ -149,34 +149,31 @@ impl<T: Transport + 'static> VirtioGpu<T> {
transport.write_device_status(status | DeviceStatus::DRIVER_OK);
}
fn setup_queues(&self) -> Result<(), Error> {
fn setup_queues(self: &Arc<Self>) -> Result<(), Error> {
// TODO cursorq
let mut transport = self.transport.lock();
let control = VirtQueue::with_max_capacity(&mut *transport, &*self.dma, 0, 128, None, true)
.map_err(|_| Error::InvalidArgument)?;
let control = VirtQueue::with_capacity_manual(&*self.dma, 0, 128)?;
transport.set_queue(0, &control, None);
self.queues.init(Queues {
control: IrqSafeSpinlock::new(control),
});
self.queues.init(Queues { control });
Ok(())
}
fn control(&self) -> ControlLock<T> {
let queues = self.queues.get();
let control = queues.control.lock();
let transport = self.transport.lock();
ControlLock::new(control, transport)
fn begin_command(&self) -> CommandExecution<T> {
CommandExecution {
transport: &self.transport,
control: &self.queues.get().control,
}
}
fn setup_display(&self) -> Result<(), Error> {
let mut control = self.control();
let mut config = self.config.write();
let mut command = self.begin_command();
let scanouts =
control.query_scanouts(&*self.dma, self.num_scanouts, &mut config.response)?;
command.query_scanouts(&*self.dma, self.num_scanouts, &mut config.response)?;
for (i, scanout) in scanouts.iter().enumerate() {
log::info!(
"virtio-gpu: [{i}] {}x{} + {},{}",
@ -214,23 +211,23 @@ impl<T: Transport + 'static> VirtioGpu<T> {
let dma_buffer = DmaBuffer::new_uninit_slice(&*self.dma, size)?;
let mut control = self.control();
let mut command = self.begin_command();
let resource_id = control.create_resource_2d(
let resource_id = command.create_resource_2d(
&*self.dma,
&mut config.response,
w,
h,
PixelFormat::R8G8B8A8,
)?;
control.attach_backing(
command.attach_backing(
&*self.dma,
&mut config.response,
resource_id,
dma_buffer.bus_address(),
size.try_into().unwrap(),
)?;
control.set_scanout(
command.set_scanout(
&*self.dma,
&mut config.response,
index as u32,
@ -259,7 +256,7 @@ impl<T: Transport + 'static> VirtioGpu<T> {
let framebuffer = config.framebuffer.as_ref().ok_or(Error::DoesNotExist)?;
let r = config.scanouts[framebuffer.scanout_index].r;
let mut control = self.control();
let mut command = self.begin_command();
if framebuffer.double {
// Flip the buffer
@ -267,8 +264,8 @@ impl<T: Transport + 'static> VirtioGpu<T> {
} else {
let resource_id = framebuffer.resource_id;
control.transfer_to_host_2d(&*self.dma, &mut config.response, resource_id, r)?;
control.resource_flush(&*self.dma, &mut config.response, resource_id, r)?;
command.transfer_to_host_2d(&*self.dma, &mut config.response, resource_id, r)?;
command.resource_flush(&*self.dma, &mut config.response, resource_id, r)?;
Ok(())
}
@ -283,7 +280,6 @@ impl<T: Transport + 'static> Device for VirtioGpu<T> {
// Set up some initial mode
self.setup_display()?;
self.setup_mode(0)?;
DEVICE_REGISTRY.display.register(self.clone(), false)?;
@ -427,7 +423,7 @@ pci_driver! {
log::error!("Couldn't set up PCI virtio transport: {error:?}");
})
.map_err(|_| Error::InvalidArgument)?;
let device = VirtioGpu::new(dma.clone(), transport, Some(info.clone()))?;
let device = VirtioGpu::new(dma.clone(), transport, info.clone())?;
let device = Arc::new(device);
Ok(device)

View File

@ -18,6 +18,7 @@ log.workspace = true
bitflags.workspace = true
tock-registers.workspace = true
bytemuck.workspace = true
futures-util.workspace = true
[features]
default = []

View File

@ -5,16 +5,19 @@ extern crate alloc;
use core::mem::{size_of, MaybeUninit};
use alloc::{collections::BTreeMap, sync::Arc};
use alloc::{boxed::Box, sync::Arc};
use bytemuck::{Pod, Zeroable};
use device_api::{
device::{Device, DeviceInitContext},
dma::DmaAllocator,
interrupt::{InterruptAffinity, InterruptHandler, IrqVector},
};
use libk::dma::DmaBuffer;
use futures_util::task::AtomicWaker;
use libk::{dma::DmaBuffer, task::runtime};
use libk_util::{
sync::{spin_rwlock::IrqSafeRwLock, IrqSafeSpinlock, IrqSafeSpinlockGuard},
event::BitmapEvent,
hash_table::DefaultHashTable,
sync::{spin_rwlock::IrqSafeRwLock, IrqSafeSpinlock},
OneTimeInit,
};
use ygg_driver_net_core::{
@ -26,28 +29,25 @@ use ygg_driver_pci::{
macros::pci_driver,
};
use ygg_driver_virtio_core::{
queue::VirtQueue,
queue::{VirtQueue, VqCallbackNotification, VqManualNotification},
transport::{pci::PciTransport, Transport},
DeviceStatus,
};
use yggdrasil_abi::{error::Error, net::MacAddress};
struct Queues {
receive: IrqSafeSpinlock<VirtQueue>,
transmit: IrqSafeSpinlock<VirtQueue>,
}
pub struct VirtioNet<T: Transport> {
transport: IrqSafeSpinlock<T>,
queues: OneTimeInit<Queues>,
interface_id: OneTimeInit<u32>,
mac: IrqSafeRwLock<MacAddress>,
pending_packets: IrqSafeRwLock<BTreeMap<u16, DmaBuffer<[MaybeUninit<u8>]>>>,
pci_device_info: PciDeviceInfo,
dma: Arc<dyn DmaAllocator>,
pci_device_info: Option<PciDeviceInfo>,
interface_id: OneTimeInit<u32>,
mac: IrqSafeRwLock<MacAddress>,
rx_queue: OneTimeInit<VirtQueue<VqManualNotification>>,
tx_queue: OneTimeInit<VirtQueue<VqCallbackNotification>>,
tx_in_flight: IrqSafeSpinlock<DefaultHashTable<u16, DmaBuffer<[u8]>>>,
softirq: BitmapEvent<AtomicWaker>,
}
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
@ -61,23 +61,15 @@ struct VirtioPacketHeader {
csum_offset: u16,
}
impl Queues {
pub fn try_receive(&self, _index: usize) -> Option<(u16, IrqSafeSpinlockGuard<VirtQueue>)> {
let mut queue = self.receive.lock();
// TODO use len for packet size hint
let (token, _len) = queue.pop_last_used()?;
Some((token, queue))
}
}
impl<T: Transport + 'static> VirtioNet<T> {
const PACKET_SIZE: usize = 4096;
const VQ_RX: u16 = 0;
const VQ_TX: u16 = 1;
pub fn new(
dma: Arc<dyn DmaAllocator>,
transport: T,
pci_device_info: Option<PciDeviceInfo>,
) -> Self {
pci_device_info: PciDeviceInfo,
) -> Result<Self, Error> {
// Read MAC from device config
let device_cfg = transport
.device_cfg()
@ -86,60 +78,21 @@ impl<T: Transport + 'static> VirtioNet<T> {
mac_bytes.copy_from_slice(&device_cfg[..6]);
let mac = MacAddress::from(mac_bytes);
Self {
pci_device_info.init_interrupts(PreferredInterruptMode::Msi(true))?;
Ok(Self {
transport: IrqSafeSpinlock::new(transport),
queues: OneTimeInit::new(),
interface_id: OneTimeInit::new(),
mac: IrqSafeRwLock::new(mac),
pending_packets: IrqSafeRwLock::new(BTreeMap::new()),
pci_device_info,
dma,
}
}
pub fn listen(&self, buffers: usize) {
let queues = self.queues.get();
let mut queue = queues.receive.lock();
let mut packets = self.pending_packets.write();
interface_id: OneTimeInit::new(),
mac: IrqSafeRwLock::new(mac),
for _ in 0..buffers {
let mut packet = DmaBuffer::new_uninit_slice(&*self.dma, Self::PACKET_SIZE).unwrap();
let token = unsafe { queue.add(&[&mut packet], &[]).unwrap() };
packets.insert(token, packet);
}
let mut transport = self.transport.lock();
transport.notify(0);
}
fn handle_receive_interrupt(&self, queue: usize) -> bool {
let queues = self.queues.get();
let interface_id = *self.interface_id.get();
let mut count = 0;
while let Some((token, mut queue)) = queues.try_receive(queue) {
let mut pending_packets = self.pending_packets.write();
let packet = pending_packets.remove(&token).unwrap();
let mut buffer = DmaBuffer::new_uninit_slice(&*self.dma, Self::PACKET_SIZE).unwrap();
let token = unsafe { queue.add(&[&mut buffer], &[]).unwrap() };
pending_packets.insert(token, buffer);
let packet = unsafe { DmaBuffer::assume_init_slice(packet) };
let packet = RxPacket::new(packet, size_of::<VirtioPacketHeader>(), interface_id);
ygg_driver_net_core::receive_packet(packet).unwrap();
count += 1
}
if count != 0 {
self.transport.lock().notify(0);
}
count != 0
rx_queue: OneTimeInit::new(),
tx_queue: OneTimeInit::new(),
tx_in_flight: IrqSafeSpinlock::new(DefaultHashTable::new()),
softirq: BitmapEvent::new(AtomicWaker::new()),
})
}
fn begin_init(&self) -> Result<DeviceStatus, Error> {
@ -178,46 +131,85 @@ impl<T: Transport + 'static> VirtioNet<T> {
transport.write_device_status(status | DeviceStatus::DRIVER_OK);
}
unsafe fn setup_queues(
self: Arc<Self>,
receive_count: usize,
transmit_count: usize,
) -> Result<(), Error> {
let receive_vector = if let Some(pci) = self.pci_device_info.as_ref() {
pci.init_interrupts(PreferredInterruptMode::Msi(true))?;
let info = pci.map_interrupt(InterruptAffinity::Any, self.clone())?;
info.map(|info| info.vector as u16)
unsafe fn setup_queues(self: &Arc<Self>) -> Result<(), Error> {
let (rx_vector, tx_vector) = if let Ok(msis) =
self.pci_device_info
.map_interrupt_multiple(0..2, InterruptAffinity::Any, self.clone())
{
// Bound a MSI(-x) range, use per-queue vectors
(Some(msis[0].vector as u16), Some(msis[1].vector as u16))
} else {
None
// TODO support non-MSI-x/non-multivec setups
todo!();
};
// TODO multiqueue capability
assert_eq!(receive_count, 1);
assert_eq!(transmit_count, 1);
// Setup a callback to remove pending buffers and pass them to the network stack
let rx_queue = VirtQueue::with_capacity_manual(&*self.dma, Self::VQ_RX, 64)?;
let p = self.clone();
// Setup a callback to remove buffers from "in flight owned buffers" list
let tx_queue = VirtQueue::with_capacity_and_callback(
&*self.dma,
Self::VQ_TX,
64,
Box::new(move |head, _| {
p.tx_in_flight.lock().remove(&head);
}),
)?;
let rx_queue = self.rx_queue.init(rx_queue);
let tx_queue = self.tx_queue.init(tx_queue);
let mut transport = self.transport.lock();
// Setup the virtqs
let rx = VirtQueue::with_max_capacity(
&mut *transport,
&*self.dma,
0,
128,
receive_vector,
false,
)
.map_err(cvt_error)?;
let tx = VirtQueue::with_max_capacity(&mut *transport, &*self.dma, 1, 128, None, true)
.map_err(cvt_error)?;
self.queues.init(Queues {
receive: IrqSafeSpinlock::new(rx),
transmit: IrqSafeSpinlock::new(tx),
});
transport.set_queue(Self::VQ_RX, rx_queue, rx_vector);
transport.set_queue(Self::VQ_TX, tx_queue, tx_vector);
Ok(())
}
async fn softirq(&self) -> Result<(), Error> {
const RX_SIZE: usize = 4096;
const RX_IN_FLIGHT: usize = 32;
let rx_queue = self.rx_queue.get();
let tx_queue = self.tx_queue.get();
let nic = *self.interface_id.get();
let mut rx_in_flight = DefaultHashTable::<u16, DmaBuffer<[MaybeUninit<u8>]>>::new();
// Setup initial Rx set
for _ in 0..RX_IN_FLIGHT {
let mut buffer = DmaBuffer::new_uninit_slice(&*self.dma, RX_SIZE)?;
let token = rx_queue.try_enqueue(&[], &[buffer.slice_mut(0..RX_SIZE)])?;
rx_in_flight.insert(token, buffer);
}
loop {
let events = self.softirq.wait().await;
if events & (1 << Self::VQ_RX) != 0 {
let refill_rx = rx_queue.handle_notify_manual(|head, _| {
if let Some(packet) = rx_in_flight.remove(&head) {
let packet = unsafe { DmaBuffer::assume_init_slice(packet) };
let packet = RxPacket::new(packet, size_of::<VirtioPacketHeader>(), nic);
ygg_driver_net_core::receive_packet(packet).ok();
}
});
// Refill Rx buffers
for _ in 0..refill_rx {
let mut buffer = DmaBuffer::new_uninit_slice(&*self.dma, RX_SIZE)?;
let token = rx_queue.try_enqueue(&[], &[buffer.slice_mut(0..RX_SIZE)])?;
rx_in_flight.insert(token, buffer);
}
}
if events & (1 << Self::VQ_TX) != 0 {
tx_queue.handle_notify();
}
}
}
}
impl<T: Transport + 'static> NetworkDevice for VirtioNet<T> {
@ -226,13 +218,13 @@ impl<T: Transport + 'static> NetworkDevice for VirtioNet<T> {
}
fn transmit_buffer(&self, mut packet: DmaBuffer<[u8]>) -> Result<(), Error> {
let queues = self.queues.get();
let mut tx = queues.transmit.lock();
let mut transport = self.transport.lock();
let tx_queue = self.tx_queue.get();
packet[..size_of::<VirtioPacketHeader>()].fill(0);
let _len = tx
.add_notify_wait_pop(&[], &[&packet], &mut *transport)
.unwrap();
let token = tx_queue.try_enqueue(&[packet.slice(0..packet.len())], &[])?;
// Add the packet to "in flight" list to make sure it doesn't get dropped and invalidated
// immediately after returning from this function
self.tx_in_flight.lock().insert(token, packet);
self.transport.lock().notify(Self::VQ_TX);
Ok(())
}
@ -248,20 +240,12 @@ impl<T: Transport + 'static> NetworkDevice for VirtioNet<T> {
impl<T: Transport + 'static> InterruptHandler for VirtioNet<T> {
fn handle_irq(self: Arc<Self>, vector: IrqVector) -> bool {
match vector {
IrqVector::Msi(_) => {
// MSI/MSI-X
self.handle_receive_interrupt(0)
}
IrqVector::Irq(_) => {
// Legacy IRQ
let (queue_irq, config_irq) = self.transport.lock().read_interrupt_status();
if queue_irq {
self.handle_receive_interrupt(0);
}
queue_irq || config_irq
IrqVector::Msi(vector) => {
self.softirq.signal(1 << vector);
true
}
// TODO non-multivec/legacy IRQ setup
IrqVector::Irq(_) => todo!(),
}
}
}
@ -274,15 +258,16 @@ impl<T: Transport + 'static> Device for VirtioNet<T> {
unsafe fn init(self: Arc<Self>, _cx: DeviceInitContext) -> Result<(), Error> {
let status = self.begin_init()?;
// TODO multiqueue
self.clone().setup_queues(1, 1)?;
self.setup_queues()?;
self.finish_init(status);
let iface =
ygg_driver_net_core::register_interface(NetworkInterfaceType::Ethernet, self.clone());
self.interface_id.init(iface.id());
self.listen(64);
let p = self.clone();
runtime::spawn(async move { p.softirq().await })?;
Ok(())
}
@ -292,14 +277,6 @@ impl<T: Transport + 'static> Device for VirtioNet<T> {
}
}
fn cvt_error(error: ygg_driver_virtio_core::error::Error) -> Error {
use ygg_driver_virtio_core::error::Error as VirtioError;
match error {
VirtioError::OsError(err) => err,
_ => Error::InvalidOperation,
}
}
pci_driver! {
matches: [device (0x1AF4:0x1000)],
driver: {
@ -311,7 +288,7 @@ pci_driver! {
let space = &info.config_space;
let transport = PciTransport::from_config_space(space).unwrap();
let device = VirtioNet::new(dma.clone(), transport, Some(info.clone()));
let device = VirtioNet::new(dma.clone(), transport, info.clone())?;
let device = Arc::new(device);

View File

@ -49,6 +49,24 @@ impl<K: Hash + Eq, V, const N: usize> HashTable<K, V, DefaultHashBuilder, N> {
}
}
pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
where
Q: Hash + Eq + ?Sized,
K: Borrow<Q>,
{
let h = self.hasher.hash_one(key);
let bucket = &mut self.buckets[h as usize % self.buckets.len()];
for i in 0..bucket.len() {
if bucket[i].0.borrow() == key {
let (_, value) = bucket.remove(i);
return Some(value);
}
}
None
}
pub fn get<Q>(&self, key: &Q) -> Option<&V>
where
Q: Hash + Eq + ?Sized,

View File

@ -247,6 +247,7 @@ impl<T: ?Sized> Drop for DmaBuffer<T> {
log::trace!("Drop DmaBuffer @ {:#x}", self.host_physical);
unsafe {
ptr::drop_in_place(self.host_pointer.as_ptr());
#[cfg(any(rust_analyzer, target_os = "none"))]
for i in 0..self.page_count {
phys::free_page(self.host_physical.add(i * L3_PAGE_SIZE));
}

View File

@ -66,6 +66,7 @@ extern crate ygg_driver_ahci;
extern crate ygg_driver_net_rtl81xx;
extern crate ygg_driver_nvme;
extern crate ygg_driver_usb_xhci;
extern crate ygg_driver_virtio_blk;
extern crate ygg_driver_virtio_gpu;
extern crate ygg_driver_virtio_net;

View File

@ -13,6 +13,7 @@ pub enum QemuNic {
pub enum QemuDrive {
Nvme,
Sata,
VirtioBlk,
}
#[derive(Debug)]
@ -115,6 +116,18 @@ impl IntoArgs for QemuDevice {
file.display()
));
}
QemuDrive::VirtioBlk => {
command.arg("-drive");
command.arg(format!("file={},if=none,id=drive0", file.display()));
command.arg("-device");
let mut device = "virtio-blk-pci".to_owned();
if let Some(serial) = serial {
device.push_str(",serial=");
device.push_str(serial);
}
device.push_str(",drive=drive0");
command.arg(device);
}
}
// command.arg("-drive");
// command.arg(format!("file={},if=none,id=drive0", file.display()));

View File

@ -41,6 +41,7 @@ enum QemuDiskInterface {
#[default]
Nvme,
Ahci,
VirtioBlk,
}
#[derive(Debug, Default, serde::Deserialize, serde::Serialize)]
@ -139,6 +140,7 @@ impl From<QemuDiskInterface> for QemuDrive {
match value {
QemuDiskInterface::Nvme => Self::Nvme,
QemuDiskInterface::Ahci => Self::Sata,
QemuDiskInterface::VirtioBlk => Self::VirtioBlk,
}
}
}