diff --git a/Cargo.toml b/Cargo.toml index 70cb0e906..b5cfc34ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ ark-std = { version = "0.4.0", default-features = false } ark-serialize = { version = "0.4.2", default-features = false, features = [ "derive", ] } +parking_lot = { version = "0.12.1", optional = true } # ark-bls12-381 = { version = "^0.4.0", default-features = false, features = [ "curve" ] } criterion = { version = "0.3.1", features = ["html_reports"] } @@ -66,7 +67,7 @@ default = [ "ark-ff/asm", "multicore", ] -multicore = ["rayon"] +multicore = ["rayon", "parking_lot"] ark-msm = [] # run with arkworks MSM without small field element optimization [profile.release] diff --git a/src/subprotocols/sumcheck.rs b/src/subprotocols/sumcheck.rs index 15b5a114b..1fbc908b7 100644 --- a/src/subprotocols/sumcheck.rs +++ b/src/subprotocols/sumcheck.rs @@ -22,6 +22,9 @@ use crate::msm::VariableBaseMSM; #[cfg(feature = "multicore")] use rayon::prelude::*; +#[cfg(feature = "multicore")] +use parking_lot::Mutex; + impl SumcheckInstanceProof { #[tracing::instrument(skip_all, name = "Sumcheck.prove_batched")] pub fn prove_cubic_batched( @@ -165,16 +168,22 @@ impl SumcheckInstanceProof { for _round in 0..num_rounds { // Vector storing evaluations of combined polynomials g(x) = P_0(x) * ... P_{num_polys} (x) // for points {0, ..., |g(x)|} - let mut eval_points = vec![F::zero(); combined_degree + 1]; let mle_half = polys[0].len() / 2; // let mut accum = vec![vec![F::zero(); combined_degree + 1]; mle_half]; #[cfg(feature = "multicore")] - let iterator = (0..mle_half).into_par_iter(); + let (iterator, eval_points) = { + let iterator = (0..mle_half).into_par_iter(); + let eval_points = Mutex::new(vec![F::zero(); combined_degree + 1]); + (iterator, eval_points) + }; #[cfg(not(feature = "multicore"))] - let iterator = 0..mle_half; + let (iterator, mut eval_points) = { + let iterator = (0..mle_half).iter(); + let mut eval_points = vec![F::zero(); combined_degree + 1]; + }; let accum: Vec> = iterator .map(|poly_term_i| { @@ -217,12 +226,16 @@ impl SumcheckInstanceProof { }) .collect(); - // TODO(#31): Parallelize - for (poly_i, eval_point) in eval_points.iter_mut().enumerate() { - for mle in accum.iter().take(mle_half) { - *eval_point += mle[poly_i]; - } - } + (0..(combined_degree + 1)).into_par_iter().for_each(|poly_i| { + (0..mle_half).into_par_iter().for_each(|mle_i| { + #[cfg(feature = "multicore")] + let mut eval_points = eval_points.lock(); + eval_points[poly_i] += accum[mle_i][poly_i]; + }) + }); + + #[cfg(feature = "multicore")] + let eval_points = eval_points.into_inner(); let round_uni_poly = UniPoly::from_evals(&eval_points);