Skip to content

Commit

Permalink
Allow dynamic constraint size in sparse poly commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Kunming Jiang committed Dec 11, 2024
1 parent 0044de9 commit 777b501
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 50 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
.DS_Store
*.rtk
*.ctk
zok_tests/benchmarks/poseidon_test/poseidon_const.zok

# Generated by Cargo
# will have compiled files and executables
Expand Down
6 changes: 3 additions & 3 deletions spartan_parallel/src/dense_mlpoly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ impl<S: SpartanExtensionField> PolyEvalProof<S> {
// compute vector-matrix product between L and Z viewed as a matrix
let LZ = poly.bound(&L);

PolyEvalProof { v: LZ }
PolyEvalProof { v: Vec::new() }
}

pub fn verify(
Expand Down Expand Up @@ -440,7 +440,7 @@ impl<S: SpartanExtensionField> PolyEvalProof<S> {
// compute vector-matrix product between L and Z viewed as a matrix
let LZ = poly.bound(L);

proof_list.push(PolyEvalProof{ v: LZ });
proof_list.push(PolyEvalProof{ v: Vec::new() });
}

proof_list
Expand Down Expand Up @@ -555,7 +555,7 @@ impl<S: SpartanExtensionField> PolyEvalProof<S> {
let LZ = poly.bound(&L);
L_list.push(L);
R_list.push(R);
LZ_list.push(LZ);
LZ_list.push(Vec::new());
}
}

Expand Down
52 changes: 27 additions & 25 deletions spartan_parallel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,11 @@ pub struct SNARK<S: SpartanExtensionField> {
perm_root_r1cs_eval_proof: R1CSEvalProof<S>,

// Product proof for permutation
perm_poly_poly_list: Vec<S>,
proof_eval_perm_poly_prod_list: Vec<PolyEvalProof<S>>,
// perm_poly_poly_list: Vec<S>,
// proof_eval_perm_poly_prod_list: Vec<PolyEvalProof<S>>,

// shift_proof: ShiftProofs<S>,
io_proof: IOProofs<S>,
// io_proof: IOProofs<S>,
}

// Sort block_num_proofs and record where each entry is
Expand Down Expand Up @@ -649,7 +649,7 @@ impl<S: SpartanExtensionField> SNARK<S> {
}

// Computes proof size by commitment / non-commitment
fn compute_size(&self) -> (usize, usize, usize) {
fn compute_size(&self) -> (usize, usize, usize, usize) {
/*
let commit_size = bincode::serialize(&self.block_comm_vars_list).unwrap().len()
+ bincode::serialize(&self.exec_comm_inputs).unwrap().len()
Expand Down Expand Up @@ -683,29 +683,27 @@ impl<S: SpartanExtensionField> SNARK<S> {
+ bincode::serialize(&self.vir_mem_addr_comm_w3).unwrap().len()
+ bincode::serialize(&self.vir_mem_addr_comm_w3_shifted).unwrap().len();
*/
let dense_commit_size = 0;

let sparse_commit_size = bincode::serialize(&self.block_r1cs_eval_proof_list).unwrap().len()
+ bincode::serialize(&self.pairwise_check_r1cs_eval_proof).unwrap().len()
+ bincode::serialize(&self.perm_root_r1cs_eval_proof).unwrap().len()
+ bincode::serialize(&self.proof_eval_perm_poly_prod_list).unwrap().len();

let noncommit_size = bincode::serialize(&self.block_r1cs_sat_proof).unwrap().len()
let dense_commit_size = 0;

let block_proof_size = bincode::serialize(&self.block_r1cs_sat_proof).unwrap().len()
+ bincode::serialize(&self.block_inst_evals_bound_rp).unwrap().len()
+ bincode::serialize(&self.block_inst_evals_list).unwrap().len()
+ bincode::serialize(&self.block_r1cs_eval_proof_list).unwrap().len();

+ bincode::serialize(&self.pairwise_check_r1cs_sat_proof).unwrap().len()
let pairwise_proof_size = bincode::serialize(&self.pairwise_check_r1cs_sat_proof).unwrap().len()
+ bincode::serialize(&self.pairwise_check_inst_evals_bound_rp).unwrap().len()
+ bincode::serialize(&self.pairwise_check_inst_evals_list).unwrap().len()
+ bincode::serialize(&self.pairwise_check_r1cs_eval_proof).unwrap().len();

+ bincode::serialize(&self.perm_root_r1cs_sat_proof).unwrap().len()
let perm_proof_size = bincode::serialize(&self.perm_root_r1cs_sat_proof).unwrap().len()
+ bincode::serialize(&self.perm_root_inst_evals).unwrap().len()

+ bincode::serialize(&self.perm_poly_poly_list).unwrap().len()
+ bincode::serialize(&self.perm_root_r1cs_eval_proof).unwrap().len();
// + bincode::serialize(&self.perm_poly_poly_list).unwrap().len()
// + bincode::serialize(&self.proof_eval_perm_poly_prod_list).unwrap().len();

// + bincode::serialize(&self.shift_proof).unwrap().len()
+ bincode::serialize(&self.io_proof).unwrap().len();
(dense_commit_size, sparse_commit_size, noncommit_size)
// let io_proof_size = bincode::serialize(&self.io_proof).unwrap().len();
(dense_commit_size, block_proof_size, pairwise_proof_size, perm_proof_size)
}

/// A public computation to create a commitment to a list of R1CS instances
Expand Down Expand Up @@ -2233,6 +2231,7 @@ impl<S: SpartanExtensionField> SNARK<S> {
// --
// PERM_PRODUCT_PROOF
// --
/*
let timer_proof = Timer::new("Perm Product");
// Record the prod of exec, blocks, mem_block, & mem_addr
let (perm_poly_poly_list, proof_eval_perm_poly_prod_list) = {
Expand Down Expand Up @@ -2349,7 +2348,6 @@ impl<S: SpartanExtensionField> SNARK<S> {
shifted_polys.push(&vir_mem_addr_w3_shifted_prover.poly_w[0]);
header_len_list.push(6);
}
/*
let shift_proof = ShiftProofs::prove(
orig_polys,
shifted_polys,
Expand All @@ -2358,7 +2356,6 @@ impl<S: SpartanExtensionField> SNARK<S> {
&mut random_tape,
);
shift_proof
*/
};
timer_proof.stop();
Expand All @@ -2384,6 +2381,7 @@ impl<S: SpartanExtensionField> SNARK<S> {
&mut random_tape,
);
timer_proof.stop();
*/

timer_prove.stop();

Expand All @@ -2402,11 +2400,11 @@ impl<S: SpartanExtensionField> SNARK<S> {
perm_root_inst_evals,
perm_root_r1cs_eval_proof,

perm_poly_poly_list,
proof_eval_perm_poly_prod_list,
// perm_poly_poly_list,
// proof_eval_perm_poly_prod_list,

// shift_proof,
io_proof,
// io_proof,
}
}

Expand Down Expand Up @@ -2457,7 +2455,7 @@ impl<S: SpartanExtensionField> SNARK<S> {

transcript: &mut Transcript,
) -> Result<(), ProofVerifyError> {
let (_, _, sumcheck_size) = self.compute_size();
let (_, block_size, pairwise_size, perm_size) = self.compute_size();
let meta_size =
// usize
19 * std::mem::size_of::<usize>() +
Expand Down Expand Up @@ -3174,6 +3172,7 @@ impl<S: SpartanExtensionField> SNARK<S> {
// --
// PERM_PRODUCT_PROOF
// --
/*
{
let timer_eval_opening = Timer::new("Perm Product");
// Verify prod of exec, blocks, mem_block, & mem_addr
Expand Down Expand Up @@ -3368,10 +3367,13 @@ impl<S: SpartanExtensionField> SNARK<S> {
transcript,
)?;
timer_proof.stop();
*/

timer_verify.stop();

println!("SUMCHECK SIZE: {} bytes", sumcheck_size);
println!("BLOCK SUMCHECK SIZE: {} bytes", block_size);
println!("PAIRWISE SUMCHECK SIZE: {} bytes", pairwise_size);
println!("PERM SUMCHECK SIZE: {} bytes", perm_size);
println!("META SIZE: {} bytes", meta_size);

Ok(())
Expand Down
45 changes: 39 additions & 6 deletions spartan_parallel/src/r1csinstance.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::{max, min};
use std::collections::HashMap;

use crate::scalar::SpartanExtensionField;
Expand All @@ -14,6 +15,7 @@ use super::sparse_mlpoly::{
};
use super::timer::Timer;
use flate2::{write::ZlibEncoder, Compression};
use std::iter::zip;
use merlin::Transcript;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -47,6 +49,8 @@ impl<S: SpartanExtensionField> AppendToTranscript for R1CSCommitment<S> {
}

pub struct R1CSDecommitment<S: SpartanExtensionField> {
num_cons: usize,
num_vars: usize,
dense: MultiSparseMatPolynomialAsDense<S>,
}

Expand Down Expand Up @@ -455,53 +459,66 @@ impl<S: SpartanExtensionField> R1CSInstance<S> {
let mut nnz_size: HashMap<usize, usize> = HashMap::new();
let mut label_map: Vec<Vec<usize>> = Vec::new();
let mut sparse_polys_list: Vec<Vec<&SparseMatPolynomial<S>>> = Vec::new();
let mut max_num_cons_list: Vec<usize> = Vec::new();

for i in 0..self.num_instances {
// A_list
let A_len = Self::next_power_of_eight(self.A_list[i].get_num_nz_entries());
if let Some(index) = nnz_size.get(&A_len) {
label_map[*index].push(3 * i);
sparse_polys_list[*index].push(&self.A_list[i]);
max_num_cons_list[*index] = max(max_num_cons_list[*index], self.num_cons[i]);
} else {
let next_label = nnz_size.len();
nnz_size.insert(A_len, next_label);
label_map.push(vec![3 * i]);
sparse_polys_list.push(vec![&self.A_list[i]]);
max_num_cons_list.push(self.num_cons[i]);
}
// B_list
let B_len = Self::next_power_of_eight(self.B_list[i].get_num_nz_entries());
if let Some(index) = nnz_size.get(&B_len) {
label_map[*index].push(3 * i + 1);
sparse_polys_list[*index].push(&self.B_list[i]);
max_num_cons_list[*index] = max(max_num_cons_list[*index], self.num_cons[i]);
} else {
let next_label = nnz_size.len();
nnz_size.insert(B_len, next_label);
label_map.push(vec![3 * i + 1]);
sparse_polys_list.push(vec![&self.B_list[i]]);
max_num_cons_list.push(self.num_cons[i]);
}
// C_list
let C_len = Self::next_power_of_eight(self.C_list[i].get_num_nz_entries());
if let Some(index) = nnz_size.get(&C_len) {
label_map[*index].push(3 * i + 2);
sparse_polys_list[*index].push(&self.C_list[i]);
max_num_cons_list[*index] = max(max_num_cons_list[*index], self.num_cons[i]);
} else {
let next_label = nnz_size.len();
nnz_size.insert(C_len, next_label);
label_map.push(vec![3 * i + 2]);
sparse_polys_list.push(vec![&self.C_list[i]]);
max_num_cons_list.push(self.num_cons[i]);
}
}

println!("nnz_size: {:?}", nnz_size);
let mut r1cs_comm_list = Vec::new();
let mut r1cs_decomm_list = Vec::new();
for sparse_polys in sparse_polys_list {
println!("NUM_SP: {}", sparse_polys_list.len());
for (sparse_polys, max_num_cons) in zip(sparse_polys_list, max_num_cons_list) {
let (comm, dense) = SparseMatPolynomial::multi_commit(&sparse_polys);
let r1cs_comm = R1CSCommitment {
num_cons: self.num_instances * self.max_num_cons,
num_cons: max_num_cons.next_power_of_two(),
num_vars: self.num_vars,
comm,
};
let r1cs_decomm = R1CSDecommitment { dense };
let r1cs_decomm = R1CSDecommitment {
num_cons: max_num_cons.next_power_of_two(),
num_vars: self.num_vars,
dense
};

r1cs_comm_list.push(r1cs_comm);
r1cs_decomm_list.push(r1cs_decomm);
Expand All @@ -526,7 +543,11 @@ impl<S: SpartanExtensionField> R1CSInstance<S> {
comm,
};

let r1cs_decomm = R1CSDecommitment { dense };
let r1cs_decomm = R1CSDecommitment {
num_cons: self.num_instances * self.max_num_cons,
num_vars: self.num_vars,
dense
};

(r1cs_comm, r1cs_decomm)
}
Expand All @@ -547,8 +568,15 @@ impl<S: SpartanExtensionField> R1CSEvalProof<S> {
random_tape: &mut RandomTape<S>,
) -> R1CSEvalProof<S> {
let timer = Timer::new("R1CSEvalProof::prove");
println!("RX_LEN: {}, RY_LEN: {}", rx.len(), ry.len());
println!("NUM_CONS: {}, NUM_VARS: {}", decomm.num_cons, decomm.num_vars);
let rx_header = rx[..rx.len() - min(rx.len(), decomm.num_cons.log_2())].iter().fold(
S::field_one(), |c, i| c * (S::field_one() - i.clone())
);
let rx_short = &rx[rx.len() - min(rx.len(), decomm.num_cons.log_2())..];
// let ry_short = &ry[..min(ry.len(), decomm.num_vars.log_2())];
let proof =
SparseMatPolyEvalProof::prove(&decomm.dense, rx, ry, evals, transcript, random_tape);
SparseMatPolyEvalProof::prove(&decomm.dense, rx_header, rx_short, ry, evals, transcript, random_tape);
timer.stop();

R1CSEvalProof { proof }
Expand All @@ -562,6 +590,11 @@ impl<S: SpartanExtensionField> R1CSEvalProof<S> {
evals: &Vec<S>,
transcript: &mut Transcript,
) -> Result<(), ProofVerifyError> {
self.proof.verify(&comm.comm, rx, ry, evals, transcript)
let rx_header = rx[..rx.len() - min(rx.len(), comm.num_cons.log_2())].iter().fold(
S::field_one(), |c, i| c * (S::field_one() - i.clone())
);
let rx_short = &rx[rx.len() - min(rx.len(), comm.num_cons.log_2())..];
// let ry_short = &ry[..min(ry.len(), comm.num_vars.log_2())];
self.proof.verify(&comm.comm, rx_header, rx_short, ry, evals, transcript)
}
}
Loading

0 comments on commit 777b501

Please sign in to comment.