From 4819acf626e912039f28f6693861d8ed9a60ea9e Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Fri, 15 Mar 2024 17:49:16 +0200 Subject: [PATCH] Clean up fri quotients --- src/core/backend/cpu/quotients.rs | 13 ++--- src/core/commitment_scheme/quotients.rs | 64 ++++++++++--------------- 2 files changed, 32 insertions(+), 45 deletions(-) diff --git a/src/core/backend/cpu/quotients.rs b/src/core/backend/cpu/quotients.rs index c2c550cba..a40e0d2d2 100644 --- a/src/core/backend/cpu/quotients.rs +++ b/src/core/backend/cpu/quotients.rs @@ -9,7 +9,7 @@ use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumn; use crate::core::fields::{ComplexConjugate, FieldExpOps}; -use crate::core::poly::circle::{CircleDomain, CircleEvaluation}; +use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::utils::bit_reverse_index; @@ -19,8 +19,8 @@ impl QuotientOps for CPUBackend { columns: &[&CircleEvaluation], random_coeff: SecureField, sample_batches: &[ColumnSampleBatch], - ) -> SecureColumn { - let mut res = SecureColumn::zeros(domain.size()); + ) -> SecureEvaluation { + let mut values = SecureColumn::zeros(domain.size()); let column_constants = column_constants(sample_batches, random_coeff); for row in 0..domain.size() { @@ -34,9 +34,9 @@ impl QuotientOps for CPUBackend { random_coeff, domain_point, ); - res.set(row, row_value); + values.set(row, row_value); } - res + SecureEvaluation { domain, values } } } @@ -128,7 +128,8 @@ mod tests { }], ); let quot_poly_base_field = - CPUCircleEvaluation::new(eval_domain, quot_eval.columns[0].clone()).interpolate(); + CPUCircleEvaluation::new(eval_domain, quot_eval.values.columns[0].clone()) + .interpolate(); assert!(quot_poly_base_field.is_in_fft_space(LOG_SIZE)); } } diff --git a/src/core/commitment_scheme/quotients.rs b/src/core/commitment_scheme/quotients.rs index 5f4b1ef29..f4f7b2a97 100644 --- a/src/core/commitment_scheme/quotients.rs +++ b/src/core/commitment_scheme/quotients.rs @@ -4,18 +4,15 @@ use std::iter::zip; use itertools::{izip, multiunzip, Itertools}; -use crate::core::backend::cpu::quotients::{accumulate_row_quotients, column_constants}; -use crate::core::backend::Backend; +use crate::core::backend::{Backend, CPUBackend}; use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SecureColumn; use crate::core::fri::SparseCircleEvaluation; use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation, SecureEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::prover::VerificationError; use crate::core::queries::SparseSubCircleDomain; -use crate::core::utils::bit_reverse_index; pub trait QuotientOps: Backend { /// Accumulates the quotients of the columns at the given domain. @@ -29,7 +26,7 @@ pub trait QuotientOps: Backend { columns: &[&CircleEvaluation], random_coeff: SecureField, sample_batches: &[ColumnSampleBatch], - ) -> SecureColumn; + ) -> SecureEvaluation; } /// A batch of column samplings at a point. @@ -85,8 +82,7 @@ pub fn compute_fri_quotients( let domain = CanonicCoset::new(log_size).circle_domain(); // TODO: slice. let sample_batches = ColumnSampleBatch::new_vec(&samples); - let values = B::accumulate_quotients(domain, &columns, random_coeff, &sample_batches); - SecureEvaluation { domain, values } + B::accumulate_quotients(domain, &columns, random_coeff, &sample_batches) }) .collect() } @@ -124,8 +120,7 @@ pub fn fri_answers_for_log_size( queried_values_per_column: &[&Vec], ) -> Result { let commitment_domain = CanonicCoset::new(log_size).circle_domain(); - let sample_batches = ColumnSampleBatch::new_vec(samples); - let column_constants = column_constants(&sample_batches, random_coeff); + let batched_samples = ColumnSampleBatch::new_vec(samples); for queried_values in queried_values_per_column { if queried_values.len() != query_domain.flatten().len() { return Err(VerificationError::InvalidStructure( @@ -138,36 +133,27 @@ pub fn fri_answers_for_log_size( .map(|q| q.iter()) .collect_vec(); - let mut evals = Vec::new(); - for subdomain in query_domain.iter() { - let domain = subdomain.to_circle_domain(&commitment_domain); - let mut column_evals = Vec::new(); - for queried_values in queried_values_per_column.iter_mut() { - let eval = CircleEvaluation::new( - domain, - queried_values.take(domain.size()).copied().collect_vec(), - ); - column_evals.push(eval); - } - // TODO(spapini): bit reverse iterator. - let mut values = Vec::new(); - for row in 0..domain.size() { - let domain_point = domain.at(bit_reverse_index(row, log_size)); - let value = accumulate_row_quotients( - &sample_batches, - &column_evals.iter().collect_vec(), - &column_constants, - row, - random_coeff, - domain_point, - ); - values.push(value); - } - let eval = CircleEvaluation::new(domain, values); - evals.push(eval); - } - - let res = SparseCircleEvaluation::new(evals); + let res = SparseCircleEvaluation::new( + query_domain + .iter() + .map(|subdomain| { + let domain = subdomain.to_circle_domain(&commitment_domain); + let column_evals = queried_values_per_column + .iter_mut() + .map(|q| { + CircleEvaluation::new(domain, q.take(domain.size()).copied().collect_vec()) + }) + .collect_vec(); + CPUBackend::accumulate_quotients( + domain, + &column_evals.iter().collect_vec(), + random_coeff, + &batched_samples, + ) + .to_cpu() + }) + .collect(), + ); if !queried_values_per_column.iter().all(|x| x.is_empty()) { return Err(VerificationError::InvalidStructure( "Too many queried values".to_string(),