Files
yggdrasil/userspace/lib/pixie/src/jpeg/decoder.rs
T

577 lines
19 KiB
Rust

use std::{collections::HashMap, io::Read};
use bytemuck::Pod;
use crate::jpeg::{
data::{self, Block8x8},
error::Error,
header::App0Header,
huffman::{HuffmanDecoder, HuffmanTable},
};
use crate::{RgbImage, YCbCrImage};
pub const MAX_COMPONENTS: usize = 4;
pub const COMPONENT_NAMES: &[&str] = &["luma", "chroma_b", "chroma_r", "<unimp>"];
#[derive(Debug, Default)]
struct JpegState {
frame_info: Option<FrameInfo>,
quantization_tables: [Option<Block8x8<u16>>; MAX_COMPONENTS],
scan_info: Option<ScanInfo>,
dc_huffman_tables: HashMap<u8, HuffmanTable>,
ac_huffman_tables: HashMap<u8, HuffmanTable>,
headers_parsed: bool,
end_of_image: bool,
is_progressive: bool,
}
pub struct JpegDecoder<R: Read> {
reader: R,
state: JpegState,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum SamplingMode {
Sample1x1,
Sample1x2,
Sample2x1,
Sample2x2,
}
#[derive(Debug, Clone, Copy)]
pub struct FrameComponent {
pub sampling: SamplingMode,
pub qt_index: usize,
}
#[derive(Debug)]
pub struct FrameInfo {
pub precision: u8,
pub components: [Option<FrameComponent>; MAX_COMPONENTS],
pub lines: usize,
pub samples_per_line: usize,
}
#[derive(Debug)]
pub struct ScanInfo {
pub dc_entropy_selectors: [u8; MAX_COMPONENTS],
pub ac_entropy_selectors: [u8; MAX_COMPONENTS],
pub spectral_predict_start: u8,
pub spectral_predict_end: u8,
pub successive_approximation: u8,
pub component_mask: u8,
}
impl SamplingMode {
pub fn vertical(&self) -> usize {
match self {
Self::Sample1x1 | Self::Sample2x1 => 1,
_ => 2,
}
}
pub fn horizontal(&self) -> usize {
match self {
Self::Sample1x1 | Self::Sample1x2 => 1,
_ => 2,
}
}
pub fn sample_count(&self) -> usize {
match self {
Self::Sample1x1 => 1,
Self::Sample1x2 | Self::Sample2x1 => 2,
Self::Sample2x2 => 4,
}
}
}
impl<R: Read> JpegDecoder<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
state: JpegState::default(),
}
}
pub fn decode_headers(&mut self) -> Result<(), Error> {
self.state.decode_headers(&mut self.reader)
}
pub fn decode_ycbcr(&mut self) -> Result<YCbCrImage, Error> {
self.state.decode_frame(&mut self.reader)
}
pub fn decode_rgb(&mut self) -> Result<RgbImage, Error> {
let ycbcr = self.decode_ycbcr()?;
Ok(ycbcr.to_rgb())
}
}
impl JpegState {
fn do_decode_headers<R: Read>(&mut self, reader: &mut R) -> Result<(), Error> {
// SOI marker
let soi_marker = read_u16be(reader)?;
if soi_marker != 0xFFD8 {
return Err(Error::MalformedHeader);
}
let mut last_byte = 0;
loop {
let mut m = read_u8(reader)?;
if last_byte == 0xFF && (m == 0xFF || m == 0x00) {
while m == 0xFF || m == 0x00 {
last_byte = m;
m = read_u8(reader)?;
}
}
if last_byte == 0xFF {
self.consume_marker(reader, m)?;
if self.scan_info.is_some() || self.end_of_image {
break;
}
}
last_byte = m;
}
Ok(())
}
fn decode_headers<R: Read>(&mut self, reader: &mut R) -> Result<(), Error> {
if self.headers_parsed {
return Ok(());
}
self.do_decode_headers(reader)
}
fn decode_frame<R: Read>(&mut self, reader: &mut R) -> Result<YCbCrImage, Error> {
self.decode_headers(reader)?;
if self.is_progressive {
Err(Error::UnimplementedFeature("Progressive image decoding"))
} else {
decode_ycbcr_baseline(reader, self)
}
}
fn consume_marker<R: Read>(&mut self, reader: &mut R, m: u8) -> Result<(), Error> {
match m {
0xC0..=0xC3 => self.consume_sof(reader, m - 0xC0),
0xC4 => self.consume_dht(reader),
0xE0 => self.consume_app0(reader),
0xD9 => {
self.end_of_image = true;
Ok(())
}
0xDA => self.consume_sos(reader),
0xDB => self.consume_dqt(reader),
_ => Err(Error::InvalidMarker(m)),
}
}
fn consume_sof<R: Read>(&mut self, reader: &mut R, i: u8) -> Result<(), Error> {
self.is_progressive = i >= 2;
if self.frame_info.is_some() {
return Err(Error::MultipleSofMarkers);
}
let length = read_u16be(reader)?;
let sample_precision = read_u8(reader)?;
let line_count = read_u16be(reader)?;
let samples_per_line = read_u16be(reader)?;
let component_count = read_u8(reader)?;
// TODO parse non-8bpp images
if sample_precision != 8 {
return Err(Error::UnimplementedSamplePrecision(sample_precision));
}
if length as usize != 8 + 3 * component_count as usize {
return Err(Error::MalformedHeader);
}
if component_count != 3 {
return Err(Error::UnimplementedComponentCount(component_count));
}
log::debug!("SOF {sample_precision}bpp {samples_per_line}x{line_count}");
let mut frame_info = FrameInfo {
precision: sample_precision,
samples_per_line: samples_per_line as usize,
lines: line_count as usize,
components: [const { None }; MAX_COMPONENTS],
};
for _ in 0..component_count {
let component_id = read_u8(reader)? as usize;
if component_id > MAX_COMPONENTS || component_id == 0 {
return Err(Error::InvalidComponentIndex(component_id));
}
let component_id = component_id - 1;
if frame_info.components[component_id].is_some() {
return Err(Error::DuplicateComponent(component_id));
}
let sampling_factor = read_u8(reader)?;
let sampling = match sampling_factor {
0x11 => SamplingMode::Sample1x1,
0x21 => SamplingMode::Sample2x1,
0x12 => SamplingMode::Sample1x2,
0x22 => SamplingMode::Sample2x2,
_ => {
let hsampling = sampling_factor >> 4;
let vsampling = sampling_factor & 0xF;
return Err(Error::UnimplementedSamplingFactor(
COMPONENT_NAMES[component_id],
hsampling,
vsampling,
));
}
};
let qt_index = read_u8(reader)? as usize;
frame_info.components[component_id] = Some(FrameComponent { sampling, qt_index });
log::debug!(" [{component_id}] {sampling:?}, qt {qt_index}");
}
self.frame_info = Some(frame_info);
Ok(())
}
fn consume_dht<R: Read>(&mut self, reader: &mut R) -> Result<(), Error> {
let mut length = read_u16be(reader)?.saturating_sub(2) as usize;
let mut huffman_code_counts = [0; 16];
let mut huffman_code_values = [0; 256];
while length > 16 {
let ht_info = read_u8(reader)?;
let dc_or_ac = ht_info >> 4;
let index = ht_info & 0x0F;
reader.read_exact(&mut huffman_code_counts)?;
let total_value_count: usize =
huffman_code_counts.into_iter().map(|c| c as usize).sum();
length -= 17;
if total_value_count > 256 || total_value_count > length {
return Err(Error::InvalidHuffmanTable);
}
reader.read_exact(&mut huffman_code_values[..total_value_count])?;
length -= total_value_count;
match dc_or_ac {
0 => {
log::debug!(
"DC[{index}] {huffman_code_counts:?}, {:?}",
&huffman_code_values[..total_value_count]
);
let table = HuffmanTable::new(
huffman_code_counts,
huffman_code_values,
true,
self.is_progressive,
)?;
self.dc_huffman_tables.insert(index, table);
}
1 => {
log::debug!(
"AC[{index}] {huffman_code_counts:?}, {:?}",
&huffman_code_values[..total_value_count]
);
let table = HuffmanTable::new(
huffman_code_counts,
huffman_code_values,
false,
self.is_progressive,
)?;
self.ac_huffman_tables.insert(index, table);
}
_ => {
return Err(Error::InvalidHuffmanTable);
}
}
}
if length > 0 {
return Err(Error::InvalidHuffmanTable);
}
Ok(())
}
fn consume_dqt<R: Read>(&mut self, reader: &mut R) -> Result<(), Error> {
let mut length = read_u16be(reader)?.saturating_sub(2) as usize;
while length != 0 {
let qt_param = read_u8(reader)?;
let qt_precision = (qt_param >> 4) as usize;
let qt_index = (qt_param & 0x0F) as usize;
if qt_index >= MAX_COMPONENTS {
return Err(Error::InvalidQtIndex(qt_index));
}
if self.quantization_tables[qt_index].is_some() {
return Err(Error::DuplicateQt(qt_index));
}
let qt_bytes_per_element = 1 << qt_precision;
if qt_bytes_per_element * 64 + 1 > length {
return Err(Error::MalformedHeader);
}
// Read 64 elements of the quantization table
let qt_block = match qt_bytes_per_element {
1 => {
let mut qt_values = [0; 64];
reader.read_exact(&mut qt_values)?;
Block8x8::new(qt_values).map(|x| x as u16)
}
2 => {
let mut qt_values = [0u16; 64];
reader.read_exact(bytemuck::cast_slice_mut(&mut qt_values))?;
Block8x8::new(qt_values)
}
_ => return Err(Error::InvalidQtPrecision(qt_bytes_per_element)),
};
log::debug!("QT[{qt_index}]: {qt_bytes_per_element}Bpe\n{qt_block:3}");
self.quantization_tables[qt_index] = Some(qt_block);
length -= 1 + qt_bytes_per_element * 64;
}
Ok(())
}
fn consume_app0<R: Read>(&mut self, reader: &mut R) -> Result<(), Error> {
let app0: App0Header = read_struct(reader)?;
if app0.identifier != *b"JFIF\x00" {
return Err(Error::MalformedHeader);
}
assert_eq!(app0.xthumbnail, 0);
assert_eq!(app0.ythumbnail, 0);
assert_eq!(app0.length.read() as usize, size_of::<App0Header>());
// TODO skip thumbnail bytes
Ok(())
}
fn consume_sos<R: Read>(&mut self, reader: &mut R) -> Result<(), Error> {
let _length = read_u16be(reader)?.saturating_sub(2) as usize;
let mut scan = ScanInfo {
dc_entropy_selectors: [0; MAX_COMPONENTS],
ac_entropy_selectors: [0; MAX_COMPONENTS],
spectral_predict_start: 0,
spectral_predict_end: 0,
successive_approximation: 0,
component_mask: 0,
};
let component_count = read_u8(reader)? as usize;
for _ in 0..component_count {
let component_id = read_u8(reader)? as usize;
let entropy_dst_selector = read_u8(reader)?;
if component_id > MAX_COMPONENTS || component_id == 0 {
return Err(Error::InvalidComponentIndex(component_id));
}
let component_id = component_id - 1;
let dc_selector = entropy_dst_selector >> 4;
let ac_selector = entropy_dst_selector & 0x0F;
scan.dc_entropy_selectors[component_id] = dc_selector;
scan.ac_entropy_selectors[component_id] = ac_selector;
scan.component_mask |= 1 << component_id;
}
scan.spectral_predict_start = read_u8(reader)?;
scan.spectral_predict_end = read_u8(reader)?;
scan.successive_approximation = read_u8(reader)?;
self.scan_info = Some(scan);
self.headers_parsed = true;
Ok(())
}
}
pub fn read_u8<R: Read>(reader: &mut R) -> Result<u8, Error> {
let mut buf = [0; 1];
reader.read_exact(&mut buf)?;
Ok(buf[0])
}
pub fn read_u16be<R: Read>(reader: &mut R) -> Result<u16, Error> {
let mut buf = [0; 2];
reader.read_exact(&mut buf)?;
Ok(u16::from_be_bytes(buf))
}
pub fn read_struct<T: Pod, R: Read>(reader: &mut R) -> Result<T, Error> {
let mut value = T::zeroed();
reader.read_exact(bytemuck::bytes_of_mut(&mut value))?;
Ok(value)
}
fn expand_number(code: u8, bits: u16) -> i16 {
if code == 0 {
return 0;
}
let l = 1 << (code - 1);
if bits >= l {
bits as i16
} else {
bits as i16 - ((1i16 << code) - 1)
}
}
fn decode_block<R: Read>(
reader: &mut R,
huffman: &mut HuffmanDecoder,
dc_table: &HuffmanTable,
ac_table: &HuffmanTable,
quant_table: &Block8x8<u16>,
dc_predictor: &mut i16,
) -> Result<Block8x8<u8>, Error> {
let mut block_data = [0; 64];
let mut l = 1;
// Read the DC coefficient
let code = huffman.decode_symbol(reader, dc_table)?;
let bits = huffman.get_bits(reader, code & 0x0F)?;
let dccoeff = expand_number(code, bits).wrapping_add(*dc_predictor);
*dc_predictor = dccoeff;
block_data[0] = dccoeff.wrapping_mul(quant_table.data[0] as i16);
while l < 64 {
let code = huffman.decode_symbol(reader, ac_table)?;
if code == 0 {
break;
}
let code = if code > 15 {
l += (code >> 4) as usize;
code & 0x0F
} else {
code
};
let bits = huffman.get_bits(reader, code)?;
if l < 64 {
let coeff = expand_number(code, bits);
block_data[l] = coeff.wrapping_mul(quant_table.data[l] as i16);
l += 1;
}
}
let block = Block8x8::unzigzag(&block_data);
let block = block.idct();
let block = block.map(|f| f as u8);
Ok(block)
}
fn decode_ycbcr_baseline<R: Read>(
reader: &mut R,
state: &mut JpegState,
) -> Result<YCbCrImage, Error> {
let frame_info = state.frame_info.as_ref().unwrap();
let scan_info = state.scan_info.as_ref().unwrap();
let component_count = scan_info.component_mask.count_ones() as usize;
let components = [
frame_info.components[0].as_ref().unwrap(),
frame_info.components[1].as_ref().unwrap(),
frame_info.components[2].as_ref().unwrap(),
];
let max_hsample = components
.iter()
.map(|c| c.sampling.horizontal())
.max()
.unwrap();
let max_vsample = components
.iter()
.map(|c| c.sampling.vertical())
.max()
.unwrap();
let component_len = frame_info.samples_per_line * frame_info.lines;
let mut component_data = [
vec![0; component_len],
vec![0; component_len],
vec![0; component_len],
];
let mut huffman = HuffmanDecoder::default();
let mut dc_predictors = [0; MAX_COMPONENTS];
let mcu_width_pixels = 8 * max_hsample;
let mcu_height_pixels = 8 * max_vsample;
let mcu_columns = frame_info.samples_per_line.div_ceil(mcu_width_pixels);
let mcu_rows = frame_info.lines.div_ceil(mcu_height_pixels);
'outer: for mcu_y in 0..mcu_rows {
for mcu_x in 0..mcu_columns {
for component_id in 0..component_count {
// TODO move this to block decode
if let Some(marker) = huffman.marker {
match marker {
0xD9 => {
break 'outer;
}
_ => todo!(),
}
}
let hsamples = components[component_id].sampling.horizontal();
let vsamples = components[component_id].sampling.vertical();
let dc_index = scan_info.dc_entropy_selectors[component_id];
let ac_index = scan_info.ac_entropy_selectors[component_id];
let qt_index = frame_info.components[component_id]
.as_ref()
.unwrap()
.qt_index;
let dc_table = state
.dc_huffman_tables
.get(&dc_index)
.ok_or(Error::MissingDCHT(dc_index))?;
let ac_table = state
.ac_huffman_tables
.get(&ac_index)
.ok_or(Error::MissingACHT(ac_index))?;
let quant_table = state.quantization_tables[qt_index]
.as_ref()
.ok_or(Error::MissingQT(qt_index))?;
let upsample_y = max_vsample / vsamples;
let upsample_x = max_hsample / hsamples;
let output = &mut component_data[component_id];
for vsample in 0..vsamples {
for hsample in 0..hsamples {
let block = decode_block(
reader,
&mut huffman,
dc_table,
ac_table,
quant_table,
&mut dc_predictors[component_id],
)?;
let dst_block_x = (mcu_x * max_hsample + hsample) * 8;
let dst_block_y = (mcu_y * max_vsample + vsample) * 8;
let write_fn = match (upsample_x, upsample_y) {
(1, 1) => data::write_block_1x1,
(1, 2) => data::write_block_1x2,
(2, 1) => data::write_block_2x1,
(2, 2) => data::write_block_2x2,
_ => unreachable!(),
};
write_fn(
output,
frame_info.samples_per_line,
dst_block_x,
dst_block_y,
&block,
);
}
}
}
}
}
let [y, cb, cr] = component_data;
let ycbcr = YCbCrImage {
luma: y,
chroma_b: cb,
chroma_r: cr,
width: frame_info.samples_per_line,
height: frame_info.lines,
};
Ok(ycbcr)
}