Skip to content

Commit

Permalink
AVX backend
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Feb 22, 2024
1 parent 4147e06 commit 47a0728
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 42 deletions.
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ itertools = "0.12.0"
num-traits = "0.2.17"
thiserror = "1.0.56"
merging-iterator = "1.3.0"
bytemuck = { version = "1.14.3", features = ["derive"] }

[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
Expand Down
14 changes: 7 additions & 7 deletions src/core/backend/avx512/bit_reverse.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use std::arch::x86_64::{__m512i, _mm512_permutex2var_epi32};

use crate::core::fields::m31::BaseField;
use super::PackedBaseField;
use crate::core::utils::{bit_reverse_index, IteratorMutExt};

const VEC_BITS: u32 = 4;
const W_BITS: u32 = 3;
const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS;
pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS;

// TODO(spapini): Use PackedBaseField type.
/// Bit reverses packed M31 values.
/// Given an array A[0..2^n), computes B[i] = A[bit_reverse(i)].
pub fn bit_reverse_m31(data: &mut [[BaseField; 16]]) {
/// Given an array `A[0..2^n)`, computes `B[i] = A[bit_reverse(i)]`.
pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {
assert!(data.len().is_power_of_two());
assert!(data.len().ilog2() >= MIN_LOG_SIZE);

Expand Down Expand Up @@ -65,7 +64,7 @@ pub fn bit_reverse_m31(data: &mut [[BaseField; 16]]) {
}
}

fn bit_reverse16(data: [[BaseField; 16]; 16]) -> [[BaseField; 16]; 16] {
fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] {
let mut data: [__m512i; 16] = unsafe { std::mem::transmute(data) };
// abcd0123 => 0abc123d
const L: __m512i = unsafe {
Expand Down Expand Up @@ -128,7 +127,8 @@ mod tests {
let data: Vec<_> = (0..SIZE as u32)
.map(BaseField::from_u32_unchecked)
.collect();
let expected = bit_reverse(data.clone());
let mut expected = data.clone();
bit_reverse(&mut expected);
let mut data: Vec<_> = data.into_iter().array_chunks::<16>().collect();
let expected: Vec<_> = expected.into_iter().array_chunks::<16>().collect();

Expand Down
117 changes: 117 additions & 0 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,118 @@
pub mod bit_reverse;

use std::ops::Index;

use bytemuck::checked::cast_slice_mut;
use num_traits::Zero;

use self::bit_reverse::bit_reverse_m31;
use super::{Column, FieldOps};
use crate::core::fields::m31::BaseField;
use crate::core::utils;

#[derive(Copy, Clone, Debug)]
pub struct AVX512Backend;

// BaseField.
type PackedBaseField = [BaseField; 16];
#[derive(Clone, Debug)]
pub struct BaseFieldVec {
data: Vec<PackedBaseField>,
length: usize,
}
impl FieldOps<BaseField> for AVX512Backend {
type Column = BaseFieldVec;

fn bit_reverse_column(column: &mut Self::Column) {
if column.data.len().ilog2() < bit_reverse::MIN_LOG_SIZE {
let data: &mut [BaseField] = cast_slice_mut(&mut column.data[..]);
utils::bit_reverse(&mut data[..column.length]);
return;
}
bit_reverse_m31(&mut column.data);
}
}

impl Column<BaseField> for BaseFieldVec {
fn zeros(len: usize) -> Self {
Self {
data: vec![PackedBaseField::default(); (len + 15) / 16],
length: len,
}
}
fn to_vec(&self) -> Vec<BaseField> {
self.data
.iter()
.flatten()
.copied()
.take(self.length)
.collect()
}
fn len(&self) -> usize {
self.length
}
}

impl Index<usize> for BaseFieldVec {
type Output = BaseField;
fn index(&self, index: usize) -> &Self::Output {
&self.data[index / 8][index % 8]
}
}

impl FromIterator<BaseField> for BaseFieldVec {
fn from_iter<I: IntoIterator<Item = BaseField>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
let mut res: Vec<_> = (&mut chunks).collect();
let mut length = res.len() * 16;

if let Some(remainder) = chunks.into_remainder() {
if !remainder.is_empty() {
length += remainder.len();
let pad_len = 16 - remainder.len();
let last: PackedBaseField = remainder
.chain(std::iter::repeat(BaseField::zero()).take(pad_len))
.collect::<Vec<_>>()
.try_into()
.unwrap();
res.push(last);
}
}

Self { data: res, length }
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::core::backend::{Col, Column};

type B = AVX512Backend;

#[test]
fn test_column() {
for i in 0..100 {
let col = Col::<B, BaseField>::from_iter((0..i).map(BaseField::from));
assert_eq!(
col.to_vec(),
(0..i).map(BaseField::from).collect::<Vec<_>>()
);
}
}

#[test]
fn test_bit_reverse() {
for i in 1..16 {
let len = 1 << i;
let mut col = Col::<B, BaseField>::from_iter((0..len).map(BaseField::from));
B::bit_reverse_column(&mut col);
assert_eq!(
col.to_vec(),
(0..len)
.map(|x| BaseField::from(utils::bit_reverse_index(x, i as u32)))
.collect::<Vec<_>>()
);
}
}
}
4 changes: 2 additions & 2 deletions src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ impl Backend for CPUBackend {}
impl<F: Field> FieldOps<F> for CPUBackend {
type Column = Vec<F>;

fn bit_reverse_column(column: Self::Column) -> Self::Column {
bit_reverse(column)
fn bit_reverse_column(column: &mut Self::Column) {
bit_reverse(&mut column[..])
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub trait Backend:

pub trait FieldOps<F: Field> {
type Column: Column<F>;
fn bit_reverse_column(column: Self::Column) -> Self::Column;
fn bit_reverse_column(column: &mut Self::Column);
}

pub type Col<B, F> = <B as FieldOps<F>>::Column;
Expand Down
5 changes: 4 additions & 1 deletion src/core/fields/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ use std::ops::{
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
};

use bytemuck::{Pod, Zeroable};

use super::ComplexConjugate;
use crate::impl_field;

pub const MODULUS_BITS: u32 = 31;
pub const N_BYTES_FELT: usize = 4;
pub const P: u32 = 2147483647; // 2 ** 31 - 1

#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Pod, Zeroable)]
pub struct M31(u32);
pub type BaseField = M31;

Expand Down
13 changes: 8 additions & 5 deletions src/core/poly/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,9 @@ impl<F: ExtensionOf<BaseField>, B: PolyOps<F>> CircleEvaluation<B, F> {
self.values[self.domain.find(point_index).expect("Not in domain")]
}

pub fn bit_reverse(self) -> CircleEvaluation<B, F, BitReversedOrder> {
CircleEvaluation::new(self.domain, B::bit_reverse_column(self.values))
pub fn bit_reverse(mut self) -> CircleEvaluation<B, F, BitReversedOrder> {
B::bit_reverse_column(&mut self.values);
CircleEvaluation::new(self.domain, self.values)
}
}

Expand All @@ -282,8 +283,9 @@ impl<F: ExtensionOf<BaseField>> CPUCircleEvaluation<F> {
}

impl<B: PolyOps<F>, F: ExtensionOf<BaseField>> CircleEvaluation<B, F, BitReversedOrder> {
pub fn bit_reverse(self) -> CircleEvaluation<B, F, NaturalOrder> {
CircleEvaluation::new(self.domain, B::bit_reverse_column(self.values))
pub fn bit_reverse(mut self) -> CircleEvaluation<B, F, NaturalOrder> {
B::bit_reverse_column(&mut self.values);
CircleEvaluation::new(self.domain, self.values)
}

pub fn get_at(&self, point_index: CirclePointIndex) -> F {
Expand Down Expand Up @@ -390,7 +392,8 @@ impl<F: ExtensionOf<BaseField>, B: PolyOps<F>> CirclePoly<B, F> {
#[cfg(test)]
impl<F: ExtensionOf<BaseField>> crate::core::backend::cpu::CPUCirclePoly<F> {
pub fn is_in_fft_space(&self, log_fft_size: u32) -> bool {
let mut coeffs = crate::core::utils::bit_reverse(self.coeffs.clone());
let mut coeffs = self.coeffs.clone();
crate::core::utils::bit_reverse(&mut coeffs);
while coeffs.last() == Some(&F::zero()) {
coeffs.pop();
}
Expand Down
20 changes: 12 additions & 8 deletions src/core/poly/line.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,19 @@ impl<B: LinePolyOps<F>, F: Field> LinePoly<B, F> {
}

/// Returns the polynomial's coefficients in their natural order.
pub fn into_ordered_coefficients(self) -> Col<B, F> {
B::bit_reverse_column(self.coeffs)
pub fn into_ordered_coefficients(mut self) -> Col<B, F> {
B::bit_reverse_column(&mut self.coeffs);
self.coeffs
}

/// Creates a new line polynomial from coefficients in their natural order.
///
/// # Panics
///
/// Panics if the number of coefficients is not a power of two.
pub fn from_ordered_coefficients(coeffs: Col<B, F>) -> Self {
Self::new(B::bit_reverse_column(coeffs))
pub fn from_ordered_coefficients(mut coeffs: Col<B, F>) -> Self {
B::bit_reverse_column(&mut coeffs);
Self::new(coeffs)
}
}

Expand Down Expand Up @@ -204,19 +206,21 @@ impl<B: LinePolyOps<F>, F: Field> LineEvaluation<B, F> {
B::interpolate(self)
}

pub fn bit_reverse(self) -> LineEvaluation<B, F, BitReversedOrder> {
pub fn bit_reverse(mut self) -> LineEvaluation<B, F, BitReversedOrder> {
B::bit_reverse_column(&mut self.values);
LineEvaluation {
values: B::bit_reverse_column(self.values),
values: self.values,
domain: self.domain,
_eval_order: PhantomData,
}
}
}

impl<B: LinePolyOps<F>, F: Field> LineEvaluation<B, F, BitReversedOrder> {
pub fn bit_reverse(self) -> LineEvaluation<B, F, NaturalOrder> {
pub fn bit_reverse(mut self) -> LineEvaluation<B, F, NaturalOrder> {
B::bit_reverse_column(&mut self.values);
LineEvaluation {
values: B::bit_reverse_column(self.values),
values: self.values,
domain: self.domain,
_eval_order: PhantomData,
}
Expand Down
12 changes: 6 additions & 6 deletions src/core/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@ mod tests {
pub fn test_folded_queries() {
let log_domain_size = 7;
let domain = CanonicCoset::new(log_domain_size).circle_domain();
let values = domain.iter().collect::<Vec<_>>();
let values = bit_reverse(values.clone());
let mut values = domain.iter().collect::<Vec<_>>();
bit_reverse(&mut values);

let log_folded_domain_size = 5;
let folded_domain = CanonicCoset::new(log_folded_domain_size).circle_domain();
let folded_values = folded_domain.iter().collect::<Vec<_>>();
let folded_values = bit_reverse(folded_values.clone());
let mut folded_values = folded_domain.iter().collect::<Vec<_>>();
bit_reverse(&mut folded_values);

// Generate all possible queries.
let queries = Queries {
Expand Down Expand Up @@ -192,8 +192,8 @@ mod tests {
let channel = &mut Blake2sChannel::new(Blake2sHash::default());
let log_domain_size = 7;
let domain = CanonicCoset::new(log_domain_size).circle_domain();
let values = domain.iter().collect::<Vec<_>>();
let values = bit_reverse(values.clone());
let mut values = domain.iter().collect::<Vec<_>>();
bit_reverse(&mut values);

// Test random queries one by one because the conjugate queries are sorted.
for _ in 0..100 {
Expand Down
17 changes: 8 additions & 9 deletions src/core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@ pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize {
/// Panics if the length of the slice is not a power of two.
// TODO(AlonH): Consider benchmarking this function.
// TODO: Implement cache friendly implementation.
pub fn bit_reverse<T, U: AsMut<[T]>>(mut v: U) -> U {
let n = v.as_mut().len();
pub fn bit_reverse<T>(v: &mut [T]) {
let n = v.len();
assert!(n.is_power_of_two());
let log_n = n.ilog2();
for i in 0..n {
let j = bit_reverse_index(i, log_n);
if j > i {
v.as_mut().swap(i, j);
v.swap(i, j);
}
}
v
}

#[cfg(test)]
Expand All @@ -39,15 +38,15 @@ mod tests {

#[test]
fn bit_reverse_works() {
assert_eq!(
bit_reverse([0, 1, 2, 3, 4, 5, 6, 7]),
[0, 4, 2, 6, 1, 5, 3, 7]
);
let mut data = [0, 1, 2, 3, 4, 5, 6, 7];
bit_reverse(&mut data);
assert_eq!(data, [0, 4, 2, 6, 1, 5, 3, 7]);
}

#[test]
#[should_panic]
fn bit_reverse_non_power_of_two_size_fails() {
bit_reverse([0, 1, 2, 3, 4, 5]);
let mut data = [0, 1, 2, 3, 4, 5];
bit_reverse(&mut data);
}
}
Loading

0 comments on commit 47a0728

Please sign in to comment.