Skip to content

Commit

Permalink
Use fast eval_at_point (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored Mar 25, 2024
1 parent 3846f82 commit 83002ed
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 105 deletions.
4 changes: 1 addition & 3 deletions benches/eval_at_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ pub fn avx512_eval_at_secure_point(c: &mut criterion::Criterion) {
let point = CirclePoint { x, y };
c.bench_function("avx eval_at_secure_field_point", |b| {
b.iter(|| {
black_box(<AVX512Backend as PolyOps>::eval_at_securefield_point(
&poly, point,
));
black_box(<AVX512Backend as PolyOps>::eval_at_point(&poly, point));
})
});
}
Expand Down
147 changes: 72 additions & 75 deletions src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::core::backend::{CPUBackend, Col};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, Field, FieldExpOps};
use crate::core::fields::{Field, FieldExpOps};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
Expand Down Expand Up @@ -163,30 +163,52 @@ impl PolyOps for AVX512Backend {
CirclePoly::new(values)
}

fn eval_at_point<E: ExtensionOf<BaseField>>(
poly: &CirclePoly<Self>,
point: CirclePoint<E>,
) -> E {
// TODO(spapini): Optimize.
let mut mappings = vec![point.y, point.x];
let mut x = point.x;
for _ in 2..poly.log_size() {
x = CirclePoint::double_x(x);
mappings.push(x);
fn eval_at_point(poly: &CirclePoly<Self>, point: CirclePoint<SecureField>) -> SecureField {
// If the polynomial is small, fallback to evaluate directly.
// TODO(Ohad): it's possible to avoid falling back. Consider fixing.
if poly.log_size() <= 8 {
return slow_eval_at_point(poly, point);
}
mappings.reverse();

// If the polynomial is large, the fft does a transpose in the middle.
if poly.log_size() as usize > CACHED_FFT_LOG_SIZE {
let n = mappings.len();
let n0 = (n - VECS_LOG_SIZE) / 2;
let n1 = (n - VECS_LOG_SIZE + 1) / 2;
let (ab, c) = mappings.split_at_mut(n1);
let (a, _b) = ab.split_at_mut(n0);
// Swap content of a,c.
a.swap_with_slice(&mut c[0..n0]);
let mappings = Self::generate_evaluation_mappings(point, poly.log_size());

// 8 lowest mappings produce the first 2^8 twiddles. Separate to optimize each calculation.
let (map_low, map_high) = mappings.split_at(4);
let twiddle_lows =
PackedQM31::from_array(&std::array::from_fn(|i| Self::twiddle_at(map_low, i)));
let (map_mid, map_high) = map_high.split_at(4);
let twiddle_mids =
PackedQM31::from_array(&std::array::from_fn(|i| Self::twiddle_at(map_mid, i)));

// Compute the high twiddle steps.
let twiddle_steps = Self::twiddle_steps(map_high);

// Every twiddle is a product of mappings that correspond to '1's in the bit representation
// of the current index. For every 2^n alligned chunk of 2^n elements, the twiddle
// array is the same, denoted twiddle_low. Use this to compute sums of (coeff *
// twiddle_high) mod 2^n, then multiply by twiddle_low, and sum to get the final result.
let mut sum = PackedQM31::zeroed();
let mut twiddle_high = SecureField::one();
for (i, coeff_chunk) in poly.coeffs.data.array_chunks::<K_BLOCK_SIZE>().enumerate() {
// For every chunk of 2 ^ 4 * 2 ^ 4 = 2 ^ 8 elements, the twiddle high is the same.
// Multiply it by every mid twiddle factor to get the factors for the current chunk.
let high_twiddle_factors =
(PackedQM31::broadcast(twiddle_high) * twiddle_mids).to_array();

// Sum the coefficients multiplied by each corrseponsing twiddle. Result is effectivley
// an array[16] where the value at index 'i' is the sum of all coefficients at indices
// that are i mod 16.
for (&packed_coeffs, &mid_twiddle) in
coeff_chunk.iter().zip(high_twiddle_factors.iter())
{
sum = sum + PackedQM31::broadcast(mid_twiddle).mul_packed_m31(packed_coeffs);
}

// Advance twiddle high.
twiddle_high = Self::advance_twiddle(twiddle_high, &twiddle_steps, i);
}
fold(cast_slice(&poly.coeffs.data), &mappings)

(sum * twiddle_lows).pointwise_sum()
}

fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {
Expand Down Expand Up @@ -278,57 +300,31 @@ impl PolyOps for AVX512Backend {
itwiddles,
}
}
}

fn eval_at_securefield_point(
poly: &CirclePoly<Self>,
point: CirclePoint<SecureField>,
) -> SecureField {
// If the polynomial is small, fallback to evaluate directly.
// TODO(Ohad): it's possible to avoid falling back. Consider fixing.
if poly.log_size() <= 8 {
return Self::eval_at_point(poly, point);
}

let mappings = Self::generate_evaluation_mappings(point, poly.log_size());

// 8 lowest mappings produce the first 2^8 twiddles. Separate to optimize each calculation.
let (map_low, map_high) = mappings.split_at(4);
let twiddle_lows =
PackedQM31::from_array(&std::array::from_fn(|i| Self::twiddle_at(map_low, i)));
let (map_mid, map_high) = map_high.split_at(4);
let twiddle_mids =
PackedQM31::from_array(&std::array::from_fn(|i| Self::twiddle_at(map_mid, i)));

// Compute the high twiddle steps.
let twiddle_steps = Self::twiddle_steps(map_high);

// Every twiddle is a product of mappings that correspond to '1's in the bit representation
// of the current index. For every 2^n alligned chunk of 2^n elements, the twiddle
// array is the same, denoted twiddle_low. Use this to compute sums of (coeff *
// twiddle_high) mod 2^n, then multiply by twiddle_low, and sum to get the final result.
let mut sum = PackedQM31::zeroed();
let mut twiddle_high = SecureField::one();
for (i, coeff_chunk) in poly.coeffs.data.array_chunks::<K_BLOCK_SIZE>().enumerate() {
// For every chunk of 2 ^ 4 * 2 ^ 4 = 2 ^ 8 elements, the twiddle high is the same.
// Multiply it by every mid twiddle factor to get the factors for the current chunk.
let high_twiddle_factors =
(PackedQM31::broadcast(twiddle_high) * twiddle_mids).to_array();

// Sum the coefficients multiplied by each corrseponsing twiddle. Result is effectivley
// an array[16] where the value at index 'i' is the sum of all coefficients at indices
// that are i mod 16.
for (&packed_coeffs, &mid_twiddle) in
coeff_chunk.iter().zip(high_twiddle_factors.iter())
{
sum = sum + PackedQM31::broadcast(mid_twiddle).mul_packed_m31(packed_coeffs);
}

// Advance twiddle high.
twiddle_high = Self::advance_twiddle(twiddle_high, &twiddle_steps, i);
}

(sum * twiddle_lows).pointwise_sum()
fn slow_eval_at_point(
poly: &CirclePoly<AVX512Backend>,
point: CirclePoint<SecureField>,
) -> SecureField {
let mut mappings = vec![point.y, point.x];
let mut x = point.x;
for _ in 2..poly.log_size() {
x = CirclePoint::double_x(x);
mappings.push(x);
}
mappings.reverse();

// If the polynomial is large, the fft does a transpose in the middle.
if poly.log_size() as usize > CACHED_FFT_LOG_SIZE {
let n = mappings.len();
let n0 = (n - VECS_LOG_SIZE) / 2;
let n1 = (n - VECS_LOG_SIZE + 1) / 2;
let (ab, c) = mappings.split_at_mut(n1);
let (a, _b) = ab.split_at_mut(n0);
// Swap content of a,c.
a.swap_with_slice(&mut c[0..n0]);
}
fold(cast_slice::<_, BaseField>(&poly.coeffs.data), &mappings)
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
Expand All @@ -337,6 +333,7 @@ mod tests {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};

use crate::core::backend::avx512::circle::slow_eval_at_point;
use crate::core::backend::avx512::fft::{CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use crate::core::backend::avx512::AVX512Backend;
use crate::core::backend::Column;
Expand Down Expand Up @@ -398,8 +395,8 @@ mod tests {
for i in [0, 1, 3, 1 << (log_size - 1), 1 << (log_size - 2)] {
let p = domain.at(i);
assert_eq!(
poly.eval_at_point(p),
BaseField::from_u32_unchecked(i as u32),
poly.eval_at_point(p.into_ef()),
BaseField::from_u32_unchecked(i as u32).into(),
"log_size = {log_size} i = {i}"
);
}
Expand Down Expand Up @@ -455,14 +452,14 @@ mod tests {
let p = CirclePoint { x, y };

assert_eq!(
<AVX512Backend as PolyOps>::eval_at_securefield_point(&poly, p),
<AVX512Backend as PolyOps>::eval_at_point(&poly, p),
slow_eval_at_point(&poly, p),
"log_size = {log_size}"
);

println!(
"log_size = {log_size} passed, eval{}",
<AVX512Backend as PolyOps>::eval_at_securefield_point(&poly, p)
<AVX512Backend as PolyOps>::eval_at_point(&poly, p)
);
}
}
Expand Down
13 changes: 2 additions & 11 deletions src/core/backend/cpu/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::core::backend::{Col, ColumnOps};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fft::{butterfly, ibutterfly};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, FieldExpOps};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
Expand Down Expand Up @@ -76,10 +77,7 @@ impl PolyOps for CPUBackend {
CirclePoly::new(values)
}

fn eval_at_point<E: ExtensionOf<BaseField>>(
poly: &CirclePoly<Self>,
point: CirclePoint<E>,
) -> E {
fn eval_at_point(poly: &CirclePoly<Self>, point: CirclePoint<SecureField>) -> SecureField {
// TODO(Andrew): Allocation here expensive for small polynomials.
let mut mappings = vec![point.y, point.x];
let mut x = point.x;
Expand Down Expand Up @@ -174,13 +172,6 @@ impl PolyOps for CPUBackend {
itwiddles,
}
}

fn eval_at_securefield_point(
poly: &CirclePoly<Self>,
point: CirclePoint<crate::core::fields::qm31::SecureField>,
) -> crate::core::fields::qm31::SecureField {
Self::eval_at_point(poly, point)
}
}

fn fft_layer_loop(
Expand Down
2 changes: 1 addition & 1 deletion src/core/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ mod tests {
let mut quotient_polynomial_values = Vec::with_capacity(large_domain_size as usize);
for point in large_domain.iter() {
let line = complex_conjugate_line(vanish_point, vanish_point_value, point);
let mut value = polynomial.eval_at_point(point) - line;
let mut value = polynomial.eval_at_point(point.into_ef()) - line;
value /= pair_vanishing(
vanish_point,
vanish_point.complex_conjugate(),
Expand Down
4 changes: 2 additions & 2 deletions src/core/poly/circle/evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ mod tests {
.bit_reverse();
let poly = evaluation.interpolate();
for (i, point) in domain.iter().enumerate() {
assert_eq!(poly.eval_at_point(point), m31!(i as u32));
assert_eq!(poly.eval_at_point(point.into_ef()), m31!(i as u32).into());
}
}

Expand All @@ -182,7 +182,7 @@ mod tests {
);
let poly = evaluation.interpolate();
for (i, point) in Coset::odds(3).iter().enumerate() {
assert_eq!(poly.eval_at_point(point), m31!(i as u32));
assert_eq!(poly.eval_at_point(point.into_ef()), m31!(i as u32).into());
}
}

Expand Down
13 changes: 2 additions & 11 deletions src/core/poly/circle/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::core::backend::Col;
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, FieldOps};
use crate::core::fields::FieldOps;
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::BitReversedOrder;

Expand All @@ -29,11 +29,7 @@ pub trait PolyOps: FieldOps<BaseField> + Sized {

/// Evaluates the polynomial at a single point.
/// Used by the [`CirclePoly::eval_at_point()`] function.
// TODO: Consider deprecating if/when not in use.
fn eval_at_point<E: ExtensionOf<BaseField>>(
poly: &CirclePoly<Self>,
point: CirclePoint<E>,
) -> E;
fn eval_at_point(poly: &CirclePoly<Self>, point: CirclePoint<SecureField>) -> SecureField;

/// Extends the polynomial to a larger degree bound.
/// Used by the [`CirclePoly::extend()`] function.
Expand All @@ -49,9 +45,4 @@ pub trait PolyOps: FieldOps<BaseField> + Sized {

/// Precomputes twiddles for a given coset.
fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self>;

fn eval_at_securefield_point(
poly: &CirclePoly<Self>,
point: CirclePoint<SecureField>,
) -> SecureField;
}
5 changes: 3 additions & 2 deletions src/core/poly/circle/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use super::{CircleDomain, CircleEvaluation, PolyOps};
use crate::core::backend::{Col, Column};
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::{ExtensionOf, FieldOps};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldOps;
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::BitReversedOrder;

Expand Down Expand Up @@ -38,7 +39,7 @@ impl<B: PolyOps> CirclePoly<B> {
}

/// Evaluates the polynomial at a single point.
pub fn eval_at_point<E: ExtensionOf<BaseField>>(&self, point: CirclePoint<E>) -> E {
pub fn eval_at_point(&self, point: CirclePoint<SecureField>) -> SecureField {
B::eval_at_point(self, point)
}

Expand Down

0 comments on commit 83002ed

Please sign in to comment.