Skip to content

Commit

Permalink
Fix bitwise rotate (#109)
Browse files Browse the repository at this point in the history
Co-authored-by: Xinding Wei <[email protected]>
  • Loading branch information
nyunyunyunyu and nyunyunyunyu authored Nov 3, 2023
1 parent c40a496 commit 262f5c5
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 10 deletions.
92 changes: 92 additions & 0 deletions halo2-base/src/gates/range/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,32 @@ pub trait RangeInstructions<F: ScalarField> {
self.gate().assert_bit(ctx, bit);
bit
}

/// Bitwise right rotate a by BIT bits. BIT and NUM_BITS must be determined at compile time.
///
/// Assumes 'a' is a NUM_BITS bit integer and 0 < NUM_BITS <= 128.
/// * `ctx`: [Context] to add the constraints to
/// * `a`: a [AssignedValue] value.
fn const_right_rotate<const BIT: usize, const NUM_BITS: usize>(
&self,
ctx: &mut Context<F>,
a: AssignedValue<F>,
) -> AssignedValue<F>
where
F: BigPrimeField;

/// Bitwise left rotate a by BIT bits. BIT and NUM_BITS must be determined at compile time.
///
/// Assumes 'a' is a NUM_BITS bit integer and 0 < NUM_BITS <= 128.
/// * `ctx`: [Context] to add the constraints to
/// * `a`: a [AssignedValue] value.
fn const_left_rotate<const BIT: usize, const NUM_BITS: usize>(
&self,
ctx: &mut Context<F>,
a: AssignedValue<F>,
) -> AssignedValue<F>
where
F: BigPrimeField;
}

/// # RangeChip
Expand Down Expand Up @@ -517,6 +543,41 @@ impl<F: ScalarField> RangeChip<F> {
}
last_limb
}

/// Bitwise right rotate a by <bit> bits. This function should never be called directly
/// because const bitwise rotation must be determined at compile time.
///
/// Assumes 'a' is a `num_bits` bit integer and `0 < num_bits <= F::CAPACITY`.
fn const_right_rotate_internal(
&self,
ctx: &mut Context<F>,
a: AssignedValue<F>,
bit: usize,
num_bits: usize,
) -> AssignedValue<F>
where
F: BigPrimeField,
{
assert!(0 < num_bits && num_bits <= F::CAPACITY as usize);
// Add a constrain a = l_witness << bit | r_wintess
let val = fe_to_biguint(a.value());
assert!(val.bits() <= num_bits as u64);
let (val_r, val_l) = val.div_mod_floor(&(BigUint::one() << bit));
let l_witness = ctx.load_witness(biguint_to_fe(&val_l));
let r_witness = ctx.load_witness(biguint_to_fe(&val_r));
let val_witness =
self.gate.mul_add(ctx, l_witness, Constant(self.gate.pow_of_two()[bit]), r_witness);
self.range_check(ctx, l_witness, num_bits - bit);
self.range_check(ctx, r_witness, bit);
ctx.constrain_equal(&a, &val_witness);
// Return (r_witness << (num_bits - bit)) | l_witness
self.gate.mul_add(
ctx,
r_witness,
Constant(self.gate.pow_of_two()[num_bits - bit]),
l_witness,
)
}
}

impl<F: ScalarField> RangeInstructions<F> for RangeChip<F> {
Expand Down Expand Up @@ -636,4 +697,35 @@ impl<F: ScalarField> RangeInstructions<F> for RangeChip<F> {
// last_limb will have the (k + 1)-th limb of `a - b + 2^{k * limb_bits}`, which is zero iff `a < b`
self.gate.is_zero(ctx, last_limb)
}

fn const_right_rotate<const BIT: usize, const NUM_BITS: usize>(
&self,
ctx: &mut Context<F>,
a: AssignedValue<F>,
) -> AssignedValue<F>
where
F: BigPrimeField,
{
let bit_to_shift = BIT % NUM_BITS;
if bit_to_shift == 0 {
return a;
};
self.const_right_rotate_internal(ctx, a, bit_to_shift, NUM_BITS)
}

fn const_left_rotate<const BIT: usize, const NUM_BITS: usize>(
&self,
ctx: &mut Context<F>,
a: AssignedValue<F>,
) -> AssignedValue<F>
where
F: BigPrimeField,
{
let bit_to_shift = BIT % NUM_BITS;
if bit_to_shift == 0 {
return a;
};
// left rotate by bit_to_shift == right rotate by (NUM_BITS - bit_to_shift)
self.const_right_rotate_internal(ctx, a, NUM_BITS - bit_to_shift, NUM_BITS)
}
}
55 changes: 45 additions & 10 deletions halo2-base/src/gates/tests/bitwise_rotate.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
gates::{
builder::{GateCircuitBuilder, GateThreadBuilder},
GateChip, GateInstructions,
builder::{GateThreadBuilder, RangeCircuitBuilder},
RangeChip, RangeInstructions,
},
halo2_proofs::{
plonk::keygen_pk,
Expand All @@ -13,6 +13,7 @@ use crate::{

use halo2_proofs_axiom::halo2curves::FieldExt;
use rand::rngs::OsRng;
use std::env;

use super::*;

Expand All @@ -25,19 +26,21 @@ fn test_bitwise_rotate_gen<const BIT: usize, const NUM_BITS: usize>(
expect_satisfied: bool,
) {
// first create proving and verifying key
let lookup_bits = 3;
env::set_var("LOOKUP_BITS", lookup_bits.to_string());
let mut builder = GateThreadBuilder::keygen();
let gate = GateChip::default();
let gate = RangeChip::<Fr>::default(lookup_bits);
let dummy_a = builder.main(0).load_witness(Fr::zero());
let result = if is_left {
gate.const_left_rotate_unsafe::<BIT, NUM_BITS>(builder.main(0), dummy_a)
gate.const_left_rotate::<BIT, NUM_BITS>(builder.main(0), dummy_a)
} else {
gate.const_right_rotate_unsafe::<BIT, NUM_BITS>(builder.main(0), dummy_a)
gate.const_right_rotate::<BIT, NUM_BITS>(builder.main(0), dummy_a)
};
// get the offsets of the indicator cells for later 'pranking'
let result_offsets = result.cell.unwrap().offset;
// set env vars
builder.config(k as usize, Some(9));
let circuit = GateCircuitBuilder::keygen(builder);
let circuit = RangeCircuitBuilder::keygen(builder);

let params = ParamsKZG::setup(k, OsRng);
// generate proving key
Expand All @@ -49,15 +52,16 @@ fn test_bitwise_rotate_gen<const BIT: usize, const NUM_BITS: usize>(

let gen_pf = || {
let mut builder = GateThreadBuilder::prover();
let gate = GateChip::default();
let gate = RangeChip::<Fr>::default(lookup_bits);
let a_witness = builder.main(0).load_witness(Fr::from_u128(a));
if is_left {
gate.const_left_rotate_unsafe::<BIT, NUM_BITS>(builder.main(0), a_witness)
gate.const_left_rotate::<BIT, NUM_BITS>(builder.main(0), a_witness);
} else {
gate.const_right_rotate_unsafe::<BIT, NUM_BITS>(builder.main(0), a_witness)
gate.const_right_rotate::<BIT, NUM_BITS>(builder.main(0), a_witness);
};
builder.main(0).advice[result_offsets] = Assigned::Trivial(Fr::from_u128(result_val));
let circuit = GateCircuitBuilder::prover(builder, vec![vec![]]); // no break points
builder.config(k as usize, Some(9));
let circuit = RangeCircuitBuilder::prover(builder, vec![vec![]]); // no break points
gen_proof(&params, &pk, circuit)
};

Expand Down Expand Up @@ -89,3 +93,34 @@ fn test_bitwise_rotate() {
// 1u128 >> 5 != 2047
test_bitwise_rotate_gen::<5, 128>(8, false, 1, 2047, false);
}

#[test]
#[should_panic]
fn test_bitwise_rotate_zero_num_bits() {
let lookup_bits = 3;
let mut builder = GateThreadBuilder::keygen();
let gate = RangeChip::<Fr>::default(lookup_bits);
let dummy_a = builder.main(0).load_witness(Fr::zero());
gate.const_left_rotate::<1, 0>(builder.main(0), dummy_a);
}

#[test]
#[should_panic]
fn test_bitwise_rotate_too_large_num_bits() {
let lookup_bits = 3;
let mut builder = GateThreadBuilder::keygen();
let gate = RangeChip::<Fr>::default(lookup_bits);
let dummy_a = builder.main(0).load_witness(Fr::zero());
gate.const_left_rotate::<1, 200>(builder.main(0), dummy_a);
}

#[test]
#[should_panic]
fn test_bitwise_rotate_value_overflow() {
let lookup_bits = 3;
let mut builder = GateThreadBuilder::keygen();
let gate = RangeChip::<Fr>::default(lookup_bits);
// 1 << 128
let dummy_a = builder.main(0).load_witness(Fr::from_raw([0, 0, 1, 0]));
gate.const_left_rotate::<1, 128>(builder.main(0), dummy_a);
}

0 comments on commit 262f5c5

Please sign in to comment.