Skip to content

Commit

Permalink
jxl-render: SSE2 version of EPF (#86)
Browse files Browse the repository at this point in the history
* jxl-render: SSE2 version of EPF

* Remove redundant bounds checks
  • Loading branch information
tirr-c authored Sep 25, 2023
1 parent bb4cbbe commit b7afe10
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 125 deletions.
11 changes: 11 additions & 0 deletions crates/jxl-grid/src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub trait SimdVector: Copy {
fn sub(self, lhs: Self) -> Self;
fn mul(self, lhs: Self) -> Self;
fn div(self, lhs: Self) -> Self;
fn abs(self) -> Self;

fn muladd(self, mul: Self, add: Self) -> Self;
fn mulsub(self, mul: Self, sub: Self) -> Self;
Expand Down Expand Up @@ -122,6 +123,16 @@ impl SimdVector for std::arch::x86_64::__m128 {
unsafe { std::arch::x86_64::_mm_div_ps(self, lhs) }
}

#[inline]
fn abs(self) -> Self {
unsafe {
std::arch::x86_64::_mm_andnot_ps(
Self::splat_f32(f32::from_bits(0x80000000)),
self,
)
}
}

#[inline]
#[cfg(target_feature = "fma")]
fn muladd(self, mul: Self, add: Self) -> Self {
Expand Down
207 changes: 82 additions & 125 deletions crates/jxl-render/src/filter/epf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,92 +5,6 @@ use jxl_grid::SimpleGrid;

use crate::{region::ImageWithRegion, Region};

#[inline]
fn weight(scaled_distance: f32, sigma: f32, step_multiplier: f32) -> f32 {
let inv_sigma = step_multiplier * 6.6 * (1.0 - std::f32::consts::FRAC_1_SQRT_2) / sigma;
(1.0 - scaled_distance * inv_sigma).max(0.0)
}

#[allow(clippy::too_many_arguments)]
fn epf_step(
input: &[SimpleGrid<f32>; 3],
output: &mut [SimpleGrid<f32>; 3],
sigma_grid: &SimpleGrid<f32>,
channel_scale: [f32; 3],
border_sad_mul: f32,
step_multiplier: f32,
kernel_coords: &'static [(isize, isize)],
dist_coords: &'static [(isize, isize)],
) {
let width = input[0].width();
let height = input[0].height();
for y in 0..height - 6 {
let y8 = y / 8;
let is_y_border = (y % 8) == 0 || (y % 8) == 7;
let y = y + 3;

for x in 0..width - 6 {
let x8 = x / 8;
let is_border = is_y_border || (x % 8) == 0 || (x % 8) == 7;
let x = x + 3;

let sigma_val = *sigma_grid.get(x8, y8).unwrap();
if sigma_val < 0.3 {
for (input, ch) in input.iter().zip(output.iter_mut()) {
let input_ch = input.buf();
let output_ch = ch.buf_mut();
output_ch[y * width + x] = input_ch[y * width + x];
}
continue;
}

let mut sum_weights = 1.0f32;
let mut sum_channels = [0.0f32; 3];
for (sum, ch) in sum_channels.iter_mut().zip(input) {
let ch = ch.buf();
*sum = ch[y * width + x];
}

for &(dx, dy) in kernel_coords {
let tx = x as isize + dx;
let ty = y as isize + dy;
let mut dist = 0.0f32;
for (ch, scale) in input.iter().zip(channel_scale) {
let ch = ch.buf();
for &(dx, dy) in dist_coords {
let x = x as isize + dx;
let y = y as isize + dy;
let tx = (tx + dx) as usize;
let ty = (ty + dy) as usize;
let x = x as usize;
let y = y as usize;
dist += (ch[y * width + x] - ch[ty * width + tx]).abs() * scale;
}
}

let weight = weight(
dist,
sigma_val,
step_multiplier * if is_border { border_sad_mul } else { 1.0 },
);
sum_weights += weight;

let tx = tx as usize;
let ty = ty as usize;
for (sum, ch) in sum_channels.iter_mut().zip(input) {
let ch = ch.buf();
*sum += ch[ty * width + tx] * weight;
}
}

for (sum, ch) in sum_channels.into_iter().zip(output.iter_mut()) {
let ch = ch.buf_mut();
ch[y * width + x] = sum / sum_weights;
}
}
}
}

pub fn apply_epf(
fb: &mut ImageWithRegion,
lf_groups: &HashMap<u32, LfGroup>,
Expand All @@ -114,15 +28,16 @@ pub fn apply_epf(

let width = region.width as usize;
let height = region.height as usize;
// Extra padding for SIMD
let mut fb_in = [
SimpleGrid::new(width + 6, height + 6),
SimpleGrid::new(width + 6, height + 6),
SimpleGrid::new(width + 6, height + 6),
SimpleGrid::new(width + 6, height + 7),
SimpleGrid::new(width + 6, height + 7),
SimpleGrid::new(width + 6, height + 7),
];
let mut fb_out = [
SimpleGrid::new(width + 6, height + 6),
SimpleGrid::new(width + 6, height + 6),
SimpleGrid::new(width + 6, height + 6),
SimpleGrid::new(width + 6, height + 7),
SimpleGrid::new(width + 6, height + 7),
SimpleGrid::new(width + 6, height + 7),
];
for (output, input) in fb_in.iter_mut().zip(&*fb) {
let output = output.buf_mut();
Expand Down Expand Up @@ -204,19 +119,33 @@ pub fn apply_epf(
}
}

epf_step(
&fb_in,
&mut fb_out,
sigma_grid,
channel_scale,
sigma.border_sad_mul,
sigma.pass0_sigma_scale,
&[
(0, -1), (-1, 0), (1, 0), (0, 1),
(0, -2), (-1, -1), (1, -1), (-2, 0), (2, 0), (-1, 1), (1, 1), (0, 2),
],
&[(0, 0), (0, -1), (-1, 0), (1, 0), (0, 1)],
);
#[cfg(target_arch = "x86_64")]
{
super::x86_64::epf_step0_sse2(
&fb_in,
&mut fb_out,
sigma_grid,
channel_scale,
sigma.border_sad_mul,
sigma.pass0_sigma_scale,
);
}
#[cfg(not(target_arch = "x86_64"))]
{
super::generic::epf_step(
&fb_in,
&mut fb_out,
sigma_grid,
channel_scale,
sigma.border_sad_mul,
sigma.pass0_sigma_scale,
&[
(0, -1), (-1, 0), (1, 0), (0, 1),
(0, -2), (-1, -1), (1, -1), (-2, 0), (2, 0), (-1, 1), (1, 1), (0, 2),
],
&[(0, 0), (0, -1), (-1, 0), (1, 0), (0, 1)],
);
}
std::mem::swap(&mut fb_in, &mut fb_out);
}

Expand Down Expand Up @@ -247,16 +176,30 @@ pub fn apply_epf(
}
}

epf_step(
&fb_in,
&mut fb_out,
sigma_grid,
channel_scale,
sigma.border_sad_mul,
1.0,
&[(0, -1), (-1, 0), (1, 0), (0, 1)],
&[(0, 0), (0, -1), (-1, 0), (1, 0), (0, 1)],
);
#[cfg(target_arch = "x86_64")]
{
super::x86_64::epf_step1_sse2(
&fb_in,
&mut fb_out,
sigma_grid,
channel_scale,
sigma.border_sad_mul,
1.0,
);
}
#[cfg(not(target_arch = "x86_64"))]
{
super::generic::epf_step(
&fb_in,
&mut fb_out,
sigma_grid,
channel_scale,
sigma.border_sad_mul,
1.0,
&[(0, -1), (-1, 0), (1, 0), (0, 1)],
&[(0, 0), (0, -1), (-1, 0), (1, 0), (0, 1)],
);
}
std::mem::swap(&mut fb_in, &mut fb_out);
}

Expand Down Expand Up @@ -287,16 +230,30 @@ pub fn apply_epf(
}
}

epf_step(
&fb_in,
&mut fb_out,
sigma_grid,
channel_scale,
sigma.border_sad_mul,
sigma.pass2_sigma_scale,
&[(0, -1), (-1, 0), (1, 0), (0, 1)],
&[(0, 0)],
);
#[cfg(target_arch = "x86_64")]
{
super::x86_64::epf_step2_sse2(
&fb_in,
&mut fb_out,
sigma_grid,
channel_scale,
sigma.border_sad_mul,
sigma.pass2_sigma_scale,
);
}
#[cfg(not(target_arch = "x86_64"))]
{
super::generic::epf_step(
&fb_in,
&mut fb_out,
sigma_grid,
channel_scale,
sigma.border_sad_mul,
sigma.pass2_sigma_scale,
&[(0, -1), (-1, 0), (1, 0), (0, 1)],
&[(0, 0)],
);
}
std::mem::swap(&mut fb_in, &mut fb_out);
}

Expand Down
87 changes: 87 additions & 0 deletions crates/jxl-render/src/filter/generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use jxl_grid::SimpleGrid;

#[inline]
fn weight(scaled_distance: f32, sigma: f32, step_multiplier: f32) -> f32 {
let inv_sigma = step_multiplier * 6.6 * (1.0 - std::f32::consts::FRAC_1_SQRT_2) / sigma;
(1.0 - scaled_distance * inv_sigma).max(0.0)
}

#[allow(clippy::too_many_arguments)]
pub fn epf_step(
input: &[SimpleGrid<f32>; 3],
output: &mut [SimpleGrid<f32>; 3],
sigma_grid: &SimpleGrid<f32>,
channel_scale: [f32; 3],
border_sad_mul: f32,
step_multiplier: f32,
kernel_coords: &'static [(isize, isize)],
dist_coords: &'static [(isize, isize)],
) {
let width = input[0].width();
let height = input[0].height();
for y in 0..height - 7 {
let y8 = y / 8;
let is_y_border = (y % 8) == 0 || (y % 8) == 7;
let y = y + 3;

for x in 0..width - 6 {
let x8 = x / 8;
let is_border = is_y_border || (x % 8) == 0 || (x % 8) == 7;
let x = x + 3;

let sigma_val = *sigma_grid.get(x8, y8).unwrap();
if sigma_val < 0.3 {
for (input, ch) in input.iter().zip(output.iter_mut()) {
let input_ch = input.buf();
let output_ch = ch.buf_mut();
output_ch[y * width + x] = input_ch[y * width + x];
}
continue;
}

let mut sum_weights = 1.0f32;
let mut sum_channels = [0.0f32; 3];
for (sum, ch) in sum_channels.iter_mut().zip(input) {
let ch = ch.buf();
*sum = ch[y * width + x];
}

for &(dx, dy) in kernel_coords {
let tx = x as isize + dx;
let ty = y as isize + dy;
let mut dist = 0.0f32;
for (ch, scale) in input.iter().zip(channel_scale) {
let ch = ch.buf();
for &(dx, dy) in dist_coords {
let x = x as isize + dx;
let y = y as isize + dy;
let tx = (tx + dx) as usize;
let ty = (ty + dy) as usize;
let x = x as usize;
let y = y as usize;
dist += (ch[y * width + x] - ch[ty * width + tx]).abs() * scale;
}
}

let weight = weight(
dist,
sigma_val,
step_multiplier * if is_border { border_sad_mul } else { 1.0 },
);
sum_weights += weight;

let tx = tx as usize;
let ty = ty as usize;
for (sum, ch) in sum_channels.iter_mut().zip(input) {
let ch = ch.buf();
*sum += ch[ty * width + tx] * weight;
}
}

for (sum, ch) in sum_channels.into_iter().zip(output.iter_mut()) {
let ch = ch.buf_mut();
ch[y * width + x] = sum / sum_weights;
}
}
}
}
5 changes: 5 additions & 0 deletions crates/jxl-render/src/filter/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
#[cfg(not(target_arch = "x86_64"))]
mod generic;
#[cfg(target_arch = "x86_64")]
mod x86_64;

mod epf;
mod gabor;
mod ycbcr;
Expand Down
Loading

0 comments on commit b7afe10

Please sign in to comment.