diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 32cdfa6..7e397ea 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -25,11 +25,20 @@ jobs: run: ./ci/pull_request_checks.sh - rust: - name: Test, Format and Clippy - runs-on: [ubuntu-latest] + checks: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + check: [format, clippy, test] + features: [all, default] + exclude: + # Remove the "format+all" combination, since it is the same as "format+default" + - check: format + features: all steps: - - uses: actions/checkout@v4 + - name: Checkout + uses: actions/checkout@v4 with: submodules: recursive @@ -38,20 +47,28 @@ jobs: with: components: clippy, rustfmt - - name: Build artefact caching + - name: Rust cache uses: Swatinem/rust-cache@v2.7.3 - - name: Format + # format + - name: Cargo fmt (check) + if: ${{ matrix.check == 'format' }} run: cargo fmt --all -- --check - - name: Clippy + # clippy + - name: Clippy with all features + if: ${{ matrix.check == 'clippy' && matrix.features == 'all' }} run: cargo clippy --release --all-targets --all-features --tests --all -- -D warnings - name: Clippy with default features + if: ${{ matrix.check == 'clippy' && matrix.features == 'default' }} run: cargo clippy --release --all-targets --tests --all -- -D warnings - - name: Run tests + # test + - name: Tests with all features + if: ${{ matrix.check == 'test' && matrix.features == 'all' }} run: cargo test --release --all --no-fail-fast --all-features - - name: Run tests with default features + - name: Tests with default features + if: ${{ matrix.check == 'test' && matrix.features == 'default' }} run: cargo test --release --all --no-fail-fast diff --git a/jxl/src/entropy_coding/huffman.rs b/jxl/src/entropy_coding/huffman.rs index cedf990..371fd65 100644 --- a/jxl/src/entropy_coding/huffman.rs +++ b/jxl/src/entropy_coding/huffman.rs @@ -8,8 +8,7 @@ use std::fmt::Debug; use crate::bit_reader::BitReader; use crate::entropy_coding::decode::*; use crate::error::{Error, Result}; -use crate::util::tracing_wrappers::*; -use crate::util::*; +use crate::util::{tracing_wrappers::*, CeilLog2, NewWithCapacity}; pub const HUFFMAN_MAX_BITS: usize = 15; const TABLE_BITS: usize = 8; @@ -104,7 +103,7 @@ impl Table { TABLE_SIZE ]), (2, _) => { - let mut ret = Vec::with_capacity(TABLE_SIZE); + let mut ret = Vec::new_with_capacity(TABLE_SIZE)?; for _ in 0..(TABLE_SIZE >> 1) { ret.push(TableEntry { bits: 1, @@ -118,7 +117,7 @@ impl Table { Ok(ret) } (3, _) => { - let mut ret = Vec::with_capacity(TABLE_SIZE); + let mut ret = Vec::new_with_capacity(TABLE_SIZE)?; for _ in 0..(TABLE_SIZE >> 2) { ret.push(TableEntry { bits: 1, @@ -140,7 +139,7 @@ impl Table { Ok(ret) } (4, false) => { - let mut ret = Vec::with_capacity(TABLE_SIZE); + let mut ret = Vec::new_with_capacity(TABLE_SIZE)?; for _ in 0..(TABLE_SIZE >> 2) { ret.push(TableEntry { bits: 2, @@ -162,7 +161,7 @@ impl Table { Ok(ret) } (4, true) => { - let mut ret = Vec::with_capacity(TABLE_SIZE); + let mut ret = Vec::new_with_capacity(TABLE_SIZE)?; symbols[2..4].sort_unstable(); for _ in 0..(TABLE_SIZE >> 3) { ret.push(TableEntry { diff --git a/jxl/src/entropy_coding/hybrid_uint.rs b/jxl/src/entropy_coding/hybrid_uint.rs index 9e645c7..2077110 100644 --- a/jxl/src/entropy_coding/hybrid_uint.rs +++ b/jxl/src/entropy_coding/hybrid_uint.rs @@ -6,7 +6,7 @@ use crate::bit_reader::BitReader; use crate::error::Error; -use crate::util::*; +use crate::util::CeilLog2; #[derive(Debug)] pub struct HybridUint { diff --git a/jxl/src/error.rs b/jxl/src/error.rs index 641e6f2..ef4b232 100644 --- a/jxl/src/error.rs +++ b/jxl/src/error.rs @@ -113,6 +113,24 @@ pub enum Error { InvalidPredictor(u32), #[error("Invalid modular mode property: {0}")] InvalidProperty(u32), + #[error("Invalid alpha channel for blending: {0}, limit is {1}")] + PatchesInvalidAlphaChannel(usize, usize), + #[error("Invalid patch blend mode: {0}, limit is {1}")] + PatchesInvalidBlendMode(u8, u8), + #[error("Invalid Patch: negative {0}-coordinate: {1} base {0}, {2} delta {0}")] + PatchesInvalidDelta(String, usize, i32), + #[error("Invalid position specified in reference frame in {0}-coordinate: {0}0 + {0}size = {1} + {2} > {3} = reference_frame {0}size")] + PatchesInvalidPosition(String, usize, usize, usize), + #[error("Patches invalid reference frame at index {0}")] + PatchesInvalidReference(usize), + #[error("Invalid Patch {0}: at {1} + {2} > {3}")] + PatchesOutOfBounds(String, usize, usize, usize), + #[error("Patches cannot use frames saved post color transforms")] + PatchesPostColorTransform(), + #[error("Too many {0}: {1}, limit is {2}")] + PatchesTooMany(String, usize, usize), + #[error("Reference too large: {0}, limit is {1}")] + PatchesRefTooLarge(usize, usize), #[error("Point list is empty")] PointListEmpty, #[error("Too large area for spline: {0}, limit is {1}")] @@ -122,7 +140,7 @@ pub enum Error { #[error("Too many splines: {0}, limit is {1}")] SplinesTooMany(u32, u32), #[error("Spline has adjacent coinciding control points: point[{0}]: {1:?}, point[{2}]: {3:?}")] - SplineAdjacentCoincidingControlPoints(u32, Point, u32, Point), + SplineAdjacentCoincidingControlPoints(usize, Point, usize, Point), #[error("Too many control points for splines: {0}, limit is {1}")] SplinesTooManyControlPoints(u32, u32), #[error( diff --git a/jxl/src/features/mod.rs b/jxl/src/features/mod.rs index 5494033..685ca5e 100644 --- a/jxl/src/features/mod.rs +++ b/jxl/src/features/mod.rs @@ -4,4 +4,5 @@ // license that can be found in the LICENSE file. pub mod noise; +pub mod patches; pub mod spline; diff --git a/jxl/src/features/patches.rs b/jxl/src/features/patches.rs new file mode 100644 index 0000000..f164d70 --- /dev/null +++ b/jxl/src/features/patches.rs @@ -0,0 +1,343 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TODO(firsching): remove once we use this! +#![allow(dead_code)] + +use num_derive::FromPrimitive; +use num_traits::FromPrimitive; + +use crate::{ + bit_reader::BitReader, + entropy_coding::decode::Histograms, + error::{Error, Result}, + frame::DecoderState, + util::{tracing_wrappers::*, NewWithCapacity}, +}; + +// Context numbers as specified in Section C.4.5, Listing C.2: +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[repr(usize)] +pub enum PatchContext { + NumRefPatch = 0, + ReferenceFrame = 1, + PatchSize = 2, + PatchReferencePosition = 3, + PatchPosition = 4, + PatchBlendMode = 5, + PatchOffset = 6, + PatchCount = 7, + PatchAlphaChannel = 8, + PatchClamp = 9, +} + +impl PatchContext { + const NUM: usize = 10; +} + +/// Blend modes +#[derive(Debug, PartialEq, Eq, Clone, Copy, FromPrimitive)] +#[repr(u8)] +pub enum PatchBlendMode { + // The new values are the old ones. Useful to skip some channels. + None = 0, + // The new values (in the crop) replace the old ones: sample = new + Replace = 1, + // The new values (in the crop) get added to the old ones: sample = old + new + Add = 2, + // The new values (in the crop) get multiplied by the old ones: + // sample = old * new + // This blend mode is only supported if BlendColorSpace is kEncoded. The + // range of the new value matters for multiplication purposes, and its + // nominal range of 0..1 is computed the same way as this is done for the + // alpha values in kBlend and kAlphaWeightedAdd. + Mul = 3, + // The new values (in the crop) replace the old ones if alpha>0: + // For first alpha channel: + // alpha = old + new * (1 - old) + // For other channels if !alpha_associated: + // sample = ((1 - new_alpha) * old * old_alpha + new_alpha * new) / alpha + // For other channels if alpha_associated: + // sample = (1 - new_alpha) * old + new + // The alpha formula applies to the alpha used for the division in the other + // channels formula, and applies to the alpha channel itself if its + // blend_channel value matches itself. + // If using kBlendAbove, new is the patch and old is the original image; if + // using kBlendBelow, the meaning is inverted. + BlendAbove = 4, + BlendBelow = 5, + // The new values (in the crop) are added to the old ones if alpha>0: + // For first alpha channel: sample = sample = old + new * (1 - old) + // For other channels: sample = old + alpha * new + AlphaWeightedAddAbove = 6, + AlphaWeightedAddBelow = 7, +} + +impl PatchBlendMode { + pub const NUM_BLEND_MODES: u8 = 8; + + pub fn uses_alpha(self) -> bool { + matches!( + self, + Self::BlendAbove + | Self::BlendBelow + | Self::AlphaWeightedAddAbove + | Self::AlphaWeightedAddBelow + ) + } + + pub fn uses_clamp(self) -> bool { + self.uses_alpha() || self == Self::Mul + } +} + +#[derive(Debug, Clone, Copy)] +struct PatchBlending { + mode: PatchBlendMode, + alpha_channel: usize, + clamp: bool, +} + +#[derive(Debug, Clone, Copy)] +pub struct PatchReferencePosition { + // Not using `ref` like in the spec here, because it is a keyword. + reference: usize, + x0: usize, + y0: usize, + xsize: usize, + ysize: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct PatchPosition { + x: usize, + y: usize, + ref_pos_idx: usize, +} + +#[derive(Debug, Default)] +pub struct PatchesDictionary { + pub positions: Vec, + ref_positions: Vec, + blendings: Vec, + blendings_stride: usize, +} + +impl PatchesDictionary { + #[instrument(level = "debug", skip(br), ret, err)] + pub fn read( + br: &mut BitReader, + xsize: usize, + ysize: usize, + decoder_state: &DecoderState, + ) -> Result { + let num_extra_channels = decoder_state.extra_channel_info().len(); + let blendings_stride = num_extra_channels + 1; + let patches_histograms = Histograms::decode(PatchContext::NUM, br, true)?; + let mut patches_reader = patches_histograms.make_reader(br)?; + let num_ref_patch = patches_reader.read(br, PatchContext::NumRefPatch as usize)? as usize; + let num_pixels = xsize * ysize; + let max_ref_patches = 1024 + num_pixels / 4; + let max_patches = max_ref_patches * 4; + let max_blending_infos = max_patches * 4; + if num_ref_patch > max_ref_patches { + return Err(Error::PatchesTooMany( + "reference patches".to_string(), + num_ref_patch, + max_ref_patches, + )); + } + let mut total_patches = 0; + let mut next_size = 1; + let mut positions: Vec = Vec::new(); + let mut blendings = Vec::new(); + let mut ref_positions = Vec::new_with_capacity(num_ref_patch)?; + for _ in 0..num_ref_patch { + let reference = + patches_reader.read(br, PatchContext::ReferenceFrame as usize)? as usize; + if reference >= DecoderState::MAX_STORED_FRAMES { + return Err(Error::PatchesRefTooLarge( + reference, + DecoderState::MAX_STORED_FRAMES, + )); + } + + let x0 = + patches_reader.read(br, PatchContext::PatchReferencePosition as usize)? as usize; + let y0 = + patches_reader.read(br, PatchContext::PatchReferencePosition as usize)? as usize; + let ref_pos_xsize = + patches_reader.read(br, PatchContext::PatchSize as usize)? as usize + 1; + let ref_pos_ysize = + patches_reader.read(br, PatchContext::PatchSize as usize)? as usize + 1; + let reference_frame = decoder_state.reference_frame(reference); + // TODO(firsching): make sure this check is correct in the presence of downsampled extra channels (also in libjxl). + match reference_frame { + None => return Err(Error::PatchesInvalidReference(reference)), + Some(reference) => { + if !reference.saved_before_color_transform { + return Err(Error::PatchesPostColorTransform()); + } + if x0 + ref_pos_xsize > reference.frame[0].size.0 { + return Err(Error::PatchesInvalidPosition( + "x".to_string(), + x0, + ref_pos_xsize, + reference.frame[0].size.0, + )); + } + if y0 + ref_pos_ysize > reference.frame[0].size.1 { + return Err(Error::PatchesInvalidPosition( + "y".to_string(), + y0, + ref_pos_ysize, + reference.frame[0].size.1, + )); + } + } + } + + let id_count = patches_reader.read(br, PatchContext::PatchCount as usize)? as usize + 1; + if id_count > max_patches + 1 { + return Err(Error::PatchesTooMany( + "patches".to_string(), + id_count, + max_patches, + )); + } + total_patches += id_count; + + if total_patches > max_patches { + return Err(Error::PatchesTooMany( + "patches".to_string(), + total_patches, + max_patches, + )); + } + + if next_size < total_patches { + next_size *= 2; + next_size = std::cmp::min(next_size, max_patches); + } + if next_size * blendings_stride > max_blending_infos { + return Err(Error::PatchesTooMany( + "blending_info".to_string(), + total_patches, + max_patches, + )); + } + positions.try_reserve(next_size.saturating_sub(positions.len()))?; + blendings.try_reserve( + (next_size * PatchBlendMode::NUM_BLEND_MODES as usize) + .saturating_sub(blendings.len()), + )?; + + for i in 0..id_count { + let mut pos = PatchPosition { + x: 0, + y: 0, + ref_pos_idx: ref_positions.len(), + }; + if i == 0 { + // Read initial position + pos.x = patches_reader.read(br, PatchContext::PatchPosition as usize)? as usize; + pos.y = patches_reader.read(br, PatchContext::PatchPosition as usize)? as usize; + } else { + // Read offsets and calculate new position + let delta_x = + patches_reader.read_signed(br, PatchContext::PatchOffset as usize)?; + if delta_x < 0 && (-delta_x as usize) > positions.last().unwrap().x { + return Err(Error::PatchesInvalidDelta( + "x".to_string(), + positions.last().unwrap().x, + delta_x, + )); + } + pos.x = (positions.last().unwrap().x as i32 + delta_x) as usize; + + let delta_y = + patches_reader.read_signed(br, PatchContext::PatchOffset as usize)?; + if delta_y < 0 && (-delta_y as usize) > positions.last().unwrap().y { + return Err(Error::PatchesInvalidDelta( + "y".to_string(), + positions.last().unwrap().y, + delta_y, + )); + } + pos.y = (positions.last().unwrap().y as i32 + delta_y) as usize; + } + + if pos.x + ref_pos_xsize > xsize { + return Err(Error::PatchesOutOfBounds( + "x".to_string(), + pos.x, + ref_pos_xsize, + xsize, + )); + } + if pos.y + ref_pos_ysize > ysize { + return Err(Error::PatchesOutOfBounds( + "y".to_string(), + pos.y, + ref_pos_ysize, + ysize, + )); + } + + let mut alpha_channel = 0; + let mut clamp = false; + for _ in 0..blendings_stride { + let maybe_blend_mode = + patches_reader.read(br, PatchContext::PatchBlendMode as usize)? as u8; + let blend_mode = match PatchBlendMode::from_u8(maybe_blend_mode) { + None => { + return Err(Error::PatchesInvalidBlendMode( + maybe_blend_mode, + PatchBlendMode::NUM_BLEND_MODES, + )) + } + Some(blend_mode) => blend_mode, + }; + + if PatchBlendMode::uses_alpha(blend_mode) { + alpha_channel = patches_reader + .read(br, PatchContext::PatchAlphaChannel as usize)? + as usize; + if alpha_channel >= num_extra_channels { + return Err(Error::PatchesInvalidAlphaChannel( + alpha_channel, + num_extra_channels, + )); + } + } + + if PatchBlendMode::uses_clamp(blend_mode) { + clamp = patches_reader.read(br, PatchContext::PatchClamp as usize)? != 0; + } + blendings.push(PatchBlending { + mode: blend_mode, + alpha_channel, + clamp, + }); + } + positions.push(pos); + } + + ref_positions.push(PatchReferencePosition { + reference, + x0, + y0, + xsize: ref_pos_xsize, + ysize: ref_pos_ysize, + }) + } + Ok(PatchesDictionary { + positions, + blendings, + ref_positions, + blendings_stride, + }) + } +} diff --git a/jxl/src/features/spline.rs b/jxl/src/features/spline.rs index ec29afb..7b5bb68 100644 --- a/jxl/src/features/spline.rs +++ b/jxl/src/features/spline.rs @@ -5,13 +5,17 @@ // TODO(firsching): remove once we use this! #![allow(dead_code)] -use std::{f32::consts::FRAC_1_SQRT_2, iter, ops}; +use std::{ + f32::consts::{FRAC_1_SQRT_2, PI, SQRT_2}, + iter::{self, zip}, + ops, +}; use crate::{ bit_reader::BitReader, entropy_coding::decode::{unpack_signed, Histograms, Reader}, error::{Error, Result}, - util::{tracing_wrappers::*, CeilLog2}, + util::{tracing_wrappers::*, CeilLog2, NewWithCapacity}, }; const MAX_NUM_CONTROL_POINTS: u32 = 1 << 20; const MAX_NUM_CONTROL_POINTS_PER_PIXEL_RATIO: u32 = 2; @@ -25,6 +29,7 @@ const NUM_CONTROL_POINTS_CONTEXT: usize = 3; const CONTROL_POINTS_CONTEXT: usize = 4; const DCT_CONTEXT: usize = 5; const NUM_SPLINE_CONTEXTS: usize = 6; +const DESIRED_RENDERING_DISTANCE: f32 = 1.0; #[derive(Debug, Clone, Copy, Default)] pub struct Point { @@ -88,32 +93,34 @@ impl ops::Div for Point { } } +#[derive(Default)] pub struct Spline { control_points: Vec, // X, Y, B. - color_dct: [[f32; 32]; 3], + color_dct: [Dct32; 3], // Splines are drawn by normalized Gaussian splatting. This controls the // Gaussian's parameter along the spline. - sigma_dct: [f32; 32], + sigma_dct: Dct32, // The estimated area in pixels covered by the spline. estimated_area_reached: u64, } impl Spline { pub fn validate_adjacent_point_coincidence(&self) -> Result<()> { - if let Some(item) = self - .control_points - .iter() - .enumerate() - .find(|(index, point)| { - index + 1 < self.control_points.len() && self.control_points[index + 1] == **point - }) + if let Some(((index, p0), p1)) = zip( + self.control_points + .iter() + .take(self.control_points.len() - 1) + .enumerate(), + self.control_points.iter().skip(1), + ) + .find(|((_, p0), p1)| **p0 == **p1) { return Err(Error::SplineAdjacentCoincidingControlPoints( - item.0 as u32, - *item.1, - (item.0 + 1) as u32, - self.control_points[item.0 + 1], + index, + *p0, + index + 1, + *p1, )); } Ok(()) @@ -185,7 +192,7 @@ impl QuantizedSpline { max_control_points, )); } - let mut control_points = Vec::with_capacity(num_control_points as usize); + let mut control_points = Vec::new_with_capacity(num_control_points as usize)?; for _ in 0..num_control_points { let x = splines_reader.read_signed(br, CONTROL_POINTS_CONTEXT)? as i64; let y = splines_reader.read_signed(br, CONTROL_POINTS_CONTEXT)? as i64; @@ -230,10 +237,8 @@ impl QuantizedSpline { let area_limit = area_limit(image_size); let mut result = Spline { - control_points: Vec::with_capacity(self.control_points.len() + 1), - color_dct: [[0.0; 32]; 3], - sigma_dct: [0.0; 32], - estimated_area_reached: 0, + control_points: Vec::new_with_capacity(self.control_points.len() + 1)?, + ..Default::default() }; let px = starting_point.x.round(); @@ -279,14 +284,14 @@ impl QuantizedSpline { for (c, weight) in CHANNEL_WEIGHT.iter().enumerate().take(3) { for i in 0..32 { let inv_dct_factor = if i == 0 { FRAC_1_SQRT_2 } else { 1.0 }; - result.color_dct[c][i] = + result.color_dct[c].0[i] = self.color_dct[c][i] as f32 * inv_dct_factor * weight * inv_quant; } } for i in 0..32 { - result.color_dct[0][i] += y_to_x * result.color_dct[1][i]; - result.color_dct[2][i] += y_to_b * result.color_dct[1][i]; + result.color_dct[0].0[i] += y_to_x * result.color_dct[1].0[i]; + result.color_dct[2].0[i] += y_to_b * result.color_dct[1].0[i]; } let mut width_estimate = 0; @@ -310,7 +315,7 @@ impl QuantizedSpline { for i in 0..32 { let inv_dct_factor = if i == 0 { FRAC_1_SQRT_2 } else { 1.0 }; - result.sigma_dct[i] = + result.sigma_dct.0[i] = self.sigma_dct[i] as f32 * inv_dct_factor * CHANNEL_WEIGHT[3] * inv_quant; let weight_f = (inv_quant * self.sigma_dct[i] as f32).abs().ceil(); @@ -341,15 +346,15 @@ pub struct Splines { pub starting_points: Vec, segments: Vec, segment_indices: Vec, - segment_y_start: Vec, + segment_y_start: Vec, } fn draw_centripetal_catmull_rom_spline(points: &[Point]) -> Result> { if points.is_empty() { - return Ok([].to_vec()); + return Ok(vec![]); } if points.len() == 1 { - return Ok([points[0]].to_vec()); + return Ok(vec![points[0]]); } const NUM_POINTS: usize = 16; // Create a view of points with one prepended and one appended point. @@ -439,20 +444,120 @@ fn for_each_equally_spaced_point( f(points[points.len() - 1], accumulated_distance); } -const DESIRED_RENDERING_DISTANCE: f32 = 1.0; +#[derive(Default, Clone, Copy, Debug)] +struct Dct32([f32; 32]); + +impl Dct32 { + fn continuous_idct(&self, t: f32) -> f32 { + const MULTIPLIERS: [f32; 32] = [ + PI / 32.0 * 0.0, + PI / 32.0 * 1.0, + PI / 32.0 * 2.0, + PI / 32.0 * 3.0, + PI / 32.0 * 4.0, + PI / 32.0 * 5.0, + PI / 32.0 * 6.0, + PI / 32.0 * 7.0, + PI / 32.0 * 8.0, + PI / 32.0 * 9.0, + PI / 32.0 * 10.0, + PI / 32.0 * 11.0, + PI / 32.0 * 12.0, + PI / 32.0 * 13.0, + PI / 32.0 * 14.0, + PI / 32.0 * 15.0, + PI / 32.0 * 16.0, + PI / 32.0 * 17.0, + PI / 32.0 * 18.0, + PI / 32.0 * 19.0, + PI / 32.0 * 20.0, + PI / 32.0 * 21.0, + PI / 32.0 * 22.0, + PI / 32.0 * 23.0, + PI / 32.0 * 24.0, + PI / 32.0 * 25.0, + PI / 32.0 * 26.0, + PI / 32.0 * 27.0, + PI / 32.0 * 28.0, + PI / 32.0 * 29.0, + PI / 32.0 * 30.0, + PI / 32.0 * 31.0, + ]; + let tandhalf = t + 0.5; + zip(MULTIPLIERS.iter(), self.0.iter()) + .map(|(multiplier, coeff)| SQRT_2 * coeff * (multiplier * tandhalf).cos()) + .sum() + } +} impl Splines { + fn add_segment( + &mut self, + center: &Point, + intensity: f32, + color: [f32; 3], + sigma: f32, + segments_by_y: &mut Vec<(u64, usize)>, + ) { + if sigma.is_infinite() + || sigma == 0.0 + || (1.0 / sigma).is_infinite() + || intensity.is_infinite() + { + return; + } + // TODO(zond): Use 3 if not JXL_HIGH_PRECISION + const DISTANCE_EXP: f32 = 5.0; + let max_color = color + .iter() + .map(|chan| (chan * intensity).abs()) + .max_by(|a, b| a.total_cmp(b)) + .unwrap(); + let max_distance = + (-2.0 * sigma * sigma * (0.1f32.ln() * DISTANCE_EXP - max_color.ln())).sqrt(); + let segment = SplineSegment { + center_x: center.x, + center_y: center.y, + color, + inv_sigma: 1.0 / sigma, + sigma_over_4_times_intensity: 0.25 * sigma * intensity, + maximum_distance: max_distance, + }; + let y0 = (center.y - max_distance).round() as i64; + let y1 = (center.y + max_distance).round() as i64 + 1; + for y in 0.max(y0)..y1 { + segments_by_y.push((y as u64, self.segments.len())); + } + self.segments.push(segment); + } + + fn add_segments_from_points( + &mut self, + spline: &Spline, + points_to_draw: &[(Point, f32)], + length: f32, + desired_distance: f32, + segments_by_y: &mut Vec<(u64, usize)>, + ) { + let inv_length = 1.0 / length; + for (point_index, (point, multiplier)) in points_to_draw.iter().enumerate() { + let progress = (point_index as f32 * desired_distance * inv_length).min(1.0); + let mut color = [0.0; 3]; + for (index, coeffs) in spline.color_dct.iter().enumerate() { + color[index] = coeffs.continuous_idct((32.0 - 1.0) * progress); + } + let sigma = spline.sigma_dct.continuous_idct((32.0 - 1.0) * progress); + self.add_segment(point, *multiplier, color, sigma, segments_by_y); + } + } + fn has_any(&self) -> bool { !self.splines.is_empty() } // TODO(zond): Add color correlation as parameter. pub fn initialize_draw_cache(&mut self, image_xsize: u64, image_ysize: u64) -> Result<()> { - self.segments.clear(); - self.segment_indices.clear(); self.segment_y_start.clear(); - // let mut segments_by_y = Vec::new(); - // let mut intermediate_points = Vec::new(); let mut total_estimated_area_reached = 0u64; let mut splines = Vec::new(); // TODO(zond): Use color correlation here. @@ -487,6 +592,9 @@ impl Splines { ); } + let mut segments_by_y = Vec::new(); + + self.segments.clear(); for spline in splines { let mut points_to_draw = Vec::<(Point, f32)>::new(); let intermediate_points = draw_centripetal_catmull_rom_spline(&spline.control_points)?; @@ -500,9 +608,37 @@ impl Splines { if length <= 0.0 { continue; } + self.add_segments_from_points( + &spline, + &points_to_draw, + length, + DESIRED_RENDERING_DISTANCE, + &mut segments_by_y, + ); } - todo!("finish translating this function from C++"); + // TODO(from libjxl): Consider linear sorting here. + segments_by_y.sort_by_key(|segment| segment.0); + + self.segment_indices.clear(); + self.segment_indices.try_reserve(segments_by_y.len())?; + self.segment_indices.resize(segments_by_y.len(), 0); + + self.segment_y_start.clear(); + self.segment_y_start.try_reserve(image_ysize as usize + 1)?; + self.segment_y_start.resize(image_ysize as usize + 1, 0); + + for (i, segment) in segments_by_y.iter().enumerate() { + self.segment_indices[i] = segment.1; + let y = segment.0; + if y < image_ysize { + self.segment_y_start[y as usize + 1] += 1; + } + } + for y in 0..image_ysize { + self.segment_y_start[y as usize + 1] += self.segment_y_start[y as usize]; + } + Ok(()) } #[instrument(level = "debug", skip(br), ret, err)] @@ -575,11 +711,17 @@ impl Splines { #[cfg(test)] mod test_splines { - use crate::{error::Error, util::test::assert_all_almost_eq}; + use std::{f32::consts::SQRT_2, iter::zip}; + + use crate::{ + error::Error, + features::spline::SplineSegment, + util::test::{assert_all_almost_eq, assert_almost_eq}, + }; use super::{ - draw_centripetal_catmull_rom_spline, for_each_equally_spaced_point, Point, QuantizedSpline, - Spline, + draw_centripetal_catmull_rom_spline, for_each_equally_spaced_point, Dct32, Point, + QuantizedSpline, Spline, Splines, DESIRED_RENDERING_DISTANCE, }; #[test] @@ -627,31 +769,31 @@ mod test_splines { Point { x: 17.0, y: 277.0 }, ], color_dct: [ - [ + Dct32([ 36.3005, 39.6984, 23.2008, 67.4982, 4.4016, 71.5008, 62.2986, 32.298, 92.1984, 10.101, 10.7982, 9.198, 6.0984, 10.5, 79.0986, 7.0014, 24.5994, 90.7998, 5.502, 84.0, 43.8018, 49.0014, 33.4992, 78.9012, 54.4992, 77.9016, 62.1012, 51.3996, 36.4014, 14.301, 83.7018, 35.4018, - ], - [ + ]), + Dct32([ 9.38684, 53.4, 9.525, 74.925, 72.675, 26.7, 7.875, 0.9, 84.9, 23.175, 26.475, 31.125, 90.975, 11.7, 74.1, 39.3, 23.7, 82.5, 4.8, 2.7, 61.2, 96.375, 13.725, 66.675, 62.925, 82.425, 5.925, 98.7, 21.525, 7.875, 51.675, 63.075, - ], - [ + ]), + Dct32([ 47.9949, 39.33, 6.865, 26.275, 33.265, 6.19, 1.715, 98.9, 59.91, 59.575, 95.005, 61.295, 82.715, 53.0, 6.13, 30.41, 34.69, 96.92, 93.42, 16.98, 38.8, 80.765, 63.005, 18.585, 43.605, 32.305, 61.015, 20.23, 24.325, 28.315, 69.105, 62.375, - ], + ]), ], - sigma_dct: [ + sigma_dct: Dct32([ 32.7593, 21.6645, 44.3289, 1.6665, 45.6621, 90.6576, 29.3304, 59.3274, 23.6643, 85.3248, 84.6582, 27.3306, 41.9958, 83.9916, 50.6616, 17.6649, 93.6573, 4.9995, 2.6664, 69.6597, 94.9905, 51.9948, 24.3309, 18.6648, 11.9988, 95.6571, 28.6638, 81.3252, 89.991, 31.3302, 74.6592, 51.9948, - ], + ]), estimated_area_reached: 19843491681, }, ), @@ -700,31 +842,31 @@ mod test_splines { Point { x: 233.0, y: 267.0 }, ], color_dct: [ - [ + Dct32([ 15.0007, 28.9002, 21.9996, 6.5982, 41.7984, 83.0004, 8.6016, 56.8008, 68.901, 9.702, 5.4012, 19.7988, 70.7994, 90.0018, 52.5, 65.2008, 7.7994, 23.499, 26.4012, 72.198, 64.701, 87.0996, 1.302, 67.4982, 45.9984, 68.4012, 65.3982, 35.4984, 29.1018, 12.999, 41.601, 23.898, - ], - [ + ]), + Dct32([ 47.6767, 79.425, 62.7, 29.1, 96.825, 18.525, 17.625, 15.225, 80.475, 56.025, 96.225, 59.925, 26.7, 96.075, 92.325, 42.075, 35.775, 54.0, 23.175, 54.975, 75.975, 35.775, 58.425, 88.725, 2.4, 78.075, 95.625, 27.525, 6.6, 78.525, 24.075, 69.825, - ], - [ + ]), + Dct32([ 43.8159, 96.505, 0.889999, 95.11, 49.085, 71.165, 25.115, 33.565, 75.225, 95.015, 82.085, 19.675, 10.53, 44.905, 49.975, 93.315, 83.515, 99.5, 64.615, 53.995, 3.52501, 99.685, 45.265, 82.075, 22.42, 37.895, 59.995, 32.215, 12.62, 4.605, 65.515, 96.425, - ], + ]), ], - sigma_dct: [ + sigma_dct: Dct32([ 72.589, 2.6664, 41.6625, 2.3331, 39.6627, 78.9921, 69.6597, 19.998, 92.3241, 71.6595, 41.9958, 61.9938, 29.997, 49.3284, 70.3263, 45.3288, 62.6604, 47.3286, 46.662, 41.3292, 90.6576, 46.662, 91.3242, 54.9945, 7.9992, 69.6597, 25.3308, 84.6582, 61.6605, 27.6639, 3.6663, 46.9953, - ], + ]), estimated_area_reached: 25829781306, }, ), @@ -777,31 +919,31 @@ mod test_splines { Point { x: 390.0, y: 88.0 }, ], color_dct: [ - [ + Dct32([ 16.9014, 64.8018, 4.2, 10.6008, 23.499, 17.0016, 79.3002, 5.6994, 60.4002, 16.5984, 94.899, 63.7014, 87.5994, 10.5, 3.801, 61.1016, 22.8984, 81.9, 80.4006, 40.5006, 45.9018, 25.4016, 39.7992, 30.0006, 50.1984, 90.4008, 27.9006, 93.702, 65.1, 48.1992, 22.302, 43.8984, - ], - [ + ]), + Dct32([ 24.9255, 66.0, 3.525, 90.225, 97.125, 15.825, 35.625, 0.6, 68.025, 39.6, 24.375, 85.875, 57.675, 77.625, 47.475, 67.875, 4.275, 5.4, 91.2, 58.5, 0.075, 52.2, 3.525, 47.775, 63.225, 43.5, 85.8, 35.775, 50.175, 35.925, 19.2, 48.225, - ], - [ + ]), + Dct32([ 82.7881, 44.93, 76.395, 39.475, 94.115, 14.285, 89.805, 9.98, 10.485, 74.53, 56.295, 65.785, 7.765, 23.305, 52.795, 99.305, 56.775, 46.0, 76.71, 13.49, 66.995, 22.38, 29.915, 43.295, 70.295, 26.0, 74.32, 53.905, 62.005, 19.125, 49.3, 46.685, - ], + ]), ], - sigma_dct: [ + sigma_dct: Dct32([ 83.4303, 1.6665, 24.9975, 18.6648, 46.662, 75.3258, 27.9972, 62.3271, 50.3283, 23.331, 85.6581, 95.9904, 45.6621, 32.9967, 33.33, 52.9947, 26.3307, 58.6608, 19.6647, 69.993, 92.6574, 22.6644, 56.9943, 21.6645, 76.659, 87.6579, 22.9977, 66.3267, 35.6631, 35.6631, 56.661, 67.3266, - ], + ]), estimated_area_reached: 47263284396, }, ), @@ -846,12 +988,16 @@ mod test_splines { ); for index in 0..got_dequantized.color_dct.len() { assert_all_almost_eq!( - got_dequantized.color_dct[index], - want_dequantized.color_dct[index], + got_dequantized.color_dct[index].0, + want_dequantized.color_dct[index].0, 1e-4, ); } - assert_all_almost_eq!(got_dequantized.sigma_dct, want_dequantized.sigma_dct, 1e-4); + assert_all_almost_eq!( + got_dequantized.sigma_dct.0, + want_dequantized.sigma_dct.0, + 1e-4 + ); assert_eq!( got_dequantized.estimated_area_reached, want_dequantized.estimated_area_reached, @@ -953,4 +1099,323 @@ mod test_splines { ); Ok(()) } + + #[test] + fn dct32() -> Result<(), Error> { + let mut dct = Dct32::default(); + for (i, coeff) in dct.0.iter_mut().enumerate() { + *coeff = 0.05f32 * i as f32; + } + // Golden numbers come from libjxl. + let want_out = [ + 16.7353, -18.6042, 7.99317, -7.12508, 4.66999, -4.33676, 3.24505, -3.06945, 2.44468, + -2.33509, 1.92438, -1.8484, 1.55314, -1.49642, 1.27014, -1.22549, 1.04345, -1.00677, + 0.854484, -0.823243, 0.691654, -0.66428, 0.547331, -0.522654, 0.416109, -0.393396, + 0.294056, -0.272631, 0.178113, -0.157472, 0.0656886, -0.0454512, + ]; + for (t, want) in want_out.iter().enumerate() { + let got_out = dct.continuous_idct(t as f32); + assert_almost_eq!(got_out, *want, 1e-3); + } + Ok(()) + } + + fn verify_segment_almost_equal(seg1: &SplineSegment, seg2: &SplineSegment) { + assert_almost_eq!(seg1.center_x, seg2.center_x, 1e-3); + assert_almost_eq!(seg1.center_y, seg2.center_y, 1e-3); + for (got, want) in zip(seg1.color.iter(), seg2.color.iter()) { + assert_almost_eq!(*got, *want, 1e-2); + } + assert_almost_eq!(seg1.inv_sigma, seg2.inv_sigma, 1e-3); + assert_almost_eq!(seg1.maximum_distance, seg2.maximum_distance, 1e-1); + assert_almost_eq!( + seg1.sigma_over_4_times_intensity, + seg2.sigma_over_4_times_intensity, + 1e-3 + ); + } + + #[test] + fn spline_segments_add_segment() -> Result<(), Error> { + let mut splines = Splines::default(); + let mut segments_by_y = Vec::<(u64, usize)>::new(); + + splines.add_segment( + &Point { x: 10.0, y: 20.0 }, + 0.5, + [0.5, 0.6, 0.7], + 0.8, + &mut segments_by_y, + ); + // Golden numbers come from libjxl. + let want_segment = SplineSegment { + center_x: 10.0, + center_y: 20.0, + color: [0.5, 0.6, 0.7], + inv_sigma: 1.25, + maximum_distance: 3.65961, + sigma_over_4_times_intensity: 0.1, + }; + assert_eq!(splines.segments.len(), 1); + verify_segment_almost_equal(&splines.segments[0], &want_segment); + let want_segments_by_y = [ + (16, 0), + (17, 0), + (18, 0), + (19, 0), + (20, 0), + (21, 0), + (22, 0), + (23, 0), + (24, 0), + ]; + for (got, want) in zip(segments_by_y.iter(), want_segments_by_y.iter()) { + assert_eq!(got.0, want.0); + assert_eq!(got.1, want.1); + } + Ok(()) + } + + #[test] + fn spline_segments_add_segments_from_points() -> Result<(), Error> { + let mut splines = Splines::default(); + let mut segments_by_y = Vec::<(u64, usize)>::new(); + let mut color_dct = [Dct32::default(); 3]; + for (channel_index, channel_dct) in color_dct.iter_mut().enumerate() { + for (coeff_index, coeff) in channel_dct.0.iter_mut().enumerate() { + *coeff = 0.1 * channel_index as f32 + 0.05 * coeff_index as f32; + } + } + let mut sigma_dct = Dct32::default(); + for (coeff_index, coeff) in sigma_dct.0.iter_mut().enumerate() { + *coeff = 0.06 * coeff_index as f32; + } + let spline = Spline { + control_points: vec![], + color_dct, + sigma_dct, + estimated_area_reached: 0, + }; + let points_to_draw = vec![ + (Point { x: 10.0, y: 20.0 }, 1.0), + (Point { x: 11.0, y: 21.0 }, 1.0), + (Point { x: 12.0, y: 21.0 }, 1.0), + ]; + splines.add_segments_from_points( + &spline, + &points_to_draw, + SQRT_2 + 1.0, + DESIRED_RENDERING_DISTANCE, + &mut segments_by_y, + ); + // Golden numbers come from libjxl. + let want_segments = [ + SplineSegment { + center_x: 10.0, + center_y: 20.0, + color: [16.7353, 19.6865, 22.6376], + inv_sigma: 0.0497949, + maximum_distance: 108.64, + sigma_over_4_times_intensity: 5.02059, + }, + SplineSegment { + center_x: 11.0, + center_y: 21.0, + color: [-0.819923, -0.79605, -0.772177], + inv_sigma: -1.01636, + maximum_distance: 4.68042, + sigma_over_4_times_intensity: -0.245977, + }, + SplineSegment { + center_x: 12.0, + center_y: 21.0, + color: [-0.776775, -0.754424, -0.732072], + inv_sigma: -1.07281, + maximum_distance: 4.42351, + sigma_over_4_times_intensity: -0.233033, + }, + ]; + assert_eq!(splines.segments.len(), want_segments.len()); + for (got, want) in zip(splines.segments.iter(), want_segments.iter()) { + verify_segment_almost_equal(got, want); + } + let want_segments_by_y: Vec<(u64, usize)> = (0..=129) + .map(|c| (c, 0)) + .chain((16..=26).map(|c| (c, 1))) + .chain((17..=25).map(|c| (c, 2))) + .collect(); + for (got, want) in zip(segments_by_y.iter(), want_segments_by_y.iter()) { + assert_eq!(got.0, want.0); + assert_eq!(got.1, want.1); + } + Ok(()) + } + + #[test] + fn init_draw_cache() -> Result<(), Error> { + let mut splines = Splines { + splines: vec![ + QuantizedSpline { + control_points: vec![ + (109, 105), + (-247, -261), + (168, 427), + (-46, -360), + (-61, 181), + ], + color_dct: [ + [ + 12223, 9452, 5524, 16071, 1048, 17024, 14833, 7690, 21952, 2405, 2571, + 2190, 1452, 2500, 18833, 1667, 5857, 21619, 1310, 20000, 10429, 11667, + 7976, 18786, 12976, 18548, 14786, 12238, 8667, 3405, 19929, 8429, + ], + [ + 177, 712, 127, 999, 969, 356, 105, 12, 1132, 309, 353, 415, 1213, 156, + 988, 524, 316, 1100, 64, 36, 816, 1285, 183, 889, 839, 1099, 79, 1316, + 287, 105, 689, 841, + ], + [ + 780, -201, -38, -695, -563, -293, -88, 1400, -357, 520, 979, 431, -118, + 590, -971, -127, 157, 206, 1266, 204, -320, -223, 704, -687, -276, + -716, 787, -1121, 40, 292, 249, -10, + ], + ], + sigma_dct: [ + 139, 65, 133, 5, 137, 272, 88, 178, 71, 256, 254, 82, 126, 252, 152, 53, + 281, 15, 8, 209, 285, 156, 73, 56, 36, 287, 86, 244, 270, 94, 224, 156, + ], + }, + QuantizedSpline { + control_points: vec![ + (24, -32), + (-178, -7), + (226, 151), + (121, -172), + (-184, 39), + (-201, -182), + (301, 404), + ], + color_dct: [ + [ + 5051, 6881, 5238, 1571, 9952, 19762, 2048, 13524, 16405, 2310, 1286, + 4714, 16857, 21429, 12500, 15524, 1857, 5595, 6286, 17190, 15405, + 20738, 310, 16071, 10952, 16286, 15571, 8452, 6929, 3095, 9905, 5690, + ], + [ + 899, 1059, 836, 388, 1291, 247, 235, 203, 1073, 747, 1283, 799, 356, + 1281, 1231, 561, 477, 720, 309, 733, 1013, 477, 779, 1183, 32, 1041, + 1275, 367, 88, 1047, 321, 931, + ], + [ + -78, 244, -883, 943, -682, 752, 107, 262, -75, 557, -202, -575, -231, + -731, -605, 732, 682, 650, 592, -14, -1035, 913, -188, -95, 286, -574, + -509, 67, 86, -1056, 592, 380, + ], + ], + sigma_dct: [ + 308, 8, 125, 7, 119, 237, 209, 60, 277, 215, 126, 186, 90, 148, 211, 136, + 188, 142, 140, 124, 272, 140, 274, 165, 24, 209, 76, 254, 185, 83, 11, 141, + ], + }, + ], + starting_points: vec![Point { x: 10.0, y: 20.0 }, Point { x: 5.0, y: 40.0 }], + ..Default::default() + }; + splines.initialize_draw_cache(1 << 15, 1 << 15)?; + assert_eq!(splines.segments.len(), 1940); + let want_segments_sample = [ + ( + 22, + SplineSegment { + center_x: 25.7765, + center_y: 35.333, + color: [-524.997, -509.905, 43.3884], + inv_sigma: -0.00197347, + maximum_distance: 3021.38, + sigma_over_4_times_intensity: -126.68, + }, + ), + ( + 474, + SplineSegment { + center_x: -16.456, + center_y: 78.8185, + color: [-117.671, -133.552, 343.563], + inv_sigma: -0.00263185, + maximum_distance: 2238.38, + sigma_over_4_times_intensity: -94.9904, + }, + ), + ( + 835, + SplineSegment { + center_x: -71.937, + center_y: 230.064, + color: [44.7951, 298.941, -395.357], + inv_sigma: 0.0186913, + maximum_distance: 316.45, + sigma_over_4_times_intensity: 13.3752, + }, + ), + ( + 1066, + SplineSegment { + center_x: -126.259, + center_y: -22.9786, + color: [-136.42, 194.757, -98.1878], + inv_sigma: 0.00753185, + maximum_distance: 769.254, + sigma_over_4_times_intensity: 33.1924, + }, + ), + ( + 1328, + SplineSegment { + center_x: 73.7087, + center_y: 56.3141, + color: [-13.4439, 162.614, 93.7842], + inv_sigma: 0.00366418, + maximum_distance: 1572.71, + sigma_over_4_times_intensity: 68.2281, + }, + ), + ( + 1545, + SplineSegment { + center_x: 77.4889, + center_y: -92.3388, + color: [-220.681, 66.1304, -32.2618], + inv_sigma: 0.0316616, + maximum_distance: 183.675, + sigma_over_4_times_intensity: 7.89601, + }, + ), + ( + 1774, + SplineSegment { + center_x: -16.4359, + center_y: -144.863, + color: [57.3154, -46.3684, 92.1495], + inv_sigma: -0.0152451, + maximum_distance: 371.483, + sigma_over_4_times_intensity: -16.3988, + }, + ), + ( + 1929, + SplineSegment { + center_x: 61.1934, + center_y: -10.7072, + color: [-69.7881, 300.608, -476.514], + inv_sigma: 0.00322928, + maximum_distance: 1841.38, + sigma_over_4_times_intensity: 77.4166, + }, + ), + ]; + for (index, segment) in want_segments_sample { + verify_segment_almost_equal(&segment, &splines.segments[index]); + } + Ok(()) + } } diff --git a/jxl/src/frame.rs b/jxl/src/frame.rs index 4bdb3a0..8e50463 100644 --- a/jxl/src/frame.rs +++ b/jxl/src/frame.rs @@ -6,7 +6,7 @@ use crate::{ bit_reader::BitReader, error::Result, - features::{noise::Noise, spline::Splines}, + features::{noise::Noise, patches::PatchesDictionary, spline::Splines}, headers::{ color_encoding::ColorSpace, encodings::UnconditionalCoder, @@ -14,6 +14,7 @@ use crate::{ frame_header::{Encoding, FrameHeader, Toc, TocNonserialized}, FileHeader, }, + image::Image, util::tracing_wrappers::*, }; use modular::{FullModularImage, Tree}; @@ -32,8 +33,7 @@ pub enum Section { #[allow(dead_code)] pub struct LfGlobalState { - // TODO(veluca93): patches - // TODO(veluca93): splines + patches: Option, splines: Option, noise: Option, lf_quant: LfQuantFactors, @@ -44,19 +44,72 @@ pub struct LfGlobalState { modular_global: FullModularImage, } +#[derive(Debug)] +pub struct ReferenceFrame { + pub frame: Vec>, + pub saved_before_color_transform: bool, +} + +impl ReferenceFrame { + // TODO(firsching): make this #[cfg(test)] + fn blank( + width: usize, + height: usize, + num_channels: usize, + saved_before_color_transform: bool, + ) -> Result { + let frame = (0..num_channels) + .map(|_| Image::new_constant((width, height), 0.0)) + .collect::>()?; + Ok(Self { + frame, + saved_before_color_transform, + }) + } +} + +#[derive(Debug)] +pub struct DecoderState { + file_header: FileHeader, + reference_frames: [Option; Self::MAX_STORED_FRAMES], +} + +impl DecoderState { + pub const MAX_STORED_FRAMES: usize = 4; + + pub fn new(file_header: FileHeader) -> Self { + Self { + file_header, + reference_frames: [None, None, None, None], + } + } + + pub fn extra_channel_info(&self) -> &Vec { + &self.file_header.image_metadata.extra_channel_info + } + + pub fn reference_frame(&self, i: usize) -> Option<&ReferenceFrame> { + assert!(i < Self::MAX_STORED_FRAMES); + self.reference_frames[i].as_ref() + } +} + pub struct Frame { header: FrameHeader, toc: Toc, modular_color_channels: usize, - extra_channel_info: Vec, lf_global: Option, + decoder_state: DecoderState, } impl Frame { - pub fn new(br: &mut BitReader, file_header: &FileHeader) -> Result { - let frame_header = - FrameHeader::read_unconditional(&(), br, &file_header.frame_header_nonserialized()) - .unwrap(); + pub fn new(br: &mut BitReader, decoder_state: DecoderState) -> Result { + let frame_header = FrameHeader::read_unconditional( + &(), + br, + &decoder_state.file_header.frame_header_nonserialized(), + ) + .unwrap(); let num_toc_entries = frame_header.num_toc_entries(); let toc = Toc::read_unconditional( &(), @@ -69,7 +122,13 @@ impl Frame { br.jump_to_byte_boundary()?; let modular_color_channels = if frame_header.encoding == Encoding::VarDCT { 0 - } else if file_header.image_metadata.color_encoding.color_space == ColorSpace::Gray { + } else if decoder_state + .file_header + .image_metadata + .color_encoding + .color_space + == ColorSpace::Gray + { 1 } else { 3 @@ -77,9 +136,9 @@ impl Frame { Ok(Self { header: frame_header, modular_color_channels, - extra_channel_info: file_header.image_metadata.extra_channel_info.clone(), toc, lf_global: None, + decoder_state, }) } @@ -91,10 +150,6 @@ impl Frame { self.toc.entries.iter().map(|x| *x as usize).sum() } - pub fn is_last(&self) -> bool { - self.header.is_last - } - /// Given a bit reader pointing at the end of the TOC, returns a vector of `BitReader`s, each /// of which reads a specific section. pub fn sections<'a>(&self, br: &'a mut BitReader) -> Result>> { @@ -135,11 +190,20 @@ impl Frame { assert!(self.lf_global.is_none()); trace!(pos = br.total_bits_read()); - if self.header.has_patches() { + let patches = if self.header.has_patches() { info!("decoding patches"); - todo!("patches not implemented"); - } + Some(PatchesDictionary::read( + br, + self.header.width as usize, + self.header.height as usize, + &self.decoder_state, + )?) + } else { + None + }; + let splines = if self.header.has_splines() { + info!("decoding splines"); Some(Splines::read(br, self.header.width * self.header.height)?) } else { None @@ -164,7 +228,8 @@ impl Frame { let size_limit = (1024 + self.header.width as usize * self.header.height as usize - * (self.modular_color_channels + self.extra_channel_info.len()) + * (self.modular_color_channels + + self.decoder_state.extra_channel_info().len()) / 16) .min(1 << 22); Some(Tree::read(br, size_limit)?) @@ -175,12 +240,13 @@ impl Frame { let modular_global = FullModularImage::read( &self.header, self.modular_color_channels, - &self.extra_channel_info, + self.decoder_state.extra_channel_info(), &tree, br, )?; self.lf_global = Some(LfGlobalState { + patches, splines, noise, lf_quant, @@ -190,6 +256,26 @@ impl Frame { Ok(()) } + + pub fn finalize(mut self) -> Result> { + if self.header.can_be_referenced { + // TODO(firsching): actually use real reference images here, instead of setting it + // to a blank image here, once we can decode images. + self.decoder_state.reference_frames[self.header.save_as_reference as usize] = + Some(ReferenceFrame::blank( + self.header.width as usize, + self.header.height as usize, + // Set num_channels to "3 + self.decoder_state.extra_channel_info().len()" here unconditionally for now. + 3 + self.decoder_state.extra_channel_info().len(), + self.header.save_before_ct, + )?); + } + Ok(if self.header.is_last { + None + } else { + Some(self.decoder_state) + }) + } } #[cfg(test)] @@ -207,28 +293,34 @@ mod test { util::test::assert_almost_eq, }; - use super::{Frame, Section}; + use super::{DecoderState, Frame, Section}; - fn read_frames(image: &[u8]) -> Result, Error> { + fn read_frames( + image: &[u8], + mut callback: impl FnMut(Frame) -> Result, Error>, + ) -> Result<(), Error> { let codestream = ContainerParser::collect_codestream(image).unwrap(); let mut br = BitReader::new(&codestream); let file_header = FileHeader::read(&mut br).unwrap(); - let mut frames = vec![]; + let mut decoder_state = DecoderState::new(file_header); loop { - let mut frame = Frame::new(&mut br, &file_header)?; - let is_last = frame.is_last(); + let mut frame = Frame::new(&mut br, decoder_state)?; let mut sections = frame.sections(&mut br)?; frame.decode_lf_global(&mut sections[frame.get_section_idx(Section::LfGlobal)])?; - frames.push(frame); - if is_last { + + // Call the callback with the frame + if let Some(state) = callback(frame)? { + decoder_state = state; + } else { break; } } - Ok(frames) + Ok(()) } + fn read_frames_from_path(path: &Path) -> Result<(), Error> { let data = std::fs::read(path).unwrap(); - let result = panic::catch_unwind(|| read_frames(data.as_slice())); + let result = panic::catch_unwind(|| read_frames(data.as_slice(), |frame| frame.finalize())); match result { Ok(Ok(_frame)) => {} @@ -240,8 +332,6 @@ mod test { if let Some(msg) = e.downcast_ref::<&str>() { if msg.contains("VarDCT not implemented") { println!("Skipping {}: VarDCT not implemented", path.display()); - } else if msg.contains("patches not implemented") { - println!("Skipping {}: patches not implented", path.display()); } else { panic::resume_unwind(e); } @@ -258,7 +348,11 @@ mod test { #[test] fn splines() -> Result<(), Error> { - let frames = read_frames(include_bytes!("../resources/test/splines.jxl"))?; + let mut frames = Vec::new(); + read_frames(include_bytes!("../resources/test/splines.jxl"), |frame| { + frames.push(frame); + Ok(None) + })?; assert_eq!(frames.len(), 1); let frame = &frames[0]; let lf_global = frame.lf_global.as_ref().unwrap(); @@ -313,7 +407,12 @@ mod test { #[test] fn noise() -> Result<(), Error> { - let frames = read_frames(include_bytes!("../resources/test/8x8_noise.jxl"))?; + let mut frames = Vec::new(); + read_frames(include_bytes!("../resources/test/8x8_noise.jxl"), |frame| { + frames.push(frame); + Ok(None) + })?; + assert_eq!(frames.len(), 1); let frame = &frames[0]; let lf_global = frame.lf_global.as_ref().unwrap(); @@ -326,4 +425,18 @@ mod test { } Ok(()) } + + #[test] + fn patches() -> Result<(), Error> { + let mut frames = Vec::new(); + read_frames( + include_bytes!("../resources/test/grayscale_patches_modular.jxl"), + |frame| { + frames.push(frame); + Ok(None) + }, + )?; + // TODO(firsching) add test for patches + Ok(()) + } } diff --git a/jxl/src/frame/modular/tree.rs b/jxl/src/frame/modular/tree.rs index 3bb4247..7cf7f81 100644 --- a/jxl/src/frame/modular/tree.rs +++ b/jxl/src/frame/modular/tree.rs @@ -10,7 +10,7 @@ use crate::{ bit_reader::BitReader, entropy_coding::decode::Histograms, error::{Error, Result}, - util::tracing_wrappers::*, + util::{tracing_wrappers::*, NewWithCapacity}, }; #[allow(dead_code)] @@ -115,11 +115,9 @@ impl Tree { tree_reader.check_final_state()?; let num_properties = max_property as usize + 1; - let mut property_ranges = vec![]; - property_ranges.try_reserve(num_properties * tree.len())?; + let mut property_ranges = Vec::new_with_capacity(num_properties * tree.len())?; property_ranges.resize(num_properties * tree.len(), (i32::MIN, i32::MAX)); - let mut height = vec![]; - height.try_reserve(tree.len())?; + let mut height = Vec::new_with_capacity(tree.len())?; height.resize(tree.len(), 0); for i in 0..tree.len() { const HEIGHT_LIMIT: usize = 2048; diff --git a/jxl/src/headers/frame_header.rs b/jxl/src/headers/frame_header.rs index 4a333bc..b2da142 100644 --- a/jxl/src/headers/frame_header.rs +++ b/jxl/src/headers/frame_header.rs @@ -408,13 +408,13 @@ pub struct FrameHeader { #[coder(Bits(2))] #[default(0)] #[condition(frame_type != FrameType::LFFrame && !is_last)] - save_as_reference: u32, + pub save_as_reference: u32, // The following 3 fields are not actually serialized, but just used as variables to help with // defining later conditions. #[default(!is_last && frame_type != FrameType::LFFrame && (duration == 0 || save_as_reference != 0))] #[condition(false)] - can_be_referenced: bool, + pub can_be_referenced: bool, #[default(!have_crop || frame_width >= nonserialized.img_width && frame_height >= nonserialized.img_height && x0 == 0 && y0 == 0)] #[condition(false)] @@ -427,7 +427,7 @@ pub struct FrameHeader { #[default(frame_type == FrameType::LFFrame)] #[condition(frame_type == FrameType::ReferenceOnly || save_before_ct_def_false)] - save_before_ct: bool, + pub save_before_ct: bool, name: String, diff --git a/jxl/src/headers/permutation.rs b/jxl/src/headers/permutation.rs index d12b584..9449667 100644 --- a/jxl/src/headers/permutation.rs +++ b/jxl/src/headers/permutation.rs @@ -6,7 +6,7 @@ use crate::bit_reader::BitReader; use crate::entropy_coding::decode::Reader; use crate::error::{Error, Result}; -use crate::util::{tracing_wrappers::instrument, value_of_lowest_1_bit, CeilLog2}; +use crate::util::{tracing_wrappers::instrument, value_of_lowest_1_bit, CeilLog2, NewWithCapacity}; #[derive(Debug, PartialEq, Default)] pub struct Permutation(pub Vec); @@ -41,8 +41,7 @@ impl Permutation { return Err(Error::InvalidPermutationSize { size, skip, end }); } - let mut lehmer = Vec::new(); - lehmer.try_reserve(end as usize)?; + let mut lehmer = Vec::new_with_capacity(end as usize)?; let mut prev_val = 0u32; for idx in skip..(skip + end) { @@ -59,8 +58,7 @@ impl Permutation { } // Initialize the full permutation vector with skipped elements intact - let mut permutation: Vec = Vec::new(); - permutation.try_reserve((size - skip) as usize)?; + let mut permutation = Vec::new_with_capacity((size - skip) as usize)?; permutation.extend(0..size); // Decode the Lehmer code into the slice starting at `skip` @@ -88,15 +86,13 @@ fn decode_lehmer_code(code: &[u32], permutation_slice: &[u32]) -> Result Result Resu let header_size = output_size.min(ICC_HEADER_SIZE); let header_data = data_stream.read_to_vec_exact(header_size as usize)?; - let mut profile = Vec::new(); - profile.try_reserve(output_size as usize)?; + let mut profile = Vec::new_with_capacity(output_size as usize)?; for (idx, &e) in header_data.iter().enumerate() { let p = predict_header(idx, output_size as u32, &header_data); diff --git a/jxl/src/icc/stream.rs b/jxl/src/icc/stream.rs index fd8bf4a..ff6453b 100644 --- a/jxl/src/icc/stream.rs +++ b/jxl/src/icc/stream.rs @@ -11,6 +11,7 @@ use crate::bit_reader::*; use crate::entropy_coding::decode::{Histograms, Reader}; use crate::error::{Error, Result}; use crate::util::tracing_wrappers::{instrument, warn}; +use crate::util::NewWithCapacity; fn read_varint(mut read_one: impl FnMut() -> Result) -> Result { let mut value = 0u64; @@ -129,8 +130,7 @@ impl<'br, 'buf, 'hist> IccStream<'br, 'buf, 'hist> { return Err(Error::IccEndOfStream); } - let mut out = Vec::new(); - out.try_reserve(len)?; + let mut out = Vec::new_with_capacity(len)?; for _ in 0..len { out.push(self.read_one()?); diff --git a/jxl/src/icc/tag.rs b/jxl/src/icc/tag.rs index c69bc6d..2c9b232 100644 --- a/jxl/src/icc/tag.rs +++ b/jxl/src/icc/tag.rs @@ -9,6 +9,7 @@ use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use crate::error::{Error, Result}; use crate::util::tracing_wrappers::warn; +use crate::util::NewWithCapacity; use super::{read_varint_from_reader, IccStream, ICC_HEADER_SIZE}; @@ -96,8 +97,7 @@ pub(super) fn read_tag_list( fn shuffle_w2(bytes: &[u8]) -> Result> { let len = bytes.len(); - let mut out = Vec::new(); - out.try_reserve(len)?; + let mut out = Vec::new_with_capacity(len)?; let height = len / 2; let odd = len % 2; @@ -113,8 +113,7 @@ fn shuffle_w2(bytes: &[u8]) -> Result> { fn shuffle_w4(bytes: &[u8]) -> Result> { let len = bytes.len(); - let mut out = Vec::new(); - out.try_reserve(len)?; + let mut out = Vec::new_with_capacity(len)?; let step = len / 4; let wide_count = len % 4; diff --git a/jxl/src/image.rs b/jxl/src/image.rs index 417f8ce..9cfb0b1 100644 --- a/jxl/src/image.rs +++ b/jxl/src/image.rs @@ -108,10 +108,15 @@ impl ImageDataType for half::f16 { impl_image_data_type!(f64, F64); pub struct Image { - size: (usize, usize), + pub size: (usize, usize), data: Vec, } +impl Debug for Image { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?} {}x{}", T::DATA_TYPE_ID, self.size.0, self.size.1,) + } +} #[derive(Clone, Copy)] pub struct ImageRect<'a, T: ImageDataType> { origin: (usize, usize), @@ -194,7 +199,7 @@ impl Image { Ok(img) } - #[cfg(test)] + // TODO(firsching): make this #[cfg(test)] pub fn new_constant(size: (usize, usize), val: T) -> Result> { let mut img = Self::new(size)?; img.data.iter_mut().for_each(|x| *x = val); diff --git a/jxl/src/render/stages/noise.rs b/jxl/src/render/stages/noise.rs index 4587057..c0b2fe0 100644 --- a/jxl/src/render/stages/noise.rs +++ b/jxl/src/render/stages/noise.rs @@ -146,7 +146,7 @@ mod test { }; use test_log::test; - // TODO(mo271): Add more relevant ConvolveNoise tests as per discussions in https://github.com/libjxl/jxl-rs/pull/60. + // TODO(firsching): Add more relevant ConvolveNoise tests as per discussions in https://github.com/libjxl/jxl-rs/pull/60. #[test] fn convolve_noise_process_row_chunk() -> Result<()> { @@ -170,7 +170,7 @@ mod test { ) } - // TODO(mo271): Add more relevant AddNoise tests as per discussions in https://github.com/libjxl/jxl-rs/pull/60. + // TODO(firsching): Add more relevant AddNoise tests as per discussions in https://github.com/libjxl/jxl-rs/pull/60. #[test] fn add_noise_process_row_chunk() -> Result<()> { diff --git a/jxl/src/util.rs b/jxl/src/util.rs index fd00579..9ac8ca9 100644 --- a/jxl/src/util.rs +++ b/jxl/src/util.rs @@ -7,14 +7,15 @@ pub mod test; mod bits; -#[allow(unused)] mod concat_slice; mod log2; mod shift_right_ceil; pub mod tracing_wrappers; +mod vec_helpers; pub use bits::*; #[allow(unused)] pub use concat_slice::*; pub use log2::*; pub use shift_right_ceil::*; +pub use vec_helpers::*; diff --git a/jxl/src/util/vec_helpers.rs b/jxl/src/util/vec_helpers.rs new file mode 100644 index 0000000..e9af8c8 --- /dev/null +++ b/jxl/src/util/vec_helpers.rs @@ -0,0 +1,33 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TODO(firsching): as soon as "Vec::try_with_capacity" is available from the +// standard library use this instead of the functions here. +pub trait NewWithCapacity { + type Output; + type Error; + fn new_with_capacity(capacity: usize) -> Result; +} + +impl NewWithCapacity for Vec { + type Output = Vec; + type Error = std::collections::TryReserveError; + + fn new_with_capacity(capacity: usize) -> Result { + let mut vec = Vec::new(); + vec.try_reserve(capacity)?; + Ok(vec) + } +} + +impl NewWithCapacity for String { + type Output = String; + type Error = std::collections::TryReserveError; + fn new_with_capacity(capacity: usize) -> Result { + let mut s = String::new(); + s.try_reserve(capacity)?; + Ok(s) + } +} diff --git a/jxl_cli/src/bin/jxlinspect.rs b/jxl_cli/src/bin/jxlinspect.rs index a4570e0..730b8d5 100644 --- a/jxl_cli/src/bin/jxlinspect.rs +++ b/jxl_cli/src/bin/jxlinspect.rs @@ -6,7 +6,7 @@ use clap::{Arg, Command}; use jxl::bit_reader::BitReader; use jxl::container::{ContainerParser, ParseEvent}; -use jxl::frame::Frame; +use jxl::frame::{DecoderState, Frame}; use jxl::headers::color_encoding::{ColorEncoding, Primaries, WhitePoint}; use jxl::headers::{FileHeader, JxlHeader}; use jxl::icc::read_icc; @@ -128,9 +128,9 @@ fn parse_jxl_codestream(data: &[u8], verbose: bool) -> Result<(), jxl::error::Er // TODO(firsching): handle frames which are blended together, also within animations. if let Some(ref animation) = file_header.image_metadata.animation { let mut total_duration = 0.0f64; - let mut not_is_last = true; - while not_is_last { - let frame = Frame::new(&mut br, &file_header)?; + let mut decoder_state = DecoderState::new(file_header.clone()); + loop { + let frame = Frame::new(&mut br, decoder_state)?; let ms = frame.header().duration(animation); total_duration += ms; println!( @@ -141,8 +141,12 @@ fn parse_jxl_codestream(data: &[u8], verbose: bool) -> Result<(), jxl::error::Er frame.header().y0 ); br.jump_to_byte_boundary()?; - not_is_last = !frame.is_last(); br.skip_bits(frame.total_bytes_in_toc() * 8)?; + if let Some(state) = frame.finalize()? { + decoder_state = state; + } else { + break; + } } print!( "Animation length: {} seconds", diff --git a/jxl_cli/src/main.rs b/jxl_cli/src/main.rs index 1d524b1..20b6adf 100644 --- a/jxl_cli/src/main.rs +++ b/jxl_cli/src/main.rs @@ -5,7 +5,7 @@ use jxl::bit_reader::BitReader; use jxl::container::{ContainerParser, ParseEvent}; -use jxl::frame::{Frame, Section}; +use jxl::frame::{DecoderState, Frame, Section}; use jxl::headers::FileHeader; use jxl::icc::read_icc; use std::env; @@ -26,9 +26,9 @@ fn parse_jxl_codestream(data: &[u8]) -> Result<(), jxl::error::Error> { let r = read_icc(&mut br)?; println!("found {}-byte ICC", r.len()); }; - + let mut decoder_state = DecoderState::new(file_header); loop { - let mut frame = Frame::new(&mut br, &file_header)?; + let mut frame = Frame::new(&mut br, decoder_state)?; br.jump_to_byte_boundary()?; let mut section_readers = frame.sections(&mut br)?; @@ -36,8 +36,9 @@ fn parse_jxl_codestream(data: &[u8]) -> Result<(), jxl::error::Error> { println!("read frame with {} sections", section_readers.len()); frame.decode_lf_global(&mut section_readers[frame.get_section_idx(Section::LfGlobal)])?; - - if frame.header().is_last { + if let Some(state) = frame.finalize()? { + decoder_state = state; + } else { break; } }