-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New Danksharding flow #17
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
[submodule "icicle"] | ||
path = icicle | ||
url = https://github.com/ingonyama-zk/icicle.git | ||
url = https://github.com/DmytroTym/icicle/ | ||
branch = new_api |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,19 @@ | ||
use std::time::Instant; | ||
|
||
use rustacuda::prelude::*; | ||
use icicle_utils::{field::Point, *}; | ||
|
||
use crate::{matrix::*, utils::*, *}; | ||
|
||
pub const FLOW_SIZE: usize = 1 << 12; //4096 //prod flow size | ||
pub const TEST_SIZE_DIV: usize = 1; //TODO: Prod size / test size for speedup | ||
pub const TEST_SIZE: usize = FLOW_SIZE / TEST_SIZE_DIV; //test flow size | ||
pub const M_POINTS: usize = TEST_SIZE; | ||
pub const LOG_TEST_SIZE_DIV: usize = 3; //TODO: Prod size / test size for speedup | ||
pub const TEST_SIZE_DIV: usize = 1 << LOG_TEST_SIZE_DIV; //TODO: Prod size / test size for speedup | ||
pub const M_POINTS: usize = FLOW_SIZE / TEST_SIZE_DIV; //test flow size | ||
pub const LOG_M_POINTS: usize = 12 - LOG_TEST_SIZE_DIV; | ||
pub const SRS_SIZE: usize = M_POINTS; | ||
pub const S_GROUP_SIZE: usize = 2 * M_POINTS; | ||
pub const N_ROWS: usize = 256 / TEST_SIZE_DIV; | ||
pub const LOG_N_ROWS: usize = 8 - LOG_TEST_SIZE_DIV; | ||
pub const FOLD_SIZE: usize = 512 / TEST_SIZE_DIV; | ||
|
||
//TODO: the casing is to match diagram | ||
|
@@ -252,15 +255,154 @@ pub fn main_flow() { | |
); | ||
|
||
assert_ne!(P[12][23], Point::zero()); //dummy check | ||
println!("success !!!",); | ||
println!("success !!!"); | ||
} | ||
|
||
#[allow(non_snake_case)] | ||
#[allow(non_upper_case_globals)] | ||
pub fn alternate_flow() { | ||
let D_in_host = get_debug_data_scalar_vec("D_in.csv"); | ||
let SRS_host = get_debug_data_points_proj_xy1_vec("SRS.csv", M_POINTS); | ||
//TODO: now S is preprocessed, copy preprocessing here | ||
let S_host = get_debug_data_points_proj_xy1_vec("S.csv", 2 * M_POINTS); | ||
|
||
const l: usize = 16; | ||
println!("loaded test data, processing..."); | ||
|
||
let pre_time = Instant::now(); | ||
// set up the device | ||
let _ctx = rustacuda::quick_init(); | ||
// build domains (i.e. compute twiddle factors) | ||
let mut interpolate_row_domain = build_domain(M_POINTS, LOG_M_POINTS, true); | ||
let mut evaluate_row_domain = build_domain(M_POINTS, LOG_M_POINTS, false); | ||
let mut interpolate_column_domain = build_domain(N_ROWS, LOG_N_ROWS, true); | ||
let mut evaluate_column_domain = build_domain(N_ROWS, LOG_N_ROWS, false); | ||
let mut interpolate_column_large_domain = build_domain(2 * N_ROWS, LOG_N_ROWS + 1, true); | ||
let mut evaluate_column_large_domain = build_domain(2 * N_ROWS, LOG_N_ROWS + 1, false); | ||
// build cosets (i.e. powers of roots of unity `w` and `v`) | ||
let mut row_coset = build_domain(M_POINTS, LOG_M_POINTS + 1, false); | ||
let mut column_coset = build_domain(N_ROWS, LOG_N_ROWS + 1, false); | ||
// transfer `D_in` into device memory | ||
let mut D_in = DeviceBuffer::from_slice(&D_in_host[..]).unwrap(); | ||
// transfer the SRS into device memory | ||
debug_assert!(SRS_host[0].to_ark_affine().is_on_curve()); | ||
let SRS_affine: Vec<_> = vec![SRS_host.iter().map(|p| p.to_xy_strip_z()).collect::<Vec<_>>(); N_ROWS].concat(); | ||
let mut SRS = DeviceBuffer::from_slice(&SRS_affine[..]).unwrap(); | ||
// transfer S into device memory after suitable bit-reversal | ||
let S_host_rbo = list_to_reverse_bit_order(&S_host[..])[..].chunks(l).map(|chunk| list_to_reverse_bit_order(chunk)).collect::<Vec<_>>().concat(); | ||
let S_affine: Vec<_> = vec![S_host_rbo.iter().map(|p| p.to_xy_strip_z()).collect::<Vec<_>>(); 2 * N_ROWS].concat(); | ||
let mut S = DeviceBuffer::from_slice(&S_affine[..]).unwrap(); | ||
|
||
println!("pre-computation {:0.3?}", pre_time.elapsed()); | ||
|
||
//C_rows = INTT_rows(D_in) | ||
reverse_order_scalars_batch(&mut D_in, N_ROWS); | ||
let mut C_rows = interpolate_scalars_batch(&mut D_in, &mut interpolate_row_domain, N_ROWS); | ||
|
||
println!("pre-branch {:0.3?}", pre_time.elapsed()); | ||
|
||
//////////////////////////////// | ||
println!("Branch 1"); | ||
//////////////////////////////// | ||
let br1_time = Instant::now(); | ||
|
||
// K0 = MSM_rows(C_rows) (256x1) | ||
let mut K0 = commit_batch(&mut SRS, &mut C_rows, N_ROWS); | ||
let mut K = vec![Point::zero(); 2 * N_ROWS]; | ||
K0.copy_to(&mut K[..N_ROWS]).unwrap(); | ||
println!("K0 {:0.3?}", br1_time.elapsed()); | ||
|
||
reverse_order_points(&mut K0); | ||
// B0 = ECINTT_col(K0) N_POINTS x 1 (256x1) | ||
let mut B0 = interpolate_points(&mut K0, &mut interpolate_column_domain); | ||
println!("B0 {:0.3?}", br1_time.elapsed()); | ||
|
||
// K1 = ECNTT_col(MUL_col(B0, [1 u u^2 ...])) N_POINTS x 1 (256x1) | ||
let mut K1 = evaluate_points_on_coset(&mut B0, &mut evaluate_column_domain, &mut column_coset); | ||
println!("K1 {:0.3?}", br1_time.elapsed()); | ||
reverse_order_points(&mut K1); | ||
|
||
// K = [K0, K1] // 2*N_POINTS x 1 (512x1 commitments) | ||
K1.copy_to(&mut K[N_ROWS..]).unwrap(); | ||
println!("K {:0.3?}", br1_time.elapsed()); | ||
|
||
assert_eq!(K, get_debug_data_points_proj_xy1_vec("K.csv", 2 * N_ROWS)); | ||
println!("Branch1 {:0.3?}", br1_time.elapsed()); | ||
|
||
//////////////////////////////// | ||
println!("Branch 2"); | ||
//////////////////////////////// | ||
let br2_time = Instant::now(); | ||
|
||
let mut D_rows = evaluate_scalars_on_coset_batch(&mut C_rows, &mut evaluate_row_domain, N_ROWS, &mut row_coset); | ||
println!("D_both {:0.3?}", br2_time.elapsed()); | ||
|
||
let mut D_transposed = unsafe { DeviceBuffer::uninitialized(2 * N_ROWS * M_POINTS).unwrap() }; | ||
transpose_scalar_matrix(&mut D_transposed.as_device_ptr(), &mut D_in, M_POINTS, N_ROWS); | ||
transpose_scalar_matrix(&mut D_transposed.as_device_ptr().wrapping_offset((N_ROWS * M_POINTS) as isize), &mut D_rows, M_POINTS, N_ROWS); | ||
|
||
let mut D = unsafe { DeviceBuffer::uninitialized(4 * N_ROWS * M_POINTS).unwrap() }; | ||
transpose_scalar_matrix(&mut D.as_device_ptr(), &mut D_transposed, N_ROWS, 2 * M_POINTS); | ||
|
||
reverse_order_scalars_batch(&mut D_transposed, 2 * M_POINTS); | ||
let mut C0 = interpolate_scalars_batch(&mut D_transposed, &mut interpolate_column_domain, 2 * M_POINTS); | ||
let mut D_cols = evaluate_scalars_on_coset_batch(&mut C0, &mut evaluate_column_domain, 2 * M_POINTS, &mut column_coset); | ||
reverse_order_scalars_batch(&mut D_cols, 2 * M_POINTS); | ||
|
||
transpose_scalar_matrix(&mut D.as_device_ptr().wrapping_offset((2 * N_ROWS * M_POINTS) as isize), &mut D_cols, N_ROWS, 2 * M_POINTS); | ||
|
||
let mut D_host_flat = vec![ScalarField::zero(); 4 * N_ROWS * M_POINTS]; | ||
D.copy_to(&mut D_host_flat[..]).unwrap(); | ||
let D_host = D_host_flat.chunks(2 * M_POINTS).collect::<Vec<_>>(); | ||
|
||
println!("Branch2 {:0.3?}", br2_time.elapsed()); | ||
debug_assert_eq!(D_host, get_debug_data_scalars("D.csv", 2 * N_ROWS, 2 * M_POINTS)); | ||
|
||
//////////////////////////////// | ||
println!("Branch 3"); | ||
//////////////////////////////// | ||
let br3_time = Instant::now(); | ||
|
||
//d1 = MSM_batch(D[i], [S], l) 1x8192 | ||
reverse_order_scalars_batch(&mut D, (4 * M_POINTS * N_ROWS) / l); | ||
let mut d1 = commit_batch(&mut S, &mut D, (4 * M_POINTS * N_ROWS) / l); | ||
|
||
//delta0 = ECINTT_row(d1) 1x512 | ||
let mut delta0 = interpolate_points_batch(&mut d1, &mut interpolate_column_large_domain, 2 * N_ROWS); | ||
|
||
// delta0 = delta0 << 256 1x512 | ||
shift_points_batch(&mut delta0, 2 * N_ROWS); | ||
|
||
//q[mu] = ECNTT_row(delta0) 1x512 | ||
let P = evaluate_points_batch(&mut delta0, &mut evaluate_column_large_domain, 2 * N_ROWS); | ||
let mut P_host_flat: Vec<Point> = (0..(4 * M_POINTS * N_ROWS) / l).map(|_| Point::zero()).collect(); | ||
P.copy_to(&mut P_host_flat[..]).unwrap(); | ||
let P_host = split_vec_to_matrix(&P_host_flat, 2 * N_ROWS).clone(); | ||
|
||
//final assertion | ||
debug_assert_eq!( | ||
P_host, | ||
get_debug_data_points_xy1("P.csv", 2 * N_ROWS, 2 * N_ROWS) | ||
); | ||
println!("final check"); | ||
|
||
println!("Branch3 {:0.3?}", br3_time.elapsed()); | ||
|
||
assert_ne!(P_host[12][23], Point::zero()); //dummy check | ||
println!("success !!!"); | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::main_flow; | ||
use super::{main_flow, alternate_flow}; | ||
|
||
#[test] | ||
fn test_main_flow() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure current "main_flow" worth keeping if new API is correct and faster There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can remove current main_flow if everyone's happy with the new one, I left it for the sake of easier comparison for now |
||
main_flow(); | ||
} | ||
|
||
#[test] | ||
fn test_alternate_flow() { | ||
alternate_flow(); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scalar_t ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, changed to
scalar_field_t
although why do we have two separate names for the same thing?