use std::{ collections::HashMap, fs::{File, OpenOptions}, io::{Read, Seek, SeekFrom, Write}, ops::Range, path::{Path, PathBuf}, process::ExitCode, }; use clap::Parser; use elf::{ abi::{EM_386, EM_AARCH64, EM_RISCV, EM_X86_64, PT_LOAD}, endian::AnyEndian, ElfStream, }; use memtables::any::AnyTables; use riscv64::Riscv64Builder; use thiserror::Error; use crate::{aarch64::AArch64Builder, x86_64::X8664Builder}; mod aarch64; mod riscv64; mod x86_64; #[derive(Error, Debug)] pub enum GenError { #[error("I/O error: {0}")] IoError(#[from] std::io::Error), #[error("ELF parse error: {0}")] ElfParseError(#[from] elf::ParseError), #[error("Image's arhitecture is not supported")] UnsupportedArchitecture, #[error("Could not determine the kernel image address range (possibly incorrect segments?)")] NoKernelImageRange, #[error("Kernel image is too large: {0:#x?} ({1}B). Maximum size: {2}B")] KernelTooLarge(Range, u64, u64), #[error("Kernel image is missing a required symbol: {0:?}")] MissingSymbol(&'static str), #[error("Kernel image is missing a required section: {0:?}")] MissingSection(&'static str), #[error("Kernel image is missing a symbol table")] MissingSymbolTable, #[error("Incorrect tables section placement: {0:#x}")] IncorrectTablesPlacement(u64), #[error("Incorrect tables section size: got {0}, expected {1}")] IncorrectTablesSize(u64, usize), #[error("Incorrect tables section alignment: expected at least 4K, got {0}")] IncorrectTablesAlign(u64), } #[derive(Parser)] struct Args { image: PathBuf, symbol_out: PathBuf, } pub enum ImageHeader { Riscv(u64), } pub struct GenData { pub kernel_start: u64, pub kernel_end: u64, pub table_offset: u64, pub table_physical_address: u64, pub kernel_virt_offset: u64, } pub struct BuiltTables { pub image_size: u64, pub image_header: Option, pub tables: Option<(AnyTables, u64)>, pub symbol_table: HashMap, } fn kernel_image_range( elf: &mut ElfStream, kernel_virt_offset: u64, ) -> Result<(u64, u64), GenError> { let mut start = u64::MAX; let mut end = u64::MIN; for segment in elf.segments() { if segment.p_type != PT_LOAD || segment.p_vaddr != segment.p_paddr + kernel_virt_offset { continue; } let aligned_start = segment.p_vaddr & !0xFFF; let aligned_end = (segment.p_vaddr + segment.p_memsz + 0xFFF) & !0xFFF; if aligned_end > end { end = aligned_end; } if aligned_start < start { start = aligned_start; } } if start < end { Ok((start, end)) } else { Err(GenError::NoKernelImageRange) } } fn kernel_virt_offset(elf: &mut ElfStream) -> Result { let (symtab, symstrtab) = elf .symbol_table()? .ok_or_else(|| GenError::MissingSection(".symtab"))?; for sym in symtab { let name = symstrtab.get(sym.st_name as _)?; if name == "KERNEL_VIRT_OFFSET" { // TODO symbol checks return Ok(sym.st_value); } } Err(GenError::MissingSymbol("KERNEL_VIRT_OFFSET")) } fn find_tables( elf: &mut ElfStream, ) -> Result<(Option, u64, u64), GenError> { let section_size = match elf.ehdr.e_machine { EM_AARCH64 => size_of::(), EM_X86_64 => size_of::(), EM_RISCV => size_of::(), _ => unimplemented!(), }; let image_header = if let Some(text_entry) = elf.section_header_by_name(".text.entry")? { let header = text_entry.clone(); let section_offset = header.sh_offset; let (data, _) = elf.section_data(&header)?; if data.len() >= 64 { let version = u32::from_le_bytes(data[32..36].try_into().unwrap()); let magic0 = &data[48..56]; let magic1 = &data[56..60]; match (version, magic0, magic1) { (2, b"RISCV\x00\x00\x00", b"RSC\x05") => Some(ImageHeader::Riscv(section_offset)), (_, _, _) => None, } } else { None } } else { None }; let (shdrs, strtab) = elf.section_headers_with_strtab()?; let strtab = strtab.ok_or_else(|| GenError::MissingSection(".strtab"))?; let mut tables = None; for shdr in shdrs { let name = strtab.get(shdr.sh_name as _)?; if name == ".data.tables" { if shdr.sh_size != section_size as _ { return Err(GenError::IncorrectTablesSize(shdr.sh_size, section_size)); } if shdr.sh_addralign < 0x1000 { return Err(GenError::IncorrectTablesAlign(shdr.sh_addralign)); } // TODO section checks tables = Some((shdr.sh_offset, shdr.sh_addr)); } } let (tables_offset, tables_addr) = tables.ok_or(GenError::MissingSection(".data.tables"))?; Ok((image_header, tables_offset, tables_addr)) } fn extract_symbols( elf: &mut ElfStream, ) -> Result, GenError> { let mut table = HashMap::new(); // let mut export_tables = HashSet::new(); // let force_export = &[ // "rust_begin_unwind", // "__rust_alloc", // "__rust_alloc_zeroed", // "__rust_dealloc", // "__rust_realloc", // ]; // // Find all .export table indices // let (shdrs, strtab) = elf.section_headers_with_strtab()?; // let strtab = strtab.ok_or_else(|| GenError::MissingSection(".strtab"))?; // for (i, shdr) in shdrs.iter().enumerate() { // if shdr.sh_type != elf::abi::SHT_PROGBITS { // continue; // } // let name = strtab.get(shdr.sh_name as _)?; // if name.starts_with(".export.") { // println!("Export from {:?}", name); // export_tables.insert(i as u16); // } // } let (symtab, strtab) = elf.symbol_table()?.ok_or(GenError::MissingSymbolTable)?; // Only produce symbols from .export tables for sym in symtab { if sym.st_vis() == elf::abi::STV_HIDDEN { continue; } let name = strtab.get(sym.st_name as _)?; // if export_tables.contains(&sym.st_shndx) || force_export.contains(&name) { table.insert(name.to_owned(), sym.st_value as usize); // } } println!("{} exported symbols extracted", table.len()); Ok(table) } fn into_any, U, V>((x, y, z): (T, U, V)) -> (AnyTables, U, V) { (x.into(), y, z) } fn build_tables(file: F) -> Result { let mut elf = ElfStream::::open_stream(file)?; if elf.ehdr.e_machine == EM_386 { // Locate symbol table let symbol_table = extract_symbols(&mut elf)?; return Ok(BuiltTables { image_size: 0, image_header: None, tables: None, symbol_table, }); } let kernel_virt_offset = kernel_virt_offset(&mut elf)?; let (kernel_start, kernel_end) = kernel_image_range(&mut elf, kernel_virt_offset)?; let (image_header, table_offset, table_virt_addr) = find_tables(&mut elf)?; let table_physical_address = table_virt_addr .checked_sub(kernel_virt_offset) .ok_or_else(|| GenError::IncorrectTablesPlacement(table_virt_addr))?; println!("Kernel image range: {:#x?}", kernel_start..kernel_end); println!("KERNEL_VIRT_OFFSET = {:#x}", kernel_virt_offset); let gen_data = GenData { kernel_virt_offset, kernel_start, kernel_end, table_offset, table_physical_address, }; let (tables, table_offset, symbol_table) = match elf.ehdr.e_machine { EM_X86_64 => X8664Builder::new(elf, gen_data)?.build().map(into_any), EM_AARCH64 => AArch64Builder::new(elf, gen_data)?.build().map(into_any), EM_RISCV => Riscv64Builder::new(elf, gen_data)?.build().map(into_any), _ => todo!(), }?; Ok(BuiltTables { image_size: kernel_end - kernel_start, image_header, tables: Some((tables, table_offset)), symbol_table, }) } fn write_tables( mut file: F, offset: u64, tables: AnyTables, ) -> Result<(), GenError> { file.seek(SeekFrom::Start(offset))?; file.write_all(tables.as_bytes())?; Ok(()) } fn write_image_header( file: &mut F, header: ImageHeader, image_size: u64, ) -> Result<(), GenError> { match header { ImageHeader::Riscv(offset) => { let size_bytes = image_size.to_le_bytes(); println!("Writing RISC-V image header: image_size={image_size}"); file.seek(SeekFrom::Start(offset + 16))?; file.write_all(&size_bytes)?; Ok(()) } } } fn write_symbol_table( out: impl AsRef, table: HashMap, ) -> Result<(), GenError> { let mut file = File::create(out)?; for (name, value) in table { let len: u32 = name.len().try_into().unwrap(); let value: u64 = value.try_into().unwrap(); file.write_all(&len.to_le_bytes())?; file.write_all(name.as_bytes())?; file.write_all(&value.to_le_bytes())?; } Ok(()) } fn gentables(image: impl AsRef, symbol_out: impl AsRef) -> Result<(), GenError> { let mut file = OpenOptions::new() .read(true) .write(true) .truncate(false) .open(image)?; let built = build_tables(&mut file)?; write_symbol_table(symbol_out, built.symbol_table)?; if let Some(header) = built.image_header { write_image_header(&mut file, header, built.image_size)?; } if let Some((tables, file_offset)) = built.tables { write_tables(file, file_offset, tables)?; } Ok(()) } fn main() -> ExitCode { let args = Args::parse(); match gentables(&args.image, &args.symbol_out) { Ok(()) => ExitCode::SUCCESS, Err(err) => { eprintln!("{}: {}", args.image.display(), err); ExitCode::FAILURE } } }