From be0ce1c74f5ee7b55c7e0af9570ead24264bbacb Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Sat, 16 Mar 2024 08:28:09 +0200 Subject: [PATCH] Implement Backend for AVX512Backend --- src/core/backend/avx512/accumulation.rs | 17 +++++++++++ src/core/backend/avx512/mod.rs | 38 ++++++++++++++++++------- src/core/backend/avx512/quotients.rs | 10 ++++--- src/core/backend/mod.rs | 12 +------- 4 files changed, 52 insertions(+), 25 deletions(-) create mode 100644 src/core/backend/avx512/accumulation.rs diff --git a/src/core/backend/avx512/accumulation.rs b/src/core/backend/avx512/accumulation.rs new file mode 100644 index 000000000..2191ccc0d --- /dev/null +++ b/src/core/backend/avx512/accumulation.rs @@ -0,0 +1,17 @@ +use super::qm31::PackedQM31; +use super::AVX512Backend; +use crate::core::air::accumulation::AccumulationOps; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SecureColumn; + +impl AccumulationOps for AVX512Backend { + fn accumulate(column: &mut SecureColumn, alpha: SecureField, other: &SecureColumn) { + let alpha = PackedQM31::broadcast(alpha); + for i in 0..column.len() { + unsafe { + let res_coeff = column.get_packed(i) * alpha + other.get_packed(i); + column.set_packed(i, res_coeff); + } + } + } +} diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs index 2ada5a4c2..6573ba7eb 100644 --- a/src/core/backend/avx512/mod.rs +++ b/src/core/backend/avx512/mod.rs @@ -1,3 +1,4 @@ +pub mod accumulation; pub mod bit_reverse; pub mod circle; pub mod cm31; @@ -10,9 +11,10 @@ use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable}; use num_traits::Zero; use self::bit_reverse::bit_reverse_m31; +use self::cm31::PackedCM31; pub use self::m31::{PackedBaseField, K_BLOCK_SIZE}; use self::qm31::PackedQM31; -use super::{Column, ColumnOps}; +use super::{Backend, Column, ColumnOps}; use crate::core::fields::m31::BaseField; use crate::core::fields::secure_column::SecureColumn; use crate::core::fields::{FieldExpOps, FieldOps}; @@ -23,8 +25,7 @@ pub const VECS_LOG_SIZE: usize = 4; #[derive(Copy, Clone, Debug)] pub struct AVX512Backend; -// BaseField. -// TODO(spapini): Unite with the M31AVX512 type. +impl Backend for AVX512Backend {} unsafe impl Pod for PackedBaseField {} unsafe impl Zeroable for PackedBaseField { @@ -130,13 +131,30 @@ impl FromIterator for BaseFieldVec { } impl SecureColumn { - pub fn set(&mut self, vec_index: usize, value: PackedQM31) { - unsafe { - *self.columns[0].data.get_unchecked_mut(vec_index) = value.a().a(); - *self.columns[1].data.get_unchecked_mut(vec_index) = value.a().b(); - *self.columns[2].data.get_unchecked_mut(vec_index) = value.b().a(); - *self.columns[3].data.get_unchecked_mut(vec_index) = value.b().b(); - } + /// # Safety + /// + /// Calling this method with an out-of-bounds index is undefined behavior. + pub unsafe fn set_packed(&mut self, pack_index: usize, value: PackedQM31) { + *self.columns[0].data.get_unchecked_mut(pack_index) = value.a().a(); + *self.columns[1].data.get_unchecked_mut(pack_index) = value.a().b(); + *self.columns[2].data.get_unchecked_mut(pack_index) = value.b().a(); + *self.columns[3].data.get_unchecked_mut(pack_index) = value.b().b(); + } + + /// # Safety + /// + /// Calling this method with an out-of-bounds index is undefined behavior. + pub unsafe fn get_packed(&self, pack_index: usize) -> PackedQM31 { + PackedQM31([ + PackedCM31([ + *self.columns[0].data.get_unchecked(pack_index), + *self.columns[1].data.get_unchecked(pack_index), + ]), + PackedCM31([ + *self.columns[2].data.get_unchecked(pack_index), + *self.columns[3].data.get_unchecked(pack_index), + ]), + ]) } } diff --git a/src/core/backend/avx512/quotients.rs b/src/core/backend/avx512/quotients.rs index def77a437..e3c43d5fb 100644 --- a/src/core/backend/avx512/quotients.rs +++ b/src/core/backend/avx512/quotients.rs @@ -20,11 +20,11 @@ impl QuotientOps for AVX512Backend { assert!(domain.log_size() >= VECS_LOG_SIZE as u32); let mut values = SecureColumn::::zeros(domain.size()); // TODO(spapini): bit reverse iterator. - for vec_row in 0..(1 << (domain.log_size() - VECS_LOG_SIZE as u32)) { + for pack_index in 0..(1 << (domain.log_size() - VECS_LOG_SIZE as u32)) { // TODO(spapini): Optimized this, for the small number of columns case. let points = std::array::from_fn(|i| { domain.at(bit_reverse_index( - (vec_row << VECS_LOG_SIZE) + i, + (pack_index << VECS_LOG_SIZE) + i, domain.log_size(), )) }); @@ -33,11 +33,13 @@ impl QuotientOps for AVX512Backend { let row_accumlator = accumulate_row_quotients( openings, columns, - vec_row, + pack_index, random_coeff, (domain_points_x, domain_points_y), ); - values.set(vec_row, row_accumlator); + unsafe { + values.set_packed(pack_index, row_accumlator); + } } SecureEvaluation { domain, values } } diff --git a/src/core/backend/mod.rs b/src/core/backend/mod.rs index 5319ace9e..8ae2f6f55 100644 --- a/src/core/backend/mod.rs +++ b/src/core/backend/mod.rs @@ -5,24 +5,14 @@ pub use cpu::CPUBackend; use super::air::accumulation::AccumulationOps; use super::commitment_scheme::quotients::QuotientOps; use super::fields::m31::BaseField; -use super::fields::qm31::SecureField; use super::fields::FieldOps; -use super::fri::FriOps; use super::poly::circle::PolyOps; pub mod avx512; pub mod cpu; pub trait Backend: - Copy - + Clone - + Debug - + FieldOps - + FieldOps - + PolyOps - + QuotientOps - + FriOps - + AccumulationOps + Copy + Clone + Debug + FieldOps + PolyOps + QuotientOps + AccumulationOps { }