Skip to content

Commit

Permalink
Faster AVX256
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Feb 26, 2024
1 parent 7fe6133 commit e7624e7
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 132 deletions.
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ merging-iterator = "1.3.0"
criterion = { version = "0.5.1", features = ["html_reports"] }
rand = "0.8.3"

[lints.rust]
warnings = "deny"
future-incompatible = "deny"
nonstandard-style = "deny"
rust-2018-idioms = "deny"
unused = "deny"

[features]
avx512 = []

Expand Down
10 changes: 5 additions & 5 deletions benches/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ pub fn qm31_operations_bench(c: &mut criterion::Criterion) {

#[cfg(target_arch = "x86_64")]
pub fn avx512_m31_operations_bench(c: &mut criterion::Criterion) {
use stwo::core::fields::avx512_m31::{K_BLOCK_SIZE, M31AVX512, M512ONE};
use stwo::core::fields::avx512_m31::{K_BLOCK_SIZE, M31AVX512};
use stwo::platform;

if !platform::avx512_detected() {
Expand All @@ -140,11 +140,11 @@ pub fn avx512_m31_operations_bench(c: &mut criterion::Criterion) {
let mut rng = rand::thread_rng();
let mut elements: Vec<M31AVX512> = Vec::new();
let mut states: Vec<M31AVX512> =
vec![M31AVX512::from_m512_unchecked(M512ONE); N_STATE_ELEMENTS];
vec![M31AVX512::from_array([1.into(); K_BLOCK_SIZE]); N_STATE_ELEMENTS];

for _ in 0..(N_ELEMENTS / K_BLOCK_SIZE) {
elements.push(M31AVX512::from_slice(
&[get_random_m31_element(&mut rng); K_BLOCK_SIZE],
elements.push(M31AVX512::from_array(
[get_random_m31_element(&mut rng); K_BLOCK_SIZE],
));
}

Expand Down Expand Up @@ -200,4 +200,4 @@ criterion::criterion_group!(
cm31_operations_bench,
qm31_operations_bench
);
criterion::criterion_main!(field_comparison);
criterion::criterion_main!(field_comparison, m31_benches);
264 changes: 137 additions & 127 deletions src/core/fields/avx512_m31.rs
Original file line number Diff line number Diff line change
@@ -1,83 +1,43 @@
use core::arch::x86_64::{
__m256i, __m512i, _mm256_loadu_si256, _mm256_storeu_si256, _mm512_add_epi32, _mm512_add_epi64,
_mm512_and_epi64, _mm512_cvtepi64_epi32, _mm512_cvtepu32_epi64, _mm512_min_epu32,
_mm512_mul_epu32, _mm512_srli_epi64, _mm512_sub_epi32, _mm512_sub_epi64,
__m512i, _mm512_add_epi32, _mm512_min_epu32, _mm512_mul_epu32, _mm512_srli_epi64,
_mm512_sub_epi32,
};
use std::arch::x86_64::_mm512_permutex2var_epi32;
use std::fmt::Display;
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};

use super::m31::{M31, MODULUS_BITS, P};
pub const K_BLOCK_SIZE: usize = 8;
pub const M512P: __m512i = unsafe { core::mem::transmute([P as u64; K_BLOCK_SIZE]) };
pub const M512ONE: __m512i = unsafe { core::mem::transmute([1u64; K_BLOCK_SIZE]) };
use super::m31::{M31, P};
pub const K_BLOCK_SIZE: usize = 16;
pub const M512P: __m512i = unsafe { core::mem::transmute([P; K_BLOCK_SIZE]) };

/// AVX512 implementation of M31.
/// Stores 16 M31 elements in a single 512-bit register.
/// Each M31 element is unreduced in the range [0, P].
#[derive(Copy, Clone, Debug)]
pub struct M31AVX512(__m512i);

impl M31AVX512 {
/// Given x1,...,x\[K_BLOCK_SIZE\] values, each in [0, 2*\[P\]), packed in
/// x, returns packed xi % \[P\].
/// If xi == 2*\[P\], then it reduces to \[P\].
/// Note that this function can be used for both reduced and unreduced
/// representations. [0, 2*\[P\]) -> [0, \[P\]), [0, 2*\[P\]] -> [0,
/// \[P\]].
#[inline(always)]
fn partial_reduce(x: __m512i) -> Self {
unsafe {
let x_minus_p = _mm512_sub_epi32(x, M512P);
Self(_mm512_min_epu32(x, x_minus_p))
}
}

/// Given x1,...,x\[K_BLOCK_SIZE\] values, each in [0, \[P\]^2), packed in
/// x, returns packed xi % \[P\].
/// If xi == \[P\]^2, then it reduces to \[P\].
/// Note that this function can be used for both reduced and unreduced
/// representations. [0, \[P\]^2) -> [0, \[P\]), [0, \[P\]^2] -> [0,
/// \[P\]].
#[inline(always)]
fn reduce(x: __m512i) -> Self {
unsafe {
let x_plus_one: __m512i = _mm512_add_epi64(x, M512ONE);

// z_i = x_i // P (integer division).
let z: __m512i = _mm512_srli_epi64(
_mm512_add_epi64(_mm512_srli_epi64(x, MODULUS_BITS), x_plus_one),
MODULUS_BITS,
);
let result: __m512i = _mm512_add_epi64(x, z);
Self(_mm512_and_epi64(result, M512P))
}
}

pub fn from_slice(v: &[M31]) -> M31AVX512 {
unsafe {
Self(_mm512_cvtepu32_epi64(_mm256_loadu_si256(
v.as_ptr() as *const __m256i
)))
}
pub fn from_array(v: [M31; K_BLOCK_SIZE]) -> M31AVX512 {
unsafe { Self(std::mem::transmute(v)) }
}

pub fn from_m512_unchecked(x: __m512i) -> Self {
Self(x)
}

pub fn to_vec(self) -> Vec<M31> {
unsafe {
let mut v = Vec::with_capacity(K_BLOCK_SIZE);
_mm256_storeu_si256(
v.as_mut_ptr() as *mut __m256i,
_mm512_cvtepi64_epi32(self.0),
);
v.set_len(K_BLOCK_SIZE);
v
}
pub fn to_array(self) -> [M31; K_BLOCK_SIZE] {
unsafe { std::mem::transmute(self.reduce()) }
}

/// Reduces each word in the 512-bit register to the range `[0, P)`, excluding P.
pub fn reduce(self) -> M31AVX512 {
Self(unsafe { _mm512_min_epu32(self.0, _mm512_sub_epi32(self.0, M512P)) })
}
}

impl Display for M31AVX512 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let v = self.to_vec();
let v = self.to_array();
for elem in v.iter() {
write!(f, "{} ", elem)?;
}
Expand All @@ -88,9 +48,18 @@ impl Display for M31AVX512 {
impl Add for M31AVX512 {
type Output = Self;

/// Adds two packed M31 elements, and reduces the result to the range [0,P].
/// Each value is assumed to be in unreduced form, [0, P] including P.
#[inline(always)]
fn add(self, rhs: Self) -> Self::Output {
unsafe { Self::partial_reduce(_mm512_add_epi64(self.0, rhs.0)) }
Self(unsafe {
// Add word by word. Each word is in the range [0, 2P].
let c = _mm512_add_epi32(self.0, rhs.0);
// Apply min(c, c-P) to each word.
// When c in [P,2P], then c-P in [0,P] which is always less than [P,2P].
// When c in [0,P-1], then c-P in [2^32-P,2^32-1] which is always greater than [0,P-1].
_mm512_min_epu32(c, _mm512_sub_epi32(c, M512P))
})
}
}

Expand All @@ -104,9 +73,70 @@ impl AddAssign for M31AVX512 {
impl Mul for M31AVX512 {
type Output = Self;

/// Computes the product of two packed M31 elements
/// Each value is assumed to be in unreduced form, [0, P] including P.
/// Returned values are in unreduced form, [0, P] including P.
#[inline(always)]
fn mul(self, rhs: Self) -> Self::Output {
unsafe { Self::reduce(_mm512_mul_epu32(self.0, rhs.0)) }
/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a
/// with the even words of b.
const EVENS_INTERLEAVE_EVENS: __m512i = unsafe {
core::mem::transmute([
0b00000, 0b10000, 0b00010, 0b10010, 0b00100, 0b10100, 0b00110, 0b10110, 0b01000,
0b11000, 0b01010, 0b11010, 0b01100, 0b11100, 0b01110, 0b11110,
])
};
/// An input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a
/// with the odd words of b.
const ODDS_INTERLEAVE_ODDS: __m512i = unsafe {
core::mem::transmute([
0b00001, 0b10001, 0b00011, 0b10011, 0b00101, 0b10101, 0b00111, 0b10111, 0b01001,
0b11001, 0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111,
])
};

unsafe {
// Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of
// the first operand.
let val0_e = self.0;
// Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of
// the first operand.
let val0_o = _mm512_srli_epi64(self.0, 32);

// Double the second operand.
let val1 = _mm512_add_epi32(rhs.0, rhs.0);
let val1_e = val1;
let val1_o = _mm512_srli_epi64(val1, 32);

// To compute prod = val0 * val1 start by multiplying
// val0_e/o by val1_e/o.
let prod_e_dbl = _mm512_mul_epu32(val0_e, val1_e);
let prod_o_dbl = _mm512_mul_epu32(val0_o, val1_o);

// The result of a multiplication holds val1*twiddle_dbl in as 64-bits.
// Each 64b-bit word looks like this:
// 1 31 31 1
// prod_e_dbl - |0|prod_e_h|prod_e_l|0|
// prod_o_dbl - |0|prod_o_h|prod_o_l|0|

// Interleave the even words of prod_e_dbl with the even words of prod_o_dbl:
let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl);
// prod_ls - |prod_o_l|0|prod_e_l|0|

// Divide by 2:
let prod_ls = Self(_mm512_srli_epi64(prod_ls, 1));
// prod_ls - |0|prod_o_l|0|prod_e_l|

// Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl:
let prod_hs = Self(_mm512_permutex2var_epi32(
prod_e_dbl,
ODDS_INTERLEAVE_ODDS,
prod_o_dbl,
));
// prod_hs - |0|prod_o_h|0|prod_e_h|

Self::add(prod_ls, prod_hs)
}
}
}

Expand All @@ -122,22 +152,26 @@ impl Neg for M31AVX512 {

#[inline(always)]
fn neg(self) -> Self::Output {
unsafe { Self::partial_reduce(_mm512_sub_epi64(M512P, self.0)) }
Self(unsafe { _mm512_sub_epi32(M512P, self.0) })
}
}

/// Subtracts two packed M31 elements, and reduces the result to the range [0,P].
/// Each value is assumed to be in unreduced form, [0, P] including P.
impl Sub for M31AVX512 {
type Output = Self;

#[inline(always)]
fn sub(self, rhs: Self) -> Self::Output {
unsafe {
let a_minus_b = _mm512_sub_epi32(self.0, rhs.0);
Self(_mm512_min_epu32(
a_minus_b,
_mm512_add_epi32(a_minus_b, M512P),
))
}
Self(unsafe {
// Subtract word by word. Each word is in the range [-P, P].
let c = _mm512_sub_epi32(self.0, rhs.0);
// Apply min(c, c+P) to each word.
// When c in [0,P], then c+P in [P,2P] which is always greater than [0,P].
// When c in [2^32-P,2^32-1], then c+P in [0,P-1] which is always less than
// [2^32-P,2^32-1].
_mm512_min_epu32(_mm512_add_epi32(c, M512P), c)
})
}
}

Expand All @@ -150,14 +184,12 @@ impl SubAssign for M31AVX512 {

#[cfg(test)]
mod tests {
use core::arch::x86_64::_mm512_loadu_epi64;

use rand::Rng;
use itertools::Itertools;

use super::{K_BLOCK_SIZE, M31AVX512};
use super::M31AVX512;
use crate::core::fields::m31::{M31, P};
use crate::core::fields::Field;
use crate::m31;

/// Tests field operations where field elements are in reduced form.
#[test]
Expand All @@ -166,66 +198,44 @@ mod tests {
return;
}

let values = [0, 1, 2, 10, (P - 1) / 2, (P + 1) / 2, P - 2, P - 1]
.map(M31::from_u32_unchecked)
.to_vec();
let avx_values = M31AVX512::from_slice(&values);
let values = [
0,
1,
2,
10,
(P - 1) / 2,
(P + 1) / 2,
P - 2,
P - 1,
0,
1,
2,
10,
(P - 1) / 2,
(P + 1) / 2,
P - 2,
P - 1,
]
.map(M31::from_u32_unchecked);
let avx_values = M31AVX512::from_array(values);

assert_eq!(
(avx_values + avx_values).to_vec(),
values.iter().map(|x| x.double()).collect::<Vec<_>>()
(avx_values + avx_values)
.to_array()
.into_iter()
.collect_vec(),
values.iter().map(|x| x.double()).collect_vec()
);
assert_eq!(
(avx_values * avx_values).to_vec(),
values.iter().map(|x| x.square()).collect::<Vec<_>>()
(avx_values * avx_values)
.to_array()
.into_iter()
.collect_vec(),
values.iter().map(|x| x.square()).collect_vec()
);
assert_eq!(
(-avx_values).to_vec(),
values.iter().map(|x| -*x).collect::<Vec<_>>()
);
}

/// Tests that reduce functions are correct.
#[test]
fn test_reduce() {
if !crate::platform::avx512_detected() {
return;
}
let mut rng = rand::thread_rng();

let const_values = [0, 1, (P + 1) / 2, P - 1, P, P + 1, 2 * P - 1, 2 * P];
let avx_const_values =
M31AVX512::from_slice(const_values.map(M31::from_u32_unchecked).as_ref());

// Tests partial reduce.
assert_eq!(
M31AVX512::partial_reduce(avx_const_values.0).to_vec(),
const_values
.iter()
.map(|x| m31!(if *x == 2 * P { P } else { x % P }))
.collect::<Vec<_>>()
);

// Generate random values in [0, P^2).
let rand_values = (0..K_BLOCK_SIZE)
.map(|_x| rng.gen::<u64>() % (P as u64).pow(2))
.collect::<Vec<u64>>();
let avx_rand_values = M31AVX512::from_m512_unchecked(unsafe {
_mm512_loadu_epi64(rand_values.as_ptr() as *const i64)
});

// Tests reduce.
assert_eq!(
M31AVX512::reduce(avx_const_values.0).to_vec(),
const_values.iter().map(|x| m31!(x % P)).collect::<Vec<_>>()
);

assert_eq!(
M31AVX512::reduce(avx_rand_values.0).to_vec(),
rand_values
.iter()
.map(|x| m31!((x % P as u64) as u32))
.collect::<Vec<_>>()
(-avx_values).to_array().into_iter().collect_vec(),
values.iter().map(|x| -*x).collect_vec()
);
}
}

0 comments on commit e7624e7

Please sign in to comment.