577 lines
19 KiB
Rust
577 lines
19 KiB
Rust
use std::{collections::HashMap, io::Read};
|
|
|
|
use bytemuck::Pod;
|
|
|
|
use crate::jpeg::{
|
|
data::{self, Block8x8},
|
|
error::Error,
|
|
header::App0Header,
|
|
huffman::{HuffmanDecoder, HuffmanTable},
|
|
};
|
|
use crate::{RgbImage, YCbCrImage};
|
|
|
|
pub const MAX_COMPONENTS: usize = 4;
|
|
pub const COMPONENT_NAMES: &[&str] = &["luma", "chroma_b", "chroma_r", "<unimp>"];
|
|
|
|
#[derive(Debug, Default)]
|
|
struct JpegState {
|
|
frame_info: Option<FrameInfo>,
|
|
quantization_tables: [Option<Block8x8<u16>>; MAX_COMPONENTS],
|
|
scan_info: Option<ScanInfo>,
|
|
dc_huffman_tables: HashMap<u8, HuffmanTable>,
|
|
ac_huffman_tables: HashMap<u8, HuffmanTable>,
|
|
headers_parsed: bool,
|
|
end_of_image: bool,
|
|
is_progressive: bool,
|
|
}
|
|
|
|
pub struct JpegDecoder<R: Read> {
|
|
reader: R,
|
|
state: JpegState,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
|
pub enum SamplingMode {
|
|
Sample1x1,
|
|
Sample1x2,
|
|
Sample2x1,
|
|
Sample2x2,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub struct FrameComponent {
|
|
pub sampling: SamplingMode,
|
|
pub qt_index: usize,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct FrameInfo {
|
|
pub precision: u8,
|
|
pub components: [Option<FrameComponent>; MAX_COMPONENTS],
|
|
pub lines: usize,
|
|
pub samples_per_line: usize,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct ScanInfo {
|
|
pub dc_entropy_selectors: [u8; MAX_COMPONENTS],
|
|
pub ac_entropy_selectors: [u8; MAX_COMPONENTS],
|
|
pub spectral_predict_start: u8,
|
|
pub spectral_predict_end: u8,
|
|
pub successive_approximation: u8,
|
|
pub component_mask: u8,
|
|
}
|
|
|
|
impl SamplingMode {
|
|
pub fn vertical(&self) -> usize {
|
|
match self {
|
|
Self::Sample1x1 | Self::Sample2x1 => 1,
|
|
_ => 2,
|
|
}
|
|
}
|
|
|
|
pub fn horizontal(&self) -> usize {
|
|
match self {
|
|
Self::Sample1x1 | Self::Sample1x2 => 1,
|
|
_ => 2,
|
|
}
|
|
}
|
|
|
|
pub fn sample_count(&self) -> usize {
|
|
match self {
|
|
Self::Sample1x1 => 1,
|
|
Self::Sample1x2 | Self::Sample2x1 => 2,
|
|
Self::Sample2x2 => 4,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<R: Read> JpegDecoder<R> {
|
|
pub fn new(reader: R) -> Self {
|
|
Self {
|
|
reader,
|
|
state: JpegState::default(),
|
|
}
|
|
}
|
|
|
|
pub fn decode_headers(&mut self) -> Result<(), Error> {
|
|
self.state.decode_headers(&mut self.reader)
|
|
}
|
|
|
|
pub fn decode_ycbcr(&mut self) -> Result<YCbCrImage, Error> {
|
|
self.state.decode_frame(&mut self.reader)
|
|
}
|
|
|
|
pub fn decode_rgb(&mut self) -> Result<RgbImage, Error> {
|
|
let ycbcr = self.decode_ycbcr()?;
|
|
Ok(ycbcr.to_rgb())
|
|
}
|
|
}
|
|
|
|
impl JpegState {
|
|
fn do_decode_headers<R: Read>(&mut self, reader: &mut R) -> Result<(), Error> {
|
|
// SOI marker
|
|
let soi_marker = read_u16be(reader)?;
|
|
if soi_marker != 0xFFD8 {
|
|
return Err(Error::MalformedHeader);
|
|
}
|
|
|
|
let mut last_byte = 0;
|
|
|
|
loop {
|
|
let mut m = read_u8(reader)?;
|
|
if last_byte == 0xFF && (m == 0xFF || m == 0x00) {
|
|
while m == 0xFF || m == 0x00 {
|
|
last_byte = m;
|
|
m = read_u8(reader)?;
|
|
}
|
|
}
|
|
if last_byte == 0xFF {
|
|
self.consume_marker(reader, m)?;
|
|
|
|
if self.scan_info.is_some() || self.end_of_image {
|
|
break;
|
|
}
|
|
}
|
|
last_byte = m;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn decode_headers<R: Read>(&mut self, reader: &mut R) -> Result<(), Error> {
|
|
if self.headers_parsed {
|
|
return Ok(());
|
|
}
|
|
self.do_decode_headers(reader)
|
|
}
|
|
|
|
fn decode_frame<R: Read>(&mut self, reader: &mut R) -> Result<YCbCrImage, Error> {
|
|
self.decode_headers(reader)?;
|
|
if self.is_progressive {
|
|
Err(Error::UnimplementedFeature("Progressive image decoding"))
|
|
} else {
|
|
decode_ycbcr_baseline(reader, self)
|
|
}
|
|
}
|
|
|
|
fn consume_marker<R: Read>(&mut self, reader: &mut R, m: u8) -> Result<(), Error> {
|
|
match m {
|
|
0xC0..=0xC3 => self.consume_sof(reader, m - 0xC0),
|
|
0xC4 => self.consume_dht(reader),
|
|
0xE0 => self.consume_app0(reader),
|
|
0xD9 => {
|
|
self.end_of_image = true;
|
|
Ok(())
|
|
}
|
|
0xDA => self.consume_sos(reader),
|
|
0xDB => self.consume_dqt(reader),
|
|
_ => Err(Error::InvalidMarker(m)),
|
|
}
|
|
}
|
|
|
|
fn consume_sof<R: Read>(&mut self, reader: &mut R, i: u8) -> Result<(), Error> {
|
|
self.is_progressive = i >= 2;
|
|
if self.frame_info.is_some() {
|
|
return Err(Error::MultipleSofMarkers);
|
|
}
|
|
let length = read_u16be(reader)?;
|
|
let sample_precision = read_u8(reader)?;
|
|
let line_count = read_u16be(reader)?;
|
|
let samples_per_line = read_u16be(reader)?;
|
|
let component_count = read_u8(reader)?;
|
|
// TODO parse non-8bpp images
|
|
if sample_precision != 8 {
|
|
return Err(Error::UnimplementedSamplePrecision(sample_precision));
|
|
}
|
|
if length as usize != 8 + 3 * component_count as usize {
|
|
return Err(Error::MalformedHeader);
|
|
}
|
|
if component_count != 3 {
|
|
return Err(Error::UnimplementedComponentCount(component_count));
|
|
}
|
|
log::debug!("SOF {sample_precision}bpp {samples_per_line}x{line_count}");
|
|
let mut frame_info = FrameInfo {
|
|
precision: sample_precision,
|
|
samples_per_line: samples_per_line as usize,
|
|
lines: line_count as usize,
|
|
components: [const { None }; MAX_COMPONENTS],
|
|
};
|
|
for _ in 0..component_count {
|
|
let component_id = read_u8(reader)? as usize;
|
|
if component_id > MAX_COMPONENTS || component_id == 0 {
|
|
return Err(Error::InvalidComponentIndex(component_id));
|
|
}
|
|
let component_id = component_id - 1;
|
|
if frame_info.components[component_id].is_some() {
|
|
return Err(Error::DuplicateComponent(component_id));
|
|
}
|
|
let sampling_factor = read_u8(reader)?;
|
|
let sampling = match sampling_factor {
|
|
0x11 => SamplingMode::Sample1x1,
|
|
0x21 => SamplingMode::Sample2x1,
|
|
0x12 => SamplingMode::Sample1x2,
|
|
0x22 => SamplingMode::Sample2x2,
|
|
_ => {
|
|
let hsampling = sampling_factor >> 4;
|
|
let vsampling = sampling_factor & 0xF;
|
|
return Err(Error::UnimplementedSamplingFactor(
|
|
COMPONENT_NAMES[component_id],
|
|
hsampling,
|
|
vsampling,
|
|
));
|
|
}
|
|
};
|
|
let qt_index = read_u8(reader)? as usize;
|
|
frame_info.components[component_id] = Some(FrameComponent { sampling, qt_index });
|
|
log::debug!(" [{component_id}] {sampling:?}, qt {qt_index}");
|
|
}
|
|
self.frame_info = Some(frame_info);
|
|
Ok(())
|
|
}
|
|
|
|
fn consume_dht<R: Read>(&mut self, reader: &mut R) -> Result<(), Error> {
|
|
let mut length = read_u16be(reader)?.saturating_sub(2) as usize;
|
|
let mut huffman_code_counts = [0; 16];
|
|
let mut huffman_code_values = [0; 256];
|
|
|
|
while length > 16 {
|
|
let ht_info = read_u8(reader)?;
|
|
let dc_or_ac = ht_info >> 4;
|
|
let index = ht_info & 0x0F;
|
|
reader.read_exact(&mut huffman_code_counts)?;
|
|
let total_value_count: usize =
|
|
huffman_code_counts.into_iter().map(|c| c as usize).sum();
|
|
length -= 17;
|
|
if total_value_count > 256 || total_value_count > length {
|
|
return Err(Error::InvalidHuffmanTable);
|
|
}
|
|
reader.read_exact(&mut huffman_code_values[..total_value_count])?;
|
|
length -= total_value_count;
|
|
match dc_or_ac {
|
|
0 => {
|
|
log::debug!(
|
|
"DC[{index}] {huffman_code_counts:?}, {:?}",
|
|
&huffman_code_values[..total_value_count]
|
|
);
|
|
let table = HuffmanTable::new(
|
|
huffman_code_counts,
|
|
huffman_code_values,
|
|
true,
|
|
self.is_progressive,
|
|
)?;
|
|
self.dc_huffman_tables.insert(index, table);
|
|
}
|
|
1 => {
|
|
log::debug!(
|
|
"AC[{index}] {huffman_code_counts:?}, {:?}",
|
|
&huffman_code_values[..total_value_count]
|
|
);
|
|
let table = HuffmanTable::new(
|
|
huffman_code_counts,
|
|
huffman_code_values,
|
|
false,
|
|
self.is_progressive,
|
|
)?;
|
|
self.ac_huffman_tables.insert(index, table);
|
|
}
|
|
_ => {
|
|
return Err(Error::InvalidHuffmanTable);
|
|
}
|
|
}
|
|
}
|
|
|
|
if length > 0 {
|
|
return Err(Error::InvalidHuffmanTable);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn consume_dqt<R: Read>(&mut self, reader: &mut R) -> Result<(), Error> {
|
|
let mut length = read_u16be(reader)?.saturating_sub(2) as usize;
|
|
while length != 0 {
|
|
let qt_param = read_u8(reader)?;
|
|
let qt_precision = (qt_param >> 4) as usize;
|
|
let qt_index = (qt_param & 0x0F) as usize;
|
|
if qt_index >= MAX_COMPONENTS {
|
|
return Err(Error::InvalidQtIndex(qt_index));
|
|
}
|
|
if self.quantization_tables[qt_index].is_some() {
|
|
return Err(Error::DuplicateQt(qt_index));
|
|
}
|
|
let qt_bytes_per_element = 1 << qt_precision;
|
|
if qt_bytes_per_element * 64 + 1 > length {
|
|
return Err(Error::MalformedHeader);
|
|
}
|
|
// Read 64 elements of the quantization table
|
|
let qt_block = match qt_bytes_per_element {
|
|
1 => {
|
|
let mut qt_values = [0; 64];
|
|
reader.read_exact(&mut qt_values)?;
|
|
Block8x8::new(qt_values).map(|x| x as u16)
|
|
}
|
|
2 => {
|
|
let mut qt_values = [0u16; 64];
|
|
reader.read_exact(bytemuck::cast_slice_mut(&mut qt_values))?;
|
|
Block8x8::new(qt_values)
|
|
}
|
|
_ => return Err(Error::InvalidQtPrecision(qt_bytes_per_element)),
|
|
};
|
|
|
|
log::debug!("QT[{qt_index}]: {qt_bytes_per_element}Bpe\n{qt_block:3}");
|
|
|
|
self.quantization_tables[qt_index] = Some(qt_block);
|
|
length -= 1 + qt_bytes_per_element * 64;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn consume_app0<R: Read>(&mut self, reader: &mut R) -> Result<(), Error> {
|
|
let app0: App0Header = read_struct(reader)?;
|
|
if app0.identifier != *b"JFIF\x00" {
|
|
return Err(Error::MalformedHeader);
|
|
}
|
|
assert_eq!(app0.xthumbnail, 0);
|
|
assert_eq!(app0.ythumbnail, 0);
|
|
assert_eq!(app0.length.read() as usize, size_of::<App0Header>());
|
|
|
|
// TODO skip thumbnail bytes
|
|
Ok(())
|
|
}
|
|
|
|
fn consume_sos<R: Read>(&mut self, reader: &mut R) -> Result<(), Error> {
|
|
let _length = read_u16be(reader)?.saturating_sub(2) as usize;
|
|
let mut scan = ScanInfo {
|
|
dc_entropy_selectors: [0; MAX_COMPONENTS],
|
|
ac_entropy_selectors: [0; MAX_COMPONENTS],
|
|
spectral_predict_start: 0,
|
|
spectral_predict_end: 0,
|
|
successive_approximation: 0,
|
|
component_mask: 0,
|
|
};
|
|
let component_count = read_u8(reader)? as usize;
|
|
for _ in 0..component_count {
|
|
let component_id = read_u8(reader)? as usize;
|
|
let entropy_dst_selector = read_u8(reader)?;
|
|
if component_id > MAX_COMPONENTS || component_id == 0 {
|
|
return Err(Error::InvalidComponentIndex(component_id));
|
|
}
|
|
let component_id = component_id - 1;
|
|
let dc_selector = entropy_dst_selector >> 4;
|
|
let ac_selector = entropy_dst_selector & 0x0F;
|
|
scan.dc_entropy_selectors[component_id] = dc_selector;
|
|
scan.ac_entropy_selectors[component_id] = ac_selector;
|
|
scan.component_mask |= 1 << component_id;
|
|
}
|
|
scan.spectral_predict_start = read_u8(reader)?;
|
|
scan.spectral_predict_end = read_u8(reader)?;
|
|
scan.successive_approximation = read_u8(reader)?;
|
|
self.scan_info = Some(scan);
|
|
self.headers_parsed = true;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
pub fn read_u8<R: Read>(reader: &mut R) -> Result<u8, Error> {
|
|
let mut buf = [0; 1];
|
|
reader.read_exact(&mut buf)?;
|
|
Ok(buf[0])
|
|
}
|
|
|
|
pub fn read_u16be<R: Read>(reader: &mut R) -> Result<u16, Error> {
|
|
let mut buf = [0; 2];
|
|
reader.read_exact(&mut buf)?;
|
|
Ok(u16::from_be_bytes(buf))
|
|
}
|
|
|
|
pub fn read_struct<T: Pod, R: Read>(reader: &mut R) -> Result<T, Error> {
|
|
let mut value = T::zeroed();
|
|
reader.read_exact(bytemuck::bytes_of_mut(&mut value))?;
|
|
Ok(value)
|
|
}
|
|
|
|
fn expand_number(code: u8, bits: u16) -> i16 {
|
|
if code == 0 {
|
|
return 0;
|
|
}
|
|
let l = 1 << (code - 1);
|
|
if bits >= l {
|
|
bits as i16
|
|
} else {
|
|
bits as i16 - ((1i16 << code) - 1)
|
|
}
|
|
}
|
|
|
|
fn decode_block<R: Read>(
|
|
reader: &mut R,
|
|
huffman: &mut HuffmanDecoder,
|
|
dc_table: &HuffmanTable,
|
|
ac_table: &HuffmanTable,
|
|
quant_table: &Block8x8<u16>,
|
|
dc_predictor: &mut i16,
|
|
) -> Result<Block8x8<u8>, Error> {
|
|
let mut block_data = [0; 64];
|
|
|
|
let mut l = 1;
|
|
|
|
// Read the DC coefficient
|
|
let code = huffman.decode_symbol(reader, dc_table)?;
|
|
let bits = huffman.get_bits(reader, code & 0x0F)?;
|
|
let dccoeff = expand_number(code, bits).wrapping_add(*dc_predictor);
|
|
*dc_predictor = dccoeff;
|
|
block_data[0] = dccoeff.wrapping_mul(quant_table.data[0] as i16);
|
|
|
|
while l < 64 {
|
|
let code = huffman.decode_symbol(reader, ac_table)?;
|
|
if code == 0 {
|
|
break;
|
|
}
|
|
|
|
let code = if code > 15 {
|
|
l += (code >> 4) as usize;
|
|
code & 0x0F
|
|
} else {
|
|
code
|
|
};
|
|
|
|
let bits = huffman.get_bits(reader, code)?;
|
|
|
|
if l < 64 {
|
|
let coeff = expand_number(code, bits);
|
|
block_data[l] = coeff.wrapping_mul(quant_table.data[l] as i16);
|
|
l += 1;
|
|
}
|
|
}
|
|
|
|
let block = Block8x8::unzigzag(&block_data);
|
|
let block = block.idct();
|
|
let block = block.map(|f| f as u8);
|
|
Ok(block)
|
|
}
|
|
|
|
fn decode_ycbcr_baseline<R: Read>(
|
|
reader: &mut R,
|
|
state: &mut JpegState,
|
|
) -> Result<YCbCrImage, Error> {
|
|
let frame_info = state.frame_info.as_ref().unwrap();
|
|
let scan_info = state.scan_info.as_ref().unwrap();
|
|
let component_count = scan_info.component_mask.count_ones() as usize;
|
|
|
|
let components = [
|
|
frame_info.components[0].as_ref().unwrap(),
|
|
frame_info.components[1].as_ref().unwrap(),
|
|
frame_info.components[2].as_ref().unwrap(),
|
|
];
|
|
|
|
let max_hsample = components
|
|
.iter()
|
|
.map(|c| c.sampling.horizontal())
|
|
.max()
|
|
.unwrap();
|
|
let max_vsample = components
|
|
.iter()
|
|
.map(|c| c.sampling.vertical())
|
|
.max()
|
|
.unwrap();
|
|
|
|
let component_len = frame_info.samples_per_line * frame_info.lines;
|
|
let mut component_data = [
|
|
vec![0; component_len],
|
|
vec![0; component_len],
|
|
vec![0; component_len],
|
|
];
|
|
|
|
let mut huffman = HuffmanDecoder::default();
|
|
|
|
let mut dc_predictors = [0; MAX_COMPONENTS];
|
|
|
|
let mcu_width_pixels = 8 * max_hsample;
|
|
let mcu_height_pixels = 8 * max_vsample;
|
|
|
|
let mcu_columns = frame_info.samples_per_line.div_ceil(mcu_width_pixels);
|
|
let mcu_rows = frame_info.lines.div_ceil(mcu_height_pixels);
|
|
|
|
'outer: for mcu_y in 0..mcu_rows {
|
|
for mcu_x in 0..mcu_columns {
|
|
for component_id in 0..component_count {
|
|
// TODO move this to block decode
|
|
if let Some(marker) = huffman.marker {
|
|
match marker {
|
|
0xD9 => {
|
|
break 'outer;
|
|
}
|
|
_ => todo!(),
|
|
}
|
|
}
|
|
|
|
let hsamples = components[component_id].sampling.horizontal();
|
|
let vsamples = components[component_id].sampling.vertical();
|
|
let dc_index = scan_info.dc_entropy_selectors[component_id];
|
|
let ac_index = scan_info.ac_entropy_selectors[component_id];
|
|
let qt_index = frame_info.components[component_id]
|
|
.as_ref()
|
|
.unwrap()
|
|
.qt_index;
|
|
let dc_table = state
|
|
.dc_huffman_tables
|
|
.get(&dc_index)
|
|
.ok_or(Error::MissingDCHT(dc_index))?;
|
|
let ac_table = state
|
|
.ac_huffman_tables
|
|
.get(&ac_index)
|
|
.ok_or(Error::MissingACHT(ac_index))?;
|
|
let quant_table = state.quantization_tables[qt_index]
|
|
.as_ref()
|
|
.ok_or(Error::MissingQT(qt_index))?;
|
|
|
|
let upsample_y = max_vsample / vsamples;
|
|
let upsample_x = max_hsample / hsamples;
|
|
let output = &mut component_data[component_id];
|
|
|
|
for vsample in 0..vsamples {
|
|
for hsample in 0..hsamples {
|
|
let block = decode_block(
|
|
reader,
|
|
&mut huffman,
|
|
dc_table,
|
|
ac_table,
|
|
quant_table,
|
|
&mut dc_predictors[component_id],
|
|
)?;
|
|
let dst_block_x = (mcu_x * max_hsample + hsample) * 8;
|
|
let dst_block_y = (mcu_y * max_vsample + vsample) * 8;
|
|
|
|
let write_fn = match (upsample_x, upsample_y) {
|
|
(1, 1) => data::write_block_1x1,
|
|
(1, 2) => data::write_block_1x2,
|
|
(2, 1) => data::write_block_2x1,
|
|
(2, 2) => data::write_block_2x2,
|
|
_ => unreachable!(),
|
|
};
|
|
|
|
write_fn(
|
|
output,
|
|
frame_info.samples_per_line,
|
|
dst_block_x,
|
|
dst_block_y,
|
|
&block,
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let [y, cb, cr] = component_data;
|
|
|
|
let ycbcr = YCbCrImage {
|
|
luma: y,
|
|
chroma_b: cb,
|
|
chroma_r: cr,
|
|
width: frame_info.samples_per_line,
|
|
height: frame_info.lines,
|
|
};
|
|
|
|
Ok(ycbcr)
|
|
}
|