diff --git a/Cargo.toml b/Cargo.toml index 38f933ef6..27ab07093 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,13 @@ merging-iterator = "1.3.0" criterion = { version = "0.5.1", features = ["html_reports"] } rand = "0.8.3" +[lints.rust] +warnings = "deny" +future-incompatible = "deny" +nonstandard-style = "deny" +rust-2018-idioms = "deny" +unused = "deny" + [features] avx512 = [] diff --git a/benches/field.rs b/benches/field.rs index f01f4bf75..8f99a4cfe 100644 --- a/benches/field.rs +++ b/benches/field.rs @@ -130,7 +130,7 @@ pub fn qm31_operations_bench(c: &mut criterion::Criterion) { #[cfg(target_arch = "x86_64")] pub fn avx512_m31_operations_bench(c: &mut criterion::Criterion) { - use stwo::core::fields::avx512_m31::{K_BLOCK_SIZE, M31AVX512, M512ONE}; + use stwo::core::fields::avx512_m31::{K_BLOCK_SIZE, M31AVX512}; use stwo::platform; if !platform::avx512_detected() { @@ -140,11 +140,11 @@ pub fn avx512_m31_operations_bench(c: &mut criterion::Criterion) { let mut rng = rand::thread_rng(); let mut elements: Vec = Vec::new(); let mut states: Vec = - vec![M31AVX512::from_m512_unchecked(M512ONE); N_STATE_ELEMENTS]; + vec![M31AVX512::from_array([1.into(); K_BLOCK_SIZE]); N_STATE_ELEMENTS]; for _ in 0..(N_ELEMENTS / K_BLOCK_SIZE) { - elements.push(M31AVX512::from_slice( - &[get_random_m31_element(&mut rng); K_BLOCK_SIZE], + elements.push(M31AVX512::from_array( + [get_random_m31_element(&mut rng); K_BLOCK_SIZE], )); } @@ -200,4 +200,4 @@ criterion::criterion_group!( cm31_operations_bench, qm31_operations_bench ); -criterion::criterion_main!(field_comparison); +criterion::criterion_main!(field_comparison, m31_benches); diff --git a/src/core/fields/avx512_m31.rs b/src/core/fields/avx512_m31.rs index 083b30a65..d4e40d87d 100644 --- a/src/core/fields/avx512_m31.rs +++ b/src/core/fields/avx512_m31.rs @@ -1,83 +1,43 @@ use core::arch::x86_64::{ - __m256i, __m512i, _mm256_loadu_si256, _mm256_storeu_si256, _mm512_add_epi32, _mm512_add_epi64, - _mm512_and_epi64, _mm512_cvtepi64_epi32, _mm512_cvtepu32_epi64, _mm512_min_epu32, - _mm512_mul_epu32, _mm512_srli_epi64, _mm512_sub_epi32, _mm512_sub_epi64, + __m512i, _mm512_add_epi32, _mm512_min_epu32, _mm512_mul_epu32, _mm512_srli_epi64, + _mm512_sub_epi32, }; +use std::arch::x86_64::_mm512_permutex2var_epi32; use std::fmt::Display; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use super::m31::{M31, MODULUS_BITS, P}; -pub const K_BLOCK_SIZE: usize = 8; -pub const M512P: __m512i = unsafe { core::mem::transmute([P as u64; K_BLOCK_SIZE]) }; -pub const M512ONE: __m512i = unsafe { core::mem::transmute([1u64; K_BLOCK_SIZE]) }; +use super::m31::{M31, P}; +pub const K_BLOCK_SIZE: usize = 16; +pub const M512P: __m512i = unsafe { core::mem::transmute([P; K_BLOCK_SIZE]) }; +/// AVX512 implementation of M31. +/// Stores 16 M31 elements in a single 512-bit register. +/// Each M31 element is unreduced in the range [0, P]. #[derive(Copy, Clone, Debug)] pub struct M31AVX512(__m512i); impl M31AVX512 { - /// Given x1,...,x\[K_BLOCK_SIZE\] values, each in [0, 2*\[P\]), packed in - /// x, returns packed xi % \[P\]. - /// If xi == 2*\[P\], then it reduces to \[P\]. - /// Note that this function can be used for both reduced and unreduced - /// representations. [0, 2*\[P\]) -> [0, \[P\]), [0, 2*\[P\]] -> [0, - /// \[P\]]. - #[inline(always)] - fn partial_reduce(x: __m512i) -> Self { - unsafe { - let x_minus_p = _mm512_sub_epi32(x, M512P); - Self(_mm512_min_epu32(x, x_minus_p)) - } - } - - /// Given x1,...,x\[K_BLOCK_SIZE\] values, each in [0, \[P\]^2), packed in - /// x, returns packed xi % \[P\]. - /// If xi == \[P\]^2, then it reduces to \[P\]. - /// Note that this function can be used for both reduced and unreduced - /// representations. [0, \[P\]^2) -> [0, \[P\]), [0, \[P\]^2] -> [0, - /// \[P\]]. - #[inline(always)] - fn reduce(x: __m512i) -> Self { - unsafe { - let x_plus_one: __m512i = _mm512_add_epi64(x, M512ONE); - - // z_i = x_i // P (integer division). - let z: __m512i = _mm512_srli_epi64( - _mm512_add_epi64(_mm512_srli_epi64(x, MODULUS_BITS), x_plus_one), - MODULUS_BITS, - ); - let result: __m512i = _mm512_add_epi64(x, z); - Self(_mm512_and_epi64(result, M512P)) - } - } - - pub fn from_slice(v: &[M31]) -> M31AVX512 { - unsafe { - Self(_mm512_cvtepu32_epi64(_mm256_loadu_si256( - v.as_ptr() as *const __m256i - ))) - } + pub fn from_array(v: [M31; K_BLOCK_SIZE]) -> M31AVX512 { + unsafe { Self(std::mem::transmute(v)) } } pub fn from_m512_unchecked(x: __m512i) -> Self { Self(x) } - pub fn to_vec(self) -> Vec { - unsafe { - let mut v = Vec::with_capacity(K_BLOCK_SIZE); - _mm256_storeu_si256( - v.as_mut_ptr() as *mut __m256i, - _mm512_cvtepi64_epi32(self.0), - ); - v.set_len(K_BLOCK_SIZE); - v - } + pub fn to_array(self) -> [M31; K_BLOCK_SIZE] { + unsafe { std::mem::transmute(self.reduce()) } + } + + /// Reduces each word in the 512-bit register to the range `[0, P)`, excluding P. + pub fn reduce(self) -> M31AVX512 { + Self(unsafe { _mm512_min_epu32(self.0, _mm512_sub_epi32(self.0, M512P)) }) } } impl Display for M31AVX512 { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let v = self.to_vec(); + let v = self.to_array(); for elem in v.iter() { write!(f, "{} ", elem)?; } @@ -88,9 +48,18 @@ impl Display for M31AVX512 { impl Add for M31AVX512 { type Output = Self; + /// Adds two packed M31 elements, and reduces the result to the range [0,P]. + /// Each value is assumed to be in unreduced form, [0, P] including P. #[inline(always)] fn add(self, rhs: Self) -> Self::Output { - unsafe { Self::partial_reduce(_mm512_add_epi64(self.0, rhs.0)) } + Self(unsafe { + // Add word by word. Each word is in the range [0, 2P]. + let c = _mm512_add_epi32(self.0, rhs.0); + // Apply min(c, c-P) to each word. + // When c in [P,2P], then c-P in [0,P] which is always less than [P,2P]. + // When c in [0,P-1], then c-P in [2^32-P,2^32-1] which is always greater than [0,P-1]. + _mm512_min_epu32(c, _mm512_sub_epi32(c, M512P)) + }) } } @@ -104,9 +73,70 @@ impl AddAssign for M31AVX512 { impl Mul for M31AVX512 { type Output = Self; + /// Computes the product of two packed M31 elements + /// Each value is assumed to be in unreduced form, [0, P] including P. + /// Returned values are in unreduced form, [0, P] including P. #[inline(always)] fn mul(self, rhs: Self) -> Self::Output { - unsafe { Self::reduce(_mm512_mul_epu32(self.0, rhs.0)) } + /// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a + /// with the even words of b. + const EVENS_INTERLEAVE_EVENS: __m512i = unsafe { + core::mem::transmute([ + 0b00000, 0b10000, 0b00010, 0b10010, 0b00100, 0b10100, 0b00110, 0b10110, 0b01000, + 0b11000, 0b01010, 0b11010, 0b01100, 0b11100, 0b01110, 0b11110, + ]) + }; + /// An input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a + /// with the odd words of b. + const ODDS_INTERLEAVE_ODDS: __m512i = unsafe { + core::mem::transmute([ + 0b00001, 0b10001, 0b00011, 0b10011, 0b00101, 0b10101, 0b00111, 0b10111, 0b01001, + 0b11001, 0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111, + ]) + }; + + unsafe { + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let val0_e = self.0; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let val0_o = _mm512_srli_epi64(self.0, 32); + + // Double the second operand. + let val1 = _mm512_add_epi32(rhs.0, rhs.0); + let val1_e = val1; + let val1_o = _mm512_srli_epi64(val1, 32); + + // To compute prod = val0 * val1 start by multiplying + // val0_e/o by val1_e/o. + let prod_e_dbl = _mm512_mul_epu32(val0_e, val1_e); + let prod_o_dbl = _mm512_mul_epu32(val0_o, val1_o); + + // The result of a multiplication holds val1*twiddle_dbl in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_e_dbl - |0|prod_e_h|prod_e_l|0| + // prod_o_dbl - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: + let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl); + // prod_ls - |prod_o_l|0|prod_e_l|0| + + // Divide by 2: + let prod_ls = Self(_mm512_srli_epi64(prod_ls, 1)); + // prod_ls - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: + let prod_hs = Self(_mm512_permutex2var_epi32( + prod_e_dbl, + ODDS_INTERLEAVE_ODDS, + prod_o_dbl, + )); + // prod_hs - |0|prod_o_h|0|prod_e_h| + + Self::add(prod_ls, prod_hs) + } } } @@ -122,22 +152,26 @@ impl Neg for M31AVX512 { #[inline(always)] fn neg(self) -> Self::Output { - unsafe { Self::partial_reduce(_mm512_sub_epi64(M512P, self.0)) } + Self(unsafe { _mm512_sub_epi32(M512P, self.0) }) } } +/// Subtracts two packed M31 elements, and reduces the result to the range [0,P]. +/// Each value is assumed to be in unreduced form, [0, P] including P. impl Sub for M31AVX512 { type Output = Self; #[inline(always)] fn sub(self, rhs: Self) -> Self::Output { - unsafe { - let a_minus_b = _mm512_sub_epi32(self.0, rhs.0); - Self(_mm512_min_epu32( - a_minus_b, - _mm512_add_epi32(a_minus_b, M512P), - )) - } + Self(unsafe { + // Subtract word by word. Each word is in the range [-P, P]. + let c = _mm512_sub_epi32(self.0, rhs.0); + // Apply min(c, c+P) to each word. + // When c in [0,P], then c+P in [P,2P] which is always greater than [0,P]. + // When c in [2^32-P,2^32-1], then c+P in [0,P-1] which is always less than + // [2^32-P,2^32-1]. + _mm512_min_epu32(_mm512_add_epi32(c, M512P), c) + }) } } @@ -150,14 +184,12 @@ impl SubAssign for M31AVX512 { #[cfg(test)] mod tests { - use core::arch::x86_64::_mm512_loadu_epi64; - use rand::Rng; + use itertools::Itertools; - use super::{K_BLOCK_SIZE, M31AVX512}; + use super::M31AVX512; use crate::core::fields::m31::{M31, P}; use crate::core::fields::Field; - use crate::m31; /// Tests field operations where field elements are in reduced form. #[test] @@ -166,66 +198,44 @@ mod tests { return; } - let values = [0, 1, 2, 10, (P - 1) / 2, (P + 1) / 2, P - 2, P - 1] - .map(M31::from_u32_unchecked) - .to_vec(); - let avx_values = M31AVX512::from_slice(&values); + let values = [ + 0, + 1, + 2, + 10, + (P - 1) / 2, + (P + 1) / 2, + P - 2, + P - 1, + 0, + 1, + 2, + 10, + (P - 1) / 2, + (P + 1) / 2, + P - 2, + P - 1, + ] + .map(M31::from_u32_unchecked); + let avx_values = M31AVX512::from_array(values); assert_eq!( - (avx_values + avx_values).to_vec(), - values.iter().map(|x| x.double()).collect::>() + (avx_values + avx_values) + .to_array() + .into_iter() + .collect_vec(), + values.iter().map(|x| x.double()).collect_vec() ); assert_eq!( - (avx_values * avx_values).to_vec(), - values.iter().map(|x| x.square()).collect::>() + (avx_values * avx_values) + .to_array() + .into_iter() + .collect_vec(), + values.iter().map(|x| x.square()).collect_vec() ); assert_eq!( - (-avx_values).to_vec(), - values.iter().map(|x| -*x).collect::>() - ); - } - - /// Tests that reduce functions are correct. - #[test] - fn test_reduce() { - if !crate::platform::avx512_detected() { - return; - } - let mut rng = rand::thread_rng(); - - let const_values = [0, 1, (P + 1) / 2, P - 1, P, P + 1, 2 * P - 1, 2 * P]; - let avx_const_values = - M31AVX512::from_slice(const_values.map(M31::from_u32_unchecked).as_ref()); - - // Tests partial reduce. - assert_eq!( - M31AVX512::partial_reduce(avx_const_values.0).to_vec(), - const_values - .iter() - .map(|x| m31!(if *x == 2 * P { P } else { x % P })) - .collect::>() - ); - - // Generate random values in [0, P^2). - let rand_values = (0..K_BLOCK_SIZE) - .map(|_x| rng.gen::() % (P as u64).pow(2)) - .collect::>(); - let avx_rand_values = M31AVX512::from_m512_unchecked(unsafe { - _mm512_loadu_epi64(rand_values.as_ptr() as *const i64) - }); - - // Tests reduce. - assert_eq!( - M31AVX512::reduce(avx_const_values.0).to_vec(), - const_values.iter().map(|x| m31!(x % P)).collect::>() - ); - - assert_eq!( - M31AVX512::reduce(avx_rand_values.0).to_vec(), - rand_values - .iter() - .map(|x| m31!((x % P as u64) as u32)) - .collect::>() + (-avx_values).to_array().into_iter().collect_vec(), + values.iter().map(|x| -*x).collect_vec() ); } }