Skip to content

Commit

Permalink
avx slow eval
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 12, 2024
1 parent 6e9da20 commit 58568c8
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 56 deletions.
139 changes: 95 additions & 44 deletions src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use super::fft::ifft;
use bytemuck::cast_slice;

use super::fft::{ifft, CACHED_FFT_LOG_SIZE};
use super::m31::PackedBaseField;
use super::{as_cpu_vec, AVX512Backend};
use super::{as_cpu_vec, AVX512Backend, VECS_LOG_SIZE};
use crate::core::backend::avx512::fft::rfft;
use crate::core::backend::avx512::{BaseFieldVec, VECS_LOG_SIZE};
use crate::core::backend::avx512::BaseFieldVec;
use crate::core::backend::CPUBackend;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Col, FieldExpOps};
use crate::core::fields::{Col, ExtensionOf, FieldExpOps};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
use crate::core::poly::utils::fold;
use crate::core::poly::BitReversedOrder;

// TODO(spapini): Everything is returned in redundant representation, where values can also be P.
Expand Down Expand Up @@ -55,11 +59,29 @@ impl PolyOps<BaseField> for AVX512Backend {
CirclePoly::new(values)
}

fn eval_at_point<E: crate::core::fields::ExtensionOf<BaseField>>(
_poly: &CirclePoly<Self, BaseField>,
_point: crate::core::circle::CirclePoint<E>,
fn eval_at_point<E: ExtensionOf<BaseField>>(
poly: &CirclePoly<Self, BaseField>,
point: CirclePoint<E>,
) -> E {
todo!()
// TODO(spapini): Optimize.
let mut mappings = vec![point.y, point.x];
let mut x = point.x;
for _ in 2..poly.log_size() {
x = CirclePoint::double_x(x);
mappings.push(x);
}
mappings.reverse();

if poly.log_size() as usize > CACHED_FFT_LOG_SIZE {
let n = mappings.len();
let n0 = (n - VECS_LOG_SIZE) / 2;
let n1 = (n - VECS_LOG_SIZE + 1) / 2;
let (ab, c) = mappings.split_at_mut(n1);
let (a, _b) = ab.split_at_mut(n0);
// Swap content of a,c.
a.swap_with_slice(&mut c[0..n0]);
}
fold(cast_slice(&poly.coeffs.data), &mappings)
}

fn extend(poly: &CirclePoly<Self, BaseField>, log_size: u32) -> CirclePoly<Self, BaseField> {
Expand Down Expand Up @@ -125,58 +147,87 @@ impl PolyOps<BaseField> for AVX512Backend {
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use crate::core::backend::avx512::fft::{CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use crate::core::backend::avx512::AVX512Backend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::Column;
use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::poly::{BitReversedOrder, NaturalOrder};

#[test]
fn test_interpolate_and_eval() {
const LOG_SIZE: u32 = 6;
let domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let evaluation = CircleEvaluation::<AVX512Backend, _, BitReversedOrder>::new(
domain,
(0..(1 << LOG_SIZE))
.map(BaseField::from_u32_unchecked)
.collect(),
);
let poly = evaluation.clone().interpolate();
let evaluation2 = poly.evaluate(domain);
assert_eq!(evaluation.values.to_vec(), evaluation2.values.to_vec());
for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 4) {
let domain = CanonicCoset::new(log_size as u32).circle_domain();
let evaluation = CircleEvaluation::<AVX512Backend, _, BitReversedOrder>::new(
domain,
(0..(1 << log_size))
.map(BaseField::from_u32_unchecked)
.collect(),
);
let poly = evaluation.clone().interpolate();
let evaluation2 = poly.evaluate(domain);
assert_eq!(evaluation.values.to_vec(), evaluation2.values.to_vec());
}
}

#[test]
fn test_eval_extension() {
const LOG_SIZE: u32 = 6;
let domain = CircleDomain::constraint_evaluation_domain(LOG_SIZE);
let domain_ext = CircleDomain::constraint_evaluation_domain(LOG_SIZE + 3);
let evaluation = CircleEvaluation::<AVX512Backend, _, BitReversedOrder>::new(
domain,
(0..(1 << LOG_SIZE))
.map(BaseField::from_u32_unchecked)
.collect(),
);
let poly = evaluation.clone().interpolate();
let evaluation2 = poly.evaluate(domain_ext);
for i in 0..(1 << LOG_SIZE) {
assert_eq!(evaluation2.values.at(i), evaluation.values.at(i));
for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 4) {
let log_size = log_size as u32;
let domain = CircleDomain::constraint_evaluation_domain(log_size);
let domain_ext = CircleDomain::constraint_evaluation_domain(log_size + 3);
let evaluation = CircleEvaluation::<AVX512Backend, _, BitReversedOrder>::new(
domain,
(0..(1 << log_size))
.map(BaseField::from_u32_unchecked)
.collect(),
);
let poly = evaluation.clone().interpolate();
let evaluation2 = poly.evaluate(domain_ext);
assert_eq!(
evaluation2.values.to_vec()[..(1 << log_size)],
evaluation.values.to_vec()
);
}
}

#[test]
fn test_eval_at_point() {
for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 4) {
let domain = CanonicCoset::new(log_size as u32).circle_domain();
let evaluation = CircleEvaluation::<AVX512Backend, _, NaturalOrder>::new(
domain,
(0..(1 << log_size))
.map(BaseField::from_u32_unchecked)
.collect(),
);
let poly = evaluation.bit_reverse().interpolate();
for i in [0, 1, 3, 1 << (log_size - 1), 1 << (log_size - 2)] {
let p = domain.at(i);
assert_eq!(
poly.eval_at_point(p),
BaseField::from_u32_unchecked(i as u32),
"log_size = {log_size} i = {i}"
);
}
}
}

#[test]
fn test_circle_poly_extend() {
let poly = CirclePoly::<AVX512Backend, _>::new(
(0..(1 << 6)).map(BaseField::from_u32_unchecked).collect(),
);
let eval0 = poly.evaluate(CanonicCoset::new(8).circle_domain());
let eval1 = poly
.extend(8)
.evaluate(CanonicCoset::new(8).circle_domain());

// Compare.
for i in 0..eval0.values.len() {
assert_eq!(eval0.values.at(i), eval1.values.at(i));
for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 2) {
let log_size = log_size as u32;
let poly = CirclePoly::<AVX512Backend, _>::new(
(0..(1 << log_size))
.map(BaseField::from_u32_unchecked)
.collect(),
);
let eval0 = poly.evaluate(CanonicCoset::new(log_size + 2).circle_domain());
let eval1 = poly
.extend(log_size + 2)
.evaluate(CanonicCoset::new(log_size + 2).circle_domain());

assert_eq!(eval0.values.to_vec(), eval1.values.to_vec());
}
}
}
9 changes: 4 additions & 5 deletions src/core/backend/avx512/fft/ifft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::arch::x86_64::{
};

use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS};
use crate::core::backend::avx512::fft::{transpose_vecs, MIN_FFT_LOG_SIZE};
use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE};
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::CircleDomain;
Expand All @@ -28,8 +28,7 @@ use crate::core::utils::bit_reverse;
pub unsafe fn ifft(values: *mut i32, twiddle_dbl: &[&[i32]], log_n_elements: usize) {
assert!(log_n_elements >= MIN_FFT_LOG_SIZE);
let log_n_vecs = log_n_elements - VECS_LOG_SIZE;
// TODO(spapini): Use CACHED_FFT_LOG_SIZE instead.
if log_n_elements <= 1 {
if log_n_elements <= CACHED_FFT_LOG_SIZE {
ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements);
return;
}
Expand Down Expand Up @@ -733,8 +732,8 @@ mod tests {

#[test]
fn test_ifft_full() {
for i in 5..12 {
run_ifft_full_test(i);
for i in (CACHED_FFT_LOG_SIZE + 1)..(CACHED_FFT_LOG_SIZE + 3) {
run_ifft_full_test(i as u32);
}
}
}
9 changes: 4 additions & 5 deletions src/core/backend/avx512/fft/rfft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::arch::x86_64::{
};

use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS};
use crate::core::backend::avx512::fft::{transpose_vecs, MIN_FFT_LOG_SIZE};
use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE};
use crate::core::poly::circle::CircleDomain;
use crate::core::utils::bit_reverse;
Expand All @@ -27,8 +27,7 @@ use crate::core::utils::bit_reverse;
pub unsafe fn fft(values: *mut i32, twiddle_dbl: &[&[i32]], log_n_elements: usize) {
assert!(log_n_elements >= MIN_FFT_LOG_SIZE);
let log_n_vecs = log_n_elements - VECS_LOG_SIZE;
// TODO(spapini): Use CACHED_FFT_LOG_SIZE instead.
if log_n_elements <= 1 {
if log_n_elements <= CACHED_FFT_LOG_SIZE {
fft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements);
return;
}
Expand Down Expand Up @@ -700,8 +699,8 @@ mod tests {

#[test]
fn test_fft_full() {
for i in 5..12 {
run_fft_full_test(i);
for i in (CACHED_FFT_LOG_SIZE + 1)..(CACHED_FFT_LOG_SIZE + 3) {
run_fft_full_test(i as u32);
}
}
}
2 changes: 0 additions & 2 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ use crate::core::fields::{Column, FieldExpOps, FieldOps};
use crate::core::utils;

const VECS_LOG_SIZE: usize = 4;
pub const CACHED_FFT_LOG_SIZE: usize = 16;
pub const MIN_FFT_LOG_SIZE: usize = 5;

#[derive(Copy, Clone, Debug)]
pub struct AVX512Backend;
Expand Down

0 comments on commit 58568c8

Please sign in to comment.