Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Danksharding flow #17

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[submodule "icicle"]
path = icicle
url = https://github.com/ingonyama-zk/icicle.git
url = https://github.com/DmytroTym/icicle/
branch = new_api
4 changes: 3 additions & 1 deletion fast-danksharding/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ homepage = "https://www.ingonyama.com"
repository = "https://github.com/ingonyama-zk/fast-danksharding"

[dependencies]
icicle-utils = { git = "https://github.com/ingonyama-zk/icicle.git" }
icicle-utils = { git = "https://github.com/DmytroTym/icicle.git", rev = "ae4e696" }
hex="0.4.3"
ark-std = "0.3.0"
ark-ff = "0.3.0"
ark-poly = "0.3.0"
ark-ec = { version = "0.3.0", features = [ "parallel" ] }
ark-bls12-381 = { version = "0.3.0", optional = true }
rustacuda = "0.1"
rustacuda_core = "0.1"

[build-dependencies]
cc = { version = "1.0", features = ["parallel"] }
Expand Down
78 changes: 77 additions & 1 deletion fast-danksharding/src/cuda/lib.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include "../../../icicle/icicle/curves/curve_config.cuh"
#include <cuda.h>

const int TILE_DIM = 32;
const int BLOCK_ROWS = 8;
const int MAX_THREAD_NUM = 256;

template <typename P>
void point_sum(P *h_outputs, P *h_inputs, unsigned nof_rows, unsigned nof_cols, unsigned l);

Expand Down Expand Up @@ -49,7 +53,7 @@ void point_sum(P* h_outputs, P* h_inputs, unsigned nof_rows, unsigned nof_cols,
cudaFree(d_outputs);
}

extern "C" int sum_of_points(projective_t *out, projective_t in[], size_t nof_rows, size_t nof_cols, size_t l, size_t device_id = 0)
extern "C" int sum_of_points(projective_t *out, projective_t in[], size_t nof_rows, size_t nof_cols, size_t l)
{
try
{
Expand All @@ -61,5 +65,77 @@ extern "C" int sum_of_points(projective_t *out, projective_t in[], size_t nof_ro
{
printf("error %s", ex.what()); // TODO: error code and message
// out->z = 0; //TODO: .set_infinity()
return -1;
}
}

template <typename T>
__global__ void shift_kernel(T *arr, unsigned nof_rows, unsigned nof_cols_div_2)
{
// printf("block id: %d; thread id: %d \n", blockIdx.x, threadIdx.x);
unsigned id = blockIdx.x * MAX_THREAD_NUM + threadIdx.x;

if (id < nof_rows * nof_cols_div_2) {
unsigned col_id = id % nof_cols_div_2;
unsigned row_id = id / nof_cols_div_2;
arr[row_id * 2 * nof_cols_div_2 + col_id] = arr[(2 * row_id + 1) * nof_cols_div_2 + col_id];
arr[(2 * row_id + 1) * nof_cols_div_2 + col_id] = T::zero();
}
}

extern "C" int shift_batch(projective_t *arr, size_t nof_rows, size_t nof_cols_div_2)
{
try
{
int thread_num = MAX_THREAD_NUM;
int block_num = (nof_rows * nof_cols_div_2) / thread_num;
shift_kernel <projective_t> <<<block_num, thread_num>>> (arr, nof_rows, nof_cols_div_2);

return CUDA_SUCCESS;
}
catch (const std::runtime_error &ex)
{
printf("error %s", ex.what()); // TODO: error code and message
return -1;
}
}

// the shared-memory version of matrix transpose taken from here: https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
template <typename T>
__global__ void transpose_kernel(T *odata, const T *idata)
{
__shared__ T tile[TILE_DIM][TILE_DIM+1];

int x = blockIdx.x * TILE_DIM + threadIdx.x;
int y = blockIdx.y * TILE_DIM + threadIdx.y;
int width = gridDim.x * TILE_DIM;
int height = gridDim.y * TILE_DIM;

for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
tile[threadIdx.y+j][threadIdx.x] = idata[(y+j)*width + x];

__syncthreads();

x = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
y = blockIdx.x * TILE_DIM + threadIdx.y;

for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
odata[(y+j)*height + x] = tile[threadIdx.x][threadIdx.y+j];
}

extern "C" int transpose_matrix(scalar_field_t *out, scalar_field_t *in, size_t nof_rows, size_t nof_cols)
{
try
{
dim3 dimGrid(nof_rows / TILE_DIM, nof_cols / TILE_DIM, 1);
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
transpose_kernel <scalar_t> <<<dimGrid, dimBlock>>> (out, in);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scalar_t ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, changed to scalar_field_t although why do we have two separate names for the same thing?


return CUDA_SUCCESS;
}
catch (const std::runtime_error &ex)
{
printf("error %s", ex.what()); // TODO: error code and message
return -1;
}
}
152 changes: 147 additions & 5 deletions fast-danksharding/src/fast_danksharding.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
use std::time::Instant;

use rustacuda::prelude::*;
use icicle_utils::{field::Point, *};

use crate::{matrix::*, utils::*, *};

pub const FLOW_SIZE: usize = 1 << 12; //4096 //prod flow size
pub const TEST_SIZE_DIV: usize = 1; //TODO: Prod size / test size for speedup
pub const TEST_SIZE: usize = FLOW_SIZE / TEST_SIZE_DIV; //test flow size
pub const M_POINTS: usize = TEST_SIZE;
pub const LOG_TEST_SIZE_DIV: usize = 3; //TODO: Prod size / test size for speedup
pub const TEST_SIZE_DIV: usize = 1 << LOG_TEST_SIZE_DIV; //TODO: Prod size / test size for speedup
pub const M_POINTS: usize = FLOW_SIZE / TEST_SIZE_DIV; //test flow size
pub const LOG_M_POINTS: usize = 12 - LOG_TEST_SIZE_DIV;
pub const SRS_SIZE: usize = M_POINTS;
pub const S_GROUP_SIZE: usize = 2 * M_POINTS;
pub const N_ROWS: usize = 256 / TEST_SIZE_DIV;
pub const LOG_N_ROWS: usize = 8 - LOG_TEST_SIZE_DIV;
pub const FOLD_SIZE: usize = 512 / TEST_SIZE_DIV;

//TODO: the casing is to match diagram
Expand Down Expand Up @@ -252,15 +255,154 @@ pub fn main_flow() {
);

assert_ne!(P[12][23], Point::zero()); //dummy check
println!("success !!!",);
println!("success !!!");
}

#[allow(non_snake_case)]
#[allow(non_upper_case_globals)]
pub fn alternate_flow() {
let D_in_host = get_debug_data_scalar_vec("D_in.csv");
let SRS_host = get_debug_data_points_proj_xy1_vec("SRS.csv", M_POINTS);
//TODO: now S is preprocessed, copy preprocessing here
let S_host = get_debug_data_points_proj_xy1_vec("S.csv", 2 * M_POINTS);

const l: usize = 16;
println!("loaded test data, processing...");

let pre_time = Instant::now();
// set up the device
let _ctx = rustacuda::quick_init();
// build domains (i.e. compute twiddle factors)
let mut interpolate_row_domain = build_domain(M_POINTS, LOG_M_POINTS, true);
let mut evaluate_row_domain = build_domain(M_POINTS, LOG_M_POINTS, false);
let mut interpolate_column_domain = build_domain(N_ROWS, LOG_N_ROWS, true);
let mut evaluate_column_domain = build_domain(N_ROWS, LOG_N_ROWS, false);
let mut interpolate_column_large_domain = build_domain(2 * N_ROWS, LOG_N_ROWS + 1, true);
let mut evaluate_column_large_domain = build_domain(2 * N_ROWS, LOG_N_ROWS + 1, false);
// build cosets (i.e. powers of roots of unity `w` and `v`)
let mut row_coset = build_domain(M_POINTS, LOG_M_POINTS + 1, false);
let mut column_coset = build_domain(N_ROWS, LOG_N_ROWS + 1, false);
// transfer `D_in` into device memory
let mut D_in = DeviceBuffer::from_slice(&D_in_host[..]).unwrap();
// transfer the SRS into device memory
debug_assert!(SRS_host[0].to_ark_affine().is_on_curve());
let SRS_affine: Vec<_> = vec![SRS_host.iter().map(|p| p.to_xy_strip_z()).collect::<Vec<_>>(); N_ROWS].concat();
let mut SRS = DeviceBuffer::from_slice(&SRS_affine[..]).unwrap();
// transfer S into device memory after suitable bit-reversal
let S_host_rbo = list_to_reverse_bit_order(&S_host[..])[..].chunks(l).map(|chunk| list_to_reverse_bit_order(chunk)).collect::<Vec<_>>().concat();
let S_affine: Vec<_> = vec![S_host_rbo.iter().map(|p| p.to_xy_strip_z()).collect::<Vec<_>>(); 2 * N_ROWS].concat();
let mut S = DeviceBuffer::from_slice(&S_affine[..]).unwrap();

println!("pre-computation {:0.3?}", pre_time.elapsed());

//C_rows = INTT_rows(D_in)
reverse_order_scalars_batch(&mut D_in, N_ROWS);
let mut C_rows = interpolate_scalars_batch(&mut D_in, &mut interpolate_row_domain, N_ROWS);

println!("pre-branch {:0.3?}", pre_time.elapsed());

////////////////////////////////
println!("Branch 1");
////////////////////////////////
let br1_time = Instant::now();

// K0 = MSM_rows(C_rows) (256x1)
let mut K0 = commit_batch(&mut SRS, &mut C_rows, N_ROWS);
let mut K = vec![Point::zero(); 2 * N_ROWS];
K0.copy_to(&mut K[..N_ROWS]).unwrap();
println!("K0 {:0.3?}", br1_time.elapsed());

reverse_order_points(&mut K0);
// B0 = ECINTT_col(K0) N_POINTS x 1 (256x1)
let mut B0 = interpolate_points(&mut K0, &mut interpolate_column_domain);
println!("B0 {:0.3?}", br1_time.elapsed());

// K1 = ECNTT_col(MUL_col(B0, [1 u u^2 ...])) N_POINTS x 1 (256x1)
let mut K1 = evaluate_points_on_coset(&mut B0, &mut evaluate_column_domain, &mut column_coset);
println!("K1 {:0.3?}", br1_time.elapsed());
reverse_order_points(&mut K1);

// K = [K0, K1] // 2*N_POINTS x 1 (512x1 commitments)
K1.copy_to(&mut K[N_ROWS..]).unwrap();
println!("K {:0.3?}", br1_time.elapsed());

assert_eq!(K, get_debug_data_points_proj_xy1_vec("K.csv", 2 * N_ROWS));
println!("Branch1 {:0.3?}", br1_time.elapsed());

////////////////////////////////
println!("Branch 2");
////////////////////////////////
let br2_time = Instant::now();

let mut D_rows = evaluate_scalars_on_coset_batch(&mut C_rows, &mut evaluate_row_domain, N_ROWS, &mut row_coset);
println!("D_both {:0.3?}", br2_time.elapsed());

let mut D_transposed = unsafe { DeviceBuffer::uninitialized(2 * N_ROWS * M_POINTS).unwrap() };
transpose_scalar_matrix(&mut D_transposed.as_device_ptr(), &mut D_in, M_POINTS, N_ROWS);
transpose_scalar_matrix(&mut D_transposed.as_device_ptr().wrapping_offset((N_ROWS * M_POINTS) as isize), &mut D_rows, M_POINTS, N_ROWS);

let mut D = unsafe { DeviceBuffer::uninitialized(4 * N_ROWS * M_POINTS).unwrap() };
transpose_scalar_matrix(&mut D.as_device_ptr(), &mut D_transposed, N_ROWS, 2 * M_POINTS);

reverse_order_scalars_batch(&mut D_transposed, 2 * M_POINTS);
let mut C0 = interpolate_scalars_batch(&mut D_transposed, &mut interpolate_column_domain, 2 * M_POINTS);
let mut D_cols = evaluate_scalars_on_coset_batch(&mut C0, &mut evaluate_column_domain, 2 * M_POINTS, &mut column_coset);
reverse_order_scalars_batch(&mut D_cols, 2 * M_POINTS);

transpose_scalar_matrix(&mut D.as_device_ptr().wrapping_offset((2 * N_ROWS * M_POINTS) as isize), &mut D_cols, N_ROWS, 2 * M_POINTS);

let mut D_host_flat = vec![ScalarField::zero(); 4 * N_ROWS * M_POINTS];
D.copy_to(&mut D_host_flat[..]).unwrap();
let D_host = D_host_flat.chunks(2 * M_POINTS).collect::<Vec<_>>();

println!("Branch2 {:0.3?}", br2_time.elapsed());
debug_assert_eq!(D_host, get_debug_data_scalars("D.csv", 2 * N_ROWS, 2 * M_POINTS));

////////////////////////////////
println!("Branch 3");
////////////////////////////////
let br3_time = Instant::now();

//d1 = MSM_batch(D[i], [S], l) 1x8192
reverse_order_scalars_batch(&mut D, (4 * M_POINTS * N_ROWS) / l);
let mut d1 = commit_batch(&mut S, &mut D, (4 * M_POINTS * N_ROWS) / l);

//delta0 = ECINTT_row(d1) 1x512
let mut delta0 = interpolate_points_batch(&mut d1, &mut interpolate_column_large_domain, 2 * N_ROWS);

// delta0 = delta0 << 256 1x512
shift_points_batch(&mut delta0, 2 * N_ROWS);

//q[mu] = ECNTT_row(delta0) 1x512
let P = evaluate_points_batch(&mut delta0, &mut evaluate_column_large_domain, 2 * N_ROWS);
let mut P_host_flat: Vec<Point> = (0..(4 * M_POINTS * N_ROWS) / l).map(|_| Point::zero()).collect();
P.copy_to(&mut P_host_flat[..]).unwrap();
let P_host = split_vec_to_matrix(&P_host_flat, 2 * N_ROWS).clone();

//final assertion
debug_assert_eq!(
P_host,
get_debug_data_points_xy1("P.csv", 2 * N_ROWS, 2 * N_ROWS)
);
println!("final check");

println!("Branch3 {:0.3?}", br3_time.elapsed());

assert_ne!(P_host[12][23], Point::zero()); //dummy check
println!("success !!!");
}

#[cfg(test)]
mod tests {
use super::main_flow;
use super::{main_flow, alternate_flow};

#[test]
fn test_main_flow() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure current "main_flow" worth keeping if new API is correct and faster

Copy link
Author

@DmytroTym DmytroTym May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove current main_flow if everyone's happy with the new one, I left it for the sake of easier comparison for now

main_flow();
}

#[test]
fn test_alternate_flow() {
alternate_flow();
}
}
Loading