Skip to content

Commit

Permalink
Faster AVX256
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Feb 25, 2024
1 parent ff47970 commit 6659b9f
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 128 deletions.
2 changes: 1 addition & 1 deletion benches/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ pub fn avx512_m31_operations_bench(c: &mut criterion::Criterion) {
vec![M31AVX512::from_m512_unchecked(M512ONE); N_STATE_ELEMENTS];

for _ in 0..(N_ELEMENTS / K_BLOCK_SIZE) {
elements.push(M31AVX512::from_slice(
elements.push(M31AVX512::from_array(
&[get_random_m31_element(&mut rng); K_BLOCK_SIZE],
));
}
Expand Down
220 changes: 93 additions & 127 deletions src/core/fields/avx512_m31.rs
Original file line number Diff line number Diff line change
@@ -1,83 +1,42 @@
use core::arch::x86_64::{
__m256i, __m512i, _mm256_loadu_si256, _mm256_storeu_si256, _mm512_add_epi32, _mm512_add_epi64,
_mm512_and_epi64, _mm512_cvtepi64_epi32, _mm512_cvtepu32_epi64, _mm512_min_epu32,
_mm512_mul_epu32, _mm512_srli_epi64, _mm512_sub_epi32, _mm512_sub_epi64,
__m512i, _mm512_add_epi32, _mm512_add_epi64, _mm512_min_epu32, _mm512_mul_epu32,
_mm512_srli_epi64, _mm512_sub_epi32,
};
use std::arch::x86_64::_mm512_permutex2var_epi32;
use std::fmt::Display;
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};

use super::m31::{M31, MODULUS_BITS, P};
pub const K_BLOCK_SIZE: usize = 8;
pub const M512P: __m512i = unsafe { core::mem::transmute([P as u64; K_BLOCK_SIZE]) };
pub const M512ONE: __m512i = unsafe { core::mem::transmute([1u64; K_BLOCK_SIZE]) };
use super::m31::{M31, P};
pub const K_BLOCK_SIZE: usize = 16;
pub const M512P: __m512i = unsafe { core::mem::transmute([P; K_BLOCK_SIZE]) };

/// AVX512 implementation of M31.
/// Stores 16 M31 elements in a single 512-bit register.
/// Each M31 element is unreduced in the range [0, P].
#[derive(Copy, Clone, Debug)]
pub struct M31AVX512(__m512i);

impl M31AVX512 {
/// Given x1,...,x\[K_BLOCK_SIZE\] values, each in [0, 2*\[P\]), packed in
/// x, returns packed xi % \[P\].
/// If xi == 2*\[P\], then it reduces to \[P\].
/// Note that this function can be used for both reduced and unreduced
/// representations. [0, 2*\[P\]) -> [0, \[P\]), [0, 2*\[P\]] -> [0,
/// \[P\]].
#[inline(always)]
fn partial_reduce(x: __m512i) -> Self {
unsafe {
let x_minus_p = _mm512_sub_epi32(x, M512P);
Self(_mm512_min_epu32(x, x_minus_p))
}
}

/// Given x1,...,x\[K_BLOCK_SIZE\] values, each in [0, \[P\]^2), packed in
/// x, returns packed xi % \[P\].
/// If xi == \[P\]^2, then it reduces to \[P\].
/// Note that this function can be used for both reduced and unreduced
/// representations. [0, \[P\]^2) -> [0, \[P\]), [0, \[P\]^2] -> [0,
/// \[P\]].
#[inline(always)]
fn reduce(x: __m512i) -> Self {
unsafe {
let x_plus_one: __m512i = _mm512_add_epi64(x, M512ONE);

// z_i = x_i // P (integer division).
let z: __m512i = _mm512_srli_epi64(
_mm512_add_epi64(_mm512_srli_epi64(x, MODULUS_BITS), x_plus_one),
MODULUS_BITS,
);
let result: __m512i = _mm512_add_epi64(x, z);
Self(_mm512_and_epi64(result, M512P))
}
}

pub fn from_slice(v: &[M31]) -> M31AVX512 {
unsafe {
Self(_mm512_cvtepu32_epi64(_mm256_loadu_si256(
v.as_ptr() as *const __m256i
)))
}
pub fn from_array(v: [M31; K_BLOCK_SIZE]) -> M31AVX512 {
unsafe { Self(std::mem::transmute(v)) }
}

pub fn from_m512_unchecked(x: __m512i) -> Self {
Self(x)
}

pub fn to_vec(self) -> Vec<M31> {
unsafe {
let mut v = Vec::with_capacity(K_BLOCK_SIZE);
_mm256_storeu_si256(
v.as_mut_ptr() as *mut __m256i,
_mm512_cvtepi64_epi32(self.0),
);
v.set_len(K_BLOCK_SIZE);
v
}
pub fn to_array(self) -> [M31; K_BLOCK_SIZE] {
unsafe { std::mem::transmute(self.reduce()) }
}

pub fn reduce(self) -> M31AVX512 {
Self(unsafe { _mm512_min_epu32(self.0, _mm512_sub_epi32(self.0, M512P)) })
}
}

impl Display for M31AVX512 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let v = self.to_vec();
let v = self.to_array();
for elem in v.iter() {
write!(f, "{} ", elem)?;
}
Expand All @@ -90,7 +49,10 @@ impl Add for M31AVX512 {

#[inline(always)]
fn add(self, rhs: Self) -> Self::Output {
unsafe { Self::partial_reduce(_mm512_add_epi64(self.0, rhs.0)) }
Self(unsafe {
let x = _mm512_add_epi64(self.0, rhs.0);
_mm512_min_epu32(x, _mm512_sub_epi32(x, M512P))
})
}
}

Expand All @@ -106,7 +68,38 @@ impl Mul for M31AVX512 {

#[inline(always)]
fn mul(self, rhs: Self) -> Self::Output {
unsafe { Self::reduce(_mm512_mul_epu32(self.0, rhs.0)) }
const L: __m512i = unsafe {
core::mem::transmute([
0b00000, 0b10000, 0b00010, 0b10010, 0b00100, 0b10100, 0b00110, 0b10110, 0b01000,
0b11000, 0b01010, 0b11010, 0b01100, 0b11100, 0b01110, 0b11110,
])
};
const H: __m512i = unsafe {
core::mem::transmute([
0b00001, 0b10001, 0b00011, 0b10011, 0b00101, 0b10101, 0b00111, 0b10111, 0b01001,
0b11001, 0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111,
])
};
const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) };

unsafe {
let val0 = self.0;
let val1 = _mm512_add_epi32(rhs.0, rhs.0);
let val1_e = val1;
let val0_e = val0;
let val1_o = _mm512_srli_epi64(val1, 32);
let val0_o = _mm512_srli_epi64(val0, 32);
let m_e_dbl = _mm512_mul_epu32(val1_e, val0_e);
let m_o_dbl = _mm512_mul_epu32(val1_o, val0_o);

let rm_l = _mm512_srli_epi64(_mm512_permutex2var_epi32(m_e_dbl, L, m_o_dbl), 1);
let rm_h = _mm512_permutex2var_epi32(m_e_dbl, H, m_o_dbl);

let rm = _mm512_add_epi32(rm_l, rm_h);
let rm_m_p = _mm512_sub_epi32(rm, P);

Self(_mm512_min_epu32(rm, rm_m_p))
}
}
}

Expand All @@ -122,7 +115,7 @@ impl Neg for M31AVX512 {

#[inline(always)]
fn neg(self) -> Self::Output {
unsafe { Self::partial_reduce(_mm512_sub_epi64(M512P, self.0)) }
Self(unsafe { _mm512_sub_epi32(M512P, self.0) })
}
}

Expand All @@ -131,13 +124,10 @@ impl Sub for M31AVX512 {

#[inline(always)]
fn sub(self, rhs: Self) -> Self::Output {
unsafe {
let a_minus_b = _mm512_sub_epi32(self.0, rhs.0);
Self(_mm512_min_epu32(
a_minus_b,
_mm512_add_epi32(a_minus_b, M512P),
))
}
Self(unsafe {
let x = _mm512_sub_epi32(self.0, rhs.0);
_mm512_min_epu32(x, _mm512_add_epi32(x, M512P))
})
}
}

Expand All @@ -150,14 +140,12 @@ impl SubAssign for M31AVX512 {

#[cfg(test)]
mod tests {
use core::arch::x86_64::_mm512_loadu_epi64;

use rand::Rng;
use itertools::Itertools;

use super::{K_BLOCK_SIZE, M31AVX512};
use super::M31AVX512;
use crate::core::fields::m31::{M31, P};
use crate::core::fields::Field;
use crate::m31;

/// Tests field operations where field elements are in reduced form.
#[test]
Expand All @@ -166,66 +154,44 @@ mod tests {
return;
}

let values = [0, 1, 2, 10, (P - 1) / 2, (P + 1) / 2, P - 2, P - 1]
.map(M31::from_u32_unchecked)
.to_vec();
let avx_values = M31AVX512::from_slice(&values);
let values = [
0,
1,
2,
10,
(P - 1) / 2,
(P + 1) / 2,
P - 2,
P - 1,
0,
1,
2,
10,
(P - 1) / 2,
(P + 1) / 2,
P - 2,
P - 1,
]
.map(M31::from_u32_unchecked);
let avx_values = M31AVX512::from_array(values);

assert_eq!(
(avx_values + avx_values).to_vec(),
values.iter().map(|x| x.double()).collect::<Vec<_>>()
(avx_values + avx_values)
.to_array()
.into_iter()
.collect_vec(),
values.iter().map(|x| x.double()).collect_vec()
);
assert_eq!(
(avx_values * avx_values).to_vec(),
values.iter().map(|x| x.square()).collect::<Vec<_>>()
(avx_values * avx_values)
.to_array()
.into_iter()
.collect_vec(),
values.iter().map(|x| x.square()).collect_vec()
);
assert_eq!(
(-avx_values).to_vec(),
values.iter().map(|x| -*x).collect::<Vec<_>>()
);
}

/// Tests that reduce functions are correct.
#[test]
fn test_reduce() {
if !crate::platform::avx512_detected() {
return;
}
let mut rng = rand::thread_rng();

let const_values = [0, 1, (P + 1) / 2, P - 1, P, P + 1, 2 * P - 1, 2 * P];
let avx_const_values =
M31AVX512::from_slice(const_values.map(M31::from_u32_unchecked).as_ref());

// Tests partial reduce.
assert_eq!(
M31AVX512::partial_reduce(avx_const_values.0).to_vec(),
const_values
.iter()
.map(|x| m31!(if *x == 2 * P { P } else { x % P }))
.collect::<Vec<_>>()
);

// Generate random values in [0, P^2).
let rand_values = (0..K_BLOCK_SIZE)
.map(|_x| rng.gen::<u64>() % (P as u64).pow(2))
.collect::<Vec<u64>>();
let avx_rand_values = M31AVX512::from_m512_unchecked(unsafe {
_mm512_loadu_epi64(rand_values.as_ptr() as *const i64)
});

// Tests reduce.
assert_eq!(
M31AVX512::reduce(avx_const_values.0).to_vec(),
const_values.iter().map(|x| m31!(x % P)).collect::<Vec<_>>()
);

assert_eq!(
M31AVX512::reduce(avx_rand_values.0).to_vec(),
rand_values
.iter()
.map(|x| m31!((x % P as u64) as u32))
.collect::<Vec<_>>()
(-avx_values).to_array().into_iter().collect_vec(),
values.iter().map(|x| -*x).collect_vec()
);
}
}

0 comments on commit 6659b9f

Please sign in to comment.