Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement header decoding for Modular groups #54

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions jxl/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ pub enum Error {
TreeMultiplierBitsTooLarge(u32, u32),
#[error("Modular tree splits on property {0} at value {1}, which is outside the possible range of [{2}, {3}]")]
TreeSplitOnEmptyRange(u8, i32, i32, i32),
#[error("Modular stream requested a global tree but there isn't one")]
NoGlobalTree,
#[error("Invalid transform id")]
InvalidTransformId,
#[error("Invalid RCT type {0}")]
InvalidRCT(u32),
}

pub type Result<T, E = Error> = std::result::Result<T, E>;
27 changes: 19 additions & 8 deletions jxl/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ use crate::{
},
util::tracing_wrappers::*,
};
use modular::Tree;
use modular::{FullModularImage, Tree};
use quantizer::LfQuantFactors;

mod modular;
pub mod modular;
mod quantizer;

#[derive(Debug, PartialEq, Eq)]
Expand All @@ -29,18 +29,17 @@ pub enum Section {
Hf(usize, usize), // group, pass
}

#[allow(dead_code)]
pub struct LfGlobalState {
// TODO(veluca93): patches
// TODO(veluca93): splines
// TODO(veluca93): noise
#[allow(dead_code)]
lf_quant: LfQuantFactors,
// TODO(veluca93), VarDCT: HF quant matrices
// TODO(veluca93), VarDCT: block context map
// TODO(veluca93), VarDCT: LF color correlation
// TODO(veluca93): Modular data
#[allow(dead_code)]
tree: Option<Tree>,
modular_global: FullModularImage,
}

pub struct Frame {
Expand Down Expand Up @@ -121,9 +120,9 @@ impl Frame {
match section {
Section::LfGlobal => 0,
Section::Lf(a) => 1 + a,
Section::HfGlobal => self.header.num_dc_groups() + 1,
Section::HfGlobal => self.header.num_lf_groups() + 1,
Section::Hf(group, pass) => {
2 + self.header.num_dc_groups() + self.header.num_groups() * pass + group
2 + self.header.num_lf_groups() + self.header.num_groups() * pass + group
}
}
}
Expand Down Expand Up @@ -169,7 +168,19 @@ impl Frame {
None
};

self.lf_global = Some(LfGlobalState { lf_quant, tree });
let modular_global = FullModularImage::read(
&self.header,
self.modular_color_channels,
&self.extra_channel_info,
&tree,
br,
)?;

self.lf_global = Some(LfGlobalState {
lf_quant,
tree,
modular_global,
});

Ok(())
}
Expand Down
262 changes: 85 additions & 177 deletions jxl/src/frame/modular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,205 +3,113 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

use std::fmt::Debug;

use crate::{
bit_reader::BitReader,
entropy_coding::decode::Histograms,
error::{Error, Result},
headers::{
extra_channels::ExtraChannelInfo, frame_header::FrameHeader, modular::GroupHeader,
JxlHeader,
},
image::Image,
util::tracing_wrappers::*,
util::CeilLog2,
};

#[repr(u8)]
#[derive(Debug)]
enum Predictor {
Zero = 0,
Left = 1,
Top = 2,
Average0 = 3,
Select = 4,
Gradient = 5,
Weighted = 6,
TopRight = 7,
TopLeft = 8,
LeftLeft = 9,
Average1 = 10,
Average2 = 11,
Average3 = 12,
Average4 = 13,
}
mod predict;
mod transforms;
mod tree;

impl TryFrom<u32> for Predictor {
type Error = Error;

fn try_from(value: u32) -> Result<Self> {
match value {
0 => Ok(Predictor::Zero),
1 => Ok(Predictor::Left),
2 => Ok(Predictor::Top),
3 => Ok(Predictor::Average0),
4 => Ok(Predictor::Select),
5 => Ok(Predictor::Gradient),
6 => Ok(Predictor::Weighted),
7 => Ok(Predictor::TopRight),
8 => Ok(Predictor::TopLeft),
9 => Ok(Predictor::LeftLeft),
10 => Ok(Predictor::Average1),
11 => Ok(Predictor::Average2),
12 => Ok(Predictor::Average3),
13 => Ok(Predictor::Average4),
_ => Err(Error::InvalidPredictor(value)),
}
}
}
pub use predict::Predictor;
use transforms::Transform;
pub use tree::Tree;

#[allow(dead_code)]
#[derive(Debug)]
enum TreeNode {
Split {
property: u8,
val: i32,
left: u32,
right: u32,
},
Leaf {
predictor: Predictor,
offset: i32,
multiplier: u32,
id: u32,
},
struct ChannelInfo {
size: (usize, usize),
shift: Option<(isize, isize)>, // None for meta-channels
}

#[allow(dead_code)]
#[derive(Debug)]
pub struct Tree {
nodes: Vec<TreeNode>,
histograms: Histograms,
struct MetaInfo {
channels: Vec<ChannelInfo>,
transforms: Vec<Transform>,
}

pub struct FullModularImage {
// TODO: decoding graph for processing global transforms
meta_info: MetaInfo,
global_channels: Vec<Image<i32>>,
}

const SPLIT_VAL_CONTEXT: usize = 0;
const PROPERTY_CONTEXT: usize = 1;
const PREDICTOR_CONTEXT: usize = 2;
const OFFSET_CONTEXT: usize = 3;
const MULTIPLIER_LOG_CONTEXT: usize = 4;
const MULTIPLIER_BITS_CONTEXT: usize = 5;
const NUM_TREE_CONTEXTS: usize = 6;
impl Debug for FullModularImage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"[info: {:?}, global channel sizes: {:?}]",
self.meta_info,
self.global_channels
.iter()
.map(Image::size)
.collect::<Vec<_>>()
)
}
}

impl Tree {
#[instrument(level = "debug", skip(br), ret, err)]
pub fn read(br: &mut BitReader, size_limit: usize) -> Result<Tree> {
assert!(size_limit <= u32::MAX as usize);
trace!(pos = br.total_bits_read());
let tree_histograms = Histograms::decode(NUM_TREE_CONTEXTS, br, true)?;
let mut tree_reader = tree_histograms.make_reader(br)?;
// TODO(veluca): consider early-exiting for trees known to be infinite.
let mut tree: Vec<TreeNode> = vec![];
let mut to_decode = 1;
let mut leaf_id = 0;
let mut max_property = 0;
while to_decode > 0 {
if tree.len() > size_limit {
return Err(Error::TreeTooLarge(tree.len(), size_limit));
}
if tree.len() >= tree.capacity() {
tree.try_reserve(tree.len() * 2 + 1)?;
}
to_decode -= 1;
let property = tree_reader.read(br, PROPERTY_CONTEXT)?;
trace!(property);
if let Some(property) = property.checked_sub(1) {
// inner node.
if property > 255 {
return Err(Error::InvalidProperty(property));
}
max_property = max_property.max(property);
let splitval = tree_reader.read_signed(br, SPLIT_VAL_CONTEXT)?;
let left_child = (tree.len() + to_decode + 1) as u32;
let node = TreeNode::Split {
property: property as u8,
val: splitval,
left: left_child,
right: left_child + 1,
};
trace!("split node {:?}", node);
to_decode += 2;
tree.push(node);
} else {
let predictor = Predictor::try_from(tree_reader.read(br, PREDICTOR_CONTEXT)?)?;
let offset = tree_reader.read_signed(br, OFFSET_CONTEXT)?;
let mul_log = tree_reader.read(br, MULTIPLIER_LOG_CONTEXT)?;
if mul_log >= 31 {
return Err(Error::TreeMultiplierTooLarge(mul_log, 31));
}
let mul_bits = tree_reader.read(br, MULTIPLIER_BITS_CONTEXT)?;
let multiplier = (mul_bits as u64 + 1) << mul_log;
if multiplier > (u32::MAX as u64) {
return Err(Error::TreeMultiplierBitsTooLarge(mul_bits, mul_log));
}
let node = TreeNode::Leaf {
predictor,
offset,
id: leaf_id,
multiplier: multiplier as u32,
};
leaf_id += 1;
trace!("leaf node {:?}", node);
tree.push(node);
}
impl FullModularImage {
#[instrument(level = "debug", skip_all, ret)]
pub fn read(
header: &FrameHeader,
modular_color_channels: usize,
extra_channel_info: &[ExtraChannelInfo],
global_tree: &Option<Tree>,
br: &mut BitReader,
) -> Result<Self> {
let mut channels = vec![];
for c in 0..modular_color_channels {
let shift = (header.hshift(c) as isize, header.vshift(c) as isize);
let size = (header.width as usize, header.height as usize);
channels.push(ChannelInfo {
size: (size.0.div_ceil(1 << shift.0), size.1.div_ceil(1 << shift.1)),
shift: Some(shift),
});
}
tree_reader.check_final_state()?;

let num_properties = max_property as usize + 1;
let mut property_ranges = vec![];
property_ranges.try_reserve(num_properties * tree.len())?;
property_ranges.resize(num_properties * tree.len(), (i32::MIN, i32::MAX));
let mut height = vec![];
height.try_reserve(tree.len())?;
height.resize(tree.len(), 0);
for i in 0..tree.len() {
const HEIGHT_LIMIT: usize = 2048;
if height[i] > HEIGHT_LIMIT {
return Err(Error::TreeTooLarge(height[i], HEIGHT_LIMIT));
}
if let TreeNode::Split {
property,
val,
left,
right,
} = tree[i]
{
height[left as usize] = height[i] + 1;
height[right as usize] = height[i] + 1;
for p in 0..num_properties {
if p == property as usize {
let (l, u) = property_ranges[i * num_properties + p];
if l > val || u <= val {
return Err(Error::TreeSplitOnEmptyRange(p as u8, val, l, u));
}
trace!("splitting at node {i} on property {p}, range [{l}, {u}] at position {val}");
property_ranges[left as usize * num_properties + p] = (val + 1, u);
property_ranges[right as usize * num_properties + p] = (l, val);
} else {
property_ranges[left as usize * num_properties + p] =
property_ranges[i * num_properties + p];
property_ranges[right as usize * num_properties + p] =
property_ranges[i * num_properties + p];
}
}
} else {
#[cfg(feature = "tracing")]
{
for p in 0..num_properties {
let (l, u) = property_ranges[i * num_properties + p];
trace!("final range at node {i} property {p}: [{l}, {u}]");
}
}
}
for info in extra_channel_info {
let shift = info.dim_shift() as isize - header.upsampling.ceil_log2() as isize;
let size = header.size_upsampled();
let size = (size.0 >> info.dim_shift(), size.1 >> info.dim_shift());
channels.push(ChannelInfo {
size,
shift: Some((shift, shift)),
});
}

let histograms = Histograms::decode((tree.len() + 1) / 2, br, true)?;
trace!("reading modular header");
let header = GroupHeader::read(br)?;

if header.use_global_tree && global_tree.is_none() {
return Err(Error::NoGlobalTree);
}

let meta_info = MetaInfo {
transforms: header
.transforms
.iter()
.map(|x| Transform::from_bitstream(x, 0, &channels))
.filter(|x| !x.is_noop())
.collect(),
channels,
};

// TODO(veluca93): meta-apply transforms

Ok(Tree {
nodes: tree,
histograms,
Ok(FullModularImage {
meta_info,
global_channels: vec![], // TODO(veluca93): read global channels
})
}
}
Loading