Skip to content

Commit

Permalink
Implement Backend for AVX512Backend
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 28, 2024
1 parent 42f9f4c commit 0a61546
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 31 deletions.
15 changes: 15 additions & 0 deletions src/core/backend/avx512/accumulation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use super::qm31::PackedQM31;
use super::AVX512Backend;
use crate::core::air::accumulation::AccumulationOps;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;

impl AccumulationOps for AVX512Backend {
fn accumulate(column: &mut SecureColumn<Self>, alpha: SecureField, other: &SecureColumn<Self>) {
let alpha = PackedQM31::broadcast(alpha);
for i in 0..column.len() {
let res_coeff = column.packed_at(i) * alpha + other.packed_at(i);
column.set_packed(i, res_coeff);
}
}
}
6 changes: 3 additions & 3 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod accumulation;
pub mod bit_reverse;
mod blake2s;
pub mod blake2s_avx;
Expand All @@ -17,7 +18,7 @@ use self::bit_reverse::bit_reverse_m31;
use self::cm31::PackedCM31;
pub use self::m31::{PackedBaseField, K_BLOCK_SIZE};
use self::qm31::PackedQM31;
use super::{Column, ColumnOps};
use super::{Backend, Column, ColumnOps};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
Expand All @@ -29,8 +30,7 @@ pub const VECS_LOG_SIZE: usize = 4;
#[derive(Copy, Clone, Debug)]
pub struct AVX512Backend;

// BaseField.
// TODO(spapini): Unite with the M31AVX512 type.
impl Backend for AVX512Backend {}

unsafe impl Pod for PackedBaseField {}
unsafe impl Zeroable for PackedBaseField {
Expand Down
10 changes: 5 additions & 5 deletions src/core/backend/avx512/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ impl QuotientOps for AVX512Backend {
assert!(domain.log_size() >= VECS_LOG_SIZE as u32);
let mut values = SecureColumn::<AVX512Backend>::zeros(domain.size());
// TODO(spapini): bit reverse iterator.
for vec_row in 0..(1 << (domain.log_size() - VECS_LOG_SIZE as u32)) {
// TODO(spapini): Optimize this, for the small number of columns case.
for pack_index in 0..(1 << (domain.log_size() - VECS_LOG_SIZE as u32)) {
// TODO(spapini): Optimized this, for the small number of columns case.
let points = std::array::from_fn(|i| {
domain.at(bit_reverse_index(
(vec_row << VECS_LOG_SIZE) + i,
(pack_index << VECS_LOG_SIZE) + i,
domain.log_size(),
))
});
Expand All @@ -33,11 +33,11 @@ impl QuotientOps for AVX512Backend {
let row_accumlator = accumulate_row_quotients(
samples,
columns,
vec_row,
pack_index,
random_coeff,
(domain_points_x, domain_points_y),
);
values.set_packed(vec_row, row_accumlator);
values.set_packed(pack_index, row_accumlator);
}
SecureEvaluation { domain, values }
}
Expand Down
12 changes: 1 addition & 11 deletions src/core/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,15 @@ pub use cpu::CPUBackend;
use super::air::accumulation::AccumulationOps;
use super::commitment_scheme::quotients::QuotientOps;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::fields::FieldOps;
use super::fri::FriOps;
use super::poly::circle::PolyOps;

#[cfg(target_arch = "x86_64")]
pub mod avx512;
pub mod cpu;

pub trait Backend:
Copy
+ Clone
+ Debug
+ FieldOps<BaseField>
+ FieldOps<SecureField>
+ PolyOps
+ QuotientOps
+ FriOps
+ AccumulationOps
Copy + Clone + Debug + FieldOps<BaseField> + PolyOps + QuotientOps + AccumulationOps
{
}

Expand Down
14 changes: 11 additions & 3 deletions src/core/commitment_scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ use crate::commitment_scheme::blake2_hash::Blake2sHash;
use crate::commitment_scheme::blake2_merkle::Blake2sMerkleHasher;
use crate::commitment_scheme::ops::MerkleOps;
use crate::commitment_scheme::prover::{MerkleDecommitment, MerkleProver};
use crate::core::backend::Backend;
use crate::core::backend::{Backend, CPUBackend};
use crate::core::channel::Channel;
use crate::core::poly::circle::{CircleEvaluation, CirclePoly};
use crate::core::poly::circle::{CircleEvaluation, CirclePoly, SecureEvaluation};

type MerkleHasher = Blake2sMerkleHasher;
type ProofChannel = Blake2sChannel;
Expand Down Expand Up @@ -89,9 +89,17 @@ impl<B: Backend + MerkleOps<MerkleHasher>> CommitmentSchemeProver<B> {
let columns = self.evaluations().flatten();
let quotients = compute_fri_quotients(&columns, &samples.flatten(), channel.draw_felt());

// TODO(spapini): Conversion to CircleEvaluation can be removed when FRI supports
// SecureColumn.
let quotients = quotients
.into_iter()
.map(SecureEvaluation::to_cpu)
.collect_vec();

// Run FRI commitment phase on the oods quotients.
let fri_config = FriConfig::new(LOG_LAST_LAYER_DEGREE_BOUND, LOG_BLOWUP_FACTOR, N_QUERIES);
let fri_prover = FriProver::<B, MerkleHasher>::commit(channel, fri_config, &quotients);
let fri_prover =
FriProver::<CPUBackend, MerkleHasher>::commit(channel, fri_config, &quotients);

// Proof of work.
let proof_of_work = ProofOfWork::new(PROOF_OF_WORK_BITS).prove(channel);
Expand Down
14 changes: 9 additions & 5 deletions src/core/commitment_scheme/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,17 @@ pub fn fri_answers_for_log_size(
CircleEvaluation::new(domain, q.take(domain.size()).copied().collect_vec())
})
.collect_vec();
CPUBackend::accumulate_quotients(
CircleEvaluation::new(
domain,
&column_evals.iter().collect_vec(),
random_coeff,
&batched_samples,
CPUBackend::accumulate_quotients(
domain,
&column_evals.iter().collect_vec(),
random_coeff,
&batched_samples,
)
.into_iter()
.collect(),
)
.to_cpu()
})
.collect(),
);
Expand Down
10 changes: 6 additions & 4 deletions src/core/poly/circle/secure_poly.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::ops::Deref;

use super::{CircleDomain, CircleEvaluation, CirclePoly, PolyOps};
use crate::core::backend::cpu::CPUCircleEvaluation;
use crate::core::backend::CPUBackend;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -64,10 +63,13 @@ impl<B: FieldOps<BaseField>> Deref for SecureEvaluation<B> {
}
}

impl SecureEvaluation<CPUBackend> {
impl<B: FieldOps<BaseField>> SecureEvaluation<B> {
// TODO(spapini): Remove when we no longer use CircleEvaluation<SecureField>.
pub fn to_cpu(self) -> CPUCircleEvaluation<SecureField, BitReversedOrder> {
CPUCircleEvaluation::new(self.domain, self.values.to_vec())
pub fn to_cpu(self) -> SecureEvaluation<CPUBackend> {
SecureEvaluation {
domain: self.domain,
values: self.values.to_cpu(),
}
}
}

Expand Down

0 comments on commit 0a61546

Please sign in to comment.