Skip to content

Commit

Permalink
Implement Backend for AVX. (#561)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 authored Apr 2, 2024
1 parent ddc8d3b commit e2d213e
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 75 deletions.
2 changes: 1 addition & 1 deletion benches/eval_at_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub fn avx512_eval_at_secure_point(c: &mut criterion::Criterion) {
let rng = &mut StdRng::seed_from_u64(0);

let domain = CanonicCoset::new(log_size as u32).circle_domain();
let evaluation = CircleEvaluation::<AVX512Backend, _, NaturalOrder>::new(
let evaluation = CircleEvaluation::<AVX512Backend, BaseField, NaturalOrder>::new(
domain,
(0..(1 << log_size))
.map(BaseField::from_u32_unchecked)
Expand Down
25 changes: 14 additions & 11 deletions src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use num_traits::One;

use super::fft::{ifft, CACHED_FFT_LOG_SIZE};
use super::m31::PackedBaseField;
use super::qm31::PackedQM31;
use super::qm31::PackedSecureField;
use super::{as_cpu_vec, AVX512Backend, K_BLOCK_SIZE, VECS_LOG_SIZE};
use crate::core::backend::avx512::fft::rfft;
use crate::core::backend::avx512::BaseFieldVec;
Expand Down Expand Up @@ -131,7 +131,10 @@ impl PolyOps for AVX512Backend {
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// TODO(spapini): Optimize.
let eval = CPUBackend::new_canonical_ordered(coset, as_cpu_vec(values));
CircleEvaluation::new(eval.domain, Col::<AVX512Backend, _>::from_iter(eval.values))
CircleEvaluation::new(
eval.domain,
Col::<AVX512Backend, BaseField>::from_iter(eval.values),
)
}

fn interpolate(
Expand Down Expand Up @@ -174,10 +177,10 @@ impl PolyOps for AVX512Backend {
// 8 lowest mappings produce the first 2^8 twiddles. Separate to optimize each calculation.
let (map_low, map_high) = mappings.split_at(4);
let twiddle_lows =
PackedQM31::from_array(&std::array::from_fn(|i| Self::twiddle_at(map_low, i)));
PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_low, i)));
let (map_mid, map_high) = map_high.split_at(4);
let twiddle_mids =
PackedQM31::from_array(&std::array::from_fn(|i| Self::twiddle_at(map_mid, i)));
PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_mid, i)));

// Compute the high twiddle steps.
let twiddle_steps = Self::twiddle_steps(map_high);
Expand All @@ -186,21 +189,21 @@ impl PolyOps for AVX512Backend {
// of the current index. For every 2^n alligned chunk of 2^n elements, the twiddle
// array is the same, denoted twiddle_low. Use this to compute sums of (coeff *
// twiddle_high) mod 2^n, then multiply by twiddle_low, and sum to get the final result.
let mut sum = PackedQM31::zeroed();
let mut sum = PackedSecureField::zeroed();
let mut twiddle_high = SecureField::one();
for (i, coeff_chunk) in poly.coeffs.data.array_chunks::<K_BLOCK_SIZE>().enumerate() {
// For every chunk of 2 ^ 4 * 2 ^ 4 = 2 ^ 8 elements, the twiddle high is the same.
// Multiply it by every mid twiddle factor to get the factors for the current chunk.
let high_twiddle_factors =
(PackedQM31::broadcast(twiddle_high) * twiddle_mids).to_array();
(PackedSecureField::broadcast(twiddle_high) * twiddle_mids).to_array();

// Sum the coefficients multiplied by each corrseponsing twiddle. Result is effectivley
// an array[16] where the value at index 'i' is the sum of all coefficients at indices
// that are i mod 16.
for (&packed_coeffs, &mid_twiddle) in
coeff_chunk.iter().zip(high_twiddle_factors.iter())
{
sum += PackedQM31::broadcast(mid_twiddle).mul_packed_m31(packed_coeffs);
sum += PackedSecureField::broadcast(mid_twiddle).mul_packed_m31(packed_coeffs);
}

// Advance twiddle high.
Expand Down Expand Up @@ -346,7 +349,7 @@ mod tests {
fn test_interpolate_and_eval() {
for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 4) {
let domain = CanonicCoset::new(log_size as u32).circle_domain();
let evaluation = CircleEvaluation::<AVX512Backend, _, BitReversedOrder>::new(
let evaluation = CircleEvaluation::<AVX512Backend, BaseField, BitReversedOrder>::new(
domain,
(0..(1 << log_size))
.map(BaseField::from_u32_unchecked)
Expand All @@ -364,7 +367,7 @@ mod tests {
let log_size = log_size as u32;
let domain = CanonicCoset::new(log_size).circle_domain();
let domain_ext = CanonicCoset::new(log_size + 3).circle_domain();
let evaluation = CircleEvaluation::<AVX512Backend, _, BitReversedOrder>::new(
let evaluation = CircleEvaluation::<AVX512Backend, BaseField, BitReversedOrder>::new(
domain,
(0..(1 << log_size))
.map(BaseField::from_u32_unchecked)
Expand All @@ -384,7 +387,7 @@ mod tests {
fn test_eval_at_point() {
for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 4) {
let domain = CanonicCoset::new(log_size as u32).circle_domain();
let evaluation = CircleEvaluation::<AVX512Backend, _, NaturalOrder>::new(
let evaluation = CircleEvaluation::<AVX512Backend, BaseField, NaturalOrder>::new(
domain,
(0..(1 << log_size))
.map(BaseField::from_u32_unchecked)
Expand Down Expand Up @@ -427,7 +430,7 @@ mod tests {

for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 2) {
let domain = CanonicCoset::new(log_size as u32).circle_domain();
let evaluation = CircleEvaluation::<AVX512Backend, _, NaturalOrder>::new(
let evaluation = CircleEvaluation::<AVX512Backend, BaseField, NaturalOrder>::new(
domain,
(0..(1 << log_size))
.map(BaseField::from_u32_unchecked)
Expand Down
2 changes: 1 addition & 1 deletion src/core/backend/avx512/cm31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::core::fields::FieldExpOps;

/// AVX implementation for the complex extension field of M31.
/// See [crate::core::fields::cm31::CM31] for more information.
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
pub struct PackedCM31(pub [PackedBaseField; 2]);
impl PackedCM31 {
pub fn broadcast(value: CM31) -> Self {
Expand Down
96 changes: 88 additions & 8 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ pub mod quotients;
pub mod tranpose_utils;

use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable};
use itertools::izip;
use itertools::{izip, Itertools};
use num_traits::Zero;

use self::bit_reverse::bit_reverse_m31;
use self::cm31::PackedCM31;
pub use self::m31::{PackedBaseField, K_BLOCK_SIZE};
use self::qm31::PackedQM31;
use super::{Column, ColumnOps};
use self::qm31::PackedSecureField;
use super::{Backend, Column, ColumnOps};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
Expand All @@ -28,6 +28,8 @@ pub const VECS_LOG_SIZE: usize = 4;
#[derive(Copy, Clone, Debug)]
pub struct AVX512Backend;

impl Backend for AVX512Backend {}

// BaseField.
// TODO(spapini): Unite with the M31AVX512 type.

Expand Down Expand Up @@ -134,10 +136,88 @@ impl FromIterator<BaseField> for BaseFieldVec {
}
}

#[derive(Clone, Debug)]
pub struct SecureFieldVec {
pub data: Vec<PackedSecureField>,
length: usize,
}

impl ColumnOps<SecureField> for AVX512Backend {
type Column = SecureFieldVec;

fn bit_reverse_column(column: &mut Self::Column) {
// Fallback to cpu bit_reverse.
// TODO(AlonH): Implement AVX512 bit_reverse for SecureField.
utils::bit_reverse(column.to_vec().as_mut_slice());
}
}

impl FieldOps<SecureField> for AVX512Backend {
fn batch_inverse(column: &Self::Column, dst: &mut Self::Column) {
PackedSecureField::batch_inverse(&column.data, &mut dst.data);
}
}

impl Column<SecureField> for SecureFieldVec {
fn zeros(len: usize) -> Self {
Self {
data: vec![PackedSecureField::zeroed(); len.div_ceil(K_BLOCK_SIZE)],
length: len,
}
}
fn to_vec(&self) -> Vec<SecureField> {
self.data
.iter()
.flat_map(|x| x.to_array())
.take(self.length)
.collect()
}
fn len(&self) -> usize {
self.length
}
fn at(&self, index: usize) -> SecureField {
self.data[index / K_BLOCK_SIZE].to_array()[index % K_BLOCK_SIZE]
}
}

impl FromIterator<SecureField> for SecureFieldVec {
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
let mut res: Vec<_> = (&mut chunks).map(PackedSecureField::from_array).collect();
let mut length = res.len() * K_BLOCK_SIZE;

if let Some(remainder) = chunks.into_remainder() {
if !remainder.is_empty() {
length += remainder.len();
let pad_len = 16 - remainder.len();
let last = PackedSecureField::from_array(
remainder
.chain(std::iter::repeat(SecureField::zero()).take(pad_len))
.collect::<Vec<_>>()
.try_into()
.unwrap(),
);
res.push(last);
}
}

Self { data: res, length }
}
}

impl FromIterator<PackedSecureField> for SecureFieldVec {
fn from_iter<I: IntoIterator<Item = PackedSecureField>>(iter: I) -> Self {
let data = (&mut iter.into_iter()).collect_vec();
let length = data.len() * K_BLOCK_SIZE;

Self { data, length }
}
}

impl SecureColumn<AVX512Backend> {
pub fn packed_at(&self, vec_index: usize) -> PackedQM31 {
pub fn packed_at(&self, vec_index: usize) -> PackedSecureField {
unsafe {
PackedQM31([
PackedSecureField([
PackedCM31([
*self.columns[0].data.get_unchecked(vec_index),
*self.columns[1].data.get_unchecked(vec_index),
Expand All @@ -150,7 +230,7 @@ impl SecureColumn<AVX512Backend> {
}
}

pub fn set_packed(&mut self, vec_index: usize, value: PackedQM31) {
pub fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) {
unsafe {
*self.columns[0].data.get_unchecked_mut(vec_index) = value.a().a();
*self.columns[1].data.get_unchecked_mut(vec_index) = value.a().b();
Expand Down Expand Up @@ -202,7 +282,7 @@ mod tests {
for i in 1..16 {
let len = 1 << i;
let mut col = Col::<B, BaseField>::from_iter((0..len).map(BaseField::from));
B::bit_reverse_column(&mut col);
<B as ColumnOps<BaseField>>::bit_reverse_column(&mut col);
assert_eq!(
col.to_vec(),
(0..len)
Expand All @@ -229,7 +309,7 @@ mod tests {
let expected = column.data.iter().map(|e| e.inverse()).collect::<Vec<_>>();
let mut dst = BaseFieldVec::from_iter((0..64).map(|_| BaseField::zero()));

AVX512Backend::batch_inverse(&column, &mut dst);
<AVX512Backend as FieldOps<BaseField>>::batch_inverse(&column, &mut dst);

dst.data.iter().zip(expected.iter()).for_each(|(a, b)| {
assert_eq!(a.to_array(), b.to_array());
Expand Down
Loading

0 comments on commit e2d213e

Please sign in to comment.