diff --git a/Cargo.lock b/Cargo.lock index dc7329416..1173a1b69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -84,6 +84,26 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" +[[package]] +name = "bytemuck" +version = "1.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "cast" version = "0.3.0" @@ -614,6 +634,7 @@ version = "0.1.1" dependencies = [ "blake2", "blake3", + "bytemuck", "criterion", "hex", "itertools 0.12.0", diff --git a/Cargo.toml b/Cargo.toml index 38f933ef6..8b01b3893 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ itertools = "0.12.0" num-traits = "0.2.17" thiserror = "1.0.56" merging-iterator = "1.3.0" +bytemuck = { version = "1.14.3", features = ["derive"] } [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports"] } diff --git a/benches/bit_rev.rs b/benches/bit_rev.rs index 85978659e..720a61220 100644 --- a/benches/bit_rev.rs +++ b/benches/bit_rev.rs @@ -13,7 +13,7 @@ pub fn cpu_bit_rev(c: &mut criterion::Criterion) { c.bench_function("cpu bit_rev", |b| { b.iter(|| { - data = stwo::core::utils::bit_reverse(std::mem::take(&mut data)); + stwo::core::utils::bit_reverse(&mut data); }) }); } diff --git a/src/core/backend/avx512/bit_reverse.rs b/src/core/backend/avx512/bit_reverse.rs index adad9e287..140f62807 100644 --- a/src/core/backend/avx512/bit_reverse.rs +++ b/src/core/backend/avx512/bit_reverse.rs @@ -1,16 +1,15 @@ use std::arch::x86_64::{__m512i, _mm512_permutex2var_epi32}; -use crate::core::fields::m31::BaseField; +use super::PackedBaseField; use crate::core::utils::bit_reverse_index; const VEC_BITS: u32 = 4; const W_BITS: u32 = 3; -const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS; +pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS; -// TODO(spapini): Use PackedBaseField type. /// Bit reverses packed M31 values. -/// Given an array A[0..2^n), computes B[i] = A[bit_reverse(i)]. -pub fn bit_reverse_m31(data: &mut [[BaseField; 16]]) { +/// Given an array `A[0..2^n)`, computes `B[i] = A[bit_reverse(i)]`. +pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { assert!(data.len().is_power_of_two()); assert!(data.len().ilog2() >= MIN_LOG_SIZE); @@ -74,7 +73,7 @@ pub fn bit_reverse_m31(data: &mut [[BaseField; 16]]) { } /// Bit reverses 16 packed M31 values. -fn bit_reverse16(data: [[BaseField; 16]; 16]) -> [[BaseField; 16]; 16] { +fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { let mut data: [__m512i; 16] = unsafe { std::mem::transmute(data) }; // L is an input to _mm512_permutex2var_epi32, and it is used to // interleave the first half of a with the first half of b. @@ -159,7 +158,8 @@ mod tests { let data: Vec<_> = (0..SIZE as u32) .map(BaseField::from_u32_unchecked) .collect(); - let expected = bit_reverse(data.clone()); + let mut expected = data.clone(); + bit_reverse(&mut expected); let mut data: Vec<_> = data.into_iter().array_chunks::<16>().collect(); let expected: Vec<_> = expected.into_iter().array_chunks::<16>().collect(); diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs index 0e2440cee..3a1739d52 100644 --- a/src/core/backend/avx512/mod.rs +++ b/src/core/backend/avx512/mod.rs @@ -1 +1,121 @@ pub mod bit_reverse; + +use std::ops::Index; + +use bytemuck::checked::cast_slice_mut; +use num_traits::Zero; + +use self::bit_reverse::bit_reverse_m31; +use crate::core::fields::m31::BaseField; +use crate::core::fields::{Column, FieldOps}; +use crate::core::utils; + +#[derive(Copy, Clone, Debug)] +pub struct AVX512Backend; + +// BaseField. +// TODO(spapini): Unite with the M31AVX512 type. +pub const K_ELEMENTS: usize = 16; +type PackedBaseField = [BaseField; K_ELEMENTS]; +#[derive(Clone, Debug)] +pub struct BaseFieldVec { + data: Vec, + length: usize, +} +impl FieldOps for AVX512Backend { + type Column = BaseFieldVec; + + fn bit_reverse_column(column: &mut Self::Column) { + // Fallback to cpu bit_reverse. + if column.data.len().ilog2() < bit_reverse::MIN_LOG_SIZE { + let data: &mut [BaseField] = cast_slice_mut(&mut column.data[..]); + utils::bit_reverse(&mut data[..column.length]); + return; + } + bit_reverse_m31(&mut column.data); + } +} + +impl Column for BaseFieldVec { + fn zeros(len: usize) -> Self { + Self { + data: vec![PackedBaseField::default(); len.div_ceil(K_ELEMENTS)], + length: len, + } + } + fn to_vec(&self) -> Vec { + self.data + .iter() + .flatten() + .copied() + .take(self.length) + .collect() + } + fn len(&self) -> usize { + self.length + } +} + +impl Index for BaseFieldVec { + type Output = BaseField; + fn index(&self, index: usize) -> &Self::Output { + &self.data[index / K_ELEMENTS][index % K_ELEMENTS] + } +} + +impl FromIterator for BaseFieldVec { + fn from_iter>(iter: I) -> Self { + let mut chunks = iter.into_iter().array_chunks(); + let mut res: Vec<_> = (&mut chunks).collect(); + let mut length = res.len() * K_ELEMENTS; + + if let Some(remainder) = chunks.into_remainder() { + if !remainder.is_empty() { + length += remainder.len(); + let pad_len = 16 - remainder.len(); + let last: PackedBaseField = remainder + .chain(std::iter::repeat(BaseField::zero()).take(pad_len)) + .collect::>() + .try_into() + .unwrap(); + res.push(last); + } + } + + Self { data: res, length } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::fields::{Col, Column}; + + type B = AVX512Backend; + + #[test] + fn test_column() { + for i in 0..100 { + let col = Col::::from_iter((0..i).map(BaseField::from)); + assert_eq!( + col.to_vec(), + (0..i).map(BaseField::from).collect::>() + ); + } + } + + #[test] + fn test_bit_reverse() { + for i in 1..16 { + let len = 1 << i; + let mut col = Col::::from_iter((0..len).map(BaseField::from)); + B::bit_reverse_column(&mut col); + assert_eq!( + col.to_vec(), + (0..len) + .map(|x| BaseField::from(utils::bit_reverse_index(x, i as u32))) + .collect::>() + ); + } + } +} diff --git a/src/core/backend/cpu/mod.rs b/src/core/backend/cpu/mod.rs index b340a6140..730078357 100644 --- a/src/core/backend/cpu/mod.rs +++ b/src/core/backend/cpu/mod.rs @@ -18,7 +18,7 @@ impl Backend for CPUBackend {} impl FieldOps for CPUBackend { type Column = Vec; - fn bit_reverse_column(column: Self::Column) -> Self::Column { + fn bit_reverse_column(column: &mut Self::Column) { bit_reverse(column) } } diff --git a/src/core/fields/m31.rs b/src/core/fields/m31.rs index 78040843f..3f34afa73 100644 --- a/src/core/fields/m31.rs +++ b/src/core/fields/m31.rs @@ -3,6 +3,8 @@ use std::ops::{ Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, }; +use bytemuck::{Pod, Zeroable}; + use super::ComplexConjugate; use crate::impl_field; @@ -10,7 +12,8 @@ pub const MODULUS_BITS: u32 = 31; pub const N_BYTES_FELT: usize = 4; pub const P: u32 = 2147483647; // 2 ** 31 - 1 -#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Pod, Zeroable)] pub struct M31(u32); pub type BaseField = M31; diff --git a/src/core/fields/mod.rs b/src/core/fields/mod.rs index fe6649d57..05fc9233b 100644 --- a/src/core/fields/mod.rs +++ b/src/core/fields/mod.rs @@ -11,7 +11,7 @@ pub mod qm31; pub trait FieldOps { type Column: Column; - fn bit_reverse_column(column: Self::Column) -> Self::Column; + fn bit_reverse_column(column: &mut Self::Column); } pub type Col = >::Column; diff --git a/src/core/poly/circle.rs b/src/core/poly/circle.rs index ba5042abd..b5a36a734 100644 --- a/src/core/poly/circle.rs +++ b/src/core/poly/circle.rs @@ -254,8 +254,9 @@ impl, B: PolyOps> CircleEvaluation { self.values[self.domain.find(point_index).expect("Not in domain")] } - pub fn bit_reverse(self) -> CircleEvaluation { - CircleEvaluation::new(self.domain, B::bit_reverse_column(self.values)) + pub fn bit_reverse(mut self) -> CircleEvaluation { + B::bit_reverse_column(&mut self.values); + CircleEvaluation::new(self.domain, self.values) } } @@ -281,8 +282,9 @@ impl> CPUCircleEvaluation { } impl, F: ExtensionOf> CircleEvaluation { - pub fn bit_reverse(self) -> CircleEvaluation { - CircleEvaluation::new(self.domain, B::bit_reverse_column(self.values)) + pub fn bit_reverse(mut self) -> CircleEvaluation { + B::bit_reverse_column(&mut self.values); + CircleEvaluation::new(self.domain, self.values) } pub fn get_at(&self, point_index: CirclePointIndex) -> F { @@ -389,7 +391,8 @@ impl, B: PolyOps> CirclePoly { #[cfg(test)] impl> crate::core::backend::cpu::CPUCirclePoly { pub fn is_in_fft_space(&self, log_fft_size: u32) -> bool { - let mut coeffs = crate::core::utils::bit_reverse(self.coeffs.clone()); + let mut coeffs = self.coeffs.clone(); + crate::core::utils::bit_reverse(&mut coeffs); while coeffs.last() == Some(&F::zero()) { coeffs.pop(); } diff --git a/src/core/poly/line.rs b/src/core/poly/line.rs index 7d270e422..c76e8f3d2 100644 --- a/src/core/poly/line.rs +++ b/src/core/poly/line.rs @@ -156,8 +156,9 @@ impl> LinePoly { } /// Returns the polynomial's coefficients in their natural order. - pub fn into_ordered_coefficients(self) -> Vec { - bit_reverse(self.coeffs) + pub fn into_ordered_coefficients(mut self) -> Vec { + bit_reverse(&mut self.coeffs); + self.coeffs } /// Creates a new line polynomial from coefficients in their natural order. @@ -165,8 +166,9 @@ impl> LinePoly { /// # Panics /// /// Panics if the number of coefficients is not a power of two. - pub fn from_ordered_coefficients(coeffs: Vec) -> Self { - Self::new(bit_reverse(coeffs)) + pub fn from_ordered_coefficients(mut coeffs: Vec) -> Self { + bit_reverse(&mut coeffs); + Self::new(coeffs) } } @@ -242,9 +244,10 @@ impl> CPULineEvaluation { } impl, F: Field> LineEvaluation { - pub fn bit_reverse(self) -> LineEvaluation { + pub fn bit_reverse(mut self) -> LineEvaluation { + B::bit_reverse_column(&mut self.values); LineEvaluation { - values: B::bit_reverse_column(self.values), + values: self.values, domain: self.domain, _eval_order: PhantomData, } @@ -252,9 +255,10 @@ impl, F: Field> LineEvaluation { } impl, F: Field> LineEvaluation { - pub fn bit_reverse(self) -> LineEvaluation { + pub fn bit_reverse(mut self) -> LineEvaluation { + B::bit_reverse_column(&mut self.values); LineEvaluation { - values: B::bit_reverse_column(self.values), + values: self.values, domain: self.domain, _eval_order: PhantomData, } diff --git a/src/core/queries.rs b/src/core/queries.rs index 673536644..9b371cbb3 100644 --- a/src/core/queries.rs +++ b/src/core/queries.rs @@ -158,13 +158,13 @@ mod tests { pub fn test_folded_queries() { let log_domain_size = 7; let domain = CanonicCoset::new(log_domain_size).circle_domain(); - let values = domain.iter().collect::>(); - let values = bit_reverse(values.clone()); + let mut values = domain.iter().collect::>(); + bit_reverse(&mut values); let log_folded_domain_size = 5; let folded_domain = CanonicCoset::new(log_folded_domain_size).circle_domain(); - let folded_values = folded_domain.iter().collect::>(); - let folded_values = bit_reverse(folded_values.clone()); + let mut folded_values = folded_domain.iter().collect::>(); + bit_reverse(&mut folded_values); // Generate all possible queries. let queries = Queries { @@ -192,8 +192,8 @@ mod tests { let channel = &mut Blake2sChannel::new(Blake2sHash::default()); let log_domain_size = 7; let domain = CanonicCoset::new(log_domain_size).circle_domain(); - let values = domain.iter().collect::>(); - let values = bit_reverse(values.clone()); + let mut values = domain.iter().collect::>(); + bit_reverse(&mut values); // Test random queries one by one because the conjugate queries are sorted. for _ in 0..100 { diff --git a/src/core/utils.rs b/src/core/utils.rs index b0b03ef4e..c781e1ae3 100644 --- a/src/core/utils.rs +++ b/src/core/utils.rs @@ -13,24 +13,23 @@ pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize { i.reverse_bits() >> (usize::BITS - log_size) } -/// Performs a naive bit-reversal permutation. +/// Performs a naive bit-reversal permutation inplace. /// /// # Panics /// /// Panics if the length of the slice is not a power of two. -// TODO(AlonH): Consider benchmarking this function. // TODO: Implement cache friendly implementation. -pub fn bit_reverse>(mut v: U) -> U { - let n = v.as_mut().len(); +// TODO(spapini): Move this to the cpu backend. +pub fn bit_reverse(v: &mut [T]) { + let n = v.len(); assert!(n.is_power_of_two()); let log_n = n.ilog2(); for i in 0..n { let j = bit_reverse_index(i, log_n); if j > i { - v.as_mut().swap(i, j); + v.swap(i, j); } } - v } #[cfg(test)] @@ -39,15 +38,15 @@ mod tests { #[test] fn bit_reverse_works() { - assert_eq!( - bit_reverse([0, 1, 2, 3, 4, 5, 6, 7]), - [0, 4, 2, 6, 1, 5, 3, 7] - ); + let mut data = [0, 1, 2, 3, 4, 5, 6, 7]; + bit_reverse(&mut data); + assert_eq!(data, [0, 4, 2, 6, 1, 5, 3, 7]); } #[test] #[should_panic] fn bit_reverse_non_power_of_two_size_fails() { - bit_reverse([0, 1, 2, 3, 4, 5]); + let mut data = [0, 1, 2, 3, 4, 5]; + bit_reverse(&mut data); } } diff --git a/src/fibonacci/mod.rs b/src/fibonacci/mod.rs index 9cdb065b5..eb8f9621f 100644 --- a/src/fibonacci/mod.rs +++ b/src/fibonacci/mod.rs @@ -367,8 +367,8 @@ mod tests { fn test_sparse_circle_points() { let log_domain_size = 7; let domain = CanonicCoset::new(log_domain_size).circle_domain(); - let domain_points = domain.iter().collect_vec(); - let trace_commitment_points = bit_reverse(domain_points); + let mut trace_commitment_points = domain.iter().collect_vec(); + bit_reverse(&mut trace_commitment_points); // Generate queries. let trace_queries = Queries { @@ -391,7 +391,8 @@ mod tests { .collect_vec(); let circle_domain = sub_circle_domain.to_circle_domain(&domain); // Bit reverse the domain points to match the order of the opened points. - let domain_points = bit_reverse(circle_domain.iter().collect_vec()); + let mut domain_points = circle_domain.iter().collect_vec(); + bit_reverse(&mut domain_points); assert_eq!(points, domain_points); } }