From e88cbff2f622873a7fda95d5b2000f7d5c45c8a0 Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Sat, 17 Feb 2024 20:20:28 +0200 Subject: [PATCH] avx poly --- src/core/backend/avx512/circle.rs | 52 +++++++++++++++++++++++++++++++ src/core/backend/avx512/mod.rs | 21 +++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 src/core/backend/avx512/circle.rs diff --git a/src/core/backend/avx512/circle.rs b/src/core/backend/avx512/circle.rs new file mode 100644 index 000000000..e499584c7 --- /dev/null +++ b/src/core/backend/avx512/circle.rs @@ -0,0 +1,52 @@ +use super::{as_cpu_vec, AVX512Backend}; +use crate::core::backend::CPUBackend; +use crate::core::circle::CirclePoint; +use crate::core::fields::m31::BaseField; +use crate::core::fields::{Col, ExtensionOf}; +use crate::core::poly::circle::{ + CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps, +}; + +impl PolyOps for AVX512Backend { + fn new_canonical_ordered( + coset: CanonicCoset, + values: Col, + ) -> CircleEvaluation { + let eval = CPUBackend::new_canonical_ordered(coset, as_cpu_vec(values)); + CircleEvaluation::new(eval.domain, Col::::from_iter(eval.values)) + } + + fn interpolate(eval: CircleEvaluation) -> CirclePoly { + let cpu_eval = CircleEvaluation::::new(eval.domain, as_cpu_vec(eval.values)); + let cpu_poly = cpu_eval.interpolate(); + CirclePoly::new(Col::::from_iter(cpu_poly.coeffs)) + } + + fn eval_at_point>( + poly: &CirclePoly, + point: CirclePoint, + ) -> E { + // TODO(spapini): Unnecessary allocation here. + let cpu_poly = CirclePoly::::new(as_cpu_vec(poly.coeffs.clone())); + cpu_poly.eval_at_point(point) + } + + fn evaluate( + poly: &CirclePoly, + domain: CircleDomain, + ) -> CircleEvaluation { + let cpu_poly = CirclePoly::::new(as_cpu_vec(poly.coeffs.clone())); + let cpu_eval = cpu_poly.evaluate(domain); + CircleEvaluation::new( + cpu_eval.domain, + Col::::from_iter(cpu_eval.values), + ) + } + + fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly { + let cpu_poly = CirclePoly::::new(as_cpu_vec(poly.coeffs.clone())); + CirclePoly::new(Col::::from_iter( + cpu_poly.extend(log_size).coeffs, + )) + } +} diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs index a84fb28f0..698b54728 100644 --- a/src/core/backend/avx512/mod.rs +++ b/src/core/backend/avx512/mod.rs @@ -1,4 +1,5 @@ pub mod bit_reverse; +pub mod circle; pub mod cm31; pub mod m31; pub mod qm31; @@ -78,6 +79,19 @@ impl Column for BaseFieldVec { } } +fn as_cpu_vec(values: BaseFieldVec) -> Vec { + let capacity = values.data.capacity() * 16; + unsafe { + let res = Vec::from_raw_parts( + values.data.as_ptr() as *mut BaseField, + values.length, + capacity, + ); + std::mem::forget(values); + res + } +} + impl FromIterator for BaseFieldVec { fn from_iter>(iter: I) -> Self { let mut chunks = iter.into_iter().array_chunks(); @@ -138,4 +152,11 @@ mod tests { ); } } + + #[test] + fn test_as_cpu_vec() { + let col = Col::::from_iter((0..100).map(BaseField::from)); + let vec = as_cpu_vec(col); + assert_eq!(vec, (0..100).map(BaseField::from).collect::>()); + } }