Skip to content

Commit

Permalink
jxl-render: Rewrite gaborish to be autovec friendly
Browse files Browse the repository at this point in the history
  • Loading branch information
tirr-c committed Sep 25, 2023
1 parent a94905e commit 55d8954
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 32 deletions.
33 changes: 1 addition & 32 deletions crates/jxl-render/src/filter/gabor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,5 @@ use jxl_grid::SimpleGrid;

pub fn apply_gabor_like(fb: [&mut SimpleGrid<f32>; 3], weights_xyb: [[f32; 2]; 3]) {
tracing::debug!("Running gaborish");

let width = fb[0].width();
let height = fb[0].height();
let mut ud_sums = Vec::with_capacity(width * height);

let buffers = fb.map(|g| g.buf_mut());
for (c, [weight1, weight2]) in buffers.into_iter().zip(weights_xyb) {
ud_sums.clear();
let rows: Vec<_> = c.chunks_exact(width).collect();

for y in 0..height {
let up = rows[y.saturating_sub(1)];
let down = rows[(y + 1).min(height - 1)];
for (u, d) in up.iter().zip(down) {
ud_sums.push(*u + *d);
}
}

let global_weight = (1.0 + weight1 * 4.0 + weight2 * 4.0).recip();
for y in 0..height {
let mut left = c[y * width];
for x in 0..width {
let x_l = x.saturating_sub(1);
let x_r = (x + 1).min(width - 1);
let side = left + c[y * width + x_r] + ud_sums[y * width + x];
let diag = ud_sums[y * width + x_l] + ud_sums[y * width + x_r];
left = c[y * width + x];
c[y * width + x] += side * weight1 + diag * weight2;
c[y * width + x] *= global_weight;
}
}
}
super::impls::apply_gabor_like(fb, weights_xyb)
}
153 changes: 153 additions & 0 deletions crates/jxl-render/src/filter/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,156 @@ pub use generic::*;

#[cfg(target_arch = "x86_64")]
pub use x86_64::*;

#[inline(always)]
fn run_gabor_inner(fb: &mut jxl_grid::SimpleGrid<f32>, weight1: f32, weight2: f32) {
let global_weight = (1.0 + weight1 * 4.0 + weight2 * 4.0).recip();

let width = fb.width();
let height = fb.height();
let input = fb.clone();

let input = input.buf();
let output = fb.buf_mut();

let len = width * (height - 2) - 2;
let center = &input[width + 1..][..len];
let sides = [
&input[1..][..len],
&input[width..][..len],
&input[width + 2..][..len],
&input[width * 2 + 1..][..len],
];
let diags = [
&input[..len],
&input[2..][..len],
&input[width * 2..][..len],
&input[width * 2 + 2..][..len],
];

for (idx, out) in output[width + 1..][..len].iter_mut().enumerate() {
*out = (
center[idx] +
(sides[0][idx] + sides[1][idx] + sides[2][idx] + sides[3][idx]) * weight1 +
(diags[0][idx] + diags[1][idx] + diags[2][idx] + diags[3][idx]) * weight2
) * global_weight;
}

// top side
let len = width - 2;
let center = &input[1..][..len];
let sides = [
&input[1..][..len],
&input[..len],
&input[2..][..len],
&input[width + 1..][..len],
];
let diags = [
&input[..len],
&input[2..][..len],
&input[width..][..len],
&input[width + 2..][..len],
];

for (idx, out) in output[1..][..len].iter_mut().enumerate() {
*out = (
center[idx] +
(sides[0][idx] + sides[1][idx] + sides[2][idx] + sides[3][idx]) * weight1 +
(diags[0][idx] + diags[1][idx] + diags[2][idx] + diags[3][idx]) * weight2
) * global_weight;
}

// bottom side
let len = width - 2;
let base = width * (height - 1);
let center = &input[base + 1..][..len];
let sides = [
&input[base - width + 1..][..len],
&input[base..][..len],
&input[base + 2..][..len],
&input[base + 1..][..len],
];
let diags = [
&input[base - width..][..len],
&input[base - width + 2..][..len],
&input[base..][..len],
&input[base + 2..][..len],
];

for (idx, out) in output[base + 1..][..len].iter_mut().enumerate() {
*out = (
center[idx] +
(sides[0][idx] + sides[1][idx] + sides[2][idx] + sides[3][idx]) * weight1 +
(diags[0][idx] + diags[1][idx] + diags[2][idx] + diags[3][idx]) * weight2
) * global_weight;
}

// left side
let len = height - 2;
let center = &input[width..];
let sides = [
input,
&input[width..],
&input[width + 1..],
&input[width * 2..],
];
let diags = [
input,
&input[1..],
&input[width * 2..],
&input[width * 2 + 1..],
];
for idx in 0..len {
output[width + idx * width] = (
center[idx * width] +
(sides[0][idx * width] + sides[1][idx * width] + sides[2][idx * width] + sides[3][idx * width]) * weight1 +
(diags[0][idx * width] + diags[1][idx * width] + diags[2][idx * width] + diags[3][idx * width]) * weight2
) * global_weight;
}

// right side
let len = height - 2;
let center = &input[width * 2 - 1..];
let sides = [
&input[width - 1..],
&input[width * 2 - 2..],
&input[width * 2 - 1..],
&input[width * 3 - 1..],
];
let diags = [
&input[width - 2..],
&input[width - 1..],
&input[width * 3 - 2..],
&input[width * 3 - 1..],
];
for idx in 0..len {
output[width * 2 - 1 + idx * width] = (
center[idx * width] +
(sides[0][idx * width] + sides[1][idx * width] + sides[2][idx * width] + sides[3][idx * width]) * weight1 +
(diags[0][idx * width] + diags[1][idx * width] + diags[2][idx * width] + diags[3][idx * width]) * weight2
) * global_weight;
}

// corners
output[0] = (
input[0] +
(input[0] + input[0] + input[1] + input[width]) * weight1 +
(input[0] + input[1] + input[width] + input[width + 1]) * weight2
) * global_weight;
output[width - 1] = (
input[width - 1] +
(input[width - 1] + input[width - 1] + input[width - 2] + input[width * 2 - 1]) * weight1 +
(input[width - 1] + input[width - 2] + input[width * 2 - 2] + input[width * 2 - 1]) * weight2
) * global_weight;
output[base] = (
input[base] +
(input[base] + input[base] + input[base - width] + input[base + 1]) * weight1 +
(input[base] + input[base - width] + input[base - width + 1] + input[base + 1]) * weight2
) * global_weight;
let last = width * height - 1;
output[last] = (
input[last] +
(input[last] + input[last] + input[last - width] + input[last - 1]) * weight1 +
(input[last] + input[last - width] + input[last - width - 1] + input[last - 1]) * weight2
) * global_weight;
}
6 changes: 6 additions & 0 deletions crates/jxl-render/src/filter/impls/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,9 @@ pub fn epf_step2(
&[(0, 0)],
);
}

pub fn apply_gabor_like(fb: [&mut SimpleGrid<f32>; 3], weights_xyb: [[f32; 2]; 3]) {
for (fb, [weight1, weight2]) in fb.into_iter().zip(weights_xyb) {
super::run_gabor_inner(fb, weight1, weight2);
}
}
22 changes: 22 additions & 0 deletions crates/jxl-render/src/filter/impls/x86_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,25 @@ pub fn epf_step2(
)
}
}

pub fn apply_gabor_like(fb: [&mut SimpleGrid<f32>; 3], weights_xyb: [[f32; 2]; 3]) {
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
// SAFETY: Features are checked above.
unsafe {
for (fb, [weight1, weight2]) in fb.into_iter().zip(weights_xyb) {
run_gabor_inner_avx2(fb, weight1, weight2)
}
}
return;
}

for (fb, [weight1, weight2]) in fb.into_iter().zip(weights_xyb) {
super::run_gabor_inner(fb, weight1, weight2);
}
}

#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
unsafe fn run_gabor_inner_avx2(fb: &mut SimpleGrid<f32>, weight1: f32, weight2: f32) {
super::run_gabor_inner(fb, weight1, weight2)
}

0 comments on commit 55d8954

Please sign in to comment.