306 lines
8.7 KiB
Rust

use std::{
collections::HashMap,
fs::{File, OpenOptions},
io::{Read, Seek, SeekFrom, Write},
ops::Range,
path::{Path, PathBuf},
process::ExitCode,
};
use clap::Parser;
use elf::{
abi::{EM_386, EM_AARCH64, EM_RISCV, EM_X86_64, PT_LOAD},
endian::AnyEndian,
ElfStream,
};
use memtables::any::AnyTables;
use thiserror::Error;
use crate::{aarch64::AArch64Builder, x86_64::X8664Builder};
mod aarch64;
mod x86_64;
#[derive(Error, Debug)]
pub enum GenError {
#[error("I/O error: {0}")]
IoError(#[from] std::io::Error),
#[error("ELF parse error: {0}")]
ElfParseError(#[from] elf::ParseError),
#[error("Image's arhitecture is not supported")]
UnsupportedArchitecture,
#[error("Could not determine the kernel image address range (possibly incorrect segments?)")]
NoKernelImageRange,
#[error("Kernel image is too large: {0:#x?} ({1}B). Maximum size: {2}B")]
KernelTooLarge(Range<u64>, u64, u64),
#[error("Kernel image is missing a required symbol: {0:?}")]
MissingSymbol(&'static str),
#[error("Kernel image is missing a required section: {0:?}")]
MissingSection(&'static str),
#[error("Kernel image is missing a symbol table")]
MissingSymbolTable,
#[error("Incorrect tables section placement: {0:#x}")]
IncorrectTablesPlacement(u64),
#[error("Incorrect tables section size: got {0}, expected {1}")]
IncorrectTablesSize(u64, usize),
#[error("Incorrect tables section alignment: expected at least 4K, got {0}")]
IncorrectTablesAlign(u64),
}
#[derive(Parser)]
struct Args {
image: PathBuf,
symbol_out: PathBuf,
}
pub struct GenData {
pub kernel_start: u64,
pub kernel_end: u64,
pub table_offset: u64,
pub table_physical_address: u64,
pub kernel_virt_offset: u64,
}
fn kernel_image_range<F: Read + Seek>(
elf: &mut ElfStream<AnyEndian, F>,
kernel_virt_offset: u64,
) -> Result<(u64, u64), GenError> {
let mut start = u64::MAX;
let mut end = u64::MIN;
for segment in elf.segments() {
if segment.p_type != PT_LOAD || segment.p_vaddr != segment.p_paddr + kernel_virt_offset {
continue;
}
let aligned_start = segment.p_vaddr & !0xFFF;
let aligned_end = (segment.p_vaddr + segment.p_memsz + 0xFFF) & !0xFFF;
if aligned_end > end {
end = aligned_end;
}
if aligned_start < start {
start = aligned_start;
}
}
if start < end {
Ok((start, end))
} else {
Err(GenError::NoKernelImageRange)
}
}
fn kernel_virt_offset<F: Read + Seek>(elf: &mut ElfStream<AnyEndian, F>) -> Result<u64, GenError> {
let (symtab, symstrtab) = elf
.symbol_table()?
.ok_or_else(|| GenError::MissingSection(".symtab"))?;
for sym in symtab {
let name = symstrtab.get(sym.st_name as _)?;
if name == "KERNEL_VIRT_OFFSET" {
// TODO symbol checks
return Ok(sym.st_value);
}
}
Err(GenError::MissingSymbol("KERNEL_VIRT_OFFSET"))
}
fn find_tables<F: Read + Seek>(elf: &mut ElfStream<AnyEndian, F>) -> Result<(u64, u64), GenError> {
let section_size = match elf.ehdr.e_machine {
EM_AARCH64 => size_of::<memtables::aarch64::FixedTables>(),
EM_X86_64 => size_of::<memtables::x86_64::FixedTables>(),
EM_RISCV => size_of::<memtables::riscv64::FixedTables>(),
_ => unimplemented!(),
};
let (shdrs, strtab) = elf.section_headers_with_strtab()?;
let strtab = strtab.ok_or_else(|| GenError::MissingSection(".strtab"))?;
for shdr in shdrs {
let name = strtab.get(shdr.sh_name as _)?;
if name == ".data.tables" {
if shdr.sh_size != section_size as _ {
return Err(GenError::IncorrectTablesSize(shdr.sh_size, section_size));
}
if shdr.sh_addralign < 0x1000 {
return Err(GenError::IncorrectTablesAlign(shdr.sh_addralign));
}
// TODO section checks
return Ok((shdr.sh_offset, shdr.sh_addr));
}
}
Err(GenError::MissingSection(".data.tables"))
}
fn extract_symbols<F: Read + Seek>(
elf: &mut ElfStream<AnyEndian, F>,
) -> Result<HashMap<String, usize>, GenError> {
let mut table = HashMap::new();
// let mut export_tables = HashSet::new();
// let force_export = &[
// "rust_begin_unwind",
// "__rust_alloc",
// "__rust_alloc_zeroed",
// "__rust_dealloc",
// "__rust_realloc",
// ];
// // Find all .export table indices
// let (shdrs, strtab) = elf.section_headers_with_strtab()?;
// let strtab = strtab.ok_or_else(|| GenError::MissingSection(".strtab"))?;
// for (i, shdr) in shdrs.iter().enumerate() {
// if shdr.sh_type != elf::abi::SHT_PROGBITS {
// continue;
// }
// let name = strtab.get(shdr.sh_name as _)?;
// if name.starts_with(".export.") {
// println!("Export from {:?}", name);
// export_tables.insert(i as u16);
// }
// }
let (symtab, strtab) = elf.symbol_table()?.ok_or(GenError::MissingSymbolTable)?;
// Only produce symbols from .export tables
for sym in symtab {
if sym.st_vis() == elf::abi::STV_HIDDEN {
continue;
}
let name = strtab.get(sym.st_name as _)?;
// if export_tables.contains(&sym.st_shndx) || force_export.contains(&name) {
table.insert(name.to_owned(), sym.st_value as usize);
// }
}
println!("{} exported symbols extracted", table.len());
Ok(table)
}
fn into_any<T: Into<AnyTables>, U, V>((x, y, z): (T, U, V)) -> (AnyTables, U, V) {
(x.into(), y, z)
}
fn build_tables<F: Read + Seek>(
file: F,
) -> Result<(Option<(AnyTables, u64)>, HashMap<String, usize>), GenError> {
let mut elf = ElfStream::<AnyEndian, F>::open_stream(file)?;
if elf.ehdr.e_machine == EM_386 {
// Locate symbol table
let symbol_table = extract_symbols(&mut elf)?;
return Ok((None, symbol_table));
}
let kernel_virt_offset = kernel_virt_offset(&mut elf)?;
let (kernel_start, kernel_end) = kernel_image_range(&mut elf, kernel_virt_offset)?;
let (table_offset, table_virt_addr) = find_tables(&mut elf)?;
let table_physical_address = table_virt_addr
.checked_sub(kernel_virt_offset)
.ok_or_else(|| GenError::IncorrectTablesPlacement(table_virt_addr))?;
println!("Kernel image range: {:#x?}", kernel_start..kernel_end);
println!("KERNEL_VIRT_OFFSET = {:#x}", kernel_virt_offset);
let (tables, table_offset, symbol_table) = match elf.ehdr.e_machine {
EM_X86_64 => X8664Builder::new(
elf,
GenData {
kernel_virt_offset,
kernel_start,
kernel_end,
table_offset,
table_physical_address,
},
)?
.build()
.map(into_any),
EM_AARCH64 => AArch64Builder::new(
elf,
GenData {
kernel_virt_offset,
kernel_start,
kernel_end,
table_offset,
table_physical_address,
},
)?
.build()
.map(into_any),
EM_RISCV => {
// TODO
std::process::exit(0);
}
_ => todo!(),
}?;
Ok((Some((tables, table_offset)), symbol_table))
}
fn write_tables<F: Write + Seek>(
mut file: F,
offset: u64,
tables: AnyTables,
) -> Result<(), GenError> {
file.seek(SeekFrom::Start(offset))?;
file.write_all(tables.as_bytes())?;
Ok(())
}
fn write_symbol_table(
out: impl AsRef<Path>,
table: HashMap<String, usize>,
) -> Result<(), GenError> {
let mut file = File::create(out)?;
for (name, value) in table {
let len: u32 = name.len().try_into().unwrap();
let value: u64 = value.try_into().unwrap();
file.write_all(&len.to_le_bytes())?;
file.write_all(name.as_bytes())?;
file.write_all(&value.to_le_bytes())?;
}
Ok(())
}
fn gentables(image: impl AsRef<Path>, symbol_out: impl AsRef<Path>) -> Result<(), GenError> {
let mut file = OpenOptions::new()
.read(true)
.write(true)
.truncate(false)
.open(image)?;
let (tables, symbol_table) = build_tables(&mut file)?;
write_symbol_table(symbol_out, symbol_table)?;
if let Some((tables, file_offset)) = tables {
write_tables(file, file_offset, tables)?;
}
Ok(())
}
fn main() -> ExitCode {
let args = Args::parse();
match gentables(&args.image, &args.symbol_out) {
Ok(()) => ExitCode::SUCCESS,
Err(err) => {
eprintln!("{}: {}", args.image.display(), err);
ExitCode::FAILURE
}
}
}