diff --git a/crates/jxl-frame/src/data/lf_global.rs b/crates/jxl-frame/src/data/lf_global.rs index 61100f51..705ee0a1 100644 --- a/crates/jxl-frame/src/data/lf_global.rs +++ b/crates/jxl-frame/src/data/lf_global.rs @@ -85,12 +85,6 @@ impl Bundle<(&ImageHeader, &FrameHeader)> for LfGlobal { } } -impl LfGlobal { - pub(crate) fn apply_modular_inverse_transform(&mut self) { - self.gmodular.modular.inverse_transform(); - } -} - define_bundle! { #[derive(Debug)] pub struct LfGlobalVarDct error(crate::Error) { @@ -100,7 +94,7 @@ define_bundle! { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct GlobalModular { pub ma_config: Option, pub modular: Modular, diff --git a/crates/jxl-frame/src/data/pass_group.rs b/crates/jxl-frame/src/data/pass_group.rs index 94369cc8..d28fdb10 100644 --- a/crates/jxl-frame/src/data/pass_group.rs +++ b/crates/jxl-frame/src/data/pass_group.rs @@ -1,135 +1,102 @@ -use jxl_bitstream::{read_bits, Bitstream, Bundle}; +use jxl_bitstream::{Bitstream, Bundle}; +use jxl_grid::CutGrid; use jxl_modular::{ChannelShift, Modular}; -use jxl_vardct::{HfCoeff, HfCoeffParams}; +use jxl_vardct::{HfCoeffParams, write_hf_coeff}; use crate::{FrameHeader, Result}; use super::{ GlobalModular, - LfGlobal, LfGlobalVarDct, LfGroup, HfGlobal, }; -#[derive(Debug, Clone, Copy)] -pub struct PassGroupParams<'a> { - frame_header: &'a FrameHeader, - gmodular: &'a GlobalModular, - lf_vardct: Option<&'a LfGlobalVarDct>, - lf_group: &'a LfGroup, - hf_global: Option<&'a HfGlobal>, - pass_idx: u32, - group_idx: u32, - shift: Option<(i32, i32)>, -} - -impl<'a> PassGroupParams<'a> { - pub fn new( - frame_header: &'a FrameHeader, - lf_global: &'a LfGlobal, - lf_group: &'a LfGroup, - hf_global: Option<&'a HfGlobal>, - pass_idx: u32, - group_idx: u32, - shift: Option<(i32, i32)>, - ) -> Self { - Self { - frame_header, - gmodular: &lf_global.gmodular, - lf_vardct: lf_global.vardct.as_ref(), - lf_group, - hf_global, - pass_idx, - group_idx, - shift, - } - } +#[derive(Debug)] +pub struct PassGroupParams<'frame, 'buf, 'g> { + pub frame_header: &'frame FrameHeader, + pub lf_group: &'frame LfGroup, + pub pass_idx: u32, + pub group_idx: u32, + pub shift: Option<(i32, i32)>, + pub gmodular: &'g mut GlobalModular, + pub vardct: Option>, } #[derive(Debug)] -pub struct PassGroup { - pub hf_coeff: Option, - pub modular: Modular, +pub struct PassGroupParamsVardct<'frame, 'buf, 'g> { + pub lf_vardct: &'frame LfGlobalVarDct, + pub hf_global: &'frame HfGlobal, + pub hf_coeff_output: &'buf mut [CutGrid<'g, f32>; 3], } -impl Bundle> for PassGroup { - type Error = crate::Error; - - fn parse(bitstream: &mut Bitstream, params: PassGroupParams<'_>) -> Result { - let PassGroupParams { - frame_header, - gmodular, - lf_vardct, - lf_group, - hf_global, - pass_idx, - group_idx, - shift, - } = params; +pub fn decode_pass_group( + bitstream: &mut Bitstream, + params: PassGroupParams, +) -> Result<()> { + let PassGroupParams { + frame_header, + lf_group, + pass_idx, + group_idx, + shift, + gmodular, + vardct, + } = params; - let hf_coeff = lf_vardct - .zip(lf_group.hf_meta.as_ref()) - .zip(hf_global) - .map(|((lf_vardct, hf_meta), hf_global)| { - let hf_pass = &hf_global.hf_passes[pass_idx as usize]; - let coeff_shift = frame_header.passes.shift.get(pass_idx as usize) - .copied() - .unwrap_or(0); + if let (Some(PassGroupParamsVardct { lf_vardct, hf_global, hf_coeff_output }), Some(hf_meta)) = (vardct, &lf_group.hf_meta) { + let hf_pass = &hf_global.hf_passes[pass_idx as usize]; + let coeff_shift = frame_header.passes.shift.get(pass_idx as usize) + .copied() + .unwrap_or(0); - let group_col = group_idx % frame_header.groups_per_row(); - let group_row = group_idx / frame_header.groups_per_row(); - let lf_col = (group_col % 8) as usize; - let lf_row = (group_row % 8) as usize; - let group_dim_blocks = (frame_header.group_dim() / 8) as usize; + let group_col = group_idx % frame_header.groups_per_row(); + let group_row = group_idx / frame_header.groups_per_row(); + let lf_col = (group_col % 8) as usize; + let lf_row = (group_row % 8) as usize; + let group_dim_blocks = (frame_header.group_dim() / 8) as usize; - let block_info = &hf_meta.block_info; + let block_info = &hf_meta.block_info; - let block_left = lf_col * group_dim_blocks; - let block_top = lf_row * group_dim_blocks; - let block_width = (block_info.width() - block_left).min(group_dim_blocks); - let block_height = (block_info.height() - block_top).min(group_dim_blocks); + let block_left = lf_col * group_dim_blocks; + let block_top = lf_row * group_dim_blocks; + let block_width = (block_info.width() - block_left).min(group_dim_blocks); + let block_height = (block_info.height() - block_top).min(group_dim_blocks); - let jpeg_upsampling = frame_header.jpeg_upsampling; - let block_info = block_info.subgrid(block_left, block_top, block_width, block_height); - let lf_quant: Option<[_; 3]> = lf_group.lf_coeff.as_ref().map(|lf_coeff| { - let lf_quant_channels = lf_coeff.lf_quant.image().channel_data(); - std::array::from_fn(|idx| { - let lf_quant = &lf_quant_channels[[1, 0, 2][idx]]; - let shift = ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx); + let jpeg_upsampling = frame_header.jpeg_upsampling; + let block_info = block_info.subgrid(block_left, block_top, block_width, block_height); + let lf_quant: Option<[_; 3]> = lf_group.lf_coeff.as_ref().map(|lf_coeff| { + let lf_quant_channels = lf_coeff.lf_quant.image().channel_data(); + std::array::from_fn(|idx| { + let lf_quant = &lf_quant_channels[[1, 0, 2][idx]]; + let shift = ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx); - let block_left = block_left >> shift.hshift(); - let block_top = block_top >> shift.vshift(); - let (block_width, block_height) = shift.shift_size((block_width as u32, block_height as u32)); - lf_quant.subgrid(block_left, block_top, block_width as usize, block_height as usize) - }) - }); - - let params = HfCoeffParams { - num_hf_presets: hf_global.num_hf_presets, - hf_block_ctx: &lf_vardct.hf_block_ctx, - block_info, - jpeg_upsampling, - lf_quant, - hf_pass, - coeff_shift, - }; - HfCoeff::parse(bitstream, params) + let block_left = block_left >> shift.hshift(); + let block_top = block_top >> shift.vshift(); + let (block_width, block_height) = shift.shift_size((block_width as u32, block_height as u32)); + lf_quant.subgrid(block_left, block_top, block_width as usize, block_height as usize) }) - .transpose()?; + }); - let modular = if let Some((minshift, maxshift)) = shift { - let modular_params = gmodular.modular.make_subimage_params_pass_group(gmodular.ma_config.as_ref(), group_idx, minshift, maxshift); - let mut modular = read_bits!(bitstream, Bundle(Modular), modular_params)?; - modular.decode_image(bitstream, 1 + 3 * frame_header.num_lf_groups() + 17 + pass_idx * frame_header.num_groups() + group_idx)?; - modular.inverse_transform(); - modular - } else { - Modular::empty() + let params = HfCoeffParams { + num_hf_presets: hf_global.num_hf_presets, + hf_block_ctx: &lf_vardct.hf_block_ctx, + block_info, + jpeg_upsampling, + lf_quant, + hf_pass, + coeff_shift, }; - Ok(Self { - hf_coeff, - modular, - }) + write_hf_coeff(bitstream, params, hf_coeff_output)?; + } + + if let Some((minshift, maxshift)) = shift { + let modular_params = gmodular.modular.make_subimage_params_pass_group(gmodular.ma_config.as_ref(), group_idx, minshift, maxshift); + let mut modular = Modular::parse(bitstream, modular_params)?; + modular.decode_image(bitstream, 1 + 3 * frame_header.num_lf_groups() + 17 + pass_idx * frame_header.num_groups() + group_idx)?; + modular.inverse_transform(); + gmodular.modular.copy_from_modular(modular); } + + Ok(()) } diff --git a/crates/jxl-frame/src/data/toc.rs b/crates/jxl-frame/src/data/toc.rs index 06fb1822..d548da3a 100644 --- a/crates/jxl-frame/src/data/toc.rs +++ b/crates/jxl-frame/src/data/toc.rs @@ -15,7 +15,8 @@ pub struct Toc { num_lf_groups: usize, num_groups: usize, groups: Vec, - bitstream_order: Vec, + bitstream_to_original: Vec, + original_to_bitstream: Vec, total_size: u64, } @@ -38,7 +39,7 @@ impl std::fmt::Debug for Toc { "bitstream_order", &format_args!( "({})", - if self.bitstream_order.is_empty() { "empty" } else { "non-empty" }, + if self.bitstream_to_original.is_empty() { "empty" } else { "non-empty" }, ), ) .finish_non_exhaustive() @@ -97,7 +98,7 @@ impl PartialOrd for TocGroupKind { impl Toc { /// Returns the offset to the beginning of the data. pub fn bookmark(&self) -> Bookmark { - let idx = self.bitstream_order.first().copied().unwrap_or(0); + let idx = self.bitstream_to_original.first().copied().unwrap_or(0); self.groups[idx].offset } @@ -106,31 +107,22 @@ impl Toc { self.groups.len() <= 1 } - pub fn lf_global(&self) -> TocGroup { - self.groups[0] - } - - pub fn lf_group(&self, idx: u32) -> TocGroup { - if self.is_single_entry() { - panic!("cannot obtain LfGroup offset of single entry frame"); - } else if (idx as usize) >= self.num_lf_groups { - panic!("index out of range: {} >= {} (num_lf_groups)", idx, self.num_lf_groups); - } else { - self.groups[idx as usize + 1] - } - } - - pub fn hf_global(&self) -> TocGroup { - self.groups[self.num_lf_groups + 1] - } + pub fn group_index_bitstream_order(&self, kind: TocGroupKind) -> usize { + let original_order = match kind { + TocGroupKind::All if self.is_single_entry() => 0, + _ if self.is_single_entry() => panic!("Cannot request group type of {:?} for single-group frame", kind), + TocGroupKind::All => panic!("Cannot request group type of All for multi-group frame"), + TocGroupKind::LfGlobal => 0, + TocGroupKind::LfGroup(lf_group_idx) => 1 + lf_group_idx as usize, + TocGroupKind::HfGlobal => 1 + self.num_lf_groups, + TocGroupKind::GroupPass { pass_idx, group_idx } => + 1 + self.num_lf_groups + 1 + pass_idx as usize * self.num_groups + group_idx as usize, + }; - pub fn pass_group(&self, pass_idx: u32, group_idx: u32) -> TocGroup { - if self.is_single_entry() { - panic!("cannot obtain PassGroup offset of single entry frame"); + if self.original_to_bitstream.is_empty() { + original_order } else { - let mut idx = 1 + self.num_lf_groups + 1; - idx += (pass_idx as usize * self.num_groups) + group_idx as usize; - self.groups[idx] + self.original_to_bitstream[original_order] } } @@ -140,10 +132,10 @@ impl Toc { } pub fn iter_bitstream_order(&self) -> impl Iterator + Send { - let groups = if self.bitstream_order.is_empty() { + let groups = if self.bitstream_to_original.is_empty() { self.groups.clone() } else { - self.bitstream_order.iter().map(|&idx| self.groups[idx]).collect() + self.bitstream_to_original.iter().map(|&idx| self.groups[idx]).collect() }; groups.into_iter() } @@ -211,18 +203,18 @@ impl Bundle<&crate::FrameHeader> for Toc { out }; - let (offsets, sizes, bitstream_order) = if permutated_toc { - let mut bitstream_order = vec![0usize; permutation.len()]; + let (offsets, sizes, bitstream_to_original, original_to_bitstream) = if permutated_toc { + let mut bitstream_to_original = vec![0usize; permutation.len()]; let mut offsets_out = Vec::with_capacity(permutation.len()); let mut sizes_out = Vec::with_capacity(permutation.len()); - for (idx, perm) in permutation.into_iter().enumerate() { + for (idx, &perm) in permutation.iter().enumerate() { offsets_out.push(offsets[perm]); sizes_out.push(sizes[perm]); - bitstream_order[perm] = idx; + bitstream_to_original[perm] = idx; } - (offsets_out, sizes_out, bitstream_order) + (offsets_out, sizes_out, bitstream_to_original, permutation) } else { - (offsets, sizes, Vec::new()) + (offsets, sizes, Vec::new(), Vec::new()) }; let groups = sizes @@ -240,7 +232,8 @@ impl Bundle<&crate::FrameHeader> for Toc { num_lf_groups: ctx.num_lf_groups() as usize, num_groups: num_groups as usize, groups, - bitstream_order, + bitstream_to_original, + original_to_bitstream, total_size, }) } diff --git a/crates/jxl-frame/src/lib.rs b/crates/jxl-frame/src/lib.rs index f88bdbfa..538ea378 100644 --- a/crates/jxl-frame/src/lib.rs +++ b/crates/jxl-frame/src/lib.rs @@ -16,11 +16,10 @@ //! [`num_lf_groups`]: FrameHeader::num_lf_groups //! [`num_groups`]: FrameHeader::num_groups //! [`num_passes`]: header::Passes::num_passes -use std::collections::{HashMap, HashSet}; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, io::Cursor}; use std::io::Read; +use std::sync::Arc; -use header::Encoding; use jxl_bitstream::{read_bits, Bitstream, Bundle}; use jxl_image::ImageHeader; @@ -34,56 +33,26 @@ pub use header::FrameHeader; use crate::data::*; +type ByteSliceBitstream<'buf> = Bitstream>; + /// JPEG XL frame. /// /// A frame represents a single unit of image that can be displayed or referenced by other frames. #[derive(Debug)] -pub struct Frame<'a> { - image_header: &'a ImageHeader, +pub struct Frame { + image_header: Arc, header: FrameHeader, toc: Toc, - plan: Vec, - next_instr: usize, - buf_slot: HashMap)>, - data: FrameData, + data: Vec>, pass_shifts: BTreeMap, } -#[derive(Debug, Copy, Clone)] -enum GroupInstr { - Read(usize, TocGroup), - Decode(usize), - ProgressiveScan { - pass_idx: Option, - downsample_factor: u32, - done: bool, - }, -} - -/// Result of progressive loading of a frame. -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum ProgressiveResult { - /// More data is needed to complete a frame or a progressive scan. - NeedMoreData, - /// A progressive scan is ready to be rendered. - SingleScan { - /// Pass index, `None` if the scan is from the LF image. - pass_idx: Option, - /// Downsample factor of the scan. - downsample_factor: u32, - /// Whether the scan completes an image with the given downsample factor. - done: bool, - }, - /// A frame is fully loaded. - FrameComplete, -} - -impl<'a> Bundle<&'a ImageHeader> for Frame<'a> { +impl Bundle> for Frame { type Error = crate::Error; - fn parse(bitstream: &mut Bitstream, image_header: &'a ImageHeader) -> Result { + fn parse(bitstream: &mut Bitstream, image_header: Arc) -> Result { bitstream.zero_pad_to_byte()?; - let header = read_bits!(bitstream, Bundle(FrameHeader), image_header)?; + let header = read_bits!(bitstream, Bundle(FrameHeader), &image_header)?; for blending_info in std::iter::once(&header.blending_info).chain(&header.ec_blending_info) { if blending_info.mode.use_alpha() @@ -127,7 +96,6 @@ impl<'a> Bundle<&'a ImageHeader> for Frame<'a> { } let toc = read_bits!(bitstream, Bundle(Toc), &header)?; - let data = FrameData::new(&header); let passes = &header.passes; let mut pass_shifts = BTreeMap::new(); @@ -143,170 +111,21 @@ impl<'a> Bundle<&'a ImageHeader> for Frame<'a> { image_header, header, toc, - plan: Vec::new(), - next_instr: 0, - buf_slot: HashMap::new(), - data, + data: Vec::new(), pass_shifts, }) } } -impl Frame<'_> { - fn prepare_default_plan(&mut self) { - let header = &self.header; - let passes = &header.passes; - let toc = &self.toc; - let plan = &mut self.plan; - if toc.is_single_entry() { - let group = toc.lf_global(); - plan.push(GroupInstr::Read(0, group)); - plan.push(GroupInstr::Decode(0)); - } else { - let groups = toc.iter_bitstream_order(); - let num_lf_groups = header.num_lf_groups() as usize; - let num_groups = header.num_groups() as usize; - - let mut read_slot = HashMap::new(); - let mut decoded_slots = HashSet::new(); - let mut need_lf_global = true; - let mut need_hf_global = header.encoding == Encoding::VarDct; - let mut next_slot_idx = 0usize; - let mut lf_group_count = 0usize; - let mut group_count_per_pass = (0..passes.num_passes) - .map(|pass| (pass, 0usize)) - .collect::>(); - for group in groups { - if !need_hf_global && group.kind == TocGroupKind::HfGlobal { - continue; - } - - let current_slot_idx = next_slot_idx; - plan.push(GroupInstr::Read(current_slot_idx, group)); - next_slot_idx += 1; - - let mut update_lf_groups = false; - let mut update_pass_groups = false; - match group.kind { - TocGroupKind::All => panic!("unexpected TocGroupKind::All"), - TocGroupKind::LfGlobal => { - plan.push(GroupInstr::Decode(current_slot_idx)); - decoded_slots.insert(group.kind); - update_lf_groups = true; - need_lf_global = false; - }, - TocGroupKind::HfGlobal => { - if need_lf_global { - read_slot.insert(group.kind, current_slot_idx); - } else { - plan.push(GroupInstr::Decode(current_slot_idx)); - decoded_slots.insert(group.kind); - update_pass_groups = true; - need_hf_global = false; - } - }, - TocGroupKind::LfGroup(_) => { - if need_lf_global { - read_slot.insert(group.kind, current_slot_idx); - } else { - plan.push(GroupInstr::Decode(current_slot_idx)); - decoded_slots.insert(group.kind); - lf_group_count += 1; - update_pass_groups = true; - } - }, - TocGroupKind::GroupPass { pass_idx, group_idx } => { - let lf_group_idx = header.lf_group_idx_from_group_idx(group_idx); - if need_lf_global || need_hf_global || !decoded_slots.contains(&TocGroupKind::LfGroup(lf_group_idx)) { - read_slot.insert(group.kind, current_slot_idx); - } else { - plan.push(GroupInstr::Decode(current_slot_idx)); - decoded_slots.insert(group.kind); - *group_count_per_pass - .get_mut(&pass_idx) - .unwrap() += 1; - } - }, - } - - if update_lf_groups { - let mut decoded = Vec::new(); - for (&kind, &slot_idx) in &read_slot { - if let TocGroupKind::LfGroup(_) = kind { - plan.push(GroupInstr::Decode(slot_idx)); - decoded.push(kind); - } - } - lf_group_count += decoded.len(); - for kind in decoded { - read_slot.remove(&kind); - decoded_slots.insert(kind); - } - update_pass_groups = true; - } - - if update_pass_groups && !need_hf_global { - let mut decoded = Vec::new(); - for (&kind, &slot_idx) in &read_slot { - if let TocGroupKind::GroupPass { pass_idx, group_idx } = kind { - let lf_group_idx = header.lf_group_idx_from_group_idx(group_idx); - if decoded_slots.contains(&TocGroupKind::LfGroup(lf_group_idx)) { - plan.push(GroupInstr::Decode(slot_idx)); - decoded.push(kind); - *group_count_per_pass - .get_mut(&pass_idx) - .unwrap() += 1; - } - } - } - for kind in decoded { - read_slot.remove(&kind); - decoded_slots.insert(kind); - } - } - - if lf_group_count == num_lf_groups && !need_hf_global { - let done = passes.downsample.first().copied().unwrap_or(1) != 8; - plan.push(GroupInstr::ProgressiveScan { - pass_idx: None, - downsample_factor: 8, - done, - }); - lf_group_count += 1; - } - if lf_group_count > num_lf_groups { - while let Some((&pass_idx, &v)) = group_count_per_pass.first_key_value() { - if v == num_groups { - let search_result = passes.last_pass.binary_search(&pass_idx); - let factor_idx = match search_result { - Ok(v) | Err(v) => v, - }; - let done = search_result.is_ok() || passes.last_pass.len() == factor_idx; - let downsample_factor = passes.downsample - .get(factor_idx) - .copied() - .unwrap_or(1); - plan.push(GroupInstr::ProgressiveScan { - pass_idx: Some(pass_idx), - downsample_factor, - done, - }); - group_count_per_pass.pop_first(); - } else { - break; - } - } - } - } - } +impl Frame { + pub fn image_header(&self) -> &ImageHeader { + &self.image_header + } - if let Some(GroupInstr::ProgressiveScan { .. }) = plan.last() { - plan.pop(); - } + pub fn clone_image_header(&self) -> Arc { + Arc::clone(&self.image_header) } -} -impl Frame<'_> { /// Returns the frame header. pub fn header(&self) -> &FrameHeader { &self.header @@ -319,13 +138,143 @@ impl Frame<'_> { &self.toc } - /// Returns the frame data. - pub fn data(&self) -> &FrameData { - &self.data + pub fn pass_shifts(&self, pass_idx: u32) -> Option<(i32, i32)> { + self.pass_shifts.get(&pass_idx).copied() + } + + pub fn data(&self, group: TocGroupKind) -> Option<&[u8]> { + let idx = self.toc.group_index_bitstream_order(group); + self.data.get(idx).map(|b| &**b) } } -impl Frame<'_> { +impl Frame { + pub fn read_all(&mut self, bitstream: &mut Bitstream) -> Result<()> { + assert!(self.data.is_empty()); + + for group in self.toc.iter_bitstream_order() { + tracing::trace!(?group); + bitstream.zero_pad_to_byte()?; + + let mut data = vec![0u8; group.size as usize]; + bitstream.read_bytes_aligned(&mut data)?; + + self.data.push(data); + } + + Ok(()) + } +} + +struct AllParseResult<'buf> { + #[allow(unused)] + lf_global: LfGlobal, + lf_group: LfGroup, + hf_global: Option, + pass_group_bitstream: ByteSliceBitstream<'buf>, +} + +impl Frame { + fn try_parse_all(&self) -> Option> { + if !self.toc.is_single_entry() { + panic!(); + } + + let group = self.data.get(0)?; + let mut bitstream = Bitstream::new(Cursor::new(&**group)); + let result = (|| -> Result<_> { + let lf_global = LfGlobal::parse(&mut bitstream, (&self.image_header, &self.header))?; + let lf_group = LfGroup::parse(&mut bitstream, LfGroupParams::new(&self.header, &lf_global, 0))?; + let hf_global = (self.header.encoding == header::Encoding::VarDct).then(|| { + HfGlobal::parse(&mut bitstream, HfGlobalParams::new(&self.image_header.metadata, &self.header, &lf_global)) + }).transpose()?; + Ok((lf_global, lf_group, hf_global)) + })(); + + match result { + Ok((lf_global, lf_group, hf_global)) => Some(Ok(AllParseResult { + lf_global, + lf_group, + hf_global, + pass_group_bitstream: bitstream, + })), + Err(e) => Some(Err(e)), + } + } + + pub fn try_parse_lf_global(&self) -> Option> { + Some(if self.toc.is_single_entry() { + let group = self.data.get(0)?; + let mut bitstream = Bitstream::new(Cursor::new(group)); + LfGlobal::parse(&mut bitstream, (&self.image_header, &self.header)) + } else { + let idx = self.toc.group_index_bitstream_order(TocGroupKind::LfGlobal); + let group = self.data.get(idx)?; + let mut bitstream = Bitstream::new(Cursor::new(group)); + LfGlobal::parse(&mut bitstream, (&self.image_header, &self.header)) + }) + } + + pub fn try_parse_lf_group(&self, cached_lf_global: Option<&LfGlobal>, lf_group_idx: u32) -> Option> { + if self.toc.is_single_entry() { + if lf_group_idx != 0 { + return None; + } + Some(self.try_parse_all()?.map(|x| x.lf_group)) + } else { + let idx = self.toc.group_index_bitstream_order(TocGroupKind::LfGroup(lf_group_idx)); + let group = self.data.get(idx)?; + let mut bitstream = Bitstream::new(Cursor::new(group)); + let lf_global = if cached_lf_global.is_none() { + match self.try_parse_lf_global()? { + Ok(lf_global) => Some(lf_global), + Err(e) => return Some(Err(e)), + } + } else { + None + }; + let lf_global = cached_lf_global.or(lf_global.as_ref()).unwrap(); + Some(LfGroup::parse(&mut bitstream, LfGroupParams::new(&self.header, lf_global, lf_group_idx))) + } + } + + pub fn try_parse_hf_global(&self, cached_lf_global: Option<&LfGlobal>) -> Option> { + if self.header.encoding == header::Encoding::Modular { + return None; + } + + if self.toc.is_single_entry() { + Some(self.try_parse_all()?.map(|x| x.hf_global.unwrap())) + } else { + let idx = self.toc.group_index_bitstream_order(TocGroupKind::HfGlobal); + let group = self.data.get(idx)?; + let mut bitstream = Bitstream::new(Cursor::new(group)); + let lf_global = if cached_lf_global.is_none() { + match self.try_parse_lf_global()? { + Ok(lf_global) => Some(lf_global), + Err(e) => return Some(Err(e)), + } + } else { + None + }; + let lf_global = cached_lf_global.or(lf_global.as_ref()).unwrap(); + let params = HfGlobalParams::new(&self.image_header.metadata, &self.header, lf_global); + Some(HfGlobal::parse(&mut bitstream, params)) + } + } + + pub fn pass_group_bitstream(&self, pass_idx: u32, group_idx: u32) -> Option> { + if self.toc.is_single_entry() { + Some(self.try_parse_all()?.map(|x| x.pass_group_bitstream)) + } else { + let idx = self.toc.group_index_bitstream_order(TocGroupKind::GroupPass { pass_idx, group_idx }); + let group = self.data.get(idx)?; + Some(Ok(Bitstream::new(Cursor::new(&**group)))) + } + } +} + +impl Frame { /// Adjusts the cropping region of the image to the actual decoding region of the frame. /// /// The cropping region of the *image* needs to be adjusted to be used in a *frame*, for a few @@ -359,337 +308,4 @@ impl Frame<'_> { *height += delta_h + padding; } } - - /// Loads the data of the frame with the given TOC filter. A group is loaded if the filter - /// returns `true` for a given TOC group. - /// - /// If `progressive` is true, then the method pauses loading the data at the next progressive - /// scan marker. - pub fn load_with_filter( - &mut self, - bitstream: &mut Bitstream, - progressive: bool, - mut filter_fn: impl FnMut(&FrameHeader, &FrameData, TocGroupKind) -> bool, - ) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "Load with filter"); - let _guard = span.enter(); - - if self.plan.is_empty() { - self.prepare_default_plan(); - } - - while let Some(&instr) = self.plan.get(self.next_instr) { - let result = self.process_instr(bitstream, instr, &mut filter_fn); - match result { - Err(e) if e.unexpected_eof() => return Ok(ProgressiveResult::NeedMoreData), - result => result?, - } - - self.next_instr += 1; - if progressive { - if let GroupInstr::ProgressiveScan { pass_idx, downsample_factor, done } = instr { - return Ok(ProgressiveResult::SingleScan { - pass_idx, - downsample_factor, - done, - }); - } - } - } - - self.data.complete()?; - Ok(ProgressiveResult::FrameComplete) - } - - /// Loads the data of the frame with the given cropping region of the frame. - /// - /// The region is expected in the frame coordinate. Use [`adjust_region`][Self::adjust_region] - /// to convert from the region of the image. - pub fn load_cropped( - &mut self, - bitstream: &mut Bitstream, - region: Option<(u32, u32, u32, u32)>, - ) -> Result<()> { - let span = tracing::span!(tracing::Level::TRACE, "Load cropped"); - let _guard = span.enter(); - - self.load_with_filter(bitstream, false, crop_filter(region)).map(drop) - } - - /// Loads all data of the frame. - pub fn load_all(&mut self, bitstream: &mut Bitstream) -> Result<()> { - self.load_cropped(bitstream, None) - } - - fn process_instr( - &mut self, - bitstream: &mut Bitstream, - instr: GroupInstr, - mut filter_fn: impl FnMut(&FrameHeader, &FrameData, TocGroupKind) -> bool, - ) -> Result<()> { - let span = tracing::span!(tracing::Level::TRACE, "Process instruction", instr = format_args!("{:?}", instr)); - let _guard = span.enter(); - - match instr { - GroupInstr::Read(slot_idx, group) => { - tracing::trace!(group_kind = format_args!("{:?}", group.kind), "Reading group into memory"); - bitstream.skip_to_bookmark(group.offset)?; - - let mut b = bitstream.rewindable(); - let mut buf = vec![0u8; group.size as usize]; - b.read_bytes_aligned(&mut buf)?; - b.commit(); - - self.buf_slot.insert(slot_idx, (group.kind, buf)); - }, - GroupInstr::Decode(slot_idx) => { - let &(kind, ref buf) = self.buf_slot.get(&slot_idx).unwrap(); - tracing::trace!(group_kind = format_args!("{:?}", kind), "Decoding group"); - if !filter_fn(&self.header, &self.data, kind) { - return Ok(()); - } - - let mut bitstream = Bitstream::new(std::io::Cursor::new(buf)); - self.data.load_single(&mut bitstream, kind, self.image_header, &self.header, &self.pass_shifts)?; - }, - GroupInstr::ProgressiveScan { downsample_factor, done, .. } => { - tracing::debug!(downsample_factor, done, "Single progressive scan"); - }, - } - Ok(()) - } -} - -/// Data of a frame. -#[derive(Debug)] -pub struct FrameData { - pub lf_global: Option, - pub lf_group: HashMap, - pub hf_global: Option>, - pub group_pass: HashMap<(u32, u32), PassGroup>, -} - -impl FrameData { - fn new(frame_header: &FrameHeader) -> Self { - let has_hf_global = frame_header.encoding == crate::header::Encoding::VarDct; - let hf_global = if has_hf_global { - None - } else { - Some(None) - }; - - Self { - lf_global: None, - lf_group: Default::default(), - hf_global, - group_pass: Default::default(), - } - } - - fn complete(&mut self) -> Result<&mut Self> { - let Self { - lf_global, - lf_group, - group_pass, - .. - } = self; - - let Some(lf_global) = lf_global else { - return Err(Error::IncompleteFrameData { field: "lf_global" }); - }; - for lf_group in lf_group.values_mut() { - let mlf_group = std::mem::take(&mut lf_group.mlf_group); - lf_global.gmodular.modular.copy_from_modular(mlf_group); - } - for group in group_pass.values_mut() { - let modular = std::mem::take(&mut group.modular); - lf_global.gmodular.modular.copy_from_modular(modular); - } - lf_global.apply_modular_inverse_transform(); - Ok(self) - } -} - -impl FrameData { - fn load_single( - &mut self, - bitstream: &mut Bitstream, - kind: TocGroupKind, - image_header: &ImageHeader, - frame_header: &FrameHeader, - pass_shifts: &BTreeMap, - ) -> Result<()> { - match kind { - TocGroupKind::All => { - let shift = pass_shifts.get(&0).copied(); - self.read_merged_group(bitstream, shift, image_header, frame_header)?; - }, - TocGroupKind::LfGlobal => { - self.lf_global = Some(self.read_lf_global(bitstream, image_header, frame_header)?); - }, - TocGroupKind::LfGroup(lf_group_idx) => { - let lf_global = self.lf_global.as_ref().expect("invalid decode plan: LfGlobal not decoded"); - self.lf_group.insert(lf_group_idx, self.read_lf_group(bitstream, lf_global, lf_group_idx, frame_header)?); - }, - TocGroupKind::HfGlobal => { - let lf_global = self.lf_global.as_ref().expect("invalid decode plan: LfGlobal not decoded"); - self.hf_global = Some(self.read_hf_global(bitstream, lf_global, image_header, frame_header)?); - }, - TocGroupKind::GroupPass { pass_idx, group_idx } => { - let lf_global = self.lf_global.as_ref().expect("invalid decode plan: LfGlobal not decoded"); - let lf_group_idx = frame_header.lf_group_idx_from_group_idx(group_idx); - let lf_group = self.lf_group.get(&lf_group_idx).expect("invalid decode plan: LfGroup not decoded"); - let hf_global = self.hf_global.as_ref().expect("invalid decode plan: HfGlobal not decoded"); - - let shift = pass_shifts.get(&pass_idx).copied(); - let group = self.read_group_pass( - bitstream, - lf_global, - lf_group, - hf_global.as_ref(), - pass_idx, - group_idx, - shift, - frame_header, - )?; - self.group_pass.insert((pass_idx, group_idx), group); - }, - } - Ok(()) - } - - fn read_lf_global( - &mut self, - bitstream: &mut Bitstream, - image_header: &ImageHeader, - frame_header: &FrameHeader, - ) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "Decode LfGlobal"); - let _guard = span.enter(); - read_bits!(bitstream, Bundle(LfGlobal), (image_header, frame_header)) - } - - fn read_lf_group( - &self, - bitstream: &mut Bitstream, - lf_global: &LfGlobal, - lf_group_idx: u32, - frame_header: &FrameHeader, - ) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "Decode LfGroup", lf_group_idx); - let _guard = span.enter(); - let lf_group_params = LfGroupParams::new(frame_header, lf_global, lf_group_idx); - read_bits!(bitstream, Bundle(LfGroup), lf_group_params) - } - - fn read_hf_global( - &self, - bitstream: &mut Bitstream, - lf_global: &LfGlobal, - image_header: &ImageHeader, - frame_header: &FrameHeader, - ) -> Result> { - let has_hf_global = frame_header.encoding == crate::header::Encoding::VarDct; - let hf_global = if has_hf_global { - let span = tracing::span!(tracing::Level::TRACE, "Decode HfGlobal"); - let _guard = span.enter(); - let params = HfGlobalParams::new(&image_header.metadata, frame_header, lf_global); - Some(HfGlobal::parse(bitstream, params)?) - } else { - None - }; - Ok(hf_global) - } - - #[allow(clippy::too_many_arguments)] - fn read_group_pass( - &self, - bitstream: &mut Bitstream, - lf_global: &LfGlobal, - lf_group: &LfGroup, - hf_global: Option<&HfGlobal>, - pass_idx: u32, - group_idx: u32, - shift: Option<(i32, i32)>, - frame_header: &FrameHeader, - ) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "Decode PassGroup", pass_idx, group_idx); - let _guard = span.enter(); - let params = PassGroupParams::new( - frame_header, - lf_global, - lf_group, - hf_global, - pass_idx, - group_idx, - shift, - ); - read_bits!(bitstream, Bundle(PassGroup), params) - } - - fn read_merged_group( - &mut self, - bitstream: &mut Bitstream, - shift: Option<(i32, i32)>, - image_header: &ImageHeader, - frame_header: &FrameHeader, - ) -> Result<()> { - let lf_global = self.read_lf_global(bitstream, image_header, frame_header)?; - let lf_group = self.read_lf_group(bitstream, &lf_global, 0, frame_header)?; - let hf_global = self.read_hf_global(bitstream, &lf_global, image_header, frame_header)?; - let group_pass = self.read_group_pass(bitstream, &lf_global, &lf_group, hf_global.as_ref(), 0, 0, shift, frame_header)?; - - self.lf_global = Some(lf_global); - self.lf_group.insert(0, lf_group); - self.hf_global = Some(hf_global); - self.group_pass.insert((0, 0), group_pass); - - Ok(()) - } -} - -/// Creates a filter that loads only groups within the given cropping region of the frame. -/// -/// The region is expected in the frame coordinate. Use [`Frame::adjust_region`] to convert from -/// the region of the image. -pub fn crop_filter(adjusted_region: Option<(u32, u32, u32, u32)>) -> impl for<'a, 'b> FnMut(&'a FrameHeader, &'b FrameData, TocGroupKind) -> bool { - let mut region = adjusted_region; - let mut region_adjust_done = false; - - move |frame_header: &FrameHeader, frame_data: &FrameData, kind| { - if !region_adjust_done { - let Some(lf_global) = frame_data.lf_global.as_ref() else { - return true; - }; - if lf_global.gmodular.modular.has_delta_palette() { - if region.take().is_some() { - tracing::debug!("GlobalModular has delta palette, forcing full decode"); - } - } else if lf_global.gmodular.modular.has_squeeze() { - if let Some((left, top, width, height)) = &mut region { - *width += *left; - *height += *top; - *left = 0; - *top = 0; - tracing::debug!("GlobalModular has squeeze, decoding from top-left"); - } - } - if let Some(region) = ®ion { - tracing::debug!("Cropped decoding: {:?}", region); - } - region_adjust_done = true; - } - - let Some(region) = region else { return true; }; - - match kind { - TocGroupKind::LfGroup(lf_group_idx) => { - frame_header.is_lf_group_collides_region(lf_group_idx, region) - }, - TocGroupKind::GroupPass { group_idx, .. } => { - frame_header.is_group_collides_region(group_idx, region) - }, - _ => true, - } - } } diff --git a/crates/jxl-modular/src/image.rs b/crates/jxl-modular/src/image.rs index 37377bb0..f87c5f3f 100644 --- a/crates/jxl-modular/src/image.rs +++ b/crates/jxl-modular/src/image.rs @@ -14,7 +14,7 @@ use crate::{ /// /// A decoded Modular image consists of multiple channels. Those channels may not be in the same /// size. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Image { group_dim: u32, bit_depth: u32, diff --git a/crates/jxl-modular/src/lib.rs b/crates/jxl-modular/src/lib.rs index 4409e918..d2c0acce 100644 --- a/crates/jxl-modular/src/lib.rs +++ b/crates/jxl-modular/src/lib.rs @@ -26,12 +26,12 @@ pub use param::*; /// - creating a subimage of existing image by calling [self.make_subimage_params_lf_group] or /// [self.make_subimage_params_pass_group]. /// 2. Decode pixels by calling [self.decode_image] or [self.decode_image_gmodular]. -#[derive(Debug, Default)] +#[derive(Debug, Clone, Default)] pub struct Modular { inner: Option, } -#[derive(Debug)] +#[derive(Debug, Clone)] struct ModularData { group_dim: u32, header: ModularHeader, @@ -318,7 +318,7 @@ impl Bundle> for ModularData { } define_bundle! { - #[derive(Debug)] + #[derive(Debug, Clone)] struct ModularHeader error(crate::Error) { use_global_tree: ty(Bool), wp_params: ty(Bundle(predictor::WpHeader)), diff --git a/crates/jxl-modular/src/transform.rs b/crates/jxl-modular/src/transform.rs index 34e6a9ca..3c3e4aac 100644 --- a/crates/jxl-modular/src/transform.rs +++ b/crates/jxl-modular/src/transform.rs @@ -6,7 +6,7 @@ use jxl_grid::Grid; use crate::{Error, Result}; use super::{ModularChannelInfo, Image, predictor::{Predictor, PredictorState, WpHeader}}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum TransformInfo { Rct(Rct), Palette(Palette), @@ -68,19 +68,19 @@ impl Bundle<&WpHeader> for TransformInfo { } define_bundle! { - #[derive(Debug)] + #[derive(Debug, Clone)] pub struct Rct error(crate::Error) { begin_c: ty(U32(u(3), 8 + u(6), 72 + u(10), 1096 + u(13))), rct_type: ty(U32(6, u(2), 2 + u(4), 10 + u(6))), } - #[derive(Debug)] + #[derive(Debug, Clone)] pub struct Squeeze error(crate::Error) { num_sq: ty(U32(0, 1 + u(4), 9 + u(6), 41 + u(8))), sp: ty(Vec[Bundle(SqueezeParams)]; num_sq), } - #[derive(Debug)] + #[derive(Debug, Clone)] struct SqueezeParams error(crate::Error) { horizontal: ty(Bool), in_place: ty(Bool), @@ -89,7 +89,7 @@ define_bundle! { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Palette { begin_c: u32, num_c: u32, diff --git a/crates/jxl-oxide/src/lib.rs b/crates/jxl-oxide/src/lib.rs index c6119582..597c4812 100644 --- a/crates/jxl-oxide/src/lib.rs +++ b/crates/jxl-oxide/src/lib.rs @@ -57,6 +57,7 @@ use std::{ fs::File, io::Read, path::Path, + sync::Arc, }; mod fb; @@ -81,7 +82,7 @@ pub type Result = std::result::Result { bitstream: Bitstream>, - image_header: ImageHeader, + image_header: Arc, embedded_icc: Option>, } @@ -89,7 +90,7 @@ impl JxlImage { /// Creates a `JxlImage` from the reader. pub fn from_reader(reader: R) -> Result { let mut bitstream = Bitstream::new_detect(reader); - let image_header = ImageHeader::parse(&mut bitstream, ())?; + let image_header = Arc::new(ImageHeader::parse(&mut bitstream, ())?); let embedded_icc = image_header.metadata.colour_encoding.want_icc.then(|| { tracing::debug!("Image has an embedded ICC profile"); @@ -101,7 +102,7 @@ impl JxlImage { tracing::debug!("Skipping preview frame"); bitstream.zero_pad_to_byte()?; - let frame = Frame::parse(&mut bitstream, &image_header)?; + let frame = Frame::parse(&mut bitstream, image_header.clone())?; let toc = frame.toc(); let bookmark = toc.bookmark() + (toc.total_byte_size() * 8); bitstream.skip_to_bookmark(bookmark)?; @@ -139,10 +140,10 @@ impl JxlImage { /// Starts rendering the image. #[inline] pub fn renderer(&mut self) -> JxlRenderer<'_, R> { - let ctx = RenderContext::new(&self.image_header); + let ctx = RenderContext::new(self.image_header.clone()); JxlRenderer { bitstream: &mut self.bitstream, - image_header: &self.image_header, + image_header: self.image_header.clone(), embedded_icc: self.embedded_icc.as_deref(), ctx, render_spot_colour: !self.image_header.metadata.grayscale(), @@ -165,9 +166,9 @@ impl JxlImage { #[derive(Debug)] pub struct JxlRenderer<'img, R> { bitstream: &'img mut Bitstream>, - image_header: &'img ImageHeader, + image_header: Arc, embedded_icc: Option<&'img [u8]>, - ctx: RenderContext<'img>, + ctx: RenderContext, render_spot_colour: bool, crop_region: Option, end_of_image: bool, @@ -176,8 +177,8 @@ pub struct JxlRenderer<'img, R> { impl<'img, R: Read> JxlRenderer<'img, R> { /// Returns the image header. #[inline] - pub fn image_header(&self) -> &'img ImageHeader { - self.image_header + pub fn image_header(&self) -> &ImageHeader { + &self.image_header } /// Sets the cropping region of the image. @@ -194,6 +195,7 @@ impl<'img, R: Read> JxlRenderer<'img, R> { } #[inline] + #[allow(unused)] fn crop_region_flattened(&self) -> Option<(u32, u32, u32, u32)> { self.crop_region.map(|info| (info.left, info.top, info.width, info.height)) } @@ -264,17 +266,11 @@ impl<'img, R: Read> JxlRenderer<'img, R> { return Ok(LoadResult::NoMoreFrames); } - let region = self.crop_region_flattened(); - let result = self.ctx.load_until_keyframe(self.bitstream, false, region)?; - match result { - jxl_frame::ProgressiveResult::NeedMoreData => Ok(LoadResult::NeedMoreData), - jxl_frame::ProgressiveResult::FrameComplete => { - let keyframe_index = self.ctx.loaded_keyframes() - 1; - self.end_of_image = self.frame_header(keyframe_index).unwrap().is_last; - Ok(LoadResult::Done(keyframe_index)) - }, - _ => unreachable!(), - } + self.ctx.load_until_keyframe(self.bitstream)?; + + let keyframe_index = self.ctx.loaded_keyframes() - 1; + self.end_of_image = self.frame_header(keyframe_index).unwrap().is_last; + Ok(LoadResult::Done(keyframe_index)) } /// Returns the frame header for the given keyframe index, or `None` if the keyframe does not @@ -291,8 +287,7 @@ impl<'img, R: Read> JxlRenderer<'img, R> { /// Renders the given keyframe. pub fn render_frame(&mut self, keyframe_index: usize) -> Result { - let region = self.crop_region_flattened(); - let mut grids = self.ctx.render_keyframe_cropped(keyframe_index, region)?; + let mut grids = self.ctx.render_keyframe(keyframe_index)?; let color_channels = if self.image_header.metadata.grayscale() { 1 } else { 3 }; let mut color_channels: Vec<_> = grids.drain(..color_channels).collect(); diff --git a/crates/jxl-oxide/tests/decode b/crates/jxl-oxide/tests/decode index 8fcea5bd..c6080763 160000 --- a/crates/jxl-oxide/tests/decode +++ b/crates/jxl-oxide/tests/decode @@ -1 +1 @@ -Subproject commit 8fcea5bdfdbafba391bb28eee1d546eff13977e8 +Subproject commit c60807634656b7d11d9bb9dfe69d917d8d15ebcf diff --git a/crates/jxl-render/src/blend.rs b/crates/jxl-render/src/blend.rs index bf62a249..93d2fdce 100644 --- a/crates/jxl-render/src/blend.rs +++ b/crates/jxl-render/src/blend.rs @@ -141,7 +141,7 @@ fn source_and_alpha_from_blending_info(blending_info: &BlendingInfo) -> (usize, pub fn blend( image_header: &ImageHeader, reference_grids: [Option<&[SimpleGrid]>; 4], - new_frame: &Frame<'_>, + new_frame: &Frame, new_grid: &[SimpleGrid], ) -> Vec> { let header = new_frame.header(); diff --git a/crates/jxl-render/src/cut_grid.rs b/crates/jxl-render/src/cut_grid.rs index 514899a8..de3d2972 100644 --- a/crates/jxl-render/src/cut_grid.rs +++ b/crates/jxl-render/src/cut_grid.rs @@ -1,8 +1,5 @@ -use std::{ptr::NonNull, collections::HashMap}; - use jxl_grid::{CutGrid, Grid, SimpleGrid}; use jxl_modular::ChannelShift; -use jxl_vardct::HfCoeff; pub fn make_quant_cut_grid<'g>( buf: &'g mut SimpleGrid, @@ -37,84 +34,3 @@ pub fn make_quant_cut_grid<'g>( } grid } - -pub fn cut_with_block_info<'g>( - grid: &'g mut SimpleGrid, - group_coeffs: &HashMap, - group_dim: usize, - jpeg_upsampling: ChannelShift, -) -> HashMap>> { - let grid_width = grid.width(); - let grid_height = grid.height(); - let buf = grid.buf_mut(); - let ptr = NonNull::new(buf.as_mut_ptr()).unwrap(); - - let hshift = jpeg_upsampling.hshift(); - let vshift = jpeg_upsampling.vshift(); - let groups_per_row = (grid_width + group_dim - 1) / group_dim; - - group_coeffs - .iter() - .map(|(&idx, group)| { - let group_y = idx / groups_per_row; - let group_x = idx % groups_per_row; - let base_y = (group_y * group_dim) >> vshift; - let base_x = (group_x * group_dim) >> hshift; - let mut check_flags = vec![false; group_dim * group_dim]; - - let mut subgrids = HashMap::new(); - for coeff in group.data() { - let x = coeff.bx; - let y = coeff.by; - let sx = x >> hshift; - let sy = y >> vshift; - if (sx << hshift) != x || (sy << vshift) != y { - continue; - } - - let dct_select = coeff.dct_select; - let x8 = sx * 8; - let y8 = sy * 8; - let (bw, bh) = dct_select.dct_select_size(); - for dy in 0..bh as usize { - for dx in 0..bw as usize { - let idx = (sy + dy) * group_dim + (sx + dx); - if check_flags[idx] { - panic!("Invalid block_info"); - } - check_flags[idx] = true; - } - } - - let block_width = bw as usize * 8; - let block_height = bh as usize * 8; - let grid_x = base_x + x8; - let grid_y = base_y + y8; - if grid_x + block_width > grid_width || grid_y + block_height > grid_height { - panic!( - "Invalid group_coeffs? \ - grid_x={grid_x}, grid_y={grid_y}, \ - block_width={block_width}, block_height={block_height}, \ - grid_width={grid_width}, grid_height={grid_height}" - ); - } - - let offset = grid_y * grid_width + grid_x; - let stride = grid_width; - - // SAFETY: check_flags makes sure that the subgrids are disjoint. - let subgrid = unsafe { - CutGrid::new( - NonNull::new_unchecked(ptr.as_ptr().add(offset)), - block_width, - block_height, - stride, - ) - }; - subgrids.insert((x, y), subgrid); - } - - (idx, subgrids) - }) - .collect() -} diff --git a/crates/jxl-render/src/dct.rs b/crates/jxl-render/src/dct.rs index db661f8c..5a1177c0 100644 --- a/crates/jxl-render/src/dct.rs +++ b/crates/jxl-render/src/dct.rs @@ -1,6 +1,12 @@ mod consts; mod generic; +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum DctDirection { + Forward, + Inverse, +} + #[cfg( not(target_arch = "x86_64") )] diff --git a/crates/jxl-render/src/dct/generic.rs b/crates/jxl-render/src/dct/generic.rs index caa2a7ea..bcf3e027 100644 --- a/crates/jxl-render/src/dct/generic.rs +++ b/crates/jxl-render/src/dct/generic.rs @@ -1,74 +1,15 @@ -use jxl_grid::{CutGrid, SimpleGrid}; +use jxl_grid::CutGrid; -use super::consts; +use super::{consts, DctDirection}; -pub fn dct_2d(io: &mut SimpleGrid) { +pub fn dct_2d(io: &mut CutGrid<'_>, direction: DctDirection) { let width = io.width(); let height = io.height(); - let io_buf = io.buf_mut(); - dct_2d_generic(io_buf, width, height, false) -} - -pub fn dct_2d_generic(io_buf: &mut [f32], width: usize, height: usize, inverse: bool) { let mut buf = vec![0f32; width.max(height)]; let row = &mut buf[..width]; for y in 0..height { - dct(&mut io_buf[y * width..][..width], row, inverse); - } - - let block_size = width.min(height); - for by in (0..height).step_by(block_size) { - for bx in (0..width).step_by(block_size) { - for dy in 0..block_size { - for dx in (dy + 1)..block_size { - io_buf.swap((by + dy) * width + (bx + dx), (by + dx) * width + (bx + dy)); - } - } - } - } - - let scratch = &mut buf[..height]; - if block_size == height { - for row in io_buf.chunks_exact_mut(height) { - dct(row, scratch, inverse); - } - } else { - let mut row = vec![0f32; height]; - for y in 0..width { - for (idx, chunk) in row.chunks_exact_mut(width).enumerate() { - let y = y + idx * block_size; - chunk.copy_from_slice(&io_buf[y * width..][..width]); - } - dct(&mut row, scratch, inverse); - for (idx, chunk) in row.chunks_exact(width).enumerate() { - let y = y + idx * block_size; - io_buf[y * width..][..width].copy_from_slice(chunk); - } - } - } - - if width != height { - for by in (0..height).step_by(block_size) { - for bx in (0..width).step_by(block_size) { - for dy in 0..block_size { - for dx in (dy + 1)..block_size { - io_buf.swap((by + dy) * width + (bx + dx), (by + dx) * width + (bx + dy)); - } - } - } - } - } -} - -pub fn idct_2d(io: &mut CutGrid<'_>) { - let width = io.width(); - let height = io.height(); - let mut buf = vec![0f32; width.max(height)]; - - let row = &mut buf[..width]; - for y in 0..height { - dct(io.get_row_mut(y), row, true); + dct(io.get_row_mut(y), row, direction); } let block_size = width.min(height); @@ -87,7 +28,7 @@ pub fn idct_2d(io: &mut CutGrid<'_>) { for y in 0..height { let grouped_row = io.get_row_mut(y); for row in grouped_row.chunks_exact_mut(height) { - dct(row, scratch, true); + dct(row, scratch, direction); } } } else { @@ -97,7 +38,7 @@ pub fn idct_2d(io: &mut CutGrid<'_>) { let y = y + idx * block_size; chunk.copy_from_slice(io.get_row(y)); } - dct(&mut row, scratch, true); + dct(&mut row, scratch, direction); for (idx, chunk) in row.chunks_exact(width).enumerate() { let y = y + idx * block_size; io.get_row_mut(y).copy_from_slice(chunk); @@ -105,24 +46,22 @@ pub fn idct_2d(io: &mut CutGrid<'_>) { } } - if width != height { - for by in (0..height).step_by(block_size) { - for bx in (0..width).step_by(block_size) { - for dy in 0..block_size { - for dx in (dy + 1)..block_size { - io.swap((bx + dx, by + dy), (bx + dy, by + dx)); - } + for by in (0..height).step_by(block_size) { + for bx in (0..width).step_by(block_size) { + for dy in 0..block_size { + for dx in (dy + 1)..block_size { + io.swap((bx + dx, by + dy), (bx + dy, by + dx)); } } } } } -fn dct4(input: [f32; 4], inverse: bool) -> [f32; 4] { +fn dct4(input: [f32; 4], direction: DctDirection) -> [f32; 4] { let sec0 = 0.5411961; let sec1 = 1.306563; - if !inverse { + if direction == DctDirection::Forward { let sum03 = input[0] + input[3]; let sum12 = input[1] + input[2]; let tmp0 = (input[0] - input[3]) * sec0; @@ -153,7 +92,7 @@ fn dct4(input: [f32; 4], inverse: bool) -> [f32; 4] { } } -fn dct(input_output: &mut [f32], scratch: &mut [f32], inverse: bool) { +fn dct(input_output: &mut [f32], scratch: &mut [f32], direction: DctDirection) { let n = input_output.len(); assert!(scratch.len() == n); @@ -166,26 +105,26 @@ fn dct(input_output: &mut [f32], scratch: &mut [f32], inverse: bool) { if n == 2 { let tmp0 = input_output[0] + input_output[1]; let tmp1 = input_output[0] - input_output[1]; - if inverse { - input_output[0] = tmp0; - input_output[1] = tmp1; - } else { + if direction == DctDirection::Forward { input_output[0] = tmp0 / 2.0; input_output[1] = tmp1 / 2.0; + } else { + input_output[0] = tmp0; + input_output[1] = tmp1; } return; } if n == 4 { let io = input_output; - io.copy_from_slice(&dct4([io[0], io[1], io[2], io[3]], inverse)); + io.copy_from_slice(&dct4([io[0], io[1], io[2], io[3]], direction)); return; } if n == 8 { let io = input_output; let sec = consts::sec_half_small(8); - if !inverse { + if direction == DctDirection::Forward { let input0 = [ (io[0] + io[7]) / 2.0, (io[1] + io[6]) / 2.0, @@ -198,11 +137,11 @@ fn dct(input_output: &mut [f32], scratch: &mut [f32], inverse: bool) { (io[2] - io[5]) * sec[2] / 2.0, (io[3] - io[4]) * sec[3] / 2.0, ]; - let output0 = dct4(input0, false); + let output0 = dct4(input0, DctDirection::Forward); for (idx, v) in output0.into_iter().enumerate() { io[idx * 2] = v; } - let mut output1 = dct4(input1, false); + let mut output1 = dct4(input1, DctDirection::Forward); output1[0] *= std::f32::consts::SQRT_2; for idx in 0..3 { io[idx * 2 + 1] = output1[idx] + output1[idx + 1]; @@ -216,8 +155,8 @@ fn dct(input_output: &mut [f32], scratch: &mut [f32], inverse: bool) { io[5] + io[3], io[7] + io[5], ]; - let output0 = dct4(input0, true); - let output1 = dct4(input1, true); + let output0 = dct4(input0, DctDirection::Inverse); + let output1 = dct4(input1, DctDirection::Inverse); for (idx, &sec) in sec.iter().enumerate() { let r = output1[idx] * sec; io[idx] = output0[idx] + r; @@ -229,7 +168,7 @@ fn dct(input_output: &mut [f32], scratch: &mut [f32], inverse: bool) { assert!(n.is_power_of_two()); - if !inverse { + if direction == DctDirection::Forward { let (input0, input1) = scratch.split_at_mut(n / 2); for idx in 0..(n / 2) { input0[idx] = (input_output[idx] + input_output[n - idx - 1]) / 2.0; @@ -239,8 +178,8 @@ fn dct(input_output: &mut [f32], scratch: &mut [f32], inverse: bool) { for (v, &sec) in input1.iter_mut().zip(consts::sec_half(n)) { *v *= sec; } - dct(input0, output0, false); - dct(input1, output1, false); + dct(input0, output0, DctDirection::Forward); + dct(input1, output1, DctDirection::Forward); input1[0] *= std::f32::consts::SQRT_2; for idx in 0..(n / 2 - 1) { input1[idx] += input1[idx + 1]; @@ -262,8 +201,8 @@ fn dct(input_output: &mut [f32], scratch: &mut [f32], inverse: bool) { } input1[0] *= std::f32::consts::SQRT_2; let (output0, output1) = input_output.split_at_mut(n / 2); - dct(input0, output0, true); - dct(input1, output1, true); + dct(input0, output0, DctDirection::Inverse); + dct(input1, output1, DctDirection::Inverse); for (v, &sec) in input1.iter_mut().zip(consts::sec_half(n)) { *v *= sec; } @@ -276,12 +215,14 @@ fn dct(input_output: &mut [f32], scratch: &mut [f32], inverse: bool) { #[cfg(test)] mod tests { + use crate::dct::DctDirection; + #[test] fn forward_dct_2() { let original = [-1.0, 3.0]; let mut io = original; let mut scratch = [0.0f32; 2]; - super::dct(&mut io, &mut scratch, false); + super::dct(&mut io, &mut scratch, DctDirection::Forward); let s = original.len(); for (k, output) in io.iter().enumerate() { @@ -306,7 +247,7 @@ mod tests { let original = [-1.0, 2.0, 3.0, -4.0]; let mut io = original; let mut scratch = [0.0f32; 4]; - super::dct(&mut io, &mut scratch, false); + super::dct(&mut io, &mut scratch, DctDirection::Forward); let s = original.len(); for (k, output) in io.iter().enumerate() { @@ -331,7 +272,7 @@ mod tests { let original = [1.0, 0.3, 1.0, 2.0, -2.0, -0.1, 1.0, 0.1]; let mut io = original; let mut scratch = [0.0f32; 8]; - super::dct(&mut io, &mut scratch, false); + super::dct(&mut io, &mut scratch, DctDirection::Forward); let s = original.len(); for (k, output) in io.iter().enumerate() { @@ -356,7 +297,7 @@ mod tests { let original = [3.0, 0.2]; let mut io = original; let mut scratch = [0.0f32; 2]; - super::dct(&mut io, &mut scratch, true); + super::dct(&mut io, &mut scratch, DctDirection::Inverse); let s = original.len(); for (k, output) in io.iter().enumerate() { @@ -377,7 +318,7 @@ mod tests { let original = [3.0, 0.2, 0.3, -1.0]; let mut io = original; let mut scratch = [0.0f32; 4]; - super::dct(&mut io, &mut scratch, true); + super::dct(&mut io, &mut scratch, DctDirection::Inverse); let s = original.len(); for (k, output) in io.iter().enumerate() { @@ -398,7 +339,7 @@ mod tests { let original = [3.0, 0.0, 0.0, -1.0, 0.0, 0.3, 0.2, 0.0]; let mut io = original; let mut scratch = [0.0f32; 8]; - super::dct(&mut io, &mut scratch, true); + super::dct(&mut io, &mut scratch, DctDirection::Inverse); let s = original.len(); for (k, output) in io.iter().enumerate() { diff --git a/crates/jxl-render/src/dct/x86_64/mod.rs b/crates/jxl-render/src/dct/x86_64/mod.rs index b8ec6a55..220d4c50 100644 --- a/crates/jxl-render/src/dct/x86_64/mod.rs +++ b/crates/jxl-render/src/dct/x86_64/mod.rs @@ -1,6 +1,6 @@ -use jxl_grid::{CutGrid, SimdVector, SimpleGrid}; +use jxl_grid::{CutGrid, SimdVector}; -use super::consts; +use super::{consts, DctDirection}; use std::arch::x86_64::*; const LANE_SIZE: usize = 4; @@ -11,47 +11,31 @@ fn transpose_lane(lanes: &mut [Lane]) { unsafe { _MM_TRANSPOSE4_PS(row0, row1, row2, row3); } } -pub fn dct_2d(io: &mut SimpleGrid) { - let width = io.width(); - let height = io.height(); - if width % LANE_SIZE != 0 || height % LANE_SIZE != 0 { - return super::generic::dct_2d(io); +pub fn dct_2d(io: &mut CutGrid<'_>, direction: DctDirection) { + if io.width() % LANE_SIZE != 0 || io.height() % LANE_SIZE != 0 { + return super::generic::dct_2d(io, direction); } - let io_buf = io.buf_mut(); - dct_2d_generic(io_buf, width, height, false) -} - -pub fn dct_2d_generic(io_buf: &mut [f32], width: usize, height: usize, inverse: bool) { - let mut io = CutGrid::from_buf(io_buf, width, height, width); let Some(mut io) = io.as_vectored() else { tracing::trace!("Input buffer is not aligned"); - return super::generic::dct_2d_generic(io_buf, width, height, inverse); + return super::generic::dct_2d(io, direction); }; - dct_2d_lane(&mut io, inverse); + dct_2d_lane(&mut io, direction); } -pub fn idct_2d(io: &mut CutGrid<'_>) { - let Some(mut io) = io.as_vectored() else { - tracing::trace!("Input buffer is not aligned"); - return super::generic::idct_2d(io); - }; - dct_2d_lane(&mut io, true); -} - -fn dct_2d_lane(io: &mut CutGrid<'_, Lane>, inverse: bool) { +fn dct_2d_lane(io: &mut CutGrid<'_, Lane>, direction: DctDirection) { let scratch_size = io.height().max(io.width() * LANE_SIZE) * 2; unsafe { let mut scratch_lanes = vec![_mm_setzero_ps(); scratch_size]; - column_dct_lane(io, &mut scratch_lanes, inverse); - row_dct_lane(io, &mut scratch_lanes, inverse); + column_dct_lane(io, &mut scratch_lanes, direction); + row_dct_lane(io, &mut scratch_lanes, direction); } } fn column_dct_lane( io: &mut CutGrid<'_, Lane>, scratch: &mut [Lane], - inverse: bool, + direction: DctDirection, ) { let width = io.width(); let height = io.height(); @@ -60,7 +44,7 @@ fn column_dct_lane( for (y, input) in io_lanes.iter_mut().enumerate() { *input = io.get(x, y); } - dct(io_lanes, scratch_lanes, inverse); + dct(io_lanes, scratch_lanes, direction); for (y, output) in io_lanes.chunks_exact_mut(LANE_SIZE).enumerate() { transpose_lane(output); for (dy, output) in output.iter_mut().enumerate() { @@ -73,7 +57,7 @@ fn column_dct_lane( fn row_dct_lane( io: &mut CutGrid<'_, Lane>, scratch: &mut [Lane], - inverse: bool, + direction: DctDirection, ) { let width = io.width() * LANE_SIZE; let height = io.height(); @@ -84,36 +68,23 @@ fn row_dct_lane( *input = io.get(x, y + dy); } } - dct(io_lanes, scratch_lanes, inverse); + dct(io_lanes, scratch_lanes, direction); for (x, output) in io_lanes.chunks_exact_mut(LANE_SIZE).enumerate() { - if width != height { - transpose_lane(output); - } + transpose_lane(output); for (dy, output) in output.iter_mut().enumerate() { *io.get_mut(x, y + dy) = *output; } } } - - if width == height { - for y in 0..height / LANE_SIZE { - for x in (y + 1)..width / LANE_SIZE { - io.swap((x, y * LANE_SIZE), (y, x * LANE_SIZE)); - io.swap((x, y * LANE_SIZE + 1), (y, x * LANE_SIZE + 1)); - io.swap((x, y * LANE_SIZE + 2), (y, x * LANE_SIZE + 2)); - io.swap((x, y * LANE_SIZE + 3), (y, x * LANE_SIZE + 3)); - } - } - } } -fn dct4(input: [Lane; 4], inverse: bool) -> [Lane; 4] { +fn dct4(input: [Lane; 4], direction: DctDirection) -> [Lane; 4] { let sec0 = Lane::splat_f32(0.5411961); let sec1 = Lane::splat_f32(1.306563); let quarter = Lane::splat_f32(0.25); let sqrt2 = Lane::splat_f32(std::f32::consts::SQRT_2); - if !inverse { + if direction == DctDirection::Forward { let sum03 = input[0].add(input[3]); let sum12 = input[1].add(input[2]); let tmp0 = input[0].sub(input[3]).mul(sec0); @@ -144,7 +115,7 @@ fn dct4(input: [Lane; 4], inverse: bool) -> [Lane; 4] { } } -fn dct(io: &mut [Lane], scratch: &mut [Lane], inverse: bool) { +fn dct(io: &mut [Lane], scratch: &mut [Lane], direction: DctDirection) { let n = io.len(); assert!(scratch.len() == n); @@ -159,25 +130,25 @@ fn dct(io: &mut [Lane], scratch: &mut [Lane], inverse: bool) { if n == 2 { let tmp0 = io[0].add(io[1]); let tmp1 = io[0].sub(io[1]); - if inverse { - io[0] = tmp0; - io[1] = tmp1; - } else { + if direction == DctDirection::Forward { io[0] = tmp0.mul(half); io[1] = tmp1.mul(half); + } else { + io[0] = tmp0; + io[1] = tmp1; } return; } if n == 4 { - io.copy_from_slice(&dct4([io[0], io[1], io[2], io[3]], inverse)); + io.copy_from_slice(&dct4([io[0], io[1], io[2], io[3]], direction)); return; } let sqrt2 = Lane::splat_f32(std::f32::consts::SQRT_2); if n == 8 { let sec = consts::sec_half_small(8); - if !inverse { + if direction == DctDirection::Forward { let input0 = [ io[0].add(io[7]).mul(half), io[1].add(io[6]).mul(half), @@ -190,11 +161,11 @@ fn dct(io: &mut [Lane], scratch: &mut [Lane], inverse: bool) { io[2].sub(io[5]).mul(Lane::splat_f32(sec[2] / 2.0)), io[3].sub(io[4]).mul(Lane::splat_f32(sec[3] / 2.0)), ]; - let output0 = dct4(input0, false); + let output0 = dct4(input0, DctDirection::Forward); for (idx, v) in output0.into_iter().enumerate() { io[idx * 2] = v; } - let mut output1 = dct4(input1, false); + let mut output1 = dct4(input1, DctDirection::Forward); output1[0] = output1[0].mul(sqrt2); for idx in 0..3 { io[idx * 2 + 1] = output1[idx].add(output1[idx + 1]); @@ -208,8 +179,8 @@ fn dct(io: &mut [Lane], scratch: &mut [Lane], inverse: bool) { io[5].add(io[3]), io[7].add(io[5]), ]; - let output0 = dct4(input0, true); - let output1 = dct4(input1, true); + let output0 = dct4(input0, DctDirection::Inverse); + let output1 = dct4(input1, DctDirection::Inverse); for (idx, &sec) in sec.iter().enumerate() { let r = output1[idx].mul(Lane::splat_f32(sec)); io[idx] = output0[idx].add(r); @@ -221,15 +192,15 @@ fn dct(io: &mut [Lane], scratch: &mut [Lane], inverse: bool) { assert!(n.is_power_of_two()); - if !inverse { + if direction == DctDirection::Forward { let (input0, input1) = scratch.split_at_mut(n / 2); for (idx, &sec) in consts::sec_half(n).iter().enumerate() { input0[idx] = io[idx].add(io[n - idx - 1]).mul(half); input1[idx] = io[idx].sub(io[n - idx - 1]).mul(Lane::splat_f32(sec / 2.0)); } let (output0, output1) = io.split_at_mut(n / 2); - dct(input0, output0, false); - dct(input1, output1, false); + dct(input0, output0, DctDirection::Forward); + dct(input1, output1, DctDirection::Forward); for (idx, v) in input0.iter().enumerate() { io[idx * 2] = *v; } @@ -248,8 +219,8 @@ fn dct(io: &mut [Lane], scratch: &mut [Lane], inverse: bool) { input0[0] = io[0]; input1[0] = io[1].mul(sqrt2); let (output0, output1) = io.split_at_mut(n / 2); - dct(input0, output0, true); - dct(input1, output1, true); + dct(input0, output0, DctDirection::Inverse); + dct(input1, output1, DctDirection::Inverse); for (idx, &sec) in consts::sec_half(n).iter().enumerate() { let r = input1[idx].mul(Lane::splat_f32(sec)); output0[idx] = input0[idx].add(r); diff --git a/crates/jxl-render/src/inner.rs b/crates/jxl-render/src/inner.rs index 3f470d0f..e7bd1217 100644 --- a/crates/jxl-render/src/inner.rs +++ b/crates/jxl-render/src/inner.rs @@ -1,19 +1,19 @@ use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, io::Read, + sync::Arc, }; use jxl_bitstream::{Bitstream, Bundle}; use jxl_frame::{ + data::*, filter::Gabor, header::{Encoding, FrameType}, Frame, - ProgressiveResult, }; -use jxl_grid::SimpleGrid; +use jxl_grid::{SimpleGrid, CutGrid}; use jxl_image::{ImageHeader, ImageMetadata}; use jxl_modular::ChannelShift; -use jxl_vardct::HfCoeff; use crate::{ blend, @@ -27,20 +27,20 @@ use crate::{ }; #[derive(Debug)] -pub struct ContextInner<'a> { - image_header: &'a ImageHeader, - pub(crate) frames: Vec>, +pub struct ContextInner { + image_header: Arc, + pub(crate) frames: Vec, pub(crate) keyframes: Vec, pub(crate) keyframe_in_progress: Option, pub(crate) refcounts: Vec, pub(crate) frame_deps: Vec, pub(crate) lf_frame: [usize; 4], pub(crate) reference: [usize; 4], - pub(crate) loading_frame: Option>, + pub(crate) loading_frame: Option, } -impl<'a> ContextInner<'a> { - pub fn new(image_header: &'a ImageHeader) -> Self { +impl ContextInner { + pub fn new(image_header: Arc) -> Self { Self { image_header, frames: Vec::new(), @@ -55,7 +55,7 @@ impl<'a> ContextInner<'a> { } } -impl<'a> ContextInner<'a> { +impl ContextInner { #[inline] pub fn width(&self) -> u32 { self.image_header.size.width @@ -67,7 +67,7 @@ impl<'a> ContextInner<'a> { } #[inline] - pub fn metadata(&self) -> &'a ImageMetadata { + pub fn metadata(&self) -> &ImageMetadata { &self.image_header.metadata } @@ -81,7 +81,7 @@ impl<'a> ContextInner<'a> { self.keyframes.len() + (self.keyframe_in_progress.is_some() as usize) } - pub fn keyframe(&self, keyframe_idx: usize) -> Option<&IndexedFrame<'a>> { + pub fn keyframe(&self, keyframe_idx: usize) -> Option<&IndexedFrame> { if keyframe_idx == self.keyframes.len() { if let Some(idx) = self.keyframe_in_progress { Some(&self.frames[idx]) @@ -145,20 +145,18 @@ impl<'a> ContextInner<'a> { } } -impl ContextInner<'_> { - pub fn load_cropped_single( +impl ContextInner { + pub fn load_single( &mut self, bitstream: &mut Bitstream, - progressive: bool, - mut region: Option<(u32, u32, u32, u32)>, - ) -> Result<(ProgressiveResult, &IndexedFrame)> { - let image_header = self.image_header; + ) -> Result<&IndexedFrame> { + let image_header = &self.image_header; let frame = match &mut self.loading_frame { Some(frame) => frame, slot => { let mut bitstream = bitstream.rewindable(); - let frame = Frame::parse(&mut bitstream, image_header)?; + let frame = Frame::parse(&mut bitstream, image_header.clone())?; bitstream.commit(); *slot = Some(IndexedFrame::new(frame, self.frames.len())); slot.as_mut().unwrap() @@ -182,30 +180,15 @@ impl ContextInner<'_> { return Err(Error::UninitializedLfFrame(header.lf_level)); } - if let Some(region) = &mut region { - frame.adjust_region(region); - }; - let filter = if region.is_some() { - Box::new(jxl_frame::crop_filter(region)) as Box bool> - } else { - Box::new(|_: &_, _: &_, _| true) - }; - - let result = if header.frame_type == FrameType::RegularFrame { - frame.load_with_filter(bitstream, progressive, filter)? - } else { - frame.load_all(bitstream)?; - ProgressiveResult::FrameComplete - }; - - Ok((result, frame)) + frame.read_all(bitstream)?; + Ok(frame) } } -impl<'f> ContextInner<'f> { +impl ContextInner { pub fn render_frame<'a>( &'a self, - frame: &'a IndexedFrame<'f>, + frame: &'a IndexedFrame, reference_frames: ReferenceFrames<'a>, cache: &mut RenderCache, mut region: Option<(u32, u32, u32, u32)>, @@ -215,7 +198,7 @@ impl<'f> ContextInner<'f> { frame.adjust_region(region); } - let mut fb = match frame_header.encoding { + let (mut fb, gmodular) = match frame_header.encoding { Encoding::Modular => self.render_modular(frame, cache, region), Encoding::VarDct => self.render_vardct(frame, reference_frames.lf, cache, region), }?; @@ -227,13 +210,13 @@ impl<'f> ContextInner<'f> { if let Gabor::Enabled(weights) = frame_header.restoration_filter.gab { filter::apply_gabor_like([a, b, c], weights); } - filter::apply_epf([a, b, c], &frame.data().lf_group, frame_header); + filter::apply_epf([a, b, c], &cache.lf_groups, frame_header); let [a, b, c] = fb; let mut ret = vec![a, b, c]; - self.append_extra_channels(frame, &mut ret); + self.append_extra_channels(frame, &mut ret, gmodular); - self.render_features(frame, &mut ret, reference_frames.refs)?; + self.render_features(frame, &mut ret, reference_frames.refs, cache)?; if !frame_header.save_before_ct { if frame_header.do_ycbcr { @@ -267,21 +250,20 @@ impl<'f> ContextInner<'f> { } cropped } else { - blend::blend(self.image_header, reference_frames.refs, frame, &ret) + blend::blend(&self.image_header, reference_frames.refs, frame, &ret) }) } fn append_extra_channels<'a>( &'a self, - frame: &'a IndexedFrame<'f>, + frame: &'a IndexedFrame, fb: &mut Vec>, + gmodular: GlobalModular, ) { tracing::debug!("Attaching extra channels"); - let frame_data = frame.data(); - let lf_global = frame_data.lf_global.as_ref().unwrap(); - let extra_channel_from = lf_global.gmodular.extra_channel_from(); - let gmodular = &lf_global.gmodular.modular; + let extra_channel_from = gmodular.extra_channel_from(); + let gmodular = &gmodular.modular; let channel_data = &gmodular.image().channel_data()[extra_channel_from..]; @@ -319,13 +301,13 @@ impl<'f> ContextInner<'f> { fn render_features<'a>( &'a self, - frame: &'a IndexedFrame<'f>, + frame: &'a IndexedFrame, grid: &mut [SimpleGrid], reference_grids: [Option<&[SimpleGrid]>; 4], + cache: &mut RenderCache, ) -> Result<()> { - let frame_data = frame.data(); let frame_header = frame.header(); - let lf_global = frame_data.lf_global.as_ref().unwrap(); + let lf_global = cache.lf_global.as_ref().unwrap(); let base_correlations_xb = lf_global.vardct.as_ref().map(|x| { ( x.lf_chan_corr.base_correlation_x, @@ -334,7 +316,7 @@ impl<'f> ContextInner<'f> { }); for (idx, g) in grid.iter_mut().enumerate() { - features::upsample(g, self.image_header, frame_header, idx); + features::upsample(g, &self.image_header, frame_header, idx); } if let Some(patches) = &lf_global.patches { @@ -342,7 +324,7 @@ impl<'f> ContextInner<'f> { let Some(ref_grid) = reference_grids[patch.ref_idx as usize] else { return Err(Error::InvalidReference(patch.ref_idx)); }; - blend::patch(self.image_header, grid, ref_grid, patch); + blend::patch(&self.image_header, grid, ref_grid, patch); } } @@ -413,24 +395,29 @@ impl<'f> ContextInner<'f> { fn render_modular<'a>( &'a self, - frame: &'a IndexedFrame<'f>, - _cache: &mut RenderCache, + frame: &'a IndexedFrame, + cache: &mut RenderCache, _region: Option<(u32, u32, u32, u32)>, - ) -> Result<[SimpleGrid; 3]> { + ) -> Result<([SimpleGrid; 3], GlobalModular)> { let metadata = self.metadata(); let xyb_encoded = self.xyb_encoded(); let frame_header = frame.header(); - let frame_data = frame.data(); - let lf_global = frame_data.lf_global.as_ref().ok_or(Error::IncompleteFrame)?; - let gmodular = &lf_global.gmodular.modular; + + let lf_global = if let Some(x) = &cache.lf_global { + x + } else { + let lf_global = frame.try_parse_lf_global().ok_or(Error::IncompleteFrame)??; + cache.lf_global = Some(lf_global); + cache.lf_global.as_ref().unwrap() + }; + let mut gmodular = lf_global.gmodular.clone(); + let jpeg_upsampling = frame_header.jpeg_upsampling; let shifts_cbycr = [0, 1, 2].map(|idx| { ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx) }); let channels = metadata.encoded_color_channels(); - let channel_data = &gmodular.image().channel_data()[..channels]; - let width = frame_header.color_sample_width() as usize; let height = frame_header.color_sample_height() as usize; let bit_depth = metadata.bit_depth; @@ -440,6 +427,34 @@ impl<'f> ContextInner<'f> { SimpleGrid::new(width, height), ]; + let lf_groups = &mut cache.lf_groups; + load_lf_groups(frame, lf_global, lf_groups, _region, &mut gmodular)?; + + for pass_idx in 0..frame_header.passes.num_passes { + for group_idx in 0..frame_header.num_groups() { + let lf_group_idx = frame_header.lf_group_idx_from_group_idx(group_idx); + let Some(lf_group) = lf_groups.get(&lf_group_idx) else { continue; }; + let Some(mut bitstream) = frame.pass_group_bitstream(pass_idx, group_idx).transpose()? else { continue; }; + + let shift = frame.pass_shifts(pass_idx); + decode_pass_group( + &mut bitstream, + PassGroupParams { + frame_header, + lf_group, + pass_idx, + group_idx, + shift, + gmodular: &mut gmodular, + vardct: None, + }, + )?; + } + } + + gmodular.modular.inverse_transform(); + let channel_data = gmodular.modular.image().channel_data(); + for ((g, shift), buffer) in channel_data.iter().zip(shifts_cbycr).zip(fb_xyb.iter_mut()) { let buffer = buffer.buf_mut(); let (gw, gh) = g.group_dim(); @@ -487,57 +502,45 @@ impl<'f> ContextInner<'f> { } } - Ok(fb_xyb) + Ok((fb_xyb, gmodular)) } fn render_vardct<'a>( &'a self, - frame: &'a IndexedFrame<'f>, + frame: &'a IndexedFrame, lf_frame: Option<&'a [SimpleGrid]>, cache: &mut RenderCache, - region: Option<(u32, u32, u32, u32)>, - ) -> Result<[SimpleGrid; 3]> { + _region: Option<(u32, u32, u32, u32)>, + ) -> Result<([SimpleGrid; 3], GlobalModular)> { let span = tracing::span!(tracing::Level::TRACE, "RenderContext::render_vardct"); let _guard = span.enter(); - let metadata = self.metadata(); let frame_header = frame.header(); - let frame_data = frame.data(); - let lf_global = frame_data.lf_global.as_ref().ok_or(Error::IncompleteFrame)?; + + let lf_global = if let Some(x) = &cache.lf_global { + x + } else { + let lf_global = frame.try_parse_lf_global().ok_or(Error::IncompleteFrame)??; + cache.lf_global = Some(lf_global); + cache.lf_global.as_ref().unwrap() + }; + let mut gmodular = lf_global.gmodular.clone(); let lf_global_vardct = lf_global.vardct.as_ref().unwrap(); - let hf_global = frame_data.hf_global.as_ref().ok_or(Error::IncompleteFrame)?; - let hf_global = hf_global.as_ref().expect("HfGlobal not found for VarDCT frame"); + + let hf_global = if let Some(x) = &cache.hf_global { + x + } else { + let hf_global = frame.try_parse_hf_global(Some(lf_global)).ok_or(Error::IncompleteFrame)??; + cache.hf_global = Some(hf_global); + cache.hf_global.as_ref().unwrap() + }; + let jpeg_upsampling = frame_header.jpeg_upsampling; let shifts_cbycr: [_; 3] = std::array::from_fn(|idx| { ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx) }); let subsampled = jpeg_upsampling.into_iter().any(|x| x != 0); - // Modular extra channels are already merged into GlobalModular, - // so it's okay to drop PassGroup modular - for (&(pass_idx, group_idx), group_pass) in &frame_data.group_pass { - if let Some(region) = region { - if !frame_header.is_group_collides_region(group_idx, region) { - continue; - } - } - if !cache.coeff_merged.insert((pass_idx, group_idx)) { - continue; - } - - let hf_coeff = group_pass.hf_coeff.as_ref().unwrap(); - cache.group_coeffs - .entry(group_idx as usize) - .or_insert_with(HfCoeff::empty) - .merge(hf_coeff); - } - let group_coeffs = &cache.group_coeffs; - - let quantizer = &lf_global_vardct.quantizer; - let oim = &metadata.opsin_inverse_matrix; - let dequant_matrices = &hf_global.dequant_matrices; - let lf_chan_corr = &lf_global_vardct.lf_chan_corr; - let width = frame_header.color_sample_width() as usize; let height = frame_header.color_sample_height() as usize; let (width_rounded, height_rounded) = { @@ -559,171 +562,138 @@ impl<'f> ContextInner<'f> { SimpleGrid::new(width_rounded, height_rounded), ]; - let mut subgrids = { - let [x, y, b] = &mut fb_xyb; - let group_dim = frame_header.group_dim() as usize; - [ - cut_grid::cut_with_block_info(x, group_coeffs, group_dim, shifts_cbycr[0]), - cut_grid::cut_with_block_info(y, group_coeffs, group_dim, shifts_cbycr[1]), - cut_grid::cut_with_block_info(b, group_coeffs, group_dim, shifts_cbycr[2]), - ] - }; + let lf_groups = &mut cache.lf_groups; + load_lf_groups(frame, lf_global, lf_groups, _region, &mut gmodular)?; - let lf_group_it = frame_data.lf_group - .iter() - .filter(|(&lf_group_idx, _)| { - let Some(region) = region else { return true; }; - frame_header.is_lf_group_collides_region(lf_group_idx, region) - }); - let mut hf_meta_map = HashMap::new(); - let mut lf_image_changed = false; - for (&lf_group_idx, data) in lf_group_it { - let group_x = lf_group_idx % frame_header.lf_groups_per_row(); - let group_y = lf_group_idx / frame_header.lf_groups_per_row(); - - let lf_group_idx = lf_group_idx as usize; - hf_meta_map.insert(lf_group_idx, data.hf_meta.as_ref().unwrap()); - - if lf_frame.is_some() { - continue; - } - if !cache.inserted_lf_groups.insert(lf_group_idx) { - continue; + let mut lf_xyb_buf; + let lf_xyb; + if let Some(x) = lf_frame { + lf_xyb = x; + } else { + lf_xyb_buf = [ + SimpleGrid::new(width_rounded / 8, height_rounded / 8), + SimpleGrid::new(width_rounded / 8, height_rounded / 8), + SimpleGrid::new(width_rounded / 8, height_rounded / 8), + ]; + for idx in 0..frame_header.num_lf_groups() { + let Some(lf_group) = lf_groups.get(&idx) else { continue; }; + + let lf_group_x = idx % frame_header.lf_groups_per_row(); + let lf_group_y = idx / frame_header.lf_groups_per_row(); + let left = lf_group_x * frame_header.group_dim(); + let top = lf_group_y * frame_header.group_dim(); + + let lf_coeff = lf_group.lf_coeff.as_ref().unwrap(); + let channel_data = lf_coeff.lf_quant.image().channel_data(); + + let [lf_x, lf_y, lf_b] = &mut lf_xyb_buf; + let lf_x = cut_grid::make_quant_cut_grid(lf_x, left as usize, top as usize, shifts_cbycr[0], &channel_data[1]); + let lf_y = cut_grid::make_quant_cut_grid(lf_y, left as usize, top as usize, shifts_cbycr[1], &channel_data[0]); + let lf_b = cut_grid::make_quant_cut_grid(lf_b, left as usize, top as usize, shifts_cbycr[2], &channel_data[2]); + let mut lf = [lf_x, lf_y, lf_b]; + + vardct::dequant_lf( + &mut lf, + &lf_global.lf_dequant, + &lf_global_vardct.quantizer, + lf_coeff.extra_precision, + ); + if !subsampled { + vardct::chroma_from_luma_lf( + &mut lf, + &lf_global_vardct.lf_chan_corr, + ); + } } - let group_dim = frame_header.group_dim(); - let lf_coeff = data.lf_coeff.as_ref().unwrap(); - let quant_channel_data = lf_coeff.lf_quant.image().channel_data(); - let [lf_x, lf_y, lf_b] = &mut cache.dequantized_lf; - - let left = (group_x * group_dim) as usize; - let top = (group_y * group_dim) as usize; - let lf_x = cut_grid::make_quant_cut_grid(lf_x, left, top, shifts_cbycr[0], &quant_channel_data[1]); - let lf_y = cut_grid::make_quant_cut_grid(lf_y, left, top, shifts_cbycr[1], &quant_channel_data[0]); - let lf_b = cut_grid::make_quant_cut_grid(lf_b, left, top, shifts_cbycr[2], &quant_channel_data[2]); - let mut lf = [lf_x, lf_y, lf_b]; - - vardct::dequant_lf( - &mut lf, - &lf_global.lf_dequant, - quantizer, - lf_coeff.extra_precision, - ); - if !subsampled { - vardct::chroma_from_luma_lf( - &mut lf, - &lf_global_vardct.lf_chan_corr, + if !frame_header.flags.skip_adaptive_lf_smoothing() { + vardct::adaptive_lf_smoothing( + &mut lf_xyb_buf, + &lf_global.lf_dequant, + &lf_global_vardct.quantizer, ); } - lf_image_changed = true; + lf_xyb = &lf_xyb_buf; } - if lf_image_changed && lf_frame.is_none() && !frame_header.flags.skip_adaptive_lf_smoothing() { - let smoothed_lf = match &mut cache.smoothed_lf { - Some(smoothed_lf) => smoothed_lf, - x => { - let width = cache.dequantized_lf[0].width(); - let height = cache.dequantized_lf[0].height(); - *x = Some(std::array::from_fn(|_| SimpleGrid::new(width, height))); - x.as_mut().unwrap() - }, - }; - vardct::adaptive_lf_smoothing( - &cache.dequantized_lf, - smoothed_lf, - &lf_global.lf_dequant, - quantizer, - ); + let group_dim = frame_header.group_dim(); + for pass_idx in 0..frame_header.passes.num_passes { + for group_idx in 0..frame_header.num_groups() { + let lf_group_idx = frame_header.lf_group_idx_from_group_idx(group_idx); + let Some(lf_group) = lf_groups.get(&lf_group_idx) else { continue; }; + let Some(mut bitstream) = frame.pass_group_bitstream(pass_idx, group_idx).transpose()? else { continue; }; + + let group_x = group_idx % frame_header.groups_per_row(); + let group_y = group_idx / frame_header.groups_per_row(); + let left = group_x * group_dim; + let top = group_y * group_dim; + let group_width = group_dim.min(width_rounded as u32 - left); + let group_height = group_dim.min(height_rounded as u32 - top); + + let [fb_x, fb_y, fb_b] = &mut fb_xyb; + let mut grid_xyb = [(0usize, fb_x), (1, fb_y), (2, fb_b)].map(|(idx, fb)| { + let hshift = shifts_cbycr[idx].hshift(); + let vshift = shifts_cbycr[idx].vshift(); + let group_width = group_width >> hshift; + let group_height = group_height >> vshift; + let left = left >> hshift; + let top = top >> vshift; + let offset = top as usize * width_rounded + left as usize; + CutGrid::from_buf(&mut fb.buf_mut()[offset..], group_width as usize, group_height as usize, width_rounded) + }); + + let shift = frame.pass_shifts(pass_idx); + decode_pass_group( + &mut bitstream, + PassGroupParams { + frame_header, + lf_group, + pass_idx, + group_idx, + shift, + gmodular: &mut gmodular, + vardct: Some(PassGroupParamsVardct { + lf_vardct: lf_global_vardct, + hf_global, + hf_coeff_output: &mut grid_xyb, + }), + }, + )?; + } } - let dequantized_lf = if let Some(lf_frame) = lf_frame { - lf_frame - } else if let Some(smoothed_lf) = &cache.smoothed_lf { - smoothed_lf - } else { - &cache.dequantized_lf - }; - - let group_dim = frame_header.group_dim() as usize; - let groups_per_row = frame_header.groups_per_row() as usize; - - for (group_idx, hf_coeff) in group_coeffs { - let mut x = subgrids[0].remove(group_idx).unwrap(); - let mut y = subgrids[1].remove(group_idx).unwrap(); - let mut b = subgrids[2].remove(group_idx).unwrap(); - let lf_group_id = frame_header.lf_group_idx_from_group_idx(*group_idx as u32) as usize; - let hf_meta = hf_meta_map.get(&lf_group_id).unwrap(); - let x_from_y = &hf_meta.x_from_y; - let b_from_y = &hf_meta.b_from_y; - - let group_row = group_idx / groups_per_row; - let group_col = group_idx % groups_per_row; - - for coeff_data in hf_coeff.data() { - let bx = coeff_data.bx; - let by = coeff_data.by; - let coord = (bx, by); - let mut x = x.get_mut(&coord); - let mut y = y.get_mut(&coord); - let mut b = b.get_mut(&coord); - let dct_select = coeff_data.dct_select; - - if let Some(x) = &mut x { - vardct::dequant_hf_varblock(coeff_data, x, 0, oim, quantizer, dequant_matrices, Some(frame_header.x_qm_scale)); - } - if let Some(y) = &mut y { - vardct::dequant_hf_varblock(coeff_data, y, 1, oim, quantizer, dequant_matrices, None); - } - if let Some(b) = &mut b { - vardct::dequant_hf_varblock(coeff_data, b, 2, oim, quantizer, dequant_matrices, Some(frame_header.b_qm_scale)); - } - - let lf_left = (group_col * group_dim) / 8 + bx; - let lf_top = (group_row * group_dim) / 8 + by; - if !subsampled { - let lf_left = (lf_left % group_dim) * 8; - let lf_top = (lf_top % group_dim) * 8; - let mut xyb = [ - &mut **x.as_mut().unwrap(), - &mut **y.as_mut().unwrap(), - &mut **b.as_mut().unwrap(), - ]; - vardct::chroma_from_luma_hf(&mut xyb, lf_left, lf_top, x_from_y, b_from_y, lf_chan_corr); - } - - for ((coeff, lf_dequant), shift) in [x, y, b].into_iter().zip(dequantized_lf.iter()).zip(shifts_cbycr) { - let Some(coeff) = coeff else { continue; }; - - let s_lf_left = lf_left >> shift.hshift(); - let s_lf_top = lf_top >> shift.vshift(); - if s_lf_left << shift.hshift() != lf_left || s_lf_top << shift.vshift() != lf_top { - continue; - } - - let llf = vardct::llf_from_lf(lf_dequant, s_lf_left, s_lf_top, dct_select); - for y in 0..llf.height() { - for x in 0..llf.width() { - *coeff.get_mut(x, y) = *llf.get(x, y).unwrap(); - } - } - - vardct::transform(coeff, dct_select); - } - } + gmodular.modular.inverse_transform(); + vardct::dequant_hf_varblock( + &mut fb_xyb, + &self.image_header, + frame_header, + lf_global, + &*lf_groups, + hf_global, + ); + if !subsampled { + vardct::chroma_from_luma_hf( + &mut fb_xyb, + frame_header, + lf_global, + &*lf_groups, + ); } + vardct::transform_with_lf(lf_xyb, &mut fb_xyb, frame_header, &*lf_groups); - if width == width_rounded && height == width_rounded { - Ok(fb_xyb) + let fb = if width == width_rounded && height == width_rounded { + fb_xyb } else { - Ok(fb_xyb.map(|g| { + fb_xyb.map(|g| { let mut new_g = SimpleGrid::new(width, height); for (new_row, row) in new_g.buf_mut().chunks_exact_mut(width).zip(g.buf().chunks_exact(width_rounded)) { new_row.copy_from_slice(&row[..width]); } new_g - })) - } + }) + }; + Ok((fb, gmodular)) } } @@ -747,15 +717,13 @@ pub struct ReferenceFrames<'state> { #[derive(Debug)] pub struct RenderCache { - dequantized_lf: [SimpleGrid; 3], - smoothed_lf: Option<[SimpleGrid; 3]>, - inserted_lf_groups: HashSet, - group_coeffs: HashMap, - coeff_merged: HashSet<(u32, u32)>, + lf_global: Option, + hf_global: Option, + lf_groups: HashMap, } impl RenderCache { - pub fn new(frame: &IndexedFrame<'_>) -> Self { + pub fn new(frame: &IndexedFrame) -> Self { let frame_header = frame.header(); let jpeg_upsampling = frame_header.jpeg_upsampling; let shifts_cbycr: [_; 3] = std::array::from_fn(|idx| { @@ -770,13 +738,33 @@ impl RenderCache { *w = shift_w; *h = shift_h; } - let dequantized_lf = whd.map(|(w, h)| SimpleGrid::new(w as usize, h as usize)); Self { - dequantized_lf, - smoothed_lf: None, - inserted_lf_groups: HashSet::new(), - group_coeffs: HashMap::new(), - coeff_merged: HashSet::new(), + lf_global: None, + hf_global: None, + lf_groups: HashMap::new(), } } } + +fn load_lf_groups( + frame: &IndexedFrame, + lf_global: &LfGlobal, + lf_groups: &mut HashMap, + _region: Option<(u32, u32, u32, u32)>, + gmodular: &mut GlobalModular, +) -> Result<()> { + let frame_header = frame.header(); + for idx in 0..frame_header.num_lf_groups() { + let lf_group = lf_groups.entry(idx); + let lf_group = match lf_group { + std::collections::hash_map::Entry::Occupied(x) => x.into_mut(), + std::collections::hash_map::Entry::Vacant(x) => { + let Some(lf_group) = frame.try_parse_lf_group(Some(lf_global), idx).transpose()? else { continue; }; + &*x.insert(lf_group) + }, + }; + gmodular.modular.copy_from_modular(lf_group.mlf_group.clone()); + } + + Ok(()) +} diff --git a/crates/jxl-render/src/lib.rs b/crates/jxl-render/src/lib.rs index 568cd12a..03982245 100644 --- a/crates/jxl-render/src/lib.rs +++ b/crates/jxl-render/src/lib.rs @@ -1,8 +1,9 @@ //! This crate is the core of jxl-oxide that provides JPEG XL renderer. use std::io::Read; +use std::sync::Arc; use jxl_bitstream::Bitstream; -use jxl_frame::{Frame, ProgressiveResult}; +use jxl_frame::{Frame, data::TocGroupKind}; use jxl_grid::SimpleGrid; use jxl_image::{ImageHeader, ExtraChannelType}; @@ -20,14 +21,14 @@ use inner::*; /// Render context that tracks loaded and rendered frames. #[derive(Debug)] -pub struct RenderContext<'a> { - inner: ContextInner<'a>, +pub struct RenderContext { + inner: ContextInner, state: RenderState, } -impl<'a> RenderContext<'a> { +impl RenderContext { /// Creates a new render context. - pub fn new(image_header: &'a ImageHeader) -> Self { + pub fn new(image_header: Arc) -> Self { Self { inner: ContextInner::new(image_header), state: RenderState::new(), @@ -35,7 +36,7 @@ impl<'a> RenderContext<'a> { } } -impl RenderContext<'_> { +impl RenderContext { /// Returns the image width. #[inline] pub fn width(&self) -> u32 { @@ -55,26 +56,14 @@ impl RenderContext<'_> { } } -impl RenderContext<'_> { - /// Load all frames in the bitstream, with the given cropping region. - /// - /// `region` is expected to be in the order `(left, top, width, height)`. - pub fn load_all_frames_cropped( +impl RenderContext { + /// Load all frames in the bitstream. + pub fn load_all_frames( &mut self, bitstream: &mut Bitstream, - progressive: bool, - region: Option<(u32, u32, u32, u32)>, - ) -> Result { + ) -> Result<()> { loop { - let result = self.inner.load_cropped_single(bitstream, progressive, region); - let (result, frame) = match result { - Ok(val) => val, - Err(Error::Frame(e)) if e.unexpected_eof() => return Ok(ProgressiveResult::NeedMoreData), - Err(e) => return Err(e), - }; - if result != ProgressiveResult::FrameComplete { - return Ok(result); - } + let frame = self.inner.load_single(bitstream)?; let is_last = frame.header().is_last; let toc = frame.toc(); @@ -88,28 +77,16 @@ impl RenderContext<'_> { bitstream.skip_to_bookmark(bookmark)?; } - Ok(ProgressiveResult::FrameComplete) + Ok(()) } /// Load a single keyframe from the bitstream. - /// - /// `region` is expected to be in the order `(left, top, width, height)`. pub fn load_until_keyframe( &mut self, bitstream: &mut Bitstream, - progressive: bool, - region: Option<(u32, u32, u32, u32)>, - ) -> Result { + ) -> Result<()> { loop { - let result = self.inner.load_cropped_single(bitstream, progressive, region); - let (result, frame) = match result { - Ok(val) => val, - Err(Error::Frame(e)) if e.unexpected_eof() => return Ok(ProgressiveResult::NeedMoreData), - Err(e) => return Err(e), - }; - if result != ProgressiveResult::FrameComplete { - return Ok(result); - } + let frame = self.inner.load_single(bitstream)?; let is_keyframe = frame.header().is_keyframe(); let toc = frame.toc(); @@ -123,29 +100,20 @@ impl RenderContext<'_> { bitstream.skip_to_bookmark(bookmark)?; } - Ok(ProgressiveResult::FrameComplete) - } - - /// Load all frames in the bitstream. - pub fn load_all_frames( - &mut self, - bitstream: &mut Bitstream, - progressive: bool, - ) -> Result { - self.load_all_frames_cropped(bitstream, progressive, None) + Ok(()) } } -impl<'a> RenderContext<'a> { +impl RenderContext { /// Returns the frame with the keyframe index, or `None` if the keyframe does not exist. #[inline] - pub fn keyframe(&self, keyframe_idx: usize) -> Option<&IndexedFrame<'a>> { + pub fn keyframe(&self, keyframe_idx: usize) -> Option<&IndexedFrame> { self.inner.keyframe(keyframe_idx) } } -impl RenderContext<'_> { - fn render_by_index(&mut self, index: usize, region: Option<(u32, u32, u32, u32)>) -> Result<()> { +impl RenderContext { + fn render_by_index(&mut self, index: usize) -> Result<()> { let span = tracing::span!(tracing::Level::TRACE, "RenderContext::render_by_index", index); let _guard = span.enter(); @@ -155,10 +123,10 @@ impl RenderContext<'_> { let deps = self.inner.frame_deps[index]; for dep in deps.indices() { - self.render_by_index(dep, None)?; + self.render_by_index(dep)?; } - tracing::debug!(index, region = format_args!("{:?}", region), "Rendering frame"); + tracing::debug!(index, "Rendering frame"); let frame = &self.inner.frames[index]; let (prev, state) = self.state.renders.split_at_mut(index); let state = &mut state[0]; @@ -177,7 +145,7 @@ impl RenderContext<'_> { }, }; - let grid = self.inner.render_frame(frame, reference_frames, cache, region)?; + let grid = self.inner.render_frame(frame, reference_frames, cache, None)?; *state = FrameRender::Done(grid); let mut unref = |idx: usize| { @@ -201,27 +169,25 @@ impl RenderContext<'_> { Ok(()) } - /// Renders the first keyframe with the given cropping region. + /// Renders the first keyframe. /// /// The keyframe should be loaded in prior to rendering, with one of the loading methods. #[inline] - pub fn render_cropped( + pub fn render( &mut self, - region: Option<(u32, u32, u32, u32)>, ) -> Result>> { - self.render_keyframe_cropped(0, region) + self.render_keyframe(0) } - /// Renders the keyframe with the given cropping region. + /// Renders the keyframe. /// /// The keyframe should be loaded in prior to rendering, with one of the loading methods. - pub fn render_keyframe_cropped( + pub fn render_keyframe( &mut self, keyframe_idx: usize, - region: Option<(u32, u32, u32, u32)>, ) -> Result>> { - let (frame, grid) = if let Some(&idx) = self.inner.keyframes.get(keyframe_idx) { - self.render_by_index(idx, region)?; + let (frame, mut grid) = if let Some(&idx) = self.inner.keyframes.get(keyframe_idx) { + self.render_by_index(idx)?; let FrameRender::Done(grid) = &self.state.renders[idx] else { panic!(); }; let frame = &self.inner.frames[idx]; (frame, grid.clone()) @@ -229,7 +195,7 @@ impl RenderContext<'_> { let mut current_frame_grid = None; if let Some(frame) = &self.inner.loading_frame { if frame.header().frame_type.is_normal_frame() { - let ret = self.render_loading_frame(region); + let ret = self.render_loading_frame(); match ret { Ok(grid) => current_frame_grid = Some(grid), Err(Error::IncompleteFrame) => {}, @@ -242,7 +208,7 @@ impl RenderContext<'_> { let frame = self.inner.loading_frame.as_ref().unwrap(); (frame, grid) } else if let Some(idx) = self.inner.keyframe_in_progress { - self.render_by_index(idx, region)?; + self.render_by_index(idx)?; let FrameRender::Done(grid) = &self.state.renders[idx] else { panic!(); }; let frame = &self.inner.frames[idx]; (frame, grid.clone()) @@ -253,45 +219,29 @@ impl RenderContext<'_> { let frame_header = frame.header(); - let mut cropped = if let Some((l, t, w, h)) = region { - let mut cropped = Vec::with_capacity(grid.len()); - for g in grid { - let mut new_grid = SimpleGrid::new(w as usize, h as usize); - for (idx, v) in new_grid.buf_mut().iter_mut().enumerate() { - let y = idx / w as usize; - let x = idx % w as usize; - *v = *g.get(x + l as usize, y + t as usize).unwrap(); - } - cropped.push(new_grid); - } - cropped - } else { - grid - }; - if frame_header.save_before_ct { if frame_header.do_ycbcr { - let [cb, y, cr, ..] = &mut *cropped else { panic!() }; + let [cb, y, cr, ..] = &mut *grid else { panic!() }; jxl_color::ycbcr_to_rgb([cb, y, cr]); } - self.inner.convert_color(&mut cropped); + self.inner.convert_color(&mut grid); } let channels = if self.inner.metadata().grayscale() { 1 } else { 3 }; - cropped.drain(channels..3); - Ok(cropped) + grid.drain(channels..3); + Ok(grid) } - fn render_loading_frame(&mut self, region: Option<(u32, u32, u32, u32)>) -> Result>> { + fn render_loading_frame(&mut self) -> Result>> { let frame = self.inner.loading_frame.as_ref().unwrap(); let header = frame.header(); - if frame.data().lf_global.is_none() { + if frame.data(TocGroupKind::LfGlobal).is_none() { return Err(Error::IncompleteFrame); } let lf_frame = if header.flags.use_lf_frame() { let lf_frame_idx = self.inner.lf_frame[header.lf_level as usize]; - self.render_by_index(lf_frame_idx, None)?; + self.render_by_index(lf_frame_idx)?; Some(self.state.renders[lf_frame_idx].as_grid().unwrap()) } else { None @@ -308,7 +258,7 @@ impl RenderContext<'_> { refs: [None; 4], }; - self.inner.render_frame(frame, reference_frames, cache, region) + self.inner.render_frame(frame, reference_frames, cache, None) } } @@ -356,13 +306,13 @@ impl FrameRender { /// Frame with its index in the image. #[derive(Debug)] -pub struct IndexedFrame<'a> { - f: Frame<'a>, +pub struct IndexedFrame { + f: Frame, idx: usize, } -impl<'a> IndexedFrame<'a> { - fn new(frame: Frame<'a>, index: usize) -> Self { +impl IndexedFrame { + fn new(frame: Frame, index: usize) -> Self { IndexedFrame { f: frame, idx: index } } @@ -372,15 +322,15 @@ impl<'a> IndexedFrame<'a> { } } -impl<'a> std::ops::Deref for IndexedFrame<'a> { - type Target = Frame<'a>; +impl std::ops::Deref for IndexedFrame { + type Target = Frame; fn deref(&self) -> &Self::Target { &self.f } } -impl<'a> std::ops::DerefMut for IndexedFrame<'a> { +impl std::ops::DerefMut for IndexedFrame { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.f } diff --git a/crates/jxl-render/src/vardct/generic.rs b/crates/jxl-render/src/vardct/generic.rs index 9d78750a..62188d1e 100644 --- a/crates/jxl-render/src/vardct/generic.rs +++ b/crates/jxl-render/src/vardct/generic.rs @@ -2,8 +2,7 @@ pub fn adaptive_lf_smoothing_impl( width: usize, height: usize, - [in_x, in_y, in_b]: [&[f32]; 3], - [out_x, out_y, out_b]: [&mut [f32]; 3], + [in_x, in_y, in_b]: [&mut [f32]; 3], [lf_x, lf_y, lf_b]: [f32; 3], ) { const SCALE_SELF: f32 = 0.052262735; @@ -17,16 +16,13 @@ pub fn adaptive_lf_smoothing_impl( assert_eq!(in_x.len(), in_y.len()); assert_eq!(in_y.len(), in_b.len()); - assert_eq!(in_x.len(), out_x.len()); - assert_eq!(in_y.len(), out_y.len()); - assert_eq!(in_b.len(), out_b.len()); assert_eq!(in_x.len(), width * height); let mut udsum_x = vec![0.0f32; width * (height - 2)]; let mut udsum_y = vec![0.0f32; width * (height - 2)]; let mut udsum_b = vec![0.0f32; width * (height - 2)]; - for (g, out) in [(in_x, &mut udsum_x), (in_y, &mut udsum_y), (in_b, &mut udsum_b)] { + for (g, out) in [(&mut *in_x, &mut udsum_x), (&mut *in_y, &mut udsum_y), (&mut *in_b, &mut udsum_b)] { let up = g.chunks_exact(width); let down = g[width * 2..].chunks_exact(width); let out = out.chunks_exact_mut(width); @@ -37,16 +33,9 @@ pub fn adaptive_lf_smoothing_impl( } } - let mut in_x_row = in_x.chunks_exact(width); - let mut in_y_row = in_y.chunks_exact(width); - let mut in_b_row = in_b.chunks_exact(width); - let mut out_x_row = out_x.chunks_exact_mut(width); - let mut out_y_row = out_y.chunks_exact_mut(width); - let mut out_b_row = out_b.chunks_exact_mut(width); - - out_x_row.next().unwrap().copy_from_slice(in_x_row.next().unwrap()); - out_y_row.next().unwrap().copy_from_slice(in_y_row.next().unwrap()); - out_b_row.next().unwrap().copy_from_slice(in_b_row.next().unwrap()); + let mut in_x_row = in_x.chunks_exact_mut(width).skip(1); + let mut in_y_row = in_y.chunks_exact_mut(width).skip(1); + let mut in_b_row = in_b.chunks_exact_mut(width).skip(1); let mut udsum_x_row = udsum_x.chunks_exact(width); let mut udsum_y_row = udsum_y.chunks_exact(width); @@ -59,29 +48,25 @@ pub fn adaptive_lf_smoothing_impl( let in_x = in_x_row.next().unwrap(); let in_y = in_y_row.next().unwrap(); let in_b = in_b_row.next().unwrap(); - let out_x = out_x_row.next().unwrap(); - let out_y = out_y_row.next().unwrap(); - let out_b = out_b_row.next().unwrap(); - - out_x[0] = in_x[0]; - out_y[0] = in_y[0]; - out_b[0] = in_b[0]; + let mut in_x_prev = in_x[0]; + let mut in_y_prev = in_y[0]; + let mut in_b_prev = in_b[0]; for x in 1..(width - 1) { let x_self = in_x[x]; - let x_side = in_x[x - 1] + in_x[x + 1] + udsum_x[x]; + let x_side = in_x_prev + in_x[x + 1] + udsum_x[x]; let x_diag = udsum_x[x - 1] + udsum_x[x + 1]; let x_wa = x_self * SCALE_SELF + x_side * SCALE_SIDE + x_diag * SCALE_DIAG; let x_gap_t = (x_wa - x_self).abs() / lf_x; let y_self = in_y[x]; - let y_side = in_y[x - 1] + in_y[x + 1] + udsum_y[x]; + let y_side = in_y_prev + in_y[x + 1] + udsum_y[x]; let y_diag = udsum_y[x - 1] + udsum_y[x + 1]; let y_wa = y_self * SCALE_SELF + y_side * SCALE_SIDE + y_diag * SCALE_DIAG; let y_gap_t = (y_wa - y_self).abs() / lf_y; let b_self = in_b[x]; - let b_side = in_b[x - 1] + in_b[x + 1] + udsum_b[x]; + let b_side = in_b_prev + in_b[x + 1] + udsum_b[x]; let b_diag = udsum_b[x - 1] + udsum_b[x + 1]; let b_wa = b_self * SCALE_SELF + b_side * SCALE_SIDE + b_diag * SCALE_DIAG; let b_gap_t = (b_wa - b_self).abs() / lf_b; @@ -89,17 +74,12 @@ pub fn adaptive_lf_smoothing_impl( let gap = 0.5f32.max(x_gap_t).max(y_gap_t).max(b_gap_t); let gap_scale = (3.0 - 4.0 * gap).max(0.0); - out_x[x] = (x_wa - x_self) * gap_scale + x_self; - out_y[x] = (y_wa - y_self) * gap_scale + y_self; - out_b[x] = (b_wa - b_self) * gap_scale + b_self; + in_x[x] = (x_wa - x_self) * gap_scale + x_self; + in_y[x] = (y_wa - y_self) * gap_scale + y_self; + in_b[x] = (b_wa - b_self) * gap_scale + b_self; + in_x_prev = x_self; + in_y_prev = y_self; + in_b_prev = b_self; } - - out_x[width - 1] = in_x[width - 1]; - out_y[width - 1] = in_y[width - 1]; - out_b[width - 1] = in_b[width - 1]; } - - out_x_row.next().unwrap().copy_from_slice(in_x_row.next().unwrap()); - out_y_row.next().unwrap().copy_from_slice(in_y_row.next().unwrap()); - out_b_row.next().unwrap().copy_from_slice(in_b_row.next().unwrap()); } diff --git a/crates/jxl-render/src/vardct/mod.rs b/crates/jxl-render/src/vardct/mod.rs index 54e12d0d..da44aa07 100644 --- a/crates/jxl-render/src/vardct/mod.rs +++ b/crates/jxl-render/src/vardct/mod.rs @@ -1,15 +1,18 @@ -use jxl_color::OpsinInverseMatrix; +use std::collections::HashMap; + +use jxl_frame::{data::{LfGroup, LfGlobal, HfGlobal}, FrameHeader}; use jxl_grid::{CutGrid, SimpleGrid}; +use jxl_image::ImageHeader; +use jxl_modular::ChannelShift; use jxl_vardct::{ - CoeffData, - DequantMatrixSet, LfChannelCorrelation, LfChannelDequantization, Quantizer, TransformType, + BlockInfo, }; -use crate::dct::dct_2d; +use crate::dct; mod transform; pub use transform::transform; @@ -49,8 +52,7 @@ pub fn dequant_lf( } pub fn adaptive_lf_smoothing( - lf_image: &[SimpleGrid; 3], - out: &mut [SimpleGrid; 3], + lf_image: &mut [SimpleGrid; 3], lf_dequant: &LfChannelDequantization, quantizer: &Quantizer, ) { @@ -60,78 +62,114 @@ pub fn adaptive_lf_smoothing( let lf_b = 512.0 * lf_dequant.m_b_lf / scale_inv as f32; let [in_x, in_y, in_b] = lf_image; - let [out_x, out_y, out_b] = out; - let width = out_x.width(); - let height = out_x.height(); + let width = in_x.width(); + let height = in_x.height(); - let in_x = in_x.buf(); - let in_y = in_y.buf(); - let in_b = in_b.buf(); - let out_x = out_x.buf_mut(); - let out_y = out_y.buf_mut(); - let out_b = out_b.buf_mut(); + let in_x = in_x.buf_mut(); + let in_y = in_y.buf_mut(); + let in_b = in_b.buf_mut(); impls::adaptive_lf_smoothing_impl( width, height, [in_x, in_y, in_b], - [out_x, out_y, out_b], [lf_x, lf_y, lf_b], ); } pub fn dequant_hf_varblock( - coeff_data: &CoeffData, - out: &mut CutGrid<'_>, - channel: usize, - oim: &OpsinInverseMatrix, - quantizer: &Quantizer, - dequant_matrices: &DequantMatrixSet, - qm_scale: Option, + out: &mut [SimpleGrid; 3], + image_header: &ImageHeader, + frame_header: &FrameHeader, + lf_global: &LfGlobal, + lf_groups: &HashMap, + hf_global: &HfGlobal, ) { - let CoeffData { dct_select, hf_mul, ref coeff, .. } = *coeff_data; - let need_transpose = dct_select.need_transpose(); + let shifts_cbycr: [_; 3] = std::array::from_fn(|idx| { + ChannelShift::from_jpeg_upsampling(frame_header.jpeg_upsampling, idx) + }); + let oim = &image_header.metadata.opsin_inverse_matrix; + let quantizer = &lf_global.vardct.as_ref().unwrap().quantizer; + let dequant_matrices = &hf_global.dequant_matrices; + + let qm_scale = [ + 0.8f32.powi(frame_header.x_qm_scale as i32 - 2), + 1.0f32, + 0.8f32.powi(frame_header.b_qm_scale as i32 - 2), + ]; - let mut mul = 65536.0 / (quantizer.global_scale as i32 * hf_mul) as f32; - if let Some(qm_scale) = qm_scale { - let scale = 0.8f32.powi(qm_scale as i32 - 2); - mul *= scale; - } - let quant_bias = oim.quant_bias[channel]; let quant_bias_numerator = oim.quant_bias_numerator; - let coeff = &coeff[channel]; - let mut buf = vec![0f32; coeff.width() * coeff.height()]; - - for (&quant, out) in coeff.buf().iter().zip(&mut buf) { - *out = match quant { - -1 => -quant_bias, - 0 => 0.0, - 1 => quant_bias, - quant => { - let q = quant as f32; - q - (quant_bias_numerator / q) - }, - }; - } + for lf_group_idx in 0..frame_header.num_lf_groups() { + let Some(lf_group) = lf_groups.get(&lf_group_idx) else { continue; }; + let hf_meta = lf_group.hf_meta.as_ref().unwrap(); - let matrix = dequant_matrices.get(channel, dct_select); - for (out, &mat) in buf.iter_mut().zip(matrix) { - let val = *out * mat; - *out = val * mul; - } + let lf_left = (lf_group_idx % frame_header.lf_groups_per_row()) * frame_header.lf_group_dim(); + let lf_top = (lf_group_idx / frame_header.lf_groups_per_row()) * frame_header.lf_group_dim(); + + let block_info = &hf_meta.block_info; + let w8 = block_info.width(); + let h8 = block_info.height(); - if need_transpose { - for y in 0..coeff.height() { - for x in 0..coeff.width() { - *out.get_mut(y, x) = buf[y * coeff.width() + x]; + for (channel, coeff) in out.iter_mut().enumerate() { + let shift = shifts_cbycr[channel]; + let vshift = shift.vshift(); + let hshift = shift.hshift(); + + let quant_bias = oim.quant_bias[channel]; + let stride = coeff.width(); + for by in 0..h8 { + for bx in 0..w8 { + let &BlockInfo::Data { dct_select, hf_mul } = block_info.get(bx, by).unwrap() else { continue; }; + if ((bx >> hshift) << hshift) != bx || ((by >> vshift) << vshift) != by { + continue; + } + + let (bw, bh) = dct_select.dct_select_size(); + let width = bw * 8; + let height = bh * 8; + let need_transpose = dct_select.need_transpose(); + let mul = 65536.0 / (quantizer.global_scale as i32 * hf_mul) as f32 * qm_scale[channel]; + + let mut new_matrix; + let mut matrix = dequant_matrices.get(channel, dct_select); + if need_transpose { + new_matrix = vec![0f32; matrix.len()]; + for (idx, val) in new_matrix.iter_mut().enumerate() { + let mat_x = idx % width as usize; + let mat_y = idx / width as usize; + *val = matrix[mat_x * height as usize + mat_y]; + } + matrix = &new_matrix; + } + + let left = lf_left as usize + bx * 8; + let top = lf_top as usize + by * 8; + let left = left >> hshift; + let top = top >> vshift; + + let mut coeff = CutGrid::from_buf( + &mut coeff.buf_mut()[top * stride + left..], + width as usize, + height as usize, + stride, + ); + for y in 0..height { + let row = coeff.get_row_mut(y as usize); + let matrix_row = &matrix[(y * width) as usize..][..width as usize]; + for (q, &m) in row.iter_mut().zip(matrix_row) { + if q.abs() <= 1.0f32 { + *q *= quant_bias; + } else { + *q -= quant_bias_numerator / *q; + } + *q *= m; + *q *= mul; + } + } + } } } - } else { - for y in 0..coeff.height() { - let row = out.get_row_mut(y); - row.copy_from_slice(&buf[y * coeff.width()..][..coeff.width()]); - } } } @@ -168,52 +206,60 @@ pub fn chroma_from_luma_lf( } pub fn chroma_from_luma_hf( - coeff_xyb: &mut [&mut CutGrid<'_>; 3], - lf_left: usize, - lf_top: usize, - x_from_y: &SimpleGrid, - b_from_y: &SimpleGrid, - lf_chan_corr: &LfChannelCorrelation, + coeff_xyb: &mut [SimpleGrid; 3], + frame_header: &FrameHeader, + lf_global: &LfGlobal, + lf_groups: &HashMap, ) { let LfChannelCorrelation { colour_factor, base_correlation_x, base_correlation_b, .. - } = *lf_chan_corr; + } = lf_global.vardct.as_ref().unwrap().lf_chan_corr; let [coeff_x, coeff_y, coeff_b] = coeff_xyb; let width = coeff_x.width(); let height = coeff_x.height(); + let lf_group_dim = frame_header.lf_group_dim() as usize; - for cy in 0..height { - for cx in 0..width { - let (x, y) = if width == height { - (lf_left + cy, lf_top + cx) - } else { - (lf_left + cx, lf_top + cy) - }; - let cfactor_x = x / 64; - let cfactor_y = y / 64; - - let x_factor = *x_from_y.get(cfactor_x, cfactor_y).unwrap(); - let b_factor = *b_from_y.get(cfactor_x, cfactor_y).unwrap(); - let kx = base_correlation_x + (x_factor as f32 / colour_factor as f32); - let kb = base_correlation_b + (b_factor as f32 / colour_factor as f32); - - let coeff_y = coeff_y.get(cx, cy); - *coeff_x.get_mut(cx, cy) += kx * coeff_y; - *coeff_b.get_mut(cx, cy) += kb * coeff_y; + for lf_group_idx in 0..frame_header.num_lf_groups() { + let Some(lf_group) = lf_groups.get(&lf_group_idx) else { continue; }; + let hf_meta = lf_group.hf_meta.as_ref().unwrap(); + let x_from_y = &hf_meta.x_from_y; + let b_from_y = &hf_meta.b_from_y; + + let lf_left = ((lf_group_idx % frame_header.lf_groups_per_row()) * frame_header.lf_group_dim()) as usize; + let lf_top = ((lf_group_idx / frame_header.lf_groups_per_row()) * frame_header.lf_group_dim()) as usize; + let lf_group_width = lf_group_dim.min(width - lf_left); + let lf_group_height = lf_group_dim.min(height - lf_top); + + for cy in 0..lf_group_height { + for cx in 0..lf_group_width { + let x = lf_left + cx; + let y = lf_top + cy; + let cfactor_x = cx / 64; + let cfactor_y = cy / 64; + + let x_factor = *x_from_y.get(cfactor_x, cfactor_y).unwrap(); + let b_factor = *b_from_y.get(cfactor_x, cfactor_y).unwrap(); + let kx = base_correlation_x + (x_factor as f32 / colour_factor as f32); + let kb = base_correlation_b + (b_factor as f32 / colour_factor as f32); + + let coeff_y = *coeff_y.get(x, y).unwrap(); + *coeff_x.get_mut(x, y).unwrap() += kx * coeff_y; + *coeff_b.get_mut(x, y).unwrap() += kb * coeff_y; + } } } } -pub fn llf_from_lf( - lf: &SimpleGrid, - left: usize, - top: usize, - dct_select: TransformType, -) -> SimpleGrid { +pub fn transform_with_lf( + lf: &[SimpleGrid], + coeff_out: &mut [SimpleGrid; 3], + frame_header: &FrameHeader, + lf_groups: &HashMap, +) { use TransformType::*; fn scale_f(c: usize, b: usize) -> f32 { @@ -224,30 +270,76 @@ pub fn llf_from_lf( recip.recip() } - let (bw, bh) = dct_select.dct_select_size(); - let bw = bw as usize; - let bh = bh as usize; - - if matches!(dct_select, Hornuss | Dct2 | Dct4 | Dct8x4 | Dct4x8 | Dct8 | Afv0 | Afv1 | Afv2 | Afv3) { - debug_assert_eq!(bw * bh, 1); - let mut out = SimpleGrid::new(1, 1); - out.buf_mut()[0] = *lf.get(left, top).unwrap(); - out - } else { - let mut out = SimpleGrid::new(bw, bh); - for y in 0..bh { - for x in 0..bw { - out.buf_mut()[y * bw + x] = *lf.get(left + x, top + y).unwrap(); - } - } - dct_2d(&mut out); + let shifts_cbycr: [_; 3] = std::array::from_fn(|idx| { + ChannelShift::from_jpeg_upsampling(frame_header.jpeg_upsampling, idx) + }); + + for lf_group_idx in 0..frame_header.num_lf_groups() { + let Some(lf_group) = lf_groups.get(&lf_group_idx) else { continue; }; + let hf_meta = lf_group.hf_meta.as_ref().unwrap(); - for y in 0..bh { - for x in 0..bw { - out.buf_mut()[y * bw + x] *= scale_f(y, bh * 8) * scale_f(x, bw * 8); + let lf_left = (lf_group_idx % frame_header.lf_groups_per_row()) * frame_header.lf_group_dim(); + let lf_top = (lf_group_idx / frame_header.lf_groups_per_row()) * frame_header.lf_group_dim(); + + let block_info = &hf_meta.block_info; + let w8 = block_info.width(); + let h8 = block_info.height(); + + for (channel, (coeff, lf)) in coeff_out.iter_mut().zip(lf).enumerate() { + let shift = shifts_cbycr[channel]; + let vshift = shift.vshift(); + let hshift = shift.hshift(); + + let stride = coeff.width(); + for by in 0..h8 { + for bx in 0..w8 { + let &BlockInfo::Data { dct_select, .. } = block_info.get(bx, by).unwrap() else { continue; }; + if ((bx >> hshift) << hshift) != bx || ((by >> vshift) << vshift) != by { + continue; + } + + let (bw, bh) = dct_select.dct_select_size(); + let bw = bw as usize; + let bh = bh as usize; + + let left = lf_left as usize + bx * 8; + let top = lf_top as usize + by * 8; + let left = left >> hshift; + let top = top >> vshift; + + if matches!(dct_select, Hornuss | Dct2 | Dct4 | Dct8x4 | Dct4x8 | Dct8 | Afv0 | Afv1 | Afv2 | Afv3) { + debug_assert_eq!(bw * bh, 1); + *coeff.get_mut(left, top).unwrap() = *lf.get(left / 8, top / 8).unwrap(); + } else { + let mut out = CutGrid::from_buf( + &mut coeff.buf_mut()[top * stride + left..], + bw, + bh, + stride, + ); + + for y in 0..bh { + for x in 0..bw { + *out.get_mut(x, y) = *lf.get(left / 8 + x, top / 8 + y).unwrap(); + } + } + dct::dct_2d(&mut out, dct::DctDirection::Forward); + for y in 0..bh { + for x in 0..bw { + *out.get_mut(x, y) *= scale_f(y, bh * 8) * scale_f(x, bw * 8); + } + } + } + + let mut block = CutGrid::from_buf( + &mut coeff.buf_mut()[top * stride + left..], + bw * 8, + bh * 8, + stride, + ); + transform(&mut block, dct_select); + } } } - - out } } diff --git a/crates/jxl-render/src/vardct/transform.rs b/crates/jxl-render/src/vardct/transform.rs index bc8b9202..9c6e598d 100644 --- a/crates/jxl-render/src/vardct/transform.rs +++ b/crates/jxl-render/src/vardct/transform.rs @@ -1,7 +1,7 @@ use jxl_grid::CutGrid; use jxl_vardct::TransformType; -use crate::dct::{idct_2d, dct_2d_generic}; +use crate::dct::{dct_2d, DctDirection}; fn aux_idct2_in_place(block: &mut CutGrid<'_>, size: usize) { debug_assert!(size.is_power_of_two()); @@ -40,13 +40,13 @@ fn transform_dct4(coeff: &mut CutGrid<'_>) { let mut scratch = [0.0f32; 64]; for y in 0..2 { for x in 0..2 { - let scratch = &mut scratch[(y * 2 + x) * 16..][..16]; + let mut scratch = CutGrid::from_buf(&mut scratch[(y * 2 + x) * 16..], 4, 4, 4); for iy in 0..4 { for ix in 0..4 { - scratch[iy * 4 + ix] = coeff.get(x + ix * 2, y + iy * 2); + *scratch.get_mut(iy, ix) = coeff.get(x + ix * 2, y + iy * 2); } } - dct_2d_generic(scratch, 4, 4, true); + dct_2d(&mut scratch, DctDirection::Inverse); } } @@ -107,13 +107,13 @@ fn transform_dct4x8(coeff: &mut CutGrid<'_>, transpose: bool) { let mut scratch = [0.0f32; 64]; for idx in [0, 1] { - let scratch = &mut scratch[(idx * 32)..][..32]; + let mut scratch = CutGrid::from_buf(&mut scratch[(idx * 32)..], 8, 4, 8); for iy in 0..4 { for ix in 0..8 { - scratch[iy * 8 + ix] = coeff.get(ix, iy * 2 + idx); + *scratch.get_mut(ix, iy) = coeff.get(ix, iy * 2 + idx); } } - dct_2d_generic(scratch, 8, 4, true); + dct_2d(&mut scratch, DctDirection::Inverse); } if transpose { @@ -158,10 +158,10 @@ fn transform_afv(coeff: &mut CutGrid<'_>) { if ix | iy == 0 { continue; } - scratch_4x4[iy * 4 + ix] = coeff.get(2 * ix + 1, 2 * iy); + scratch_4x4[ix * 4 + iy] = coeff.get(2 * ix + 1, 2 * iy); } } - dct_2d_generic(&mut scratch_4x4, 4, 4, true); + dct_2d(&mut CutGrid::from_buf(&mut scratch_4x4, 4, 4, 4), DctDirection::Inverse); scratch_4x8[0] = coeff.get(0, 0) - coeff.get(0, 1); for iy in 0..4 { @@ -172,7 +172,7 @@ fn transform_afv(coeff: &mut CutGrid<'_>) { scratch_4x8[iy * 8 + ix] = coeff.get(ix, 2 * iy + 1); } } - dct_2d_generic(&mut scratch_4x8, 8, 4, true); + dct_2d(&mut CutGrid::from_buf(&mut scratch_4x8, 8, 4, 8), DctDirection::Inverse); for iy in 0..4 { let afv_y = if flip_y == 0 { iy } else { 3 - iy }; @@ -197,7 +197,7 @@ fn transform_afv(coeff: &mut CutGrid<'_>) { } fn transform_dct(coeff: &mut CutGrid<'_>) { - idct_2d(coeff); + dct_2d(coeff, DctDirection::Inverse); } pub fn transform(coeff: &mut CutGrid<'_>, dct_select: TransformType) { diff --git a/crates/jxl-render/src/vardct/x86_64.rs b/crates/jxl-render/src/vardct/x86_64.rs index c4144d4e..43c8799f 100644 --- a/crates/jxl-render/src/vardct/x86_64.rs +++ b/crates/jxl-render/src/vardct/x86_64.rs @@ -3,18 +3,17 @@ use super::generic; pub fn adaptive_lf_smoothing_impl( width: usize, height: usize, - lf_image: [&[f32]; 3], - out: [&mut [f32]; 3], + lf_image: [&mut [f32]; 3], lf_scale: [f32; 3], ) { if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { // SAFETY: Feature set is checked above. return unsafe { - adaptive_lf_smoothing_core_avx2(width, height, lf_image, out, lf_scale) + adaptive_lf_smoothing_core_avx2(width, height, lf_image, lf_scale) }; } - generic::adaptive_lf_smoothing_impl(width, height, lf_image, out, lf_scale) + generic::adaptive_lf_smoothing_impl(width, height, lf_image, lf_scale) } #[target_feature(enable = "avx2")] @@ -22,9 +21,8 @@ pub fn adaptive_lf_smoothing_impl( unsafe fn adaptive_lf_smoothing_core_avx2( width: usize, height: usize, - lf_image: [&[f32]; 3], - out: [&mut [f32]; 3], + lf_image: [&mut [f32]; 3], lf_scale: [f32; 3], ) { - generic::adaptive_lf_smoothing_impl(width, height, lf_image, out, lf_scale) + generic::adaptive_lf_smoothing_impl(width, height, lf_image, lf_scale) } diff --git a/crates/jxl-vardct/src/dct_select.rs b/crates/jxl-vardct/src/dct_select.rs index 81912cca..51f85cc6 100644 --- a/crates/jxl-vardct/src/dct_select.rs +++ b/crates/jxl-vardct/src/dct_select.rs @@ -116,7 +116,13 @@ impl TransformType { /// Returns whether DCT coefficients should be transposed. #[inline] pub fn need_transpose(&self) -> bool { - let (w, h) = self.dct_select_size(); - h > w + use TransformType::*; + + if matches!(self, Hornuss | Dct2 | Dct4 | Dct4x8 | Dct8x4 | Afv0 | Afv1 | Afv2 | Afv3) { + false + } else { + let (w, h) = self.dct_select_size(); + h >= w + } } } diff --git a/crates/jxl-vardct/src/hf_coeff.rs b/crates/jxl-vardct/src/hf_coeff.rs index 62fdc981..55d837ce 100644 --- a/crates/jxl-vardct/src/hf_coeff.rs +++ b/crates/jxl-vardct/src/hf_coeff.rs @@ -1,5 +1,5 @@ -use jxl_bitstream::{Bundle, Bitstream}; -use jxl_grid::{Subgrid, SimpleGrid, Grid}; +use jxl_bitstream::Bitstream; +use jxl_grid::{Subgrid, Grid, CutGrid}; use jxl_modular::ChannelShift; use crate::{ @@ -7,7 +7,6 @@ use crate::{ HfBlockContext, HfPass, Result, - TransformType, }; /// Parameters for decoding `HfCoeff`. @@ -22,258 +21,185 @@ pub struct HfCoeffParams<'a> { pub coeff_shift: u32, } -/// HF coefficient data in a group. -#[derive(Debug, Clone)] -pub struct HfCoeff { - data: Vec, -} - -impl HfCoeff { - /// Creates an empty `HfCoeff`. - #[inline] - pub fn empty() -> Self { - Self { data: Vec::new() } - } - - /// Returns the HF coefficient data in raster order. - #[inline] - pub fn data(&self) -> &[CoeffData] { - &self.data - } - - /// Merge coefficients from another `HfCoeff`. - /// - /// # Panics - /// Panics if `other` is not from the same group. - pub fn merge(&mut self, other: &HfCoeff) { - let reserve_size = other.data.len().saturating_sub(self.data.len()); - self.data.reserve_exact(reserve_size); - - for (target_data, other_data) in self.data.iter_mut().zip(&other.data) { - assert_eq!(target_data.bx, other_data.bx); - assert_eq!(target_data.by, other_data.by); - assert_eq!(target_data.dct_select, other_data.dct_select); - for (target, v) in target_data.coeff.iter_mut().zip(other_data.coeff.iter()) { - assert_eq!(target.width(), v.width()); - assert_eq!(target.height(), v.height()); - for (target, v) in target.buf_mut().iter_mut().zip(v.buf()) { - *target += *v; - } - } - } - - if reserve_size > 0 { - self.data.extend_from_slice(&other.data[self.data.len()..]); +pub fn write_hf_coeff( + bitstream: &mut Bitstream, + params: HfCoeffParams, + hf_coeff_output: &mut [CutGrid<'_, f32>; 3], +) -> Result<()> { + const COEFF_FREQ_CONTEXT: [u32; 64] = [ + 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, + 23, 23, 23, 23, 24, 24, 24, 24, 25, 25, 25, 25, 26, 26, 26, 26, + 27, 27, 27, 27, 28, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, + ]; + const COEFF_NUM_NONZERO_CONTEXT: [u32; 64] = [ + 0, 0, 31, 62, 62, 93, 93, 93, 93, 123, 123, 123, 123, + 152, 152, 152, 152, 152, 152, 152, 152, 180, 180, 180, 180, 180, + 180, 180, 180, 180, 180, 180, 180, 206, 206, 206, 206, 206, 206, + 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, + 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, + ]; + + let HfCoeffParams { + num_hf_presets, + hf_block_ctx, + block_info, + jpeg_upsampling, + lf_quant, + hf_pass, + coeff_shift, + } = params; + let mut dist = hf_pass.clone_decoder(); + let span = tracing::span!(tracing::Level::TRACE, "HfCoeff::parse"); + let _guard = span.enter(); + + let HfBlockContext { + qf_thresholds, + lf_thresholds, + block_ctx_map, + num_block_clusters, + } = hf_block_ctx; + let upsampling_shifts: [_; 3] = std::array::from_fn(|idx| ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx)); + + let hfp_bits = num_hf_presets.next_power_of_two().trailing_zeros(); + let hfp = bitstream.read_bits(hfp_bits)?; + let ctx_offset = 495 * *num_block_clusters * hfp; + + dist.begin(bitstream)?; + + let width = block_info.width(); + let height = block_info.height(); + let mut non_zeros_grid = upsampling_shifts.map(|shift| { + let (width, height) = shift.shift_size((width as u32, height as u32)); + Grid::new(width, height, width, height) + }); + let predict_non_zeros = |grid: &Grid, x: usize, y: usize| { + if x == 0 && y == 0 { + 32u32 + } else if x == 0 { + *grid.get(x, y - 1).unwrap() + } else if y == 0 { + *grid.get(x - 1, y).unwrap() + } else { + ( + *grid.get(x, y - 1).unwrap() + + *grid.get(x - 1, y).unwrap() + + 1 + ) >> 1 } - } -} - -/// Data for a single varblock. -#[derive(Debug, Clone)] -pub struct CoeffData { - /// X coordinate within a group, in 8x8 blocks. - pub bx: usize, - /// Y coordinate within a group, in 8x8 blocks. - pub by: usize, - /// Transform type for the varblock. - pub dct_select: TransformType, - /// Quantization multiplier for the varblock. - pub hf_mul: i32, - /// Quantized coefficients in XYB order. - pub coeff: [SimpleGrid; 3], // x, y, b -} - -impl Bundle> for HfCoeff { - type Error = crate::Error; - - fn parse(bitstream: &mut Bitstream, params: HfCoeffParams<'_>) -> Result { - const COEFF_FREQ_CONTEXT: [u32; 64] = [ - 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, - 23, 23, 23, 23, 24, 24, 24, 24, 25, 25, 25, 25, 26, 26, 26, 26, - 27, 27, 27, 27, 28, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, - ]; - const COEFF_NUM_NONZERO_CONTEXT: [u32; 64] = [ - 0, 0, 31, 62, 62, 93, 93, 93, 93, 123, 123, 123, 123, - 152, 152, 152, 152, 152, 152, 152, 152, 180, 180, 180, 180, 180, - 180, 180, 180, 180, 180, 180, 180, 206, 206, 206, 206, 206, 206, - 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, - 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, - ]; - - let mut data = Vec::new(); - - let HfCoeffParams { - num_hf_presets, - hf_block_ctx, - block_info, - jpeg_upsampling, - lf_quant, - hf_pass, - coeff_shift, - } = params; - let mut dist = hf_pass.clone_decoder(); - let span = tracing::span!(tracing::Level::TRACE, "HfCoeff::parse"); - let _guard = span.enter(); - - let HfBlockContext { - qf_thresholds, - lf_thresholds, - block_ctx_map, - num_block_clusters, - } = hf_block_ctx; - let upsampling_shifts: [_; 3] = std::array::from_fn(|idx| ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx)); - - let hfp_bits = num_hf_presets.next_power_of_two().trailing_zeros(); - let hfp = bitstream.read_bits(hfp_bits)?; - let ctx_offset = 495 * *num_block_clusters * hfp; - - dist.begin(bitstream)?; - - let width = block_info.width(); - let height = block_info.height(); - let mut non_zeros_grid = upsampling_shifts.map(|shift| { - let (width, height) = shift.shift_size((width as u32, height as u32)); - Grid::new(width, height, width, height) - }); - let predict_non_zeros = |grid: &Grid, x: usize, y: usize| { - if x == 0 && y == 0 { - 32u32 - } else if x == 0 { - *grid.get(x, y - 1).unwrap() - } else if y == 0 { - *grid.get(x - 1, y).unwrap() - } else { - ( - *grid.get(x, y - 1).unwrap() + - *grid.get(x - 1, y).unwrap() + - 1 - ) >> 1 - } - }; - - for y in 0..height { - for x in 0..width { - let BlockInfo::Data { dct_select, hf_mul: qf } = *block_info.get(x, y).unwrap() else { - continue; - }; - let (w8, h8) = dct_select.dct_select_size(); - let coeff_size = dct_select.dequant_matrix_size(); - let num_blocks = w8 * h8; - let order_id = dct_select.order_id(); - let qdc: Option<[_; 3]> = lf_quant.as_ref().map(|lf_quant| { - std::array::from_fn(|idx| { - let shift = upsampling_shifts[idx]; - let x = x >> shift.hshift(); - let y = y >> shift.vshift(); - *lf_quant[idx].get(x, y).unwrap() - }) - }); - - let hf_idx = { - let mut idx = 0usize; - for &threshold in qf_thresholds { - if qf > threshold as i32 { + }; + + for y in 0..height { + for x in 0..width { + let BlockInfo::Data { dct_select, hf_mul: qf } = *block_info.get(x, y).unwrap() else { + continue; + }; + let (w8, h8) = dct_select.dct_select_size(); + let num_blocks = w8 * h8; + let order_id = dct_select.order_id(); + let qdc: Option<[_; 3]> = lf_quant.as_ref().map(|lf_quant| { + std::array::from_fn(|idx| { + let shift = upsampling_shifts[idx]; + let x = x >> shift.hshift(); + let y = y >> shift.vshift(); + *lf_quant[idx].get(x, y).unwrap() + }) + }); + + let hf_idx = { + let mut idx = 0usize; + for &threshold in qf_thresholds { + if qf > threshold as i32 { + idx += 1; + } + } + idx + }; + let lf_idx = if let Some(qdc) = qdc { + let mut idx = 0usize; + for c in [0, 2, 1] { + let lf_thresholds = &lf_thresholds[c]; + idx *= lf_thresholds.len() + 1; + + let q = qdc[c]; + for &threshold in lf_thresholds { + if q > threshold { idx += 1; } } - idx - }; - let lf_idx = if let Some(qdc) = qdc { - let mut idx = 0usize; - for c in [0, 2, 1] { - let lf_thresholds = &lf_thresholds[c]; - idx *= lf_thresholds.len() + 1; + } + idx + } else { + 0 + }; + let lf_idx_mul = (lf_thresholds[0].len() + 1) * (lf_thresholds[1].len() + 1) * (lf_thresholds[2].len() + 1); + + for c in [1, 0, 2] { // y, x, b + let shift = upsampling_shifts[c]; + let sx = x >> shift.hshift(); + let sy = y >> shift.vshift(); + if sx << shift.hshift() != x || sy << shift.vshift() != y { + continue; + } - let q = qdc[c]; - for &threshold in lf_thresholds { - if q > threshold { - idx += 1; - } - } - } - idx - } else { - 0 + let ch_idx = [1, 0, 2][c] * 13 + order_id as usize; + let idx = (ch_idx * (qf_thresholds.len() + 1) + hf_idx) * lf_idx_mul + lf_idx; + let block_ctx = block_ctx_map[idx] as u32; + let non_zeros_ctx = { + let predicted = predict_non_zeros(&non_zeros_grid[c], sx, sy).min(64); + let idx = if predicted >= 8 { + 4 + predicted / 2 + } else { + predicted + }; + block_ctx + idx * num_block_clusters }; - let lf_idx_mul = (lf_thresholds[0].len() + 1) * (lf_thresholds[1].len() + 1) * (lf_thresholds[2].len() + 1); - let mut coeff = [ - SimpleGrid::new(coeff_size.0 as usize, coeff_size.1 as usize), - SimpleGrid::new(coeff_size.0 as usize, coeff_size.1 as usize), - SimpleGrid::new(coeff_size.0 as usize, coeff_size.1 as usize), - ]; - for c in [1, 0, 2] { // y, x, b - let shift = upsampling_shifts[c]; - let sx = x >> shift.hshift(); - let sy = y >> shift.vshift(); - if sx << shift.hshift() != x || sy << shift.vshift() != y { - continue; + let mut non_zeros = dist.read_varint(bitstream, ctx_offset + non_zeros_ctx)?; + let non_zeros_val = (non_zeros + num_blocks - 1) / num_blocks; + let non_zeros_grid = &mut non_zeros_grid[c]; + for dy in 0..h8 as usize { + for dx in 0..w8 as usize { + non_zeros_grid.set(sx + dx, sy + dy, non_zeros_val); } + } - let ch_idx = [1, 0, 2][c] * 13 + order_id as usize; - let idx = (ch_idx * (qf_thresholds.len() + 1) + hf_idx) * lf_idx_mul + lf_idx; - let block_ctx = block_ctx_map[idx] as u32; - let non_zeros_ctx = { - let predicted = predict_non_zeros(&non_zeros_grid[c], sx, sy).min(64); - let idx = if predicted >= 8 { - 4 + predicted / 2 - } else { - predicted - }; - block_ctx + idx * num_block_clusters - }; - - let mut non_zeros = dist.read_varint(bitstream, ctx_offset + non_zeros_ctx)?; - let non_zeros_val = (non_zeros + num_blocks - 1) / num_blocks; - let non_zeros_grid = &mut non_zeros_grid[c]; - for dy in 0..h8 as usize { - for dx in 0..w8 as usize { - non_zeros_grid.set(sx + dx, sy + dy, non_zeros_val); - } + let size = (w8 * 8) * (h8 * 8); + let coeff_grid = &mut hf_coeff_output[c]; + let mut prev_coeff = (non_zeros <= size / 16) as i32; + let order_it = hf_pass.order(order_id as usize, c); + for (idx, coeff_coord) in order_it.enumerate().skip(num_blocks as usize) { + if non_zeros == 0 { + break; } - let size = (w8 * 8) * (h8 * 8); - let coeff_grid = &mut coeff[c]; - let mut prev_coeff = (non_zeros <= size / 16) as i32; - let order_it = hf_pass.order(order_id as usize, c); - for (idx, coeff_coord) in order_it.enumerate().skip(num_blocks as usize) { - if non_zeros == 0 { - break; - } - - let idx = idx as u32; - let coeff_ctx = { - let prev = (prev_coeff != 0) as u32; - let non_zeros = (non_zeros + num_blocks - 1) / num_blocks; - let idx = idx / num_blocks; - (COEFF_NUM_NONZERO_CONTEXT[non_zeros as usize] + COEFF_FREQ_CONTEXT[idx as usize]) * 2 + - prev + block_ctx * 458 + 37 * num_block_clusters - }; - let ucoeff = dist.read_varint(bitstream, ctx_offset + coeff_ctx)?; - let coeff = jxl_bitstream::unpack_signed(ucoeff) << coeff_shift; - let (x, y) = coeff_coord; - *coeff_grid.get_mut(x as usize, y as usize).unwrap() = coeff; - prev_coeff = coeff; + let idx = idx as u32; + let coeff_ctx = { + let prev = (prev_coeff != 0) as u32; + let non_zeros = (non_zeros + num_blocks - 1) / num_blocks; + let idx = idx / num_blocks; + (COEFF_NUM_NONZERO_CONTEXT[non_zeros as usize] + COEFF_FREQ_CONTEXT[idx as usize]) * 2 + + prev + block_ctx * 458 + 37 * num_block_clusters + }; + let ucoeff = dist.read_varint(bitstream, ctx_offset + coeff_ctx)?; + let coeff = jxl_bitstream::unpack_signed(ucoeff) << coeff_shift; + let (x, y) = if dct_select.need_transpose() { + (sx * 8 + coeff_coord.1 as usize, sy * 8 + coeff_coord.0 as usize) + } else { + (sx * 8 + coeff_coord.0 as usize, sy * 8 + coeff_coord.1 as usize) + }; + *coeff_grid.get_mut(x, y) += coeff as f32; + prev_coeff = coeff; - if coeff != 0 { - non_zeros -= 1; - } + if coeff != 0 { + non_zeros -= 1; } } - - data.push(CoeffData { - bx: x, - by: y, - dct_select, - hf_mul: qf, - coeff, - }); } } + } - dist.finalize()?; + dist.finalize()?; - Ok(Self { data }) - } + Ok(()) }