From 25589faa62ae9f24ea8d74c0e270af7ed52079f0 Mon Sep 17 00:00:00 2001 From: Luca Versari Date: Sun, 24 Nov 2024 15:40:54 +0100 Subject: [PATCH] Implement header decoding for Modular groups --- jxl/src/error.rs | 6 + jxl/src/frame.rs | 27 ++- jxl/src/frame/modular.rs | 262 +++++++++------------------- jxl/src/frame/modular/predict.rs | 39 +++++ jxl/src/frame/modular/transforms.rs | 162 +++++++++++++++++ jxl/src/frame/modular/tree.rs | 165 ++++++++++++++++++ jxl/src/headers/frame_header.rs | 86 ++++++--- jxl/src/headers/mod.rs | 1 + jxl/src/headers/modular.rs | 151 ++++++++++++++++ 9 files changed, 691 insertions(+), 208 deletions(-) create mode 100644 jxl/src/frame/modular/predict.rs create mode 100644 jxl/src/frame/modular/transforms.rs create mode 100644 jxl/src/frame/modular/tree.rs create mode 100644 jxl/src/headers/modular.rs diff --git a/jxl/src/error.rs b/jxl/src/error.rs index d832a73..05fe71b 100644 --- a/jxl/src/error.rs +++ b/jxl/src/error.rs @@ -121,6 +121,12 @@ pub enum Error { TreeMultiplierBitsTooLarge(u32, u32), #[error("Modular tree splits on property {0} at value {1}, which is outside the possible range of [{2}, {3}]")] TreeSplitOnEmptyRange(u8, i32, i32, i32), + #[error("Modular stream requested a global tree but there isn't one")] + NoGlobalTree, + #[error("Invalid transform id")] + InvalidTransformId, + #[error("Invalid RCT type {0}")] + InvalidRCT(u32), } pub type Result = std::result::Result; diff --git a/jxl/src/frame.rs b/jxl/src/frame.rs index 19d7334..fe6481d 100644 --- a/jxl/src/frame.rs +++ b/jxl/src/frame.rs @@ -15,10 +15,10 @@ use crate::{ }, util::tracing_wrappers::*, }; -use modular::Tree; +use modular::{FullModularImage, Tree}; use quantizer::LfQuantFactors; -mod modular; +pub mod modular; mod quantizer; #[derive(Debug, PartialEq, Eq)] @@ -29,18 +29,17 @@ pub enum Section { Hf(usize, usize), // group, pass } +#[allow(dead_code)] pub struct LfGlobalState { // TODO(veluca93): patches // TODO(veluca93): splines // TODO(veluca93): noise - #[allow(dead_code)] lf_quant: LfQuantFactors, // TODO(veluca93), VarDCT: HF quant matrices // TODO(veluca93), VarDCT: block context map // TODO(veluca93), VarDCT: LF color correlation - // TODO(veluca93): Modular data - #[allow(dead_code)] tree: Option, + modular_global: FullModularImage, } pub struct Frame { @@ -121,9 +120,9 @@ impl Frame { match section { Section::LfGlobal => 0, Section::Lf(a) => 1 + a, - Section::HfGlobal => self.header.num_dc_groups() + 1, + Section::HfGlobal => self.header.num_lf_groups() + 1, Section::Hf(group, pass) => { - 2 + self.header.num_dc_groups() + self.header.num_groups() * pass + group + 2 + self.header.num_lf_groups() + self.header.num_groups() * pass + group } } } @@ -169,7 +168,19 @@ impl Frame { None }; - self.lf_global = Some(LfGlobalState { lf_quant, tree }); + let modular_global = FullModularImage::read( + &self.header, + self.modular_color_channels, + &self.extra_channel_info, + &tree, + br, + )?; + + self.lf_global = Some(LfGlobalState { + lf_quant, + tree, + modular_global, + }); Ok(()) } diff --git a/jxl/src/frame/modular.rs b/jxl/src/frame/modular.rs index 2ae6fa0..04d5555 100644 --- a/jxl/src/frame/modular.rs +++ b/jxl/src/frame/modular.rs @@ -3,205 +3,113 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +use std::fmt::Debug; + use crate::{ bit_reader::BitReader, - entropy_coding::decode::Histograms, error::{Error, Result}, + headers::{ + extra_channels::ExtraChannelInfo, frame_header::FrameHeader, modular::GroupHeader, + JxlHeader, + }, + image::Image, util::tracing_wrappers::*, + util::CeilLog2, }; -#[repr(u8)] -#[derive(Debug)] -enum Predictor { - Zero = 0, - Left = 1, - Top = 2, - Average0 = 3, - Select = 4, - Gradient = 5, - Weighted = 6, - TopRight = 7, - TopLeft = 8, - LeftLeft = 9, - Average1 = 10, - Average2 = 11, - Average3 = 12, - Average4 = 13, -} +mod predict; +mod transforms; +mod tree; -impl TryFrom for Predictor { - type Error = Error; - - fn try_from(value: u32) -> Result { - match value { - 0 => Ok(Predictor::Zero), - 1 => Ok(Predictor::Left), - 2 => Ok(Predictor::Top), - 3 => Ok(Predictor::Average0), - 4 => Ok(Predictor::Select), - 5 => Ok(Predictor::Gradient), - 6 => Ok(Predictor::Weighted), - 7 => Ok(Predictor::TopRight), - 8 => Ok(Predictor::TopLeft), - 9 => Ok(Predictor::LeftLeft), - 10 => Ok(Predictor::Average1), - 11 => Ok(Predictor::Average2), - 12 => Ok(Predictor::Average3), - 13 => Ok(Predictor::Average4), - _ => Err(Error::InvalidPredictor(value)), - } - } -} +pub use predict::Predictor; +use transforms::Transform; +pub use tree::Tree; #[allow(dead_code)] #[derive(Debug)] -enum TreeNode { - Split { - property: u8, - val: i32, - left: u32, - right: u32, - }, - Leaf { - predictor: Predictor, - offset: i32, - multiplier: u32, - id: u32, - }, +struct ChannelInfo { + size: (usize, usize), + shift: Option<(isize, isize)>, // None for meta-channels } #[allow(dead_code)] #[derive(Debug)] -pub struct Tree { - nodes: Vec, - histograms: Histograms, +struct MetaInfo { + channels: Vec, + transforms: Vec, +} + +pub struct FullModularImage { + // TODO: decoding graph for processing global transforms + meta_info: MetaInfo, + global_channels: Vec>, } -const SPLIT_VAL_CONTEXT: usize = 0; -const PROPERTY_CONTEXT: usize = 1; -const PREDICTOR_CONTEXT: usize = 2; -const OFFSET_CONTEXT: usize = 3; -const MULTIPLIER_LOG_CONTEXT: usize = 4; -const MULTIPLIER_BITS_CONTEXT: usize = 5; -const NUM_TREE_CONTEXTS: usize = 6; +impl Debug for FullModularImage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "[info: {:?}, global channel sizes: {:?}]", + self.meta_info, + self.global_channels + .iter() + .map(Image::size) + .collect::>() + ) + } +} -impl Tree { - #[instrument(level = "debug", skip(br), ret, err)] - pub fn read(br: &mut BitReader, size_limit: usize) -> Result { - assert!(size_limit <= u32::MAX as usize); - trace!(pos = br.total_bits_read()); - let tree_histograms = Histograms::decode(NUM_TREE_CONTEXTS, br, true)?; - let mut tree_reader = tree_histograms.make_reader(br)?; - // TODO(veluca): consider early-exiting for trees known to be infinite. - let mut tree: Vec = vec![]; - let mut to_decode = 1; - let mut leaf_id = 0; - let mut max_property = 0; - while to_decode > 0 { - if tree.len() > size_limit { - return Err(Error::TreeTooLarge(tree.len(), size_limit)); - } - if tree.len() >= tree.capacity() { - tree.try_reserve(tree.len() * 2 + 1)?; - } - to_decode -= 1; - let property = tree_reader.read(br, PROPERTY_CONTEXT)?; - trace!(property); - if let Some(property) = property.checked_sub(1) { - // inner node. - if property > 255 { - return Err(Error::InvalidProperty(property)); - } - max_property = max_property.max(property); - let splitval = tree_reader.read_signed(br, SPLIT_VAL_CONTEXT)?; - let left_child = (tree.len() + to_decode + 1) as u32; - let node = TreeNode::Split { - property: property as u8, - val: splitval, - left: left_child, - right: left_child + 1, - }; - trace!("split node {:?}", node); - to_decode += 2; - tree.push(node); - } else { - let predictor = Predictor::try_from(tree_reader.read(br, PREDICTOR_CONTEXT)?)?; - let offset = tree_reader.read_signed(br, OFFSET_CONTEXT)?; - let mul_log = tree_reader.read(br, MULTIPLIER_LOG_CONTEXT)?; - if mul_log >= 31 { - return Err(Error::TreeMultiplierTooLarge(mul_log, 31)); - } - let mul_bits = tree_reader.read(br, MULTIPLIER_BITS_CONTEXT)?; - let multiplier = (mul_bits as u64 + 1) << mul_log; - if multiplier > (u32::MAX as u64) { - return Err(Error::TreeMultiplierBitsTooLarge(mul_bits, mul_log)); - } - let node = TreeNode::Leaf { - predictor, - offset, - id: leaf_id, - multiplier: multiplier as u32, - }; - leaf_id += 1; - trace!("leaf node {:?}", node); - tree.push(node); - } +impl FullModularImage { + #[instrument(level = "debug", skip_all, ret)] + pub fn read( + header: &FrameHeader, + modular_color_channels: usize, + extra_channel_info: &[ExtraChannelInfo], + global_tree: &Option, + br: &mut BitReader, + ) -> Result { + let mut channels = vec![]; + for c in 0..modular_color_channels { + let shift = (header.hshift(c) as isize, header.vshift(c) as isize); + let size = (header.width as usize, header.height as usize); + channels.push(ChannelInfo { + size: (size.0.div_ceil(1 << shift.0), size.1.div_ceil(1 << shift.1)), + shift: Some(shift), + }); } - tree_reader.check_final_state()?; - let num_properties = max_property as usize + 1; - let mut property_ranges = vec![]; - property_ranges.try_reserve(num_properties * tree.len())?; - property_ranges.resize(num_properties * tree.len(), (i32::MIN, i32::MAX)); - let mut height = vec![]; - height.try_reserve(tree.len())?; - height.resize(tree.len(), 0); - for i in 0..tree.len() { - const HEIGHT_LIMIT: usize = 2048; - if height[i] > HEIGHT_LIMIT { - return Err(Error::TreeTooLarge(height[i], HEIGHT_LIMIT)); - } - if let TreeNode::Split { - property, - val, - left, - right, - } = tree[i] - { - height[left as usize] = height[i] + 1; - height[right as usize] = height[i] + 1; - for p in 0..num_properties { - if p == property as usize { - let (l, u) = property_ranges[i * num_properties + p]; - if l > val || u <= val { - return Err(Error::TreeSplitOnEmptyRange(p as u8, val, l, u)); - } - trace!("splitting at node {i} on property {p}, range [{l}, {u}] at position {val}"); - property_ranges[left as usize * num_properties + p] = (val + 1, u); - property_ranges[right as usize * num_properties + p] = (l, val); - } else { - property_ranges[left as usize * num_properties + p] = - property_ranges[i * num_properties + p]; - property_ranges[right as usize * num_properties + p] = - property_ranges[i * num_properties + p]; - } - } - } else { - #[cfg(feature = "tracing")] - { - for p in 0..num_properties { - let (l, u) = property_ranges[i * num_properties + p]; - trace!("final range at node {i} property {p}: [{l}, {u}]"); - } - } - } + for info in extra_channel_info { + let shift = info.dim_shift() as isize - header.upsampling.ceil_log2() as isize; + let size = header.size_upsampled(); + let size = (size.0 >> info.dim_shift(), size.1 >> info.dim_shift()); + channels.push(ChannelInfo { + size, + shift: Some((shift, shift)), + }); } - let histograms = Histograms::decode((tree.len() + 1) / 2, br, true)?; + trace!("reading modular header"); + let header = GroupHeader::read(br)?; + + if header.use_global_tree && global_tree.is_none() { + return Err(Error::NoGlobalTree); + } + + let meta_info = MetaInfo { + transforms: header + .transforms + .iter() + .map(|x| Transform::from_bitstream(x, 0, &channels)) + .filter(|x| !x.is_noop()) + .collect(), + channels, + }; + + // TODO(veluca93): meta-apply transforms - Ok(Tree { - nodes: tree, - histograms, + Ok(FullModularImage { + meta_info, + global_channels: vec![], // TODO(veluca93): read global channels }) } } diff --git a/jxl/src/frame/modular/predict.rs b/jxl/src/frame/modular/predict.rs new file mode 100644 index 0000000..58a6874 --- /dev/null +++ b/jxl/src/frame/modular/predict.rs @@ -0,0 +1,39 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +use crate::error::{Error, Result}; +use num_derive::FromPrimitive; +use num_traits::FromPrimitive; + +#[repr(u8)] +#[derive(Debug, FromPrimitive)] +pub enum Predictor { + Zero = 0, + West = 1, + North = 2, + AverageWestAndNorth = 3, + Select = 4, + Gradient = 5, + Weighted = 6, + NorthEast = 7, + NorthWest = 8, + WestWest = 9, + AverageWestAndNorthWest = 10, + AverageNorthAndNorthWest = 11, + AverageNorthAndNorthEast = 12, + AverageAll = 13, +} + +impl TryFrom for Predictor { + type Error = Error; + + fn try_from(value: u32) -> Result { + Self::from_u32(value).ok_or(Error::InvalidPredictor(value)) + } +} + +impl Predictor { + pub const NUM_PREDICTORS: u32 = Predictor::AverageAll as u32 + 1; +} diff --git a/jxl/src/frame/modular/transforms.rs b/jxl/src/frame/modular/transforms.rs new file mode 100644 index 0000000..7a86fde --- /dev/null +++ b/jxl/src/frame/modular/transforms.rs @@ -0,0 +1,162 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +use num_derive::FromPrimitive; +use num_traits::FromPrimitive; + +use crate::headers::{ + self, + modular::{SqueezeParams, TransformId}, +}; +use crate::util::tracing_wrappers::*; + +use super::ChannelInfo; + +#[derive(Debug, FromPrimitive, PartialEq)] +pub enum RctPermutation { + Rgb = 0, + Gbr = 1, + Brg = 2, + Rbg = 3, + Grb = 4, + Bgr = 5, +} + +#[derive(Debug, FromPrimitive, PartialEq)] +pub enum RctOp { + Noop = 0, + AddFirstToThird = 1, + AddFirstToSecond = 2, + AddFirstToSecondAndThird = 3, + AddAvgToSecond = 4, + AddFirstToThirdAndAvgToSecond = 5, + YCoCg = 6, +} + +#[allow(dead_code)] +#[derive(Debug)] +pub enum Transform { + Rct { + begin_channel: usize, + op: RctOp, + perm: RctPermutation, + }, + Palette { + begin_channel: usize, + num_channels: usize, + num_colors: usize, + num_deltas: usize, + }, + Squeeze(Vec), +} + +fn default_squeeze( + num_meta_channels: usize, + data_channel_info: &[ChannelInfo], +) -> Vec { + let mut w = data_channel_info[0].size.0; + let mut h = data_channel_info[0].size.1; + let nc = data_channel_info.len(); + + let mut params = vec![]; + + if nc > 2 && data_channel_info[1].size == (w, h) { + // 420 previews + let sp = SqueezeParams { + horizontal: true, + in_place: false, + begin_channel: num_meta_channels as u32 + 1, + num_channels: 2, + }; + params.push(sp); + params.push(SqueezeParams { + horizontal: false, + ..sp + }); + } + + const MAX_FIRST_PREVIEW_SIZE: usize = 8; + + let sp = SqueezeParams { + begin_channel: num_meta_channels as u32, + num_channels: nc as u32, + in_place: true, + horizontal: false, + }; + + // vertical first on tall images + if w <= h && h > MAX_FIRST_PREVIEW_SIZE { + params.push(SqueezeParams { + horizontal: false, + ..sp + }); + h = (h + 1) / 2; + } + while w > MAX_FIRST_PREVIEW_SIZE || h > MAX_FIRST_PREVIEW_SIZE { + if w > MAX_FIRST_PREVIEW_SIZE { + params.push(SqueezeParams { + horizontal: true, + ..sp + }); + w = (w + 1) / 2; + } + if h > MAX_FIRST_PREVIEW_SIZE { + params.push(SqueezeParams { + horizontal: false, + ..sp + }); + h = (h + 1) / 2; + } + } + + params +} + +impl Transform { + #[instrument(level = "trace", ret)] + pub fn from_bitstream( + t: &headers::modular::Transform, + num_meta_channels: usize, + data_channel_info: &[ChannelInfo], + ) -> Transform { + match t.id { + TransformId::Rct => Transform::Rct { + begin_channel: t.begin_channel as usize, + op: RctOp::from_u32(t.rct_type % 7).unwrap(), + perm: RctPermutation::from_u32(t.rct_type / 7) + .expect("header decoding should ensure rct_type < 42"), + }, + TransformId::Palette => Transform::Palette { + begin_channel: t.begin_channel as usize, + num_channels: t.num_channels as usize, + num_colors: t.num_colors as usize, + num_deltas: t.num_deltas as usize, + }, + TransformId::Squeeze => { + if t.squeezes.is_empty() { + Transform::Squeeze(default_squeeze(num_meta_channels, data_channel_info)) + } else { + Transform::Squeeze(t.squeezes.clone()) + } + } + TransformId::Invalid => { + unreachable!("header decoding for invalid transforms should fail") + } + } + } + + #[instrument(level = "trace", ret)] + pub fn is_noop(&self) -> bool { + match self { + Self::Rct { + begin_channel: _, + op, + perm, + } => *op == RctOp::Noop && *perm == RctPermutation::Rgb, + Self::Squeeze(x) if x.is_empty() => true, + _ => false, + } + } +} diff --git a/jxl/src/frame/modular/tree.rs b/jxl/src/frame/modular/tree.rs new file mode 100644 index 0000000..6a1b008 --- /dev/null +++ b/jxl/src/frame/modular/tree.rs @@ -0,0 +1,165 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +use super::Predictor; +use crate::{ + bit_reader::BitReader, + entropy_coding::decode::Histograms, + error::{Error, Result}, + util::tracing_wrappers::*, +}; + +#[allow(dead_code)] +#[derive(Debug)] +enum TreeNode { + Split { + property: u8, + val: i32, + left: u32, + right: u32, + }, + Leaf { + predictor: Predictor, + offset: i32, + multiplier: u32, + id: u32, + }, +} + +#[allow(dead_code)] +#[derive(Debug)] +pub struct Tree { + nodes: Vec, + histograms: Histograms, +} + +const SPLIT_VAL_CONTEXT: usize = 0; +const PROPERTY_CONTEXT: usize = 1; +const PREDICTOR_CONTEXT: usize = 2; +const OFFSET_CONTEXT: usize = 3; +const MULTIPLIER_LOG_CONTEXT: usize = 4; +const MULTIPLIER_BITS_CONTEXT: usize = 5; +const NUM_TREE_CONTEXTS: usize = 6; + +impl Tree { + #[instrument(level = "debug", skip(br), ret, err)] + pub fn read(br: &mut BitReader, size_limit: usize) -> Result { + assert!(size_limit <= u32::MAX as usize); + trace!(pos = br.total_bits_read()); + let tree_histograms = Histograms::decode(NUM_TREE_CONTEXTS, br, true)?; + let mut tree_reader = tree_histograms.make_reader(br)?; + // TODO(veluca): consider early-exiting for trees known to be infinite. + let mut tree: Vec = vec![]; + let mut to_decode = 1; + let mut leaf_id = 0; + let mut max_property = 0; + while to_decode > 0 { + if tree.len() > size_limit { + return Err(Error::TreeTooLarge(tree.len(), size_limit)); + } + if tree.len() >= tree.capacity() { + tree.try_reserve(tree.len() * 2 + 1)?; + } + to_decode -= 1; + let property = tree_reader.read(br, PROPERTY_CONTEXT)?; + trace!(property); + if let Some(property) = property.checked_sub(1) { + // inner node. + if property > 255 { + return Err(Error::InvalidProperty(property)); + } + max_property = max_property.max(property); + let splitval = tree_reader.read_signed(br, SPLIT_VAL_CONTEXT)?; + let left_child = (tree.len() + to_decode + 1) as u32; + let node = TreeNode::Split { + property: property as u8, + val: splitval, + left: left_child, + right: left_child + 1, + }; + trace!("split node {:?}", node); + to_decode += 2; + tree.push(node); + } else { + let predictor = Predictor::try_from(tree_reader.read(br, PREDICTOR_CONTEXT)?)?; + let offset = tree_reader.read_signed(br, OFFSET_CONTEXT)?; + let mul_log = tree_reader.read(br, MULTIPLIER_LOG_CONTEXT)?; + if mul_log >= 31 { + return Err(Error::TreeMultiplierTooLarge(mul_log, 31)); + } + let mul_bits = tree_reader.read(br, MULTIPLIER_BITS_CONTEXT)?; + let multiplier = (mul_bits as u64 + 1) << mul_log; + if multiplier > (u32::MAX as u64) { + return Err(Error::TreeMultiplierBitsTooLarge(mul_bits, mul_log)); + } + let node = TreeNode::Leaf { + predictor, + offset, + id: leaf_id, + multiplier: multiplier as u32, + }; + leaf_id += 1; + trace!("leaf node {:?}", node); + tree.push(node); + } + } + tree_reader.check_final_state()?; + + let num_properties = max_property as usize + 1; + let mut property_ranges = vec![]; + property_ranges.try_reserve(num_properties * tree.len())?; + property_ranges.resize(num_properties * tree.len(), (i32::MIN, i32::MAX)); + let mut height = vec![]; + height.try_reserve(tree.len())?; + height.resize(tree.len(), 0); + for i in 0..tree.len() { + const HEIGHT_LIMIT: usize = 2048; + if height[i] > HEIGHT_LIMIT { + return Err(Error::TreeTooLarge(height[i], HEIGHT_LIMIT)); + } + if let TreeNode::Split { + property, + val, + left, + right, + } = tree[i] + { + height[left as usize] = height[i] + 1; + height[right as usize] = height[i] + 1; + for p in 0..num_properties { + if p == property as usize { + let (l, u) = property_ranges[i * num_properties + p]; + if l > val || u <= val { + return Err(Error::TreeSplitOnEmptyRange(p as u8, val, l, u)); + } + trace!("splitting at node {i} on property {p}, range [{l}, {u}] at position {val}"); + property_ranges[left as usize * num_properties + p] = (val + 1, u); + property_ranges[right as usize * num_properties + p] = (l, val); + } else { + property_ranges[left as usize * num_properties + p] = + property_ranges[i * num_properties + p]; + property_ranges[right as usize * num_properties + p] = + property_ranges[i * num_properties + p]; + } + } + } else { + #[cfg(feature = "tracing")] + { + for p in 0..num_properties { + let (l, u) = property_ranges[i * num_properties + p]; + trace!("final range at node {i} property {p}: [{l}, {u}]"); + } + } + } + } + + let histograms = Histograms::decode((tree.len() + 1) / 2, br, true)?; + + Ok(Tree { + nodes: tree, + histograms, + }) + } +} diff --git a/jxl/src/headers/frame_header.rs b/jxl/src/headers/frame_header.rs index 512b704..d3bb502 100644 --- a/jxl/src/headers/frame_header.rs +++ b/jxl/src/headers/frame_header.rs @@ -296,7 +296,7 @@ pub struct FrameHeader { #[coder(u2S(1, 2, 4, 8))] #[default(1)] #[condition(flags & Flags::USE_LF_FRAME == 0)] - upsampling: u32, + pub upsampling: u32, #[size_coder(explicit(nonserialized.num_extra_channels))] #[coder(u2S(1, 2, 4, 8))] @@ -441,37 +441,27 @@ pub struct FrameHeader { // TODO(firsching): remove once we use this! #[allow(dead_code)] impl FrameHeader { - fn group_dim(&self) -> u32 { - const GROUP_DIM: u32 = 256; - (GROUP_DIM >> 1) << self.group_size_shift + const GROUP_DIM: usize = 256; + const BLOCK_DIM: usize = 8; + + fn group_dim(&self) -> usize { + (Self::GROUP_DIM >> 1) << self.group_size_shift } - fn dc_group_dim(&self) -> u32 { - const BLOCK_DIM: u32 = 8; - self.group_dim() * BLOCK_DIM + fn dc_group_dim(&self) -> usize { + self.group_dim() * Self::BLOCK_DIM } + pub fn num_groups(&self) -> usize { - let xsize = self.width as usize; - let ysize = self.height as usize; - let group_dim = self.group_dim() as usize; - let xsize_groups = xsize.div_ceil(group_dim); - let ysize_groups = ysize.div_ceil(group_dim); - xsize_groups * ysize_groups + self.size_groups().0 * self.size_groups().1 } - pub fn num_dc_groups(&self) -> usize { - const BLOCK_DIM: usize = 8; - - let xsize_blocks = (self.width as usize).div_ceil(BLOCK_DIM << self.maxhs) << self.maxhs; - let ysize_blocks = (self.height as usize).div_ceil(BLOCK_DIM << self.maxvs) << self.maxvs; - let group_dim = self.group_dim() as usize; - let xsize_dc_groups = xsize_blocks.div_ceil(group_dim); - let ysize_dc_groups = ysize_blocks.div_ceil(group_dim); - xsize_dc_groups * ysize_dc_groups + pub fn num_lf_groups(&self) -> usize { + self.size_lf_groups().0 * self.size_lf_groups().1 } pub fn num_toc_entries(&self) -> usize { let num_groups = self.num_groups(); - let num_dc_groups = self.num_dc_groups(); + let num_dc_groups = self.num_lf_groups(); if num_groups == 1 && self.passes.num_passes == 1 { 1 @@ -528,6 +518,56 @@ impl FrameHeader { self.hshift(2) == 0 && self.vshift(2) == 1 && // Cr self.hshift(1) == 0 && self.vshift(1) == 0 // Y } + + /// The dimensions of this frame, as coded in the codestream, excluding padding pixels. + pub fn size(&self) -> (usize, usize) { + ( + (self.width as usize).div_ceil(self.upsampling as usize), + (self.height as usize).div_ceil(self.upsampling as usize), + ) + } + + /// The dimensions of this frame, as coded in the codestream, in 8x8 blocks. + pub fn size_blocks(&self) -> (usize, usize) { + ( + self.size().0.div_ceil(Self::BLOCK_DIM << self.maxhs) << self.maxhs, + self.size().1.div_ceil(Self::BLOCK_DIM << self.maxvs) << self.maxvs, + ) + } + + /// The dimensions of this frame, as coded in the codestream but including padding pixels. + pub fn size_padded(&self) -> (usize, usize) { + if self.encoding == Encoding::Modular { + self.size() + } else { + ( + self.size_blocks().0 * Self::BLOCK_DIM, + self.size_blocks().1 * Self::BLOCK_DIM, + ) + } + } + + /// The dimensions of this frame, after upsampling. + pub fn size_upsampled(&self) -> (usize, usize) { + (self.width as usize, self.height as usize) + } + + /// The dimensions of this frame, in groups. + pub fn size_groups(&self) -> (usize, usize) { + ( + self.size().0.div_ceil(self.group_dim()), + self.size().1.div_ceil(self.group_dim()), + ) + } + + /// The dimensions of this frame, in LF groups. + pub fn size_lf_groups(&self) -> (usize, usize) { + ( + self.size_blocks().0.div_ceil(self.group_dim()), + self.size_blocks().1.div_ceil(self.group_dim()), + ) + } + fn check(&self, nonserialized: &FrameHeaderNonserialized) -> Result<(), Error> { if self.upsampling > 1 { if let Some((info, upsampling)) = nonserialized diff --git a/jxl/src/headers/mod.rs b/jxl/src/headers/mod.rs index 1565c40..f959b1f 100644 --- a/jxl/src/headers/mod.rs +++ b/jxl/src/headers/mod.rs @@ -9,6 +9,7 @@ pub mod encodings; pub mod extra_channels; pub mod frame_header; pub mod image_metadata; +pub mod modular; pub mod permutation; pub mod size; pub mod transform_data; diff --git a/jxl/src/headers/modular.rs b/jxl/src/headers/modular.rs new file mode 100644 index 0000000..8e887aa --- /dev/null +++ b/jxl/src/headers/modular.rs @@ -0,0 +1,151 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +use crate::{ + bit_reader::BitReader, + error::{Error, Result}, + frame::modular::Predictor, + headers::encodings::*, +}; +use jxl_macros::UnconditionalCoder; +use num_derive::FromPrimitive; + +use super::encodings; + +#[derive(UnconditionalCoder, Debug, PartialEq)] +pub struct WeightedHeader { + #[all_default] + all_default: bool, + + #[coder(Bits(5))] + #[default(16)] + pub p1c: u32, + + #[coder(Bits(5))] + #[default(10)] + pub p2c: u32, + + #[coder(Bits(5))] + #[default(7)] + pub p3ca: u32, + + #[coder(Bits(5))] + #[default(7)] + pub p3cb: u32, + + #[coder(Bits(5))] + #[default(7)] + pub p3cc: u32, + + #[coder(Bits(5))] + #[default(0)] + pub p3cd: u32, + + #[coder(Bits(5))] + #[default(0)] + pub p3ce: u32, + + #[coder(Bits(4))] + #[default(0xd)] + pub w0: u32, + + #[coder(Bits(4))] + #[default(0xc)] + pub w1: u32, + + #[coder(Bits(4))] + #[default(0xc)] + pub w2: u32, + + #[coder(Bits(4))] + #[default(0xc)] + pub w3: u32, +} + +#[derive(UnconditionalCoder, Debug, PartialEq, Clone, Copy)] +pub struct SqueezeParams { + pub horizontal: bool, + pub in_place: bool, + #[coder(u2S(Bits(3), Bits(6) + 8, Bits(10) + 72, Bits(13) + 1096))] + pub begin_channel: u32, + #[coder(u2S(1, 2, 3, Bits(4) + 4))] + pub num_channels: u32, +} + +#[derive(UnconditionalCoder, Copy, Clone, PartialEq, Debug, FromPrimitive)] +pub enum TransformId { + Rct = 0, + Palette = 1, + Squeeze = 2, + Invalid = 3, +} + +#[derive(UnconditionalCoder, Debug, PartialEq)] +#[validate] +pub struct Transform { + #[coder(Bits(2))] + pub id: TransformId, + + #[condition(id == TransformId::Rct || id == TransformId::Palette)] + #[coder(u2S(Bits(3), Bits(6) + 8, Bits(10) + 72, Bits(13) + 1096))] + #[default(0)] + pub begin_channel: u32, + + #[condition(id == TransformId::Rct)] + #[coder(u2S(6, Bits(2), Bits(4) + 2, Bits(6) + 10))] + #[default(6)] + pub rct_type: u32, + + #[condition(id == TransformId::Palette)] + #[coder(u2S(1, 3, 4, Bits(13) + 1))] + #[default(3)] + pub num_channels: u32, + + #[condition(id == TransformId::Palette)] + #[coder(u2S(Bits(8), Bits(10) + 256, Bits(12) + 1280, Bits(16)+5376))] + #[default(256)] + pub num_colors: u32, + + #[condition(id == TransformId::Palette)] + #[coder(u2S(0, Bits(8)+1, Bits(10) + 257, Bits(16)+1281))] + #[default(0)] + pub num_deltas: u32, + + #[condition(id == TransformId::Palette)] + #[coder(Bits(4))] + #[default(0)] + pub predictor_id: u32, + + #[condition(id == TransformId::Squeeze)] + #[size_coder(implicit(u2S(0, Bits(4) + 1, Bits(6) + 9, Bits(8) + 41)))] + #[default(Vec::new())] + pub squeezes: Vec, +} + +impl Transform { + fn check(&self, _: &encodings::Empty) -> Result<()> { + if self.id == TransformId::Invalid { + return Err(Error::InvalidTransformId); + } + + if self.rct_type >= 42 { + return Err(Error::InvalidRCT(self.rct_type)); + } + + if self.predictor_id >= Predictor::NUM_PREDICTORS { + return Err(Error::InvalidPredictor(self.predictor_id)); + } + + Ok(()) + } +} + +#[derive(UnconditionalCoder, Debug, PartialEq)] +pub struct GroupHeader { + pub use_global_tree: bool, + pub wp_header: WeightedHeader, + #[size_coder(implicit(u2S(0, 1, Bits(4) + 2, Bits(8) + 18)))] + pub transforms: Vec, +}