385 lines
11 KiB
Rust

use core::{
future::poll_fn,
mem::MaybeUninit,
sync::atomic::{AtomicU32, Ordering},
task::{Context, Poll},
};
use alloc::{boxed::Box, string::String, sync::Arc};
use async_trait::async_trait;
use bytemuck::Zeroable;
use device_api::{device::Device, dma::DmaAllocator};
use futures_util::task::AtomicWaker;
use libk::{
device::block::BlockDevice,
dma::{DmaBuffer, DmaSlice, DmaSliceMut},
error::Error,
};
use libk_mm::{
address::PhysicalAddress, device::DeviceMemoryIo, table::MapAttributes, PageProvider,
};
use libk_util::{sync::IrqSafeSpinlock, waker::QueueWaker, OneTimeInit};
use tock_registers::interfaces::{Readable, Writeable};
use crate::{
command::{AtaCommand, AtaIdentify, AtaReadDmaEx},
data::{CommandListEntry, CommandTable, ReceivedFis, COMMAND_LIST_LENGTH},
error::AhciError,
regs::{PortRegs, CMD_PENDING, CMD_READY, IE, TFD},
AhciController, MAX_COMMANDS, MAX_PRD_SIZE, SECTOR_SIZE,
};
#[derive(Clone, Copy, PartialEq, Debug)]
pub enum PortType {
Sata,
}
struct PortInner {
regs: DeviceMemoryIo<'static, PortRegs>,
#[allow(unused)]
received_fis: DmaBuffer<ReceivedFis>,
command_list: DmaBuffer<[CommandListEntry]>,
}
pub struct PortInfo {
pub model: String,
pub serial: String,
pub lba_count: u64,
}
#[allow(unused)]
pub struct AhciPort {
inner: IrqSafeSpinlock<PortInner>,
ahci: Arc<AhciController>,
ty: PortType,
pub(crate) index: usize,
info: OneTimeInit<PortInfo>,
command_allocation: IrqSafeSpinlock<u32>,
// One command index can only be waited for by one task, so this approach is usable
command_completion: [(AtomicWaker, AtomicU32); COMMAND_LIST_LENGTH],
command_available: QueueWaker,
}
struct SubmittedCommand<'a> {
port: &'a AhciPort,
index: usize,
}
impl SubmittedCommand<'_> {
pub async fn wait_for_completion(self) -> Result<(), AhciError> {
let result = poll_fn(|cx| self.port.poll_slot(cx, self.index)).await;
// Free the command without dropping it
self.port.free_command(self.index);
core::mem::forget(self);
result
}
}
impl Drop for SubmittedCommand<'_> {
fn drop(&mut self) {
panic!(
"Cannot drop command in flight: port{}, slot{}",
self.port.index, self.index
)
}
}
impl PortInner {
fn submit_command<C: AtaCommand>(
&mut self,
dma: &dyn DmaAllocator,
index: usize,
command: &C,
) -> Result<(), AhciError> {
let list_entry = &mut self.command_list[index];
let mut table_entry =
DmaBuffer::new(dma, CommandTable::zeroed()).map_err(AhciError::MemoryError)?;
table_entry.setup_command(command)?;
*list_entry = CommandListEntry::new(table_entry.bus_address(), command.prd_count())?;
// Sync before send
// XXX do this properly
#[cfg(target_arch = "x86_64")]
unsafe {
core::arch::asm!("wbinvd");
}
// TODO deal with this async way
while self.regs.TFD.matches_any(&[TFD::BSY::SET, TFD::DRQ::SET]) {
core::hint::spin_loop();
}
let ci = self.regs.CI.get();
assert_eq!(ci & (1 << index), 0);
self.regs.CI.set(ci | (1 << index));
Ok(())
}
}
impl AhciPort {
pub fn create(
regs: DeviceMemoryIo<'static, PortRegs>,
ahci: Arc<AhciController>,
index: usize,
) -> Result<Arc<Self>, AhciError> {
log::debug!("Initialize port {}", index);
regs.stop()?;
if !ahci.has_64_bit {
log::error!("Handle controllers incapable of 64 bit");
return Err(AhciError::DeviceError);
}
let received_fis =
DmaBuffer::new(&*ahci.dma, ReceivedFis::zeroed()).map_err(AhciError::MemoryError)?;
let command_list =
DmaBuffer::new_slice(&*ahci.dma, CommandListEntry::zeroed(), COMMAND_LIST_LENGTH)
.map_err(AhciError::MemoryError)?;
regs.set_received_fis_address_64(received_fis.bus_address());
regs.set_command_list_address_64(command_list.bus_address());
regs.IE.write(
IE::DPE::SET
+ IE::IFE::SET
+ IE::OFE::SET
+ IE::HBDE::SET
+ IE::HBFE::SET
+ IE::TFEE::SET
+ IE::DHRE::SET,
);
regs.start()?;
let inner = PortInner {
regs,
command_list,
received_fis,
};
let command_completion =
[const { (AtomicWaker::new(), AtomicU32::new(CMD_READY)) }; MAX_COMMANDS];
let command_available = QueueWaker::new();
let command_allocation = IrqSafeSpinlock::new(0);
let port = Arc::new(Self {
inner: IrqSafeSpinlock::new(inner),
ty: PortType::Sata,
info: OneTimeInit::new(),
ahci,
index,
command_completion,
command_allocation,
command_available,
});
Ok(port)
}
pub async fn init_inner(&self) -> Result<(), AhciError> {
let identify = self
.perform_command(AtaIdentify::create(&*self.ahci.dma)?)
.await?;
let model = identify.model_number.to_string();
let serial = identify.serial_number.to_string();
let lba_count = identify.logical_sector_count();
// TODO can sector size be different from 512 in ATA?
// should logical sector size be accounted for?
// TODO test for ReadDmaEx capability (?)
self.info.init(PortInfo {
model,
serial,
lba_count,
});
Ok(())
}
pub fn info(&self) -> Option<&PortInfo> {
self.info.try_get()
}
async fn allocate_command(&self) -> usize {
poll_fn(|cx| {
self.command_available.register(cx.waker());
let mut state = self.command_allocation.lock();
if *state != u32::MAX {
self.command_available.remove(cx.waker());
for i in 0..MAX_COMMANDS {
if *state & (1 << i) == 0 {
*state |= 1 << i;
self.command_completion[i]
.1
.store(CMD_PENDING, Ordering::Release);
return Poll::Ready(i);
}
}
unreachable!()
} else {
Poll::Pending
}
})
.await
}
async fn submit<C: AtaCommand>(&self, command: &C) -> Result<SubmittedCommand, AhciError> {
if command.prd_count() > 2 {
log::warn!("TODO: AHCI doesn't like 3+ PRD transfers");
return Err(AhciError::RegionTooLarge);
}
let index = self.allocate_command().await;
if let Err(error) = self
.inner
.lock()
.submit_command(&*self.ahci.dma, index, command)
{
self.free_command(index);
return Err(error);
}
Ok(SubmittedCommand { port: self, index })
}
async fn perform_command<C: AtaCommand>(&self, command: C) -> Result<C::Response, AhciError> {
// Run the command
self.submit(&command).await?.wait_for_completion().await?;
Ok(unsafe { command.into_response() })
}
fn poll_slot(&self, cx: &mut Context<'_>, index: usize) -> Poll<Result<(), AhciError>> {
let (waker, status) = &self.command_completion[index];
match status.load(Ordering::Acquire) {
CMD_PENDING => (),
CMD_READY => return Poll::Ready(Ok(())),
_ => return Poll::Ready(Err(AhciError::DeviceError)),
}
waker.register(cx.waker());
match status.load(Ordering::Acquire) {
CMD_PENDING => Poll::Pending,
CMD_READY => Poll::Ready(Ok(())),
_ => Poll::Ready(Err(AhciError::DeviceError)),
}
}
fn free_command(&self, index: usize) {
{
let mut alloc = self.command_allocation.lock();
assert_ne!(*alloc & (1 << index), 0);
*alloc &= !(1 << index);
}
self.command_available.wake_one();
}
pub fn handle_pending_interrupts(&self) -> bool {
let inner = self.inner.lock();
let Some(status) = inner.regs.clear_interrupt() else {
return false;
};
let ci = inner.regs.CI.get();
for i in 0..MAX_COMMANDS {
if ci & (1 << i) == 0
&& self.command_completion[i].1.swap(status, Ordering::Release) == CMD_PENDING
{
log::trace!(target: "io", "port{}: completion on slot {}", self.index, i);
self.command_completion[i].0.wake();
}
}
true
}
}
#[async_trait]
impl BlockDevice for AhciPort {
fn allocate_buffer(&self, size: usize) -> Result<DmaBuffer<[MaybeUninit<u8>]>, Error> {
DmaBuffer::new_uninit_slice(&*self.ahci.dma, size)
}
async fn read_aligned(
&self,
position: u64,
buffer: DmaSliceMut<'_, MaybeUninit<u8>>,
) -> Result<(), Error> {
if buffer.len() % SECTOR_SIZE != 0 {
log::warn!("ahci: misaligned buffer size: {}", buffer.len());
return Err(Error::InvalidOperation);
}
if position % SECTOR_SIZE as u64 != 0 {
log::warn!("ahci: misaligned read");
return Err(Error::InvalidOperation);
}
let lba = position / SECTOR_SIZE as u64;
let lba_count = buffer.len() / SECTOR_SIZE;
if lba + lba_count as u64 >= self.block_count() {
log::warn!("ahci: read crosses medium end");
return Err(Error::InvalidOperation);
}
let command = AtaReadDmaEx::new(lba, lba_count, buffer);
self.submit(&command).await?.wait_for_completion().await?;
Ok(())
}
async fn write_aligned(&self, _position: u64, _buffer: DmaSlice<'_, u8>) -> Result<(), Error> {
// TODO AtaWriteDmaEx
Err(Error::NotImplemented)
}
fn block_size(&self) -> usize {
SECTOR_SIZE
}
fn block_count(&self) -> u64 {
self.info().as_ref().map(|i| i.lba_count).unwrap() as _
}
fn max_blocks_per_request(&self) -> usize {
(MAX_PRD_SIZE * 2) / SECTOR_SIZE
}
}
impl Device for AhciPort {
fn display_name(&self) -> &str {
"AHCI SATA Drive"
}
}
impl PageProvider for AhciPort {
fn get_page(&self, _offset: u64) -> Result<PhysicalAddress, Error> {
todo!()
}
fn clone_page(
&self,
_offset: u64,
_src_phys: PhysicalAddress,
_src_attrs: MapAttributes,
) -> Result<PhysicalAddress, Error> {
todo!()
}
fn release_page(&self, _offset: u64, _phys: PhysicalAddress) -> Result<(), Error> {
todo!()
}
}