From cf22cc25ce6ff3d01387df8e7d709b511de854ad Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Sun, 25 Feb 2024 18:03:35 +0200 Subject: [PATCH] Faster AVX256 --- Cargo.toml | 7 ++ benches/field.rs | 10 +- benches/fri.rs | 2 +- src/core/fields/avx512_m31.rs | 220 ++++++++++++++-------------------- 4 files changed, 106 insertions(+), 133 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 32019cceb..f3a60797d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,13 @@ merging-iterator = "1.3.0" criterion = { version = "0.5.1", features = ["html_reports"] } rand = "0.8.3" +[lints.rust] +warnings = "deny" +future-incompatible = "deny" +nonstandard-style = "deny" +rust-2018-idioms = "deny" +unused = "deny" + [features] avx512 = [] diff --git a/benches/field.rs b/benches/field.rs index f01f4bf75..8f99a4cfe 100644 --- a/benches/field.rs +++ b/benches/field.rs @@ -130,7 +130,7 @@ pub fn qm31_operations_bench(c: &mut criterion::Criterion) { #[cfg(target_arch = "x86_64")] pub fn avx512_m31_operations_bench(c: &mut criterion::Criterion) { - use stwo::core::fields::avx512_m31::{K_BLOCK_SIZE, M31AVX512, M512ONE}; + use stwo::core::fields::avx512_m31::{K_BLOCK_SIZE, M31AVX512}; use stwo::platform; if !platform::avx512_detected() { @@ -140,11 +140,11 @@ pub fn avx512_m31_operations_bench(c: &mut criterion::Criterion) { let mut rng = rand::thread_rng(); let mut elements: Vec = Vec::new(); let mut states: Vec = - vec![M31AVX512::from_m512_unchecked(M512ONE); N_STATE_ELEMENTS]; + vec![M31AVX512::from_array([1.into(); K_BLOCK_SIZE]); N_STATE_ELEMENTS]; for _ in 0..(N_ELEMENTS / K_BLOCK_SIZE) { - elements.push(M31AVX512::from_slice( - &[get_random_m31_element(&mut rng); K_BLOCK_SIZE], + elements.push(M31AVX512::from_array( + [get_random_m31_element(&mut rng); K_BLOCK_SIZE], )); } @@ -200,4 +200,4 @@ criterion::criterion_group!( cm31_operations_bench, qm31_operations_bench ); -criterion::criterion_main!(field_comparison); +criterion::criterion_main!(field_comparison, m31_benches); diff --git a/benches/fri.rs b/benches/fri.rs index 24eec5751..cee0c76f9 100644 --- a/benches/fri.rs +++ b/benches/fri.rs @@ -6,7 +6,7 @@ use stwo::core::poly::line::{LineDomain, LineEvaluation}; fn folding_benchmark(c: &mut Criterion) { const LOG_SIZE: u32 = 12; - let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE).coset()); + let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset()); let evals = LineEvaluation::new( domain, vec![BaseField::from_u32_unchecked(712837213); 1 << LOG_SIZE], diff --git a/src/core/fields/avx512_m31.rs b/src/core/fields/avx512_m31.rs index 083b30a65..79fc2b690 100644 --- a/src/core/fields/avx512_m31.rs +++ b/src/core/fields/avx512_m31.rs @@ -1,83 +1,42 @@ use core::arch::x86_64::{ - __m256i, __m512i, _mm256_loadu_si256, _mm256_storeu_si256, _mm512_add_epi32, _mm512_add_epi64, - _mm512_and_epi64, _mm512_cvtepi64_epi32, _mm512_cvtepu32_epi64, _mm512_min_epu32, - _mm512_mul_epu32, _mm512_srli_epi64, _mm512_sub_epi32, _mm512_sub_epi64, + __m512i, _mm512_add_epi32, _mm512_add_epi64, _mm512_min_epu32, _mm512_mul_epu32, + _mm512_srli_epi64, _mm512_sub_epi32, }; +use std::arch::x86_64::_mm512_permutex2var_epi32; use std::fmt::Display; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use super::m31::{M31, MODULUS_BITS, P}; -pub const K_BLOCK_SIZE: usize = 8; -pub const M512P: __m512i = unsafe { core::mem::transmute([P as u64; K_BLOCK_SIZE]) }; -pub const M512ONE: __m512i = unsafe { core::mem::transmute([1u64; K_BLOCK_SIZE]) }; +use super::m31::{M31, P}; +pub const K_BLOCK_SIZE: usize = 16; +pub const M512P: __m512i = unsafe { core::mem::transmute([P; K_BLOCK_SIZE]) }; +/// AVX512 implementation of M31. +/// Stores 16 M31 elements in a single 512-bit register. +/// Each M31 element is unreduced in the range [0, P]. #[derive(Copy, Clone, Debug)] pub struct M31AVX512(__m512i); impl M31AVX512 { - /// Given x1,...,x\[K_BLOCK_SIZE\] values, each in [0, 2*\[P\]), packed in - /// x, returns packed xi % \[P\]. - /// If xi == 2*\[P\], then it reduces to \[P\]. - /// Note that this function can be used for both reduced and unreduced - /// representations. [0, 2*\[P\]) -> [0, \[P\]), [0, 2*\[P\]] -> [0, - /// \[P\]]. - #[inline(always)] - fn partial_reduce(x: __m512i) -> Self { - unsafe { - let x_minus_p = _mm512_sub_epi32(x, M512P); - Self(_mm512_min_epu32(x, x_minus_p)) - } - } - - /// Given x1,...,x\[K_BLOCK_SIZE\] values, each in [0, \[P\]^2), packed in - /// x, returns packed xi % \[P\]. - /// If xi == \[P\]^2, then it reduces to \[P\]. - /// Note that this function can be used for both reduced and unreduced - /// representations. [0, \[P\]^2) -> [0, \[P\]), [0, \[P\]^2] -> [0, - /// \[P\]]. - #[inline(always)] - fn reduce(x: __m512i) -> Self { - unsafe { - let x_plus_one: __m512i = _mm512_add_epi64(x, M512ONE); - - // z_i = x_i // P (integer division). - let z: __m512i = _mm512_srli_epi64( - _mm512_add_epi64(_mm512_srli_epi64(x, MODULUS_BITS), x_plus_one), - MODULUS_BITS, - ); - let result: __m512i = _mm512_add_epi64(x, z); - Self(_mm512_and_epi64(result, M512P)) - } - } - - pub fn from_slice(v: &[M31]) -> M31AVX512 { - unsafe { - Self(_mm512_cvtepu32_epi64(_mm256_loadu_si256( - v.as_ptr() as *const __m256i - ))) - } + pub fn from_array(v: [M31; K_BLOCK_SIZE]) -> M31AVX512 { + unsafe { Self(std::mem::transmute(v)) } } pub fn from_m512_unchecked(x: __m512i) -> Self { Self(x) } - pub fn to_vec(self) -> Vec { - unsafe { - let mut v = Vec::with_capacity(K_BLOCK_SIZE); - _mm256_storeu_si256( - v.as_mut_ptr() as *mut __m256i, - _mm512_cvtepi64_epi32(self.0), - ); - v.set_len(K_BLOCK_SIZE); - v - } + pub fn to_array(self) -> [M31; K_BLOCK_SIZE] { + unsafe { std::mem::transmute(self.reduce()) } + } + + pub fn reduce(self) -> M31AVX512 { + Self(unsafe { _mm512_min_epu32(self.0, _mm512_sub_epi32(self.0, M512P)) }) } } impl Display for M31AVX512 { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let v = self.to_vec(); + let v = self.to_array(); for elem in v.iter() { write!(f, "{} ", elem)?; } @@ -90,7 +49,10 @@ impl Add for M31AVX512 { #[inline(always)] fn add(self, rhs: Self) -> Self::Output { - unsafe { Self::partial_reduce(_mm512_add_epi64(self.0, rhs.0)) } + Self(unsafe { + let x = _mm512_add_epi64(self.0, rhs.0); + _mm512_min_epu32(x, _mm512_sub_epi32(x, M512P)) + }) } } @@ -106,7 +68,38 @@ impl Mul for M31AVX512 { #[inline(always)] fn mul(self, rhs: Self) -> Self::Output { - unsafe { Self::reduce(_mm512_mul_epu32(self.0, rhs.0)) } + const L: __m512i = unsafe { + core::mem::transmute([ + 0b00000, 0b10000, 0b00010, 0b10010, 0b00100, 0b10100, 0b00110, 0b10110, 0b01000, + 0b11000, 0b01010, 0b11010, 0b01100, 0b11100, 0b01110, 0b11110, + ]) + }; + const H: __m512i = unsafe { + core::mem::transmute([ + 0b00001, 0b10001, 0b00011, 0b10011, 0b00101, 0b10101, 0b00111, 0b10111, 0b01001, + 0b11001, 0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111, + ]) + }; + const P: __m512i = unsafe { core::mem::transmute([(1u32 << 31) - 1; 16]) }; + + unsafe { + let val0 = self.0; + let val1 = _mm512_add_epi32(rhs.0, rhs.0); + let val1_e = val1; + let val0_e = val0; + let val1_o = _mm512_srli_epi64(val1, 32); + let val0_o = _mm512_srli_epi64(val0, 32); + let m_e_dbl = _mm512_mul_epu32(val1_e, val0_e); + let m_o_dbl = _mm512_mul_epu32(val1_o, val0_o); + + let rm_l = _mm512_srli_epi64(_mm512_permutex2var_epi32(m_e_dbl, L, m_o_dbl), 1); + let rm_h = _mm512_permutex2var_epi32(m_e_dbl, H, m_o_dbl); + + let rm = _mm512_add_epi32(rm_l, rm_h); + let rm_m_p = _mm512_sub_epi32(rm, P); + + Self(_mm512_min_epu32(rm, rm_m_p)) + } } } @@ -122,7 +115,7 @@ impl Neg for M31AVX512 { #[inline(always)] fn neg(self) -> Self::Output { - unsafe { Self::partial_reduce(_mm512_sub_epi64(M512P, self.0)) } + Self(unsafe { _mm512_sub_epi32(M512P, self.0) }) } } @@ -131,13 +124,10 @@ impl Sub for M31AVX512 { #[inline(always)] fn sub(self, rhs: Self) -> Self::Output { - unsafe { - let a_minus_b = _mm512_sub_epi32(self.0, rhs.0); - Self(_mm512_min_epu32( - a_minus_b, - _mm512_add_epi32(a_minus_b, M512P), - )) - } + Self(unsafe { + let x = _mm512_sub_epi32(self.0, rhs.0); + _mm512_min_epu32(x, _mm512_add_epi32(x, M512P)) + }) } } @@ -150,14 +140,12 @@ impl SubAssign for M31AVX512 { #[cfg(test)] mod tests { - use core::arch::x86_64::_mm512_loadu_epi64; - use rand::Rng; + use itertools::Itertools; - use super::{K_BLOCK_SIZE, M31AVX512}; + use super::M31AVX512; use crate::core::fields::m31::{M31, P}; use crate::core::fields::Field; - use crate::m31; /// Tests field operations where field elements are in reduced form. #[test] @@ -166,66 +154,44 @@ mod tests { return; } - let values = [0, 1, 2, 10, (P - 1) / 2, (P + 1) / 2, P - 2, P - 1] - .map(M31::from_u32_unchecked) - .to_vec(); - let avx_values = M31AVX512::from_slice(&values); + let values = [ + 0, + 1, + 2, + 10, + (P - 1) / 2, + (P + 1) / 2, + P - 2, + P - 1, + 0, + 1, + 2, + 10, + (P - 1) / 2, + (P + 1) / 2, + P - 2, + P - 1, + ] + .map(M31::from_u32_unchecked); + let avx_values = M31AVX512::from_array(values); assert_eq!( - (avx_values + avx_values).to_vec(), - values.iter().map(|x| x.double()).collect::>() + (avx_values + avx_values) + .to_array() + .into_iter() + .collect_vec(), + values.iter().map(|x| x.double()).collect_vec() ); assert_eq!( - (avx_values * avx_values).to_vec(), - values.iter().map(|x| x.square()).collect::>() + (avx_values * avx_values) + .to_array() + .into_iter() + .collect_vec(), + values.iter().map(|x| x.square()).collect_vec() ); assert_eq!( - (-avx_values).to_vec(), - values.iter().map(|x| -*x).collect::>() - ); - } - - /// Tests that reduce functions are correct. - #[test] - fn test_reduce() { - if !crate::platform::avx512_detected() { - return; - } - let mut rng = rand::thread_rng(); - - let const_values = [0, 1, (P + 1) / 2, P - 1, P, P + 1, 2 * P - 1, 2 * P]; - let avx_const_values = - M31AVX512::from_slice(const_values.map(M31::from_u32_unchecked).as_ref()); - - // Tests partial reduce. - assert_eq!( - M31AVX512::partial_reduce(avx_const_values.0).to_vec(), - const_values - .iter() - .map(|x| m31!(if *x == 2 * P { P } else { x % P })) - .collect::>() - ); - - // Generate random values in [0, P^2). - let rand_values = (0..K_BLOCK_SIZE) - .map(|_x| rng.gen::() % (P as u64).pow(2)) - .collect::>(); - let avx_rand_values = M31AVX512::from_m512_unchecked(unsafe { - _mm512_loadu_epi64(rand_values.as_ptr() as *const i64) - }); - - // Tests reduce. - assert_eq!( - M31AVX512::reduce(avx_const_values.0).to_vec(), - const_values.iter().map(|x| m31!(x % P)).collect::>() - ); - - assert_eq!( - M31AVX512::reduce(avx_rand_values.0).to_vec(), - rand_values - .iter() - .map(|x| m31!((x % P as u64) as u32)) - .collect::>() + (-avx_values).to_array().into_iter().collect_vec(), + values.iter().map(|x| -*x).collect_vec() ); } }