From a9659df67af16974b9a8c04133527fc98cb043d4 Mon Sep 17 00:00:00 2001 From: Sagar Dhawan Date: Mon, 18 Nov 2024 15:58:11 -0800 Subject: [PATCH] rewrite hyperkzg batch_commit and cleanup redundant batch msm implementations --- jolt-core/src/jolt/vm/mod.rs | 5 +- jolt-core/src/msm/icicle/adapter.rs | 112 +----- jolt-core/src/msm/mod.rs | 423 ++++++--------------- jolt-core/src/poly/commitment/hyperkzg.rs | 71 +--- jolt-core/src/poly/commitment/kzg.rs | 97 +++++ jolt-core/src/poly/commitment/zeromorph.rs | 2 +- 6 files changed, 235 insertions(+), 475 deletions(-) diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index 377ea65cd..1fbe412a0 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -260,10 +260,7 @@ impl JoltPolynomials { commitments.init_final_values().len(), ); - let span = tracing::span!( - tracing::Level::INFO, - "commit::commit_instructions_final_cts" - ); + let span = tracing::span!(tracing::Level::INFO, "commit::commit_bytecode.t_final"); let _guard = span.enter(); commitments.bytecode.t_final = PCS::commit(&self.bytecode.t_final, &preprocessing.generators); diff --git a/jolt-core/src/msm/icicle/adapter.rs b/jolt-core/src/msm/icicle/adapter.rs index bf736ecdd..dd27f7f04 100644 --- a/jolt-core/src/msm/icicle/adapter.rs +++ b/jolt-core/src/msm/icicle/adapter.rs @@ -60,8 +60,10 @@ pub trait Icicle: ScalarMul { pub fn icicle_msm( bases: &[GpuBaseType], scalars: &[V::ScalarField], - bit_size: i32, + bit_size: usize, ) -> V { + assert!(scalars.len() <= bases.len()); + let mut bases_slice = DeviceVec::>::device_malloc(bases.len()).unwrap(); let span = tracing::span!(tracing::Level::INFO, "convert_scalars"); @@ -94,14 +96,14 @@ pub fn icicle_msm( cfg.stream_handle = IcicleStreamHandle::from(&stream); cfg.is_async = false; cfg.are_scalars_montgomery_form = true; - cfg.bitsize = bit_size; + cfg.bitsize = bit_size as i32; let span = tracing::span!(tracing::Level::INFO, "gpu_msm"); let _guard = span.enter(); msm( &scalars_slice[..], - &bases_slice[..], + &bases_slice[..scalars.len()], &cfg, &mut msm_result[..], ) @@ -129,14 +131,14 @@ pub fn icicle_msm( pub fn icicle_batch_msm( bases: &[GpuBaseType], scalar_batches: &[&[V::ScalarField]], - bit_size: i32, + bit_size: usize, ) -> Vec { - let len = bases.len(); + let bases_len = bases.len(); let batch_size = scalar_batches.len(); - assert!(scalar_batches.iter().all(|s| s.len() == len)); + assert!(scalar_batches.iter().all(|s| s.len() <= bases_len)); let mut stream = IcicleStream::create().unwrap(); - let mut bases_slice = DeviceVec::>::device_malloc(len).unwrap(); + let mut bases_slice = DeviceVec::>::device_malloc(bases_len).unwrap(); let span = tracing::span!(tracing::Level::INFO, "copy_bases_to_gpu"); let _guard = span.enter(); bases_slice @@ -149,101 +151,15 @@ pub fn icicle_batch_msm( let mut msm_host_results = vec![Projective::::zero(); batch_size]; for (batch_i, scalars) in scalar_batches.iter().enumerate() { + let scalars_len = scalars.len(); let span = tracing::span!(tracing::Level::INFO, "convert_scalars"); let _guard = span.enter(); let mut scalars_slice = DeviceVec::<<::C as Curve>::ScalarField>::device_malloc_async( - len, &stream, - ) - .unwrap(); - let scalars_mont = unsafe { - &*(&scalars[..] as *const _ as *const [<::C as Curve>::ScalarField]) - }; - drop(_guard); - drop(span); - - let span = tracing::span!(tracing::Level::INFO, "copy_scalars_to_gpu"); - let _guard = span.enter(); - scalars_slice - .copy_from_host_async(HostSlice::from_slice(scalars_mont), &stream) - .unwrap(); - drop(_guard); - drop(span); - - let mut cfg = MSMConfig::default(); - cfg.stream_handle = IcicleStreamHandle::from(&stream); - cfg.is_async = true; - cfg.are_scalars_montgomery_form = true; - cfg.are_bases_montgomery_form = false; - cfg.bitsize = bit_size; - - let span = tracing::span!(tracing::Level::INFO, "msm_gpu"); - let _guard = span.enter(); - msm( - &scalars_slice[..], - &bases_slice[..], - &cfg, - &mut msm_result[..], - ) - .unwrap(); - drop(_guard); - drop(span); - - let span = tracing::span!(tracing::Level::INFO, "copy_msm_result"); - let _guard = span.enter(); - msm_result - .copy_to_host_async( - HostSlice::from_mut_slice(&mut msm_host_results[batch_i..(batch_i + 1)]), - &stream, - ) - .unwrap(); - drop(_guard); - drop(span); - } - - stream.synchronize().unwrap(); - - stream.destroy().unwrap(); - msm_host_results - .into_iter() - .map(|res| V::to_ark_projective(&res)) - .collect() -} - -#[tracing::instrument(skip_all, name = "icicle_batch_msm")] -/// MSM which allows scalar_batches of non-uniform size -pub fn icicle_variable_batch_msm( - bases: &[GpuBaseType], - scalar_batches: &[&[V::ScalarField]], - bit_size: i32, -) -> Vec { - let base_len = bases.len(); - let batch_size = scalar_batches.len(); - assert!(scalar_batches.iter().all(|s| s.len() <= base_len)); - - let mut stream = IcicleStream::create().unwrap(); - let mut bases_slice = DeviceVec::>::device_malloc(base_len).unwrap(); - let span = tracing::span!(tracing::Level::INFO, "copy_bases_to_gpu"); - let _guard = span.enter(); - bases_slice - .copy_from_host_async(HostSlice::from_slice(bases), &stream) - .unwrap(); - drop(_guard); - drop(span); - - let mut msm_result = DeviceVec::>::device_malloc(1).unwrap(); - let mut msm_host_results = vec![Projective::::zero(); batch_size]; - - for (batch_i, scalars) in scalar_batches.iter().enumerate() { - let mut scalars_slice = - DeviceVec::<<::C as Curve>::ScalarField>::device_malloc_async( - scalars.len(), + scalars_len, &stream, ) .unwrap(); - - let span = tracing::span!(tracing::Level::INFO, "convert_scalars"); - let _guard = span.enter(); let scalars_mont = unsafe { &*(&scalars[..] as *const _ as *const [<::C as Curve>::ScalarField]) }; @@ -263,13 +179,13 @@ pub fn icicle_variable_batch_msm( cfg.is_async = true; cfg.are_scalars_montgomery_form = true; cfg.are_bases_montgomery_form = false; - cfg.bitsize = bit_size; + cfg.bitsize = bit_size as i32; let span = tracing::span!(tracing::Level::INFO, "msm_gpu"); let _guard = span.enter(); msm( &scalars_slice[..], - &bases_slice[..scalars.len()], + &bases_slice[..scalars_len], &cfg, &mut msm_result[..], ) @@ -277,7 +193,7 @@ pub fn icicle_variable_batch_msm( drop(_guard); drop(span); - let span = tracing::span!(tracing::Level::INFO, "copy_result"); + let span = tracing::span!(tracing::Level::INFO, "copy_msm_result"); let _guard = span.enter(); msm_result .copy_to_host_async( diff --git a/jolt-core/src/msm/mod.rs b/jolt-core/src/msm/mod.rs index f1ce7a1b6..0b182724c 100644 --- a/jolt-core/src/msm/mod.rs +++ b/jolt-core/src/msm/mod.rs @@ -36,47 +36,21 @@ impl MsmType { _ => MsmType::Large, } } + + fn max_num_bits(&self) -> usize { + match self { + MsmType::Zero => 0, + MsmType::One => 1, + MsmType::Small => 10, + MsmType::Medium => 64, + MsmType::Large => 256, + } + } } -type TrackedScalar<'a, P: Pairing> = (usize, &'a [P::ScalarField]); +type TrackedScalar<'a, P: Pairing> = (usize, &'a [P::ScalarField]); pub type ScalarGroups<'a, P: Pairing> = (MsmType, Vec>); -#[tracing::instrument(skip_all)] -pub fn group_scalars_by_msm_type<'a, P: Pairing>( - scalars: &[&'a [P::ScalarField]], -) -> Vec> { - // Group scalars by their max number of bits and keep track of their original indices - let mut grouped: Vec>> = vec![Vec::new(); 5]; - - // Process slices in parallel - scalars - .par_iter() - .enumerate() - .map(|(idx, scalar_slice)| { - // TODO(sagar) should par iter this? - let max_num_bits = scalar_slice - .iter() - .map(|s| s.into_bigint().num_bits()) - .max() - .unwrap(); - let msm_type = MsmType::from_u32(max_num_bits); - (msm_type, idx, scalar_slice) - }) - .collect::>() - .into_iter() - .for_each(|(msm_type_idx, original_idx, scalar_slice)| { - grouped[msm_type_idx as usize].push((original_idx, scalar_slice)); - }); - - // Convert grouped Vec into the desired output format - grouped - .into_iter() - .enumerate() - .filter(|(_, group)| !group.is_empty()) // Ignore empty groups - .map(|(idx, group)| (MsmType::from_u32(idx as u32), group)) - .collect() -} - /// Copy of ark_ec::VariableBaseMSM with minor modifications to speed up /// known small element sized MSMs. pub trait VariableBaseMSM: ScalarMul + Icicle { @@ -96,55 +70,20 @@ pub trait VariableBaseMSM: ScalarMul + Icicle { .map(|s| s.into_bigint().num_bits()) .max() .unwrap(); + let msm_type = MsmType::from_u32(max_num_bits); - match max_num_bits { - 0 => Self::zero(), - 1 => { - if use_icicle() { - #[cfg(feature = "icicle")] - { - let mut backup = vec![]; - let gpu_bases = gpu_bases.unwrap_or_else(|| { - backup = Self::get_gpu_bases(bases); - &backup - }); - return icicle_msm::(gpu_bases, scalars, 1); - } - #[cfg(not(feature = "icicle"))] - { - unreachable!( - "icicle_init must not return true without the icicle feature" - ); - } - } - + match msm_type { + MsmType::Zero => Self::zero(), + MsmType::One => { let scalars_u64 = &map_field_elements_to_u64::(scalars); msm_binary(bases, scalars_u64) } - 2..=10 => { - // TODO(sagar) caching this as "use_icicle = use_icicle" seems to cause a massive slowdown - if use_icicle() { - #[cfg(feature = "icicle")] - { - let mut backup = vec![]; - let gpu_bases = gpu_bases.unwrap_or_else(|| { - backup = Self::get_gpu_bases(bases); - &backup - }); - return icicle_msm::(gpu_bases, scalars, 10); - } - #[cfg(not(feature = "icicle"))] - { - unreachable!( - "icicle_init must not return true without the icicle feature" - ); - } - } - + MsmType::Small => { let scalars_u64 = &map_field_elements_to_u64::(scalars); - msm_small(bases, scalars_u64, max_num_bits as usize) + msm_small(bases, scalars_u64, msm_type.max_num_bits()) } - 11..=64 => { + MsmType::Medium => { + // TODO(sagar) caching this as "use_icicle = use_icicle" seems to cause a massive slowdown if use_icicle() { #[cfg(feature = "icicle")] { @@ -153,7 +92,11 @@ pub trait VariableBaseMSM: ScalarMul + Icicle { backup = Self::get_gpu_bases(bases); &backup }); - return icicle_msm::(gpu_bases, scalars, 64); + return icicle_msm::( + gpu_bases, + scalars, + msm_type.max_num_bits(), + ); } #[cfg(not(feature = "icicle"))] { @@ -165,12 +108,12 @@ pub trait VariableBaseMSM: ScalarMul + Icicle { let scalars_u64 = &map_field_elements_to_u64::(scalars); if Self::NEGATION_IS_CHEAP { - msm_u64_wnaf(bases, scalars_u64, max_num_bits as usize) + msm_u64_wnaf(bases, scalars_u64, msm_type.max_num_bits()) } else { - msm_u64(bases, scalars_u64, max_num_bits as usize) + msm_u64(bases, scalars_u64, msm_type.max_num_bits()) } } - _ => { + MsmType::Large => { if use_icicle() { #[cfg(feature = "icicle")] { @@ -179,7 +122,11 @@ pub trait VariableBaseMSM: ScalarMul + Icicle { backup = Self::get_gpu_bases(bases); &backup }); - return icicle_msm::(gpu_bases, scalars, 256); + return icicle_msm::( + gpu_bases, + scalars, + msm_type.max_num_bits(), + ); } #[cfg(not(feature = "icicle"))] { @@ -194,9 +141,9 @@ pub trait VariableBaseMSM: ScalarMul + Icicle { .map(|s| s.into_bigint()) .collect::>(); if Self::NEGATION_IS_CHEAP { - msm_bigint_wnaf(bases, &scalars, max_num_bits as usize) + msm_bigint_wnaf(bases, &scalars, msm_type.max_num_bits()) } else { - msm_bigint(bases, &scalars, max_num_bits as usize) + msm_bigint(bases, &scalars, msm_type.max_num_bits()) } } } @@ -204,108 +151,66 @@ pub trait VariableBaseMSM: ScalarMul + Icicle { .ok_or_else(|| bases.len().min(scalars.len())) } - #[cfg(feature = "icicle")] - #[tracing::instrument(skip_all)] - fn get_gpu_bases(bases: &[Self::MulBase]) -> Vec> { - bases - .par_iter() - .map(|base| ::from_ark_affine(base)) - .collect() - } - #[tracing::instrument(skip_all)] fn batch_msm( bases: &[Self::MulBase], gpu_bases: Option<&[GpuBaseType]>, - scalars: &[&[Self::ScalarField]], + scalar_batches: &[&[Self::ScalarField]], ) -> Vec { - assert!(scalars.iter().all(|s| s.len() == scalars[0].len())); - assert_eq!(bases.len(), scalars[0].len()); + assert!(scalar_batches.iter().all(|s| s.len() <= bases.len())); #[cfg(not(feature = "icicle"))] assert!(gpu_bases.is_none()); - #[cfg(feature = "icicle")] - let mut backup: Vec> = vec![]; - #[cfg(feature = "icicle")] - let gpu_bases = if use_icicle() { - gpu_bases.unwrap_or_else(|| { - backup = Self::get_gpu_bases(bases); - &backup - }) - } else { - &backup - }; - - let slice_bit_size = 256 * scalars[0].len() * 3; + let slice_bit_size = 256 * scalar_batches[0].len() * 3; let slices_at_a_time = total_memory_bits() / slice_bit_size; - let mut telemetry = Vec::new(); - - for (i, scalar_slice) in scalars.iter().enumerate() { + let mut grouped_scalar_indices = vec![Vec::new(); 5]; + for (i, scalar_slice) in scalar_batches.iter().enumerate() { let max_num_bits = scalar_slice .par_iter() .map(|s| s.into_bigint().num_bits()) .max() .unwrap(); - let msm_type = match max_num_bits { - 0 => MsmType::Zero, - 1 => MsmType::One, - 2..=10 => MsmType::Small, - 11..=64 => MsmType::Medium, - _ => MsmType::Large, - }; - - telemetry.push((i, msm_type)); + let msm_type = MsmType::from_u32(max_num_bits); + grouped_scalar_indices[msm_type as usize].push(i); } - let mut results = vec![Self::zero(); scalars.len()]; + let mut results = vec![Self::zero(); scalar_batches.len()]; - let run_msm = |indices: Vec, msm_type: MsmType, results: &mut Vec| { + let run_msm = |indices: &[usize], msm_type: MsmType, results: &mut Vec| { let partial_results: Vec<(usize, Self)> = match msm_type { - MsmType::Zero => indices.into_par_iter().map(|i| (i, Self::zero())).collect(), - - MsmType::One => { + MsmType::Zero => indices + .into_par_iter() + .map(|i| (*i, Self::zero())) + .collect(), + MsmType::One | MsmType::Small => indices + .into_par_iter() + .map(|i| { + let scalars = scalar_batches[*i]; + ( + *i, + Self::msm(&bases[..scalars.len()], gpu_bases, scalar_batches[*i]) + .unwrap(), + ) + }) + .collect(), + MsmType::Medium | MsmType::Large => { if use_icicle() { #[cfg(feature = "icicle")] { + let mut backup = vec![]; + let gpu_bases = gpu_bases.unwrap_or_else(|| { + backup = Self::get_gpu_bases(bases); + &backup + }); let scalar_batches: Vec<&[Self::ScalarField]> = - indices.iter().map(|i| scalars[*i]).collect(); - let batch_results = - icicle_batch_msm::(gpu_bases, &scalar_batches, 1); - assert_eq!(batch_results.len(), scalar_batches.len()); - batch_results - .into_iter() - .enumerate() - .map(|(batch_index, result)| (indices[batch_index], result)) - .collect() - } - #[cfg(not(feature = "icicle"))] - { - unreachable!( - "icicle_init must not return true without the icicle feature" + indices.iter().map(|i| scalar_batches[*i]).collect(); + let batch_results = icicle_batch_msm::( + gpu_bases, + &scalar_batches, + msm_type.max_num_bits(), ); - } - } else { - indices - .into_par_iter() - .map(|i| { - let scalars = scalars[i]; - let scalars_u64 = &map_field_elements_to_u64::(scalars); - (i, msm_binary(bases, scalars_u64)) - }) - .collect() - } - } - - MsmType::Small => { - if use_icicle() { - #[cfg(feature = "icicle")] - { - let scalar_batches: Vec<&[Self::ScalarField]> = - indices.iter().map(|i| scalars[*i]).collect(); - let batch_results = - icicle_batch_msm::(gpu_bases, &scalar_batches, 10); assert_eq!(batch_results.len(), scalar_batches.len()); batch_results .into_iter() @@ -323,88 +228,16 @@ pub trait VariableBaseMSM: ScalarMul + Icicle { indices .into_par_iter() .map(|i| { - let scalars = scalars[i]; - let scalars_u64 = &map_field_elements_to_u64::(scalars); - (i, msm_small(bases, scalars_u64, 10)) - }) - .collect() - } - } - - MsmType::Medium => { - if use_icicle() { - #[cfg(feature = "icicle")] - { - let scalar_batches: Vec<&[Self::ScalarField]> = - indices.iter().map(|i| scalars[*i]).collect(); - let batch_results = - icicle_batch_msm::(gpu_bases, &scalar_batches, 64); - assert_eq!(batch_results.len(), scalar_batches.len()); - batch_results - .into_iter() - .enumerate() - .map(|(batch_index, result)| (indices[batch_index], result)) - .collect() - } - #[cfg(not(feature = "icicle"))] - { - unreachable!( - "icicle_init must not return true without the icicle feature" - ); - } - } else { - indices - .into_par_iter() - .map(|i| { - let scalars = scalars[i]; - let scalars_u64 = &map_field_elements_to_u64::(scalars); - let result = if Self::NEGATION_IS_CHEAP { - msm_u64_wnaf(bases, scalars_u64, 64) - } else { - msm_u64(bases, scalars_u64, 64) - }; - (i, result) - }) - .collect() - } - } - - MsmType::Large => { - if use_icicle() { - #[cfg(feature = "icicle")] - { - let scalar_batches: Vec<&[Self::ScalarField]> = - indices.iter().map(|i| scalars[*i]).collect(); - let batch_results = - icicle_batch_msm::(gpu_bases, &scalar_batches, 256); - assert_eq!(batch_results.len(), scalar_batches.len()); - batch_results - .into_iter() - .enumerate() - .map(|(batch_index, result)| (indices[batch_index], result)) - .collect() - } - #[cfg(not(feature = "icicle"))] - { - unreachable!( - "icicle_init must not return true without the icicle feature" - ); - } - } else { - indices - .into_par_iter() - .map(|i| { - let scalars = scalars[i]; - let scalars_bigint = scalars - .par_iter() - .map(|s| s.into_bigint()) - .collect::>(); - let result: Self = if Self::NEGATION_IS_CHEAP { - msm_bigint_wnaf(bases, &scalars_bigint, 256) - } else { - msm_bigint(bases, &scalars_bigint, 256) - }; - (i, result) + let scalars = scalar_batches[*i]; + ( + *i, + Self::msm( + &bases[..scalars.len()], + gpu_bases, + scalar_batches[*i], + ) + .unwrap(), + ) }) .collect() } @@ -416,78 +249,44 @@ pub trait VariableBaseMSM: ScalarMul + Icicle { } }; - let mut zero_indices = Vec::new(); - let mut one_indices = Vec::new(); - let mut small_indices = Vec::new(); - let mut medium_indices = Vec::new(); - let mut large_indices = Vec::new(); - - for (i, msm_type) in telemetry { - match msm_type { - MsmType::Zero => zero_indices.push(i), - MsmType::One => one_indices.push(i), - MsmType::Small => small_indices.push(i), - MsmType::Medium => medium_indices.push(i), - MsmType::Large => large_indices.push(i), - } - } - - run_msm(zero_indices, MsmType::Zero, &mut results); - run_msm(one_indices, MsmType::One, &mut results); - run_msm(small_indices, MsmType::Small, &mut results); - - { - let span = tracing::span!(tracing::Level::INFO, "medium_indices"); - let _guard = span.enter(); - medium_indices.chunks(slices_at_a_time).for_each(|chunk| { - run_msm(chunk.to_vec(), MsmType::Medium, &mut results); + let span = tracing::span!(tracing::Level::INFO, "smaller_indices"); + let _guard = span.enter(); + run_msm(&grouped_scalar_indices[0], MsmType::Zero, &mut results); + run_msm(&grouped_scalar_indices[1], MsmType::One, &mut results); + run_msm(&grouped_scalar_indices[2], MsmType::Small, &mut results); + drop(_guard); + drop(span); + + let span = tracing::span!(tracing::Level::INFO, "medium_indices"); + let _guard = span.enter(); + grouped_scalar_indices[3] + .chunks(slices_at_a_time) + .for_each(|chunk| { + run_msm(chunk, MsmType::Medium, &mut results); }); - drop(_guard); - } - - { - let span = tracing::span!(tracing::Level::INFO, "large_indices"); - let _guard = span.enter(); - large_indices.chunks(slices_at_a_time).for_each(|chunk| { - run_msm(chunk.to_vec(), MsmType::Large, &mut results); + drop(_guard); + drop(span); + + let span = tracing::span!(tracing::Level::INFO, "large_indices"); + let _guard = span.enter(); + grouped_scalar_indices[4] + .chunks(slices_at_a_time) + .for_each(|chunk| { + run_msm(chunk, MsmType::Large, &mut results); }); - drop(_guard); - } + drop(_guard); + drop(span); results } - #[tracing::instrument(skip_all, name = "variable_batch_msm")] - fn variable_batch_msm( - bases: &[Self::MulBase], - gpu_bases: Option<&[GpuBaseType]>, - scalar_batches: &[&[Self::ScalarField]], - ) -> Vec { - assert!(scalar_batches.iter().all(|s| s.len() <= bases.len())); - #[cfg(not(feature = "icicle"))] - assert!(gpu_bases.is_none()); - - if use_icicle() { - #[cfg(feature = "icicle")] - { - //TODO(sagar) this is faster without the variable batch msm - let mut backup = vec![]; - let gpu_bases = gpu_bases.unwrap_or_else(|| { - backup = Self::get_gpu_bases(bases); - &backup - }); - icicle_variable_batch_msm::(gpu_bases, scalar_batches, 256) - } - #[cfg(not(feature = "icicle"))] - { - unreachable!("icicle_init must not return true without the icicle feature"); - } - } else { - scalar_batches - .par_iter() - .map(|scalars| Self::msm(&bases[0..scalars.len()], None, scalars).unwrap()) - .collect() - } + #[cfg(feature = "icicle")] + #[tracing::instrument(skip_all)] + fn get_gpu_bases(bases: &[Self::MulBase]) -> Vec> { + bases + .par_iter() + .map(|base| ::from_ark_affine(base)) + .collect() } } diff --git a/jolt-core/src/poly/commitment/hyperkzg.rs b/jolt-core/src/poly/commitment/hyperkzg.rs index aa101399b..ca78a90d9 100644 --- a/jolt-core/src/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/poly/commitment/hyperkzg.rs @@ -9,16 +9,15 @@ //! and within the KZG commitment scheme implementation itself). use super::{ commitment_scheme::{BatchType, CommitmentScheme}, - kzg, kzg::{KZGProverKey, KZGVerifierKey, UnivariateKZG}, }; use crate::field::JoltField; -use crate::msm::MsmType; use crate::poly::commitment::commitment_scheme::CommitShape; +use crate::poly::commitment::kzg::CommitMode; use crate::utils::mul_0_1_optimized; use crate::utils::thread::unsafe_allocate_zero_vec; use crate::utils::transcript::Transcript; -use crate::{field, into_optimal_iter, msm, optimal_iter}; +use crate::{field, into_optimal_iter}; use crate::{ msm::{Icicle, VariableBaseMSM}, poly::{commitment::kzg::SRS, dense_mlpoly::DensePolynomial, unipoly::UniPoly}, @@ -567,64 +566,16 @@ where gens: &Self::Setup, batch_type: BatchType, ) -> Vec { - // TODO(sagar): How to use unsafe_allocate_zero_vec here? - let mut commitments = vec![Self::Commitment::default(); evals.len()]; - // first group the evals by num_bits for msms - let res = msm::group_scalars_by_msm_type::

(evals) - .iter() - .map(|(msm_type, evals)| { - let process_evals = |original_idx: usize, evals: &&[Self::Field]| { - assert!( - gens.0.kzg_pk.g1_powers().len() >= evals.len(), - "COMMIT KEY LENGTH ERROR {}, {}", - gens.0.kzg_pk.g1_powers().len(), - evals.len() - ); - let commitment = match batch_type { - BatchType::GrandProduct => HyperKZGCommitment( - UnivariateKZG::commit_slice_with_mode( - &gens.0.kzg_pk, - evals, - kzg::CommitMode::GrandProduct, - ) - .unwrap(), - ), - _ => HyperKZGCommitment( - UnivariateKZG::commit_slice(&gens.0.kzg_pk, evals).unwrap(), - ), - }; - (original_idx, commitment) - }; - - // depending on the type of msm, we will use different batching strategies - match msm_type { - MsmType::Zero | MsmType::One => { - // always run these in parallel - they are the fastest on CPU alone - evals - .par_iter() - .map(|(original_idx, evals)| process_evals(*original_idx, evals)) - .collect::>() - } - MsmType::Small | MsmType::Medium | MsmType::Large => { - // automatically choose between parallel and sequential iteration for larger MSMS - // This is done so that icicle's GPU performance remains strong - optimal_iter!(evals) - .map(|(original_idx, evals)| process_evals(*original_idx, evals)) - .collect::>() - } - } - }) - .collect::>(); - - let span = trace_span!("flatten_results"); - let _guard = span.enter(); - res.into_iter().flatten().for_each(|(idx, commitment)| { - commitments[idx] = commitment; - }); - drop(_guard); - drop(span); + let mode = match batch_type { + BatchType::GrandProduct => CommitMode::GrandProduct, + _ => CommitMode::Default, + }; - commitments + UnivariateKZG::commit_batch_with_mode(&gens.0.kzg_pk, evals, mode) + .unwrap() + .into_par_iter() + .map(|c| HyperKZGCommitment(c)) + .collect() } fn commit_slice(evals: &[Self::Field], setup: &Self::Setup) -> Self::Commitment { diff --git a/jolt-core/src/poly/commitment/kzg.rs b/jolt-core/src/poly/commitment/kzg.rs index 27a9d8c4e..0a3dd46b2 100644 --- a/jolt-core/src/poly/commitment/kzg.rs +++ b/jolt-core/src/poly/commitment/kzg.rs @@ -1,5 +1,6 @@ use crate::field::JoltField; use crate::msm::{GpuBaseType, Icicle, VariableBaseMSM}; +use crate::optimal_iter; use crate::poly::unipoly::UniPoly; use crate::utils::errors::ProofVerifyError; use ark_ec::scalar_mul::fixed_base::FixedBase; @@ -189,6 +190,102 @@ where P::ScalarField: JoltField, P::G1: Icicle, { + #[tracing::instrument(skip_all, name = "KZG::commit_batch")] + pub fn commit_batch( + pk: &KZGProverKey

, + coeffs: &[&[P::ScalarField]], + ) -> Result, ProofVerifyError> { + Self::commit_batch_with_mode(pk, coeffs, CommitMode::Default) + } + + #[tracing::instrument(skip_all, name = "KZG::commit_batch_with_mode")] + pub fn commit_batch_with_mode( + pk: &KZGProverKey

, + batches: &[&[P::ScalarField]], + mode: CommitMode, + ) -> Result, ProofVerifyError> { + if let Some(invalid) = batches + .iter() + .find(|coeffs| coeffs.len() > pk.g1_powers().len()) + { + return Err(ProofVerifyError::KeyLengthError( + pk.g1_powers().len(), + invalid.len(), + )); + } + + let g1_powers = &pk.g1_powers(); + let gpu_g1 = pk.gpu_g1().map(|g| &g); + + match mode { + CommitMode::Default => { + let commitments = + ::batch_msm(g1_powers, gpu_g1, batches); + Ok(commitments.into_iter().map(|c| c.into_affine()).collect()) + } + CommitMode::GrandProduct => { + // Commit to the non-1 coefficients first then combine them with the G commitment (all-1s vector) in the SRS + let (non_one_coeffs, (non_one_bases, non_one_gpu_bases)): ( + Vec<_>, + (Vec<_>, Vec<_>), + ) = batches + .par_iter() + .map(|coeff| { + let (coeffs, (bases, gpu_bases)): (Vec<_>, (Vec<_>, Vec<_>)) = coeff + .par_iter() + .enumerate() + .filter_map(|(i, coeff)| { + if *coeff != P::ScalarField::one() { + let gpu_base = gpu_g1.map(|g| g[i]); + // Subtract 1 from the coeff because we already have a commitment to the all the 1s + Some((*coeff - P::ScalarField::one(), (g1_powers[i], gpu_base))) + } else { + None + } + }) + .unzip(); + let gpu_bases: Option> = gpu_bases.into_iter().collect(); + (coeffs, (bases, gpu_bases)) + }) + .unzip(); + + // Perform MSM for the non-1 coefficients + assert_eq!(non_one_bases.len(), non_one_coeffs.len()); + let commitments = optimal_iter!(non_one_coeffs) + .enumerate() + .map(|(i, coeffs)| { + let non_one_commitment = if !coeffs.is_empty() { + ::msm( + &non_one_bases[i], + non_one_gpu_bases[i].as_deref(), + coeffs, + ) + .unwrap() + } else { + P::G1::zero() + }; + + // find the right precomputed g_product to use + let num_powers = (coeffs.len() as f64).log2(); + assert_ne!( + num_powers.fract(), + 0.0, + "Invalid key length: {}", + coeffs.len() + ); + let num_powers = num_powers.floor() as usize; + + // Combine G * H: Multiply the precomputed G commitment with the non-1 commitment (H) + let final_commitment = pk.srs.g_products[num_powers] + non_one_commitment; + final_commitment.into_affine() + }) + .collect(); + + Ok(commitments) + } + } + } + #[tracing::instrument(skip_all, name = "KZG::commit_offset")] pub fn commit_offset( pk: &KZGProverKey

, diff --git a/jolt-core/src/poly/commitment/zeromorph.rs b/jolt-core/src/poly/commitment/zeromorph.rs index 64f3edd51..8c9b7e855 100644 --- a/jolt-core/src/poly/commitment/zeromorph.rs +++ b/jolt-core/src/poly/commitment/zeromorph.rs @@ -292,7 +292,7 @@ where let quotient_slices: Vec<&[P::ScalarField]> = quotients.iter().map(|q| q.coeffs.as_slice()).collect(); let quotient_max_len = quotient_slices.iter().map(|s| s.len()).max().unwrap(); - let q_comms: Vec = P::G1::variable_batch_msm( + let q_comms: Vec = ::batch_msm( &pp.commit_pp.g1_powers()[..quotient_max_len], pp.commit_pp.gpu_g1(), "ient_slices,