From fb7a54894f9356600bea527f9dc474ec7fe3da0e Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Thu, 22 Feb 2024 16:06:12 +0200 Subject: [PATCH] Incorporate fft in AVX backend --- src/core/backend/avx512/circle.rs | 102 +++++++++++++++++++++------- src/core/backend/avx512/fft/rfft.rs | 9 +-- 2 files changed, 82 insertions(+), 29 deletions(-) diff --git a/src/core/backend/avx512/circle.rs b/src/core/backend/avx512/circle.rs index e499584c7..73303b8a9 100644 --- a/src/core/backend/avx512/circle.rs +++ b/src/core/backend/avx512/circle.rs @@ -1,8 +1,10 @@ +use super::fft::ifft; +use super::m31::PackedBaseField; use super::{as_cpu_vec, AVX512Backend}; -use crate::core::backend::CPUBackend; -use crate::core::circle::CirclePoint; +use crate::core::backend::avx512::fft::rfft; +use crate::core::backend::{CPUBackend, FieldOps}; use crate::core::fields::m31::BaseField; -use crate::core::fields::{Col, ExtensionOf}; +use crate::core::fields::{Col, Field}; use crate::core::poly::circle::{ CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps, }; @@ -12,41 +14,95 @@ impl PolyOps for AVX512Backend { coset: CanonicCoset, values: Col, ) -> CircleEvaluation { + // TODO(spapini): Optimize. let eval = CPUBackend::new_canonical_ordered(coset, as_cpu_vec(values)); CircleEvaluation::new(eval.domain, Col::::from_iter(eval.values)) } fn interpolate(eval: CircleEvaluation) -> CirclePoly { - let cpu_eval = CircleEvaluation::::new(eval.domain, as_cpu_vec(eval.values)); - let cpu_poly = cpu_eval.interpolate(); - CirclePoly::new(Col::::from_iter(cpu_poly.coeffs)) + let mut values = eval.values; + + // TODO(spapini): Precompute twiddles. + let twiddles = ifft::get_itwiddle_dbls(eval.domain); + // TODO(spapini): Remove. + AVX512Backend::bit_reverse_column(&mut values); + // TODO(spapini): Handle small cases. + let log_size = values.length.ilog2(); + + unsafe { + ifft::ifft( + std::mem::transmute(values.data.as_mut_ptr()), + &twiddles[1..], + log_size as usize, + ); + } + + // TODO(spapini): Fuse this multiplication / rotation. + let inv = BaseField::from_u32_unchecked(eval.domain.size() as u32).inverse(); + let inv = PackedBaseField::from_array([inv; 16]); + for x in values.data.iter_mut() { + *x *= inv; + } + + CirclePoly::new(values) } - fn eval_at_point>( - poly: &CirclePoly, - point: CirclePoint, + fn eval_at_point>( + _poly: &CirclePoly, + _point: crate::core::circle::CirclePoint, ) -> E { - // TODO(spapini): Unnecessary allocation here. - let cpu_poly = CirclePoly::::new(as_cpu_vec(poly.coeffs.clone())); - cpu_poly.eval_at_point(point) + todo!() } fn evaluate( poly: &CirclePoly, domain: CircleDomain, ) -> CircleEvaluation { - let cpu_poly = CirclePoly::::new(as_cpu_vec(poly.coeffs.clone())); - let cpu_eval = cpu_poly.evaluate(domain); - CircleEvaluation::new( - cpu_eval.domain, - Col::::from_iter(cpu_eval.values), - ) + let mut values = poly.coeffs.clone(); + + // TODO(spapini): Precompute twiddles. + let twiddles = rfft::get_twiddle_dbls(domain); + // TODO(spapini): Handle small cases. + let log_size = values.length.ilog2(); + + unsafe { + rfft::fft( + std::mem::transmute(values.data.as_mut_ptr()), + &twiddles[1..], + log_size as usize, + ); + } + + // TODO(spapini): Remove. + AVX512Backend::bit_reverse_column(&mut values); + + CircleEvaluation::new(domain, values) + } + + fn extend(_poly: &CirclePoly, _log_size: u32) -> CirclePoly { + todo!() } +} + +#[cfg(test)] +mod tests { + use crate::core::backend::avx512::AVX512Backend; + use crate::core::fields::m31::BaseField; + use crate::core::fields::Column; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; - fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly { - let cpu_poly = CirclePoly::::new(as_cpu_vec(poly.coeffs.clone())); - CirclePoly::new(Col::::from_iter( - cpu_poly.extend(log_size).coeffs, - )) + #[test] + fn test_interpolate_and_eval() { + const LOG_SIZE: u32 = 6; + let domain = CanonicCoset::new(LOG_SIZE).circle_domain(); + let evaluation = CircleEvaluation::::new( + domain, + (0..(1 << LOG_SIZE)) + .map(BaseField::from_u32_unchecked) + .collect(), + ); + let poly = evaluation.clone().interpolate(); + let evaluation2 = poly.evaluate(domain); + assert_eq!(evaluation.values.to_vec(), evaluation2.values.to_vec()); } } diff --git a/src/core/backend/avx512/fft/rfft.rs b/src/core/backend/avx512/fft/rfft.rs index 3a803e3bd..8e292c143 100644 --- a/src/core/backend/avx512/fft/rfft.rs +++ b/src/core/backend/avx512/fft/rfft.rs @@ -432,6 +432,7 @@ mod tests { use crate::core::backend::CPUBackend; use crate::core::fft::butterfly; use crate::core::fields::m31::BaseField; + use crate::core::fields::Column; use crate::core::poly::circle::{CanonicCoset, CircleDomain, CirclePoly}; use crate::core::utils::bit_reverse; @@ -587,9 +588,7 @@ mod tests { ); // Compare. - for i in 0..expected_coeffs.len() { - assert_eq!(values[i], expected_coeffs[i]); - } + assert_eq!(values.to_vec(), expected_coeffs); } } } @@ -617,9 +616,7 @@ mod tests { ); // Compare. - for i in 0..expected_coeffs.len() { - assert_eq!(values[i], expected_coeffs[i]); - } + assert_eq!(values.to_vec(), expected_coeffs); } }