diff --git a/src/core/backend/avx512/bit_reverse.rs b/src/core/backend/avx512/bit_reverse.rs index 140f62807..420125b48 100644 --- a/src/core/backend/avx512/bit_reverse.rs +++ b/src/core/backend/avx512/bit_reverse.rs @@ -72,7 +72,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { } } -/// Bit reverses 16 packed M31 values. +/// Bit reverses 256 M31 values, packed in 16 words of 16 elements each. fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { let mut data: [__m512i; 16] = unsafe { std::mem::transmute(data) }; // L is an input to _mm512_permutex2var_epi32, and it is used to diff --git a/src/core/backend/avx512/circle.rs b/src/core/backend/avx512/circle.rs new file mode 100644 index 000000000..45a3d92d5 --- /dev/null +++ b/src/core/backend/avx512/circle.rs @@ -0,0 +1,51 @@ +use super::{as_cpu_vec, AVX512Backend}; +use crate::core::backend::CPUBackend; +use crate::core::fields::m31::BaseField; +use crate::core::fields::Col; +use crate::core::poly::circle::{ + CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps, +}; + +impl PolyOps for AVX512Backend { + fn new_canonical_ordered( + coset: CanonicCoset, + values: Col, + ) -> CircleEvaluation { + let eval = CPUBackend::new_canonical_ordered(coset, as_cpu_vec(values)); + CircleEvaluation::new(eval.domain, Col::::from_iter(eval.values)) + } + + fn interpolate(eval: CircleEvaluation) -> CirclePoly { + let cpu_eval = CircleEvaluation::::new(eval.domain, as_cpu_vec(eval.values)); + let cpu_poly = cpu_eval.interpolate(); + CirclePoly::new(Col::::from_iter(cpu_poly.coeffs)) + } + + fn eval_at_point>( + poly: &CirclePoly, + point: crate::core::circle::CirclePoint, + ) -> E { + // TODO(spapini): Unnecessary allocation here. + let cpu_poly = CirclePoly::::new(as_cpu_vec(poly.coeffs.clone())); + cpu_poly.eval_at_point(point) + } + + fn evaluate( + poly: &CirclePoly, + domain: CircleDomain, + ) -> CircleEvaluation { + let cpu_poly = CirclePoly::::new(as_cpu_vec(poly.coeffs.clone())); + let cpu_eval = cpu_poly.evaluate(domain); + CircleEvaluation::new( + cpu_eval.domain, + Col::::from_iter(cpu_eval.values), + ) + } + + fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly { + let cpu_poly = CirclePoly::::new(as_cpu_vec(poly.coeffs.clone())); + CirclePoly::new(Col::::from_iter( + cpu_poly.extend(log_size).coeffs, + )) + } +} diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs index 8ed77812c..cbfb95643 100644 --- a/src/core/backend/avx512/mod.rs +++ b/src/core/backend/avx512/mod.rs @@ -1,4 +1,5 @@ pub mod bit_reverse; +pub mod circle; use std::ops::Index; @@ -56,6 +57,17 @@ impl Column for BaseFieldVec { } } +fn as_cpu_vec(values: BaseFieldVec) -> Vec { + let capacity = values.len() * 16; + unsafe { + Vec::from_raw_parts( + values.data.as_ptr() as *mut BaseField, + values.length, + capacity, + ) + } +} + impl Index for BaseFieldVec { type Output = BaseField; fn index(&self, index: usize) -> &Self::Output {