Compare commits

...

8 Commits

74 changed files with 2199 additions and 814 deletions
+58
View File
@@ -0,0 +1,58 @@
use core::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct UsbRoute {
bus: u16,
ports: [u8; 8],
len: u8,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct UsbBusAddress {
pub bus: u16,
pub device: u8,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct UsbInterfaceAddress {
pub device: UsbBusAddress,
pub interface: u8,
}
impl UsbBusAddress {
pub fn with_interface(self, interface: u8) -> UsbInterfaceAddress {
UsbInterfaceAddress {
device: self,
interface,
}
}
}
impl fmt::Display for UsbBusAddress {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "<Bus {} Device {}>", self.bus, self.device)
}
}
impl fmt::Display for UsbInterfaceAddress {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"<Bus {} Device {} Interface {}>",
self.device.bus, self.device.device, self.interface
)
}
}
impl fmt::Display for UsbRoute {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}-", self.bus)?;
for (i, &port) in self.ports[..self.len as usize].iter().enumerate() {
if i != 0 {
write!(f, ".")?;
}
write!(f, "{port}")?;
}
Ok(())
}
}
+34 -8
View File
@@ -1,34 +1,56 @@
use core::sync::atomic::{AtomicU16, Ordering};
use alloc::{collections::BTreeMap, sync::Arc};
use libk_util::{queue::UnboundedMpmcQueue, sync::spin_rwlock::IrqSafeRwLock};
use alloc::{collections::BTreeMap, format, sync::Arc};
use libk_util::{queue::UnboundedMpmcQueue, sync::spin_rwlock::IrqSafeRwLock, OneTimeInit};
use crate::{
address::UsbBusAddress,
class_driver,
device::{UsbBusAddress, UsbDeviceAccess},
device::{UsbDevice, UsbDeviceAccess},
sysfs::{self, UsbBusKObject},
UsbHostController,
};
pub struct UsbBusWrapper {
pub(crate) hc: Arc<dyn UsbHostController>,
pub(crate) index: u16,
kobject: OneTimeInit<UsbBusKObject>,
}
pub struct UsbBusManager {
busses: IrqSafeRwLock<BTreeMap<u16, Arc<dyn UsbHostController>>>,
busses: IrqSafeRwLock<BTreeMap<u16, Arc<UsbBusWrapper>>>,
devices: IrqSafeRwLock<BTreeMap<UsbBusAddress, Arc<UsbDeviceAccess>>>,
last_bus_address: AtomicU16,
}
impl UsbBusWrapper {
pub fn kobject(&self) -> &UsbBusKObject {
self.kobject.get()
}
}
impl UsbBusManager {
pub fn register_bus(hc: Arc<dyn UsbHostController>) -> u16 {
pub fn register_bus(hc: Arc<dyn UsbHostController>) -> (u16, Arc<UsbBusWrapper>) {
let i = BUS_MANAGER.last_bus_address.fetch_add(1, Ordering::AcqRel);
BUS_MANAGER.busses.write().insert(i, hc);
i
let wrapper = Arc::new(UsbBusWrapper {
hc,
index: i,
kobject: OneTimeInit::new(),
});
BUS_MANAGER.busses.write().insert(i, wrapper.clone());
wrapper.kobject.init(sysfs::register_bus_kobject(&wrapper));
(i, wrapper)
}
pub fn register_device(device: Arc<UsbDeviceAccess>) {
log::info!("usb: register device {}", device.bus_address());
BUS_MANAGER
.devices
.write()
.insert(device.bus_address(), device.clone());
device.kobject.init(sysfs::register_device_kobject(&device));
QUEUE.push_back(device);
}
@@ -51,7 +73,11 @@ pub async fn bus_handler() {
new_device.bus_address()
);
class_driver::spawn_driver(new_device).await.ok();
let address = new_device.bus_address();
if let Err(error) = class_driver::setup_device(new_device).await {
log::warn!("USB device {address} setup error: {error:?}",);
}
// class_driver::spawn_driver(new_device).await.ok();
}
}
@@ -4,9 +4,12 @@ use alloc::{boxed::Box, sync::Arc};
use async_trait::async_trait;
use yggdrasil_abi::io::{KeyboardKey, KeyboardKeyEvent};
use crate::{device::UsbDeviceAccess, error::UsbError, info::UsbDeviceClass};
use super::{UsbClassInfo, UsbDriver};
use crate::{
class_driver::{UsbInterfaceClass, UsbInterfaceDriver},
device::UsbDeviceAccess,
error::UsbError,
info::UsbInterfaceInfo,
};
pub struct UsbHidKeyboardDriver;
@@ -125,10 +128,14 @@ impl KeyboardState {
}
#[async_trait]
impl UsbDriver for UsbHidKeyboardDriver {
async fn run(self: Arc<Self>, device: Arc<UsbDeviceAccess>) -> Result<(), UsbError> {
impl UsbInterfaceDriver for UsbHidKeyboardDriver {
async fn run(
self: Arc<Self>,
device: Arc<UsbDeviceAccess>,
interface: UsbInterfaceInfo,
) -> Result<(), UsbError> {
// TODO not sure whether to use boot protocol (easy) or GetReport
let config = device.select_configuration(|_| true).await?.unwrap();
let config = device.current_configuration().unwrap();
log::info!("Setup HID keyboard");
let pipe = device
@@ -156,7 +163,7 @@ impl UsbDriver for UsbHidKeyboardDriver {
for &event in events {
log::trace!("Generic Keyboard: {:?}", event);
ygg_driver_input::send_event(event);
ygg_driver_input::send_keyboard_event(event);
}
}
}
@@ -165,12 +172,18 @@ impl UsbDriver for UsbHidKeyboardDriver {
"USB HID Keyboard"
}
fn probe(&self, class: &UsbClassInfo, _device: &UsbDeviceAccess) -> bool {
log::info!(
"class = {:?}, subclass = {:02x}",
class.class,
class.subclass
);
class.class == UsbDeviceClass::Hid && (class.subclass == 0x00 || class.subclass == 0x01)
fn probe(
&self,
device: &UsbDeviceAccess,
interface: &UsbInterfaceInfo,
class: UsbInterfaceClass,
) -> bool {
let _ = (device, interface);
class
== UsbInterfaceClass {
class: 3,
subclass: 1,
protocol: 1,
}
}
}
@@ -0,0 +1,64 @@
use alloc::{boxed::Box, sync::Arc};
use async_trait::async_trait;
use yggdrasil_abi::io::{ButtonMask, MouseEvent};
use crate::{
class_driver::{UsbInterfaceClass, UsbInterfaceDriver},
device::UsbDeviceAccess,
error::UsbError,
info::UsbInterfaceInfo,
};
pub struct UsbHidMouseDriver;
#[async_trait]
impl UsbInterfaceDriver for UsbHidMouseDriver {
async fn run(
self: Arc<Self>,
device: Arc<UsbDeviceAccess>,
interface: UsbInterfaceInfo,
) -> Result<(), UsbError> {
let config = device.current_configuration().unwrap();
log::info!("Setup HID mouse");
let pipe = device
.open_interrupt_in_pipe(1, config.endpoints[0].max_packet_size as u16)
.await?;
let mut buffer = [0; 16];
let mut button_state = 0;
loop {
let len = pipe.read(&mut buffer).await?;
if len < 4 {
continue;
}
let event = MouseEvent {
buttons: ButtonMask(buffer[0]),
dx: (buffer[1] as i8) as i32,
dy: (buffer[2] as i8) as i32,
};
ygg_driver_input::send_mouse_event(event);
}
}
fn name(&self) -> &'static str {
"USB HID Mouse"
}
fn probe(
&self,
device: &UsbDeviceAccess,
interface: &UsbInterfaceInfo,
class: UsbInterfaceClass,
) -> bool {
class
== UsbInterfaceClass {
class: 3,
subclass: 1,
protocol: 2,
}
}
}
@@ -1,273 +1,273 @@
use core::mem::MaybeUninit;
use alloc::{boxed::Box, sync::Arc};
use async_trait::async_trait;
use bytemuck::{Pod, Zeroable};
use libk::{
dma::{DmaBuffer, DmaSliceMut},
error::Error,
};
use ygg_driver_scsi::{transport::ScsiTransport, ScsiEnclosure};
use crate::{
communication::UsbDirection,
device::{UsbDeviceAccess, UsbDeviceDetachHandler},
error::UsbError,
info::{UsbDeviceClass, UsbEndpointType},
pipe::{
control::{ControlTransferSetup, UsbClassSpecificRequest},
normal::{UsbBulkInPipeAccess, UsbBulkOutPipeAccess},
},
};
use super::{UsbClassInfo, UsbDriver};
pub struct UsbMassStorageDriverBulkOnly;
#[derive(Debug, Clone, Copy, Zeroable, Pod)]
#[repr(C)]
struct Cbw {
signature: u32, // 0x00
tag: u32, // 0x04
transfer_length: u32, // 0x08
flags: u8, // 0x0C
lun: u8, // 0x0D
cb_length: u8, // 0x0E
cb_data: [u8; 16], // 0x0F
// Not sent
_0: u8,
}
#[derive(Debug, Clone, Copy, Zeroable, Pod)]
#[repr(C)]
struct Csw {
signature: u32,
tag: u32,
data_residue: u32,
status: u8,
_0: [u8; 3],
}
struct Bbb {
#[allow(unused)]
device: Arc<UsbDeviceAccess>,
in_pipe: UsbBulkInPipeAccess,
out_pipe: UsbBulkOutPipeAccess,
last_tag: u32,
}
struct DetachHandler(Arc<ScsiEnclosure>);
impl Bbb {
pub fn new(
device: Arc<UsbDeviceAccess>,
in_pipe: UsbBulkInPipeAccess,
out_pipe: UsbBulkOutPipeAccess,
) -> Result<Self, UsbError> {
Ok(Self {
device,
in_pipe,
out_pipe,
last_tag: 0,
})
}
}
impl Bbb {
async fn send_cbw(
&mut self,
lun: u8,
host_to_dev: bool,
command: &[u8],
response_len: usize,
) -> Result<u32, Error> {
self.last_tag = self.last_tag.wrapping_add(1);
let flags = if !host_to_dev { 1 << 7 } else { 0 };
let tag = self.last_tag;
let mut cbw_bytes = [0; 32];
let cbw = bytemuck::from_bytes_mut::<Cbw>(&mut cbw_bytes);
cbw.signature = 0x43425355;
cbw.transfer_length = response_len as u32;
cbw.flags = flags;
cbw.tag = tag;
cbw.lun = lun;
cbw.cb_length = command.len() as u8;
cbw.cb_data[..command.len()].copy_from_slice(command);
self.out_pipe
.write(&cbw_bytes[..31])
.await
.inspect_err(|error| log::error!("msc: CBW send error: {error:?}"))?;
Ok(tag)
}
async fn read_csw(&mut self, tag: u32) -> Result<(), Error> {
let mut csw_bytes = [0; 16];
self.in_pipe
.read_exact(&mut csw_bytes[..13])
.await
.inspect_err(|error| log::error!("msc: CSW receive error: {error:?}"))?;
let csw = bytemuck::from_bytes::<Csw>(&csw_bytes);
if csw.signature != 0x53425355 {
log::warn!("msc: invalid csw signature");
return Err(Error::InvalidArgument);
}
if csw.tag != tag {
let csw_tag = csw.tag;
log::warn!("msc: invalid csw tag (got {}, expected {tag})", csw_tag);
return Err(Error::InvalidArgument);
}
if csw.status != 0x00 {
return Err(Error::InvalidArgument);
}
Ok(())
}
async fn read_response_data(
&mut self,
buffer: DmaSliceMut<'_, MaybeUninit<u8>>,
) -> Result<usize, Error> {
if buffer.len() == 0 {
return Ok(0);
}
let len = self
.in_pipe
.read_dma(buffer)
.await
.inspect_err(|error| log::error!("msc: DMA read error: {error:?}"))?;
Ok(len)
}
}
#[async_trait]
impl ScsiTransport for Bbb {
fn allocate_buffer(&self, size: usize) -> Result<DmaBuffer<[MaybeUninit<u8>]>, Error> {
Ok(self.in_pipe.allocate_dma_buffer(size)?)
}
async fn perform_request_raw(
&mut self,
lun: u8,
request_data: &[u8],
response_buffer: DmaSliceMut<'_, MaybeUninit<u8>>,
) -> Result<usize, Error> {
if request_data.len() > 16 || response_buffer.len() > self.max_bytes_per_request() {
return Err(Error::InvalidArgument);
}
let tag = self
.send_cbw(lun, false, request_data, response_buffer.len())
.await?;
let response_len = self.read_response_data(response_buffer).await?;
self.read_csw(tag).await?;
Ok(response_len)
}
fn max_bytes_per_request(&self) -> usize {
32768
}
}
impl UsbDeviceDetachHandler for DetachHandler {
fn handle_device_detach(&self) {
log::info!("Mass storage detached");
self.0.detach();
}
}
#[derive(Debug, Pod, Zeroable, Clone, Copy)]
#[repr(C)]
pub struct BulkOnlyMassStorageReset;
#[derive(Debug, Pod, Zeroable, Clone, Copy)]
#[repr(C)]
pub struct GetMaxLun;
impl UsbClassSpecificRequest for BulkOnlyMassStorageReset {
const BM_REQUEST_TYPE: u8 = 0b00100001;
const B_REQUEST: u8 = 0b11111111;
}
impl UsbClassSpecificRequest for GetMaxLun {
const BM_REQUEST_TYPE: u8 = 0b10100001;
const B_REQUEST: u8 = 0b11111110;
}
#[async_trait]
impl UsbDriver for UsbMassStorageDriverBulkOnly {
async fn run(self: Arc<Self>, device: Arc<UsbDeviceAccess>) -> Result<(), UsbError> {
// TODO filter to only accept BBB config
let config = device.select_configuration(|_| true).await?.unwrap();
// Bulk-in, bulk-out
assert_eq!(config.endpoints.len(), 2);
let control_pipe = device.control_pipe();
let (in_index, in_info) = config
.find_endpoint(|ep| ep.is(UsbEndpointType::Bulk, UsbDirection::In))
.ok_or(UsbError::InvalidConfiguration)?;
let (out_index, out_info) = config
.find_endpoint(|ep| ep.is(UsbEndpointType::Bulk, UsbDirection::Out))
.ok_or(UsbError::InvalidConfiguration)?;
let in_pipe = device
.open_bulk_in_pipe(in_index, in_info.max_packet_size as u16)
.await?;
let out_pipe = device
.open_bulk_out_pipe(out_index, out_info.max_packet_size as u16)
.await?;
// Perform a Bulk-Only Mass Storage Reset
// TODO interface id?
control_pipe
.control_transfer(ControlTransferSetup {
bm_request_type: BulkOnlyMassStorageReset::BM_REQUEST_TYPE,
b_request: BulkOnlyMassStorageReset::B_REQUEST,
w_value: 0,
w_index: 0,
w_length: 0,
})
.await?;
// Get max LUN
// TODO on devices which do not support multiple LUNs, this command may STALL
let mut buffer = [MaybeUninit::uninit()];
let len = control_pipe
.control_transfer_in(
ControlTransferSetup {
bm_request_type: GetMaxLun::BM_REQUEST_TYPE,
b_request: GetMaxLun::B_REQUEST,
w_value: 0,
w_index: 0,
w_length: 1,
},
&mut buffer,
)
.await?;
let max_lun = if len < 1 {
0
} else {
unsafe { buffer[0].assume_init() }
};
let bbb = Bbb::new(device.clone(), in_pipe, out_pipe)?;
let scsi = ScsiEnclosure::setup(Box::new(bbb), max_lun as usize + 1)
.await
.inspect_err(|error| log::error!("msc: scsi error {error:?}"))
.map_err(|_| UsbError::DriverError)?;
let detach = DetachHandler(scsi.clone());
device.set_detach_handler(Arc::new(detach));
Ok(())
}
fn name(&self) -> &'static str {
"USB Mass Storage"
}
fn probe(&self, class: &UsbClassInfo, _device: &UsbDeviceAccess) -> bool {
// TODO support other protocols
class.class == UsbDeviceClass::MassStorage && class.interface_protocol_number == 0x50
}
}
// use core::mem::MaybeUninit;
//
// use alloc::{boxed::Box, sync::Arc};
// use async_trait::async_trait;
// use bytemuck::{Pod, Zeroable};
// use libk::{
// dma::{DmaBuffer, DmaSliceMut},
// error::Error,
// };
// use ygg_driver_scsi::{transport::ScsiTransport, ScsiEnclosure};
//
// use crate::{
// communication::UsbDirection,
// device::{UsbDeviceAccess, UsbDeviceDetachHandler},
// error::UsbError,
// info::{UsbDeviceClass, UsbEndpointType},
// pipe::{
// control::{ControlTransferSetup, UsbClassSpecificRequest},
// normal::{UsbBulkInPipeAccess, UsbBulkOutPipeAccess},
// },
// };
//
// use super::{UsbClassInfo, UsbDriver};
//
// pub struct UsbMassStorageDriverBulkOnly;
//
// #[derive(Debug, Clone, Copy, Zeroable, Pod)]
// #[repr(C)]
// struct Cbw {
// signature: u32, // 0x00
// tag: u32, // 0x04
// transfer_length: u32, // 0x08
// flags: u8, // 0x0C
// lun: u8, // 0x0D
// cb_length: u8, // 0x0E
// cb_data: [u8; 16], // 0x0F
// // Not sent
// _0: u8,
// }
//
// #[derive(Debug, Clone, Copy, Zeroable, Pod)]
// #[repr(C)]
// struct Csw {
// signature: u32,
// tag: u32,
// data_residue: u32,
// status: u8,
// _0: [u8; 3],
// }
//
// struct Bbb {
// #[allow(unused)]
// device: Arc<UsbDeviceAccess>,
// in_pipe: UsbBulkInPipeAccess,
// out_pipe: UsbBulkOutPipeAccess,
// last_tag: u32,
// }
//
// struct DetachHandler(Arc<ScsiEnclosure>);
//
// impl Bbb {
// pub fn new(
// device: Arc<UsbDeviceAccess>,
// in_pipe: UsbBulkInPipeAccess,
// out_pipe: UsbBulkOutPipeAccess,
// ) -> Result<Self, UsbError> {
// Ok(Self {
// device,
// in_pipe,
// out_pipe,
// last_tag: 0,
// })
// }
// }
//
// impl Bbb {
// async fn send_cbw(
// &mut self,
// lun: u8,
// host_to_dev: bool,
// command: &[u8],
// response_len: usize,
// ) -> Result<u32, Error> {
// self.last_tag = self.last_tag.wrapping_add(1);
//
// let flags = if !host_to_dev { 1 << 7 } else { 0 };
// let tag = self.last_tag;
// let mut cbw_bytes = [0; 32];
// let cbw = bytemuck::from_bytes_mut::<Cbw>(&mut cbw_bytes);
//
// cbw.signature = 0x43425355;
// cbw.transfer_length = response_len as u32;
// cbw.flags = flags;
// cbw.tag = tag;
// cbw.lun = lun;
// cbw.cb_length = command.len() as u8;
// cbw.cb_data[..command.len()].copy_from_slice(command);
//
// self.out_pipe
// .write(&cbw_bytes[..31])
// .await
// .inspect_err(|error| log::error!("msc: CBW send error: {error:?}"))?;
//
// Ok(tag)
// }
//
// async fn read_csw(&mut self, tag: u32) -> Result<(), Error> {
// let mut csw_bytes = [0; 16];
// self.in_pipe
// .read_exact(&mut csw_bytes[..13])
// .await
// .inspect_err(|error| log::error!("msc: CSW receive error: {error:?}"))?;
// let csw = bytemuck::from_bytes::<Csw>(&csw_bytes);
//
// if csw.signature != 0x53425355 {
// log::warn!("msc: invalid csw signature");
// return Err(Error::InvalidArgument);
// }
// if csw.tag != tag {
// let csw_tag = csw.tag;
// log::warn!("msc: invalid csw tag (got {}, expected {tag})", csw_tag);
// return Err(Error::InvalidArgument);
// }
// if csw.status != 0x00 {
// return Err(Error::InvalidArgument);
// }
// Ok(())
// }
//
// async fn read_response_data(
// &mut self,
// buffer: DmaSliceMut<'_, MaybeUninit<u8>>,
// ) -> Result<usize, Error> {
// if buffer.len() == 0 {
// return Ok(0);
// }
// let len = self
// .in_pipe
// .read_dma(buffer)
// .await
// .inspect_err(|error| log::error!("msc: DMA read error: {error:?}"))?;
// Ok(len)
// }
// }
//
// #[async_trait]
// impl ScsiTransport for Bbb {
// fn allocate_buffer(&self, size: usize) -> Result<DmaBuffer<[MaybeUninit<u8>]>, Error> {
// Ok(self.in_pipe.allocate_dma_buffer(size)?)
// }
//
// async fn perform_request_raw(
// &mut self,
// lun: u8,
// request_data: &[u8],
// response_buffer: DmaSliceMut<'_, MaybeUninit<u8>>,
// ) -> Result<usize, Error> {
// if request_data.len() > 16 || response_buffer.len() > self.max_bytes_per_request() {
// return Err(Error::InvalidArgument);
// }
//
// let tag = self
// .send_cbw(lun, false, request_data, response_buffer.len())
// .await?;
// let response_len = self.read_response_data(response_buffer).await?;
// self.read_csw(tag).await?;
// Ok(response_len)
// }
//
// fn max_bytes_per_request(&self) -> usize {
// 32768
// }
// }
//
// impl UsbDeviceDetachHandler for DetachHandler {
// fn handle_device_detach(&self) {
// log::info!("Mass storage detached");
// self.0.detach();
// }
// }
//
// #[derive(Debug, Pod, Zeroable, Clone, Copy)]
// #[repr(C)]
// pub struct BulkOnlyMassStorageReset;
//
// #[derive(Debug, Pod, Zeroable, Clone, Copy)]
// #[repr(C)]
// pub struct GetMaxLun;
//
// impl UsbClassSpecificRequest for BulkOnlyMassStorageReset {
// const BM_REQUEST_TYPE: u8 = 0b00100001;
// const B_REQUEST: u8 = 0b11111111;
// }
//
// impl UsbClassSpecificRequest for GetMaxLun {
// const BM_REQUEST_TYPE: u8 = 0b10100001;
// const B_REQUEST: u8 = 0b11111110;
// }
//
// #[async_trait]
// impl UsbDriver for UsbMassStorageDriverBulkOnly {
// async fn run(self: Arc<Self>, device: Arc<UsbDeviceAccess>) -> Result<(), UsbError> {
// // TODO filter to only accept BBB config
// let config = device.select_configuration(|_| true).await?.unwrap();
// // Bulk-in, bulk-out
// assert_eq!(config.endpoints.len(), 2);
// let control_pipe = device.control_pipe();
// let (in_index, in_info) = config
// .find_endpoint(|ep| ep.is(UsbEndpointType::Bulk, UsbDirection::In))
// .ok_or(UsbError::InvalidConfiguration)?;
// let (out_index, out_info) = config
// .find_endpoint(|ep| ep.is(UsbEndpointType::Bulk, UsbDirection::Out))
// .ok_or(UsbError::InvalidConfiguration)?;
// let in_pipe = device
// .open_bulk_in_pipe(in_index, in_info.max_packet_size as u16)
// .await?;
// let out_pipe = device
// .open_bulk_out_pipe(out_index, out_info.max_packet_size as u16)
// .await?;
//
// // Perform a Bulk-Only Mass Storage Reset
// // TODO interface id?
// control_pipe
// .control_transfer(ControlTransferSetup {
// bm_request_type: BulkOnlyMassStorageReset::BM_REQUEST_TYPE,
// b_request: BulkOnlyMassStorageReset::B_REQUEST,
// w_value: 0,
// w_index: 0,
// w_length: 0,
// })
// .await?;
//
// // Get max LUN
// // TODO on devices which do not support multiple LUNs, this command may STALL
// let mut buffer = [MaybeUninit::uninit()];
// let len = control_pipe
// .control_transfer_in(
// ControlTransferSetup {
// bm_request_type: GetMaxLun::BM_REQUEST_TYPE,
// b_request: GetMaxLun::B_REQUEST,
// w_value: 0,
// w_index: 0,
// w_length: 1,
// },
// &mut buffer,
// )
// .await?;
// let max_lun = if len < 1 {
// 0
// } else {
// unsafe { buffer[0].assume_init() }
// };
//
// let bbb = Bbb::new(device.clone(), in_pipe, out_pipe)?;
// let scsi = ScsiEnclosure::setup(Box::new(bbb), max_lun as usize + 1)
// .await
// .inspect_err(|error| log::error!("msc: scsi error {error:?}"))
// .map_err(|_| UsbError::DriverError)?;
// let detach = DetachHandler(scsi.clone());
// device.set_detach_handler(Arc::new(detach));
//
// Ok(())
// }
//
// fn name(&self) -> &'static str {
// "USB Mass Storage"
// }
//
// fn probe(&self, class: &UsbClassInfo, _device: &UsbDeviceAccess) -> bool {
// // TODO support other protocols
// class.class == UsbDeviceClass::MassStorage && class.interface_protocol_number == 0x50
// }
// }
+174 -83
View File
@@ -1,117 +1,208 @@
use core::mem::MaybeUninit;
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use async_trait::async_trait;
use libk::task::runtime;
use libk_util::sync::spin_rwlock::IrqSafeRwLock;
use crate::{
address::UsbInterfaceAddress,
device::UsbDeviceAccess,
error::UsbError,
info::{UsbDeviceClass, UsbDeviceProtocol},
info::{UsbInterfaceInfo, CLASS_FROM_INTERFACE},
pipe::control::{ControlTransferSetup, UsbClassSpecificRequest},
};
// use alloc::{boxed::Box, sync::Arc, vec::Vec};
// use async_trait::async_trait;
// use libk::task::runtime;
// use libk_util::sync::spin_rwlock::IrqSafeRwLock;
//
// use crate::{
// device::UsbDeviceAccess,
// error::UsbError,
// info::{UsbDeviceClass, UsbDeviceProtocol},
// };
//
pub mod hid_keyboard;
pub mod mass_storage;
#[derive(Debug)]
pub struct UsbClassInfo {
pub class: UsbDeviceClass,
pub mod hid_mouse;
// pub mod mass_storage;
//
// #[derive(Debug)]
// pub struct UsbClassInfo {
// pub class: UsbDeviceClass,
// pub subclass: u8,
// pub protocol: UsbDeviceProtocol,
// pub device_protocol_number: u8,
// pub interface_protocol_number: u8,
// }
//
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct UsbInterfaceClass {
pub class: u8,
pub subclass: u8,
pub protocol: UsbDeviceProtocol,
pub device_protocol_number: u8,
pub interface_protocol_number: u8,
pub protocol: u8,
}
#[async_trait]
pub trait UsbDriver: Send + Sync {
async fn run(self: Arc<Self>, device: Arc<UsbDeviceAccess>) -> Result<(), UsbError>;
pub trait UsbInterfaceDriver: Send + Sync {
async fn run(
self: Arc<Self>,
device: Arc<UsbDeviceAccess>,
interface: UsbInterfaceInfo,
) -> Result<(), UsbError>;
fn name(&self) -> &'static str;
fn probe(&self, class: &UsbClassInfo, device: &UsbDeviceAccess) -> bool;
fn probe(
&self,
device: &UsbDeviceAccess,
interface: &UsbInterfaceInfo,
class: UsbInterfaceClass,
) -> bool;
}
async fn extract_class_info(device: &UsbDeviceAccess) -> Result<Option<UsbClassInfo>, UsbError> {
if device.info.num_configurations != 1 {
return Ok(None);
}
let device_info = &device.info;
let config_info = device.query_configuration_info(0).await?;
// async fn extract_class_info(device: &UsbDeviceAccess) -> Result<Option<UsbClassInfo>, UsbError> {
// if device.info.num_configurations != 1 {
// return Ok(None);
// }
// let device_info = &device.info;
// let config_info = device.query_configuration_info(0).await?;
//
// if config_info.interfaces.len() >= 1 {
// let if_info = &config_info.interfaces[0];
//
// let class = if device_info.device_class == UsbDeviceClass::FromInterface {
// if_info.interface_class
// } else {
// device_info.device_class
// };
// let subclass = if device_info.device_subclass == 0 {
// if_info.interface_subclass
// } else {
// device_info.device_subclass
// };
// let protocol = if device_info.device_protocol == UsbDeviceProtocol::FromInterface {
// if_info.interface_protocol
// } else {
// device_info.device_protocol
// };
//
// Ok(Some(UsbClassInfo {
// class,
// subclass,
// protocol,
// interface_protocol_number: if_info.interface_protocol_number,
// device_protocol_number: device_info.device_protocol_number,
// }))
// } else {
// Ok(None)
// }
// }
//
// async fn pick_driver(
// device: &UsbDeviceAccess,
// ) -> Result<Option<Arc<dyn UsbDriver + 'static>>, UsbError> {
// let Some(class) = extract_class_info(device).await? else {
// return Ok(None);
// };
//
// for driver in USB_DEVICE_DRIVERS.read().iter() {
// if driver.probe(&class, device) {
// return Ok(Some(driver.clone()));
// }
// }
// Ok(None)
// }
//
// pub async fn spawn_driver(device: Arc<UsbDeviceAccess>) -> Result<(), UsbError> {
// // if let Some(driver) = pick_driver(&device).await? {
// // runtime::spawn(async move {
// // let name = driver.name();
// // match driver.run(device).await {
// // e @ Err(UsbError::DeviceDisconnected) => {
// // log::warn!(
// // "Driver {:?} did not exit cleanly: device disconnected",
// // name,
// // );
//
// // e
// // }
// // e => e,
// // }
// // })
// // .map_err(UsbError::SystemError)?;
// // }
// Ok(())
// }
if config_info.interfaces.len() >= 1 {
let if_info = &config_info.interfaces[0];
let class = if device_info.device_class == UsbDeviceClass::FromInterface {
if_info.interface_class
} else {
device_info.device_class
};
let subclass = if device_info.device_subclass == 0 {
if_info.interface_subclass
} else {
device_info.device_subclass
};
let protocol = if device_info.device_protocol == UsbDeviceProtocol::FromInterface {
if_info.interface_protocol
} else {
device_info.device_protocol
};
Ok(Some(UsbClassInfo {
class,
subclass,
protocol,
interface_protocol_number: if_info.interface_protocol_number,
device_protocol_number: device_info.device_protocol_number,
}))
} else {
Ok(None)
}
}
async fn pick_driver(
device: &UsbDeviceAccess,
) -> Result<Option<Arc<dyn UsbDriver + 'static>>, UsbError> {
let Some(class) = extract_class_info(device).await? else {
return Ok(None);
async fn setup_interface(
device: &Arc<UsbDeviceAccess>,
address: UsbInterfaceAddress,
interface: &UsbInterfaceInfo,
) -> Result<(), UsbError> {
let class = UsbInterfaceClass {
class: interface.interface_class,
subclass: interface.interface_subclass,
protocol: interface.interface_protocol,
};
for driver in USB_DEVICE_DRIVERS.read().iter() {
if driver.probe(&class, device) {
return Ok(Some(driver.clone()));
}
}
Ok(None)
}
pub async fn spawn_driver(device: Arc<UsbDeviceAccess>) -> Result<(), UsbError> {
if let Some(driver) = pick_driver(&device).await? {
runtime::spawn(async move {
let name = driver.name();
match driver.run(device).await {
e @ Err(UsbError::DeviceDisconnected) => {
log::warn!(
"Driver {:?} did not exit cleanly: device disconnected",
name,
);
e
let drivers = USB_INTERFACE_DRIVERS.read();
for driver in drivers.iter() {
if driver.probe(device, interface, class) {
let driver = driver.clone();
let device = device.clone();
let interface = interface.clone();
runtime::spawn(async move {
let name = driver.name();
match driver.run(device, interface).await {
e @ Err(UsbError::DeviceDisconnected) => {
log::warn!("{address} did not exit cleanly: device disconnected ({name})");
e
}
e => e,
}
e => e,
}
})
.map_err(UsbError::SystemError)?;
})
.map_err(UsbError::SystemError)?;
break;
}
}
Ok(())
}
pub fn register_driver(driver: Arc<dyn UsbDriver + 'static>) {
pub async fn setup_device(device: Arc<UsbDeviceAccess>) -> Result<(), UsbError> {
// If device has only one configuration available, use it
// TODO support devices with multiple configurations
let address = device.bus_address();
log::info!("Setup USB device @ {address}");
let Some(config_info) = device.use_default_configuration().await? else {
log::warn!("{address} has multiple configurations, not supported yet",);
return Ok(());
};
// Setup drivers for interfaces
log::info!("{address}: {config_info:#?}");
// TODO device-level drivers
for interface in config_info.interfaces.iter() {
let address = address.with_interface(interface.number);
if let Err(error) = setup_interface(&device, address, interface).await {
log::error!("{}: {:?}", address, error);
}
}
Ok(())
}
pub fn register_driver(driver: Arc<dyn UsbInterfaceDriver + 'static>) {
// TODO check for duplicates
USB_DEVICE_DRIVERS.write().push(driver);
USB_INTERFACE_DRIVERS.write().push(driver);
}
pub fn register_default_class_drivers() {
register_driver(Arc::new(hid_keyboard::UsbHidKeyboardDriver));
register_driver(Arc::new(mass_storage::UsbMassStorageDriverBulkOnly));
register_driver(Arc::new(hid_mouse::UsbHidMouseDriver));
// register_driver(Arc::new(mass_storage::UsbMassStorageDriverBulkOnly));
}
static USB_DEVICE_DRIVERS: IrqSafeRwLock<Vec<Arc<dyn UsbDriver + 'static>>> =
static USB_INTERFACE_DRIVERS: IrqSafeRwLock<Vec<Arc<dyn UsbInterfaceDriver + 'static>>> =
IrqSafeRwLock::new(Vec::new());
+23 -19
View File
@@ -2,9 +2,9 @@ use bytemuck::{Pod, Zeroable};
use crate::{
communication::UsbDirection,
device::UsbSpeed,
device::{self, UsbSpeed},
error::UsbError,
info::{UsbDeviceClass, UsbDeviceProtocol, UsbEndpointType, UsbVersion},
info::UsbEndpointType,
};
#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
@@ -91,15 +91,15 @@ pub struct UsbOtherSpeedConfiguration {
pub max_power: u8,
}
impl UsbInterfaceDescriptor {
pub fn class(&self) -> UsbDeviceClass {
UsbDeviceClass::try_from(self.interface_class).unwrap_or(UsbDeviceClass::Unknown)
}
pub fn protocol(&self) -> UsbDeviceProtocol {
UsbDeviceProtocol::try_from(self.interface_protocol).unwrap_or(UsbDeviceProtocol::Unknown)
}
}
// impl UsbInterfaceDescriptor {
// pub fn class(&self) -> UsbDeviceClass {
// UsbDeviceClass::try_from(self.interface_class).unwrap_or(UsbDeviceClass::Unknown)
// }
//
// pub fn protocol(&self) -> UsbDeviceProtocol {
// UsbDeviceProtocol::try_from(self.interface_protocol).unwrap_or(UsbDeviceProtocol::Unknown)
// }
// }
impl UsbEndpointDescriptor {
pub fn direction(&self) -> UsbDirection {
@@ -127,16 +127,16 @@ impl UsbEndpointDescriptor {
}
impl UsbDeviceDescriptor {
pub fn class(&self) -> UsbDeviceClass {
UsbDeviceClass::try_from(self.device_class).unwrap_or(UsbDeviceClass::Unknown)
}
// pub fn class(&self) -> UsbDeviceClass {
// UsbDeviceClass::try_from(self.device_class).unwrap_or(UsbDeviceClass::Unknown)
// }
pub fn protocol(&self) -> UsbDeviceProtocol {
UsbDeviceProtocol::try_from(self.device_protocol).unwrap_or(UsbDeviceProtocol::Unknown)
}
// pub fn protocol(&self) -> UsbDeviceProtocol {
// UsbDeviceProtocol::try_from(self.device_protocol).unwrap_or(UsbDeviceProtocol::Unknown)
// }
pub fn max_packet_size(&self, version: UsbVersion, speed: UsbSpeed) -> Result<usize, UsbError> {
match (version.is_version_3(), speed, self.max_packet_size_0) {
pub fn max_packet_size(&self, version: u16, speed: UsbSpeed) -> Result<usize, UsbError> {
match (is_version_3(version), speed, self.max_packet_size_0) {
(true, UsbSpeed::Super, 9) => Ok(1 << 9),
(true, _, _) => todo!("Non-GenX speed USB3+ maxpacketsize0"),
(false, _, 8) => Ok(8),
@@ -147,3 +147,7 @@ impl UsbDeviceDescriptor {
}
}
}
pub fn is_version_3(version: u16) -> bool {
version & 0xFF00 == 0x300
}
+89 -65
View File
@@ -1,14 +1,28 @@
use core::{fmt, ops::Deref};
use core::{
fmt,
ops::{Deref, Sub},
};
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use alloc::{
boxed::Box,
format,
string::{String, ToString},
sync::Arc,
vec::Vec,
};
use async_trait::async_trait;
use libk_util::sync::spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockReadGuard};
use libk::error::Error;
use libk_util::{
sync::spin_rwlock::{IrqSafeRwLock, IrqSafeRwLockReadGuard},
OneTimeInit,
};
use crate::{
address::UsbBusAddress,
bus::UsbBusWrapper,
error::UsbError,
info::{
UsbConfigurationInfo, UsbDeviceInfo, UsbEndpointInfo, UsbEndpointType, UsbInterfaceInfo,
UsbVersion,
},
pipe::{
control::{ConfigurationDescriptorEntry, UsbControlPipeAccess},
@@ -17,21 +31,26 @@ use crate::{
UsbNormalPipeOut,
},
},
sysfs::UsbDeviceKObject,
UsbHostController,
};
// High-level structures for info provided through descriptors
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct UsbBusAddress {
pub bus: u16,
pub device: u8,
}
// #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
// pub struct UsbBusAddress {
// pub bus: u16,
// pub device: u8,
// }
pub struct UsbDeviceAccess {
pub device: Arc<dyn UsbDevice>,
pub bus: Arc<UsbBusWrapper>,
pub info: UsbDeviceInfo,
pub current_configuration: IrqSafeRwLock<Option<UsbConfigurationInfo>>,
pub configurations: Vec<UsbConfigurationInfo>,
current_configuration: IrqSafeRwLock<Option<usize>>,
pub(crate) kobject: OneTimeInit<UsbDeviceKObject>,
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
@@ -69,7 +88,7 @@ pub trait UsbDevice: Send + Sync {
fn port_number(&self) -> u8;
fn bus_address(&self) -> UsbBusAddress;
fn speed(&self) -> UsbSpeed;
fn controller_ref(&self) -> &dyn UsbHostController;
fn host_controller(&self) -> Arc<dyn UsbHostController>;
fn set_detach_handler(&self, handler: Arc<dyn UsbDeviceDetachHandler>);
fn handle_detach(&self);
@@ -84,47 +103,43 @@ impl UsbDeviceAccess {
/// * Device is not yet configured
/// * Control pipe for the device has been properly set up
/// * Device has been assigned a bus address
pub async fn setup(raw: Arc<dyn UsbDevice>) -> Result<Self, UsbError> {
pub async fn setup(bus: Arc<UsbBusWrapper>, raw: Arc<dyn UsbDevice>) -> Result<Self, UsbError> {
let control = raw.control_pipe();
let device_desc = control.query_device_descriptor().await?;
let bcd_usb = device_desc.bcd_usb;
let usb_version = UsbVersion::from_bcd_usb(device_desc.bcd_usb)
.ok_or(UsbError::InvalidDescriptorField)
.inspect_err(|_| {
log::error!(
"{}: unsupported/invalid USB version: {:#x}",
raw.bus_address(),
bcd_usb
)
})?;
let manufacturer = control.query_string(device_desc.manufacturer_str).await?;
let product = control.query_string(device_desc.product_str).await?;
// Query device
let info = UsbDeviceInfo {
manufacturer,
product,
usb_version,
usb_version: device_desc.bcd_usb,
id_vendor: device_desc.id_vendor,
id_product: device_desc.id_product,
device_class: device_desc.class(),
device_class: device_desc.device_class,
device_subclass: device_desc.device_subclass,
device_protocol: device_desc.protocol(),
device_protocol_number: device_desc.device_protocol,
device_protocol: device_desc.device_protocol,
num_configurations: device_desc.num_configurations,
max_packet_size: device_desc.max_packet_size(usb_version, raw.speed())?,
max_packet_size: device_desc.max_packet_size(device_desc.bcd_usb, raw.speed())?,
};
let configurations =
Self::query_configurations(control, device_desc.num_configurations).await?;
Ok(Self {
device: raw,
bus,
info,
current_configuration: IrqSafeRwLock::new(None),
configurations,
kobject: OneTimeInit::new(),
})
}
@@ -164,45 +179,54 @@ impl UsbDeviceAccess {
Ok(UsbBulkOutPipeAccess(pipe))
}
pub fn read_current_configuration(
&self,
) -> IrqSafeRwLockReadGuard<'_, Option<UsbConfigurationInfo>> {
self.current_configuration.read()
pub fn current_configuration(&self) -> Option<&UsbConfigurationInfo> {
let index = (*self.current_configuration.read())?;
Some(&self.configurations[index])
}
pub async fn select_configuration<F: Fn(&UsbConfigurationInfo) -> bool>(
pub async fn use_default_configuration(
&self,
predicate: F,
) -> Result<Option<UsbConfigurationInfo>, UsbError> {
let mut current_config = self.current_configuration.write();
let control_pipe = self.control_pipe();
for i in 0..self.info.num_configurations {
let info = self.query_configuration_info(i).await?;
if predicate(&info) {
log::debug!("Selected configuration: {:#?}", info);
let config = current_config.insert(info);
control_pipe
.set_configuration(config.config_value as _)
.await?;
return Ok(Some(config.clone()));
}
if self.configurations.len() != 1 {
return Ok(None);
}
Ok(None)
self.set_configuration(0).await.map(Some)
}
pub async fn query_configuration_info(
&self,
index: u8,
) -> Result<UsbConfigurationInfo, UsbError> {
if index >= self.info.num_configurations {
pub async fn set_configuration(&self, index: usize) -> Result<UsbConfigurationInfo, UsbError> {
if index >= self.configurations.len() {
return Err(UsbError::InvalidConfiguration);
}
let mut current = self.current_configuration.write();
let control_pipe = self.control_pipe();
let info = self.configurations[index].clone();
control_pipe
.set_configuration(info.config_value as _)
.await?;
*current = Some(index);
Ok(info)
}
async fn query_configurations(
control_pipe: &UsbControlPipeAccess,
num_configurations: u8,
) -> Result<Vec<UsbConfigurationInfo>, UsbError> {
let mut configurations = Vec::new();
for i in 0..num_configurations {
let configuration = Self::query_configuration(control_pipe, i).await?;
configurations.push(configuration);
}
Ok(configurations)
}
async fn query_configuration(
control_pipe: &UsbControlPipeAccess,
index: u8,
) -> Result<UsbConfigurationInfo, UsbError> {
let query = control_pipe.query_configuration_descriptor(index).await?;
let configuration_name = control_pipe
@@ -228,10 +252,9 @@ impl UsbDeviceAccess {
name,
number: iface.interface_number,
interface_class: iface.class(),
interface_class: iface.interface_class,
interface_subclass: iface.interface_subclass,
interface_protocol: iface.protocol(),
interface_protocol_number: iface.interface_protocol,
interface_protocol: iface.interface_protocol,
});
}
_ => (),
@@ -248,6 +271,13 @@ impl UsbDeviceAccess {
Ok(info)
}
// pub async fn query_configuration_info(
// &self,
// index: u8,
// ) -> Result<UsbConfigurationInfo, UsbError> {
// let control_pipe = self.control_pipe();
// }
pub fn set_detach_handler(&self, handler: Arc<dyn UsbDeviceDetachHandler>) {
self.device.set_detach_handler(handler);
}
@@ -260,9 +290,3 @@ impl Deref for UsbDeviceAccess {
&*self.device
}
}
impl fmt::Display for UsbBusAddress {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.bus, self.device)
}
}
+70 -61
View File
@@ -29,41 +29,51 @@ pub enum UsbUsageType {
Reserved,
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub enum UsbVersion {
Usb11,
Usb20,
Usb21,
Usb30,
Usb31,
Usb32,
}
// #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
// pub enum UsbVersion {
// Usb11,
// Usb20,
// Usb21,
// Usb30,
// Usb31,
// Usb32,
// }
primitive_enum! {
pub enum UsbDeviceClass: u8 {
FromInterface = 0x00,
Hid = 0x03,
MassStorage = 0x08,
Unknown = 0xFF,
}
}
pub const CLASS_FROM_INTERFACE: u8 = 0x00;
pub const CLASS_HID: u8 = 0x03;
pub const CLASS_MASS_STORAGE: u8 = 0x08;
primitive_enum! {
pub enum UsbDeviceProtocol: u8 {
FromInterface = 0x00,
Unknown = 0xFF,
}
}
// primitive_enum! {
// pub enum UsbDeviceClass: u8 {
// FromInterface = 0x00,
// Hid = 0x03,
// MassStorage = 0x08,
// Unknown = 0xFF,
// }
// }
//
// primitive_enum! {
// pub enum UsbDeviceProtocol: u8 {
// FromInterface = 0x00,
// Unknown = 0xFF,
// }
// }
#[derive(Debug, Clone)]
pub struct UsbInterfaceInfo {
pub name: String,
pub number: u8,
pub interface_class: UsbDeviceClass,
pub interface_class: u8,
pub interface_subclass: u8,
pub interface_protocol: UsbDeviceProtocol,
pub interface_protocol_number: u8,
pub interface_protocol: u8,
// pub name: String,
// pub number: u8,
// pub interface_class: UsbDeviceClass,
// pub interface_subclass: u8,
// pub interface_protocol: UsbDeviceProtocol,
// pub interface_protocol_number: u8,
}
#[derive(Debug, Clone)]
@@ -87,15 +97,14 @@ pub struct UsbDeviceInfo {
pub manufacturer: String,
pub product: String,
pub usb_version: UsbVersion,
pub usb_version: u16,
pub id_vendor: u16,
pub id_product: u16,
pub device_class: UsbDeviceClass,
pub device_class: u8,
pub device_subclass: u8,
pub device_protocol: UsbDeviceProtocol,
pub device_protocol_number: u8,
pub device_protocol: u8,
/// Max packet size for endpoint zero
pub max_packet_size: usize,
@@ -103,37 +112,37 @@ pub struct UsbDeviceInfo {
pub num_configurations: u8,
}
impl UsbVersion {
pub fn is_version_3(&self) -> bool {
matches!(self, Self::Usb30 | Self::Usb31 | Self::Usb32)
}
pub fn from_bcd_usb(value: u16) -> Option<Self> {
match value {
0x110 => Some(UsbVersion::Usb11),
0x200..=0x20F => Some(UsbVersion::Usb20),
0x210..=0x21F => Some(UsbVersion::Usb21),
0x300 => Some(UsbVersion::Usb30),
0x310 => Some(UsbVersion::Usb31),
0x320 => Some(UsbVersion::Usb32),
_ => None,
}
}
}
impl fmt::Display for UsbVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let string = match self {
Self::Usb11 => "USB1.1",
Self::Usb20 => "USB2.0",
Self::Usb21 => "USB2.1",
Self::Usb30 => "USB3.0",
Self::Usb31 => "USB3.1",
Self::Usb32 => "USB3.2",
};
f.write_str(string)
}
}
// impl UsbVersion {
// pub fn is_version_3(&self) -> bool {
// matches!(self, Self::Usb30 | Self::Usb31 | Self::Usb32)
// }
//
// pub fn from_bcd_usb(value: u16) -> Option<Self> {
// match value {
// 0x110 => Some(UsbVersion::Usb11),
// 0x200..=0x20F => Some(UsbVersion::Usb20),
// 0x210..=0x21F => Some(UsbVersion::Usb21),
// 0x300 => Some(UsbVersion::Usb30),
// 0x310 => Some(UsbVersion::Usb31),
// 0x320 => Some(UsbVersion::Usb32),
// _ => None,
// }
// }
// }
//
// impl fmt::Display for UsbVersion {
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// let string = match self {
// Self::Usb11 => "USB1.1",
// Self::Usb20 => "USB2.0",
// Self::Usb21 => "USB2.1",
// Self::Usb30 => "USB3.0",
// Self::Usb31 => "USB3.1",
// Self::Usb32 => "USB3.2",
// };
// f.write_str(string)
// }
// }
impl UsbEndpointInfo {
pub fn is(&self, ty: UsbEndpointType, dir: UsbDirection) -> bool {
+1
View File
@@ -0,0 +1 @@
pub struct UsbInterface {}
+10 -1
View File
@@ -8,15 +8,20 @@
maybe_uninit_fill
)]
use crate::sysfs::UsbBusKObject;
extern crate alloc;
pub mod address;
pub mod bus;
pub mod communication;
pub mod descriptor;
pub mod device;
pub mod error;
pub mod info;
pub mod interface;
pub mod pipe;
pub mod sysfs;
pub mod util;
pub mod class_driver;
@@ -25,4 +30,8 @@ pub mod class_driver;
pub trait UsbEndpoint: Sync {}
pub trait UsbHostController: Sync + Send {}
pub trait UsbHostController: Sync + Send {
fn register_sysfs_properties(&self, kobject: &UsbBusKObject) {
let _ = kobject;
}
}
@@ -37,6 +37,7 @@ pub trait UsbDeviceRequest: Sized + Pod {
pub trait UsbClassSpecificRequest: Sized + Pod {
const BM_REQUEST_TYPE: u8;
const B_REQUEST: u8;
const W_VALUE: u16 = 0;
}
pub trait UsbDescriptorRequest: UsbDeviceRequest {
+129
View File
@@ -0,0 +1,129 @@
use alloc::{format, sync::Arc};
use libk::{
error::Error,
fs::sysfs::{
self,
attribute::{IntegerAttribute, IntegerAttributeFormat, IntegerAttributeOps},
object::KObject,
},
};
use libk_util::OneTimeInit;
use crate::{bus::UsbBusWrapper, device::UsbDeviceAccess};
pub type UsbBusKObject = Arc<KObject<Arc<UsbBusWrapper>>>;
pub type UsbDeviceKObject = Arc<KObject<Arc<UsbDeviceAccess>>>;
pub(crate) fn register_bus_kobject(bus: &Arc<UsbBusWrapper>) -> UsbBusKObject {
let root = sysfs_usb_root();
let bus_kobject = KObject::new(bus.clone());
bus.hc.register_sysfs_properties(&bus_kobject);
root.add_object(format!("{}", bus.index), bus_kobject.clone())
.ok();
bus_kobject
}
pub(crate) fn register_device_kobject(device: &Arc<UsbDeviceAccess>) -> UsbDeviceKObject {
struct Class;
struct Subclass;
struct Protocol;
struct Version;
struct IdVendor;
struct IdProduct;
impl IntegerAttributeOps<u8> for Class {
type Data = Arc<UsbDeviceAccess>;
const NAME: &'static str = "class";
const FORMAT: IntegerAttributeFormat = IntegerAttributeFormat::Hex;
fn read(state: &Self::Data) -> Result<u8, Error> {
Ok(state.info.device_class)
}
}
impl IntegerAttributeOps<u8> for Subclass {
type Data = Arc<UsbDeviceAccess>;
const NAME: &'static str = "subclass";
const FORMAT: IntegerAttributeFormat = IntegerAttributeFormat::Hex;
fn read(state: &Self::Data) -> Result<u8, Error> {
Ok(state.info.device_subclass)
}
}
impl IntegerAttributeOps<u8> for Protocol {
type Data = Arc<UsbDeviceAccess>;
const NAME: &'static str = "protocol";
const FORMAT: IntegerAttributeFormat = IntegerAttributeFormat::Hex;
fn read(state: &Self::Data) -> Result<u8, Error> {
Ok(state.info.device_protocol)
}
}
impl IntegerAttributeOps<u16> for Version {
type Data = Arc<UsbDeviceAccess>;
const NAME: &'static str = "version";
const FORMAT: IntegerAttributeFormat = IntegerAttributeFormat::Hex;
fn read(state: &Self::Data) -> Result<u16, Error> {
Ok(state.info.usb_version)
}
}
impl IntegerAttributeOps<u16> for IdVendor {
type Data = Arc<UsbDeviceAccess>;
const NAME: &'static str = "vendor";
const FORMAT: IntegerAttributeFormat = IntegerAttributeFormat::Hex;
fn read(state: &Self::Data) -> Result<u16, Error> {
Ok(state.info.id_vendor)
}
}
impl IntegerAttributeOps<u16> for IdProduct {
type Data = Arc<UsbDeviceAccess>;
const NAME: &'static str = "product";
const FORMAT: IntegerAttributeFormat = IntegerAttributeFormat::Hex;
fn read(state: &Self::Data) -> Result<u16, Error> {
Ok(state.info.id_product)
}
}
let bus_kobject = device.bus.kobject();
let device_kobject = KObject::new(device.clone());
device_kobject
.add_attribute(IntegerAttribute::from(Class))
.ok();
device_kobject
.add_attribute(IntegerAttribute::from(Subclass))
.ok();
device_kobject
.add_attribute(IntegerAttribute::from(Protocol))
.ok();
device_kobject
.add_attribute(IntegerAttribute::from(Version))
.ok();
device_kobject
.add_attribute(IntegerAttribute::from(IdVendor))
.ok();
device_kobject
.add_attribute(IntegerAttribute::from(IdProduct))
.ok();
let address = device.bus_address();
bus_kobject
.add_object(format!("{}", address.device), device_kobject.clone())
.ok();
device_kobject
}
fn sysfs_usb_root() -> &'static Arc<KObject<()>> {
static USB_ROOT: OneTimeInit<Arc<KObject<()>>> = OneTimeInit::new();
USB_ROOT.or_init_with(|| {
let bus_object = sysfs::bus().expect("bus object");
let usb_object = KObject::new(());
bus_object.add_object("usb", usb_object.clone()).ok();
usb_object
})
}
+67 -8
View File
@@ -9,14 +9,20 @@ use async_trait::async_trait;
use device_api::device::Device;
use libk::{device::char::CharDevice, vfs::FileReadiness};
use libk_util::{ring::LossyRingQueue, OneTimeInit};
use yggdrasil_abi::{error::Error, io::KeyboardKeyEvent};
use yggdrasil_abi::{
abi_serde::wire,
error::Error,
io::{KeyboardKeyEvent, MouseEvent},
};
#[derive(Clone, Copy)]
pub struct KeyboardDevice;
#[derive(Clone, Copy)]
pub struct MouseDevice;
impl FileReadiness for KeyboardDevice {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
INPUT_QUEUE.poll_readable(cx).map(Ok)
KEYBOARD_INPUT_QUEUE.poll_readable(cx).map(Ok)
}
}
@@ -33,7 +39,7 @@ impl CharDevice for KeyboardDevice {
return Ok(0);
}
let ev = INPUT_QUEUE.read().await;
let ev = KEYBOARD_INPUT_QUEUE.read().await;
buf[..4].copy_from_slice(&ev.as_bytes());
@@ -45,7 +51,7 @@ impl CharDevice for KeyboardDevice {
return Ok(0);
}
let ev = INPUT_QUEUE.try_read().ok_or(Error::WouldBlock)?;
let ev = KEYBOARD_INPUT_QUEUE.try_read().ok_or(Error::WouldBlock)?;
buf[..4].copy_from_slice(&ev.as_bytes());
@@ -68,15 +74,68 @@ impl CharDevice for KeyboardDevice {
}
}
static INPUT_QUEUE: LossyRingQueue<KeyboardKeyEvent> = LossyRingQueue::with_capacity(32);
impl FileReadiness for MouseDevice {
fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
MOUSE_INPUT_QUEUE.poll_readable(cx).map(Ok)
}
}
impl Device for MouseDevice {
fn display_name(&self) -> &str {
"Mouse input pseudo-device"
}
}
#[async_trait]
impl CharDevice for MouseDevice {
async fn read(&self, buf: &mut [u8]) -> Result<usize, Error> {
let ev = MOUSE_INPUT_QUEUE.read().await;
let len = wire::to_slice(&ev, buf)?;
Ok(len)
}
fn read_nonblocking(&self, buf: &mut [u8]) -> Result<usize, Error> {
let ev = MOUSE_INPUT_QUEUE.try_read().ok_or(Error::WouldBlock)?;
let len = wire::to_slice(&ev, buf)?;
Ok(len)
}
fn is_writeable(&self) -> bool {
false
}
fn device_request(&self, option: u32, buffer: &mut [u8], len: usize) -> Result<usize, Error> {
let _ = option;
let _ = buffer;
let _ = len;
Err(Error::InvalidOperation)
}
fn is_terminal(&self) -> bool {
false
}
}
static KEYBOARD_INPUT_QUEUE: LossyRingQueue<KeyboardKeyEvent> = LossyRingQueue::with_capacity(32);
static KEYBOARD_DEVICE: OneTimeInit<Arc<KeyboardDevice>> = OneTimeInit::new();
pub fn setup() -> Arc<KeyboardDevice> {
static MOUSE_INPUT_QUEUE: LossyRingQueue<MouseEvent> = LossyRingQueue::with_capacity(32);
static MOUSE_DEVICE: OneTimeInit<Arc<MouseDevice>> = OneTimeInit::new();
pub fn setup_keyboard() -> Arc<KeyboardDevice> {
KEYBOARD_DEVICE
.or_init_with(|| Arc::new(KeyboardDevice))
.clone()
}
pub fn send_event(ev: KeyboardKeyEvent) {
INPUT_QUEUE.write(ev);
pub fn setup_mouse() -> Arc<MouseDevice> {
MOUSE_DEVICE.or_init_with(|| Arc::new(MouseDevice)).clone()
}
pub fn send_keyboard_event(ev: KeyboardKeyEvent) {
KEYBOARD_INPUT_QUEUE.write(ev);
}
pub fn send_mouse_event(ev: MouseEvent) {
MOUSE_INPUT_QUEUE.write(ev);
}
+17 -10
View File
@@ -125,24 +125,31 @@ impl<'a, M: MdioBus> PhyAccess<'a, M> {
})
}
pub fn setup_link(&self, have_pause: bool, force_gbesr: GBESR) -> Result<(), Error> {
pub fn setup_link(&self, have_pause: bool, force_gbesr: Option<GBESR>) -> Result<(), Error> {
let bmsr = BMSR::from(self.read_reg(REG_BMSR)?);
let mut gbesr = if bmsr.contains(BMSR::EXT_STATUS_1000BASET) {
GBESR::from(self.read_reg(REG_GBESR)?)
let gbesr = if let Some(force_gbesr) = force_gbesr {
let mut gbesr = if bmsr.contains(BMSR::EXT_STATUS_1000BASET) {
GBESR::from(self.read_reg(REG_GBESR)?)
} else {
GBESR::empty()
};
gbesr |= force_gbesr;
Some(gbesr)
} else {
GBESR::empty()
None
};
gbesr |= force_gbesr;
let mut anar = ANAR::from_capabilities(bmsr);
if have_pause {
anar |= ANAR::HAVE_PAUSE | ANAR::ASM_DIR;
}
let mut gbcr = GBCR::empty();
if gbesr.contains(GBESR::HAVE_1000BASET_HALF) {
gbcr |= GBCR::HAVE_1000BASET_HALF;
}
if gbesr.contains(GBESR::HAVE_1000BASET_FULL) {
gbcr |= GBCR::HAVE_1000BASET_FULL;
if let Some(gbesr) = gbesr {
if gbesr.contains(GBESR::HAVE_1000BASET_HALF) {
gbcr |= GBCR::HAVE_1000BASET_HALF;
}
if gbesr.contains(GBESR::HAVE_1000BASET_FULL) {
gbcr |= GBCR::HAVE_1000BASET_FULL;
}
}
self.write_reg(REG_ANAR, anar.bits())?;
+22 -3
View File
@@ -26,12 +26,15 @@ use ygg_driver_pci::{
};
use yggdrasil_abi::net::{link::LinkState, MacAddress};
use crate::regs::Revision;
extern crate alloc;
mod regs;
mod ring;
struct Igbe {
chip: Revision,
regs: IrqSafeSpinlock<Regs>,
dma: Arc<dyn DmaAllocator>,
pci: PciDeviceInfo,
@@ -43,8 +46,9 @@ struct Igbe {
}
impl Igbe {
pub fn new(dma: Arc<dyn DmaAllocator>, regs: Regs, pci: PciDeviceInfo) -> Self {
pub fn new(dma: Arc<dyn DmaAllocator>, regs: Regs, chip: Revision, pci: PciDeviceInfo) -> Self {
Self {
chip,
dma,
pci,
mac: OneTimeInit::new(),
@@ -74,7 +78,7 @@ impl Device for Igbe {
regs.reset(Duration::from_millis(200))?;
// Intel 8257x manuals say an additional interrupt disable is needed after a global reset
regs.disable_interrupts();
regs.set_link_up()?;
regs.set_link_up(self.chip)?;
// Initialize Rx
regs.initialize_receiver(&rx_ring);
@@ -175,6 +179,10 @@ impl NetworkDevice for Igbe {
pci_driver! {
matches: [
device (0x8086:0x100E), // 82540EM (E1000)
device (0x8086:0x100C), // 82544GC (E1000)
device (0x8086:0x100F), // 82545EM (E1000)
device (0x8086:0x10D3), // 82574L (E1000E) [[BROKEN]]
device (0x8086:0x10C9), // 82576 GbE
device (0x8086:0x1502), // 82579LM GbE (Lewisville)
],
@@ -197,11 +205,22 @@ pci_driver! {
}
};
let chip = match info.device_id {
0x100E | 0x100C | 0x100F => Revision::I8254x,
0x10D3 => Revision::I82574L,
0x10C9 => Revision::I82576,
0x1502 => Revision::I82579LM,
id => {
log::error!("Invalid igbe chip variant: {id:#04x}");
return Err(Error::InvalidOperation)
},
};
info.init_interrupts(PreferredInterruptMode::Msi(true))?;
info.set_command(true, use_mmio, !use_mmio, true);
let regs = unsafe { Regs::map(base) }?;
let device = Igbe::new(dma.clone(), regs, info.clone());
let device = Igbe::new(dma.clone(), regs, chip, info.clone());
Ok(Arc::new(device))
}
+21 -3
View File
@@ -42,6 +42,14 @@ pub trait Reg {
const OFFSET: u16;
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Revision {
I8254x,
I82574L,
I82576,
I82579LM,
}
register_bitfields! {
u32,
pub CTRL [
@@ -329,6 +337,7 @@ impl MdioBus for Regs {
let mdic = self.inner.extract();
if mdic.matches_all(MDIC::E::SET) {
log::warn!("MDIO read error: phyaddr={phyaddr:#x}, regaddr={regaddr:#x}");
return Err(Error::InvalidOperation);
}
@@ -350,6 +359,9 @@ impl MdioBus for Regs {
)?;
if self.inner.matches_all(MDIC::E::SET) {
log::warn!(
"MDIO write error: phyaddr={phyaddr:#x}, regaddr={regaddr:#x}, value={value:#x}"
);
return Err(Error::InvalidOperation);
}
@@ -403,7 +415,7 @@ impl Regs {
})
}
pub fn set_link_up(&mut self) -> Result<(), Error> {
pub fn set_link_up(&mut self, chip: Revision) -> Result<(), Error> {
self.inner
.modify(CTRL::SLU::SET + CTRL::RFCE::SET + CTRL::TFCE::SET);
@@ -412,8 +424,14 @@ impl Regs {
let (id0, id1) = phy.id()?;
log::info!("PHY {:04x}:{:04x}", id0, id1);
phy.reset(Duration::from_millis(200))?;
phy.setup_link(true, GBESR::empty())?;
phy.reset(Duration::from_millis(200))
.inspect_err(|e| log::error!("PHY reset error {e:?}"))?;
let force_gbesr = match chip {
Revision::I82576 | Revision::I82579LM => Some(GBESR::empty()),
_ => None,
};
phy.setup_link(true, force_gbesr)
.inspect_err(|e| log::error!("PHY setup error: {e:?}"))?;
Ok(())
}
+1 -1
View File
@@ -498,7 +498,7 @@ impl Regs {
phy.write_reg(0x0E, 0x00)?;
phy.reset(timeout)?;
phy.setup_link(true, GBESR::empty())?;
phy.setup_link(true, Some(GBESR::empty()))?;
psleep(Duration::from_millis(100));
+1 -1
View File
@@ -337,7 +337,7 @@ impl Device for Stmmac {
let (id0, id1) = phy.id()?;
log::info!("stmmac: PHY {id0:04x}:{id1:04x}");
phy.reset(Duration::from_millis(100))?;
phy.setup_link(true, GBESR::empty())?;
phy.setup_link(true, Some(GBESR::empty()))?;
self.inner.init(Inner {
regs: IrqSafeSpinlock::new(regs),
+14 -11
View File
@@ -25,10 +25,11 @@ use tock_registers::{
};
use ygg_driver_pci::{device::PciDeviceInfo, PciConfigurationSpace};
use ygg_driver_usb::{
bus::UsbBusManager,
device::{UsbBusAddress, UsbDeviceAccess, UsbSpeed},
address::UsbBusAddress,
bus::{UsbBusManager, UsbBusWrapper},
descriptor,
device::{UsbDeviceAccess, UsbSpeed},
error::UsbError,
info::UsbVersion,
pipe::control::UsbControlPipeAccess,
UsbHostController,
};
@@ -67,7 +68,7 @@ struct ScratchpadArray {
}
struct RootHubPort {
version: UsbVersion,
version: u16,
slot_type: u8,
}
@@ -90,6 +91,7 @@ pub struct Xhci {
pub(crate) slots: Vec<IrqSafeRwLock<Option<Arc<XhciBusDevice>>>>,
pub(crate) port_slot_map: Vec<AtomicU8>,
bus_index: OneTimeInit<u16>,
bus: OneTimeInit<Arc<UsbBusWrapper>>,
port_event_map: EventBitmap,
}
@@ -146,9 +148,7 @@ impl Xhci {
for cap in regs.extended_capabilities.iter() {
match cap {
ExtendedCapability::ProtocolSupport(support) => {
let Some(version) = support.usb_revision() else {
continue;
};
let version = support.usb_revision();
for port in support.port_range() {
log::info!("* Port {port}: {version}");
@@ -195,6 +195,7 @@ impl Xhci {
root_hub_ports,
bus_index: OneTimeInit::new(),
bus: OneTimeInit::new(),
endpoints: IrqSafeRwLock::new(BTreeMap::new()),
slots,
port_slot_map,
@@ -282,7 +283,7 @@ impl Xhci {
.as_ref()
.ok_or(UsbError::PortInitFailed)?;
let need_reset = !root_hub_port.version.is_version_3();
let need_reset = !descriptor::is_version_3(root_hub_port.version);
if need_reset {
self.reset_port(regs).await?;
@@ -341,7 +342,8 @@ impl Xhci {
device: bus_address,
});
let device = UsbDeviceAccess::setup(slot).await?;
let bus = self.bus.get();
let device = UsbDeviceAccess::setup(bus.clone(), slot).await?;
UsbBusManager::register_device(device.into());
Ok(())
@@ -523,8 +525,9 @@ impl Device for Xhci {
op.wait_usbsts_bit(USBSTS::CNR::CLEAR, 100000000)?;
let bus = UsbBusManager::register_bus(self.clone());
self.bus_index.init(bus);
let (bus_index, bus) = UsbBusManager::register_bus(self.clone());
self.bus_index.init(bus_index);
self.bus.init(bus);
runtime::spawn(self.clone().port_handler_task()).ok();
+4 -3
View File
@@ -6,8 +6,9 @@ use libk_util::{
};
use xhci_lib::context;
use ygg_driver_usb::{
address::UsbBusAddress,
communication::UsbDirection,
device::{UsbBusAddress, UsbDevice, UsbDeviceDetachHandler, UsbSpeed},
device::{UsbDevice, UsbDeviceDetachHandler, UsbSpeed},
error::UsbError,
info::UsbEndpointType,
pipe::{
@@ -63,8 +64,8 @@ impl UsbDevice for XhciBusDevice {
*self.detach_handler.lock() = Some(handler);
}
fn controller_ref(&self) -> &dyn UsbHostController {
self.xhci.as_ref()
fn host_controller(&self) -> Arc<dyn UsbHostController> {
self.xhci.clone()
}
fn debug(&self) {}
+3 -3
View File
@@ -7,7 +7,7 @@ use alloc::vec::Vec;
use libk::error::Error;
use libk_mm::{address::PhysicalAddress, device::DeviceMemoryIo};
use libk_util::sync::spin_rwlock::IrqSafeRwLock;
use ygg_driver_usb::{error::UsbError, info::UsbVersion};
use ygg_driver_usb::error::UsbError;
pub struct ProtocolSupport {
words: [u32; 4],
@@ -69,8 +69,8 @@ impl ExtendedCapability {
}
impl ProtocolSupport {
pub fn usb_revision(&self) -> Option<UsbVersion> {
UsbVersion::from_bcd_usb((self.words[0] >> 16) as u16)
pub fn usb_revision(&self) -> u16 {
(self.words[0] >> 16) as u16
}
pub fn slot_type(&self) -> u8 {
@@ -0,0 +1,239 @@
use core::{
any::Any,
marker::PhantomData,
sync::atomic::{AtomicBool, Ordering},
};
use alloc::{
string::{String, ToString},
sync::Arc,
vec::Vec,
};
use libk_util::sync::spin_rwlock::IrqSafeRwLock;
use yggdrasil_abi::{
error::Error,
io::{FileMode, OpenOptions},
};
use crate::{
fs::sysfs::object::KObject,
vfs::{CommonImpl, Filename, InstanceData, Metadata, Node, NodeFlags, NodeRef, RegularImpl},
};
use super::Attribute;
pub trait IntegerAttributeValue: Copy + Sync + Send + 'static {
fn to_bytes(&self, format: IntegerAttributeFormat) -> Vec<u8>;
}
pub enum IntegerAttributeFormat {
Decimal,
Octal,
Hex,
}
macro_rules! impl_integer_value {
($($ty:ty),+) => {
$(
impl IntegerAttributeValue for $ty {
fn to_bytes(&self, format: IntegerAttributeFormat) -> Vec<u8> {
match format {
IntegerAttributeFormat::Decimal => alloc::format!("{self}"),
IntegerAttributeFormat::Octal => alloc::format!("{self:o}"),
IntegerAttributeFormat::Hex => alloc::format!("{self:x}"),
}.into_bytes()
}
}
)+
};
}
impl_integer_value!(u8, u16, u32, u64);
pub trait IntegerAttributeOps<T: IntegerAttributeValue>: Sync + Send + 'static {
type Data: Send + 'static = ();
const WRITEABLE: bool = false;
const NAME: &'static str;
const FORMAT: IntegerAttributeFormat = IntegerAttributeFormat::Decimal;
fn read(state: &Self::Data) -> Result<T, Error> {
let _ = state;
Err(Error::NotImplemented)
}
fn write(state: &Self::Data, value: T) -> Result<(), Error> {
let _ = state;
let _ = value;
Err(Error::ReadOnly)
}
}
pub struct IntegerAttribute<T: IntegerAttributeValue, V: IntegerAttributeOps<T>>(
PhantomData<(T, V)>,
);
struct IntegerAttributeNode<T: IntegerAttributeValue, V: IntegerAttributeOps<T>> {
object: Arc<KObject<V::Data>>,
_pd: PhantomData<V>,
}
struct IntegerAttributeState<T: IntegerAttributeValue> {
value: IrqSafeRwLock<Vec<u8>>,
modified: AtomicBool,
_pd: PhantomData<T>,
}
impl<T: IntegerAttributeValue, V: IntegerAttributeOps<T>> CommonImpl
for IntegerAttributeNode<T, V>
{
fn size(&self, _node: &NodeRef) -> Result<u64, Error> {
Ok(0)
}
fn as_any(&self) -> &dyn Any {
self as _
}
}
impl<T: IntegerAttributeValue, V: IntegerAttributeOps<T>> RegularImpl
for IntegerAttributeNode<T, V>
{
fn open(
&self,
_node: &NodeRef,
opts: OpenOptions,
) -> Result<(u64, Option<InstanceData>), Error> {
if opts.contains(OpenOptions::WRITE) && !V::WRITEABLE {
return Err(Error::ReadOnly);
}
let mut value = V::read(self.object.data())?.to_bytes(V::FORMAT);
value.push(b'\n');
let instance = IntegerAttributeState {
value: IrqSafeRwLock::new(value),
modified: AtomicBool::new(false),
_pd: PhantomData::<T>,
};
Ok((0, Some(Arc::new(instance))))
}
fn close(&self, _node: &NodeRef, instance: Option<&InstanceData>) -> Result<(), Error> {
if V::WRITEABLE {
todo!()
// let instance = instance.ok_or(Error::InvalidFile)?;
// let instance = instance
// .downcast_ref::<StringAttributeState>()
// .ok_or(Error::InvalidFile)?;
// if instance.modified.load(Ordering::Acquire) {
// let value = instance.value.read();
// let value_str =
// core::str::from_utf8(&value[..]).map_err(|_| Error::InvalidArgument)?;
// // Trim whitespace and newlines
// V::write(&self.object.data, value_str.trim())?;
// }
}
Ok(())
}
fn read(
&self,
_node: &NodeRef,
instance: Option<&InstanceData>,
pos: u64,
buf: &mut [u8],
) -> Result<usize, Error> {
let instance = instance.ok_or(Error::InvalidFile)?;
let instance = instance
.downcast_ref::<IntegerAttributeState<T>>()
.ok_or(Error::InvalidFile)?;
let value = instance.value.read();
let len = value.len();
if pos >= len as u64 {
return Ok(0);
}
let pos = pos as usize;
let amount = (len - pos).min(buf.len());
buf[..amount].copy_from_slice(&value[pos..pos + amount]);
Ok(amount)
}
fn write(
&self,
_node: &NodeRef,
instance: Option<&InstanceData>,
pos: u64,
buf: &[u8],
) -> Result<usize, Error> {
todo!()
// if !V::WRITEABLE {
// return Err(Error::InvalidFile);
// }
// let instance = instance.ok_or(Error::InvalidFile)?;
// let instance = instance
// .downcast_ref::<StringAttributeState>()
// .ok_or(Error::InvalidFile)?;
// let mut value = instance.value.write();
// let pos: usize = pos.try_into().map_err(|_| Error::InvalidFile)?;
// if pos > value.len() {
// return Err(Error::InvalidArgument);
// }
// if pos + buf.len() > V::LIMIT {
// return Err(Error::InvalidArgument);
// }
// let amount_copy = (value.len() - pos).min(buf.len());
// value[pos..pos + amount_copy].copy_from_slice(&buf[..amount_copy]);
// if amount_copy < buf.len() {
// value.extend_from_slice(&buf[amount_copy..]);
// }
// instance.modified.store(true, Ordering::Release);
// Ok(buf.len())
}
fn truncate(&self, _node: &NodeRef, _new_size: u64) -> Result<(), Error> {
Ok(())
}
}
impl<T: IntegerAttributeValue, V: IntegerAttributeOps<T>> From<V> for IntegerAttribute<T, V> {
fn from(_value: V) -> Self {
Self(PhantomData)
}
}
impl<T: IntegerAttributeValue, V: IntegerAttributeOps<T>> Attribute<V::Data>
for IntegerAttribute<T, V>
{
fn instantiate(&self, parent: &Arc<KObject<V::Data>>) -> Result<Arc<Node>, Error> {
let mode = match V::WRITEABLE {
false => FileMode::new(0o444),
true => FileMode::new(0o644),
};
Ok(Node::regular(
IntegerAttributeNode {
object: parent.clone(),
_pd: PhantomData::<V>,
},
NodeFlags::IN_MEMORY_PROPS,
Some(Metadata::now_root(mode, 0)),
None,
))
}
// TODO implement this properly
fn name(&self) -> &Filename {
unsafe { Filename::from_str_unchecked(V::NAME) }
}
}
@@ -6,6 +6,7 @@ use crate::vfs::{Filename, NodeRef};
use super::object::KObject;
mod bytes;
mod integer;
mod string;
pub trait Attribute<D>: Sync + Send {
@@ -14,4 +15,7 @@ pub trait Attribute<D>: Sync + Send {
}
pub use bytes::{BytesAttribute, BytesAttributeOps};
pub use integer::{
IntegerAttribute, IntegerAttributeFormat, IntegerAttributeOps, IntegerAttributeValue,
};
pub use string::{StringAttribute, StringAttributeOps};
+1
View File
@@ -396,6 +396,7 @@ impl SymlinkImpl for FixedPathSymlink {
buf[..self.target.len()].copy_from_slice(self.target.as_bytes());
Ok(self.target.len())
} else {
log::warn!("FixedPathSymlink: BufferTooSmall");
Err(Error::BufferTooSmall)
}
}
+5
View File
@@ -94,6 +94,11 @@ impl PidFile {
pub fn read(&self, buf: &mut [u8], _non_blocking: bool) -> Result<usize, Error> {
if buf.len() < size_of::<u32>() + size_of::<i32>() {
log::warn!(
"PidFd: BufferTooSmall (need {}, got {})",
size_of::<u32>() + size_of::<i32>(),
buf.len()
);
return Err(Error::BufferTooSmall);
}
match self {
+3 -1
View File
@@ -162,7 +162,9 @@ impl<O: TerminalOutput> Terminal<O> {
if config.line.contains(TerminalLineOptions::SIGNAL) {
self.output.notify_readers();
if let Some(group_id) = *self.input.signal_pgroup.read() {
let pgrp = *self.input.signal_pgroup.read();
log::info!("Send terminal SIGINT to {pgrp:?}");
if let Some(group_id) = pgrp {
Process::signal_group(None, group_id, Signal::Interrupted);
self.input.ready_ring.notify_all();
return;
+1
View File
@@ -35,6 +35,7 @@ impl TimerFile {
pub fn read(&self, buf: &mut [u8], non_blocking: bool) -> Result<usize, Error> {
if buf.len() < size_of::<u8>() {
log::warn!("TimerFile: BufferTooSmall");
return Err(Error::BufferTooSmall);
}
if non_blocking {
+1 -1
View File
@@ -140,7 +140,7 @@ impl I686 {
TerminalInput::with_capacity(256)?,
ConsoleWrapper(textfb),
));
let keyboard_input = ygg_driver_input::setup();
let keyboard_input = ygg_driver_input::setup_keyboard();
runtime::spawn(
textfb_console
+1 -1
View File
@@ -108,7 +108,7 @@ impl InterruptHandler for PS2Controller {
inner.e0 = false;
ygg_driver_input::send_event(event);
ygg_driver_input::send_keyboard_event(event);
}
count != 0
+12 -3
View File
@@ -168,11 +168,20 @@ pub fn kernel_main() -> ! {
CPU_INIT_FENCE.wait_all(ArchitectureImpl::cpu_count());
// Add keyboard device
if let Err(error) =
devfs::add_named_char_device(ygg_driver_input::setup(), "kbd", FileMode::new(0o660))
{
if let Err(error) = devfs::add_named_char_device(
ygg_driver_input::setup_keyboard(),
"kbd",
FileMode::new(0o660),
) {
log::error!("Couldn't add keyboard device: {error:?}");
}
if let Err(error) = devfs::add_named_char_device(
ygg_driver_input::setup_mouse(),
"mouse",
FileMode::new(0o440),
) {
log::error!("Couldn't add pointer device: {error:?}");
}
task::init().expect("Failed to initialize the scheduler");
+39
View File
@@ -1,3 +1,5 @@
use abi_serde::{impl_newtype_serde, impl_struct_serde};
/// Describes a key pressed/released on a keyboard device
// Missing docs: self-explanatory names
#[allow(missing_docs)]
@@ -47,6 +49,43 @@ pub enum KeyboardKeyEvent {
Released(KeyboardKey),
}
/// Representation for button press state
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[repr(transparent)]
pub struct ButtonMask(pub u8);
impl_newtype_serde!(ButtonMask);
/// Representation for mouse event
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MouseEvent {
/// Current button state
pub buttons: ButtonMask,
/// X delta
pub dx: i32,
/// Y delta
pub dy: i32,
}
impl_struct_serde!(MouseEvent: [
buttons,
dx,
dy
]);
impl MouseEvent {
/// Returns `true` if the event has corresponding button set
pub fn button(&self, index: usize) -> bool {
if index >= 8 {
false
} else {
self.buttons.0 & (1 << index) != 0
}
}
}
impl KeyboardKey {
/// Converts [KeyboardKey] to its related [KeyboardKeyCode]
pub const fn code(self) -> KeyboardKeyCode {
+1 -1
View File
@@ -13,7 +13,7 @@ pub use crate::generated::{
OpenOptions, PollControl, RawFd, RemoveFlags, TimerOptions, UnmountOptions, UserId,
};
pub use file::{FileMetadataUpdate, FileMetadataUpdateMode, FileTimesUpdate, SeekFrom};
pub use input::{KeyboardKey, KeyboardKeyCode, KeyboardKeyEvent};
pub use input::{ButtonMask, KeyboardKey, KeyboardKeyCode, KeyboardKeyEvent, MouseEvent};
pub use terminal::{
TerminalControlCharacters, TerminalInputOptions, TerminalLineOptions, TerminalOptions,
TerminalOutputOptions, TerminalSize,
+27 -5
View File
@@ -2,11 +2,26 @@ use std::{path::PathBuf, process::Command};
use crate::IntoArgs;
#[derive(Debug)]
pub enum IntelGigabitRev {
I82574L,
I82544GC,
I82545EM,
I82540EM,
}
#[derive(Debug)]
pub enum QemuNic {
VirtioPci { mac: Option<String> },
Rtl8139 { mac: Option<String> },
IntelGigabit { mac: Option<String> },
VirtioPci {
mac: Option<String>,
},
Rtl8139 {
mac: Option<String>,
},
IntelGigabit {
mac: Option<String>,
rev: Option<IntelGigabitRev>,
},
}
#[derive(Debug, PartialEq, Eq)]
@@ -58,9 +73,16 @@ impl IntoArgs for QemuNic {
}
command.arg(val);
}
Self::IntelGigabit { mac } => {
Self::IntelGigabit { mac, rev } => {
let name = match rev {
Some(IntelGigabitRev::I82574L) => "e1000e",
Some(IntelGigabitRev::I82544GC) => "e1000-82544gc",
Some(IntelGigabitRev::I82545EM) => "e1000-82545em",
Some(IntelGigabitRev::I82540EM) => "e1000",
None => "igb",
};
command.arg("-device");
let mut val = "igb,netdev=net0".to_owned();
let mut val = format!("{name},netdev=net0");
if let Some(mac) = mac {
val.push_str(",mac=");
val.push_str(mac);
+7
View File
@@ -0,0 +1,7 @@
use abi::error::Error;
pub use abi::io::{ButtonMask, MouseEvent};
use abi_serde::wire;
pub fn parse_mouse_event(buffer: &[u8]) -> Result<MouseEvent, Error> {
wire::from_slice(buffer).map_err(Error::from)
}
+14 -1
View File
@@ -13,13 +13,26 @@ pub use abi::option::OptionSizeHint;
pub mod device;
pub mod filesystem;
pub mod input;
pub mod paths;
pub mod terminal;
pub use paths::*;
use core::mem::MaybeUninit;
use abi::{error::Error, option::OptionValue};
use abi::{error::Error, option::OptionValue, process::ProcessId};
pub fn read_pid_fd(fd: RawFd) -> Result<(ProcessId, i32), Error> {
let mut buffer = [0; size_of::<u32>() + size_of::<i32>()];
let len = unsafe { crate::sys::read(fd, &mut buffer) }?;
assert_eq!(len, buffer.len());
let mut word = [0; size_of::<u32>()];
word.copy_from_slice(&buffer[0..4]);
let pid = unsafe { ProcessId::from_raw(u32::from_ne_bytes(word)) };
word.copy_from_slice(&buffer[4..8]);
let status = i32::from_ne_bytes(word);
Ok((pid, status))
}
pub fn remove_file(at: Option<RawFd>, path: &str) -> Result<(), Error> {
unsafe { crate::sys::remove(at, path, RemoveFlags::empty()) }
+10 -10
View File
@@ -293,11 +293,11 @@ fn setup_dtv(image: &TlsImage, tls_info: &TlsInfo) -> Result<(), Error> {
// NOTE if module 1 is specified again by the dynamic loader, it will be overriden with
// what dynamic loader says
if let Some(module0_offset) = tls_info.module0_offset {
crate::debug_trace!(
Info,
"DTV[1] = {:#x}",
tls_info.base + module0_offset + DTV_OFFSET
);
// crate::debug_trace!(
// Info,
// "DTV[1] = {:#x}",
// tls_info.base + module0_offset + DTV_OFFSET
// );
dtv.set(
1,
core::ptr::without_provenance_mut(tls_info.base + module0_offset + DTV_OFFSET),
@@ -309,11 +309,11 @@ fn setup_dtv(image: &TlsImage, tls_info: &TlsInfo) -> Result<(), Error> {
}
for &(module_id, module_offset) in image.module_offsets.iter() {
assert!(module_offset < image.full_size);
crate::debug_trace!(
Info,
"DTV[{module_id}] = {:#x}",
tls_info.base + module_offset + DTV_OFFSET
);
// crate::debug_trace!(
// Info,
// "DTV[{module_id}] = {:#x}",
// tls_info.base + module_offset + DTV_OFFSET
// );
dtv.set(
module_id,
core::ptr::with_exposed_provenance_mut(tls_info.base + module_offset + DTV_OFFSET),
+16 -17
View File
@@ -629,7 +629,6 @@ dependencies = [
"cross",
"hmac",
"rand 0.8.5 (git+https://git.alnyan.me/yggdrasil/rand.git?branch=alnyan%2Fyggdrasil-rng_core-0.6.4)",
"ring 0.17.7",
"rustls",
"sha2",
"x25519-dalek",
@@ -2378,20 +2377,6 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "ring"
version = "0.17.7"
source = "git+https://git.alnyan.me/yggdrasil/ring.git?branch=alnyan%2Fyggdrasil#487ee1292bfd8eb68e4e89f590dd5b95dd0e1463"
dependencies = [
"cc",
"cfg-if",
"getrandom 0.2.12",
"libc",
"spin",
"untrusted",
"windows-sys 0.48.0",
]
[[package]]
name = "ring"
version = "0.17.14"
@@ -2434,9 +2419,9 @@ dependencies = [
"clap",
"cross",
"ed25519-dalek",
"env_logger",
"libterm",
"log",
"logsink",
"rand 0.8.5 (git+https://git.alnyan.me/yggdrasil/rand.git?branch=alnyan%2Fyggdrasil-rng_core-0.6.4)",
"sha2",
"thiserror",
@@ -2509,7 +2494,7 @@ version = "0.102.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
dependencies = [
"ring 0.17.14",
"ring",
"rustls-pki-types",
"untrusted",
]
@@ -2898,6 +2883,7 @@ dependencies = [
"sha2",
"thiserror",
"tui",
"usb-ids",
"yggdrasil-abi",
"yggdrasil-rt",
]
@@ -3181,6 +3167,19 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "usb-ids"
version = "1.2025.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f464d03993287ba27fae1c81bfa368df4493983de7e340429fc10e470043383"
dependencies = [
"nom",
"phf",
"phf_codegen",
"proc-macro2",
"quote",
]
[[package]]
name = "utf8_iter"
version = "1.0.4"
+1 -1
View File
@@ -51,7 +51,7 @@ raqote = { version = "0.8.3", default-features = false }
# Vendored/patched dependencies
rand = { git = "https://git.alnyan.me/yggdrasil/rand.git", branch = "alnyan/yggdrasil" }
rand_core = { git = "https://git.alnyan.me/yggdrasil/rand.git", branch = "alnyan/yggdrasil" }
ring = { git = "https://git.alnyan.me/yggdrasil/ring.git", branch = "alnyan/yggdrasil" }
# ring = { git = "https://git.alnyan.me/yggdrasil/ring.git", branch = "alnyan/yggdrasil" }
rsa = { git = "https://git.alnyan.me/yggdrasil/rsa.git", branch = "alnyan/yggdrasil" }
rustls = { git = "https://git.alnyan.me/yggdrasil/rustls.git", branch = "alnyan/yggdrasil", default-features = false, features = ["std", "logging", "tls12", "custom-provider"] }
curve25519-dalek = { git = "https://git.alnyan.me/yggdrasil/curve25519-dalek.git", branch = "alnyan/yggdrasil" }
+4
View File
@@ -140,6 +140,10 @@ impl RawStdin {
pub fn open() -> io::Result<Self> {
sys::RawStdinImpl::open().map(Self)
}
pub fn set_options(&mut self, options: TerminalOptionsImpl) -> io::Result<TerminalOptionsImpl> {
self.0.set_options(options)
}
}
impl Read for RawStdin {
+1
View File
@@ -47,6 +47,7 @@ pub(crate) trait Pipe: Read + Write + AsRawFd + Sized {
pub(crate) trait RawStdin: Sized + Read + AsRawFd {
fn open() -> io::Result<Self>;
fn set_options(&mut self, options: TerminalOptionsImpl) -> io::Result<TerminalOptionsImpl>;
// fn new(stdin: &'a mut Stdin) -> io::Result<Self>;
}
+1 -1
View File
@@ -13,7 +13,7 @@ pub struct PidFdImpl {
impl PidFd for PidFdImpl {
fn new(pid: u32) -> io::Result<Self> {
let pid = pid as i32;
let fd = unsafe { libc::syscall(libc::SYS_pidfd_open, pid, 0) } as i32;
let fd = unsafe { libc::syscall(libc::SYS_pidfd_open, pid, libc::PIDFD_NONBLOCK) } as i32;
if fd < 0 {
return Err(io::Error::last_os_error());
}
+14 -7
View File
@@ -1,6 +1,7 @@
use std::{
io::{self, Read, Write},
os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd},
process::Stdio,
};
use crate::sys::Pipe;
@@ -25,26 +26,32 @@ impl Pipe for PipeImpl {
Ok((Self { fd: read }, Self { fd: write }))
}
fn to_child_stdio(&self) -> std::process::Stdio {
todo!()
fn to_child_stdio(&self) -> Stdio {
unsafe { Stdio::from_raw_fd(self.as_raw_fd()) }
}
}
impl Read for PipeImpl {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let _ = buf;
todo!()
let len = unsafe { libc::read(self.fd.as_raw_fd(), buf.as_mut_ptr().cast(), buf.len()) };
if len < 0 {
return Err(io::Error::last_os_error());
}
Ok(len as usize)
}
}
impl Write for PipeImpl {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let _ = buf;
todo!()
let len = unsafe { libc::write(self.fd.as_raw_fd(), buf.as_ptr().cast(), buf.len()) };
if len < 0 {
return Err(io::Error::last_os_error());
}
Ok(len as usize)
}
fn flush(&mut self) -> io::Result<()> {
todo!()
Ok(())
}
}
+7 -1
View File
@@ -5,7 +5,7 @@ use std::{
os::fd::{AsRawFd, RawFd},
};
use crate::sys::RawStdin;
use crate::{io::TerminalOptionsImpl, sys::RawStdin};
enum Inner {
Stdin(Stdin),
@@ -72,6 +72,12 @@ impl RawStdin for RawStdinImpl {
})?;
Ok(Self { inner, saved })
}
fn set_options(&mut self, options: TerminalOptionsImpl) -> io::Result<TerminalOptionsImpl> {
self.inner.update_options(|t| {
*t = options;
})
}
}
impl io::Read for RawStdinImpl {
+17 -8
View File
@@ -1,28 +1,37 @@
use std::{
io,
os::{
fd::{AsRawFd, RawFd},
yggdrasil::io::pid::{PidFd as YggPidFd, ProcessId},
},
os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd},
};
use runtime::rt::{
process::{ProcessId, ProcessWait, WaitFlags},
sys as syscall,
};
use crate::sys::PidFd;
pub struct PidFdImpl(YggPidFd);
pub struct PidFdImpl {
fd: OwnedFd,
}
impl PidFd for PidFdImpl {
fn new(pid: u32) -> io::Result<Self> {
let pid = unsafe { ProcessId::from_raw(pid) };
YggPidFd::child(pid, false).map(Self)
let target = ProcessWait::Process(pid);
let fd = unsafe { syscall::create_pid(&target, WaitFlags::NON_BLOCKING) }
.map_err(io::Error::from)?;
let fd = unsafe { OwnedFd::from_raw_fd(fd) };
Ok(Self { fd })
}
fn exit_status(&self) -> io::Result<i32> {
self.0.status()
let (_, code) = runtime::rt::io::read_pid_fd(self.fd.as_raw_fd())?;
Ok(code)
}
}
impl AsRawFd for PidFdImpl {
fn as_raw_fd(&self) -> RawFd {
self.0.as_raw_fd()
self.fd.as_raw_fd()
}
}
@@ -63,6 +63,10 @@ impl RawStdin for RawStdinImpl {
let saved = inner.update_options(|_| TerminalOptions::raw_input())?;
Ok(Self { inner, saved })
}
fn set_options(&mut self, options: TerminalOptions) -> io::Result<TerminalOptions> {
self.inner.update_options(|_| options)
}
}
impl io::Read for RawStdinImpl {
+5 -1
View File
@@ -4,7 +4,7 @@ use std::{
time::Duration,
};
use crate::io::{Poll, RawStdin};
use crate::io::{Poll, RawStdin, TerminalOptionsImpl};
pub struct TerminalInput {
stdin: RawStdin,
@@ -75,6 +75,10 @@ impl TerminalInput {
})
}
pub fn set_options(&mut self, options: TerminalOptionsImpl) -> io::Result<TerminalOptionsImpl> {
self.stdin.set_options(options)
}
fn take(&mut self, count: usize) {
self.buffer.copy_within(count..self.buffer_len, 0);
self.buffer_len -= count;
+1 -1
View File
@@ -9,7 +9,7 @@ chrono.workspace = true
rustls.workspace = true
rand = { git = "https://git.alnyan.me/yggdrasil/rand.git", branch = "alnyan/yggdrasil-rng_core-0.6.4" }
ring.workspace = true
# ring.workspace = true
sha2.workspace = true
hmac.workspace = true
chacha20poly1305.workspace = true
@@ -0,0 +1,96 @@
use std::{io, net::TcpStream};
use crate::{
connection::{tls::TlsConnection, TlsConnector},
HttpConnection, HttpConnector, TcpConnector,
};
pub struct AutoConnector {
tcp: TcpConnector,
tls: TlsConnector,
}
pub enum AutoConnection {
Tls(TlsConnection),
Tcp(TcpStream),
}
impl AutoConnector {
pub fn new(tcp: TcpConnector, tls: TlsConnector) -> Self {
Self { tcp, tls }
}
pub fn insecure() -> Self {
let tls = TlsConnector::insecure();
let tcp = TcpConnector;
Self { tcp, tls }
}
}
impl HttpConnector for AutoConnector {
type Error = io::Error;
type Connection = AutoConnection;
fn connect(
&mut self,
remote: &std::net::SocketAddr,
scheme: &str,
server_name: &str,
timeout: Option<std::time::Duration>,
options: super::HttpConnectionOptions,
) -> Result<Self::Connection, Self::Error> {
if scheme == "https" {
self.tls
.connect(remote, scheme, server_name, timeout, options)
.map(AutoConnection::Tls)
} else {
self.tcp
.connect(remote, scheme, server_name, timeout, options)
.map(AutoConnection::Tcp)
}
}
fn supports_scheme(&self, scheme: &str) -> bool {
scheme == "https" || scheme == "http"
}
fn default_port(&self, scheme: &str) -> u16 {
if scheme == "https" {
443
} else {
80
}
}
}
impl HttpConnection for AutoConnection {
type Error = io::Error;
fn send(&mut self, buffer: &[u8]) -> Result<usize, Self::Error> {
match self {
Self::Tls(c) => c.send(buffer),
Self::Tcp(c) => c.send(buffer),
}
}
fn send_all(&mut self, buffer: &[u8]) -> Result<(), Self::Error> {
match self {
Self::Tls(c) => c.send_all(buffer),
Self::Tcp(c) => c.send_all(buffer),
}
}
fn recv(&mut self, buffer: &mut [u8]) -> Result<usize, Self::Error> {
match self {
Self::Tls(c) => c.recv(buffer),
Self::Tcp(c) => c.recv(buffer),
}
}
fn recv_exact(&mut self, buffer: &mut [u8]) -> Result<(), Self::Error> {
match self {
Self::Tls(c) => c.recv_exact(buffer),
Self::Tcp(c) => c.recv_exact(buffer),
}
}
}
+7 -2
View File
@@ -1,10 +1,16 @@
use std::{net::SocketAddr, time::Duration};
pub use tcp::TcpConnector;
#[cfg(feature = "https")]
pub use auto::AutoConnector;
#[cfg(feature = "https")]
pub use tls::TlsConnector;
pub mod tcp;
#[cfg(feature = "https")]
pub mod auto;
#[cfg(feature = "https")]
pub mod tls;
@@ -21,8 +27,6 @@ pub trait HttpConnector {
type Connection: HttpConnection<Error = Self::Error>;
type Error: std::error::Error + Send + 'static;
const DEFAULT_PORT: u16;
fn connect(
&mut self,
remote: &SocketAddr,
@@ -33,6 +37,7 @@ pub trait HttpConnector {
) -> Result<Self::Connection, Self::Error>;
fn supports_scheme(&self, scheme: &str) -> bool;
fn default_port(&self, scheme: &str) -> u16;
}
#[derive(Clone)]
+4 -2
View File
@@ -12,8 +12,6 @@ impl HttpConnector for TcpConnector {
type Connection = TcpStream;
type Error = io::Error;
const DEFAULT_PORT: u16 = 80;
fn connect(
&mut self,
remote: &SocketAddr,
@@ -34,6 +32,10 @@ impl HttpConnector for TcpConnector {
fn supports_scheme(&self, scheme: &str) -> bool {
scheme.eq_ignore_ascii_case("http")
}
fn default_port(&self, _scheme: &str) -> u16 {
80
}
}
impl HttpConnection for TcpStream {
+4 -2
View File
@@ -29,8 +29,6 @@ impl HttpConnector for TlsConnector {
type Connection = TlsConnection;
type Error = io::Error;
const DEFAULT_PORT: u16 = 443;
fn connect(
&mut self,
remote: &SocketAddr,
@@ -59,6 +57,10 @@ impl HttpConnector for TlsConnector {
fn supports_scheme(&self, scheme: &str) -> bool {
scheme == "https"
}
fn default_port(&self, _scheme: &str) -> u16 {
443
}
}
impl HttpConnection for TlsConnection {
+15 -1
View File
@@ -66,7 +66,9 @@ impl<C: HttpConnector> HttpClient<C> {
}
let host = url.host().ok_or(HttpError::MalformedUrl)?;
let port = url.port_u16().unwrap_or(C::DEFAULT_PORT);
let port = url
.port_u16()
.unwrap_or(self.connector.default_port(scheme));
let request = if let Some(path) = url.path_and_query() {
request.uri(path.as_str())
@@ -127,6 +129,18 @@ impl Default for HttpClient<TcpConnector> {
}
impl<C: HttpConnector, U: TryInto<Uri>> HttpRequestBuilder<'_, C, U> {
pub fn header<V: TryInto<HeaderValue>>(
mut self,
name: HeaderName,
value: V,
) -> Result<Self, HttpError<C::Error>>
where
http::Error: From<V::Error>,
{
self.builder = self.builder.header(name, value);
Ok(self)
}
pub fn call<R: HttpBody>(self) -> Result<HttpResponse<C::Connection, R>, HttpError<C::Error>> {
self.client.send(self.url, self.builder, ())
}
+71 -35
View File
@@ -3,11 +3,17 @@ use std::{
io::{self, stdout, Stdout, Write},
path::{Path, PathBuf},
process::ExitCode,
str::FromStr,
};
use clap::{Parser, Subcommand};
use hclient::{connection::TlsConnector, HttpClient, HttpError};
use http::{Method, Uri};
use hclient::{
connection::AutoConnector, HttpBody, HttpClient, HttpConnector, HttpError, HttpResponse,
};
use http::{
header::{self, ToStrError},
Method, Uri,
};
#[derive(Debug, thiserror::Error)]
enum Error {
@@ -15,12 +21,22 @@ enum Error {
IoError(#[from] io::Error),
#[error("HTTP error: {0}")]
HttpError(#[from] HttpError<io::Error>),
#[error("Malformed response: {0}")]
MalformedResponse(&'static str),
#[error("Too many redirects")]
TooManyRedirects,
#[error("Invalid header '{0}' value: {1}")]
InvalidHeaderValue(header::HeaderName, ToStrError),
#[error("Invalid URL ('{0}'): {1}")]
InvalidUrl(String, http::uri::InvalidUri),
}
#[derive(Debug, Parser)]
struct Arguments {
#[clap(short, long)]
output: Option<PathBuf>,
#[clap(short, long)]
follow: bool,
#[clap(subcommand)]
method: RequestMethod,
}
@@ -64,51 +80,71 @@ impl Write for Output {
}
}
fn do_request_tls(url: Uri, mut output: Output) -> Result<(), Error> {
let connector = TlsConnector::insecure();
let mut client = HttpClient::new_default(connector);
fn request_call<C: HttpConnector, B: HttpBody>(
client: &mut HttpClient<C>,
method: Method,
mut url: Uri,
) -> Result<HttpResponse<C::Connection, B>, Error>
where
Error: From<HttpError<C::Error>>,
{
let mut redirect_count = 0;
let max_redirects = 5;
while redirect_count < max_redirects {
log::info!("Try URL: {url:?}");
let response = client
.request(method.clone(), url)
.header(header::CONNECTION, "close")?
.call()?;
let mut buffer = [0; 4096];
let mut response = client.request(Method::GET, url).call().unwrap();
loop {
let len = response.read(&mut buffer)?;
if len == 0 {
break;
if response.status().is_redirection() {
let location =
response
.headers()
.get(header::LOCATION)
.ok_or(Error::MalformedResponse(
"No \"Location\" header in a redirect response",
))?;
let redirect_url = location
.to_str()
.map_err(|e| Error::InvalidHeaderValue(header::LOCATION, e))?;
let redirect_url = Uri::from_str(redirect_url)
.map_err(|e| Error::InvalidUrl(redirect_url.into(), e))?;
log::info!("Redirect to {redirect_url:?}");
url = redirect_url;
redirect_count += 1;
} else {
return Ok(response);
}
output.write_all(&buffer[..len])?;
}
Ok(())
}
fn do_request_tcp(url: Uri, mut output: Output) -> Result<(), Error> {
let mut client = HttpClient::default();
let mut buffer = [0; 4096];
let mut response = client.request(Method::GET, url).call().unwrap();
loop {
let len = response.read(&mut buffer)?;
if len == 0 {
break;
}
output.write_all(&buffer[..len])?;
}
Ok(())
Err(Error::TooManyRedirects)
}
fn get(url: Uri, output: Option<PathBuf>) -> Result<(), Error> {
let output = Output::open(output)?;
let use_https = url.scheme_str().map_or(false, |scheme| scheme == "https");
let mut output = Output::open(output)?;
let connector = AutoConnector::insecure();
let mut client = HttpClient::new_default(connector);
let mut buffer = [0; 4096];
if use_https {
do_request_tls(url, output)
} else {
do_request_tcp(url, output)
let mut response = request_call(&mut client, Method::GET, url)?;
loop {
let len = response.read(&mut buffer)?;
if len == 0 {
break;
}
output.write_all(&buffer[..len])?;
}
Ok(())
}
fn main() -> ExitCode {
logsink::setup_logging(false);
let args = Arguments::parse();
let result = match args.method {
+2
View File
@@ -19,4 +19,6 @@ pub enum Error {
TimedOut,
#[error("TOML deserialize error: {0}")]
TomlDeserializeErr(#[from] toml::de::Error),
#[error("Cannot resolve network address: {0:?}")]
UnresolvedAddress(String),
}
+24 -5
View File
@@ -1,10 +1,16 @@
#![feature(yggdrasil_os, rustc_private)]
use std::{
mem::size_of, net::{IpAddr, Ipv4Addr}, os::{
mem::size_of,
net::{IpAddr, Ipv4Addr, ToSocketAddrs},
os::{
fd::AsRawFd,
yggdrasil::io::{net::raw_socket::RawSocket, poll::PollChannel, timer::TimerFd},
}, process::ExitCode, sync::atomic::{AtomicBool, Ordering}, time::Duration
},
process::ExitCode,
str::FromStr,
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use bytemuck::Zeroable;
@@ -51,7 +57,7 @@ struct Args {
data_size: usize,
#[clap(help = "Address to ping")]
address: core::net::IpAddr,
address: String,
}
fn valid_interval(s: &str) -> Result<u32, String> {
@@ -270,12 +276,25 @@ fn ping_once(
}
fn ping(
address: IpAddr,
input_address: &str,
times: usize,
data_len: usize,
interval: Duration,
timeout: Duration,
) -> Result<PingStats, Error> {
let address = if let Ok(ip) = IpAddr::from_str(input_address) {
ip
} else {
let socket_address = format!("{input_address}:0");
let mut addrs = ToSocketAddrs::to_socket_addrs(&socket_address)?;
let addr = addrs
.next()
.ok_or_else(|| Error::UnresolvedAddress(input_address.into()))?;
addr.ip()
};
println!("Pinging {input_address:?} ({address})");
let routing = resolve_routing(address)?;
let mut stats = PingStats {
@@ -326,7 +345,7 @@ fn main() -> ExitCode {
let args = Args::parse();
let stats = match ping(
args.address,
&args.address,
args.count,
args.data_size,
Duration::from_millis(args.inteval.into()),
+5
View File
@@ -25,6 +25,7 @@ tui.workspace = true
# Own regex implementation?
regex = "1.11.1"
pci-ids = { version = "0.2.5" }
usb-ids = { version = "1.2025.2" }
cryptic.workspace = true
rustls.workspace = true
@@ -145,6 +146,10 @@ path = "src/sleep.rs"
name = "lspci"
path = "src/lspci.rs"
[[bin]]
name = "lsusb"
path = "src/lsusb.rs"
[[bin]]
name = "ps"
path = "src/ps.rs"
+122
View File
@@ -0,0 +1,122 @@
use std::{
fs::{self, ReadDir},
io,
path::PathBuf,
process::ExitCode,
};
struct BusAddress {
bus: u16,
device: u8,
}
struct Device {
path: PathBuf,
address: BusAddress,
id_vendor: Option<u16>,
id_product: Option<u16>,
}
impl Device {
pub fn read(address: BusAddress) -> Self {
let path = PathBuf::from("/sys/bus/usb")
.join(format!("bus{}", address.bus))
.join(address.device.to_string());
let id_vendor = fs::read_to_string(path.join("vendor"))
.inspect_err(|e| eprintln!("{path:?}/vendor: {e}"))
.ok()
.and_then(|v| u16::from_str_radix(v.trim(), 16).ok());
let id_product = fs::read_to_string(path.join("product"))
.inspect_err(|e| eprintln!("{path:?}/product: {e}"))
.ok()
.and_then(|v| u16::from_str_radix(v.trim(), 16).ok());
Self {
path,
address,
id_vendor,
id_product,
}
}
pub fn format_short(&self) -> String {
let product = if let (Some(vendor), Some(product)) = (self.id_vendor, self.id_product) {
let s = usb_ids::Device::from_vid_pid(vendor, product)
.map(|v| v.name())
.unwrap_or("unknown device");
format!("{vendor:04x}:{product:04x} {s}")
} else {
"????:???? Unknown device".into()
};
format!(
"Bus {:03} Device {:03}: {product}",
self.address.bus, self.address.device
)
}
}
fn list_busses() -> io::Result<Vec<u16>> {
let mut res = vec![];
let dir = fs::read_dir("/sys/bus/usb")?;
for item in dir {
let Ok(item) = item else {
continue;
};
let Some(index) = item
.file_name()
.to_str()
.and_then(|s| s.strip_prefix("bus"))
.and_then(|s| s.parse().ok())
else {
continue;
};
res.push(index);
}
Ok(res)
}
fn list_bus(bus: u16, devices: &mut Vec<Device>) -> io::Result<()> {
let path = PathBuf::from("/sys/bus/usb").join(format!("bus{bus}"));
let dir = fs::read_dir(path)?;
for item in dir {
let Ok(item) = item else {
continue;
};
let Some(name) = item.file_name().to_str().and_then(|s| s.parse().ok()) else {
continue;
};
let address = BusAddress { bus, device: name };
devices.push(Device::read(address));
}
Ok(())
}
fn list_devices(busses: &[u16]) -> Vec<Device> {
let mut devices = vec![];
for &bus in busses {
if let Err(error) = list_bus(bus, &mut devices) {
eprintln!("bus {bus}: {error}");
}
}
devices
}
fn list_all() -> Vec<Device> {
let Ok(busses) = list_busses() else {
return vec![];
};
list_devices(&busses)
}
fn main() {
let devices = list_all();
for device in devices {
let text = device.format_short();
println!("{text}");
}
}
+1 -1
View File
@@ -10,6 +10,7 @@ path = "src/rshd/main.rs"
[dependencies]
libterm.workspace = true
cross.workspace = true
logsink.workspace = true
clap.workspace = true
thiserror.workspace = true
@@ -21,7 +22,6 @@ log.workspace = true
rand = { git = "https://git.alnyan.me/yggdrasil/rand.git", branch = "alnyan/yggdrasil-rng_core-0.6.4" }
aes = { version = "0.8.4" }
env_logger = "0.11.5"
[lints]
workspace = true
+58 -10
View File
@@ -8,7 +8,7 @@ use std::{
use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::{
crypt::signature::VerificationMethod,
crypt::{self, signature::VerificationMethod},
proto::{Decode, DecodeError, Decoder, Encode, EncodeError, Encoder},
};
@@ -19,10 +19,13 @@ use super::{
ClientNegotiationMessage, ServerNegotiationMessage,
};
pub const MESSAGE_SIZE_MAX: usize = 4096;
const FRAMING_BUFFER_SIZE: usize = 8192;
pub struct ClientSocket {
pub(crate) stream: TcpStream,
pub(crate) remote: SocketAddr,
pub(crate) buffer: [u8; 512],
pub(crate) buffer: [u8; MESSAGE_SIZE_MAX],
pub(crate) recv_buf: FramingBuffer,
pub(crate) signer: SignatureMethod,
@@ -37,13 +40,16 @@ pub enum Message<T> {
}
pub(crate) struct FramingBuffer {
buffer: [u8; 512],
buffer: [u8; FRAMING_BUFFER_SIZE],
len: usize,
}
impl FramingBuffer {
pub fn new() -> Self {
Self { buffer: [0; 512], len: 0 }
Self {
buffer: [0; FRAMING_BUFFER_SIZE],
len: 0,
}
}
pub fn get_mut(&mut self) -> &mut [u8] {
@@ -117,7 +123,7 @@ impl ClientSocket {
}
pub fn write_all<E: Encode>(&mut self, message: &E) -> Result<(), Error> {
let mut buf = [0; 256];
let mut buf = [0; MESSAGE_SIZE_MAX - 256];
let mut encoder = Encoder::new(&mut buf);
message.encode(&mut encoder)?;
@@ -126,20 +132,27 @@ impl ClientSocket {
let payload_len = payload.len();
let signature_len = self.signer.sign(payload, rest)?;
let len = self.symmetric.encrypt(&buf[..payload_len + signature_len], &mut self.buffer[size_of::<u16>()..])?;
let len = self.symmetric.encrypt(
&buf[..payload_len + signature_len],
&mut self.buffer[size_of::<u16>()..],
)?;
let len_bytes: u16 = len.try_into().unwrap();
self.buffer[..size_of::<u16>()].copy_from_slice(&len_bytes.to_le_bytes());
self.stream.write_all(&self.buffer[..len + size_of::<u16>()])?;
self.stream
.write_all(&self.buffer[..len + size_of::<u16>()])?;
Ok(())
}
pub fn poll_read<'de, D: Decode<'de>>(&mut self, buffer: &'de mut [u8]) -> Result<Message<D>, Error> {
pub fn poll_read<'de, D: Decode<'de>>(
&mut self,
buffer: &'de mut [u8],
) -> Result<Message<D>, Error> {
if self.poll()? == 0 {
return Ok(Message::Closed);
}
match self.read(buffer)? {
Some(message) => Ok(Message::Data(message)),
None => Ok(Message::Incomplete)
None => Ok(Message::Incomplete),
}
}
@@ -156,6 +169,7 @@ impl ClientSocket {
pub fn read<'de, D: Decode<'de>>(&mut self, buffer: &'de mut [u8]) -> Result<Option<D>, Error> {
if let Some(len) = self.recv_buf.pop(&mut self.buffer) {
let data_len = self.symmetric.decrypt(&self.buffer[..len], buffer)?;
let mut decoder = Decoder::new(&buffer[..data_len]);
let message = D::decode(&mut decoder)?;
@@ -188,6 +202,7 @@ impl Negotiation {
}
fn hello(&mut self, recv_buf: &mut [u8]) -> Result<(SignatureMethod, u8, u8), Error> {
log::info!("Send ClientHello v1");
self.send(None, &ClientNegotiationMessage::Hello { protocol: 1 })?;
let hello = match self.recv(None, recv_buf)? {
@@ -196,6 +211,23 @@ impl Negotiation {
_ => return Err(Error::UnexpectedServerReply),
};
log::info!("Server ciphersuites:");
for &cipher in hello.symmetric_ciphersuites {
if let Some(name) = crypt::ciphersuite_name(cipher) {
log::info!(" * {name:?} ({cipher:#x})");
} else {
log::info!(" * {cipher:#x}");
}
}
log::info!("Server signature algorithms:");
for &sig in hello.sig_algos {
if let Some(name) = crypt::sig_algo_name(sig) {
log::info!(" * {name:?} ({sig:#x})");
} else {
log::info!(" * {sig:#x}");
}
}
let signer = self
.config
.signature_keystore
@@ -218,6 +250,15 @@ impl Negotiation {
) -> Result<VerificationMethod, Error> {
let sig_algorithm = signer.algorithm();
let key_data = signer.verifying_key_bytes();
let sig_algorithm_name = crypt::sig_algo_name(sig_algorithm).unwrap_or("unknown");
let offered_fingerprint =
crypt::signature::fingerprint_sha256(sig_algorithm_name, &key_data);
log::info!("Offer {offered_fingerprint}");
let ciphersuite_name = crypt::ciphersuite_name(ciphersuite).unwrap_or("???");
log::info!("With ciphersuite {ciphersuite_name:?} ({ciphersuite:#x})");
self.send(
None,
&ClientNegotiationMessage::StartKex {
@@ -237,6 +278,11 @@ impl Negotiation {
_ => return Err(Error::UnexpectedServerReply),
};
let server_sig_algorithm_name = crypt::sig_algo_name(server_sig).unwrap_or("unknown");
let server_fingerprint =
crypt::signature::fingerprint_sha256(server_sig_algorithm_name, server_verifying_key);
log::info!("Server fingerprint {server_fingerprint}");
let verifier = self
.config
.signature_keystore
@@ -301,10 +347,12 @@ impl Negotiation {
self.key_exchange(&mut recv_buf, &mut signer, &mut verifier, ciphersuite)?;
self.finish(&mut recv_buf, &mut signer, &mut verifier)?;
log::info!("Established");
Ok(ClientSocket {
stream: self.stream,
remote: self.remote,
buffer: [0; 512],
buffer: [0; MESSAGE_SIZE_MAX],
recv_buf: FramingBuffer::new(),
signer,
verifier,
+3 -8
View File
@@ -11,9 +11,7 @@ use super::{
};
fn default_select_kex_algorithm(offer: &[u8]) -> Option<u8> {
const ACCEPTED: &[u8] = &[
V1_KEX_X25519_DALEK
];
const ACCEPTED: &[u8] = &[V1_KEX_X25519_DALEK];
for accepted in ACCEPTED {
if offer.contains(accepted) {
@@ -24,10 +22,7 @@ fn default_select_kex_algorithm(offer: &[u8]) -> Option<u8> {
}
fn default_select_ciphersuite(offer: &[u8]) -> Option<u8> {
const ACCEPTED: &[u8] = &[
V1_CIPHER_AES_256_CBC,
V1_CIPHER_AES_256_ECB,
];
const ACCEPTED: &[u8] = &[V1_CIPHER_AES_256_CBC, V1_CIPHER_AES_256_ECB];
for accepted in ACCEPTED {
if offer.contains(accepted) {
@@ -133,7 +128,7 @@ impl ServerConfig {
offer_ciphersuites: default_offer_ciphersuites,
offer_sig_algorhtms: default_offer_sig_algorithms,
offer_kex_algorithms: default_offer_kex_algorithms,
signature_keystore: Box::new(signature_keystore)
signature_keystore: Box::new(signature_keystore),
}
}
}
+10 -4
View File
@@ -1,12 +1,13 @@
use crate::proto::{Decode, DecodeError, Decoder, Encode, EncodeError, Encoder};
pub mod client;
pub mod config;
pub mod server;
pub mod signature;
pub mod util;
pub mod symmetric;
pub mod config;
pub mod util;
pub const V1_CIPHER_NULL: u8 = 0x00;
pub const V1_CIPHER_AES_256_ECB: u8 = 0x10;
pub const V1_CIPHER_AES_256_CBC: u8 = 0x11;
@@ -139,7 +140,9 @@ impl<'de> Decode<'de> for ClientNegotiationMessage<'de> {
}
Self::TAG_DH_PUBLIC_KEY => buffer.read_variable_bytes().map(Self::DHPublicKey),
Self::TAG_AGREED => Ok(Self::Agreed),
_ => Err(DecodeError::InvalidMessage),
_ => Err(DecodeError::InvalidMessage(
"Invalid ClientNegotiationMessage tag",
)),
}
}
}
@@ -193,13 +196,16 @@ impl<'de> Decode<'de> for ServerNegotiationMessage<'de> {
Self::TAG_DH_PUBLIC_KEY => buffer.read_variable_bytes().map(Self::DHPublicKey),
Self::TAG_AGREED => Ok(Self::Agreed),
_ => Err(DecodeError::InvalidMessage),
_ => Err(DecodeError::InvalidMessage(
"Invalid ServerNegotiationMessage tag",
)),
}
}
}
pub fn ciphersuite_name(cipher: u8) -> Option<&'static str> {
match cipher {
V1_CIPHER_NULL => Some("null"),
V1_CIPHER_AES_256_ECB => Some("aes-256-ecb"),
V1_CIPHER_AES_256_CBC => Some("aes-256-cbc"),
_ => None,
+6 -3
View File
@@ -9,7 +9,10 @@ use cross::io::Poll;
use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::{
crypt::{sig_algo_name, signature::fingerprint_sha256, ServerHello, ServerNegotiationMessage},
crypt::{
client::MESSAGE_SIZE_MAX, sig_algo_name, signature::fingerprint_sha256, ServerHello,
ServerNegotiationMessage,
},
proto::{Decode, DecodeError, Decoder, Encode, EncodeError, Encoder},
};
@@ -128,7 +131,7 @@ impl ServerSocket {
remote: address,
verifier: client.verifier.unwrap(),
symmetric: client.symmetric.unwrap(),
buffer: [0; 512],
buffer: [0; MESSAGE_SIZE_MAX],
}))
}
Err(error) => {
@@ -182,7 +185,7 @@ impl PendingClient {
let sig_algorithm_name = sig_algo_name(sig_algorithm).unwrap_or("???");
let their_fingerprint = fingerprint_sha256(sig_algorithm_name, key_data);
log::info!("{address}: {their_fingerprint}");
log::info!("{address}: their fingerprint {their_fingerprint}");
let verifier = match config
.signature_keystore
+96 -17
View File
@@ -3,7 +3,7 @@ use aes::{
Aes256, Block,
};
use crate::crypt::util;
use crate::crypt::{util, V1_CIPHER_NULL};
use super::{V1_CIPHER_AES_256_CBC, V1_CIPHER_AES_256_ECB};
@@ -19,12 +19,14 @@ pub struct Aes256BlockCipher<M: AesBlockMode> {
pub struct CipherModeEcb;
pub struct CipherModeCbc {
iv: Block,
iv_encrypt: Block,
iv_decrypt: Block,
}
pub enum SymmetricCipher {
Aes256Ecb(Aes256BlockCipher<CipherModeEcb>),
Aes256Cbc(Aes256BlockCipher<CipherModeCbc>),
Null,
}
pub struct Pkcs7Padder<'src> {
@@ -193,17 +195,17 @@ impl AesBlockMode for CipherModeEcb {
impl AesBlockMode for CipherModeCbc {
fn encryption(&mut self, aes: &Aes256, block: Block) -> Block {
let mut block = util::xor16b(block, self.iv);
let mut block = util::xor16b(block, self.iv_encrypt);
aes.encrypt_block(&mut block);
self.iv = block;
self.iv_encrypt = block;
block
}
fn decryption(&mut self, aes: &Aes256, ciphertext: Block) -> Block {
let mut block = ciphertext;
aes.decrypt_block(&mut block);
let block = util::xor16b(block, self.iv);
self.iv = ciphertext;
let block = util::xor16b(block, self.iv_decrypt);
self.iv_decrypt = ciphertext;
block
}
}
@@ -211,13 +213,18 @@ impl AesBlockMode for CipherModeCbc {
impl SymmetricCipher {
pub fn new(suite: u8, shared_key: &[u8]) -> Result<Self, Error> {
match suite {
V1_CIPHER_NULL => Ok(Self::Null),
V1_CIPHER_AES_256_ECB => {
Aes256BlockCipher::new(shared_key, CipherModeEcb).map(Self::Aes256Ecb)
}
V1_CIPHER_AES_256_CBC => {
Aes256BlockCipher::new(shared_key, CipherModeCbc { iv: [0; 16].into() })
.map(Self::Aes256Cbc)
}
V1_CIPHER_AES_256_CBC => Aes256BlockCipher::new(
shared_key,
CipherModeCbc {
iv_encrypt: [0; 16].into(),
iv_decrypt: [0; 16].into(),
},
)
.map(Self::Aes256Cbc),
_ => unreachable!(),
}
}
@@ -226,6 +233,14 @@ impl SymmetricCipher {
match self {
Self::Aes256Ecb(cipher) => cipher.encrypt(src, dst),
Self::Aes256Cbc(cipher) => cipher.encrypt(src, dst),
Self::Null => {
if src.len() > dst.len() {
return Err(Error::MessageTooLarge(dst.len(), src.len()));
}
dst[..src.len()].copy_from_slice(src);
Ok(src.len())
}
}
}
@@ -233,6 +248,14 @@ impl SymmetricCipher {
match self {
Self::Aes256Ecb(cipher) => cipher.decrypt(src, dst),
Self::Aes256Cbc(cipher) => cipher.decrypt(src, dst),
Self::Null => {
if src.len() > dst.len() {
return Err(Error::MessageTooLarge(dst.len(), src.len()));
}
dst[..src.len()].copy_from_slice(src);
Ok(src.len())
}
}
}
}
@@ -300,12 +323,13 @@ mod tests {
#[test]
fn test_pkcs7_pad_reversible() {
let texts = [&b"Hello"[..], &[16; 16], &[32; 16], &[1; 16]];
let text = "1234567890ABCDEF";
let mut buffer = [0; 256];
for text in texts {
for i in 0..text.len() {
let text = &text.as_bytes()[..i];
let output = pad_unpad(text, &mut buffer);
assert_eq!(text, output);
assert_eq!(output, text);
}
}
@@ -315,10 +339,22 @@ mod tests {
let key = b"1234ABCD1234ABCD1234ABCD1234ABCD";
let mut encrypted = [0; 256];
let mut decrypted = [0; 256];
let mut enc_cipher =
Aes256BlockCipher::new(key, CipherModeCbc { iv: [0; 16].into() }).unwrap();
let mut dec_cipher =
Aes256BlockCipher::new(key, CipherModeCbc { iv: [0; 16].into() }).unwrap();
let mut enc_cipher = Aes256BlockCipher::new(
key,
CipherModeCbc {
iv_encrypt: [0; 16].into(),
iv_decrypt: [0; 16].into(),
},
)
.unwrap();
let mut dec_cipher = Aes256BlockCipher::new(
key,
CipherModeCbc {
iv_encrypt: [0; 16].into(),
iv_decrypt: [0; 16].into(),
},
)
.unwrap();
for text in messages {
let len = enc_cipher.encrypt(text, &mut encrypted).unwrap();
@@ -328,4 +364,47 @@ mod tests {
assert_eq!(&decrypted[..len], text);
}
}
#[test]
fn test_aes256cbc_large_message() {
let data = include_bytes!("../../tests/test-message.dat");
let key = b"1234ABCD1234ABCD1234ABCD1234ABCD";
let mut encrypt_buffer = [0; 512];
let mut decrypt_buffer = [0; 512];
let mut enc_cipher = Aes256BlockCipher::new(
key,
CipherModeCbc {
iv_encrypt: [0; 16].into(),
iv_decrypt: [0; 16].into(),
},
)
.unwrap();
let mut dec_cipher = Aes256BlockCipher::new(
key,
CipherModeCbc {
iv_encrypt: [0; 16].into(),
iv_decrypt: [0; 16].into(),
},
)
.unwrap();
let mut position = 0;
while position < data.len() {
let count = (data.len() - position).min(400);
let enc_len = enc_cipher
.encrypt(&data[position..position + count], &mut encrypt_buffer)
.unwrap();
let dec_len = dec_cipher
.decrypt(&encrypt_buffer[..enc_len], &mut decrypt_buffer)
.unwrap();
assert_eq!(dec_len, count);
assert_eq!(
&decrypt_buffer[..dec_len],
&data[position..position + count]
);
position += count;
}
}
}
+1 -1
View File
@@ -311,7 +311,7 @@ fn run(args: Args) -> Result<ExitCode, Error> {
}
fn main() -> ExitCode {
env_logger::init();
logsink::setup_logging(false);
let args = Args::parse();
match run(args) {
+14 -28
View File
@@ -42,8 +42,8 @@ pub enum EncodeError {
pub enum DecodeError {
#[error("Truncated message received")]
Truncated,
#[error("Malformed message received")]
InvalidMessage,
#[error("Malformed message received: {0}")]
InvalidMessage(&'static str),
#[error("Malformed string in the message")]
InvalidString(core::str::Utf8Error),
}
@@ -194,7 +194,7 @@ impl<'de> Decode<'de> for StreamIndex {
match tag {
Self::TAG_STDOUT => Ok(Self::Stdout),
Self::TAG_STDERR => Ok(Self::Stderr),
_ => Err(DecodeError::InvalidMessage)
_ => Err(DecodeError::InvalidMessage("Invalid StreamIndex tag")),
}
}
}
@@ -234,7 +234,7 @@ impl Encode for ClientMessage<'_> {
buffer.write(&[Self::TAG_INPUT])?;
buffer.write_variable_bytes(data)
}
Self::CloseStdin => buffer.write(&[Self::TAG_CLOSE_STDIN])
Self::CloseStdin => buffer.write(&[Self::TAG_CLOSE_STDIN]),
}
}
}
@@ -243,22 +243,12 @@ impl<'de> Decode<'de> for ClientMessage<'de> {
fn decode(buffer: &mut Decoder<'de>) -> Result<Self, DecodeError> {
let tag = buffer.read_u8()?;
match tag {
Self::TAG_OPEN_SESSION => {
TerminalInfo::decode(buffer).map(Self::OpenSession)
}
Self::TAG_RUN_COMMAND => {
buffer.read_str().map(Self::RunCommand)
}
Self::TAG_BYE => {
buffer.read_str().map(Self::Bye)
}
Self::TAG_INPUT => {
buffer.read_variable_bytes().map(Self::Input)
}
Self::TAG_CLOSE_STDIN => {
Ok(Self::CloseStdin)
}
_ => Err(DecodeError::InvalidMessage)
Self::TAG_OPEN_SESSION => TerminalInfo::decode(buffer).map(Self::OpenSession),
Self::TAG_RUN_COMMAND => buffer.read_str().map(Self::RunCommand),
Self::TAG_BYE => buffer.read_str().map(Self::Bye),
Self::TAG_INPUT => buffer.read_variable_bytes().map(Self::Input),
Self::TAG_CLOSE_STDIN => Ok(Self::CloseStdin),
_ => Err(DecodeError::InvalidMessage("Invalid ClientMessage tag")),
}
}
}
@@ -274,9 +264,7 @@ impl Encode for ServerMessage<'_> {
fn encode(&self, buffer: &mut Encoder) -> Result<(), EncodeError> {
match self {
Self::SessionOpen => buffer.write(&[Self::TAG_SESSION_OPEN]),
Self::CommandStatus(status) => {
buffer.write(&status.to_le_bytes())
}
Self::CommandStatus(status) => buffer.write(&status.to_le_bytes()),
Self::Bye(reason) => {
buffer.write(&[Self::TAG_BYE])?;
buffer.write_str(reason)
@@ -301,15 +289,13 @@ impl<'de> Decode<'de> for ServerMessage<'de> {
status.copy_from_slice(bytes);
Ok(Self::CommandStatus(i32::from_le_bytes(status)))
}
Self::TAG_BYE => {
buffer.read_str().map(Self::Bye)
}
Self::TAG_BYE => buffer.read_str().map(Self::Bye),
Self::TAG_OUTPUT => {
let index = StreamIndex::decode(buffer)?;
let data = buffer.read_variable_bytes()?;
Ok(Self::Output(index, data))
},
_ => Err(DecodeError::InvalidMessage),
}
_ => Err(DecodeError::InvalidMessage("Invalid ServerMessage tag")),
}
}
}
+52 -58
View File
@@ -1,14 +1,25 @@
#![cfg_attr(target_os = "yggdrasil", feature(yggdrasil_os, rustc_private))]
#![feature(if_let_guard)]
use std::{
collections::HashSet, net::SocketAddr, os::fd::{self, IntoRawFd}, path::PathBuf, process::ExitCode, str::FromStr, time::Duration
collections::HashSet,
io::{self, Read, Write},
net::SocketAddr,
os::fd::{self, AsRawFd, FromRawFd, IntoRawFd, RawFd},
path::PathBuf,
process::{self, Command, ExitCode, Stdio},
str::FromStr,
time::Duration,
};
use clap::Parser;
use cross::io::PidFd;
use cross::{
io::{PidFd, PtyMaster, TerminalOptions, TerminalOptionsImpl, TerminalSize},
process::CommandSpawnExt,
};
use rsh::{
crypt::config::{ServerConfig, SimpleServerKeyStore},
server::{self, Server, SessionClient},
proto::TerminalInfo,
server::{self, Server, Session, SessionClient},
};
pub const PING_INTERVAL: Duration = Duration::from_millis(500);
@@ -26,55 +37,42 @@ struct Args {
keystore: PathBuf,
}
#[cfg(target_os = "yggdrasil")]
pub struct YggdrasilSession {
pty_master: std::fs::File,
fds: [std::os::fd::RawFd; 2],
pub struct SessionImpl {
fds: [RawFd; 2],
pty_master: PtyMaster,
remote: SocketAddr,
shell: std::process::Child,
pidfd: PidFd,
shell: process::Child,
pidfd: Option<PidFd>,
}
#[cfg(target_os = "yggdrasil")]
impl rsh::server::Session for YggdrasilSession {
impl Session for SessionImpl {
type Error = std::io::Error;
fn open(remote: &SocketAddr, terminal: &rsh::proto::TerminalInfo) -> Result<Self, Self::Error> {
use std::{
os::{
fd::{AsRawFd, FromRawFd},
yggdrasil::{
self,
io::terminal::{create_pty, TerminalSize},
process::CommandExt,
},
},
process::{Command, Stdio},
};
fn open(remote: &SocketAddr, terminal: &TerminalInfo) -> Result<Self, Self::Error> {
let remote = *remote;
// TODO unix version
let (pty_master, pty_slave) = create_pty(
Default::default(),
let termios = TerminalOptionsImpl::normal();
let (pty_master, pty_slave) = cross::io::open_pty(
&termios,
TerminalSize {
columns: terminal.columns as _,
rows: terminal.rows as _,
columns: terminal.columns as _,
x_pixels: 0,
y_pixels: 0,
},
)?;
let pty_slave_stdin = pty_slave.into_raw_fd();
let pty_slave_stdout = fd::clone_fd(pty_slave_stdin)?;
let pty_slave_stderr = fd::clone_fd(pty_slave_stdin)?;
let pty_slave_stdout = cross::io::clone_fd(pty_slave_stdin)?;
let pty_slave_stderr = cross::io::clone_fd(pty_slave_stdin)?;
let group_id = yggdrasil::process::create_process_group();
let shell = unsafe {
Command::new("/bin/sh")
.arg("-l")
.stdin(Stdio::from_raw_fd(pty_slave_stdin))
.stdout(Stdio::from_raw_fd(pty_slave_stdout))
.stderr(Stdio::from_raw_fd(pty_slave_stderr))
.process_group(group_id)
.gain_terminal(0)
.create_session()?
.spawn()?
};
let pidfd = PidFd::new(shell.id())?;
@@ -82,16 +80,17 @@ impl rsh::server::Session for YggdrasilSession {
let fds = [pty_master.as_raw_fd(), pidfd.as_raw_fd()];
Ok(Self {
pty_master,
pidfd,
shell,
remote,
pty_master,
pidfd: Some(pidfd),
shell,
fds,
})
}
fn close(mut self) -> Result<(), Self::Error> {
self.shell.wait()?;
self.pidfd = None;
Ok(())
}
@@ -99,12 +98,7 @@ impl rsh::server::Session for YggdrasilSession {
self.remote
}
fn handle_input(
&mut self,
input: &[u8],
_client: SessionClient,
) -> Result<bool, Self::Error> {
use std::io::Write;
fn handle_input(&mut self, input: &[u8], _client: SessionClient) -> Result<bool, Self::Error> {
self.pty_master.write_all(input)?;
Ok(false)
}
@@ -113,29 +107,32 @@ impl rsh::server::Session for YggdrasilSession {
&mut self,
fd: std::os::fd::RawFd,
buffer: &mut [u8],
) -> Result<usize, Self::Error> {
use std::io::Read;
) -> Result<Option<usize>, Self::Error> {
if fd == self.fds[0] {
self.pty_master.read(buffer)
self.pty_master.read(buffer).map(Some)
} else if fd == self.fds[1] {
let status = self.pidfd.exit_status()?;
log::info!("Shell exited with status: {status}");
Ok(0)
let Some(pidfd) = self.pidfd.as_mut() else {
return Ok(None);
};
match pidfd.exit_status() {
Ok(status) => {
log::info!("Shell exited with status: {status}");
Ok(Some(0))
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(None),
Err(e) => Err(e),
}
} else {
unreachable!()
}
}
fn event_fds(&self) -> &[std::os::fd::RawFd] {
fn event_fds(&self) -> &[RawFd] {
&self.fds
}
}
#[cfg(any(unix, rust_analyzer))]
pub type SessionImpl = rsh::server::EchoSession;
#[cfg(any(target_os = "yggdrasil", rust_analyzer))]
pub type SessionImpl = YggdrasilSession;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Server error: {0}")]
@@ -156,10 +153,7 @@ fn run(args: Args) -> Result<(), Error> {
}
fn main() -> ExitCode {
env_logger::Builder::new()
.filter_level(log::LevelFilter::Debug)
.format_timestamp(None)
.init();
logsink::setup_logging(true);
let args = Args::parse();
if let Err(error) = run(args) {
eprintln!("Finished with error: {error}");
+19 -6
View File
@@ -38,7 +38,7 @@ pub trait Session: Sized {
fn open(peer: &SocketAddr, terminal: &TerminalInfo) -> Result<Self, Self::Error>;
fn peer(&self) -> SocketAddr;
fn handle_input(&mut self, input: &[u8], client: SessionClient) -> Result<bool, Self::Error>;
fn read_output(&mut self, fd: RawFd, buffer: &mut [u8]) -> Result<usize, Self::Error>;
fn read_output(&mut self, fd: RawFd, buffer: &mut [u8]) -> Result<Option<usize>, Self::Error>;
fn event_fds(&self) -> &[RawFd];
fn close(self) -> Result<(), Self::Error>;
}
@@ -145,6 +145,7 @@ impl<T: Session> ClientSet<T> {
}
};
log::info!("Add client {key}");
poll.add(&client.stream)?;
self.socket_fd_map.insert(client.stream.as_raw_fd(), key);
@@ -159,6 +160,7 @@ impl<T: Session> ClientSet<T> {
match client.session {
ClientSession::Terminal(terminal) => {
for fd in terminal.event_fds() {
log::info!("{key}: remove fd {fd:?}");
poll.remove(fd)?;
self.session_fd_map.remove(fd);
}
@@ -179,6 +181,8 @@ impl<T: Session> ClientSet<T> {
}
ClientSession::None => (),
}
log::info!("Remove client {key}");
}
Ok(())
}
@@ -248,6 +252,7 @@ impl<T: Session> ClientSet<T> {
};
let terminal = client.session.set_terminal(terminal);
for fd in terminal.event_fds() {
log::info!("{key}: add fd {fd:?}");
poll.add(fd)?;
self.session_fd_map.insert(*fd, key);
}
@@ -336,7 +341,7 @@ impl<T: Session> ClientSet<T> {
return self.remove(key, poll);
}
CommandEvent::Exited(status) => {
log::info!("{peer}: command exited: {:?}", status);
log::info!("{peer}: command exited: {status:?}");
poll.remove(&fd)?;
self.socket_fd_map.remove(&fd);
command.child_pid = None;
@@ -350,7 +355,10 @@ impl<T: Session> ClientSet<T> {
}
},
ClientSession::Terminal(terminal) => match terminal.read_output(fd, &mut buffer) {
Ok(0) => {
Ok(None) => {
return Ok(());
}
Ok(Some(0)) => {
poll.remove(&fd)?;
self.socket_fd_map.remove(&fd);
@@ -358,7 +366,7 @@ impl<T: Session> ClientSet<T> {
log::info!("{peer}: terminal closed");
return self.remove(key, poll);
}
Ok(len) => (len, StreamIndex::Stdout),
Ok(Some(len)) => (len, StreamIndex::Stdout),
Err(error) => {
log::error!("{peer}: terminal error: {error}");
return self.remove(key, poll);
@@ -444,6 +452,7 @@ impl<T: Session> Server<T> {
let mut poll = Poll::new()?;
let socket = ServerSocket::bind(listen_addr, crypto_config)?;
poll.add(&socket)?;
log::info!("Listening on {listen_addr}");
Ok(Self {
poll,
socket,
@@ -499,8 +508,12 @@ impl Session for EchoSession {
&[]
}
fn read_output(&mut self, _fd: RawFd, _buffer: &mut [u8]) -> Result<usize, Self::Error> {
Ok(0)
fn read_output(
&mut self,
_fd: RawFd,
_buffer: &mut [u8],
) -> Result<Option<usize>, Self::Error> {
Ok(None)
}
fn handle_input(
Binary file not shown.
+1
View File
@@ -318,6 +318,7 @@ pub fn wait_for_pipeline(handles: Vec<Handle>) -> Result<(Outcome, Option<ExitCo
#[cfg(any(unix, rust_analyzer))]
unsafe {
use std::os::fd::AsRawFd;
libc::tcsetpgrp(io::stdout().as_raw_fd(), libc::getpgrp());
}
// set_terminal_group(&stdin(), pgid)?;
+19 -1
View File
@@ -18,7 +18,10 @@ use std::{
use clap::Parser;
use command::env::Environment;
use cross::term::TerminalInput;
use cross::{
io::{TerminalOptions, TerminalOptionsImpl},
term::TerminalInput,
};
use error::Error;
use exec::Outcome;
@@ -138,7 +141,22 @@ fn run(mut input: ShellInput, env: &mut Environment) -> Result<(), Error> {
continue;
}
};
let old_termios = match &mut input {
ShellInput::Interactive(interactive) => {
let new = TerminalOptionsImpl::normal();
Some(interactive.set_options(new)?)
}
_ => None,
};
let (outcome, exit) = command::eval::evaluate(&expr, env);
match &mut input {
ShellInput::Interactive(interactive) => {
interactive.set_options(old_termios.unwrap()).ok();
}
_ => (),
}
command_text.clear();
if !outcome.is_success() {
eprintln!("{outcome:?}");
+1
View File
@@ -63,6 +63,7 @@ impl<'e> CargoBuilder<'e> {
};
command.env("LD_LIBRARY_PATH", ld_library_path);
command.env("CC", "clang");
command
.arg("+ygg-stage1")
+1
View File
@@ -53,6 +53,7 @@ const PROGRAMS: &[(&str, &str)] = &[
("sync", "bin/sync"),
("sleep", "bin/sleep"),
("lspci", "bin/lspci"),
("lsusb", "bin/lsusb"),
("ps", "bin/ps"),
("top", "bin/top"),
("tst", "bin/tst"),
+28 -3
View File
@@ -6,7 +6,7 @@ use std::{
use qemu::{
aarch64,
device::{QemuDevice, QemuDrive, QemuNic, QemuSerialTarget},
device::{IntelGigabitRev, QemuDevice, QemuDrive, QemuNic, QemuSerialTarget},
i386, riscv64, x86_64, Qemu,
};
@@ -17,13 +17,24 @@ use crate::{
util::run_external_command,
};
#[derive(Debug, Clone, Copy, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "lowercase")]
enum QemuIgbeRev {
I82574L,
I82544GC,
I82545EM,
I82540EM,
}
#[derive(Debug, Default, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "kebab-case")]
enum QemuNetworkInterface {
#[default]
VirtioNet,
Rtl8139,
IntelGigabit,
IntelGigabit {
rev: Option<QemuIgbeRev>,
},
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
@@ -332,7 +343,10 @@ fn add_devices_from_config(
let nic = match config.network.interface {
QemuNetworkInterface::VirtioNet => QemuNic::VirtioPci { mac },
QemuNetworkInterface::Rtl8139 => QemuNic::Rtl8139 { mac },
QemuNetworkInterface::IntelGigabit => QemuNic::IntelGigabit { mac },
QemuNetworkInterface::IntelGigabit { rev } => QemuNic::IntelGigabit {
mac,
rev: rev.map(Into::into),
},
};
devices.push(QemuDevice::NetworkTap {
nic,
@@ -404,3 +418,14 @@ pub fn run(
Ok(())
}
impl From<QemuIgbeRev> for IntelGigabitRev {
fn from(value: QemuIgbeRev) -> Self {
match value {
QemuIgbeRev::I82574L => IntelGigabitRev::I82574L,
QemuIgbeRev::I82544GC => IntelGigabitRev::I82544GC,
QemuIgbeRev::I82545EM => IntelGigabitRev::I82545EM,
QemuIgbeRev::I82540EM => IntelGigabitRev::I82540EM,
}
}
}