Skip to content

Commit

Permalink
Incorporate fft in AVX backend
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 6, 2024
1 parent 58647c1 commit da83eab
Showing 1 changed file with 80 additions and 23 deletions.
103 changes: 80 additions & 23 deletions src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use super::fft::ifft;
use super::m31::PackedBaseField;
use super::{as_cpu_vec, AVX512Backend};
use crate::core::backend::CPUBackend;
use crate::core::circle::CirclePoint;
use crate::core::backend::avx512::fft::rfft;
use crate::core::backend::{CPUBackend, FieldOps};
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Col, ExtensionOf};
use crate::core::fields::{Col, FieldExpOps};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
Expand All @@ -12,41 +14,96 @@ impl PolyOps<BaseField> for AVX512Backend {
coset: CanonicCoset,
values: Col<Self, BaseField>,
) -> CircleEvaluation<Self, BaseField> {
// TODO(spapini): Optimize.
let eval = CPUBackend::new_canonical_ordered(coset, as_cpu_vec(values));
CircleEvaluation::new(eval.domain, Col::<AVX512Backend, _>::from_iter(eval.values))
}

fn interpolate(eval: CircleEvaluation<Self, BaseField>) -> CirclePoly<Self, BaseField> {
let cpu_eval = CircleEvaluation::<CPUBackend, _>::new(eval.domain, as_cpu_vec(eval.values));
let cpu_poly = cpu_eval.interpolate();
CirclePoly::new(Col::<AVX512Backend, _>::from_iter(cpu_poly.coeffs))
let mut values = eval.values;

// TODO(spapini): Precompute twiddles.
let twiddles = ifft::get_itwiddle_dbls(eval.domain);
// TODO(spapini): Remove.
AVX512Backend::bit_reverse_column(&mut values);
// TODO(spapini): Handle small cases.
let log_size = values.length.ilog2();

unsafe {
ifft::ifft(
std::mem::transmute(values.data.as_mut_ptr()),
&twiddles[1..],
log_size as usize,
);
}

// TODO(spapini): Fuse this multiplication / rotation.
let inv = BaseField::from_u32_unchecked(eval.domain.size() as u32).inverse();
let inv = PackedBaseField::from_array([inv; 16]);
for x in values.data.iter_mut() {
*x *= inv;
}

CirclePoly::new(values)
}

fn eval_at_point<E: ExtensionOf<BaseField>>(
poly: &CirclePoly<Self, BaseField>,
point: CirclePoint<E>,
fn eval_at_point<E: crate::core::fields::ExtensionOf<BaseField>>(
_poly: &CirclePoly<Self, BaseField>,
_point: crate::core::circle::CirclePoint<E>,
) -> E {
// TODO(spapini): Unnecessary allocation here.
let cpu_poly = CirclePoly::<CPUBackend, _>::new(as_cpu_vec(poly.coeffs.clone()));
cpu_poly.eval_at_point(point)
todo!()
}

fn evaluate(
poly: &CirclePoly<Self, BaseField>,
domain: CircleDomain,
) -> CircleEvaluation<Self, BaseField> {
let cpu_poly = CirclePoly::<CPUBackend, _>::new(as_cpu_vec(poly.coeffs.clone()));
let cpu_eval = cpu_poly.evaluate(domain);
CircleEvaluation::new(
cpu_eval.domain,
Col::<AVX512Backend, _>::from_iter(cpu_eval.values),
)
let mut values = poly.coeffs.clone();

// TODO(spapini): Precompute twiddles.
let twiddles = rfft::get_twiddle_dbls(domain);
// TODO(spapini): Handle small cases.
let log_size = values.length.ilog2();

unsafe {
rfft::fft(
std::mem::transmute(values.data.as_mut_ptr()),
&twiddles[1..],
log_size as usize,
);
}

// TODO(spapini): Remove.
AVX512Backend::bit_reverse_column(&mut values);

CircleEvaluation::new(domain, values)
}

fn extend(_poly: &CirclePoly<Self, BaseField>, _log_size: u32) -> CirclePoly<Self, BaseField> {
todo!()
}
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use crate::core::backend::avx512::AVX512Backend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::Column;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};

fn extend(poly: &CirclePoly<Self, BaseField>, log_size: u32) -> CirclePoly<Self, BaseField> {
let cpu_poly = CirclePoly::<CPUBackend, _>::new(as_cpu_vec(poly.coeffs.clone()));
CirclePoly::new(Col::<AVX512Backend, _>::from_iter(
cpu_poly.extend(log_size).coeffs,
))
#[test]
fn test_interpolate_and_eval() {
const LOG_SIZE: u32 = 6;
let domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let evaluation = CircleEvaluation::<AVX512Backend, _>::new(
domain,
(0..(1 << LOG_SIZE))
.map(BaseField::from_u32_unchecked)
.collect(),
);
let poly = evaluation.clone().interpolate();
let evaluation2 = poly.evaluate(domain);
assert_eq!(evaluation.values.to_vec(), evaluation2.values.to_vec());
}
}

0 comments on commit da83eab

Please sign in to comment.