From e25129226ec7774414d624d839591aaf2d480f79 Mon Sep 17 00:00:00 2001 From: Chris Sosnin Date: Wed, 20 Sep 2023 20:02:47 +0400 Subject: [PATCH 1/3] make nonnative gadget params configurable --- src/fields/nonnative/allocated_field_var.rs | 157 ++++++++----------- src/fields/nonnative/allocated_mul_result.rs | 68 ++++---- src/fields/nonnative/field_var.rs | 140 +++++++++-------- src/fields/nonnative/mul_result.rs | 36 +++-- src/fields/nonnative/params.rs | 30 ++++ src/fields/nonnative/reduce.rs | 44 +++--- src/groups/curves/short_weierstrass/mod.rs | 5 +- tests/to_constraint_field_test.rs | 4 +- 8 files changed, 256 insertions(+), 228 deletions(-) diff --git a/src/fields/nonnative/allocated_field_var.rs b/src/fields/nonnative/allocated_field_var.rs index aadbe1a3..839e8d78 100644 --- a/src/fields/nonnative/allocated_field_var.rs +++ b/src/fields/nonnative/allocated_field_var.rs @@ -1,5 +1,5 @@ use super::{ - params::{get_params, OptimizationType}, + params::{DefaultParams, OptimizationType, Params}, reduce::{bigint_to_basefield, limbs_to_bigint, Reducer}, AllocatedNonNativeFieldMulResultVar, }; @@ -22,7 +22,11 @@ use ark_std::{ /// The allocated version of `NonNativeFieldVar` (introduced below) #[derive(Debug)] #[must_use] -pub struct AllocatedNonNativeFieldVar { +pub struct AllocatedNonNativeFieldVar< + TargetField: PrimeField, + BaseField: PrimeField, + P: Params = DefaultParams, +> { /// Constraint system reference pub cs: ConstraintSystemRef, /// The limbs, each of which is a BaseField gadget. @@ -36,10 +40,12 @@ pub struct AllocatedNonNativeFieldVar, + #[doc(hidden)] + pub params_phantom: PhantomData

, } -impl - AllocatedNonNativeFieldVar +impl + AllocatedNonNativeFieldVar { /// Return cs pub fn cs(&self) -> ConstraintSystemRef { @@ -51,11 +57,7 @@ impl limbs: Vec, optimization_type: OptimizationType, ) -> TargetField { - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - optimization_type, - ); + let params =

::get::(optimization_type); let mut base_repr: ::BigInt = TargetField::one().into_bigint(); @@ -122,6 +124,7 @@ impl num_of_additions_over_normal_form: BaseField::zero(), is_in_the_normal_form: true, target_phantom: PhantomData, + params_phantom: PhantomData, }) } @@ -154,9 +157,10 @@ impl .add(&BaseField::one()), is_in_the_normal_form: false, target_phantom: PhantomData, + params_phantom: PhantomData, }; - Reducer::::post_add_reduce(&mut res)?; + Reducer::::post_add_reduce(&mut res)?; Ok(res) } @@ -178,9 +182,10 @@ impl .add(&BaseField::one()), is_in_the_normal_form: false, target_phantom: PhantomData, + params_phantom: PhantomData, }; - Reducer::::post_add_reduce(&mut res)?; + Reducer::::post_add_reduce(&mut res)?; Ok(res) } @@ -190,11 +195,7 @@ impl pub fn sub_without_reduce(&self, other: &Self) -> R1CSResult { assert_eq!(self.get_optimization_type(), other.get_optimization_type()); - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - self.get_optimization_type(), - ); + let params =

::get::(self.get_optimization_type()); // Step 1: reduce the `other` if needed let mut surfeit = overhead!(other.num_of_additions_over_normal_form + BaseField::one()) + 1; @@ -251,7 +252,7 @@ impl } } - let result = AllocatedNonNativeFieldVar:: { + let result = AllocatedNonNativeFieldVar:: { cs: self.cs(), limbs, num_of_additions_over_normal_form: self.num_of_additions_over_normal_form @@ -259,6 +260,7 @@ impl + (other.num_of_additions_over_normal_form + BaseField::one()), is_in_the_normal_form: false, target_phantom: PhantomData, + params_phantom: PhantomData, }; Ok(result) @@ -270,7 +272,7 @@ impl assert_eq!(self.get_optimization_type(), other.get_optimization_type()); let mut result = self.sub_without_reduce(other)?; - Reducer::::post_add_reduce(&mut result)?; + Reducer::::post_add_reduce(&mut result)?; Ok(result) } @@ -326,11 +328,7 @@ impl elem: &::BigInt, optimization_type: OptimizationType, ) -> R1CSResult> { - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - optimization_type, - ); + let params =

::get::(optimization_type); // push the lower limbs first let mut limbs: Vec = Vec::new(); @@ -358,19 +356,18 @@ impl pub fn mul_without_reduce( &self, other: &Self, - ) -> R1CSResult> { + ) -> R1CSResult> { assert_eq!(self.get_optimization_type(), other.get_optimization_type()); - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - self.get_optimization_type(), - ); + let params =

::get::(self.get_optimization_type()); // Step 1: reduce `self` and `other` if neceessary let mut self_reduced = self.clone(); let mut other_reduced = other.clone(); - Reducer::::pre_mul_reduce(&mut self_reduced, &mut other_reduced)?; + Reducer::::pre_mul_reduce( + &mut self_reduced, + &mut other_reduced, + )?; let mut prod_limbs = Vec::new(); if self.get_optimization_type() == OptimizationType::Weight { @@ -441,6 +438,7 @@ impl + BaseField::one()) * (other_reduced.num_of_additions_over_normal_form + BaseField::one()), target_phantom: PhantomData, + params_phantom: PhantomData, }) } @@ -455,15 +453,11 @@ impl ) -> R1CSResult<()> { assert_eq!(self.get_optimization_type(), other.get_optimization_type()); - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - self.get_optimization_type(), - ); + let params =

::get::(self.get_optimization_type()); // Get p let p_representations = - AllocatedNonNativeFieldVar::::get_limbs_representations_from_big_integer( + AllocatedNonNativeFieldVar::::get_limbs_representations_from_big_integer( &::MODULUS, self.get_optimization_type() )?; @@ -473,12 +467,13 @@ impl for limb in p_representations.iter() { p_gadget_limbs.push(FpVar::::Constant(*limb)); } - let p_gadget = AllocatedNonNativeFieldVar:: { + let p_gadget = AllocatedNonNativeFieldVar:: { cs: self.cs(), limbs: p_gadget_limbs, num_of_additions_over_normal_form: BaseField::one(), is_in_the_normal_form: false, target_phantom: PhantomData, + params_phantom: PhantomData, }; // Get delta = self - other @@ -499,7 +494,7 @@ impl })?; let surfeit = overhead!(delta.num_of_additions_over_normal_form + BaseField::one()) + 1; - Reducer::::limb_to_bits(&k_gadget, surfeit)?; + Reducer::::limb_to_bits(&k_gadget, surfeit)?; // Compute k * p let mut kp_gadget_limbs = Vec::new(); @@ -508,7 +503,7 @@ impl } // Enforce delta = kp - Reducer::::group_and_check_equality( + Reducer::::group_and_check_equality( surfeit, params.bits_per_limb, params.bits_per_limb, @@ -589,6 +584,7 @@ impl num_of_additions_over_normal_form, is_in_the_normal_form: mode != AllocationMode::Witness, target_phantom: PhantomData, + params_phantom: PhantomData, }) } @@ -607,22 +603,18 @@ impl OptimizationGoal::Constraints => OptimizationType::Constraints, OptimizationGoal::Weight => OptimizationType::Weight, }; - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - optimization_type, - ); + let params =

::get::(optimization_type); let mut bits = Vec::new(); for limb in self.limbs.iter().rev().take(params.num_limbs - 1) { bits.extend( - Reducer::::limb_to_bits(limb, params.bits_per_limb)? + Reducer::::limb_to_bits(limb, params.bits_per_limb)? .into_iter() .rev(), ); } bits.extend( - Reducer::::limb_to_bits( + Reducer::::limb_to_bits( &self.limbs[0], TargetField::MODULUS_BIT_SIZE as usize - (params.num_limbs - 1) * params.bits_per_limb, @@ -650,26 +642,22 @@ impl } } -impl ToBitsGadget - for AllocatedNonNativeFieldVar +impl ToBitsGadget + for AllocatedNonNativeFieldVar { #[tracing::instrument(target = "r1cs")] fn to_bits_le(&self) -> R1CSResult>> { - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - self.get_optimization_type(), - ); + let params =

::get::(self.get_optimization_type()); // Reduce to the normal form // Though, a malicious prover can make it slightly larger than p let mut self_normal = self.clone(); - Reducer::::pre_eq_reduce(&mut self_normal)?; + Reducer::::pre_eq_reduce(&mut self_normal)?; // Therefore, we convert it to bits and enforce that it is in the field let mut bits = Vec::>::new(); for limb in self_normal.limbs.iter() { - bits.extend_from_slice(&Reducer::::limb_to_bits( + bits.extend_from_slice(&Reducer::::limb_to_bits( &limb, params.bits_per_limb, )?); @@ -690,8 +678,8 @@ impl ToBitsGadget } } -impl ToBytesGadget - for AllocatedNonNativeFieldVar +impl ToBytesGadget + for AllocatedNonNativeFieldVar { #[tracing::instrument(target = "r1cs")] fn to_bytes(&self) -> R1CSResult>> { @@ -706,8 +694,8 @@ impl ToBytesGadget } } -impl CondSelectGadget - for AllocatedNonNativeFieldVar +impl CondSelectGadget + for AllocatedNonNativeFieldVar { #[tracing::instrument(target = "r1cs")] fn conditionally_select( @@ -736,12 +724,13 @@ impl CondSelectGadget is_in_the_normal_form: true_value.is_in_the_normal_form && false_value.is_in_the_normal_form, target_phantom: PhantomData, + params_phantom: PhantomData, }) } } -impl TwoBitLookupGadget - for AllocatedNonNativeFieldVar +impl TwoBitLookupGadget + for AllocatedNonNativeFieldVar { type TableConstant = TargetField; @@ -761,11 +750,7 @@ impl TwoBitLookupGadget OptimizationType::Weight, }; - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - optimization_type, - ); + let params =

::get::(optimization_type); let mut limbs_constants = Vec::new(); for _ in 0..params.num_limbs { limbs_constants.push(Vec::new()); @@ -773,7 +758,7 @@ impl TwoBitLookupGadget::get_limbs_representations( + AllocatedNonNativeFieldVar::::get_limbs_representations( constant, optimization_type, )?; @@ -788,18 +773,20 @@ impl TwoBitLookupGadget::two_bit_lookup(bits, limbs_constant)?); } - Ok(AllocatedNonNativeFieldVar:: { + Ok(AllocatedNonNativeFieldVar:: { cs, limbs, num_of_additions_over_normal_form: BaseField::zero(), is_in_the_normal_form: true, target_phantom: PhantomData, + params_phantom: PhantomData, }) } } -impl ThreeBitCondNegLookupGadget - for AllocatedNonNativeFieldVar +impl + ThreeBitCondNegLookupGadget + for AllocatedNonNativeFieldVar { type TableConstant = TargetField; @@ -820,11 +807,7 @@ impl ThreeBitCondNegLookupGadget OptimizationGoal::Weight => OptimizationType::Weight, }; - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - optimization_type, - ); + let params =

::get::(optimization_type); let mut limbs_constants = Vec::new(); for _ in 0..params.num_limbs { @@ -833,7 +816,7 @@ impl ThreeBitCondNegLookupGadget for constant in constants.iter() { let representations = - AllocatedNonNativeFieldVar::::get_limbs_representations( + AllocatedNonNativeFieldVar::::get_limbs_representations( constant, optimization_type, )?; @@ -852,18 +835,19 @@ impl ThreeBitCondNegLookupGadget )?); } - Ok(AllocatedNonNativeFieldVar:: { + Ok(AllocatedNonNativeFieldVar:: { cs, limbs, num_of_additions_over_normal_form: BaseField::zero(), is_in_the_normal_form: true, target_phantom: PhantomData, + params_phantom: PhantomData, }) } } -impl AllocVar - for AllocatedNonNativeFieldVar +impl AllocVar + for AllocatedNonNativeFieldVar { fn new_variable>( cs: impl Into>, @@ -880,8 +864,8 @@ impl AllocVar ToConstraintFieldGadget - for AllocatedNonNativeFieldVar +impl ToConstraintFieldGadget + for AllocatedNonNativeFieldVar { fn to_constraint_field(&self) -> R1CSResult>> { // provide a unique representation of the nonnative variable @@ -889,11 +873,7 @@ impl ToConstraintFieldGadget::get::(OptimizationType::Weight); // step 3: assemble the limbs let mut limbs = bits @@ -918,8 +898,8 @@ impl ToConstraintFieldGadget Clone - for AllocatedNonNativeFieldVar +impl Clone + for AllocatedNonNativeFieldVar { fn clone(&self) -> Self { AllocatedNonNativeFieldVar { @@ -928,6 +908,7 @@ impl Clone num_of_additions_over_normal_form: self.num_of_additions_over_normal_form, is_in_the_normal_form: self.is_in_the_normal_form, target_phantom: PhantomData, + params_phantom: PhantomData, } } } diff --git a/src/fields/nonnative/allocated_mul_result.rs b/src/fields/nonnative/allocated_mul_result.rs index 07c74daf..900f99e0 100644 --- a/src/fields/nonnative/allocated_mul_result.rs +++ b/src/fields/nonnative/allocated_mul_result.rs @@ -1,5 +1,5 @@ use super::{ - params::{get_params, OptimizationType}, + params::{DefaultParams, OptimizationType, Params}, reduce::{bigint_to_basefield, limbs_to_bigint, Reducer}, AllocatedNonNativeFieldVar, }; @@ -15,7 +15,11 @@ use num_bigint::BigUint; /// The allocated form of `NonNativeFieldMulResultVar` (introduced below) #[derive(Debug)] #[must_use] -pub struct AllocatedNonNativeFieldMulResultVar { +pub struct AllocatedNonNativeFieldMulResultVar< + TargetField: PrimeField, + BaseField: PrimeField, + P: Params = DefaultParams, +> { /// Constraint system reference pub cs: ConstraintSystemRef, /// Limbs of the intermediate representations @@ -24,18 +28,16 @@ pub struct AllocatedNonNativeFieldMulResultVar, + #[doc(hidden)] + pub params_phantom: PhantomData

, } -impl - From<&AllocatedNonNativeFieldVar> - for AllocatedNonNativeFieldMulResultVar +impl + From<&AllocatedNonNativeFieldVar> + for AllocatedNonNativeFieldMulResultVar { - fn from(src: &AllocatedNonNativeFieldVar) -> Self { - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - src.get_optimization_type(), - ); + fn from(src: &AllocatedNonNativeFieldVar) -> Self { + let params =

::get::(src.get_optimization_type()); let mut limbs = src.limbs.clone(); limbs.reverse(); @@ -49,12 +51,13 @@ impl limbs, prod_of_num_of_additions, target_phantom: PhantomData, + params_phantom: PhantomData, } } } -impl - AllocatedNonNativeFieldMulResultVar +impl + AllocatedNonNativeFieldMulResultVar { /// Get the CS pub fn cs(&self) -> ConstraintSystemRef { @@ -63,14 +66,10 @@ impl /// Get the value of the multiplication result pub fn value(&self) -> R1CSResult { - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - self.get_optimization_type(), - ); + let params =

::get::(self.get_optimization_type()); let p_representations = - AllocatedNonNativeFieldVar::::get_limbs_representations_from_big_integer( + AllocatedNonNativeFieldVar::::get_limbs_representations_from_big_integer( &::MODULUS, self.get_optimization_type() )?; @@ -88,16 +87,12 @@ impl /// Constraints for reducing the result of a multiplication mod p, to get an /// original representation. - pub fn reduce(&self) -> R1CSResult> { - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - self.get_optimization_type(), - ); + pub fn reduce(&self) -> R1CSResult> { + let params =

::get::(self.get_optimization_type()); // Step 1: get p let p_representations = - AllocatedNonNativeFieldVar::::get_limbs_representations_from_big_integer( + AllocatedNonNativeFieldVar::::get_limbs_representations_from_big_integer( &::MODULUS, self.get_optimization_type() )?; @@ -107,12 +102,13 @@ impl for limb in p_representations.iter() { p_gadget_limbs.push(FpVar::::new_constant(self.cs(), limb)?); } - let p_gadget = AllocatedNonNativeFieldVar:: { + let p_gadget = AllocatedNonNativeFieldVar:: { cs: self.cs(), limbs: p_gadget_limbs, num_of_additions_over_normal_form: BaseField::one(), is_in_the_normal_form: false, target_phantom: PhantomData, + params_phantom: PhantomData, }; // Step 2: compute surfeit @@ -171,26 +167,23 @@ impl limbs }; - let k_gadget = AllocatedNonNativeFieldVar:: { + let k_gadget = AllocatedNonNativeFieldVar:: { cs: self.cs(), limbs: k_limbs, num_of_additions_over_normal_form: self.prod_of_num_of_additions, is_in_the_normal_form: false, target_phantom: PhantomData, + params_phantom: PhantomData, }; let cs = self.cs(); - let r_gadget = AllocatedNonNativeFieldVar::::new_witness( + let r_gadget = AllocatedNonNativeFieldVar::::new_witness( ns!(cs, "r"), || Ok(self.value()?), )?; - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - self.get_optimization_type(), - ); + let params =

::get::(self.get_optimization_type()); // Step 1: reduce `self` and `other` if neceessary let mut prod_limbs = Vec::new(); @@ -213,6 +206,7 @@ impl + BaseField::one()) * (k_gadget.num_of_additions_over_normal_form + BaseField::one()), target_phantom: PhantomData, + params_phantom: PhantomData, }; let kp_plus_r_limbs_len = kp_plus_r_gadget.limbs.len(); @@ -220,7 +214,7 @@ impl kp_plus_r_gadget.limbs[kp_plus_r_limbs_len - 1 - i] += limb; } - Reducer::::group_and_check_equality( + Reducer::::group_and_check_equality( surfeit, 2 * params.bits_per_limb, params.bits_per_limb, @@ -249,6 +243,7 @@ impl prod_of_num_of_additions: self.prod_of_num_of_additions + other.prod_of_num_of_additions, target_phantom: PhantomData, + params_phantom: PhantomData, }) } @@ -256,7 +251,7 @@ impl #[tracing::instrument(target = "r1cs")] pub fn add_constant(&self, other: &TargetField) -> R1CSResult { let mut other_limbs = - AllocatedNonNativeFieldVar::::get_limbs_representations( + AllocatedNonNativeFieldVar::::get_limbs_representations( other, self.get_optimization_type(), )?; @@ -279,6 +274,7 @@ impl limbs: new_limbs, prod_of_num_of_additions: self.prod_of_num_of_additions + BaseField::one(), target_phantom: PhantomData, + params_phantom: PhantomData, }) } diff --git a/src/fields/nonnative/field_var.rs b/src/fields/nonnative/field_var.rs index 2879e636..bb69efd0 100644 --- a/src/fields/nonnative/field_var.rs +++ b/src/fields/nonnative/field_var.rs @@ -1,4 +1,7 @@ -use super::{params::OptimizationType, AllocatedNonNativeFieldVar, NonNativeFieldMulResultVar}; +use super::{ + params::{DefaultParams, OptimizationType, Params}, + AllocatedNonNativeFieldVar, NonNativeFieldMulResultVar, +}; use crate::{ boolean::Boolean, fields::{fp::FpVar, FieldVar}, @@ -15,17 +18,22 @@ use ark_std::{ /// A gadget for representing non-native (`TargetField`) field elements over the /// constraint field (`BaseField`). -#[derive(Clone, Debug)] +#[derive(Derivative)] +#[derivative(Debug, Clone(bound = "TargetField: Clone, BaseField: Clone"))] #[must_use] -pub enum NonNativeFieldVar { +pub enum NonNativeFieldVar< + TargetField: PrimeField, + BaseField: PrimeField, + P: Params = DefaultParams, +> { /// Constant Constant(TargetField), /// Allocated gadget - Var(AllocatedNonNativeFieldVar), + Var(AllocatedNonNativeFieldVar), } -impl PartialEq - for NonNativeFieldVar +impl PartialEq + for NonNativeFieldVar { fn eq(&self, other: &Self) -> bool { self.value() @@ -34,21 +42,21 @@ impl PartialEq } } -impl Eq - for NonNativeFieldVar +impl Eq + for NonNativeFieldVar { } -impl Hash - for NonNativeFieldVar +impl Hash + for NonNativeFieldVar { fn hash(&self, state: &mut H) { self.value().unwrap_or_default().hash(state); } } -impl R1CSVar - for NonNativeFieldVar +impl R1CSVar + for NonNativeFieldVar { type Value = TargetField; @@ -67,8 +75,8 @@ impl R1CSVar } } -impl From> - for NonNativeFieldVar +impl From> + for NonNativeFieldVar { fn from(other: Boolean) -> Self { if let Boolean::Constant(b) = other { @@ -82,28 +90,28 @@ impl From> } } -impl - From> - for NonNativeFieldVar +impl + From> + for NonNativeFieldVar { - fn from(other: AllocatedNonNativeFieldVar) -> Self { + fn from(other: AllocatedNonNativeFieldVar) -> Self { Self::Var(other) } } -impl<'a, TargetField: PrimeField, BaseField: PrimeField> FieldOpsBounds<'a, TargetField, Self> - for NonNativeFieldVar +impl<'a, TargetField: PrimeField, BaseField: PrimeField, P: Params> + FieldOpsBounds<'a, TargetField, Self> for NonNativeFieldVar { } -impl<'a, TargetField: PrimeField, BaseField: PrimeField> - FieldOpsBounds<'a, TargetField, NonNativeFieldVar> - for &'a NonNativeFieldVar +impl<'a, TargetField: PrimeField, BaseField: PrimeField, P: Params> + FieldOpsBounds<'a, TargetField, NonNativeFieldVar> + for &'a NonNativeFieldVar { } -impl FieldVar - for NonNativeFieldVar +impl FieldVar + for NonNativeFieldVar { fn zero() -> Self { Self::Constant(TargetField::zero()) @@ -147,13 +155,13 @@ impl FieldVar, + NonNativeFieldVar, TargetField, Add, add, AddAssign, add_assign, - |this: &'a NonNativeFieldVar, other: &'a NonNativeFieldVar| { + |this: &'a NonNativeFieldVar, other: &'a NonNativeFieldVar| { use NonNativeFieldVar::*; match (this, other) { (Constant(c1), Constant(c2)) => Constant(*c1 + c2), @@ -161,18 +169,18 @@ impl_bounded_ops!( (Var(v1), Var(v2)) => Var(v1.add(v2).unwrap()), } }, - |this: &'a NonNativeFieldVar, other: TargetField| { this + &NonNativeFieldVar::Constant(other) }, - (TargetField: PrimeField, BaseField: PrimeField), + |this: &'a NonNativeFieldVar, other: TargetField| { this + &NonNativeFieldVar::Constant(other) }, + (TargetField: PrimeField, BaseField: PrimeField, P: Params), ); impl_bounded_ops!( - NonNativeFieldVar, + NonNativeFieldVar, TargetField, Sub, sub, SubAssign, sub_assign, - |this: &'a NonNativeFieldVar, other: &'a NonNativeFieldVar| { + |this: &'a NonNativeFieldVar, other: &'a NonNativeFieldVar| { use NonNativeFieldVar::*; match (this, other) { (Constant(c1), Constant(c2)) => Constant(*c1 - c2), @@ -181,20 +189,20 @@ impl_bounded_ops!( (Var(v1), Var(v2)) => Var(v1.sub(v2).unwrap()), } }, - |this: &'a NonNativeFieldVar, other: TargetField| { + |this: &'a NonNativeFieldVar, other: TargetField| { this - &NonNativeFieldVar::Constant(other) }, - (TargetField: PrimeField, BaseField: PrimeField), + (TargetField: PrimeField, BaseField: PrimeField, P: Params), ); impl_bounded_ops!( - NonNativeFieldVar, + NonNativeFieldVar, TargetField, Mul, mul, MulAssign, mul_assign, - |this: &'a NonNativeFieldVar, other: &'a NonNativeFieldVar| { + |this: &'a NonNativeFieldVar, other: &'a NonNativeFieldVar| { use NonNativeFieldVar::*; match (this, other) { (Constant(c1), Constant(c2)) => Constant(*c1 * c2), @@ -202,21 +210,21 @@ impl_bounded_ops!( (Var(v1), Var(v2)) => Var(v1.mul(v2).unwrap()), } }, - |this: &'a NonNativeFieldVar, other: TargetField| { + |this: &'a NonNativeFieldVar, other: TargetField| { if other.is_zero() { NonNativeFieldVar::zero() } else { this * &NonNativeFieldVar::Constant(other) } }, - (TargetField: PrimeField, BaseField: PrimeField), + (TargetField: PrimeField, BaseField: PrimeField, P: Params), ); /// ************************************************************************* /// ************************************************************************* -impl EqGadget - for NonNativeFieldVar +impl EqGadget + for NonNativeFieldVar { #[tracing::instrument(target = "r1cs")] fn is_eq(&self, other: &Self) -> R1CSResult> { @@ -280,8 +288,8 @@ impl EqGadget } } -impl ToBitsGadget - for NonNativeFieldVar +impl ToBitsGadget + for NonNativeFieldVar { #[tracing::instrument(target = "r1cs")] fn to_bits_le(&self) -> R1CSResult>> { @@ -304,8 +312,8 @@ impl ToBitsGadget } } -impl ToBytesGadget - for NonNativeFieldVar +impl ToBytesGadget + for NonNativeFieldVar { /// Outputs the unique byte decomposition of `self` in *little-endian* /// form. @@ -331,8 +339,8 @@ impl ToBytesGadget } } -impl CondSelectGadget - for NonNativeFieldVar +impl CondSelectGadget + for NonNativeFieldVar { #[tracing::instrument(target = "r1cs")] fn conditionally_select( @@ -361,8 +369,8 @@ impl CondSelectGadget /// Uses two bits to perform a lookup into a table /// `b` is little-endian: `b[0]` is LSB. -impl TwoBitLookupGadget - for NonNativeFieldVar +impl TwoBitLookupGadget + for NonNativeFieldVar { type TableConstant = TargetField; @@ -383,8 +391,8 @@ impl TwoBitLookupGadget ThreeBitCondNegLookupGadget - for NonNativeFieldVar +impl + ThreeBitCondNegLookupGadget for NonNativeFieldVar { type TableConstant = TargetField; @@ -418,8 +426,8 @@ impl ThreeBitCondNegLookupGadget } } -impl AllocVar - for NonNativeFieldVar +impl AllocVar + for NonNativeFieldVar { fn new_variable>( cs: impl Into>, @@ -437,8 +445,8 @@ impl AllocVar ToConstraintFieldGadget - for NonNativeFieldVar +impl ToConstraintFieldGadget + for NonNativeFieldVar { #[tracing::instrument(target = "r1cs")] fn to_constraint_field(&self) -> R1CSResult>> { @@ -447,31 +455,35 @@ impl ToConstraintFieldGadget Ok(AllocatedNonNativeFieldVar::get_limbs_representations( - c, - OptimizationType::Weight, - )? - .into_iter() - .map(FpVar::constant) - .collect()), + Self::Constant(c) => Ok( + AllocatedNonNativeFieldVar::::get_limbs_representations( + c, + OptimizationType::Weight, + )? + .into_iter() + .map(FpVar::constant) + .collect(), + ), Self::Var(v) => v.to_constraint_field(), } } } -impl NonNativeFieldVar { +impl + NonNativeFieldVar +{ /// The `mul_without_reduce` for `NonNativeFieldVar` #[tracing::instrument(target = "r1cs")] pub fn mul_without_reduce( &self, other: &Self, - ) -> R1CSResult> { + ) -> R1CSResult> { match self { Self::Constant(c) => match other { Self::Constant(other_c) => Ok(NonNativeFieldMulResultVar::Constant(*c * other_c)), Self::Var(other_v) => { let self_v = - AllocatedNonNativeFieldVar::::new_constant( + AllocatedNonNativeFieldVar::::new_constant( self.cs(), c, )?; @@ -483,7 +495,7 @@ impl NonNativeFieldVar { let other_v = match other { Self::Constant(other_c) => { - AllocatedNonNativeFieldVar::::new_constant( + AllocatedNonNativeFieldVar::::new_constant( self.cs(), other_c, )? diff --git a/src/fields/nonnative/mul_result.rs b/src/fields/nonnative/mul_result.rs index a04e8d6e..91f090dd 100644 --- a/src/fields/nonnative/mul_result.rs +++ b/src/fields/nonnative/mul_result.rs @@ -1,4 +1,7 @@ -use super::{AllocatedNonNativeFieldMulResultVar, NonNativeFieldVar}; +use super::{ + params::{DefaultParams, Params}, + AllocatedNonNativeFieldMulResultVar, NonNativeFieldVar, +}; use ark_ff::PrimeField; use ark_relations::r1cs::Result as R1CSResult; @@ -12,15 +15,19 @@ use ark_relations::r1cs::Result as R1CSResult; /// This may help cut the number of reduce operations. #[derive(Debug)] #[must_use] -pub enum NonNativeFieldMulResultVar { +pub enum NonNativeFieldMulResultVar< + TargetField: PrimeField, + BaseField: PrimeField, + P: Params = DefaultParams, +> { /// as a constant Constant(TargetField), /// as an allocated gadget - Var(AllocatedNonNativeFieldMulResultVar), + Var(AllocatedNonNativeFieldMulResultVar), } -impl - NonNativeFieldMulResultVar +impl + NonNativeFieldMulResultVar { /// Create a zero `NonNativeFieldMulResultVar` (used for additions) pub fn zero() -> Self { @@ -34,7 +41,7 @@ impl /// Reduce the `NonNativeFieldMulResultVar` back to NonNativeFieldVar #[tracing::instrument(target = "r1cs")] - pub fn reduce(&self) -> R1CSResult> { + pub fn reduce(&self) -> R1CSResult> { match self { Self::Constant(c) => Ok(NonNativeFieldVar::Constant(*c)), Self::Var(v) => Ok(NonNativeFieldVar::Var(v.reduce()?)), @@ -42,17 +49,18 @@ impl } } -impl - From<&NonNativeFieldVar> - for NonNativeFieldMulResultVar +impl + From<&NonNativeFieldVar> + for NonNativeFieldMulResultVar { - fn from(src: &NonNativeFieldVar) -> Self { + fn from(src: &NonNativeFieldVar) -> Self { match src { NonNativeFieldVar::Constant(c) => NonNativeFieldMulResultVar::Constant(*c), NonNativeFieldVar::Var(v) => { NonNativeFieldMulResultVar::Var(AllocatedNonNativeFieldMulResultVar::< TargetField, BaseField, + P, >::from(v)) }, } @@ -60,13 +68,13 @@ impl } impl_bounded_ops!( - NonNativeFieldMulResultVar, + NonNativeFieldMulResultVar, TargetField, Add, add, AddAssign, add_assign, - |this: &'a NonNativeFieldMulResultVar, other: &'a NonNativeFieldMulResultVar| { + |this: &'a NonNativeFieldMulResultVar, other: &'a NonNativeFieldMulResultVar| { use NonNativeFieldMulResultVar::*; match (this, other) { (Constant(c1), Constant(c2)) => Constant(*c1 + c2), @@ -74,6 +82,6 @@ impl_bounded_ops!( (Var(v1), Var(v2)) => Var(v1.add(v2).unwrap()), } }, - |this: &'a NonNativeFieldMulResultVar, other: TargetField| { this + &NonNativeFieldMulResultVar::Constant(other) }, - (TargetField: PrimeField, BaseField: PrimeField), + |this: &'a NonNativeFieldMulResultVar, other: TargetField| { this + &NonNativeFieldMulResultVar::Constant(other) }, + (TargetField: PrimeField, BaseField: PrimeField, P: Params), ); diff --git a/src/fields/nonnative/params.rs b/src/fields/nonnative/params.rs index c7ad48b0..9f29804d 100644 --- a/src/fields/nonnative/params.rs +++ b/src/fields/nonnative/params.rs @@ -1,5 +1,35 @@ +use ark_ff::PrimeField; +use ark_std::fmt; + use super::NonNativeFieldConfig; +/// Type that provides non native arithmetic configuration for a given pair of fields +/// and optimization goal. +pub trait Params: fmt::Debug + 'static { + /// Provide non native parameters. + /// + /// This function should be pure -- return the same config for the same input. + fn get( + optimization_type: OptimizationType, + ) -> NonNativeFieldConfig; +} + +/// Default parameters implementation based on [`find_parameters`]. +#[derive(Debug)] +pub struct DefaultParams; + +impl Params for DefaultParams { + fn get( + optimization_type: OptimizationType, + ) -> NonNativeFieldConfig { + get_params( + TargetField::MODULUS_BIT_SIZE as usize, + BaseField::MODULUS_BIT_SIZE as usize, + optimization_type, + ) + } +} + /// Obtain the parameters from a `ConstraintSystem`'s cache or generate a new /// one #[must_use] diff --git a/src/fields/nonnative/reduce.rs b/src/fields/nonnative/reduce.rs index 31e5fe7a..a0157eca 100644 --- a/src/fields/nonnative/reduce.rs +++ b/src/fields/nonnative/reduce.rs @@ -1,4 +1,4 @@ -use super::{overhead, params::get_params, AllocatedNonNativeFieldVar}; +use super::{overhead, params::Params, AllocatedNonNativeFieldVar}; use crate::{ alloc::AllocVar, boolean::Boolean, @@ -56,12 +56,13 @@ pub fn bigint_to_basefield(bigint: &BigUint) -> BaseField } /// the collections of methods for reducing the presentations -pub struct Reducer { +pub struct Reducer { pub target_phantom: PhantomData, pub base_phantom: PhantomData, + pub params_phantom: PhantomData

, } -impl Reducer { +impl Reducer { /// convert limbs to bits (take at most `BaseField::MODULUS_BIT_SIZE as /// usize - 1` bits) This implementation would be more efficient than /// the original `to_bits` or `to_non_unique_bits` since we enforce that @@ -119,11 +120,13 @@ impl Reducer) -> R1CSResult<()> { - let new_elem = - AllocatedNonNativeFieldVar::new_witness(ns!(elem.cs(), "normal_form"), || { - Ok(elem.value().unwrap_or_default()) - })?; + pub fn reduce( + elem: &mut AllocatedNonNativeFieldVar, + ) -> R1CSResult<()> { + let new_elem = AllocatedNonNativeFieldVar::::new_witness( + ns!(elem.cs(), "normal_form"), + || Ok(elem.value().unwrap_or_default()), + )?; elem.conditional_enforce_equal(&new_elem, &Boolean::TRUE)?; *elem = new_elem; @@ -133,13 +136,9 @@ impl Reducer, + elem: &mut AllocatedNonNativeFieldVar, ) -> R1CSResult<()> { - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - elem.get_optimization_type(), - ); + let params =

::get::(elem.get_optimization_type()); let surfeit = overhead!(elem.num_of_additions_over_normal_form + BaseField::one()) + 1; if BaseField::MODULUS_BIT_SIZE as usize > 2 * params.bits_per_limb + surfeit + 1 { @@ -153,19 +152,15 @@ impl Reducer, - elem_other: &mut AllocatedNonNativeFieldVar, + elem: &mut AllocatedNonNativeFieldVar, + elem_other: &mut AllocatedNonNativeFieldVar, ) -> R1CSResult<()> { assert_eq!( elem.get_optimization_type(), elem_other.get_optimization_type() ); - let params = get_params( - TargetField::MODULUS_BIT_SIZE as usize, - BaseField::MODULUS_BIT_SIZE as usize, - elem.get_optimization_type(), - ); + let params =

::get::(elem.get_optimization_type()); if 2 * params.bits_per_limb + ark_std::log2(params.num_limbs) as usize > BaseField::MODULUS_BIT_SIZE as usize - 1 @@ -204,7 +199,7 @@ impl Reducer, + elem: &mut AllocatedNonNativeFieldVar, ) -> R1CSResult<()> { if elem.is_in_the_normal_form { return Ok(()); @@ -333,7 +328,10 @@ impl Reducer::limb_to_bits(&carry, surfeit + bits_per_limb)?; + Reducer::::limb_to_bits( + &carry, + surfeit + bits_per_limb, + )?; } } diff --git a/src/groups/curves/short_weierstrass/mod.rs b/src/groups/curves/short_weierstrass/mod.rs index c21822fe..37854238 100644 --- a/src/groups/curves/short_weierstrass/mod.rs +++ b/src/groups/curves/short_weierstrass/mod.rs @@ -1018,7 +1018,10 @@ mod test_sw_curve { ProjectiveVar::>::new_input(cs.clone(), || { Ok(point_out) })?; - let scalar = NonNativeFieldVar::new_input(cs.clone(), || Ok(scalar))?; + let scalar = + NonNativeFieldVar::::new_input(cs.clone(), || { + Ok(scalar) + })?; let mul = point_in.scalar_mul_le(scalar.to_bits_le().unwrap().iter())?; diff --git a/tests/to_constraint_field_test.rs b/tests/to_constraint_field_test.rs index d0a5ac02..de897cf6 100644 --- a/tests/to_constraint_field_test.rs +++ b/tests/to_constraint_field_test.rs @@ -10,8 +10,8 @@ fn to_constraint_field_test() { let cs = ConstraintSystem::::new_ref(); - let a = NonNativeFieldVar::Constant(F::from(12u8)); - let b = NonNativeFieldVar::new_input(cs.clone(), || Ok(F::from(6u8))).unwrap(); + let a = NonNativeFieldVar::::Constant(F::from(12u8)); + let b = NonNativeFieldVar::::new_input(cs.clone(), || Ok(F::from(6u8))).unwrap(); let b2 = &b + &b; From 41b25bf86c002f749ee9aedf816c7e8e93957933 Mon Sep 17 00:00:00 2001 From: Chris Sosnin Date: Wed, 20 Sep 2023 20:15:11 +0400 Subject: [PATCH 2/3] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b9d5cf67..22516d78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - [\#117](https://github.com/arkworks-rs/r1cs-std/pull/117) Fix result of `precomputed_base_scalar_mul_le` to not discard previous value. - [\#124](https://github.com/arkworks-rs/r1cs-std/pull/124) Fix `scalar_mul_le` constraints unsatisfiability when short Weierstrass point is zero. - [\#127](https://github.com/arkworks-rs/r1cs-std/pull/127) Convert `NonNativeFieldVar` constants to little-endian bytes instead of big-endian (`ToBytesGadget`). +- [\#129](https://github.com/arkworks-rs/r1cs-std/pull/129) Make the number of limbs and bit width configurable for `NonNativeFieldVar`. ### Breaking changes From ca738f83b85e774bea9cdd34c9a272963e067f45 Mon Sep 17 00:00:00 2001 From: Kristian Sosnin <48099298+slumber@users.noreply.github.com> Date: Fri, 29 Sep 2023 18:19:30 +0300 Subject: [PATCH 3/3] Update src/fields/nonnative/params.rs Co-authored-by: Pratyush Mishra --- src/fields/nonnative/params.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fields/nonnative/params.rs b/src/fields/nonnative/params.rs index 9f29804d..5838440f 100644 --- a/src/fields/nonnative/params.rs +++ b/src/fields/nonnative/params.rs @@ -14,7 +14,7 @@ pub trait Params: fmt::Debug + 'static { ) -> NonNativeFieldConfig; } -/// Default parameters implementation based on [`find_parameters`]. +/// Default parameters implementation based on [`get_params`]. #[derive(Debug)] pub struct DefaultParams;