Skip to content

Commit

Permalink
AVX backend (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored Feb 26, 2024
1 parent 89ec9e9 commit fe76a48
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 49 deletions.
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ itertools = "0.12.0"
num-traits = "0.2.17"
thiserror = "1.0.56"
merging-iterator = "1.3.0"
bytemuck = { version = "1.14.3", features = ["derive"] }

[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
Expand Down
5 changes: 3 additions & 2 deletions benches/bit_rev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ pub fn cpu_bit_rev(c: &mut criterion::Criterion) {

c.bench_function("cpu bit_rev", |b| {
b.iter(|| {
data = stwo::core::utils::bit_reverse(std::mem::take(&mut data));
stwo::core::utils::bit_reverse(&mut data);
})
});
}

#[cfg(target_arch = "x86_64")]
pub fn avx512_bit_rev(c: &mut criterion::Criterion) {
use bytemuck::cast_slice_mut;
use stwo::core::backend::avx512::bit_reverse::bit_reverse_m31;
use stwo::core::fields::m31::BaseField;
use stwo::platform;
Expand All @@ -35,7 +36,7 @@ pub fn avx512_bit_rev(c: &mut criterion::Criterion) {

c.bench_function("avx bit_rev", |b| {
b.iter(|| {
bit_reverse_m31(&mut data);
bit_reverse_m31(cast_slice_mut(&mut data[..]));
})
});
}
Expand Down
23 changes: 12 additions & 11 deletions src/core/backend/avx512/bit_reverse.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use std::arch::x86_64::{__m512i, _mm512_permutex2var_epi32};

use crate::core::fields::m31::BaseField;
use super::PackedBaseField;
use crate::core::utils::bit_reverse_index;

const VEC_BITS: u32 = 4;
const W_BITS: u32 = 3;
const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS;
pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS;

// TODO(spapini): Use PackedBaseField type.
/// Bit reverses packed M31 values.
/// Given an array A[0..2^n), computes B[i] = A[bit_reverse(i)].
pub fn bit_reverse_m31(data: &mut [[BaseField; 16]]) {
/// Given an array `A[0..2^n)`, computes `B[i] = A[bit_reverse(i)]`.
pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {
assert!(data.len().is_power_of_two());
assert!(data.len().ilog2() >= MIN_LOG_SIZE);

Expand Down Expand Up @@ -73,8 +72,8 @@ pub fn bit_reverse_m31(data: &mut [[BaseField; 16]]) {
}
}

/// Bit reverses 16 packed M31 values.
fn bit_reverse16(data: [[BaseField; 16]; 16]) -> [[BaseField; 16]; 16] {
/// Bit reverses 256 M31 values, packed in 16 words of 16 elements each.
fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] {
let mut data: [__m512i; 16] = unsafe { std::mem::transmute(data) };
// L is an input to _mm512_permutex2var_epi32, and it is used to
// interleave the first half of a with the first half of b.
Expand Down Expand Up @@ -140,6 +139,7 @@ fn bit_reverse16(data: [[BaseField; 16]; 16]) -> [[BaseField; 16]; 16] {
mod tests {
use super::bit_reverse16;
use crate::core::backend::avx512::bit_reverse::bit_reverse_m31;
use crate::core::backend::avx512::BaseFieldVec;
use crate::core::fields::m31::BaseField;
use crate::core::utils::bit_reverse;

Expand All @@ -159,11 +159,12 @@ mod tests {
let data: Vec<_> = (0..SIZE as u32)
.map(BaseField::from_u32_unchecked)
.collect();
let expected = bit_reverse(data.clone());
let mut data: Vec<_> = data.into_iter().array_chunks::<16>().collect();
let expected: Vec<_> = expected.into_iter().array_chunks::<16>().collect();
let mut expected = data.clone();
bit_reverse(&mut expected);
let mut data: BaseFieldVec = data.into_iter().collect();
let expected: BaseFieldVec = expected.into_iter().collect();

bit_reverse_m31(&mut data);
bit_reverse_m31(&mut data.data[..]);
assert_eq!(data, expected);
}
}
145 changes: 145 additions & 0 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,146 @@
pub mod bit_reverse;

use std::ops::Index;

use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable};
use num_traits::Zero;

use self::bit_reverse::bit_reverse_m31;
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Column, FieldOps};
use crate::core::utils;

#[derive(Copy, Clone, Debug)]
pub struct AVX512Backend;

// BaseField.
// TODO(spapini): Unite with the M31AVX512 type.
pub const K_ELEMENTS: usize = 16;

#[repr(align(64))]
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct PackedBaseField([BaseField; K_ELEMENTS]);
unsafe impl Pod for PackedBaseField {}
unsafe impl Zeroable for PackedBaseField {
fn zeroed() -> Self {
unsafe { core::mem::zeroed() }
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BaseFieldVec {
pub data: Vec<PackedBaseField>,
length: usize,
}

impl BaseFieldVec {
pub fn as_slice(&self) -> &[BaseField] {
let data: &[BaseField] = cast_slice(&self.data[..]);
&data[..self.length]
}
pub fn as_mut_slice(&mut self) -> &mut [BaseField] {
let data: &mut [BaseField] = cast_slice_mut(&mut self.data[..]);
&mut data[..self.length]
}
}

impl FieldOps<BaseField> for AVX512Backend {
type Column = BaseFieldVec;

fn bit_reverse_column(column: &mut Self::Column) {
// Fallback to cpu bit_reverse.
if column.data.len().ilog2() < bit_reverse::MIN_LOG_SIZE {
utils::bit_reverse(column.as_mut_slice());
return;
}
bit_reverse_m31(&mut column.data);
}
}

impl Column<BaseField> for BaseFieldVec {
fn zeros(len: usize) -> Self {
Self {
data: vec![PackedBaseField::default(); len.div_ceil(K_ELEMENTS)],
length: len,
}
}
fn to_vec(&self) -> Vec<BaseField> {
self.data
.iter()
.flat_map(|x| x.0)
.take(self.length)
.collect()
}
fn len(&self) -> usize {
self.length
}
}

impl Index<usize> for BaseFieldVec {
type Output = BaseField;
fn index(&self, index: usize) -> &Self::Output {
&self.data[index / K_ELEMENTS].0[index % K_ELEMENTS]
}
}

impl FromIterator<BaseField> for BaseFieldVec {
fn from_iter<I: IntoIterator<Item = BaseField>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
let mut res: Vec<_> = (&mut chunks).map(PackedBaseField).collect();
let mut length = res.len() * K_ELEMENTS;

if let Some(remainder) = chunks.into_remainder() {
if !remainder.is_empty() {
length += remainder.len();
let pad_len = 16 - remainder.len();
let last = PackedBaseField(
remainder
.chain(std::iter::repeat(BaseField::zero()).take(pad_len))
.collect::<Vec<_>>()
.try_into()
.unwrap(),
);
res.push(last);
}
}

Self { data: res, length }
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::core::fields::{Col, Column};

type B = AVX512Backend;

#[test]
fn test_column() {
for i in 0..100 {
let col = Col::<B, BaseField>::from_iter((0..i).map(BaseField::from));
assert_eq!(
col.to_vec(),
(0..i).map(BaseField::from).collect::<Vec<_>>()
);
for j in 0..i {
assert_eq!(col[j], BaseField::from(j));
}
}
}

#[test]
fn test_bit_reverse() {
for i in 1..16 {
let len = 1 << i;
let mut col = Col::<B, BaseField>::from_iter((0..len).map(BaseField::from));
B::bit_reverse_column(&mut col);
assert_eq!(
col.to_vec(),
(0..len)
.map(|x| BaseField::from(utils::bit_reverse_index(x, i as u32)))
.collect::<Vec<_>>()
);
}
}
}
2 changes: 1 addition & 1 deletion src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl Backend for CPUBackend {}
impl<F: Field> FieldOps<F> for CPUBackend {
type Column = Vec<F>;

fn bit_reverse_column(column: Self::Column) -> Self::Column {
fn bit_reverse_column(column: &mut Self::Column) {
bit_reverse(column)
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/core/fields/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ use std::ops::{
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
};

use bytemuck::{Pod, Zeroable};

use super::ComplexConjugate;
use crate::impl_field;

pub const MODULUS_BITS: u32 = 31;
pub const N_BYTES_FELT: usize = 4;
pub const P: u32 = 2147483647; // 2 ** 31 - 1

#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Pod, Zeroable)]
pub struct M31(u32);
pub type BaseField = M31;

Expand Down
2 changes: 1 addition & 1 deletion src/core/fields/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub mod qm31;

pub trait FieldOps<F: Field> {
type Column: Column<F>;
fn bit_reverse_column(column: Self::Column) -> Self::Column;
fn bit_reverse_column(column: &mut Self::Column);
}

pub type Col<B, F> = <B as FieldOps<F>>::Column;
Expand Down
13 changes: 8 additions & 5 deletions src/core/poly/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,9 @@ impl<F: ExtensionOf<BaseField>, B: PolyOps<F>> CircleEvaluation<B, F> {
self.values[self.domain.find(point_index).expect("Not in domain")]
}

pub fn bit_reverse(self) -> CircleEvaluation<B, F, BitReversedOrder> {
CircleEvaluation::new(self.domain, B::bit_reverse_column(self.values))
pub fn bit_reverse(mut self) -> CircleEvaluation<B, F, BitReversedOrder> {
B::bit_reverse_column(&mut self.values);
CircleEvaluation::new(self.domain, self.values)
}
}

Expand All @@ -281,8 +282,9 @@ impl<F: ExtensionOf<BaseField>> CPUCircleEvaluation<F> {
}

impl<B: PolyOps<F>, F: ExtensionOf<BaseField>> CircleEvaluation<B, F, BitReversedOrder> {
pub fn bit_reverse(self) -> CircleEvaluation<B, F, NaturalOrder> {
CircleEvaluation::new(self.domain, B::bit_reverse_column(self.values))
pub fn bit_reverse(mut self) -> CircleEvaluation<B, F, NaturalOrder> {
B::bit_reverse_column(&mut self.values);
CircleEvaluation::new(self.domain, self.values)
}

pub fn get_at(&self, point_index: CirclePointIndex) -> F {
Expand Down Expand Up @@ -389,7 +391,8 @@ impl<F: ExtensionOf<BaseField>, B: PolyOps<F>> CirclePoly<B, F> {
#[cfg(test)]
impl<F: ExtensionOf<BaseField>> crate::core::backend::cpu::CPUCirclePoly<F> {
pub fn is_in_fft_space(&self, log_fft_size: u32) -> bool {
let mut coeffs = crate::core::utils::bit_reverse(self.coeffs.clone());
let mut coeffs = self.coeffs.clone();
crate::core::utils::bit_reverse(&mut coeffs);
while coeffs.last() == Some(&F::zero()) {
coeffs.pop();
}
Expand Down
Loading

0 comments on commit fe76a48

Please sign in to comment.