Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AVX regular fft #390

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 68 additions & 69 deletions src/core/backend/avx512/fft/ifft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,24 @@ pub unsafe fn ifft_lower_with_vecwise(

for index_h in 0..(1 << (log_size - fft_layers)) {
ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h);
let mut layer = VECWISE_FFT_BITS;
while fft_layers - layer >= 3 {
ifft3_loop(
values,
&twiddle_dbl[(layer - 1)..],
fft_layers - layer - 3,
layer,
index_h,
);
layer += 3;
}
match fft_layers - layer {
2 => {
ifft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h);
}
1 => {
ifft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h);
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) {
match fft_layers - layer {
1 => {
ifft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h);
}
2 => {
ifft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h);
}
_ => {
ifft3_loop(
values,
&twiddle_dbl[(layer - 1)..],
fft_layers - layer - 3,
layer,
index_h,
);
}
}
_ => {}
}
}
}
Expand All @@ -123,26 +122,25 @@ pub unsafe fn ifft_lower_without_vecwise(
assert!(log_size >= VECS_LOG_SIZE);

for index_h in 0..(1 << (log_size - fft_layers - VECS_LOG_SIZE)) {
let mut layer = 0;
while fft_layers - layer >= 3 {
ifft3_loop(
values,
&twiddle_dbl[layer..],
fft_layers - layer - 3,
layer + VECS_LOG_SIZE,
index_h,
);
layer += 3;
}
let fixed_layer = layer + VECS_LOG_SIZE;
match fft_layers - layer {
2 => {
ifft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h);
}
1 => {
ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h);
for layer in (0..fft_layers).step_by(3) {
let fixed_layer = layer + VECS_LOG_SIZE;
match fft_layers - layer {
1 => {
ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h);
}
2 => {
ifft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h);
}
_ => {
ifft3_loop(
values,
&twiddle_dbl[layer..],
fft_layers - layer - 3,
fixed_layer,
index_h,
);
}
}
_ => {}
}
}
}
Expand Down Expand Up @@ -422,7 +420,7 @@ pub fn get_itwiddle_dbls(domain: CircleDomain) -> Vec<Vec<i32>> {
res
}

/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements.
/// Applies 3 ibutterfly layers on 8 vectors of 16 M31 elements.
/// Vectorized over the 16 elements of the vectors.
/// Used for radix-8 ifft.
/// Each butterfly layer, has 3 AVX butterflies.
Expand All @@ -432,7 +430,7 @@ pub fn get_itwiddle_dbls(domain: CircleDomain) -> Vec<Vec<i32>> {
/// offset - The offset of the first value in the array.
/// log_step - The log of the distance in the array, in M31 elements, between each pair of
/// values that need to be transformed. For layer i this is i - 4.
/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of butterflies.
/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of ibutterflies.
/// Each layer has 4/2/1 twiddles.
/// # Safety
pub unsafe fn ifft3(
Expand All @@ -453,19 +451,19 @@ pub unsafe fn ifft3(
let mut val6 = _mm512_load_epi32(values.add(offset + (6 << log_step)).cast_const());
let mut val7 = _mm512_load_epi32(values.add(offset + (7 << log_step)).cast_const());

// Apply the first layer of butterflies.
// Apply the first layer of ibutterflies.
(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
(val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1]));
(val4, val5) = avx_ibutterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2]));
(val6, val7) = avx_ibutterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3]));

// Apply the second layer of butterflies.
// Apply the second layer of ibutterflies.
(val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0]));
(val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0]));
(val4, val6) = avx_ibutterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1]));
(val5, val7) = avx_ibutterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1]));

// Apply the third layer of butterflies.
// Apply the third layer of ibutterflies.
(val0, val4) = avx_ibutterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0]));
(val1, val5) = avx_ibutterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0]));
(val2, val6) = avx_ibutterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0]));
Expand All @@ -482,17 +480,17 @@ pub unsafe fn ifft3(
_mm512_store_epi32(values.add(offset + (7 << log_step)), val7);
}

/// Applies 2 butterfly layers on 4 vectors of 16 M31 elements.
/// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements.
/// Vectorized over the 16 elements of the vectors.
/// Used for radix-4 ifft.
/// Each butterfly layer, has 2 AVX butterflies.
/// Each ibutterfly layer, has 2 AVX butterflies.
/// Total of 4 AVX butterflies.
/// Parameters:
/// values - Pointer to the entire value array.
/// offset - The offset of the first value in the array.
/// log_step - The log of the distance in the array, in M31 elements, between each pair of
/// values that need to be transformed. For layer i this is i - 4.
/// twiddles_dbl0/1 - The double of the twiddles for the 2 layers of butterflies.
/// twiddles_dbl0/1 - The double of the twiddles for the 2 layers of ibutterflies.
/// Each layer has 2/1 twiddles.
/// # Safety
pub unsafe fn ifft2(
Expand Down Expand Up @@ -523,14 +521,14 @@ pub unsafe fn ifft2(
_mm512_store_epi32(values.add(offset + (3 << log_step)), val3);
}

/// Applies 1 butterfly layers on 2 vectors of 16 M31 elements.
/// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements.
/// Vectorized over the 16 elements of the vectors.
/// Parameters:
/// values - Pointer to the entire value array.
/// offset - The offset of the first value in the array.
/// log_step - The log of the distance in the array, in M31 elements, between each pair of
/// values that need to be transformed. For layer i this is i - 4.
/// twiddles_dbl0 - The double of the twiddles for the butterfly layer.
/// twiddles_dbl0 - The double of the twiddles for the ibutterfly layer.
/// # Safety
pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) {
// Load the 2 AVX vectors from the array.
Expand Down Expand Up @@ -695,27 +693,28 @@ mod tests {

#[test]
fn test_ifft_lower_with_vecwise() {
let log_size = 5 + 3 + 3;
let domain = CanonicCoset::new(log_size).circle_domain();
let values = (0..domain.size())
.map(|i| BaseField::from_u32_unchecked(i as u32))
.collect::<Vec<_>>();
let expected_coeffs = ref_ifft(domain, values.clone());

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_itwiddle_dbls(domain);

unsafe {
ifft_lower_with_vecwise(
std::mem::transmute(values.data.as_mut_ptr()),
&twiddle_dbls[1..],
log_size as usize,
log_size as usize,
);

// Compare.
assert_eq!(values.to_vec(), expected_coeffs);
for log_size in 5..12 {
let domain = CanonicCoset::new(log_size).circle_domain();
let values = (0..domain.size())
.map(|i| BaseField::from_u32_unchecked(i as u32))
.collect::<Vec<_>>();
let expected_coeffs = ref_ifft(domain, values.clone());

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_itwiddle_dbls(domain);

unsafe {
ifft_lower_with_vecwise(
std::mem::transmute(values.data.as_mut_ptr()),
&twiddle_dbls[1..],
log_size as usize,
log_size as usize,
);

// Compare.
assert_eq!(values.to_vec(), expected_coeffs);
}
}
}

Expand Down Expand Up @@ -748,7 +747,7 @@ mod tests {

#[test]
fn test_ifft_full() {
for i in 5..=5 + 3 + 3 {
for i in 5..12 {
run_ifft_full_test(i);
}
}
Expand Down
Loading