Skip to content

Commit

Permalink
Optimize AVX quotienting. (#555)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 authored Mar 28, 2024
1 parent d4ebbc6 commit 2a18eb8
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ impl PolyOps for AVX512Backend {
for (&packed_coeffs, &mid_twiddle) in
coeff_chunk.iter().zip(high_twiddle_factors.iter())
{
sum = sum + PackedQM31::broadcast(mid_twiddle).mul_packed_m31(packed_coeffs);
sum += PackedQM31::broadcast(mid_twiddle).mul_packed_m31(packed_coeffs);
}

// Advance twiddle high.
Expand Down
7 changes: 6 additions & 1 deletion src/core/backend/avx512/qm31.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Add, Mul, MulAssign, Sub};
use std::ops::{Add, AddAssign, Mul, MulAssign, Sub};

use bytemuck::{Pod, Zeroable};
use num_traits::{One, Zero};
Expand Down Expand Up @@ -106,6 +106,11 @@ impl One for PackedQM31 {
Self([PackedCM31::one(), PackedCM31::zero()])
}
}
impl AddAssign for PackedQM31 {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl MulAssign for PackedQM31 {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
Expand Down
79 changes: 41 additions & 38 deletions src/core/backend/avx512/quotients.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use itertools::zip_eq;

use super::qm31::PackedQM31;
use super::{AVX512Backend, VECS_LOG_SIZE};
use crate::core::backend::avx512::PackedBaseField;
use crate::core::backend::cpu::quotients::column_constants;
use crate::core::circle::CirclePoint;
use crate::core::commitment_scheme::quotients::{ColumnSampleBatch, QuotientOps};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand All @@ -15,10 +19,12 @@ impl QuotientOps for AVX512Backend {
domain: CircleDomain,
columns: &[&CircleEvaluation<Self, BaseField, BitReversedOrder>],
random_coeff: SecureField,
samples: &[ColumnSampleBatch],
sample_batches: &[ColumnSampleBatch],
) -> SecureEvaluation<Self> {
assert!(domain.log_size() >= VECS_LOG_SIZE as u32);
let mut values = SecureColumn::<AVX512Backend>::zeros(domain.size());
let column_constants = column_constants(sample_batches, random_coeff);

// TODO(spapini): bit reverse iterator.
for vec_row in 0..(1 << (domain.log_size() - VECS_LOG_SIZE as u32)) {
// TODO(spapini): Optimize this, for the small number of columns case.
Expand All @@ -30,75 +36,73 @@ impl QuotientOps for AVX512Backend {
});
let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x));
let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y));
let row_accumlator = accumulate_row_quotients(
samples,
let row_accumulator = accumulate_row_quotients(
sample_batches,
columns,
&column_constants,
vec_row,
random_coeff,
(domain_points_x, domain_points_y),
);
values.set_packed(vec_row, row_accumlator);
values.set_packed(vec_row, row_accumulator);
}
SecureEvaluation { domain, values }
}
}

pub fn accumulate_row_quotients(
samples: &[ColumnSampleBatch],
sample_batches: &[ColumnSampleBatch],
columns: &[&CircleEvaluation<AVX512Backend, BaseField, BitReversedOrder>],
column_constants: &[Vec<(SecureField, SecureField, SecureField)>],
vec_row: usize,
random_coeff: SecureField,
domain_point_vec: (PackedBaseField, PackedBaseField),
) -> PackedQM31 {
let mut row_accumlator = PackedQM31::zero();
for sample in samples {
let mut row_accumulator = PackedQM31::zero();
for (sample_batch, sample_constants) in zip_eq(sample_batches, column_constants) {
let mut numerator = PackedQM31::zero();
for (column_index, sample_value) in &sample.columns_and_values {
for ((column_index, _), (a, b, c)) in
zip_eq(&sample_batch.columns_and_values, sample_constants)
{
let column = &columns[*column_index];
let value = column.data[vec_row];
// TODO(alonh): Optimize and simplify this.
let value = PackedQM31::broadcast(*c) * column.data[vec_row];
// The numerator is a line equation passing through
// (sample_point.y, sample_value), (conj(sample_point), conj(sample_value))
// evaluated at (domain_point.y, value).
// When substituting a polynomial in this line equation, we get a polynomial with a root
// at sample_point and conj(sample_point) if the original polynomial had the values
// sample_value and conj(sample_value) at these points.
let current_numerator = cross(
(domain_point_vec.1, value),
(sample.point.y, *sample_value),
(
sample.point.y.complex_conjugate(),
sample_value.complex_conjugate(),
),
);
numerator = numerator * PackedQM31::broadcast(random_coeff) + current_numerator;
// TODO(AlonH): Use single point vanishing to save a multiplication.
let linear_term =
PackedQM31::broadcast(*a) * domain_point_vec.1 + PackedQM31::broadcast(*b);
numerator += value - linear_term;
}

let denominator = cross(
let denominator = packed_pair_vanishing(
sample_batch.point,
sample_batch.point.complex_conjugate(),
domain_point_vec,
(sample.point.x, sample.point.y),
(
sample.point.x.complex_conjugate(),
sample.point.y.complex_conjugate(),
),
);

row_accumlator = row_accumlator
* PackedQM31::broadcast(random_coeff.pow(sample.columns_and_values.len() as u128))
row_accumulator = row_accumulator
* PackedQM31::broadcast(
random_coeff.pow(sample_batch.columns_and_values.len() as u128),
)
+ numerator * denominator.inverse();
}
row_accumlator
row_accumulator
}

/// Computes the cross product of of the vectors (a.0, a.1), (b.0, b.1), (c.0, c.1).
/// This is a multilinear function of the inputs that vanishes when the inputs are collinear.
fn cross(
a: (PackedBaseField, PackedBaseField),
b: (SecureField, SecureField),
c: (SecureField, SecureField),
/// Pair vanishing for the packed representation of the points. See
/// [crate::core::constraints::pair_vanishing] for more details.
fn packed_pair_vanishing(
excluded0: CirclePoint<SecureField>,
excluded1: CirclePoint<SecureField>,
packed_p: (PackedBaseField, PackedBaseField),
) -> PackedQM31 {
PackedQM31::broadcast(b.0 - c.0) * a.1 - PackedQM31::broadcast(b.1 - c.1) * a.0
+ PackedQM31::broadcast(b.1 * c.0 - b.0 * c.1)
PackedQM31::broadcast(excluded0.y - excluded1.y) * packed_p.0
+ PackedQM31::broadcast(excluded1.x - excluded0.x) * packed_p.1
+ PackedQM31::broadcast(excluded0.x * excluded1.y - excluded0.y * excluded1.x)
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
Expand Down Expand Up @@ -160,7 +164,6 @@ mod tests {
.values
.to_vec();

// TODO(spapini): This is calculated in a different way from CPUBackend right now.
assert_ne!(avx_result, cpu_result);
assert_eq!(avx_result, cpu_result);
}
}
2 changes: 1 addition & 1 deletion src/core/backend/cpu/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub fn accumulate_row_quotients(
let mut row_accumulator = SecureField::zero();
for (sample_batch, sample_constants) in zip_eq(sample_batches, column_constants) {
let mut numerator = SecureField::zero();
for ((column_index, _sampled_value), (a, b, c)) in
for ((column_index, _), (a, b, c)) in
zip_eq(&sample_batch.columns_and_values, sample_constants)
{
let column = &columns[*column_index];
Expand Down
7 changes: 3 additions & 4 deletions src/core/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,9 @@ pub fn pair_vanishing<F: ExtensionOf<BaseField>>(
// | e1.x e1.y 1 |
// This is a polynomial of degree 1 in p.x and p.y, and thus it is a line.
// It vanishes at e0 and e1.
p.x * excluded0.y + excluded0.x * excluded1.y + excluded1.x * p.y
- p.x * excluded1.y
- excluded0.x * p.y
- excluded1.x * excluded0.y
(excluded0.y - excluded1.y) * p.x
+ (excluded1.x - excluded0.x) * p.y
+ (excluded0.x * excluded1.y - excluded0.y * excluded1.x)
}

/// Evaluates a vanishing polynomial of the vanish_point at a point.
Expand Down

0 comments on commit 2a18eb8

Please sign in to comment.