Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AVX backend #364

Merged
merged 1 commit into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ itertools = "0.12.0"
num-traits = "0.2.17"
thiserror = "1.0.56"
merging-iterator = "1.3.0"
bytemuck = { version = "1.14.3", features = ["derive"] }

[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
Expand Down
5 changes: 3 additions & 2 deletions benches/bit_rev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ pub fn cpu_bit_rev(c: &mut criterion::Criterion) {

c.bench_function("cpu bit_rev", |b| {
b.iter(|| {
data = stwo::core::utils::bit_reverse(std::mem::take(&mut data));
stwo::core::utils::bit_reverse(&mut data);
})
});
}

#[cfg(target_arch = "x86_64")]
pub fn avx512_bit_rev(c: &mut criterion::Criterion) {
use bytemuck::cast_slice_mut;
use stwo::core::backend::avx512::bit_reverse::bit_reverse_m31;
use stwo::core::fields::m31::BaseField;
use stwo::platform;
Expand All @@ -35,7 +36,7 @@ pub fn avx512_bit_rev(c: &mut criterion::Criterion) {

c.bench_function("avx bit_rev", |b| {
b.iter(|| {
bit_reverse_m31(&mut data);
bit_reverse_m31(cast_slice_mut(&mut data[..]));
})
});
}
Expand Down
23 changes: 12 additions & 11 deletions src/core/backend/avx512/bit_reverse.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use std::arch::x86_64::{__m512i, _mm512_permutex2var_epi32};

use crate::core::fields::m31::BaseField;
use super::PackedBaseField;
use crate::core::utils::bit_reverse_index;

const VEC_BITS: u32 = 4;
const W_BITS: u32 = 3;
const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS;
pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS;

// TODO(spapini): Use PackedBaseField type.
/// Bit reverses packed M31 values.
/// Given an array A[0..2^n), computes B[i] = A[bit_reverse(i)].
pub fn bit_reverse_m31(data: &mut [[BaseField; 16]]) {
/// Given an array `A[0..2^n)`, computes `B[i] = A[bit_reverse(i)]`.
pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {
assert!(data.len().is_power_of_two());
assert!(data.len().ilog2() >= MIN_LOG_SIZE);

Expand Down Expand Up @@ -73,8 +72,8 @@ pub fn bit_reverse_m31(data: &mut [[BaseField; 16]]) {
}
}

/// Bit reverses 16 packed M31 values.
fn bit_reverse16(data: [[BaseField; 16]; 16]) -> [[BaseField; 16]; 16] {
/// Bit reverses 256 M31 values, packed in 16 words of 16 elements each.
fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] {
let mut data: [__m512i; 16] = unsafe { std::mem::transmute(data) };
// L is an input to _mm512_permutex2var_epi32, and it is used to
// interleave the first half of a with the first half of b.
Expand Down Expand Up @@ -140,6 +139,7 @@ fn bit_reverse16(data: [[BaseField; 16]; 16]) -> [[BaseField; 16]; 16] {
mod tests {
use super::bit_reverse16;
use crate::core::backend::avx512::bit_reverse::bit_reverse_m31;
use crate::core::backend::avx512::BaseFieldVec;
use crate::core::fields::m31::BaseField;
use crate::core::utils::bit_reverse;

Expand All @@ -159,11 +159,12 @@ mod tests {
let data: Vec<_> = (0..SIZE as u32)
.map(BaseField::from_u32_unchecked)
.collect();
let expected = bit_reverse(data.clone());
let mut data: Vec<_> = data.into_iter().array_chunks::<16>().collect();
let expected: Vec<_> = expected.into_iter().array_chunks::<16>().collect();
let mut expected = data.clone();
bit_reverse(&mut expected);
let mut data: BaseFieldVec = data.into_iter().collect();
let expected: BaseFieldVec = expected.into_iter().collect();

bit_reverse_m31(&mut data);
bit_reverse_m31(&mut data.data[..]);
assert_eq!(data, expected);
}
}
145 changes: 145 additions & 0 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,146 @@
pub mod bit_reverse;

use std::ops::Index;

use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable};
use num_traits::Zero;

use self::bit_reverse::bit_reverse_m31;
use crate::core::fields::m31::BaseField;
use crate::core::fields::{Column, FieldOps};
use crate::core::utils;

#[derive(Copy, Clone, Debug)]
pub struct AVX512Backend;

// BaseField.
// TODO(spapini): Unite with the M31AVX512 type.
pub const K_ELEMENTS: usize = 16;

#[repr(align(64))]
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct PackedBaseField([BaseField; K_ELEMENTS]);
unsafe impl Pod for PackedBaseField {}
unsafe impl Zeroable for PackedBaseField {
fn zeroed() -> Self {
unsafe { core::mem::zeroed() }
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BaseFieldVec {
pub data: Vec<PackedBaseField>,
length: usize,
}

impl BaseFieldVec {
pub fn as_slice(&self) -> &[BaseField] {
let data: &[BaseField] = cast_slice(&self.data[..]);
&data[..self.length]
}
pub fn as_mut_slice(&mut self) -> &mut [BaseField] {
let data: &mut [BaseField] = cast_slice_mut(&mut self.data[..]);
&mut data[..self.length]
}
}

impl FieldOps<BaseField> for AVX512Backend {
type Column = BaseFieldVec;

fn bit_reverse_column(column: &mut Self::Column) {
// Fallback to cpu bit_reverse.
if column.data.len().ilog2() < bit_reverse::MIN_LOG_SIZE {
utils::bit_reverse(column.as_mut_slice());
return;
}
bit_reverse_m31(&mut column.data);
}
}

impl Column<BaseField> for BaseFieldVec {
fn zeros(len: usize) -> Self {
Self {
data: vec![PackedBaseField::default(); len.div_ceil(K_ELEMENTS)],
length: len,
}
}
fn to_vec(&self) -> Vec<BaseField> {
self.data
.iter()
.flat_map(|x| x.0)
.take(self.length)
.collect()
}
fn len(&self) -> usize {
self.length
}
}

impl Index<usize> for BaseFieldVec {
type Output = BaseField;
fn index(&self, index: usize) -> &Self::Output {
&self.data[index / K_ELEMENTS].0[index % K_ELEMENTS]
}
}

impl FromIterator<BaseField> for BaseFieldVec {
fn from_iter<I: IntoIterator<Item = BaseField>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
let mut res: Vec<_> = (&mut chunks).map(PackedBaseField).collect();
let mut length = res.len() * K_ELEMENTS;

if let Some(remainder) = chunks.into_remainder() {
if !remainder.is_empty() {
length += remainder.len();
let pad_len = 16 - remainder.len();
let last = PackedBaseField(
remainder
.chain(std::iter::repeat(BaseField::zero()).take(pad_len))
.collect::<Vec<_>>()
.try_into()
.unwrap(),
);
res.push(last);
}
}

Self { data: res, length }
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::core::fields::{Col, Column};

type B = AVX512Backend;

#[test]
fn test_column() {
for i in 0..100 {
let col = Col::<B, BaseField>::from_iter((0..i).map(BaseField::from));
assert_eq!(
col.to_vec(),
(0..i).map(BaseField::from).collect::<Vec<_>>()
);
for j in 0..i {
assert_eq!(col[j], BaseField::from(j));
}
}
}

#[test]
fn test_bit_reverse() {
for i in 1..16 {
let len = 1 << i;
let mut col = Col::<B, BaseField>::from_iter((0..len).map(BaseField::from));
B::bit_reverse_column(&mut col);
assert_eq!(
col.to_vec(),
(0..len)
.map(|x| BaseField::from(utils::bit_reverse_index(x, i as u32)))
.collect::<Vec<_>>()
);
}
}
}
2 changes: 1 addition & 1 deletion src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl Backend for CPUBackend {}
impl<F: Field> FieldOps<F> for CPUBackend {
type Column = Vec<F>;

fn bit_reverse_column(column: Self::Column) -> Self::Column {
fn bit_reverse_column(column: &mut Self::Column) {
bit_reverse(column)
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/core/fields/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ use std::ops::{
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
};

use bytemuck::{Pod, Zeroable};

use super::ComplexConjugate;
use crate::impl_field;

pub const MODULUS_BITS: u32 = 31;
pub const N_BYTES_FELT: usize = 4;
pub const P: u32 = 2147483647; // 2 ** 31 - 1

#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Pod, Zeroable)]
pub struct M31(u32);
pub type BaseField = M31;

Expand Down
2 changes: 1 addition & 1 deletion src/core/fields/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub mod qm31;

pub trait FieldOps<F: Field> {
type Column: Column<F>;
fn bit_reverse_column(column: Self::Column) -> Self::Column;
fn bit_reverse_column(column: &mut Self::Column);
}

pub type Col<B, F> = <B as FieldOps<F>>::Column;
Expand Down
13 changes: 8 additions & 5 deletions src/core/poly/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,9 @@ impl<F: ExtensionOf<BaseField>, B: PolyOps<F>> CircleEvaluation<B, F> {
self.values[self.domain.find(point_index).expect("Not in domain")]
}

pub fn bit_reverse(self) -> CircleEvaluation<B, F, BitReversedOrder> {
CircleEvaluation::new(self.domain, B::bit_reverse_column(self.values))
pub fn bit_reverse(mut self) -> CircleEvaluation<B, F, BitReversedOrder> {
B::bit_reverse_column(&mut self.values);
CircleEvaluation::new(self.domain, self.values)
}
}

Expand All @@ -281,8 +282,9 @@ impl<F: ExtensionOf<BaseField>> CPUCircleEvaluation<F> {
}

impl<B: PolyOps<F>, F: ExtensionOf<BaseField>> CircleEvaluation<B, F, BitReversedOrder> {
pub fn bit_reverse(self) -> CircleEvaluation<B, F, NaturalOrder> {
CircleEvaluation::new(self.domain, B::bit_reverse_column(self.values))
pub fn bit_reverse(mut self) -> CircleEvaluation<B, F, NaturalOrder> {
B::bit_reverse_column(&mut self.values);
CircleEvaluation::new(self.domain, self.values)
}

pub fn get_at(&self, point_index: CirclePointIndex) -> F {
Expand Down Expand Up @@ -389,7 +391,8 @@ impl<F: ExtensionOf<BaseField>, B: PolyOps<F>> CirclePoly<B, F> {
#[cfg(test)]
impl<F: ExtensionOf<BaseField>> crate::core::backend::cpu::CPUCirclePoly<F> {
pub fn is_in_fft_space(&self, log_fft_size: u32) -> bool {
let mut coeffs = crate::core::utils::bit_reverse(self.coeffs.clone());
let mut coeffs = self.coeffs.clone();
crate::core::utils::bit_reverse(&mut coeffs);
while coeffs.last() == Some(&F::zero()) {
coeffs.pop();
}
Expand Down
Loading
Loading