use core::{ future::poll_fn, mem::MaybeUninit, sync::atomic::{AtomicBool, Ordering}, task::Poll, }; use alloc::{ sync::{Arc, Weak}, vec::Vec, }; use bytemuck::{Pod, Zeroable}; use device_api::dma::DmaAllocator; use futures_util::task::AtomicWaker; use libk::dma::{BusAddress, DmaBuffer}; use libk_util::{ queue::BoundedQueue, sync::{spin_rwlock::IrqSafeRwLock, IrqSafeSpinlock, IrqSafeSpinlockGuard}, }; use ygg_driver_usb::{ communication::UsbDirection, error::{TransferError, UsbError}, pipe::control::ControlTransferSetup, }; use yggdrasil_abi::define_bitfields; use super::{CommandExecutor, LinkTrb}; struct TransferRingInner { trbs: DmaBuffer<[MaybeUninit]>, enqueue_index: usize, dequeue_index: usize, cycle_bit: bool, } pub struct TransferRing { inner: IrqSafeSpinlock, bus_base: BusAddress, capacity: usize, slot_id: u8, endpoint_id: u8, transactions: IrqSafeRwLock>>>, shutdown: AtomicBool, } pub struct TransactionBuilder<'a> { inner: IrqSafeSpinlockGuard<'a, TransferRingInner>, ring: &'a Arc, pending: Vec, } pub struct Transaction { event_queue: BoundedQueue, event_notify: AtomicWaker, next_dequeue: usize, next_cycle: bool, } #[derive(Debug)] pub enum TransactionEvent { Status(usize, u32), Shutdown, } pub enum ControlDataStage<'a> { None, In(&'a mut DmaBuffer<[MaybeUninit]>), Out(&'a DmaBuffer<[u8]>), } impl TransferRing { pub fn new( dma: &dyn DmaAllocator, slot_id: u8, endpoint_id: u8, capacity: usize, ) -> Result { let inner = TransferRingInner::new(dma, capacity)?; let bus_base = inner.trbs.bus_address(); let transactions = (0..capacity).map(|_| None).collect(); Ok(Self { inner: IrqSafeSpinlock::new(inner), bus_base, capacity, slot_id, endpoint_id, transactions: IrqSafeRwLock::new(transactions), shutdown: AtomicBool::new(false), }) } pub fn transaction_builder(self: &Arc) -> Result { if self.shutdown.load(Ordering::Acquire) { return Err(UsbError::DeviceDisconnected); } Ok(TransactionBuilder { inner: self.inner.lock(), ring: self, pending: Vec::new(), }) } async fn handle_stall( &self, executor: &E, result: &Result, transaction: &Transaction, ) { if let Err(TransferError::Stall) = result { let dequeue = self .bus_base .add(transaction.next_dequeue * size_of::()); if let Err(rerror) = executor .reset_endpoint( self.slot_id, self.endpoint_id, dequeue, transaction.next_cycle, ) .await { log::error!( "xhci: could not reset endpoint after stall {}:{}: {rerror:?}", self.slot_id, self.endpoint_id ); self.shutdown.store(true, Ordering::Release); } } } pub async fn normal_transfer( self: &Arc, executor: &E, buffer: BusAddress, length: usize, ) -> Result { if length == 0 { return Ok(0); } let mut builder = self.transaction_builder()?; let last_data_trb = builder.enqueue_normal(buffer, length)?; let transaction = builder.submit(executor); let status = transaction.wait_normal(last_data_trb).await; builder.inner.dequeue_index = builder.inner.enqueue_index; self.handle_stall(executor, &status, &transaction).await; let residual = status?; Ok(length.saturating_sub(residual)) } // Helper functions, shorthands for transaction_builder().....finish() + kick() pub async fn control_transfer( self: &Arc, executor: &E, setup: ControlTransferSetup, data: ControlDataStage<'_>, ) -> Result { let mut builder = self.transaction_builder()?; let data_len = data.len(); let (setup, data, status) = builder.enqueue_control(setup, data)?; let transaction = builder.submit(executor); // TODO timeout let status = transaction.wait_control(setup, data, status).await; builder.inner.dequeue_index = builder.inner.enqueue_index; self.handle_stall(executor, &status, &transaction).await; let residual = status?; Ok(data_len.saturating_sub(residual)) } pub fn kick(&self, executor: &E) { executor.ring_doorbell(self.slot_id as usize, self.endpoint_id); } pub fn shutdown(&self) { self.shutdown.store(true, Ordering::Release); // Shutdown transactions let transactions = self.transactions.read(); for index in 0..self.capacity { if let Some(tx) = transactions[index].as_ref().and_then(Weak::upgrade) { tx.shutdown(); } } } pub fn notify(&self, address: BusAddress, status: u32) { if status == 0 { return; } if address < self.bus_base || address - self.bus_base >= size_of::() * self.capacity { log::warn!("xhci: event outside of trb array: {address:#x}"); return; } let index = (address - self.bus_base) / size_of::(); if let Some(tx) = self.transactions.write()[index] .take() .and_then(|tx| tx.upgrade()) { tx.notify(index, status); } else { log::warn!("xhci: no transaction @ {index} to notify"); } } pub fn bus_address(&self) -> BusAddress { self.bus_base } } impl TransactionBuilder<'_> { const TRB_SIZE_LIMIT: usize = 65536; pub fn enqueue(&mut self, trb: C, ioc: bool) -> Result { let address = self.inner.enqueue(trb, ioc)?; self.pending.push(address); Ok((address - self.ring.bus_base) / size_of::()) } pub fn enqueue_normal(&mut self, buffer: BusAddress, length: usize) -> Result { let trb_count = length.div_ceil(Self::TRB_SIZE_LIMIT); if self.inner.free_capacity() <= trb_count || trb_count == 0 { return Err(UsbError::DeviceBusy); } let mut last_trb = 0; for i in 0..trb_count { let offset = i * Self::TRB_SIZE_LIMIT; let amount = (length - offset).min(Self::TRB_SIZE_LIMIT); last_trb = self .enqueue( NormalTransferTrb::new(buffer.add(offset), amount), i == trb_count - 1, ) .unwrap(); } Ok(last_trb) } pub fn enqueue_control( &mut self, setup: ControlTransferSetup, buffer: ControlDataStage, ) -> Result<(usize, Option, usize), UsbError> { // Check ring capacity first // TODO larger DATA stages let trb_count = 2 + if buffer.len() != 0 { 1 } else { 0 }; if self.inner.free_capacity() <= trb_count { return Err(UsbError::DeviceBusy); } // unwrap()s are okay here, capacity checked above let setup_stage = self .enqueue(ControlTransferSetupTrb::new(setup), true) .unwrap(); let data_stage = match buffer { ControlDataStage::None => None, ControlDataStage::In(buffer) => { let index = self.enqueue( ControlTransferDataTrb::new( buffer.bus_address(), buffer.len(), UsbDirection::In, ), true, )?; Some(index) } ControlDataStage::Out(buffer) => { let index = self.enqueue( ControlTransferDataTrb::new( buffer.bus_address(), buffer.len(), UsbDirection::Out, ), true, )?; Some(index) } }; let status_stage = self .enqueue(ControlTransferStatusTrb::new(UsbDirection::In), true) .unwrap(); Ok((setup_stage, data_stage, status_stage)) } pub fn finish(&mut self) -> Arc { let transaction = Arc::new(Transaction { event_queue: BoundedQueue::new(self.pending.len()), event_notify: AtomicWaker::new(), next_dequeue: self.inner.enqueue_index, next_cycle: self.inner.cycle_bit, }); let mut transactions = self.ring.transactions.write(); for &pending in self.pending.iter() { let index = (pending - self.ring.bus_base) / size_of::(); transactions[index] = Some(Arc::downgrade(&transaction)); } transaction } pub fn submit(&mut self, executor: &E) -> Arc { let transaction = self.finish(); self.ring.kick(executor); transaction } } impl TransferRingInner { fn new(dma: &dyn DmaAllocator, capacity: usize) -> Result { let trbs = DmaBuffer::new_zeroed_slice(dma, capacity).map_err(UsbError::MemoryError)?; Ok(Self { trbs, enqueue_index: 0, dequeue_index: 0, cycle_bit: true, }) } fn enqueue(&mut self, trb: C, ioc: bool) -> Result { if (self.enqueue_index + 1) % (self.trbs.len() - 1) == self.dequeue_index { log::warn!("xhci: transfer ring full"); return Err(UsbError::DeviceBusy); } let mut raw: RawTransferTrb = bytemuck::cast(trb); raw.flags.set_ty(C::TRB_TYPE as u32); raw.flags.set_cycle(self.cycle_bit); raw.flags.set_ioc(ioc); self.trbs[self.enqueue_index].write(raw); let address = self .trbs .bus_address() .add(self.enqueue_index * size_of::()); self.enqueue_index += 1; if self.enqueue_index >= self.trbs.len() - 1 { self.enqueue_link(); self.cycle_bit = !self.cycle_bit; self.enqueue_index = 0; } Ok(address) } fn enqueue_link(&mut self) { let base = self.trbs.bus_address(); let link = LinkTrb::new(base, self.cycle_bit); self.trbs[self.enqueue_index].write(bytemuck::cast(link)); } fn free_capacity(&self) -> usize { self.enqueue_index + self.trbs.len() - self.dequeue_index } } impl ControlDataStage<'_> { pub fn len(&self) -> usize { match self { Self::None => 0, Self::In(buf) => buf.len(), Self::Out(buf) => buf.len(), } } } impl Transaction { pub fn notify(&self, trb_index: usize, status: u32) { self.event_queue .push(TransactionEvent::Status(trb_index, status)) .ok(); self.event_notify.wake(); } pub fn shutdown(&self) { self.event_queue.push(TransactionEvent::Shutdown).ok(); self.event_notify.wake(); } pub async fn wait_normal(&self, last_trb: usize) -> Result { loop { let event = self.next_event().await; let status = event.to_result(); match event { TransactionEvent::Status(trb_index, _) => { if status.is_err() || trb_index == last_trb { break status; } } TransactionEvent::Shutdown => { log::error!("xhci: abort transaction, endpoint shutdown"); return Err(TransferError::UsbTransactionError); } } } } pub async fn wait_trb(&self, trb: usize) -> Result { let event = self.next_event().await; match event { TransactionEvent::Status(trb_index, _) => { if trb_index != trb { return Err(TransferError::InvalidTransfer); } } TransactionEvent::Shutdown => { log::error!("xhci: abort transaction, endpoint shutdown"); return Err(TransferError::UsbTransactionError); } } event.to_result() } pub async fn wait_control( &self, setup_trb: usize, last_data_trb: Option, status_trb: usize, ) -> Result { self.wait_trb(setup_trb).await?; let residual = if let Some(last_data_trb) = last_data_trb { self.wait_normal(last_data_trb).await? } else { 0 }; self.wait_trb(status_trb).await?; Ok(residual) } pub async fn next_event(&self) -> TransactionEvent { poll_fn(|cx| { if let Some(event) = self.event_queue.pop() { Poll::Ready(event) } else { self.event_notify.register(cx.waker()); Poll::Pending } }) .await } } impl TransactionEvent { pub fn to_result(&self) -> Result { match self { &Self::Status(_, status) => match status >> 24 { 1 => Ok((status as usize) & 0xFFFFFF), 4 => Err(TransferError::UsbTransactionError), 6 => Err(TransferError::Stall), 13 => Err(TransferError::ShortPacket((status as usize) & 0xFFFFFF)), code => Err(TransferError::Other(code as u8)), }, Self::Shutdown => Err(TransferError::UsbTransactionError), } } } // TRB definitions define_bitfields! { pub RawTransferFlags : u32 { (10..16) => ty + set_ty, 5 => ioc + set_ioc, 0 => cycle + set_cycle } } define_bitfields! { pub NormalTransferFlags: u64 { (0..16) => trb_length, } } define_bitfields! { pub ControlTransferSetupRequest : u64 { (0..8) => bm_request_type, (8..16) => b_request, (16..32) => w_value, (32..48) => w_index, (48..64) => w_length } } define_bitfields! { pub ControlTransferSetupFlags : u64 { (0..16) => trb_length, 38 => immediate_data, (48..50) => transfer_type } } define_bitfields! { pub ControlTransferDataFlags : u64 { (0..16) => trb_length, 48 => direction, } } define_bitfields! { pub ControlTransferStatusFlags : u32 { 16 => direction, } } #[derive(Clone, Copy, Debug, Pod, Zeroable)] #[repr(C, align(16))] pub struct NormalTransferTrb { pub buffer: BusAddress, pub flags: NormalTransferFlags, } #[derive(Clone, Copy, Debug, Pod, Zeroable)] #[repr(C, align(16))] pub struct ControlTransferSetupTrb { pub request: ControlTransferSetupRequest, pub flags: ControlTransferSetupFlags, } #[derive(Clone, Copy, Debug, Pod, Zeroable)] #[repr(C, align(16))] pub struct ControlTransferDataTrb { pub buffer: BusAddress, pub flags: ControlTransferDataFlags, } #[derive(Clone, Copy, Debug, Pod, Zeroable)] #[repr(C, align(16))] pub struct ControlTransferStatusTrb { _0: [u32; 3], pub flags: ControlTransferStatusFlags, } #[derive(Clone, Copy, Debug, Pod, Zeroable)] #[repr(C, align(16))] pub struct RawTransferTrb { _0: [u32; 3], pub flags: RawTransferFlags, } pub trait TransferTrb: Pod { const TRB_TYPE: u8; } impl NormalTransferTrb { pub fn new(buffer: BusAddress, length: usize) -> Self { Self { buffer, flags: NormalTransferFlags::new(length.try_into().unwrap()), } } } impl ControlTransferSetupTrb { pub const fn new(setup: ControlTransferSetup) -> Self { Self { request: ControlTransferSetupRequest::new( setup.bm_request_type as _, setup.b_request as _, setup.w_value as _, setup.w_index as _, setup.w_length as _, ), flags: ControlTransferSetupFlags::new(8, true, 3), } } } impl ControlTransferDataTrb { pub fn new(buffer: BusAddress, length: usize, direction: UsbDirection) -> Self { Self { buffer, flags: ControlTransferDataFlags::new( length.try_into().unwrap(), direction.is_device_to_host(), ), } } } impl ControlTransferStatusTrb { pub const fn new(direction: UsbDirection) -> Self { Self { _0: [0; 3], flags: ControlTransferStatusFlags::new(direction.is_device_to_host()), } } } impl TransferTrb for NormalTransferTrb { const TRB_TYPE: u8 = 1; } impl TransferTrb for ControlTransferSetupTrb { const TRB_TYPE: u8 = 2; } impl TransferTrb for ControlTransferDataTrb { const TRB_TYPE: u8 = 3; } impl TransferTrb for ControlTransferStatusTrb { const TRB_TYPE: u8 = 4; }