Skip to content

Commit

Permalink
Autovectorization pass
Browse files Browse the repository at this point in the history
  • Loading branch information
shssoichiro committed May 4, 2022
1 parent 6889a67 commit 03effba
Showing 1 changed file with 62 additions and 45 deletions.
107 changes: 62 additions & 45 deletions src/denoise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,38 +235,45 @@ where
.map(|f| f[p].data_origin())
.collect::<ArrayVec<_, TB_SIZE>>();

for y in (0..effective_height).step_by(INC) {
for x in (0..=(pad_width - SB_SIZE)).step_by(INC) {
for z in 0..TB_SIZE {
self.proc0(
&src_planes[z][x..],
&self.hw[(BLOCK_AREA * z)..],
&mut dftr[(BLOCK_AREA * z)..],
src_stride,
// SAFETY: We know the size of the planes we're working on,
// so we can safely ensure we are not out of bounds.
// There are a fair number of unsafe function calls here
// which are unsafe for optimization purposes.
// All are safe as long as we do not pass out-of-bounds parameters.
unsafe {
for y in (0..effective_height).step_by(INC) {
for x in (0..=(pad_width - SB_SIZE)).step_by(INC) {
for z in 0..TB_SIZE {
self.proc0(
&src_planes[z][x..],
&self.hw[(BLOCK_AREA * z)..],
&mut dftr[(BLOCK_AREA * z)..],
src_stride,
SB_SIZE,
self.src_scale,
);
}

self.real_to_complex_3d(&dftr, &mut dftc);
self.remove_mean(&mut dftc, &self.dftgc, &mut means);

self.filter_coeffs(&mut dftc);

self.add_mean(&mut dftc, &means);
self.complex_to_real_3d(&dftc, &mut dftr);

self.proc1(
&dftr[(TB_MIDPOINT * BLOCK_AREA)..],
&self.hw[(TB_MIDPOINT * BLOCK_AREA)..],
&mut ebuff[(y * ebuff_stride + x)..],
SB_SIZE,
self.src_scale,
ebuff_stride,
);
}

self.real_to_complex_3d(&dftr, &mut dftc);
self.remove_mean(&mut dftc, &self.dftgc, &mut means);

self.filter_coeffs(&mut dftc);

self.add_mean(&mut dftc, &means);
self.complex_to_real_3d(&dftc, &mut dftr);

self.proc1(
&dftr[(TB_MIDPOINT * BLOCK_AREA)..],
&self.hw[(TB_MIDPOINT * BLOCK_AREA)..],
&mut ebuff[(y * ebuff_stride + x)..],
SB_SIZE,
ebuff_stride,
);
}

for q in 0..TB_SIZE {
src_planes[q] = &src_planes[q][(INC * src_stride)..];
for q in 0..TB_SIZE {
src_planes[q] = &src_planes[q][(INC * src_stride)..];
}
}
}

Expand Down Expand Up @@ -313,6 +320,7 @@ where
hw
}

#[inline(always)]
// Hanning windowing
fn spatial_window(n: f64) -> f64 {
0.5 - 0.5 * (2.0 * PI * n / SB_SIZE as f64).cos()
Expand Down Expand Up @@ -345,35 +353,44 @@ where
}
}

fn proc0(
#[inline]
unsafe fn proc0(
&self, s0: &[T], s1: &[f32], dest: &mut [f32], p0: usize, p1: usize,
src_scale: f32,
) {
let s0 = s0.chunks(p0);
let s1 = s1.chunks(p1);
let dest = dest.chunks_mut(p1);
let s0 = s0.as_ptr();
let s1 = s1.as_ptr();
let dest = dest.as_mut_ptr();

for (s0, (s1, dest)) in s0.zip(s1.zip(dest)).take(p1) {
for u in 0..p1 {
for v in 0..p1 {
dest[v] = u16::cast_from(s0[v]) as f32 * src_scale * s1[v];
let s0 = s0.add(u * p0 + v);
let s1 = s1.add(u * p1 + v);
let dest = dest.add(u * p1 + v);
dest.write(u16::cast_from(s0.read()) as f32 * src_scale * s1.read())
}
}
}

fn proc1(
#[inline]
unsafe fn proc1(
&self, s0: &[f32], s1: &[f32], dest: &mut [f32], p0: usize, p1: usize,
) {
let s0 = s0.chunks(p0);
let s1 = s1.chunks(p0);
let dest = dest.chunks_mut(p1);
let s0 = s0.as_ptr();
let s1 = s1.as_ptr();
let dest = dest.as_mut_ptr();

for (s0, (s1, dest)) in s0.zip(s1.zip(dest)).take(p0) {
for u in 0..p0 {
for v in 0..p0 {
dest[v] += s0[v] * s1[v];
let s0 = s0.add(u * p0 + v);
let s1 = s1.add(u * p0 + v);
let dest = dest.add(u * p1 + v);
dest.write(s0.read().mul_add(s1.read(), dest.read()));
}
}
}

#[inline]
fn remove_mean(
&self, dftc: &mut [Complex<f32>; COMPLEX_COUNT],
dftgc: &[Complex<f32>; COMPLEX_COUNT],
Expand All @@ -389,6 +406,7 @@ where
}
}

#[inline]
fn add_mean(
&self, dftc: &mut [Complex<f32>; COMPLEX_COUNT],
means: &[Complex<f32>; COMPLEX_COUNT],
Expand All @@ -399,6 +417,7 @@ where
}
}

#[inline]
// Applies a generalized wiener filter
fn filter_coeffs(&self, dftc: &mut [Complex<f32>; COMPLEX_COUNT]) {
for h in 0..COMPLEX_COUNT {
Expand Down Expand Up @@ -495,11 +514,8 @@ where
for (ebuff, dest) in ebuff.zip(dest).take(dest_height) {
for x in 0..dest_width {
let fval = ebuff[x].mul_add(self.dest_scale, 0.5);
dest[x] = clamp(
T::cast_from(fval.round() as u16),
T::cast_from(0u16),
self.peak,
);
dest[x] =
clamp(T::cast_from(fval as u16), T::cast_from(0u16), self.peak);
}
}
}
Expand Down Expand Up @@ -544,6 +560,7 @@ where
}
}

#[inline(always)]
fn extra(a: usize, b: usize) -> usize {
if a % b > 0 {
b - (a % b)
Expand Down

0 comments on commit 03effba

Please sign in to comment.