Skip to content

Commit

Permalink
rewrite hyperkzg batch_commit and cleanup redundant batch msm impleme…
Browse files Browse the repository at this point in the history
…ntations
  • Loading branch information
sagar-a16z committed Nov 18, 2024
1 parent db5a9d7 commit a9659df
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 475 deletions.
5 changes: 1 addition & 4 deletions jolt-core/src/jolt/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,7 @@ impl<F: JoltField> JoltPolynomials<F> {
commitments.init_final_values().len(),
);

let span = tracing::span!(
tracing::Level::INFO,
"commit::commit_instructions_final_cts"
);
let span = tracing::span!(tracing::Level::INFO, "commit::commit_bytecode.t_final");
let _guard = span.enter();
commitments.bytecode.t_final =
PCS::commit(&self.bytecode.t_final, &preprocessing.generators);
Expand Down
112 changes: 14 additions & 98 deletions jolt-core/src/msm/icicle/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ pub trait Icicle: ScalarMul {
pub fn icicle_msm<V: VariableBaseMSM>(
bases: &[GpuBaseType<V>],
scalars: &[V::ScalarField],
bit_size: i32,
bit_size: usize,
) -> V {
assert!(scalars.len() <= bases.len());

let mut bases_slice = DeviceVec::<GpuBaseType<V>>::device_malloc(bases.len()).unwrap();

let span = tracing::span!(tracing::Level::INFO, "convert_scalars");
Expand Down Expand Up @@ -94,14 +96,14 @@ pub fn icicle_msm<V: VariableBaseMSM>(
cfg.stream_handle = IcicleStreamHandle::from(&stream);
cfg.is_async = false;
cfg.are_scalars_montgomery_form = true;
cfg.bitsize = bit_size;
cfg.bitsize = bit_size as i32;

let span = tracing::span!(tracing::Level::INFO, "gpu_msm");
let _guard = span.enter();

msm(
&scalars_slice[..],
&bases_slice[..],
&bases_slice[..scalars.len()],
&cfg,
&mut msm_result[..],
)
Expand Down Expand Up @@ -129,14 +131,14 @@ pub fn icicle_msm<V: VariableBaseMSM>(
pub fn icicle_batch_msm<V: VariableBaseMSM>(
bases: &[GpuBaseType<V>],
scalar_batches: &[&[V::ScalarField]],
bit_size: i32,
bit_size: usize,
) -> Vec<V> {
let len = bases.len();
let bases_len = bases.len();
let batch_size = scalar_batches.len();
assert!(scalar_batches.iter().all(|s| s.len() == len));
assert!(scalar_batches.iter().all(|s| s.len() <= bases_len));

let mut stream = IcicleStream::create().unwrap();
let mut bases_slice = DeviceVec::<GpuBaseType<V>>::device_malloc(len).unwrap();
let mut bases_slice = DeviceVec::<GpuBaseType<V>>::device_malloc(bases_len).unwrap();
let span = tracing::span!(tracing::Level::INFO, "copy_bases_to_gpu");
let _guard = span.enter();
bases_slice
Expand All @@ -149,101 +151,15 @@ pub fn icicle_batch_msm<V: VariableBaseMSM>(
let mut msm_host_results = vec![Projective::<V::C>::zero(); batch_size];

for (batch_i, scalars) in scalar_batches.iter().enumerate() {
let scalars_len = scalars.len();
let span = tracing::span!(tracing::Level::INFO, "convert_scalars");
let _guard = span.enter();
let mut scalars_slice =
DeviceVec::<<<V as Icicle>::C as Curve>::ScalarField>::device_malloc_async(
len, &stream,
)
.unwrap();
let scalars_mont = unsafe {
&*(&scalars[..] as *const _ as *const [<<V as Icicle>::C as Curve>::ScalarField])
};
drop(_guard);
drop(span);

let span = tracing::span!(tracing::Level::INFO, "copy_scalars_to_gpu");
let _guard = span.enter();
scalars_slice
.copy_from_host_async(HostSlice::from_slice(scalars_mont), &stream)
.unwrap();
drop(_guard);
drop(span);

let mut cfg = MSMConfig::default();
cfg.stream_handle = IcicleStreamHandle::from(&stream);
cfg.is_async = true;
cfg.are_scalars_montgomery_form = true;
cfg.are_bases_montgomery_form = false;
cfg.bitsize = bit_size;

let span = tracing::span!(tracing::Level::INFO, "msm_gpu");
let _guard = span.enter();
msm(
&scalars_slice[..],
&bases_slice[..],
&cfg,
&mut msm_result[..],
)
.unwrap();
drop(_guard);
drop(span);

let span = tracing::span!(tracing::Level::INFO, "copy_msm_result");
let _guard = span.enter();
msm_result
.copy_to_host_async(
HostSlice::from_mut_slice(&mut msm_host_results[batch_i..(batch_i + 1)]),
&stream,
)
.unwrap();
drop(_guard);
drop(span);
}

stream.synchronize().unwrap();

stream.destroy().unwrap();
msm_host_results
.into_iter()
.map(|res| V::to_ark_projective(&res))
.collect()
}

#[tracing::instrument(skip_all, name = "icicle_batch_msm")]
/// MSM which allows scalar_batches of non-uniform size
pub fn icicle_variable_batch_msm<V: VariableBaseMSM>(
bases: &[GpuBaseType<V>],
scalar_batches: &[&[V::ScalarField]],
bit_size: i32,
) -> Vec<V> {
let base_len = bases.len();
let batch_size = scalar_batches.len();
assert!(scalar_batches.iter().all(|s| s.len() <= base_len));

let mut stream = IcicleStream::create().unwrap();
let mut bases_slice = DeviceVec::<GpuBaseType<V>>::device_malloc(base_len).unwrap();
let span = tracing::span!(tracing::Level::INFO, "copy_bases_to_gpu");
let _guard = span.enter();
bases_slice
.copy_from_host_async(HostSlice::from_slice(bases), &stream)
.unwrap();
drop(_guard);
drop(span);

let mut msm_result = DeviceVec::<Projective<V::C>>::device_malloc(1).unwrap();
let mut msm_host_results = vec![Projective::<V::C>::zero(); batch_size];

for (batch_i, scalars) in scalar_batches.iter().enumerate() {
let mut scalars_slice =
DeviceVec::<<<V as Icicle>::C as Curve>::ScalarField>::device_malloc_async(
scalars.len(),
scalars_len,
&stream,
)
.unwrap();

let span = tracing::span!(tracing::Level::INFO, "convert_scalars");
let _guard = span.enter();
let scalars_mont = unsafe {
&*(&scalars[..] as *const _ as *const [<<V as Icicle>::C as Curve>::ScalarField])
};
Expand All @@ -263,21 +179,21 @@ pub fn icicle_variable_batch_msm<V: VariableBaseMSM>(
cfg.is_async = true;
cfg.are_scalars_montgomery_form = true;
cfg.are_bases_montgomery_form = false;
cfg.bitsize = bit_size;
cfg.bitsize = bit_size as i32;

let span = tracing::span!(tracing::Level::INFO, "msm_gpu");
let _guard = span.enter();
msm(
&scalars_slice[..],
&bases_slice[..scalars.len()],
&bases_slice[..scalars_len],
&cfg,
&mut msm_result[..],
)
.unwrap();
drop(_guard);
drop(span);

let span = tracing::span!(tracing::Level::INFO, "copy_result");
let span = tracing::span!(tracing::Level::INFO, "copy_msm_result");
let _guard = span.enter();
msm_result
.copy_to_host_async(
Expand Down
Loading

0 comments on commit a9659df

Please sign in to comment.