Skip to content

Commit

Permalink
fix bug in preprocessing gpu g1s
Browse files Browse the repository at this point in the history
  • Loading branch information
sagar-a16z committed Nov 19, 2024
1 parent 705070a commit 120f46b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
29 changes: 25 additions & 4 deletions jolt-core/src/msm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ impl MsmType {
MsmType::Large => 256,
}
}

fn use_icicle(&self) -> bool {
match self {
MsmType::Zero | MsmType::One | MsmType::Small => false,
#[cfg(feature = "icicle")]
MsmType::Medium | MsmType::Large => true,
#[cfg(not(feature = "icicle"))]
_ => false,
}
}
}

type TrackedScalar<'a, P: Pairing> = (usize, &'a [P::ScalarField]);
Expand All @@ -62,6 +72,7 @@ pub trait VariableBaseMSM: ScalarMul + Icicle {
) -> Result<Self, usize> {
#[cfg(not(feature = "icicle"))]
assert!(gpu_bases.is_none());
assert_eq!(bases.len(), gpu_bases.map_or(bases.len(), |b| b.len()));

(bases.len() == scalars.len())
.then(|| {
Expand All @@ -84,7 +95,7 @@ pub trait VariableBaseMSM: ScalarMul + Icicle {
}
MsmType::Medium => {
// TODO(sagar) caching this as "use_icicle = use_icicle" seems to cause a massive slowdown
if use_icicle() {
if use_icicle() && msm_type.use_icicle() {
#[cfg(feature = "icicle")]
{
let mut backup = vec![];
Expand Down Expand Up @@ -114,7 +125,7 @@ pub trait VariableBaseMSM: ScalarMul + Icicle {
}
}
MsmType::Large => {
if use_icicle() {
if use_icicle() && msm_type.use_icicle() {
#[cfg(feature = "icicle")]
{
let mut backup = vec![];
Expand Down Expand Up @@ -160,6 +171,16 @@ pub trait VariableBaseMSM: ScalarMul + Icicle {
assert!(scalar_batches.iter().all(|s| s.len() <= bases.len()));
#[cfg(not(feature = "icicle"))]
assert!(gpu_bases.is_none());
assert_eq!(bases.len(), gpu_bases.map_or(bases.len(), |b| b.len()));

if !use_icicle() {
let span = tracing::span!(tracing::Level::INFO, "batch_msm_cpu_only");
let _guard = span.enter();
return scalar_batches
.into_par_iter()
.map(|scalars| Self::msm(&bases[..scalars.len()], None, scalars).unwrap())
.collect();
}

let slice_bit_size = 256 * scalar_batches[0].len() * 3;
let slices_at_a_time = total_memory_bits() / slice_bit_size;
Expand Down Expand Up @@ -190,7 +211,7 @@ pub trait VariableBaseMSM: ScalarMul + Icicle {
let scalars = scalar_batches[*i];

Check warning on line 211 in jolt-core/src/msm/mod.rs

View workflow job for this annotation

GitHub Actions / fmt

Diff in /home/runner/work/jolt/jolt/jolt-core/src/msm/mod.rs
(
*i,
Self::msm(&bases[..scalars.len()], gpu_bases, scalar_batches[*i])
Self::msm(&bases[..scalars.len()], None, scalar_batches[*i])
.unwrap(),
)
})
Expand Down Expand Up @@ -233,7 +254,7 @@ pub trait VariableBaseMSM: ScalarMul + Icicle {
*i,
Self::msm(
&bases[..scalars.len()],
gpu_bases,
None,
scalar_batches[*i],
)
.unwrap(),
Expand Down
5 changes: 4 additions & 1 deletion jolt-core/src/poly/commitment/kzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ where
}

pub fn gpu_g1(&self) -> Option<&[GpuBaseType<P::G1>]> {
self.srs.gpu_g1.as_deref()
self.srs
.gpu_g1
.as_ref()
.map(|gpu_g1| &gpu_g1[self.offset..self.offset + self.supported_size])
}
}

Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/poly/commitment/zeromorph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ where
let quotient_max_len = quotient_slices.iter().map(|s| s.len()).max().unwrap();
let q_comms: Vec<P::G1> = <P::G1 as VariableBaseMSM>::batch_msm(
&pp.commit_pp.g1_powers()[..quotient_max_len],
pp.commit_pp.gpu_g1(),
pp.commit_pp.gpu_g1().map(|g| &g[..quotient_max_len]),
&quotient_slices,
);
let q_k_com: Vec<P::G1Affine> = q_comms.iter().map(|q| q.into_affine()).collect();
Expand Down

0 comments on commit 120f46b

Please sign in to comment.