385 lines
11 KiB
Rust
385 lines
11 KiB
Rust
use core::{
|
|
future::poll_fn,
|
|
mem::MaybeUninit,
|
|
sync::atomic::{AtomicU32, Ordering},
|
|
task::{Context, Poll},
|
|
};
|
|
|
|
use alloc::{boxed::Box, string::String, sync::Arc};
|
|
use async_trait::async_trait;
|
|
use bytemuck::Zeroable;
|
|
use device_api::{device::Device, dma::DmaAllocator};
|
|
use futures_util::task::AtomicWaker;
|
|
use libk::{
|
|
device::block::BlockDevice,
|
|
dma::{DmaBuffer, DmaSlice, DmaSliceMut},
|
|
error::Error,
|
|
};
|
|
use libk_mm::{
|
|
address::PhysicalAddress, device::DeviceMemoryIo, table::MapAttributes, PageProvider,
|
|
};
|
|
use libk_util::{sync::IrqSafeSpinlock, waker::QueueWaker, OneTimeInit};
|
|
use tock_registers::interfaces::{Readable, Writeable};
|
|
|
|
use crate::{
|
|
command::{AtaCommand, AtaIdentify, AtaReadDmaEx},
|
|
data::{CommandListEntry, CommandTable, ReceivedFis, COMMAND_LIST_LENGTH},
|
|
error::AhciError,
|
|
regs::{PortRegs, CMD_PENDING, CMD_READY, IE, TFD},
|
|
AhciController, MAX_COMMANDS, MAX_PRD_SIZE, SECTOR_SIZE,
|
|
};
|
|
|
|
#[derive(Clone, Copy, PartialEq, Debug)]
|
|
pub enum PortType {
|
|
Sata,
|
|
}
|
|
|
|
struct PortInner {
|
|
regs: DeviceMemoryIo<'static, PortRegs>,
|
|
|
|
#[allow(unused)]
|
|
received_fis: DmaBuffer<ReceivedFis>,
|
|
command_list: DmaBuffer<[CommandListEntry]>,
|
|
}
|
|
|
|
pub struct PortInfo {
|
|
pub model: String,
|
|
pub serial: String,
|
|
pub lba_count: u64,
|
|
}
|
|
|
|
#[allow(unused)]
|
|
pub struct AhciPort {
|
|
inner: IrqSafeSpinlock<PortInner>,
|
|
ahci: Arc<AhciController>,
|
|
ty: PortType,
|
|
pub(crate) index: usize,
|
|
info: OneTimeInit<PortInfo>,
|
|
|
|
command_allocation: IrqSafeSpinlock<u32>,
|
|
// One command index can only be waited for by one task, so this approach is usable
|
|
command_completion: [(AtomicWaker, AtomicU32); COMMAND_LIST_LENGTH],
|
|
command_available: QueueWaker,
|
|
}
|
|
|
|
struct SubmittedCommand<'a> {
|
|
port: &'a AhciPort,
|
|
index: usize,
|
|
}
|
|
|
|
impl SubmittedCommand<'_> {
|
|
pub async fn wait_for_completion(self) -> Result<(), AhciError> {
|
|
let result = poll_fn(|cx| self.port.poll_slot(cx, self.index)).await;
|
|
|
|
// Free the command without dropping it
|
|
self.port.free_command(self.index);
|
|
core::mem::forget(self);
|
|
|
|
result
|
|
}
|
|
}
|
|
|
|
impl Drop for SubmittedCommand<'_> {
|
|
fn drop(&mut self) {
|
|
panic!(
|
|
"Cannot drop command in flight: port{}, slot{}",
|
|
self.port.index, self.index
|
|
)
|
|
}
|
|
}
|
|
|
|
impl PortInner {
|
|
fn submit_command<C: AtaCommand>(
|
|
&mut self,
|
|
dma: &dyn DmaAllocator,
|
|
index: usize,
|
|
command: &C,
|
|
) -> Result<(), AhciError> {
|
|
let list_entry = &mut self.command_list[index];
|
|
let mut table_entry =
|
|
DmaBuffer::new(dma, CommandTable::zeroed()).map_err(AhciError::MemoryError)?;
|
|
|
|
table_entry.setup_command(command)?;
|
|
*list_entry = CommandListEntry::new(table_entry.bus_address(), command.prd_count())?;
|
|
|
|
// Sync before send
|
|
// XXX do this properly
|
|
#[cfg(target_arch = "x86_64")]
|
|
unsafe {
|
|
core::arch::asm!("wbinvd");
|
|
}
|
|
|
|
// TODO deal with this async way
|
|
while self.regs.TFD.matches_any(&[TFD::BSY::SET, TFD::DRQ::SET]) {
|
|
core::hint::spin_loop();
|
|
}
|
|
|
|
let ci = self.regs.CI.get();
|
|
assert_eq!(ci & (1 << index), 0);
|
|
self.regs.CI.set(ci | (1 << index));
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl AhciPort {
|
|
pub fn create(
|
|
regs: DeviceMemoryIo<'static, PortRegs>,
|
|
ahci: Arc<AhciController>,
|
|
index: usize,
|
|
) -> Result<Arc<Self>, AhciError> {
|
|
log::debug!("Initialize port {}", index);
|
|
regs.stop()?;
|
|
|
|
if !ahci.has_64_bit {
|
|
log::error!("Handle controllers incapable of 64 bit");
|
|
return Err(AhciError::DeviceError);
|
|
}
|
|
|
|
let received_fis =
|
|
DmaBuffer::new(&*ahci.dma, ReceivedFis::zeroed()).map_err(AhciError::MemoryError)?;
|
|
let command_list =
|
|
DmaBuffer::new_slice(&*ahci.dma, CommandListEntry::zeroed(), COMMAND_LIST_LENGTH)
|
|
.map_err(AhciError::MemoryError)?;
|
|
|
|
regs.set_received_fis_address_64(received_fis.bus_address());
|
|
regs.set_command_list_address_64(command_list.bus_address());
|
|
|
|
regs.IE.write(
|
|
IE::DPE::SET
|
|
+ IE::IFE::SET
|
|
+ IE::OFE::SET
|
|
+ IE::HBDE::SET
|
|
+ IE::HBFE::SET
|
|
+ IE::TFEE::SET
|
|
+ IE::DHRE::SET,
|
|
);
|
|
|
|
regs.start()?;
|
|
|
|
let inner = PortInner {
|
|
regs,
|
|
command_list,
|
|
received_fis,
|
|
};
|
|
let command_completion =
|
|
[const { (AtomicWaker::new(), AtomicU32::new(CMD_READY)) }; MAX_COMMANDS];
|
|
let command_available = QueueWaker::new();
|
|
let command_allocation = IrqSafeSpinlock::new(0);
|
|
|
|
let port = Arc::new(Self {
|
|
inner: IrqSafeSpinlock::new(inner),
|
|
ty: PortType::Sata,
|
|
info: OneTimeInit::new(),
|
|
ahci,
|
|
index,
|
|
|
|
command_completion,
|
|
command_allocation,
|
|
command_available,
|
|
});
|
|
|
|
Ok(port)
|
|
}
|
|
|
|
pub async fn init_inner(&self) -> Result<(), AhciError> {
|
|
let identify = self
|
|
.perform_command(AtaIdentify::create(&*self.ahci.dma)?)
|
|
.await?;
|
|
|
|
let model = identify.model_number.to_string();
|
|
let serial = identify.serial_number.to_string();
|
|
let lba_count = identify.logical_sector_count();
|
|
|
|
// TODO can sector size be different from 512 in ATA?
|
|
// should logical sector size be accounted for?
|
|
// TODO test for ReadDmaEx capability (?)
|
|
|
|
self.info.init(PortInfo {
|
|
model,
|
|
serial,
|
|
lba_count,
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn info(&self) -> Option<&PortInfo> {
|
|
self.info.try_get()
|
|
}
|
|
|
|
async fn allocate_command(&self) -> usize {
|
|
poll_fn(|cx| {
|
|
self.command_available.register(cx.waker());
|
|
|
|
let mut state = self.command_allocation.lock();
|
|
if *state != u32::MAX {
|
|
self.command_available.remove(cx.waker());
|
|
|
|
for i in 0..MAX_COMMANDS {
|
|
if *state & (1 << i) == 0 {
|
|
*state |= 1 << i;
|
|
self.command_completion[i]
|
|
.1
|
|
.store(CMD_PENDING, Ordering::Release);
|
|
return Poll::Ready(i);
|
|
}
|
|
}
|
|
|
|
unreachable!()
|
|
} else {
|
|
Poll::Pending
|
|
}
|
|
})
|
|
.await
|
|
}
|
|
|
|
async fn submit<C: AtaCommand>(&self, command: &C) -> Result<SubmittedCommand, AhciError> {
|
|
if command.prd_count() > 2 {
|
|
log::warn!("TODO: AHCI doesn't like 3+ PRD transfers");
|
|
return Err(AhciError::RegionTooLarge);
|
|
}
|
|
let index = self.allocate_command().await;
|
|
if let Err(error) = self
|
|
.inner
|
|
.lock()
|
|
.submit_command(&*self.ahci.dma, index, command)
|
|
{
|
|
self.free_command(index);
|
|
return Err(error);
|
|
}
|
|
Ok(SubmittedCommand { port: self, index })
|
|
}
|
|
|
|
async fn perform_command<C: AtaCommand>(&self, command: C) -> Result<C::Response, AhciError> {
|
|
// Run the command
|
|
self.submit(&command).await?.wait_for_completion().await?;
|
|
|
|
Ok(unsafe { command.into_response() })
|
|
}
|
|
|
|
fn poll_slot(&self, cx: &mut Context<'_>, index: usize) -> Poll<Result<(), AhciError>> {
|
|
let (waker, status) = &self.command_completion[index];
|
|
|
|
match status.load(Ordering::Acquire) {
|
|
CMD_PENDING => (),
|
|
CMD_READY => return Poll::Ready(Ok(())),
|
|
_ => return Poll::Ready(Err(AhciError::DeviceError)),
|
|
}
|
|
|
|
waker.register(cx.waker());
|
|
|
|
match status.load(Ordering::Acquire) {
|
|
CMD_PENDING => Poll::Pending,
|
|
CMD_READY => Poll::Ready(Ok(())),
|
|
_ => Poll::Ready(Err(AhciError::DeviceError)),
|
|
}
|
|
}
|
|
|
|
fn free_command(&self, index: usize) {
|
|
{
|
|
let mut alloc = self.command_allocation.lock();
|
|
assert_ne!(*alloc & (1 << index), 0);
|
|
*alloc &= !(1 << index);
|
|
}
|
|
self.command_available.wake_one();
|
|
}
|
|
|
|
pub fn handle_pending_interrupts(&self) -> bool {
|
|
let inner = self.inner.lock();
|
|
|
|
let Some(status) = inner.regs.clear_interrupt() else {
|
|
return false;
|
|
};
|
|
|
|
let ci = inner.regs.CI.get();
|
|
|
|
for i in 0..MAX_COMMANDS {
|
|
if ci & (1 << i) == 0
|
|
&& self.command_completion[i].1.swap(status, Ordering::Release) == CMD_PENDING
|
|
{
|
|
log::trace!(target: "io", "port{}: completion on slot {}", self.index, i);
|
|
self.command_completion[i].0.wake();
|
|
}
|
|
}
|
|
|
|
true
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl BlockDevice for AhciPort {
|
|
fn allocate_buffer(&self, size: usize) -> Result<DmaBuffer<[MaybeUninit<u8>]>, Error> {
|
|
DmaBuffer::new_uninit_slice(&*self.ahci.dma, size)
|
|
}
|
|
|
|
async fn read_aligned(
|
|
&self,
|
|
position: u64,
|
|
buffer: DmaSliceMut<'_, MaybeUninit<u8>>,
|
|
) -> Result<(), Error> {
|
|
if buffer.len() % SECTOR_SIZE != 0 {
|
|
log::warn!("ahci: misaligned buffer size: {}", buffer.len());
|
|
return Err(Error::InvalidOperation);
|
|
}
|
|
if position % SECTOR_SIZE as u64 != 0 {
|
|
log::warn!("ahci: misaligned read");
|
|
return Err(Error::InvalidOperation);
|
|
}
|
|
|
|
let lba = position / SECTOR_SIZE as u64;
|
|
let lba_count = buffer.len() / SECTOR_SIZE;
|
|
if lba + lba_count as u64 >= self.block_count() {
|
|
log::warn!("ahci: read crosses medium end");
|
|
return Err(Error::InvalidOperation);
|
|
}
|
|
|
|
let command = AtaReadDmaEx::new(lba, lba_count, buffer);
|
|
self.submit(&command).await?.wait_for_completion().await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn write_aligned(&self, _position: u64, _buffer: DmaSlice<'_, u8>) -> Result<(), Error> {
|
|
// TODO AtaWriteDmaEx
|
|
Err(Error::NotImplemented)
|
|
}
|
|
|
|
fn block_size(&self) -> usize {
|
|
SECTOR_SIZE
|
|
}
|
|
|
|
fn block_count(&self) -> u64 {
|
|
self.info().as_ref().map(|i| i.lba_count).unwrap() as _
|
|
}
|
|
|
|
fn max_blocks_per_request(&self) -> usize {
|
|
(MAX_PRD_SIZE * 2) / SECTOR_SIZE
|
|
}
|
|
}
|
|
|
|
impl Device for AhciPort {
|
|
fn display_name(&self) -> &str {
|
|
"AHCI SATA Drive"
|
|
}
|
|
}
|
|
|
|
impl PageProvider for AhciPort {
|
|
fn get_page(&self, _offset: u64) -> Result<PhysicalAddress, Error> {
|
|
todo!()
|
|
}
|
|
|
|
fn clone_page(
|
|
&self,
|
|
_offset: u64,
|
|
_src_phys: PhysicalAddress,
|
|
_src_attrs: MapAttributes,
|
|
) -> Result<PhysicalAddress, Error> {
|
|
todo!()
|
|
}
|
|
|
|
fn release_page(&self, _offset: u64, _phys: PhysicalAddress) -> Result<(), Error> {
|
|
todo!()
|
|
}
|
|
}
|