Skip to content

Commit

Permalink
avx poly (#365)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo/365)
<!-- Reviewable:end -->
  • Loading branch information
spapinistarkware committed Mar 3, 2024
1 parent 54fa6ca commit ad739e3
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 6 deletions.
52 changes: 52 additions & 0 deletions src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use super::{as_cpu_vec, AVX512Backend};
use crate::core::backend::CPUBackend;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Col, ExtensionOf};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};

impl PolyOps<BaseField> for AVX512Backend {
fn new_canonical_ordered(
coset: CanonicCoset,
values: Col<Self, BaseField>,
) -> CircleEvaluation<Self, BaseField> {
let eval = CPUBackend::new_canonical_ordered(coset, as_cpu_vec(values));
CircleEvaluation::new(eval.domain, Col::<AVX512Backend, _>::from_iter(eval.values))
}

fn interpolate(eval: CircleEvaluation<Self, BaseField>) -> CirclePoly<Self, BaseField> {
let cpu_eval = CircleEvaluation::<CPUBackend, _>::new(eval.domain, as_cpu_vec(eval.values));
let cpu_poly = cpu_eval.interpolate();
CirclePoly::new(Col::<AVX512Backend, _>::from_iter(cpu_poly.coeffs))
}

fn eval_at_point<E: ExtensionOf<BaseField>>(
poly: &CirclePoly<Self, BaseField>,
point: CirclePoint<E>,
) -> E {
// TODO(spapini): Unnecessary allocation here.
let cpu_poly = CirclePoly::<CPUBackend, _>::new(as_cpu_vec(poly.coeffs.clone()));
cpu_poly.eval_at_point(point)
}

fn evaluate(
poly: &CirclePoly<Self, BaseField>,
domain: CircleDomain,
) -> CircleEvaluation<Self, BaseField> {
let cpu_poly = CirclePoly::<CPUBackend, _>::new(as_cpu_vec(poly.coeffs.clone()));
let cpu_eval = cpu_poly.evaluate(domain);
CircleEvaluation::new(
cpu_eval.domain,
Col::<AVX512Backend, _>::from_iter(cpu_eval.values),
)
}

fn extend(poly: &CirclePoly<Self, BaseField>, log_size: u32) -> CirclePoly<Self, BaseField> {
let cpu_poly = CirclePoly::<CPUBackend, _>::new(as_cpu_vec(poly.coeffs.clone()));
CirclePoly::new(Col::<AVX512Backend, _>::from_iter(
cpu_poly.extend(log_size).coeffs,
))
}
}
32 changes: 26 additions & 6 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod bit_reverse;
pub mod circle;
pub mod cm31;
pub mod m31;
pub mod qm31;
Expand All @@ -7,17 +8,15 @@ use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable};
use num_traits::Zero;

use self::bit_reverse::bit_reverse_m31;
use self::m31::PackedBaseField;
pub use self::m31::{PackedBaseField, K_BLOCK_SIZE};
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Column, FieldOps};
use crate::core::utils;

#[derive(Copy, Clone, Debug)]
pub struct AVX512Backend;

// BaseField.
// TODO(spapini): Unite with the M31AVX512 type.
pub const K_ELEMENTS: usize = 16;

unsafe impl Pod for PackedBaseField {}
unsafe impl Zeroable for PackedBaseField {
Expand Down Expand Up @@ -59,7 +58,7 @@ impl FieldOps<BaseField> for AVX512Backend {
impl Column<BaseField> for BaseFieldVec {
fn zeros(len: usize) -> Self {
Self {
data: vec![PackedBaseField::zeroed(); len.div_ceil(K_ELEMENTS)],
data: vec![PackedBaseField::zeroed(); len.div_ceil(K_BLOCK_SIZE)],
length: len,
}
}
Expand All @@ -74,15 +73,28 @@ impl Column<BaseField> for BaseFieldVec {
self.length
}
fn at(&self, index: usize) -> BaseField {
self.data[index / K_ELEMENTS].to_array()[index % K_ELEMENTS]
self.data[index / K_BLOCK_SIZE].to_array()[index % K_BLOCK_SIZE]
}
}

fn as_cpu_vec(values: BaseFieldVec) -> Vec<BaseField> {
let capacity = values.data.capacity() * K_BLOCK_SIZE;
unsafe {
let res = Vec::from_raw_parts(
values.data.as_ptr() as *mut BaseField,
values.length,
capacity,
);
std::mem::forget(values);
res
}
}

impl FromIterator<BaseField> for BaseFieldVec {
fn from_iter<I: IntoIterator<Item = BaseField>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
let mut res: Vec<_> = (&mut chunks).map(PackedBaseField::from_array).collect();
let mut length = res.len() * K_ELEMENTS;
let mut length = res.len() * K_BLOCK_SIZE;

if let Some(remainder) = chunks.into_remainder() {
if !remainder.is_empty() {
Expand Down Expand Up @@ -138,4 +150,12 @@ mod tests {
);
}
}

#[test]
fn test_as_cpu_vec() {
let original_vec = (1000..1100).map(BaseField::from).collect::<Vec<_>>();
let col = Col::<B, BaseField>::from_iter(original_vec.clone());
let vec = as_cpu_vec(col);
assert_eq!(vec, original_vec);
}
}

0 comments on commit ad739e3

Please sign in to comment.