diff --git a/Cargo.lock b/Cargo.lock index 92ae452f9..ebd4efefc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -135,6 +135,27 @@ dependencies = [ "serde", ] +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + [[package]] name = "blake2" version = "0.10.6" @@ -450,6 +471,28 @@ dependencies = [ "log", ] +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "generic-array" version = "0.14.7" @@ -560,6 +603,18 @@ version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + [[package]] name = "log" version = "0.4.21" @@ -617,6 +672,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -692,6 +748,32 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proptest" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c2511913b88df1637da85cc8d96ec8e43a3f8bb8ccb71ee1ac240d6f3df58d" +dependencies = [ + "bit-set", + "bit-vec", + "bitflags", + "lazy_static", + "num-traits", + "rand", + "rand_chacha", + "rand_xorshift", + "regex-syntax 0.8.3", + "rusty-fork", + "tempfile", + "unarray", +] + +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.36" @@ -707,6 +789,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ + "libc", "rand_chacha", "rand_core", ] @@ -726,6 +809,18 @@ name = "rand_core" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_xorshift" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" +dependencies = [ + "rand_core", +] [[package]] name = "rayon" @@ -810,6 +905,31 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "0.38.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + +[[package]] +name = "rusty-fork" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb3dcc6e454c328bb824492db107ab7c0ae8fcffe4ad210136ef014458c1bc4f" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + [[package]] name = "ryu" version = "1.0.17" @@ -963,6 +1083,7 @@ dependencies = [ "hex", "itertools 0.12.1", "num-traits", + "proptest", "rand", "rayon", "serde", @@ -1002,6 +1123,18 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tempfile" +version = "3.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +dependencies = [ + "cfg-if", + "fastrand", + "rustix", + "windows-sys", +] + [[package]] name = "test-log" version = "0.2.15" @@ -1131,6 +1264,12 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicode-ident" version = "1.0.12" @@ -1149,6 +1288,15 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "wait-timeout" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6" +dependencies = [ + "libc", +] + [[package]] name = "walkdir" version = "2.5.0" diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 381e912ea..abf486357 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -27,6 +27,7 @@ rayon = { version = "1.10.0", optional = true } serde = { version = "1.0", features = ["derive"] } [dev-dependencies] +proptest = "1.0" aligned = "0.4.2" test-log = { version = "0.2.15", features = ["trace"] } tracing-subscriber = "0.3.18" diff --git a/crates/prover/proptest-regressions/examples/range_check/mod.txt b/crates/prover/proptest-regressions/examples/range_check/mod.txt new file mode 100644 index 000000000..629150366 --- /dev/null +++ b/crates/prover/proptest-regressions/examples/range_check/mod.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc e1bb1f72c891764d16309d8fc8cfd7a23f6d31324a661979c6c684a2cc29724c # shrinks to invalid_value = 32768 diff --git a/crates/prover/src/examples/mod.rs b/crates/prover/src/examples/mod.rs index 45e2f4186..7c9fa069c 100644 --- a/crates/prover/src/examples/mod.rs +++ b/crates/prover/src/examples/mod.rs @@ -1,4 +1,5 @@ pub mod fibonacci; pub mod poseidon; +pub mod range_check; pub mod wide_fibonacci; pub mod xor; diff --git a/crates/prover/src/examples/range_check/air.rs b/crates/prover/src/examples/range_check/air.rs new file mode 100644 index 000000000..47f8c6589 --- /dev/null +++ b/crates/prover/src/examples/range_check/air.rs @@ -0,0 +1,118 @@ +use super::component::{RangeCheckComponent, RangeCheckInput, RangeCheckTraceGenerator}; +use crate::core::air::{Air, AirProver, Component, ComponentProver}; +use crate::core::backend::CpuBackend; +use crate::core::channel::Blake2sChannel; +use crate::core::fields::m31::BaseField; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; +use crate::core::prover::VerificationError; +use crate::core::{ColumnVec, InteractionElements, LookupValues}; +use crate::trace_generation::registry::ComponentGenerationRegistry; +use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; + +pub struct RangeCheckAirGenerator { + pub registry: ComponentGenerationRegistry, +} + +impl RangeCheckAirGenerator { + pub fn new(inputs: &RangeCheckInput) -> Self { + let mut component_generator = RangeCheckTraceGenerator::new(); + component_generator.add_inputs(inputs); + let mut registry = ComponentGenerationRegistry::default(); + registry.register("range_check", component_generator); + Self { registry } + } +} + +impl AirTraceVerifier for RangeCheckAirGenerator { + fn interaction_elements(&self, _channel: &mut Blake2sChannel) -> InteractionElements { + InteractionElements::default() + } +} + +impl AirTraceGenerator for RangeCheckAirGenerator { + fn write_trace(&mut self) -> Vec> { + RangeCheckTraceGenerator::write_trace("range_check", &mut self.registry) + } + + fn interact( + &self, + _trace: &ColumnVec>, + _elements: &InteractionElements, + ) -> Vec> { + vec![] + } + + fn to_air_prover(&self) -> impl AirProver { + let component_generator = self + .registry + .get_generator::("range_check"); + RangeCheckAir { + component: component_generator.component(), + } + } + + fn composition_log_degree_bound(&self) -> u32 { + let component_generator = self + .registry + .get_generator::("range_check"); + assert!( + component_generator.inputs_set(), + "range_check input not set." + ); + component_generator + .component() + .max_constraint_log_degree_bound() + } +} + +#[derive(Clone)] +pub struct RangeCheckAir { + pub component: RangeCheckComponent, +} + +impl RangeCheckAir { + pub fn new(component: RangeCheckComponent) -> Self { + Self { component } + } +} + +impl Air for RangeCheckAir { + fn components(&self) -> Vec<&dyn Component> { + vec![&self.component] + } + + fn verify_lookups(&self, _lookup_values: &LookupValues) -> Result<(), VerificationError> { + Ok(()) + } +} + +impl AirTraceVerifier for RangeCheckAir { + fn interaction_elements(&self, _channel: &mut Blake2sChannel) -> InteractionElements { + InteractionElements::default() + } +} + +impl AirTraceGenerator for RangeCheckAir { + fn interact( + &self, + _trace: &ColumnVec>, + _elements: &InteractionElements, + ) -> Vec> { + vec![] + } + + fn to_air_prover(&self) -> impl AirProver { + self.clone() + } + + fn composition_log_degree_bound(&self) -> u32 { + self.component.max_constraint_log_degree_bound() + } +} + +impl AirProver for RangeCheckAir { + fn prover_components(&self) -> Vec<&dyn ComponentProver> { + vec![&self.component] + } +} diff --git a/crates/prover/src/examples/range_check/component.rs b/crates/prover/src/examples/range_check/component.rs new file mode 100644 index 000000000..342d5f0ec --- /dev/null +++ b/crates/prover/src/examples/range_check/component.rs @@ -0,0 +1,231 @@ +use std::ops::Div; + +use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; +use crate::core::air::mask::shifted_mask_points; +use crate::core::air::{Component, ComponentProver, ComponentTrace}; +use crate::core::backend::CpuBackend; +use crate::core::circle::{CirclePoint, Coset}; +use crate::core::constraints::point_vanishing; +use crate::core::fields::m31::{BaseField, M31}; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::ExtensionOf; +use crate::core::pcs::TreeVec; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::bit_reverse_index; +use crate::core::{ColumnVec, InteractionElements, LookupValues}; +use crate::trace_generation::registry::ComponentGenerationRegistry; +use crate::trace_generation::{ComponentGen, ComponentTraceGenerator, BASE_TRACE}; + +#[derive(Clone)] +pub struct RangeCheckComponent { + pub log_size: u32, + pub value: BaseField, +} + +impl RangeCheckComponent { + pub fn new(log_size: u32, value: BaseField) -> Self { + Self { log_size, value } + } + + /// Evaluates the step constraint quotient polynomial on a single point. + fn step_constraint_eval_quotient_by_mask>( + &self, + point: CirclePoint, + mask: &[F; 16], + ) -> F { + let two = F::one().double(); + let constraint_zero_domain = Coset::subgroup(self.log_size); + // Check if value at mask[0] equals value represented by 15 next bits little endian. + // If value can be represented with 15 bits, it means that it is in range of 0..2^15 + let constraint_value = mask[0] + - mask[1..] + .iter() + .enumerate() + .map(|(i, &val)| val * two.pow(i as u128)) + .sum::(); + let num = constraint_value; + // Apply this step constrain on first row + let denom = point_vanishing(constraint_zero_domain.at(0).into_ef(), point); + num / denom + } + + /// Evaluates the boundary constraint quotient polynomial on a single point. + fn boundary_constraint_eval_quotient_by_mask>( + &self, + point: CirclePoint, + mask: &[F; 1], + ) -> F { + let constraint_zero_domain = Coset::subgroup(self.log_size); + // Check if mask value is a binary 0 | 1 + let constraint_value = mask[0].square() - mask[0]; + let num = constraint_value; + // Apply this boundary constrain on 1..16 trace rows + let denom = (1..16) + .map(|i| point_vanishing(constraint_zero_domain.at(i), point)) + .product::(); + num / denom + } +} + +impl Component for RangeCheckComponent { + fn n_constraints(&self) -> usize { + 2 + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + // Step constraint is of degree 2. + self.log_size + 1 + } + + fn trace_log_degree_bounds(&self) -> TreeVec> { + TreeVec::new(vec![vec![self.log_size]]) + } + + fn mask_points( + &self, + point: CirclePoint, + ) -> TreeVec>>> { + TreeVec::new(vec![shifted_mask_points( + &vec![(0..16).collect::>()], + &[CanonicCoset::new(self.log_size)], + point, + )]) + } + + fn evaluate_constraint_quotients_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + evaluation_accumulator: &mut PointEvaluationAccumulator, + _interaction_elements: &InteractionElements, + _lookup_values: &LookupValues, + ) { + evaluation_accumulator.accumulate( + self.step_constraint_eval_quotient_by_mask(point, &mask[0][0][..].try_into().unwrap()), + ); + evaluation_accumulator.accumulate(self.boundary_constraint_eval_quotient_by_mask( + point, + &mask[0][0][..1].try_into().unwrap(), + )); + } +} + +#[derive(Copy, Clone)] +pub struct RangeCheckInput { + pub log_size: u32, + pub value: BaseField, +} + +#[derive(Clone)] +pub struct RangeCheckTraceGenerator { + input: Option, +} + +impl ComponentGen for RangeCheckTraceGenerator {} + +impl RangeCheckTraceGenerator { + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Self { input: None } + } + + pub fn inputs_set(&self) -> bool { + self.input.is_some() + } +} + +impl ComponentTraceGenerator for RangeCheckTraceGenerator { + type Component = RangeCheckComponent; + type Inputs = RangeCheckInput; + + fn add_inputs(&mut self, inputs: &Self::Inputs) { + assert!(!self.inputs_set(), "range_check input already set."); + self.input = Some(*inputs); + } + + fn write_trace( + component_id: &str, + registry: &mut ComponentGenerationRegistry, + ) -> ColumnVec> { + let trace_generator = registry.get_generator_mut::(component_id); + assert!(trace_generator.inputs_set(), "range_check input not set."); + let trace_domain = CanonicCoset::new(trace_generator.input.unwrap().log_size); + let mut trace = Vec::with_capacity(trace_domain.size()); + + if let Some(input) = trace_generator.input { + // Push the value to the trace. + trace.push(input.value); + + // Fill trace with binary representation of value. + let mut value_bits = input.value.0; + for _ in 0..15 { + trace.push(M31::from(value_bits & 0x1)); + value_bits >>= 1; + } + } + + // Returns as a CircleEvaluation. + vec![CircleEvaluation::new_canonical_ordered(trace_domain, trace)] + } + + fn write_interaction_trace( + &self, + _trace: &ColumnVec<&CircleEvaluation>, + _elements: &InteractionElements, + ) -> ColumnVec> { + vec![] + } + + fn component(&self) -> Self::Component { + assert!(self.inputs_set(), "range_check input not set."); + RangeCheckComponent::new(self.input.unwrap().log_size, self.input.unwrap().value) + } +} + +impl ComponentProver for RangeCheckComponent { + fn evaluate_constraint_quotients_on_domain( + &self, + trace: &ComponentTrace<'_, CpuBackend>, + evaluation_accumulator: &mut DomainEvaluationAccumulator, + _interaction_elements: &InteractionElements, + _lookup_values: &LookupValues, + ) { + let poly = &trace.polys[BASE_TRACE][0]; + let trace_domain = CanonicCoset::new(self.log_size); + let trace_eval_domain = CanonicCoset::new(self.log_size + 1).circle_domain(); + let trace_eval = poly.evaluate(trace_eval_domain).bit_reverse(); + + // Step constraint. + let constraint_log_degree_bound = trace_domain.log_size() + 1; + let [mut accum] = evaluation_accumulator.columns([(constraint_log_degree_bound, 2)]); + let constraint_eval_domain = trace_eval_domain; + for (off, point_coset) in [ + (0, constraint_eval_domain.half_coset), + ( + constraint_eval_domain.half_coset.size(), + constraint_eval_domain.half_coset.conjugate(), + ), + ] { + let eval = trace_eval.fetch_eval_on_coset(point_coset.shift(trace_domain.index_at(0))); + let mul = trace_domain.step_size().div(point_coset.step_size); + for (i, point) in point_coset.iter().enumerate() { + let mask: [M31; 16] = (0..16) + .map(|j| eval[i as isize + j * mul]) + .collect::>() + .try_into() + .unwrap(); + + let mut res = self.boundary_constraint_eval_quotient_by_mask(point, &[mask[0]]) + * accum.random_coeff_powers[0]; + res += self.step_constraint_eval_quotient_by_mask(point, &mask) + * accum.random_coeff_powers[1]; + accum.accumulate(bit_reverse_index(i + off, constraint_log_degree_bound), res); + } + } + } + + fn lookup_values(&self, _trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { + LookupValues::default() + } +} diff --git a/crates/prover/src/examples/range_check/mod.rs b/crates/prover/src/examples/range_check/mod.rs new file mode 100644 index 000000000..e5fd405c3 --- /dev/null +++ b/crates/prover/src/examples/range_check/mod.rs @@ -0,0 +1,98 @@ +use air::RangeCheckAir; + +use self::component::RangeCheckComponent; +use crate::core::backend::cpu::CpuCircleEvaluation; +use crate::core::channel::{Blake2sChannel, Channel}; +use crate::core::fields::m31::{BaseField, M31}; +use crate::core::fields::IntoSlice; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::prover::{ProvingError, StarkProof, VerificationError}; +use crate::core::vcs::blake2_hash::Blake2sHasher; +use crate::core::vcs::hasher::Hasher; +use crate::trace_generation::{commit_and_prove, commit_and_verify}; + +pub mod air; +mod component; + +#[derive(Clone)] +pub struct RangeCheck { + pub air: RangeCheckAir, +} + +impl RangeCheck { + pub fn new(log_size: u32, value: BaseField) -> Self { + let component = RangeCheckComponent::new(log_size, value); + Self { + air: RangeCheckAir::new(component), + } + } + + pub fn get_trace(&self) -> CpuCircleEvaluation { + // Trace. + let trace_domain = CanonicCoset::new(self.air.component.log_size); + let mut trace = Vec::with_capacity(trace_domain.size()); + + // Push the value to the trace. + trace.push(self.air.component.value); + + // Fill trace with binary representation of value. + let mut value_bits = self.air.component.value.0; + for _ in 0..15 { + trace.push(M31::from(value_bits & 0x1)); + value_bits >>= 1; + } + + // Returns as a CircleEvaluation. + CircleEvaluation::new_canonical_ordered(trace_domain, trace) + } + + pub fn prove(&self) -> Result { + let trace = self.get_trace(); + let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[self + .air + .component + .value]))); + commit_and_prove(&self.air, channel, vec![trace]) + } + + pub fn verify(&self, proof: StarkProof) -> Result<(), VerificationError> { + let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[self + .air + .component + .value]))); + commit_and_verify(proof, &self.air, channel) + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::RangeCheck; + use crate::m31; + + const RANGE_CHECK_LOG_SIZE: u32 = 4; + + proptest! { + #![proptest_config(ProptestConfig { + cases: 50, // Number of test cases to generate + .. ProptestConfig::default() + })] + + #[test] + fn test_range_check_prove(valid_value in 0..32768_u32) { + let range_check = RangeCheck::new(RANGE_CHECK_LOG_SIZE, m31!(valid_value)); + let proof = range_check.prove().unwrap(); + range_check.verify(proof).unwrap(); + } + + #[test] + #[should_panic] + fn test_range_check_prove_overflow(invalid_value in 32768..u32::MAX) { + let range_check = RangeCheck::new(RANGE_CHECK_LOG_SIZE, m31!(invalid_value)); + let proof = range_check.prove().unwrap(); + range_check.verify(proof).unwrap(); + } + } +}