Skip to content

Commit

Permalink
Implement Backend for AVX512Backend
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 21, 2024
1 parent 0e6f2ac commit be0ce1c
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 25 deletions.
17 changes: 17 additions & 0 deletions src/core/backend/avx512/accumulation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use super::qm31::PackedQM31;
use super::AVX512Backend;
use crate::core::air::accumulation::AccumulationOps;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;

impl AccumulationOps for AVX512Backend {
fn accumulate(column: &mut SecureColumn<Self>, alpha: SecureField, other: &SecureColumn<Self>) {
let alpha = PackedQM31::broadcast(alpha);
for i in 0..column.len() {
unsafe {
let res_coeff = column.get_packed(i) * alpha + other.get_packed(i);
column.set_packed(i, res_coeff);
}
}
}
}
38 changes: 28 additions & 10 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod accumulation;
pub mod bit_reverse;
pub mod circle;
pub mod cm31;
Expand All @@ -10,9 +11,10 @@ use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable};
use num_traits::Zero;

use self::bit_reverse::bit_reverse_m31;
use self::cm31::PackedCM31;
pub use self::m31::{PackedBaseField, K_BLOCK_SIZE};
use self::qm31::PackedQM31;
use super::{Column, ColumnOps};
use super::{Backend, Column, ColumnOps};
use crate::core::fields::m31::BaseField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fields::{FieldExpOps, FieldOps};
Expand All @@ -23,8 +25,7 @@ pub const VECS_LOG_SIZE: usize = 4;
#[derive(Copy, Clone, Debug)]
pub struct AVX512Backend;

// BaseField.
// TODO(spapini): Unite with the M31AVX512 type.
impl Backend for AVX512Backend {}

unsafe impl Pod for PackedBaseField {}
unsafe impl Zeroable for PackedBaseField {
Expand Down Expand Up @@ -130,13 +131,30 @@ impl FromIterator<BaseField> for BaseFieldVec {
}

impl SecureColumn<AVX512Backend> {
pub fn set(&mut self, vec_index: usize, value: PackedQM31) {
unsafe {
*self.columns[0].data.get_unchecked_mut(vec_index) = value.a().a();
*self.columns[1].data.get_unchecked_mut(vec_index) = value.a().b();
*self.columns[2].data.get_unchecked_mut(vec_index) = value.b().a();
*self.columns[3].data.get_unchecked_mut(vec_index) = value.b().b();
}
/// # Safety
///
/// Calling this method with an out-of-bounds index is undefined behavior.
pub unsafe fn set_packed(&mut self, pack_index: usize, value: PackedQM31) {
*self.columns[0].data.get_unchecked_mut(pack_index) = value.a().a();
*self.columns[1].data.get_unchecked_mut(pack_index) = value.a().b();
*self.columns[2].data.get_unchecked_mut(pack_index) = value.b().a();
*self.columns[3].data.get_unchecked_mut(pack_index) = value.b().b();
}

/// # Safety
///
/// Calling this method with an out-of-bounds index is undefined behavior.
pub unsafe fn get_packed(&self, pack_index: usize) -> PackedQM31 {
PackedQM31([
PackedCM31([
*self.columns[0].data.get_unchecked(pack_index),
*self.columns[1].data.get_unchecked(pack_index),
]),
PackedCM31([
*self.columns[2].data.get_unchecked(pack_index),
*self.columns[3].data.get_unchecked(pack_index),
]),
])
}
}

Expand Down
10 changes: 6 additions & 4 deletions src/core/backend/avx512/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ impl QuotientOps for AVX512Backend {
assert!(domain.log_size() >= VECS_LOG_SIZE as u32);
let mut values = SecureColumn::<AVX512Backend>::zeros(domain.size());
// TODO(spapini): bit reverse iterator.
for vec_row in 0..(1 << (domain.log_size() - VECS_LOG_SIZE as u32)) {
for pack_index in 0..(1 << (domain.log_size() - VECS_LOG_SIZE as u32)) {
// TODO(spapini): Optimized this, for the small number of columns case.
let points = std::array::from_fn(|i| {
domain.at(bit_reverse_index(
(vec_row << VECS_LOG_SIZE) + i,
(pack_index << VECS_LOG_SIZE) + i,
domain.log_size(),
))
});
Expand All @@ -33,11 +33,13 @@ impl QuotientOps for AVX512Backend {
let row_accumlator = accumulate_row_quotients(
openings,
columns,
vec_row,
pack_index,
random_coeff,
(domain_points_x, domain_points_y),
);
values.set(vec_row, row_accumlator);
unsafe {
values.set_packed(pack_index, row_accumlator);
}
}
SecureEvaluation { domain, values }
}
Expand Down
12 changes: 1 addition & 11 deletions src/core/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,14 @@ pub use cpu::CPUBackend;
use super::air::accumulation::AccumulationOps;
use super::commitment_scheme::quotients::QuotientOps;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::fields::FieldOps;
use super::fri::FriOps;
use super::poly::circle::PolyOps;

pub mod avx512;
pub mod cpu;

pub trait Backend:
Copy
+ Clone
+ Debug
+ FieldOps<BaseField>
+ FieldOps<SecureField>
+ PolyOps
+ QuotientOps
+ FriOps
+ AccumulationOps
Copy + Clone + Debug + FieldOps<BaseField> + PolyOps + QuotientOps + AccumulationOps
{
}

Expand Down

0 comments on commit be0ce1c

Please sign in to comment.