Skip to content

Commit

Permalink
Implement header decoding for Modular groups
Browse files Browse the repository at this point in the history
  • Loading branch information
veluca93 committed Nov 24, 2024
1 parent 7699b25 commit d3626a1
Show file tree
Hide file tree
Showing 9 changed files with 693 additions and 208 deletions.
6 changes: 6 additions & 0 deletions jxl/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ pub enum Error {
TreeMultiplierBitsTooLarge(u32, u32),
#[error("Modular tree splits on property {0} at value {1}, which is outside the possible range of [{2}, {3}]")]
TreeSplitOnEmptyRange(u8, i32, i32, i32),
#[error("Modular stream requested a global tree but there isn't one")]
NoGlobalTree,
#[error("Invalid transform id")]
InvalidTransformId,
#[error("Invalid RCT type {0}")]
InvalidRCT(u32),
}

pub type Result<T, E = Error> = std::result::Result<T, E>;
27 changes: 19 additions & 8 deletions jxl/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ use crate::{
},
util::tracing_wrappers::*,
};
use modular::Tree;
use modular::{FullModularImage, Tree};
use quantizer::LfQuantFactors;

mod modular;
pub mod modular;
mod quantizer;

#[derive(Debug, PartialEq, Eq)]
Expand All @@ -29,18 +29,17 @@ pub enum Section {
Hf(usize, usize), // group, pass
}

#[allow(dead_code)]
pub struct LfGlobalState {
// TODO(veluca93): patches
// TODO(veluca93): splines
// TODO(veluca93): noise
#[allow(dead_code)]
lf_quant: LfQuantFactors,
// TODO(veluca93), VarDCT: HF quant matrices
// TODO(veluca93), VarDCT: block context map
// TODO(veluca93), VarDCT: LF color correlation
// TODO(veluca93): Modular data
#[allow(dead_code)]
tree: Option<Tree>,
modular_global: FullModularImage,
}

pub struct Frame {
Expand Down Expand Up @@ -121,9 +120,9 @@ impl Frame {
match section {
Section::LfGlobal => 0,
Section::Lf(a) => 1 + a,
Section::HfGlobal => self.header.num_dc_groups() + 1,
Section::HfGlobal => self.header.num_lf_groups() + 1,
Section::Hf(group, pass) => {
2 + self.header.num_dc_groups() + self.header.num_groups() * pass + group
2 + self.header.num_lf_groups() + self.header.num_groups() * pass + group
}
}
}
Expand Down Expand Up @@ -169,7 +168,19 @@ impl Frame {
None
};

self.lf_global = Some(LfGlobalState { lf_quant, tree });
let modular_global = FullModularImage::read(
&self.header,
self.modular_color_channels,
&self.extra_channel_info,
&tree,
br,
)?;

self.lf_global = Some(LfGlobalState {
lf_quant,
tree,
modular_global,
});

Ok(())
}
Expand Down
262 changes: 85 additions & 177 deletions jxl/src/frame/modular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,205 +3,113 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

use std::fmt::Debug;

use crate::{
bit_reader::BitReader,
entropy_coding::decode::Histograms,
error::{Error, Result},
headers::{
extra_channels::ExtraChannelInfo, frame_header::FrameHeader, modular::GroupHeader,
JxlHeader,
},
image::Image,
util::tracing_wrappers::*,
util::CeilLog2,
};

#[repr(u8)]
#[derive(Debug)]
enum Predictor {
Zero = 0,
Left = 1,
Top = 2,
Average0 = 3,
Select = 4,
Gradient = 5,
Weighted = 6,
TopRight = 7,
TopLeft = 8,
LeftLeft = 9,
Average1 = 10,
Average2 = 11,
Average3 = 12,
Average4 = 13,
}
mod predict;
mod transforms;
mod tree;

impl TryFrom<u32> for Predictor {
type Error = Error;

fn try_from(value: u32) -> Result<Self> {
match value {
0 => Ok(Predictor::Zero),
1 => Ok(Predictor::Left),
2 => Ok(Predictor::Top),
3 => Ok(Predictor::Average0),
4 => Ok(Predictor::Select),
5 => Ok(Predictor::Gradient),
6 => Ok(Predictor::Weighted),
7 => Ok(Predictor::TopRight),
8 => Ok(Predictor::TopLeft),
9 => Ok(Predictor::LeftLeft),
10 => Ok(Predictor::Average1),
11 => Ok(Predictor::Average2),
12 => Ok(Predictor::Average3),
13 => Ok(Predictor::Average4),
_ => Err(Error::InvalidPredictor(value)),
}
}
}
pub use predict::Predictor;
use transforms::Transform;
pub use tree::Tree;

#[allow(dead_code)]
#[derive(Debug)]
enum TreeNode {
Split {
property: u8,
val: i32,
left: u32,
right: u32,
},
Leaf {
predictor: Predictor,
offset: i32,
multiplier: u32,
id: u32,
},
struct ChannelInfo {
size: (usize, usize),
shift: Option<(isize, isize)>, // None for meta-channels
}

#[allow(dead_code)]
#[derive(Debug)]
pub struct Tree {
nodes: Vec<TreeNode>,
histograms: Histograms,
struct MetaInfo {
channels: Vec<ChannelInfo>,
transforms: Vec<Transform>,
}

pub struct FullModularImage {
// TODO: decoding graph for processing global transforms
meta_info: MetaInfo,
global_channels: Vec<Image<i32>>,
}

const SPLIT_VAL_CONTEXT: usize = 0;
const PROPERTY_CONTEXT: usize = 1;
const PREDICTOR_CONTEXT: usize = 2;
const OFFSET_CONTEXT: usize = 3;
const MULTIPLIER_LOG_CONTEXT: usize = 4;
const MULTIPLIER_BITS_CONTEXT: usize = 5;
const NUM_TREE_CONTEXTS: usize = 6;
impl Debug for FullModularImage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"[info: {:?}, global channel sizes: {:?}]",
self.meta_info,
self.global_channels
.iter()
.map(Image::size)
.collect::<Vec<_>>()
)
}
}

impl Tree {
#[instrument(level = "debug", skip(br), ret, err)]
pub fn read(br: &mut BitReader, size_limit: usize) -> Result<Tree> {
assert!(size_limit <= u32::MAX as usize);
trace!(pos = br.total_bits_read());
let tree_histograms = Histograms::decode(NUM_TREE_CONTEXTS, br, true)?;
let mut tree_reader = tree_histograms.make_reader(br)?;
// TODO(veluca): consider early-exiting for trees known to be infinite.
let mut tree: Vec<TreeNode> = vec![];
let mut to_decode = 1;
let mut leaf_id = 0;
let mut max_property = 0;
while to_decode > 0 {
if tree.len() > size_limit {
return Err(Error::TreeTooLarge(tree.len(), size_limit));
}
if tree.len() >= tree.capacity() {
tree.try_reserve(tree.len() * 2 + 1)?;
}
to_decode -= 1;
let property = tree_reader.read(br, PROPERTY_CONTEXT)?;
trace!(property);
if let Some(property) = property.checked_sub(1) {
// inner node.
if property > 255 {
return Err(Error::InvalidProperty(property));
}
max_property = max_property.max(property);
let splitval = tree_reader.read_signed(br, SPLIT_VAL_CONTEXT)?;
let left_child = (tree.len() + to_decode + 1) as u32;
let node = TreeNode::Split {
property: property as u8,
val: splitval,
left: left_child,
right: left_child + 1,
};
trace!("split node {:?}", node);
to_decode += 2;
tree.push(node);
} else {
let predictor = Predictor::try_from(tree_reader.read(br, PREDICTOR_CONTEXT)?)?;
let offset = tree_reader.read_signed(br, OFFSET_CONTEXT)?;
let mul_log = tree_reader.read(br, MULTIPLIER_LOG_CONTEXT)?;
if mul_log >= 31 {
return Err(Error::TreeMultiplierTooLarge(mul_log, 31));
}
let mul_bits = tree_reader.read(br, MULTIPLIER_BITS_CONTEXT)?;
let multiplier = (mul_bits as u64 + 1) << mul_log;
if multiplier > (u32::MAX as u64) {
return Err(Error::TreeMultiplierBitsTooLarge(mul_bits, mul_log));
}
let node = TreeNode::Leaf {
predictor,
offset,
id: leaf_id,
multiplier: multiplier as u32,
};
leaf_id += 1;
trace!("leaf node {:?}", node);
tree.push(node);
}
impl FullModularImage {
#[instrument(skip_all, ret)]
pub fn read(
header: &FrameHeader,
modular_color_channels: usize,
extra_channel_info: &[ExtraChannelInfo],
global_tree: &Option<Tree>,
br: &mut BitReader,
) -> Result<Self> {
let mut channels = vec![];
for c in 0..modular_color_channels {
let shift = (header.hshift(c) as isize, header.vshift(c) as isize);
let size = (header.width as usize, header.height as usize);
channels.push(ChannelInfo {
size: (size.0.div_ceil(1 << shift.0), size.1.div_ceil(1 << shift.1)),
shift: Some(shift),
});
}
tree_reader.check_final_state()?;

let num_properties = max_property as usize + 1;
let mut property_ranges = vec![];
property_ranges.try_reserve(num_properties * tree.len())?;
property_ranges.resize(num_properties * tree.len(), (i32::MIN, i32::MAX));
let mut height = vec![];
height.try_reserve(tree.len())?;
height.resize(tree.len(), 0);
for i in 0..tree.len() {
const HEIGHT_LIMIT: usize = 2048;
if height[i] > HEIGHT_LIMIT {
return Err(Error::TreeTooLarge(height[i], HEIGHT_LIMIT));
}
if let TreeNode::Split {
property,
val,
left,
right,
} = tree[i]
{
height[left as usize] = height[i] + 1;
height[right as usize] = height[i] + 1;
for p in 0..num_properties {
if p == property as usize {
let (l, u) = property_ranges[i * num_properties + p];
if l > val || u <= val {
return Err(Error::TreeSplitOnEmptyRange(p as u8, val, l, u));
}
trace!("splitting at node {i} on property {p}, range [{l}, {u}] at position {val}");
property_ranges[left as usize * num_properties + p] = (val + 1, u);
property_ranges[right as usize * num_properties + p] = (l, val);
} else {
property_ranges[left as usize * num_properties + p] =
property_ranges[i * num_properties + p];
property_ranges[right as usize * num_properties + p] =
property_ranges[i * num_properties + p];
}
}
} else {
#[cfg(feature = "tracing")]
{
for p in 0..num_properties {
let (l, u) = property_ranges[i * num_properties + p];
trace!("final range at node {i} property {p}: [{l}, {u}]");
}
}
}
for info in extra_channel_info {
let shift = info.dim_shift() as isize - header.upsampling.ceil_log2() as isize;
let size = header.size_upsampled();
let size = (size.0 >> info.dim_shift(), size.1 >> info.dim_shift());
channels.push(ChannelInfo {
size,
shift: Some((shift, shift)),
});
}

let histograms = Histograms::decode((tree.len() + 1) / 2, br, true)?;
trace!("reading modular header");
let header = GroupHeader::read(br)?;

if header.use_global_tree && global_tree.is_none() {
return Err(Error::NoGlobalTree);
}

let meta_info = MetaInfo {
transforms: header
.transforms
.iter()
.map(|x| Transform::from_bitstream(x, 0, &channels))
.filter(|x| !x.is_noop())
.collect(),
channels,
};

// TODO(veluca93): meta-apply transforms

Ok(Tree {
nodes: tree,
histograms,
Ok(FullModularImage {
meta_info,
global_channels: vec![], // TODO(veluca93): read global channels
})
}
}
Loading

0 comments on commit d3626a1

Please sign in to comment.