Skip to content

Commit

Permalink
Use precomputed twiddles in avx
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 12, 2024
1 parent 034a430 commit db9ada9
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 31 deletions.
2 changes: 1 addition & 1 deletion benches/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub fn avx512_ifft(c: &mut criterion::Criterion) {

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_itwiddle_dbls(domain);
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);

c.bench_function("avx ifft", |b| {
b.iter(|| unsafe {
Expand Down
55 changes: 40 additions & 15 deletions src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use bytemuck::cast_slice;
use itertools::Itertools;

use super::fft::{ifft, CACHED_FFT_LOG_SIZE};
use super::m31::PackedBaseField;
Expand All @@ -19,7 +20,7 @@ use crate::core::poly::BitReversedOrder;
// TODO(spapini): Everything is returned in redundant representation, where values can also be P.
// Decide if and when it's ok and what to do if it's not.
impl PolyOps for AVX512Backend {
type Twiddles = ();
type Twiddles = Vec<i32>;

fn new_canonical_ordered(
coset: CanonicCoset,
Expand All @@ -32,23 +33,25 @@ impl PolyOps for AVX512Backend {

fn interpolate(
eval: CircleEvaluation<Self, BaseField, BitReversedOrder>,
_itwiddles: &TwiddleTree<Self>,
itwiddles: &TwiddleTree<Self>,
) -> CirclePoly<Self> {
let mut values = eval.values;
let log_size = values.length.ilog2();

// TODO(spapini): Precompute twiddles.
let twiddle_dbls = ifft::get_itwiddle_dbls(eval.domain);
// TODO(spapini): Handle small cases.
let twiddle_buffer = &itwiddles.itwiddles;
let twiddles = (0..eval.domain.half_coset.log_size())
.map(|i| {
let len = 1 << i;
&twiddle_buffer[twiddle_buffer.len() - len * 2..twiddle_buffer.len() - len]
})
.rev()
.collect_vec();

// Safe because [PackedBaseField] is aligned on 64 bytes.
unsafe {
ifft::ifft(
std::mem::transmute(values.data.as_mut_ptr()),
&twiddle_dbls[1..]
.iter()
.map(|x| x.as_slice())
.collect::<Vec<_>>(),
&twiddles,
log_size as usize,
);
}
Expand Down Expand Up @@ -98,7 +101,7 @@ impl PolyOps for AVX512Backend {
fn evaluate(
poly: &CirclePoly<Self>,
domain: CircleDomain,
_twiddles: &TwiddleTree<Self>,
twiddles: &TwiddleTree<Self>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// TODO(spapini): Precompute twiddles.
// TODO(spapini): Handle small cases.
Expand All @@ -109,17 +112,24 @@ impl PolyOps for AVX512Backend {
"Can only evaluate on larger domains"
);

let twiddles = rfft::get_twiddle_dbls(domain);
let twiddle_buffer = &twiddles.twiddles;
let twiddles = (0..domain.half_coset.log_size())
.map(|i| {
let len = 1 << i;
&twiddle_buffer[twiddle_buffer.len() - len * 2..twiddle_buffer.len() - len]
})
.rev()
.collect_vec();

// Evaluate on a big domains by evaluating on several subdomains.
let log_subdomains = log_size - fft_log_size;
let mut values = Vec::with_capacity(domain.size() >> VECS_LOG_SIZE);
for i in 0..(1 << log_subdomains) {
// The subdomain twiddles are a slice of the large domain twiddles.
let subdomain_twiddles = (1..fft_log_size)
let subdomain_twiddles = (0..(fft_log_size - 1))
.map(|layer_i| {
&twiddles[layer_i]
[i << (fft_log_size - 1 - layer_i)..(i + 1) << (fft_log_size - 1 - layer_i)]
[i << (fft_log_size - 2 - layer_i)..(i + 1) << (fft_log_size - 2 - layer_i)]
})
.collect::<Vec<_>>();

Expand Down Expand Up @@ -150,10 +160,25 @@ impl PolyOps for AVX512Backend {
}

fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {
let mut twiddles = Vec::with_capacity(coset.size());
let mut itwiddles = Vec::with_capacity(coset.size());

// Optimize.
for layer in &rfft::get_twiddle_dbls(coset)[1..] {
twiddles.extend(layer);
}
twiddles.push(2);
assert_eq!(twiddles.len(), coset.size());
for layer in &ifft::get_itwiddle_dbls(coset)[1..] {
itwiddles.extend(layer);
}
itwiddles.push(2);
assert_eq!(itwiddles.len(), coset.size());

TwiddleTree {
root_coset: coset,
twiddles: (),
itwiddles: (),
twiddles,
itwiddles,
}
}
}
Expand Down
12 changes: 5 additions & 7 deletions src/core/backend/avx512/fft/ifft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use std::arch::x86_64::{
use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS};
use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE};
use crate::core::circle::Coset;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::CircleDomain;
use crate::core::utils::bit_reverse;

/// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values.
Expand Down Expand Up @@ -374,9 +374,7 @@ pub unsafe fn vecwise_ibutterflies(
val0.deinterleave_with(val1)
}

pub fn get_itwiddle_dbls(domain: CircleDomain) -> Vec<Vec<i32>> {
let mut coset = domain.half_coset;

pub fn get_itwiddle_dbls(mut coset: Coset) -> Vec<Vec<i32>> {
let mut res = vec![];
res.push(
coset
Expand Down Expand Up @@ -643,7 +641,7 @@ mod tests {
#[test]
fn test_vecwise_ibutterflies() {
let domain = CanonicCoset::new(5).circle_domain();
let twiddle_dbls = get_itwiddle_dbls(domain);
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);
assert_eq!(twiddle_dbls.len(), 5);
let values0: [i32; 16] = std::array::from_fn(|i| i as i32);
let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32);
Expand Down Expand Up @@ -681,7 +679,7 @@ mod tests {

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_itwiddle_dbls(domain);
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);

unsafe {
ifft_lower_with_vecwise(
Expand Down Expand Up @@ -709,7 +707,7 @@ mod tests {

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_itwiddle_dbls(domain);
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);

unsafe {
ifft(
Expand Down
2 changes: 1 addition & 1 deletion src/core/backend/avx512/fft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ mod tests {

#[test]
fn test_twiddle_relation() {
let ts = get_itwiddle_dbls(CanonicCoset::new(5).circle_domain());
let ts = get_itwiddle_dbls(CanonicCoset::new(5).half_coset());
let t0 = ts[0]
.iter()
.copied()
Expand Down
12 changes: 5 additions & 7 deletions src/core/backend/avx512/fft/rfft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::arch::x86_64::{
use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS};
use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE};
use crate::core::poly::circle::CircleDomain;
use crate::core::circle::Coset;
use crate::core::utils::bit_reverse;

/// Performs a Circle Fast Fourier Transform (ICFFT) on the given values.
Expand Down Expand Up @@ -347,9 +347,7 @@ pub unsafe fn vecwise_butterflies(
val0.interleave_with(val1)
}

pub fn get_twiddle_dbls(domain: CircleDomain) -> Vec<Vec<i32>> {
let mut coset = domain.half_coset;

pub fn get_twiddle_dbls(mut coset: Coset) -> Vec<Vec<i32>> {
let mut res = vec![];
res.push(coset.iter().map(|p| (p.y.0 * 2) as i32).collect::<Vec<_>>());
bit_reverse(res.last_mut().unwrap());
Expand Down Expand Up @@ -606,7 +604,7 @@ mod tests {
#[test]
fn test_vecwise_butterflies() {
let domain = CanonicCoset::new(5).circle_domain();
let twiddle_dbls = get_twiddle_dbls(domain);
let twiddle_dbls = get_twiddle_dbls(domain.half_coset);
assert_eq!(twiddle_dbls.len(), 5);
let values0: [i32; 16] = std::array::from_fn(|i| i as i32);
let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32);
Expand Down Expand Up @@ -648,7 +646,7 @@ mod tests {

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_twiddle_dbls(domain);
let twiddle_dbls = get_twiddle_dbls(domain.half_coset);

unsafe {
fft_lower_with_vecwise(
Expand Down Expand Up @@ -676,7 +674,7 @@ mod tests {

// Compute.
let mut values = BaseFieldVec::from_iter(values);
let twiddle_dbls = get_twiddle_dbls(domain);
let twiddle_dbls = get_twiddle_dbls(domain.half_coset);

unsafe {
transpose_vecs(
Expand Down

0 comments on commit db9ada9

Please sign in to comment.