Skip to content

Commit

Permalink
consistent tracing for icicle wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
sagar-a16z committed Nov 15, 2024
1 parent adf16c9 commit c68b873
Showing 1 changed file with 50 additions and 4 deletions.
54 changes: 50 additions & 4 deletions jolt-core/src/msm/icicle/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,26 @@ pub fn icicle_msm<V: VariableBaseMSM + Icicle>(
drop(span);

let mut stream = IcicleStream::create().unwrap();

let span = tracing::span!(tracing::Level::INFO, "copy_to_gpu");
let _guard = span.enter();
bases_slice
.copy_from_host_async(HostSlice::from_slice(bases), &stream)
.unwrap();
scalars_slice
.copy_from_host_async(HostSlice::from_slice(scalars_mont), &stream)
.unwrap();
drop(_guard);
drop(span);

let mut msm_result = DeviceVec::<Projective<V::C>>::device_malloc(1).unwrap();
let mut cfg = MSMConfig::default();
cfg.stream_handle = IcicleStreamHandle::from(&stream);
cfg.is_async = false;
cfg.are_scalars_montgomery_form = true;
cfg.bitsize = bit_size;

let span = tracing::span!(tracing::Level::INFO, "msm");
let span = tracing::span!(tracing::Level::INFO, "gpu_msm");
let _guard = span.enter();

msm(
Expand All @@ -103,15 +109,14 @@ pub fn icicle_msm<V: VariableBaseMSM + Icicle>(

drop(_guard);
drop(span);

let mut msm_host_result = [Projective::<V::C>::zero(); 1];

let span = tracing::span!(tracing::Level::INFO, "copy_msm_result");
let _guard = span.enter();

msm_result
.copy_to_host(HostSlice::from_mut_slice(&mut msm_host_result[..]))
.unwrap();

drop(_guard);
drop(span);

Expand All @@ -132,15 +137,19 @@ pub fn icicle_batch_msm<V: VariableBaseMSM + Icicle>(

let mut stream = IcicleStream::create().unwrap();
let mut bases_slice = DeviceVec::<Affine<V::C>>::device_malloc(len).unwrap();
let span = tracing::span!(tracing::Level::INFO, "copy_bases_to_gpu");
let _guard = span.enter();
bases_slice
.copy_from_host_async(HostSlice::from_slice(bases), &stream)
.unwrap();
drop(_guard);
drop(span);

let mut msm_result = DeviceVec::<Projective<V::C>>::device_malloc(1).unwrap();
let mut msm_host_results = vec![Projective::<V::C>::zero(); batch_size];

for (batch_i, scalars) in scalar_batches.iter().enumerate() {
let span = tracing::span!(tracing::Level::INFO, "msm_gpu");
let span = tracing::span!(tracing::Level::INFO, "convert_scalars");
let _guard = span.enter();
let mut scalars_slice =
DeviceVec::<<<V as Icicle>::C as Curve>::ScalarField>::device_malloc_async(
Expand All @@ -150,9 +159,16 @@ pub fn icicle_batch_msm<V: VariableBaseMSM + Icicle>(
let scalars_mont = unsafe {
&*(&scalars[..] as *const _ as *const [<<V as Icicle>::C as Curve>::ScalarField])
};
drop(_guard);
drop(span);

let span = tracing::span!(tracing::Level::INFO, "copy_scalars_to_gpu");
let _guard = span.enter();
scalars_slice
.copy_from_host_async(HostSlice::from_slice(scalars_mont), &stream)
.unwrap();
drop(_guard);
drop(span);

let mut cfg = MSMConfig::default();
cfg.stream_handle = IcicleStreamHandle::from(&stream);
Expand All @@ -161,20 +177,28 @@ pub fn icicle_batch_msm<V: VariableBaseMSM + Icicle>(
cfg.are_bases_montgomery_form = false;
cfg.bitsize = bit_size;

let span = tracing::span!(tracing::Level::INFO, "msm_gpu");
let _guard = span.enter();
msm(
&scalars_slice[..],
&bases_slice[..],
&cfg,
&mut msm_result[..],
)
.unwrap();
drop(_guard);
drop(span);

let span = tracing::span!(tracing::Level::INFO, "copy_msm_result");
let _guard = span.enter();
msm_result
.copy_to_host_async(
HostSlice::from_mut_slice(&mut msm_host_results[batch_i..(batch_i + 1)]),
&stream,
)
.unwrap();
drop(_guard);
drop(span);
}

stream.synchronize().unwrap();
Expand All @@ -199,9 +223,13 @@ pub fn icicle_variable_batch_msm<V: VariableBaseMSM + Icicle>(

let mut stream = IcicleStream::create().unwrap();
let mut bases_slice = DeviceVec::<Affine<V::C>>::device_malloc(base_len).unwrap();
let span = tracing::span!(tracing::Level::INFO, "copy_bases_to_gpu");
let _guard = span.enter();
bases_slice
.copy_from_host_async(HostSlice::from_slice(bases), &stream)
.unwrap();
drop(_guard);
drop(span);

let mut msm_result = DeviceVec::<Projective<V::C>>::device_malloc(1).unwrap();
let mut msm_host_results = vec![Projective::<V::C>::zero(); batch_size];
Expand All @@ -213,12 +241,22 @@ pub fn icicle_variable_batch_msm<V: VariableBaseMSM + Icicle>(
&stream,
)
.unwrap();

let span = tracing::span!(tracing::Level::INFO, "convert_scalars");
let _guard = span.enter();
let scalars_mont = unsafe {
&*(&scalars[..] as *const _ as *const [<<V as Icicle>::C as Curve>::ScalarField])
};
drop(_guard);
drop(span);

let span = tracing::span!(tracing::Level::INFO, "copy_scalars_to_gpu");
let _guard = span.enter();
scalars_slice
.copy_from_host_async(HostSlice::from_slice(scalars_mont), &stream)
.unwrap();
drop(_guard);
drop(span);

let mut cfg = MSMConfig::default();
cfg.stream_handle = IcicleStreamHandle::from(&stream);
Expand All @@ -227,20 +265,28 @@ pub fn icicle_variable_batch_msm<V: VariableBaseMSM + Icicle>(
cfg.are_bases_montgomery_form = false;
cfg.bitsize = bit_size;

let span = tracing::span!(tracing::Level::INFO, "msm_gpu");
let _guard = span.enter();
msm(
&scalars_slice[..],
&bases_slice[..scalars.len()],
&cfg,
&mut msm_result[..],
)
.unwrap();
drop(_guard);
drop(span);

let span = tracing::span!(tracing::Level::INFO, "copy_result");
let _guard = span.enter();
msm_result
.copy_to_host_async(
HostSlice::from_mut_slice(&mut msm_host_results[batch_i..(batch_i + 1)]),
&stream,
)
.unwrap();
drop(_guard);
drop(span);
}

stream.synchronize().unwrap();
Expand Down

0 comments on commit c68b873

Please sign in to comment.