diff --git a/.github/workflows/benchmarks-pages.yaml b/.github/workflows/benchmarks-pages.yaml index 3672b6c2c..9d59fbb89 100644 --- a/.github/workflows/benchmarks-pages.yaml +++ b/.github/workflows/benchmarks-pages.yaml @@ -1,4 +1,4 @@ -name: +name: on: push: @@ -18,7 +18,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - name: Run benchmark run: ./scripts/bench.sh -- --output-format bencher | tee output.txt - name: Download previous benchmark data @@ -29,7 +29,7 @@ jobs: - name: Store benchmark result uses: benchmark-action/github-action-benchmark@v1 with: - tool: 'cargo' + tool: "cargo" output-file-path: output.txt github-token: ${{ secrets.GITHUB_TOKEN }} auto-push: true diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b9b72d33d..ae785e2ce 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,7 +25,7 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: rustfmt - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - run: scripts/rust_fmt.sh --check @@ -36,7 +36,7 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: clippy - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - run: scripts/clippy.sh @@ -46,9 +46,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 doc + - run: cargo +nightly-2024-11-06 doc run-wasm32-wasi-tests: runs-on: ubuntu-latest @@ -56,7 +56,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 targets: wasm32-wasi - uses: taiki-e/install-action@v2 with: @@ -73,7 +73,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 targets: wasm32-unknown-unknown - uses: Swatinem/rust-cache@v2 - uses: jetli/wasm-pack-action@v0.4.0 @@ -89,9 +89,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test + - run: cargo +nightly-2024-11-06 test env: RUSTFLAGS: -C target-feature=+neon @@ -104,9 +104,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test + - run: cargo +nightly-2024-11-06 test env: RUSTFLAGS: -C target-cpu=native -C target-feature=+${{ matrix.target-feature }} @@ -116,7 +116,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - name: Run benchmark run: ./scripts/bench.sh -- --output-format bencher | tee output.txt - name: Download previous benchmark data @@ -142,7 +142,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - name: Run benchmark run: ./scripts/bench.sh --features="parallel" -- --output-format bencher | tee output.txt - name: Download previous benchmark data @@ -168,9 +168,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test + - run: cargo +nightly-2024-11-06 test run-slow-tests: runs-on: ubuntu-latest @@ -178,9 +178,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test --release --features="slow-tests" + - run: cargo +nightly-2024-11-06 test --release --features="slow-tests" run-tests-parallel: runs-on: ubuntu-latest @@ -188,9 +188,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test --features="parallel" + - run: cargo +nightly-2024-11-06 test --features="parallel" machete: runs-on: ubuntu-latest diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 504cd67bb..508e0f11b 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -12,14 +12,14 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: rustfmt - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov # TODO: Merge coverage reports for tests on different architectures. # - name: Generate code coverage - run: cargo +nightly-2024-01-04 llvm-cov --codecov --output-path codecov.json + run: cargo +nightly-2024-11-06 llvm-cov --codecov --output-path codecov.json env: RUSTFLAGS: "-C target-feature=+avx512f" - name: Upload coverage to Codecov diff --git a/Cargo.lock b/Cargo.lock index c14183025..b71aed62e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -397,12 +397,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "downcast-rs" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" - [[package]] name = "educe" version = "0.5.11" @@ -984,7 +978,6 @@ dependencies = [ "bytemuck", "cfg-if", "criterion", - "downcast-rs", "educe", "hex", "itertools 0.12.1", diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 508c86e18..21c6bd150 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -14,7 +14,6 @@ blake2.workspace = true blake3.workspace = true bytemuck = { workspace = true, features = ["derive", "extern_crate_alloc"] } cfg-if = "1.0.0" -downcast-rs = "1.2" educe.workspace = true hex.workspace = true itertools.workspace = true @@ -56,9 +55,6 @@ nonstandard-style = "deny" rust-2018-idioms = "deny" unused = "deny" -[package.metadata.cargo-machete] -ignored = ["downcast-rs"] - [[bench]] harness = false name = "bit_rev" diff --git a/crates/prover/benches/fft.rs b/crates/prover/benches/fft.rs index 35841d7e8..cbb0c9e80 100644 --- a/crates/prover/benches/fft.rs +++ b/crates/prover/benches/fft.rs @@ -29,7 +29,7 @@ pub fn simd_ifft(c: &mut Criterion) { || values.clone().data, |mut data| unsafe { ifft( - transmute(data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(data.as_mut_ptr()), black_box(&twiddle_dbls_refs), black_box(log_size as usize), ); @@ -58,7 +58,7 @@ pub fn simd_ifft_parts(c: &mut Criterion) { || values.clone().data, |mut values| unsafe { ifft_vecwise_loop( - transmute(values.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.as_mut_ptr()), black_box(&twiddle_dbls_refs), black_box(9), black_box(0), @@ -72,7 +72,7 @@ pub fn simd_ifft_parts(c: &mut Criterion) { || values.clone().data, |mut values| unsafe { ifft3_loop( - transmute(values.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.as_mut_ptr()), black_box(&twiddle_dbls_refs[3..]), black_box(7), black_box(4), @@ -91,7 +91,7 @@ pub fn simd_ifft_parts(c: &mut Criterion) { || transpose_values.clone().data, |mut values| unsafe { transpose_vecs( - transmute(values.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.as_mut_ptr()), black_box(TRANSPOSE_LOG_SIZE as usize - 4), ) }, @@ -115,8 +115,10 @@ pub fn simd_rfft(c: &mut Criterion) { target.set_len(values.data.len()); fft( - black_box(transmute(values.data.as_ptr())), - transmute(target.as_mut_ptr()), + black_box(transmute::<*const PackedBaseField, *const u32>( + values.data.as_ptr(), + )), + transmute::<*mut PackedBaseField, *mut u32>(target.as_mut_ptr()), black_box(&twiddle_dbls_refs), black_box(LOG_SIZE as usize), ) diff --git a/crates/prover/benches/merkle.rs b/crates/prover/benches/merkle.rs index c039be77e..9a63a3c38 100644 --- a/crates/prover/benches/merkle.rs +++ b/crates/prover/benches/merkle.rs @@ -21,7 +21,7 @@ fn bench_blake2s_merkle>(c: &mut Criterion, id let n_elements = 1 << (LOG_N_COLS + LOG_N_ROWS); group.throughput(Throughput::Elements(n_elements)); group.throughput(Throughput::Bytes(N_BYTES_FELT as u64 * n_elements)); - group.bench_function(&format!("{id} merkle"), |b| { + group.bench_function(format!("{id} merkle"), |b| { b.iter_with_large_drop(|| B::commit_on_layer(LOG_N_ROWS, None, &col_refs)) }); } diff --git a/crates/prover/src/constraint_framework/assert.rs b/crates/prover/src/constraint_framework/assert.rs index 9e5530bae..3ed0c6162 100644 --- a/crates/prover/src/constraint_framework/assert.rs +++ b/crates/prover/src/constraint_framework/assert.rs @@ -24,7 +24,7 @@ impl<'a> AssertEvaluator<'a> { } } } -impl<'a> EvalAtRow for AssertEvaluator<'a> { +impl EvalAtRow for AssertEvaluator<'_> { type F = BaseField; type EF = SecureField; diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 8c237a9c5..422cad2a9 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -65,9 +65,10 @@ impl TraceLocationAllocator { } /// A component defined solely in means of the constraints framework. +/// /// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for -/// the SIMD backend. -/// Note that the constraint framework only support components with columns of the same size. +/// the SIMD backend. Note that the constraint framework only supports components with columns of +/// the same size. pub trait FrameworkEval { fn log_size(&self) -> u32; diff --git a/crates/prover/src/constraint_framework/cpu_domain.rs b/crates/prover/src/constraint_framework/cpu_domain.rs index d8aad0b8b..e198b5df0 100644 --- a/crates/prover/src/constraint_framework/cpu_domain.rs +++ b/crates/prover/src/constraint_framework/cpu_domain.rs @@ -46,7 +46,7 @@ impl<'a> CpuDomainEvaluator<'a> { } } -impl<'a> EvalAtRow for CpuDomainEvaluator<'a> { +impl EvalAtRow for CpuDomainEvaluator<'_> { type F = BaseField; type EF = SecureField; diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index 81f1acfcb..f56761ecc 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -252,7 +252,7 @@ pub struct LogupColGenerator<'a> { /// Numerator expressions (i.e. multiplicities) being generated for the current lookup. numerator: SecureColumnByCoords, } -impl<'a> LogupColGenerator<'a> { +impl LogupColGenerator<'_> { /// Write a fraction to the column at a row. pub fn write_frac( &mut self, diff --git a/crates/prover/src/constraint_framework/point.rs b/crates/prover/src/constraint_framework/point.rs index 6c6f72f81..05a0ea969 100644 --- a/crates/prover/src/constraint_framework/point.rs +++ b/crates/prover/src/constraint_framework/point.rs @@ -29,7 +29,7 @@ impl<'a> PointEvaluator<'a> { } } } -impl<'a> EvalAtRow for PointEvaluator<'a> { +impl EvalAtRow for PointEvaluator<'_> { type F = SecureField; type EF = SecureField; diff --git a/crates/prover/src/constraint_framework/simd_domain.rs b/crates/prover/src/constraint_framework/simd_domain.rs index 0ae32ebd3..2aacbef75 100644 --- a/crates/prover/src/constraint_framework/simd_domain.rs +++ b/crates/prover/src/constraint_framework/simd_domain.rs @@ -51,7 +51,7 @@ impl<'a> SimdDomainEvaluator<'a> { } } } -impl<'a> EvalAtRow for SimdDomainEvaluator<'a> { +impl EvalAtRow for SimdDomainEvaluator<'_> { type F = VeryPackedBaseField; type EF = VeryPackedSecureField; diff --git a/crates/prover/src/core/air/accumulation.rs b/crates/prover/src/core/air/accumulation.rs index 2e7010ae6..2803f9f2d 100644 --- a/crates/prover/src/core/air/accumulation.rs +++ b/crates/prover/src/core/air/accumulation.rs @@ -1,4 +1,5 @@ //! Accumulators for a random linear combination of circle polynomials. +//! //! Given N polynomials, u_0(P), ... u_{N-1}(P), and a random alpha, the combined polynomial is //! defined as //! f(p) = sum_i alpha^{N-1-i} u_i(P). @@ -162,7 +163,7 @@ pub struct ColumnAccumulator<'a, B: Backend> { pub random_coeff_powers: Vec, pub col: &'a mut SecureColumnByCoords, } -impl<'a> ColumnAccumulator<'a, CpuBackend> { +impl ColumnAccumulator<'_, CpuBackend> { pub fn accumulate(&mut self, index: usize, evaluation: SecureField) { let val = self.col.at(index) + evaluation; self.col.set(index, val); diff --git a/crates/prover/src/core/air/components.rs b/crates/prover/src/core/air/components.rs index 320ec35cb..0376f3c39 100644 --- a/crates/prover/src/core/air/components.rs +++ b/crates/prover/src/core/air/components.rs @@ -11,7 +11,7 @@ use crate::core::ColumnVec; pub struct Components<'a>(pub Vec<&'a dyn Component>); -impl<'a> Components<'a> { +impl Components<'_> { pub fn composition_log_degree_bound(&self) -> u32 { self.0 .iter() @@ -55,7 +55,7 @@ impl<'a> Components<'a> { pub struct ComponentProvers<'a, B: Backend>(pub Vec<&'a dyn ComponentProver>); -impl<'a, B: Backend> ComponentProvers<'a, B> { +impl ComponentProvers<'_, B> { pub fn components(&self) -> Components<'_> { Components(self.0.iter().map(|c| *c as &dyn Component).collect_vec()) } diff --git a/crates/prover/src/core/air/mod.rs b/crates/prover/src/core/air/mod.rs index 05d8d10f5..df0aef8bc 100644 --- a/crates/prover/src/core/air/mod.rs +++ b/crates/prover/src/core/air/mod.rs @@ -15,10 +15,10 @@ mod components; pub mod mask; /// Arithmetic Intermediate Representation (AIR). -/// An Air instance is assumed to already contain all the information needed to -/// evaluate the constraints. -/// For instance, all interaction elements are assumed to be present in it. -/// Therefore, an AIR is generated only after the initial trace commitment phase. +/// +/// An Air instance is assumed to already contain all the information needed to evaluate the +/// constraints. For instance, all interaction elements are assumed to be present in it. Therefore, +/// an AIR is generated only after the initial trace commitment phase. pub trait Air { fn components(&self) -> Vec<&dyn Component>; } diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index ae9ab6b65..9c1d60093 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -265,7 +265,7 @@ enum MleExpr<'a, F: Field> { Mle(&'a Mle), } -impl<'a, F: Field> Index for MleExpr<'a, F> { +impl Index for MleExpr<'_, F> { type Output = F; fn index(&self, index: usize) -> &F { diff --git a/crates/prover/src/core/backend/cpu/quotients.rs b/crates/prover/src/core/backend/cpu/quotients.rs index 17cc007e1..4b356bb79 100644 --- a/crates/prover/src/core/backend/cpu/quotients.rs +++ b/crates/prover/src/core/backend/cpu/quotients.rs @@ -75,10 +75,10 @@ pub fn accumulate_row_quotients( row_accumulator } -/// Precompute the complex conjugate line coefficients for each column in each sample batch. -/// Specifically, for the i-th (in a sample batch) column's numerator term -/// `alpha^i * (c * F(p) - (a * p.y + b))`, we precompute and return the constants: -/// (`alpha^i * a`, `alpha^i * b`, `alpha^i * c`). +/// Precomputes the complex conjugate line coefficients for each column in each sample batch. +/// +/// For the `i`-th (in a sample batch) column's numerator term `alpha^i * (c * F(p) - (a * p.y + +/// b))`, we precompute and return the constants: (`alpha^i * a`, `alpha^i * b`, `alpha^i * c`). pub fn column_line_coeffs( sample_batches: &[ColumnSampleBatch], random_coeff: SecureField, @@ -103,8 +103,9 @@ pub fn column_line_coeffs( .collect() } -/// Precompute the random coefficients used to linearly combine the batched quotients. -/// Specifically, for each sample batch we compute random_coeff^(number of columns in the batch), +/// Precomputes the random coefficients used to linearly combine the batched quotients. +/// +/// For each sample batch we compute random_coeff^(number of columns in the batch), /// which is used to linearly combine the batch with the next one. pub fn batch_random_coeffs( sample_batches: &[ColumnSampleBatch], diff --git a/crates/prover/src/core/backend/simd/bit_reverse.rs b/crates/prover/src/core/backend/simd/bit_reverse.rs index 1c2418d7d..210e736e5 100644 --- a/crates/prover/src/core/backend/simd/bit_reverse.rs +++ b/crates/prover/src/core/backend/simd/bit_reverse.rs @@ -165,7 +165,7 @@ mod tests { let res = bit_reverse16(values.data.try_into().unwrap()); - assert_eq!(res.map(PackedM31::to_array).flatten(), expected); + assert_eq!(res.map(PackedM31::to_array).as_flattened(), expected); } #[test] diff --git a/crates/prover/src/core/backend/simd/blake2s.rs b/crates/prover/src/core/backend/simd/blake2s.rs index 4f2297d19..3f4d46b8f 100644 --- a/crates/prover/src/core/backend/simd/blake2s.rs +++ b/crates/prover/src/core/backend/simd/blake2s.rs @@ -364,8 +364,12 @@ mod tests { let res_vectorized: [[u32; 8]; 16] = unsafe { transmute(untranspose_states(compress16( - transpose_states(transmute(states)), - transpose_msgs(transmute(msgs)), + transpose_states(transmute::, [u32x16; 8]>( + states, + )), + transpose_msgs(transmute::, [u32x16; 16]>( + msgs, + )), u32x16::splat(count_low), u32x16::splat(count_high), u32x16::splat(lastblock), diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index a20721a4f..61588ffe3 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -89,10 +89,7 @@ impl SimdBackend { // Generates twiddle steps for efficiently computing the twiddles. // steps[i] = t_i/(t_0*t_1*...*t_i-1). - fn twiddle_steps(mappings: &[F]) -> Vec - where - F: FieldExpOps, - { + fn twiddle_steps(mappings: &[F]) -> Vec { let mut denominators: Vec = vec![mappings[0]]; for i in 1..mappings.len() { @@ -159,7 +156,7 @@ impl PolyOps for SimdBackend { // Safe because [PackedBaseField] is aligned on 64 bytes. unsafe { ifft::ifft( - transmute(values.data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.data.as_mut_ptr()), &twiddles, log_size as usize, ); @@ -269,8 +266,8 @@ impl PolyOps for SimdBackend { // FFT from the coefficients buffer to the values chunk. unsafe { rfft::fft( - transmute(poly.coeffs.data.as_ptr()), - transmute( + transmute::<*const PackedBaseField, *const u32>(poly.coeffs.data.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>( values[i << (fft_log_size - LOG_N_LANES) ..(i + 1) << (fft_log_size - LOG_N_LANES)] .as_mut_ptr(), diff --git a/crates/prover/src/core/backend/simd/column.rs b/crates/prover/src/core/backend/simd/column.rs index 07d7463ee..705fe713e 100644 --- a/crates/prover/src/core/backend/simd/column.rs +++ b/crates/prover/src/core/backend/simd/column.rs @@ -207,7 +207,7 @@ impl FromIterator for CM31Column { /// A mutable slice of a BaseColumn. pub struct BaseColumnMutSlice<'a>(pub &'a mut [PackedBaseField]); -impl<'a> BaseColumnMutSlice<'a> { +impl BaseColumnMutSlice<'_> { pub fn at(&self, index: usize) -> BaseField { self.0[index / N_LANES].to_array()[index % N_LANES] } @@ -323,7 +323,7 @@ impl FromIterator for SecureColumn { /// A mutable slice of a SecureColumnByCoords. pub struct SecureColumnByCoordsMutSlice<'a>(pub [BaseColumnMutSlice<'a>; SECURE_EXTENSION_DEGREE]); -impl<'a> SecureColumnByCoordsMutSlice<'a> { +impl SecureColumnByCoordsMutSlice<'_> { /// # Safety /// /// `vec_index` must be a valid index. @@ -357,7 +357,7 @@ pub struct VeryPackedSecureColumnByCoordsMutSlice<'a>( pub [VeryPackedBaseColumnMutSlice<'a>; SECURE_EXTENSION_DEGREE], ); -impl<'a> VeryPackedSecureColumnByCoordsMutSlice<'a> { +impl VeryPackedSecureColumnByCoordsMutSlice<'_> { /// # Safety /// /// `vec_index` must be a valid index. diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index 77b096d9c..b8ed0b668 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -598,7 +598,7 @@ mod tests { let mut res = values; unsafe { ifft3( - transmute(res.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.as_mut_ptr()), 0, LOG_N_LANES as usize, twiddles0_dbl, @@ -664,7 +664,7 @@ mod tests { [val0.to_array(), val1.to_array()].concat() }; - assert_eq!(res, ground_truth_ifft(domain, values.flatten())); + assert_eq!(res, ground_truth_ifft(domain, values.as_flattened())); } #[test] @@ -678,7 +678,7 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { ifft_lower_with_vecwise( - transmute(res.data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, log_size as usize, @@ -700,11 +700,14 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { ifft( - transmute(res.data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, ); - transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4); + transpose_vecs( + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), + log_size as usize - 4, + ); } assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values)); diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs index ba091b145..7f369c331 100644 --- a/crates/prover/src/core/backend/simd/fft/mod.rs +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -111,19 +111,19 @@ fn mul_twiddle(v: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { // TODO: Come up with a better approach than `cfg`ing on target_feature. // TODO: Ensure all these branches get tested in the CI. cfg_if::cfg_if! { - if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { + if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { // TODO: For architectures that when multiplying require doubling then the twiddles // should be precomputed as double. For other architectures, the twiddle should be // precomputed without doubling. - crate::core::backend::simd::m31::_mul_doubled_neon(v, twiddle_dbl) - } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { - crate::core::backend::simd::m31::_mul_doubled_wasm(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_neon(v, twiddle_dbl) + } else if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + crate::core::backend::simd::m31::mul_doubled_wasm(v, twiddle_dbl) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { - crate::core::backend::simd::m31::_mul_doubled_avx512(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_avx512(v, twiddle_dbl) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { - crate::core::backend::simd::m31::_mul_doubled_avx2(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_avx2(v, twiddle_dbl) } else { - crate::core::backend::simd::m31::_mul_doubled_simd(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_simd(v, twiddle_dbl) } } } diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs index d28c8a00d..5e280f8ec 100644 --- a/crates/prover/src/core/backend/simd/fft/rfft.rs +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -624,8 +624,8 @@ mod tests { let mut res = values; unsafe { fft3( - transmute(res.as_ptr()), - transmute(res.as_mut_ptr()), + transmute::<*const PackedBaseField, *const u32>(res.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.as_mut_ptr()), 0, LOG_N_LANES as usize, twiddles0_dbl, @@ -695,7 +695,7 @@ mod tests { [val0.to_array(), val1.to_array()].concat() }; - assert_eq!(res, ground_truth_fft(domain, values.flatten())); + assert_eq!(res, ground_truth_fft(domain, values.as_flattened())); } #[test] @@ -709,8 +709,8 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { fft_lower_with_vecwise( - transmute(res.data.as_ptr()), - transmute(res.data.as_mut_ptr()), + transmute::<*const PackedBaseField, *const u32>(res.data.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, log_size as usize, @@ -731,10 +731,13 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { - transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4); + transpose_vecs( + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), + log_size as usize - 4, + ); fft( - transmute(res.data.as_ptr()), - transmute(res.data.as_mut_ptr()), + transmute::<*const PackedBaseField, *const u32>(res.data.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, ); diff --git a/crates/prover/src/core/backend/simd/fri.rs b/crates/prover/src/core/backend/simd/fri.rs index 8804a7015..3ced43459 100644 --- a/crates/prover/src/core/backend/simd/fri.rs +++ b/crates/prover/src/core/backend/simd/fri.rs @@ -1,5 +1,5 @@ use std::array; -use std::simd::u32x8; +use std::simd::{u32x16, u32x8}; use num_traits::Zero; @@ -38,14 +38,15 @@ impl FriOps for SimdBackend { let mut folded_values = SecureColumnByCoords::::zeros(1 << (log_size - 1)); for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) { - let value = unsafe { - let twiddle_dbl: [u32; 16] = - array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i)); - let val0 = eval.values.packed_at(vec_index * 2).into_packed_m31s(); - let val1 = eval.values.packed_at(vec_index * 2 + 1).into_packed_m31s(); + let value = { + let twiddle_dbl = u32x16::from_array(array::from_fn(|i| unsafe { + *itwiddles.get_unchecked(vec_index * 16 + i) + })); + let val0 = unsafe { eval.values.packed_at(vec_index * 2) }.into_packed_m31s(); + let val1 = unsafe { eval.values.packed_at(vec_index * 2 + 1) }.into_packed_m31s(); let pairs: [_; 4] = array::from_fn(|i| { let (a, b) = val0[i].deinterleave(val1[i]); - simd_ibutterfly(a, b, std::mem::transmute(twiddle_dbl)) + simd_ibutterfly(a, b, twiddle_dbl) }); let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0)); let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1)); diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs index 017948dee..74d7f7c43 100644 --- a/crates/prover/src/core/backend/simd/lookups/gkr.rs +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -25,7 +25,7 @@ impl GkrOps for SimdBackend { } // Start DP with CPU backend to avoid dealing with instances smaller than a SIMD vector. - let (y_last_chunk, y_rem) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap(); + let (y_rem, y_last_chunk) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap(); let initial = SecureColumn::from_iter(cpu_gen_eq_evals(y_last_chunk, v)); assert_eq!(initial.len(), N_LANES); diff --git a/crates/prover/src/core/backend/simd/lookups/mle.rs b/crates/prover/src/core/backend/simd/lookups/mle.rs index 0e2fe73f7..07f175bbc 100644 --- a/crates/prover/src/core/backend/simd/lookups/mle.rs +++ b/crates/prover/src/core/backend/simd/lookups/mle.rs @@ -30,9 +30,8 @@ impl MleOps for SimdBackend { let (evals_at_0x, evals_at_1x) = mle.data.split_at(packed_midpoint); let res = zip(evals_at_0x, evals_at_1x) - .enumerate() // MLE at points `({0, 1}, rev(bits(i)), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - .map(|(_i, (&packed_eval_at_0iv, &packed_eval_at_1iv))| { + .map(|(&packed_eval_at_0iv, &packed_eval_at_1iv)| { fold_packed_mle_evals(packed_assignment, packed_eval_at_0iv, packed_eval_at_1iv) }) .collect(); diff --git a/crates/prover/src/core/backend/simd/m31.rs b/crates/prover/src/core/backend/simd/m31.rs index f6291626b..babe7cad7 100644 --- a/crates/prover/src/core/backend/simd/m31.rs +++ b/crates/prover/src/core/backend/simd/m31.rs @@ -3,14 +3,13 @@ use std::mem::transmute; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::ptr; use std::simd::cmp::SimdOrd; -use std::simd::{u32x16, Simd, Swizzle}; +use std::simd::{u32x16, Simd}; use bytemuck::{Pod, Zeroable}; use num_traits::{One, Zero}; use rand::distributions::{Distribution, Standard}; use super::qm31::PackedQM31; -use crate::core::backend::simd::utils::{InterleaveEvens, InterleaveOdds}; use crate::core::fields::m31::{pow2147483645, BaseField, M31, P}; use crate::core::fields::qm31::QM31; use crate::core::fields::FieldExpOps; @@ -142,16 +141,16 @@ impl Mul for PackedM31 { // TODO: Come up with a better approach than `cfg`ing on target_feature. // TODO: Ensure all these branches get tested in the CI. cfg_if::cfg_if! { - if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { - _mul_neon(self, rhs) - } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { - _mul_wasm(self, rhs) + if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { + mul_neon(self, rhs) + } else if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + mul_wasm(self, rhs) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { - _mul_avx512(self, rhs) + mul_avx512(self, rhs) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { - _mul_avx2(self, rhs) + mul_avx2(self, rhs) } else { - _mul_simd(self, rhs) + mul_simd(self, rhs) } } } @@ -286,290 +285,299 @@ impl Sum for PackedM31 { } } -/// Returns `a * b`. -#[cfg(target_arch = "aarch64")] -pub(crate) fn _mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { - use core::arch::aarch64::{int32x2_t, vqdmull_s32}; - use std::simd::u32x4; - - let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) }; - let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) }; - - // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| - let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) }; - let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) }; - let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) }; - let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) }; - let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) }; - let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) }; - let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) }; - let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) }; - - // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. - // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. - let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); - let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); - let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); - let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); - - // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. - c0_c1_lo >>= 1; - c2_c3_lo >>= 1; - c4_c5_lo >>= 1; - c6_c7_lo >>= 1; - - let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; - let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; - - lo + hi -} +cfg_if::cfg_if! { + if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { + use core::arch::aarch64::{uint32x2_t, vmull_u32, int32x2_t, vqdmull_s32}; + use std::simd::u32x4; + + /// Returns `a * b`. + pub(crate) fn mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { + let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi + } -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "aarch64")] -pub(crate) fn _mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 { - use core::arch::aarch64::{uint32x2_t, vmull_u32}; - use std::simd::u32x4; - - let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; - let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) }; - - // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| - let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; - let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; - let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; - let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; - let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; - let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; - let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; - let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; - - // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. - // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. - let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); - let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); - let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); - let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); - - // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. - c0_c1_lo >>= 1; - c2_c3_lo >>= 1; - c4_c5_lo >>= 1; - c6_c7_lo >>= 1; - - let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; - let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; - - lo + hi -} + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 { + let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi + } + } else if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; + use std::simd::u32x4; -/// Returns `a * b`. -#[cfg(target_arch = "wasm32")] -pub(crate) fn _mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_wasm(a, b.0 + b.0) -} + /// Returns `a * b`. + pub(crate) fn mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_wasm(a, b.0 + b.0) + } -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "wasm32")] -pub(crate) fn _mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { - use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; - use std::simd::u32x4; - - let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; - let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) }; - - let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) }; - let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) }; - let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) }; - let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) }; - let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) }; - let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) }; - let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) }; - let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) }; - - let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); - let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); - let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); - let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); - c0_even >>= 1; - c1_even >>= 1; - c2_even >>= 1; - c3_even >>= 1; - - let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; - let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; - - even + odd -} + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { + let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; + let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) }; + + let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) }; + let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) }; + let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) }; + let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) }; + let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) }; + let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) }; + let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) }; + let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) }; + + let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); + let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); + let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); + let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); + c0_even >>= 1; + c1_even >>= 1; + c2_even >>= 1; + c3_even >>= 1; + + let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; + let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; + + even + odd + } + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { + use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; + use std::simd::Swizzle; -/// Returns `a * b`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_avx512(a, b.0 + b.0) -} + use crate::core::backend::simd::utils::swizzle::{InterleaveEvens, InterleaveOdds}; -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { - use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; - - let a: __m512i = unsafe { transmute(a) }; - let b_double: __m512i = unsafe { transmute(b_double) }; - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a_e = a; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a_o = unsafe { _mm512_srli_epi64(a, 32) }; - - let b_dbl_e = b_double; - let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) }; - - // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. - let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; - let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_dbl_e - |0|prod_e_h|prod_e_l|0| - // prod_dbl_o - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: - let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| - // Divide by 2: - prod_lo >>= 1; - // prod_lo - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: - let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_hi - |0|prod_o_h|0|prod_e_h| - - PackedM31(prod_lo) + PackedM31(prod_hi) -} + /// Returns `a * b`. + pub(crate) fn mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_avx512(a, b.0 + b.0) + } -/// Returns `a * b`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_avx2(a, b.0 + b.0) -} + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { + let a: __m512i = unsafe { transmute(a) }; + let b_double: __m512i = unsafe { transmute(b_double) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = a; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { _mm512_srli_epi64(a, 32) }; + + let b_dbl_e = b_double; + let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) }; + + // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. + let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; + let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + PackedM31(prod_lo) + PackedM31(prod_hi) + } + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { + use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; + use std::simd::Swizzle; -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { - use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; - - let [a0, a1]: [__m256i; 2] = unsafe { transmute(a) }; - let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute(b_double) }; - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a0_e = a0; - let a1_e = a1; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; - let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; - - let b0_dbl_e = b0_dbl; - let b1_dbl_e = b1_dbl; - let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; - let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; - - // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. - let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; - let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; - let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; - let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; - - let prod_dbl_e: u32x16 = unsafe { transmute([prod0_dbl_e, prod1_dbl_e]) }; - let prod_dbl_o: u32x16 = unsafe { transmute([prod0_dbl_o, prod1_dbl_o]) }; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_dbl_e - |0|prod_e_h|prod_e_l|0| - // prod_dbl_o - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: - let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| - // Divide by 2: - prod_lo >>= 1; - // prod_lo - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: - let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_hi - |0|prod_o_h|0|prod_e_h| - - PackedM31(prod_lo) + PackedM31(prod_hi) -} + use crate::core::backend::simd::utils::swizzle::{InterleaveEvens, InterleaveOdds}; -/// Returns `a * b`. -/// -/// Should only be used in the absence of a platform specific implementation. -pub(crate) fn _mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_simd(a, b.0 + b.0) -} + /// Returns `a * b`. + pub(crate) fn mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_avx2(a, b.0 + b.0) + } -/// Returns `a * b`. -/// -/// Should only be used in the absence of a platform specific implementation. -/// -/// `b_double` should be in the range `[0, 2P]`. -pub(crate) fn _mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { - const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a_e = unsafe { transmute::<_, Simd>(a.0) & MASK_EVENS }; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a_o = unsafe { transmute::<_, Simd>(a) >> 32 }; - - let b_dbl_e = unsafe { transmute::<_, Simd>(b_double) & MASK_EVENS }; - let b_dbl_o = unsafe { transmute::<_, Simd>(b_double) >> 32 }; - - // To compute prod = a * b start by multiplying - // a_e/o by b_dbl_e/o. - let prod_e_dbl = a_e * b_dbl_e; - let prod_o_dbl = a_o * b_dbl_o; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_e_dbl - |0|prod_e_h|prod_e_l|0| - // prod_o_dbl - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, - // prod_o_dbl); - // prod_ls - |prod_o_l|0|prod_e_l|0| - let mut prod_lows = InterleaveEvens::concat_swizzle( - unsafe { transmute::<_, Simd>(prod_e_dbl) }, - unsafe { transmute::<_, Simd>(prod_o_dbl) }, - ); - // Divide by 2: - prod_lows >>= 1; - // prod_ls - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_highs = InterleaveOdds::concat_swizzle( - unsafe { transmute::<_, Simd>(prod_e_dbl) }, - unsafe { transmute::<_, Simd>(prod_o_dbl) }, - ); - - // prod_hs - |0|prod_o_h|0|prod_e_h| - PackedM31(prod_lows) + PackedM31(prod_highs) + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { + let [a0, a1]: [__m256i; 2] = unsafe { transmute::(a) }; + let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute::(b_double) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a0_e = a0; + let a1_e = a1; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; + let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; + + let b0_dbl_e = b0_dbl; + let b1_dbl_e = b1_dbl; + let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; + let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; + + // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. + let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; + let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; + let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; + let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; + + let prod_dbl_e: u32x16 = + unsafe { transmute::<[__m256i; 2], u32x16>([prod0_dbl_e, prod1_dbl_e]) }; + let prod_dbl_o: u32x16 = + unsafe { transmute::<[__m256i; 2], u32x16>([prod0_dbl_o, prod1_dbl_o]) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + PackedM31(prod_lo) + PackedM31(prod_hi) + } + } else { + use std::simd::Swizzle; + + use crate::core::backend::simd::utils::swizzle::{InterleaveEvens, InterleaveOdds}; + + /// Returns `a * b`. + /// + /// Should only be used in the absence of a platform specific implementation. + pub(crate) fn mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_simd(a, b.0 + b.0) + } + + /// Returns `a * b`. + /// + /// Should only be used in the absence of a platform specific implementation. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { + const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = + unsafe { transmute::, Simd>(a.0) & MASK_EVENS }; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { transmute::>(a) >> 32 }; + + let b_dbl_e = unsafe { + transmute::, Simd>(b_double) & MASK_EVENS + }; + let b_dbl_o = + unsafe { transmute::, Simd>(b_double) >> 32 }; + + // To compute prod = a * b start by multiplying + // a_e/o by b_dbl_e/o. + let prod_e_dbl = a_e * b_dbl_e; + let prod_o_dbl = a_o * b_dbl_o; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_e_dbl - |0|prod_e_h|prod_e_l|0| + // prod_o_dbl - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: + // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, + // prod_o_dbl); + // prod_ls - |prod_o_l|0|prod_e_l|0| + let mut prod_lows = InterleaveEvens::concat_swizzle( + unsafe { transmute::, Simd>(prod_e_dbl) }, + unsafe { transmute::, Simd>(prod_o_dbl) }, + ); + // Divide by 2: + prod_lows >>= 1; + // prod_ls - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: + let prod_highs = InterleaveOdds::concat_swizzle( + unsafe { transmute::, Simd>(prod_e_dbl) }, + unsafe { transmute::, Simd>(prod_o_dbl) }, + ); + + // prod_hs - |0|prod_o_h|0|prod_e_h| + PackedM31(prod_lows) + PackedM31(prod_highs) + } + } } #[cfg(test)] diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index 553e540a5..ccda23343 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -286,13 +286,13 @@ mod tests { let e1: BaseColumn = (0..small_domain.size()) .map(|i| BaseField::from(2 * i)) .collect(); - let polys = vec![ + let polys = [ CircleEvaluation::::new(small_domain, e0) .interpolate(), CircleEvaluation::::new(small_domain, e1) .interpolate(), ]; - let columns = vec![polys[0].evaluate(domain), polys[1].evaluate(domain)]; + let columns = [polys[0].evaluate(domain), polys[1].evaluate(domain)]; let random_coeff = qm31!(1, 2, 3, 4); let a = polys[0].eval_at_point(SECURE_FIELD_CIRCLE_GEN); let b = polys[1].eval_at_point(SECURE_FIELD_CIRCLE_GEN); diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index b5cb9e986..b698f9961 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -1,29 +1,3 @@ -use std::simd::Swizzle; - -/// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. -pub struct InterleaveEvens; - -impl Swizzle for InterleaveEvens { - const INDEX: [usize; N] = parity_interleave(false); -} - -/// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors. -pub struct InterleaveOdds; - -impl Swizzle for InterleaveOdds { - const INDEX: [usize; N] = parity_interleave(true); -} - -const fn parity_interleave(odd: bool) -> [usize; N] { - let mut res = [0; N]; - let mut i = 0; - while i < N { - res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 }; - i += 1; - } - res -} - // TODO(andrew): Examine usage of unsafe in SIMD FFT. pub struct UnsafeMut(pub *mut T); impl UnsafeMut { @@ -51,29 +25,60 @@ impl UnsafeConst { unsafe impl Send for UnsafeConst {} unsafe impl Sync for UnsafeConst {} -#[cfg(test)] -mod tests { - use std::simd::{u32x4, Swizzle}; - - use super::{InterleaveEvens, InterleaveOdds}; +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "wasm32", target_feature = "simd128") +)))] +pub mod swizzle { + use std::simd::Swizzle; + + /// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. + pub struct InterleaveEvens; + impl Swizzle for InterleaveEvens { + const INDEX: [usize; N] = parity_interleave(false); + } - #[test] - fn interleave_evens() { - let lo = u32x4::from_array([0, 1, 2, 3]); - let hi = u32x4::from_array([4, 5, 6, 7]); + /// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors. + pub struct InterleaveOdds; - let res = InterleaveEvens::concat_swizzle(lo, hi); + impl Swizzle for InterleaveOdds { + const INDEX: [usize; N] = parity_interleave(true); + } - assert_eq!(res, u32x4::from_array([0, 4, 2, 6])); + const fn parity_interleave(odd: bool) -> [usize; N] { + let mut res = [0; N]; + let mut i = 0; + while i < N { + res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 }; + i += 1; + } + res } - #[test] - fn interleave_odds() { - let lo = u32x4::from_array([0, 1, 2, 3]); - let hi = u32x4::from_array([4, 5, 6, 7]); + #[cfg(test)] + mod tests { + use std::simd::{u32x4, Swizzle}; + + use super::{InterleaveEvens, InterleaveOdds}; + + #[test] + fn interleave_evens() { + let lo = u32x4::from_array([0, 1, 2, 3]); + let hi = u32x4::from_array([4, 5, 6, 7]); + + let res = InterleaveEvens::concat_swizzle(lo, hi); + + assert_eq!(res, u32x4::from_array([0, 4, 2, 6])); + } + + #[test] + fn interleave_odds() { + let lo = u32x4::from_array([0, 1, 2, 3]); + let hi = u32x4::from_array([4, 5, 6, 7]); - let res = InterleaveOdds::concat_swizzle(lo, hi); + let res = InterleaveOdds::concat_swizzle(lo, hi); - assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); + assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); + } } } diff --git a/crates/prover/src/core/channel/blake2s.rs b/crates/prover/src/core/channel/blake2s.rs index 86565658b..632a23855 100644 --- a/crates/prover/src/core/channel/blake2s.rs +++ b/crates/prover/src/core/channel/blake2s.rs @@ -75,7 +75,7 @@ impl Channel for Blake2sChannel { let res = compress(std::array::from_fn(|i| digest[i]), msg, 0, 0, 0, 0); // TODO(shahars) Channel should always finalize hash. - self.update_digest(unsafe { std::mem::transmute(res) }); + self.update_digest(unsafe { std::mem::transmute::<[u32; 8], Blake2sHash>(res) }); } fn draw_felt(&mut self) -> SecureField { diff --git a/crates/prover/src/core/constraints.rs b/crates/prover/src/core/constraints.rs index 31711d98e..f66c8d93d 100644 --- a/crates/prover/src/core/constraints.rs +++ b/crates/prover/src/core/constraints.rs @@ -90,11 +90,11 @@ pub fn complex_conjugate_line( / (point.complex_conjugate().y - point.y) } -/// Evaluates the coefficients of a line between a point and its complex conjugate. Specifically, -/// `a, b, and c, s.t. a*x + b -c*y = 0` for (x,y) being (sample.y, sample.value) and -/// (conj(sample.y), conj(sample.value)). -/// Relies on the fact that every polynomial F over the base -/// field holds: F(p*) == F(p)* (* being the complex conjugate). +/// Evaluates the coefficients of a line between a point and its complex conjugate. +/// +/// Specifically, `a, b, and c, s.t. a*x + b -c*y = 0` for (x,y) being (sample.y, sample.value) and +/// (conj(sample.y), conj(sample.value)). Relies on the fact that every polynomial F over the base +/// field holds: `F(p*) == F(p)*` (`*` being the complex conjugate). pub fn complex_conjugate_line_coeffs( sample: &PointSample, alpha: SecureField, diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index ce5d4bddb..108db501b 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -97,8 +97,7 @@ pub trait FriOps: FieldOps + PolyOps + Sized + FieldOps /// Let `src` be the evaluation of a circle polynomial `f` on a /// [`CircleDomain`] `E`. This function computes evaluations of `f' = f0 /// + alpha * f1` on the x-coordinates of `E` such that `2f(p) = f0(px) + py * f1(px)`. The - /// evaluations of `f'` are accumulated into `dst` by the formula `dst = dst * alpha^2 + - /// f'`. + /// evaluations of `f'` are accumulated into `dst` by the formula `dst = dst * alpha^2 + f'`. /// /// # Panics /// @@ -727,7 +726,7 @@ impl FriLayerVerifier { let mut all_subline_evals = Vec::new(); // Group queries by the subline they reside in. - for subline_queries in queries.group_by(|a, b| a >> FOLD_STEP == b >> FOLD_STEP) { + for subline_queries in queries.chunk_by(|a, b| a >> FOLD_STEP == b >> FOLD_STEP) { let subline_start = (subline_queries[0] >> FOLD_STEP) << FOLD_STEP; let subline_end = subline_start + (1 << FOLD_STEP); @@ -798,7 +797,7 @@ impl, H: MerkleHasher> FriLayerProver { // Group queries by the subline they reside in. // TODO(andrew): Explain what a "subline" is at the top of the module. - for query_group in queries.group_by(|a, b| a >> FOLD_STEP == b >> FOLD_STEP) { + for query_group in queries.chunk_by(|a, b| a >> FOLD_STEP == b >> FOLD_STEP) { let subline_start = (query_group[0] >> FOLD_STEP) << FOLD_STEP; let subline_end = subline_start + (1 << FOLD_STEP); diff --git a/crates/prover/src/core/lookups/gkr_prover.rs b/crates/prover/src/core/lookups/gkr_prover.rs index 6e6ed2586..c2d6df1bd 100644 --- a/crates/prover/src/core/lookups/gkr_prover.rs +++ b/crates/prover/src/core/lookups/gkr_prover.rs @@ -299,7 +299,7 @@ pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> { pub lambda: SecureField, } -impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> { +impl MultivariatePolyOracle for GkrMultivariatePolyOracle<'_, B> { fn n_variables(&self) -> usize { self.input_layer.n_variables() - 1 } @@ -470,7 +470,7 @@ pub fn prove_batch( // Seed the channel with the layer masks. for (&instance, mask) in zip(&sumcheck_instances, &masks) { - channel.mix_felts(mask.columns().flatten()); + channel.mix_felts(mask.columns().as_flattened()); layer_masks_by_instance[instance].push(mask.clone()); } diff --git a/crates/prover/src/core/lookups/gkr_verifier.rs b/crates/prover/src/core/lookups/gkr_verifier.rs index b65ceb162..69d8314f5 100644 --- a/crates/prover/src/core/lookups/gkr_verifier.rs +++ b/crates/prover/src/core/lookups/gkr_verifier.rs @@ -120,7 +120,7 @@ pub fn partially_verify_batch( for &instance in &sumcheck_instances { let n_unused = n_layers - instance_n_layers(instance); let mask = &layer_masks_by_instance[instance][layer - n_unused]; - channel.mix_felts(mask.columns().flatten()); + channel.mix_felts(mask.columns().as_flattened()); } // Set the OOD evaluation point for layer above. diff --git a/crates/prover/src/core/pcs/mod.rs b/crates/prover/src/core/pcs/mod.rs index d9acf524b..1a551d1eb 100644 --- a/crates/prover/src/core/pcs/mod.rs +++ b/crates/prover/src/core/pcs/mod.rs @@ -1,4 +1,5 @@ //! Implements a FRI polynomial commitment scheme. +//! //! This is a protocol where the prover can commit on a set of polynomials and then prove their //! opening on a set of points. //! Note: This implementation is not really a polynomial commitment scheme, because we are not in diff --git a/crates/prover/src/core/pcs/prover.rs b/crates/prover/src/core/pcs/prover.rs index 58d347ee6..4d584d748 100644 --- a/crates/prover/src/core/pcs/prover.rs +++ b/crates/prover/src/core/pcs/prover.rs @@ -163,7 +163,7 @@ pub struct TreeBuilder<'a, 'b, B: BackendForChannel, MC: MerkleChannel> { commitment_scheme: &'a mut CommitmentSchemeProver<'b, B, MC>, polys: ColumnVec>, } -impl<'a, 'b, B: BackendForChannel, MC: MerkleChannel> TreeBuilder<'a, 'b, B, MC> { +impl, MC: MerkleChannel> TreeBuilder<'_, '_, B, MC> { pub fn extend_evals( &mut self, columns: impl IntoIterator>, diff --git a/crates/prover/src/core/pcs/verifier.rs b/crates/prover/src/core/pcs/verifier.rs index 959393faa..4cd707daa 100644 --- a/crates/prover/src/core/pcs/verifier.rs +++ b/crates/prover/src/core/pcs/verifier.rs @@ -104,7 +104,7 @@ impl CommitmentSchemeVerifier { }) .0 .into_iter() - .collect::>()?; + .collect::>()?; // Answer FRI queries. let samples = sampled_points diff --git a/crates/prover/src/core/poly/circle/canonic.rs b/crates/prover/src/core/poly/circle/canonic.rs index 837e648d9..4d068594d 100644 --- a/crates/prover/src/core/poly/circle/canonic.rs +++ b/crates/prover/src/core/poly/circle/canonic.rs @@ -2,12 +2,14 @@ use super::CircleDomain; use crate::core::circle::{CirclePoint, CirclePointIndex, Coset}; use crate::core::fields::m31::BaseField; -/// A coset of the form G_{2n} + , where G_n is the generator of the -/// subgroup of order n. The ordering on this coset is G_2n + i * G_n. -/// These cosets can be used as a [CircleDomain], and be interpolated on. -/// Note that this changes the ordering on the coset to be like [CircleDomain], -/// which is G_2n + i * G_n/2 and then -G_2n -i * G_n/2. -/// For example, the Xs below are a canonic coset with n=8. +/// A coset of the form `G_{2n} + `, where `G_n` is the generator of the subgroup of order `n`. +/// +/// The ordering on this coset is `G_2n + i * G_n`. +/// These cosets can be used as a [`CircleDomain`], and be interpolated on. +/// Note that this changes the ordering on the coset to be like [`CircleDomain`], +/// which is `G_{2n} + i * G_{n/2}` and then `-G_{2n} -i * G_{n/2}`. +/// For example, the `X`s below are a canonic coset with `n=8`. +/// /// ```text /// X O X /// O O diff --git a/crates/prover/src/core/poly/circle/domain.rs b/crates/prover/src/core/poly/circle/domain.rs index fba2bc3fb..b6ced6018 100644 --- a/crates/prover/src/core/poly/circle/domain.rs +++ b/crates/prover/src/core/poly/circle/domain.rs @@ -10,8 +10,9 @@ use crate::core::fields::m31::BaseField; pub const MAX_CIRCLE_DOMAIN_LOG_SIZE: u32 = M31_CIRCLE_LOG_ORDER - 1; /// A valid domain for circle polynomial interpolation and evaluation. -/// Valid domains are a disjoint union of two conjugate cosets: +-C + . -/// The ordering defined on this domain is C + iG_n, and then -C - iG_n. +/// +/// Valid domains are a disjoint union of two conjugate cosets: `+-C + `. +/// The ordering defined on this domain is `C + iG_n`, and then `-C - iG_n`. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct CircleDomain { pub half_coset: Coset, diff --git a/crates/prover/src/core/poly/circle/evaluation.rs b/crates/prover/src/core/poly/circle/evaluation.rs index faa2b7284..6094df399 100644 --- a/crates/prover/src/core/poly/circle/evaluation.rs +++ b/crates/prover/src/core/poly/circle/evaluation.rs @@ -146,7 +146,7 @@ impl<'a, F: ExtensionOf> CosetSubEvaluation<'a, F> { } } -impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { +impl> Index for CosetSubEvaluation<'_, F> { type Output = F; fn index(&self, index: isize) -> &Self::Output { @@ -156,7 +156,7 @@ impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { } } -impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { +impl> Index for CosetSubEvaluation<'_, F> { type Output = F; fn index(&self, index: usize) -> &Self::Output { diff --git a/crates/prover/src/core/poly/twiddles.rs b/crates/prover/src/core/poly/twiddles.rs index f3b186376..2e172a2cf 100644 --- a/crates/prover/src/core/poly/twiddles.rs +++ b/crates/prover/src/core/poly/twiddles.rs @@ -2,6 +2,7 @@ use super::circle::PolyOps; use crate::core::circle::Coset; /// Precomputed twiddles for a specific coset tower. +/// /// A coset tower is every repeated doubling of a `root_coset`. /// The largest CircleDomain that can be ffted using these twiddles is one with `root_coset` as /// its `half_coset`. diff --git a/crates/prover/src/core/utils.rs b/crates/prover/src/core/utils.rs index 628679d7c..7cd76a6cc 100644 --- a/crates/prover/src/core/utils.rs +++ b/crates/prover/src/core/utils.rs @@ -24,7 +24,7 @@ pub struct PeekTakeWhile<'a, I: Iterator, P: FnMut(&I::Item) -> bool> { iter: &'a mut Peekable, predicate: P, } -impl<'a, I: Iterator, P: FnMut(&I::Item) -> bool> Iterator for PeekTakeWhile<'a, I, P> { +impl bool> Iterator for PeekTakeWhile<'_, I, P> { type Item = I::Item; fn next(&mut self) -> Option { diff --git a/crates/prover/src/core/vcs/blake2_merkle.rs b/crates/prover/src/core/vcs/blake2_merkle.rs index 293ed4ab3..f5a1f5762 100644 --- a/crates/prover/src/core/vcs/blake2_merkle.rs +++ b/crates/prover/src/core/vcs/blake2_merkle.rs @@ -20,7 +20,7 @@ impl MerkleHasher for Blake2sMerkleHasher { if let Some((left, right)) = children_hashes { state = compress( state, - unsafe { std::mem::transmute([left, right]) }, + unsafe { std::mem::transmute::<[Blake2sHash; 2], [u32; 16]>([left, right]) }, 0, 0, 0, @@ -33,9 +33,16 @@ impl MerkleHasher for Blake2sMerkleHasher { .copied() .chain(std::iter::repeat(BaseField::zero()).take(rem)); for chunk in padded_values.array_chunks::<16>() { - state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0); + state = compress( + state, + unsafe { std::mem::transmute::<[BaseField; 16], [u32; 16]>(chunk) }, + 0, + 0, + 0, + 0, + ); } - state.map(|x| x.to_le_bytes()).flatten().into() + state.map(|x| x.to_le_bytes()).as_flattened().into() } } diff --git a/crates/prover/src/core/vcs/blake2s_ref.rs b/crates/prover/src/core/vcs/blake2s_ref.rs index ab32ea6d9..a47daa07c 100644 --- a/crates/prover/src/core/vcs/blake2s_ref.rs +++ b/crates/prover/src/core/vcs/blake2s_ref.rs @@ -1,4 +1,5 @@ //! An AVX512 implementation of the BLAKE2s compression function. +//! //! Based on . pub const IV: [u32; 8] = [ @@ -30,22 +31,22 @@ fn xor(a: u32, b: u32) -> u32 { #[inline(always)] fn rot16(x: u32) -> u32 { - (x >> 16) | (x << (32 - 16)) + x.rotate_right(16) } #[inline(always)] fn rot12(x: u32) -> u32 { - (x >> 12) | (x << (32 - 12)) + x.rotate_right(12) } #[inline(always)] fn rot8(x: u32) -> u32 { - (x >> 8) | (x << (32 - 8)) + x.rotate_right(8) } #[inline(always)] fn rot7(x: u32) -> u32 { - (x >> 7) | (x << (32 - 7)) + x.rotate_right(7) } #[inline(always)] diff --git a/crates/prover/src/core/vcs/ops.rs b/crates/prover/src/core/vcs/ops.rs index 14093e536..b40a91bef 100644 --- a/crates/prover/src/core/vcs/ops.rs +++ b/crates/prover/src/core/vcs/ops.rs @@ -6,13 +6,12 @@ use crate::core::backend::{Col, ColumnOps}; use crate::core::fields::m31::BaseField; use crate::core::vcs::hash::Hash; -/// A Merkle node hash is a hash of: -/// [left_child_hash, right_child_hash], column0_value, column1_value, ... -/// "[]" denotes optional values. +/// A Merkle node hash is a hash of: `[left_child_hash, right_child_hash], column0_value, +/// column1_value, ...` where `[]` denotes optional values. +/// /// The largest Merkle layer has no left and right child hashes. The rest of the layers have -/// children hashes. -/// At each layer, the tree may have multiple columns of the same length as the layer. -/// Each node in that layer contains one value from each column. +/// children hashes. At each layer, the tree may have multiple columns of the same length as the +/// layer. Each node in that layer contains one value from each column. pub trait MerkleHasher: Debug + Default + Clone { type Hash: Hash; /// Hashes a single Merkle node. See [MerkleHasher] for more details. diff --git a/crates/prover/src/core/vcs/prover.rs b/crates/prover/src/core/vcs/prover.rs index 6312de114..77c8bf2da 100644 --- a/crates/prover/src/core/vcs/prover.rs +++ b/crates/prover/src/core/vcs/prover.rs @@ -64,8 +64,7 @@ impl, H: MerkleHasher> MerkleProver { /// /// # Arguments /// - /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that - /// log_size. + /// * `queries_per_log_size` - Maps a log_size to a vector of queries for columns of that size. /// * `columns` - A vector of references to columns. /// /// # Returns diff --git a/crates/prover/src/core/vcs/verifier.rs b/crates/prover/src/core/vcs/verifier.rs index 0a907f58e..7e695efcd 100644 --- a/crates/prover/src/core/vcs/verifier.rs +++ b/crates/prover/src/core/vcs/verifier.rs @@ -27,9 +27,9 @@ impl MerkleVerifier { /// # Arguments /// /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that - /// log_size. + /// log_size. /// * `queried_values` - A vector of vectors of queried values. For each column, there is a - /// vector of queried values to that column. + /// vector of queried values to that column. /// * `decommitment` - The decommitment object containing the witness and column values. /// /// # Errors diff --git a/crates/prover/src/examples/blake/round/constraints.rs b/crates/prover/src/examples/blake/round/constraints.rs index b5e415d14..6c76a99bb 100644 --- a/crates/prover/src/examples/blake/round/constraints.rs +++ b/crates/prover/src/examples/blake/round/constraints.rs @@ -17,7 +17,7 @@ pub struct BlakeRoundEval<'a, E: EvalAtRow> { pub round_lookup_elements: &'a RoundElements, pub logup: LogupAtRow, } -impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { +impl BlakeRoundEval<'_, E> { pub fn eval(mut self) -> E { let mut v: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); let input_v = v.clone(); diff --git a/crates/prover/src/examples/blake/round/gen.rs b/crates/prover/src/examples/blake/round/gen.rs index ceaf6a97a..d83252ee6 100644 --- a/crates/prover/src/examples/blake/round/gen.rs +++ b/crates/prover/src/examples/blake/round/gen.rs @@ -68,7 +68,7 @@ struct TraceGeneratorRow<'a> { vec_row: usize, xor_lookups_index: usize, } -impl<'a> TraceGeneratorRow<'a> { +impl TraceGeneratorRow<'_> { fn append_felt(&mut self, val: u32x16) { self.gen.trace[self.col_index].data[self.vec_row] = unsafe { PackedBaseField::from_simd_unchecked(val) }; diff --git a/crates/prover/src/examples/blake/xor_table/constraints.rs b/crates/prover/src/examples/blake/xor_table/constraints.rs index e91cb3661..8a413d4cf 100644 --- a/crates/prover/src/examples/blake/xor_table/constraints.rs +++ b/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -13,8 +13,8 @@ pub struct XorTableEval<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BIT pub lookup_elements: &'a XorElements, pub logup: LogupAtRow, } -impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> - XorTableEval<'a, E, ELEM_BITS, EXPAND_BITS> +impl + XorTableEval<'_, E, ELEM_BITS, EXPAND_BITS> { pub fn eval(mut self) -> E { // al, bl are the constant columns for the inputs: All pairs of elements in [0, diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 232ea9a4e..583a499fe 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -461,7 +461,7 @@ mod tests { for i in 0..16 { internal_matrix[i][i] += BaseField::from_u32_unchecked(1 << (i + 1)); } - let matrix = RowMajorMatrix::::new(internal_matrix.flatten().to_vec()); + let matrix = RowMajorMatrix::::new(internal_matrix.as_flattened().to_vec()); let expected_state = matrix.mul(state); apply_internal_round_matrix(&mut state); diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 985bfaf51..c54b5765e 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -114,9 +114,7 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> MleEvalProverComponent<'twiddl } } -impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> Component - for MleEvalProverComponent<'twiddles, 'oracle, O> -{ +impl Component for MleEvalProverComponent<'_, '_, O> { fn n_constraints(&self) -> usize { self.eval_info().n_constraints } @@ -180,9 +178,7 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> Component } } -impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> ComponentProver - for MleEvalProverComponent<'twiddles, 'oracle, O> -{ +impl ComponentProver for MleEvalProverComponent<'_, '_, O> { fn evaluate_constraint_quotients_on_domain( &self, trace: &Trace<'_, SimdBackend>, @@ -317,7 +313,7 @@ impl<'oracle, O: MleCoeffColumnOracle> MleEvalVerifierComponent<'oracle, O> { } } -impl<'oracle, O: MleCoeffColumnOracle> Component for MleEvalVerifierComponent<'oracle, O> { +impl Component for MleEvalVerifierComponent<'_, O> { fn n_constraints(&self) -> usize { self.eval_info().n_constraints } diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index 49adff0f3..47bc5fb50 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -1,22 +1,20 @@ #![allow(incomplete_features)] +#![cfg_attr( + all(target_arch = "x86_64", target_feature = "avx512f"), + feature(stdarch_x86_avx512) +)] #![feature( array_chunks, - array_methods, array_try_from_fn, assert_matches, exact_size_is_empty, generic_const_exprs, get_many_mut, int_roundings, - is_sorted, iter_array_chunks, - new_uninit, portable_simd, - slice_first_last_chunk, - slice_flatten, - slice_group_by, slice_ptr_get, - stdsimd + trait_upcasting )] pub mod constraint_framework; pub mod core; diff --git a/rust-toolchain.toml b/rust-toolchain.toml index a0f1a930e..85b121b33 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-01-04" +channel = "nightly-2024-11-06" diff --git a/scripts/bench.sh b/scripts/bench.sh index 416cfb36d..669564b42 100755 --- a/scripts/bench.sh +++ b/scripts/bench.sh @@ -2,4 +2,4 @@ # Can be used as a drop in replacement for `cargo bench`. # For example, `./scripts/bench.sh` will run all benchmarks. # or `./scripts/bench.sh M31` will run only the M31 benchmarks. -RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=3" cargo bench $@ +RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo bench $@ diff --git a/scripts/clippy.sh b/scripts/clippy.sh index 8361cd25d..43f11957e 100755 --- a/scripts/clippy.sh +++ b/scripts/clippy.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly-2024-01-04 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ +cargo +nightly-2024-11-06 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ -D nonstandard-style -D rust-2018-idioms -D unused diff --git a/scripts/rust_fmt.sh b/scripts/rust_fmt.sh index e4223f999..9f95485b0 100755 --- a/scripts/rust_fmt.sh +++ b/scripts/rust_fmt.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly-2024-01-04 fmt --all -- "$@" +cargo +nightly-2024-11-06 fmt --all -- "$@" diff --git a/scripts/test_avx.sh b/scripts/test_avx.sh index d911a2479..f7755f61d 100755 --- a/scripts/test_avx.sh +++ b/scripts/test_avx.sh @@ -1,4 +1,4 @@ #!/bin/bash # Can be used as a drop in replacement for `cargo test` with avx512f flag on. # For example, `./scripts/test_avx.sh` will run all tests(not only avx). -RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2024-01-04 test "$@" +RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2024-11-06 test "$@"