Skip to content

Commit

Permalink
Add cpu recursion circuit (#883)
Browse files Browse the repository at this point in the history
  • Loading branch information
sai-deng authored Nov 27, 2023
1 parent 0edf3f0 commit d3db446
Show file tree
Hide file tree
Showing 13 changed files with 1,098 additions and 33 deletions.
25 changes: 24 additions & 1 deletion circuits/src/cpu/add.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
//! This module implements the constraints for the ADD operation.
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::Field;
use starky::constraint_consumer::ConstraintConsumer;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};

use super::columns::CpuState;

Expand All @@ -20,6 +24,25 @@ pub(crate) fn constraints<P: PackedField>(
yield_constr.constraint(lv.inst.ops.add * (lv.dst_value - added) * (lv.dst_value - wrapped));
}

pub(crate) fn constraints_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &CpuState<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let wrap_at = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32));
let added = builder.add_extension(lv.op1_value, lv.op2_value);
let wrapped = builder.sub_extension(added, wrap_at);
let dst_value_sub_added = builder.sub_extension(lv.dst_value, added);
let dst_value_sub_wrapped = builder.sub_extension(lv.dst_value, wrapped);
let dst_value_sub_added_mul_dst_value_sub_wrapped =
builder.mul_extension(dst_value_sub_added, dst_value_sub_wrapped);
let constr = builder.mul_extension(
lv.inst.ops.add,
dst_value_sub_added_mul_dst_value_sub_wrapped,
);
yield_constr.constraint(builder, constr);
}

#[cfg(test)]
#[allow(clippy::cast_possible_wrap)]
mod tests {
Expand Down
84 changes: 83 additions & 1 deletion circuits/src/cpu/bitwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
//! x | y := (x + y + (x ^ y)) / 2
//! `
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::Field;
use starky::constraint_consumer::ConstraintConsumer;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};

use super::columns::CpuState;
use crate::xor::columns::XorView;
Expand All @@ -36,6 +40,12 @@ pub struct BinaryOp<P: PackedField> {
pub output: P,
}

pub struct BinaryOpExtensionTarget<const D: usize> {
pub input_a: ExtensionTarget<D>,
pub input_b: ExtensionTarget<D>,
pub output: ExtensionTarget<D>,
}

/// Re-usable gadget for AND constraints.
/// It has access to already constrained XOR evaluation and based on that
/// constrains the AND evaluation: `x & y := (x + y - xor(x,y)) / 2`
Expand All @@ -49,6 +59,21 @@ pub(crate) fn and_gadget<P: PackedField>(xor: &XorView<P>) -> BinaryOp<P> {
}
}

pub(crate) fn and_gadget_extension_targets<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
xor: &XorView<ExtensionTarget<D>>,
) -> BinaryOpExtensionTarget<D> {
let two = F::Extension::from_canonical_u64(2);
let two_inv = builder.constant_extension(two.inverse());
let a_add_b = builder.add_extension(xor.a, xor.b);
let a_add_b_sub_xor = builder.sub_extension(a_add_b, xor.out);
BinaryOpExtensionTarget {
input_a: xor.a,
input_b: xor.b,
output: builder.mul_extension(a_add_b_sub_xor, two_inv),
}
}

/// Re-usable gadget for OR constraints
/// It has access to already constrained XOR evaluation and based on that
/// constrains the OR evaluation: `x | y := (x + y + xor(x,y)) / 2`
Expand All @@ -62,6 +87,21 @@ pub(crate) fn or_gadget<P: PackedField>(xor: &XorView<P>) -> BinaryOp<P> {
}
}

pub(crate) fn or_gadget_extension_targets<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
xor: &XorView<ExtensionTarget<D>>,
) -> BinaryOpExtensionTarget<D> {
let two = F::Extension::from_canonical_u64(2);
let two_inv = builder.constant_extension(two.inverse());
let a_add_b = builder.add_extension(xor.a, xor.b);
let a_add_b_add_xor = builder.add_extension(a_add_b, xor.out);
BinaryOpExtensionTarget {
input_a: xor.a,
input_b: xor.b,
output: builder.mul_extension(a_add_b_add_xor, two_inv),
}
}

/// Re-usable gadget for XOR constraints
/// Constrains that the already constrained underlying XOR evaluation has been
/// done on the same inputs and produced the same output as this gadget.
Expand All @@ -74,6 +114,16 @@ pub(crate) fn xor_gadget<P: PackedField>(xor: &XorView<P>) -> BinaryOp<P> {
}
}

pub(crate) fn xor_gadget_extension_targets<const D: usize>(
xor: &XorView<ExtensionTarget<D>>,
) -> BinaryOpExtensionTarget<D> {
BinaryOpExtensionTarget {
input_a: xor.a,
input_b: xor.b,
output: xor.out,
}
}

/// Constraints for the AND, OR and XOR opcodes.
/// As each opcode has an associated selector, we use selectors to enable only
/// the correct opcode constraints. It can be that all selectors are not active,
Expand All @@ -100,6 +150,38 @@ pub(crate) fn constraints<P: PackedField>(
}
}

pub(crate) fn constraints_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &CpuState<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let op1 = lv.op1_value;
let op2 = lv.op2_value;
let dst = lv.dst_value;

for (selector, gadget) in [
(
lv.inst.ops.and,
and_gadget_extension_targets(builder, &lv.xor),
),
(
lv.inst.ops.or,
or_gadget_extension_targets(builder, &lv.xor),
),
(lv.inst.ops.xor, xor_gadget_extension_targets(&lv.xor)),
] {
let input_a = builder.sub_extension(gadget.input_a, op1);
let input_b = builder.sub_extension(gadget.input_b, op2);
let output = builder.sub_extension(gadget.output, dst);
let constr = builder.mul_extension(selector, input_a);
yield_constr.constraint(builder, constr);
let constr = builder.mul_extension(selector, input_b);
yield_constr.constraint(builder, constr);
let constr = builder.mul_extension(selector, output);
yield_constr.constraint(builder, constr);
}
}

#[cfg(test)]
#[allow(clippy::cast_possible_wrap)]
mod tests {
Expand Down
99 changes: 96 additions & 3 deletions circuits/src/cpu/branches.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
//! This module implements constraints for the branch operations.
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::Field;
use starky::constraint_consumer::ConstraintConsumer;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};

use super::columns::CpuState;
use crate::stark::utils::is_binary;
use super::columns::{signed_diff_extension_target, CpuState};
use crate::stark::utils::{is_binary, is_binary_ext_circuit};

/// Constraints for `less_than` and `normalised_diff`
/// For `less_than`:
Expand Down Expand Up @@ -46,6 +50,38 @@ pub(crate) fn comparison_constraints<P: PackedField>(
yield_constr.constraint(lt * (P::ONES - lv.normalised_diff));
}

pub(crate) fn comparison_constraints_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &CpuState<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let lt = lv.less_than;
is_binary_ext_circuit(builder, lt, yield_constr);

let one = builder.constant_extension(F::Extension::ONE);
let one_sub_lt = builder.sub_extension(one, lt);
let signed_diff = signed_diff_extension_target(builder, lv);
let abs_diff_sub_signed_diff = builder.sub_extension(lv.abs_diff, signed_diff);
let constr = builder.mul_extension(one_sub_lt, abs_diff_sub_signed_diff);
yield_constr.constraint(builder, constr);

let abs_diff_add_signed_diff = builder.add_extension(lv.abs_diff, signed_diff);
let constr = builder.mul_extension(lt, abs_diff_add_signed_diff);
yield_constr.constraint(builder, constr);

is_binary_ext_circuit(builder, lv.normalised_diff, yield_constr);
let one_sub_normalised_diff = builder.sub_extension(one, lv.normalised_diff);
let constr = builder.mul_extension(signed_diff, one_sub_normalised_diff);
yield_constr.constraint(builder, constr);

let signed_diff_mul_cmp_diff_inv = builder.mul_extension(signed_diff, lv.cmp_diff_inv);
let constr = builder.sub_extension(signed_diff_mul_cmp_diff_inv, lv.normalised_diff);
yield_constr.constraint(builder, constr);

let lt_mul_one_sub_normalised_diff = builder.mul_extension(lt, one_sub_normalised_diff);
yield_constr.constraint(builder, lt_mul_one_sub_normalised_diff);
}

/// Constraints for conditional branch operations
pub(crate) fn constraints<P: PackedField>(
lv: &CpuState<P>,
Expand Down Expand Up @@ -81,6 +117,63 @@ pub(crate) fn constraints<P: PackedField>(
yield_constr.constraint(ops.bne * (P::ONES - lv.normalised_diff) * (next_pc - bumped_pc));
}

pub(crate) fn constraints_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &CpuState<ExtensionTarget<D>>,
nv: &CpuState<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let ops = &lv.inst.ops;
let is_blt = ops.blt;
let is_bge = ops.bge;

let four = builder.constant_extension(F::Extension::from_noncanonical_u64(4));
let bumped_pc = builder.add_extension(lv.inst.pc, four);
let branched_pc = lv.inst.imm_value;
let next_pc = nv.inst.pc;

let lt = lv.less_than;

let is_blt_mul_lt = builder.mul_extension(is_blt, lt);
let next_pc_sub_branched_pc = builder.sub_extension(next_pc, branched_pc);
let constr = builder.mul_extension(is_blt_mul_lt, next_pc_sub_branched_pc);
yield_constr.constraint(builder, constr);

let one = builder.constant_extension(F::Extension::ONE);
let one_sub_lt = builder.sub_extension(one, lt);
let next_pc_sub_bumped_pc = builder.sub_extension(next_pc, bumped_pc);
let is_blt_mul_one_sub_lt = builder.mul_extension(is_blt, one_sub_lt);
let constr = builder.mul_extension(is_blt_mul_one_sub_lt, next_pc_sub_bumped_pc);
yield_constr.constraint(builder, constr);

let is_bge_mul_lt = builder.mul_extension(is_bge, lt);
let constr = builder.mul_extension(is_bge_mul_lt, next_pc_sub_bumped_pc);
yield_constr.constraint(builder, constr);

let is_bge_mul_one_sub_lt = builder.mul_extension(is_bge, one_sub_lt);
let constr = builder.mul_extension(is_bge_mul_one_sub_lt, next_pc_sub_branched_pc);
yield_constr.constraint(builder, constr);

let one_sub_normalised_diff = builder.sub_extension(one, lv.normalised_diff);
let is_beq_mul_one_sub_normalised_diff =
builder.mul_extension(ops.beq, one_sub_normalised_diff);
let constr = builder.mul_extension(is_beq_mul_one_sub_normalised_diff, next_pc_sub_branched_pc);
yield_constr.constraint(builder, constr);

let is_beq_mul_normalised_diff = builder.mul_extension(ops.beq, lv.normalised_diff);
let constr = builder.mul_extension(is_beq_mul_normalised_diff, next_pc_sub_bumped_pc);
yield_constr.constraint(builder, constr);

let is_bne_mul_normalised_diff = builder.mul_extension(ops.bne, lv.normalised_diff);
let constr = builder.mul_extension(is_bne_mul_normalised_diff, next_pc_sub_branched_pc);
yield_constr.constraint(builder, constr);

let is_bne_mul_one_sub_normalised_diff =
builder.mul_extension(ops.bne, one_sub_normalised_diff);
let constr = builder.mul_extension(is_bne_mul_one_sub_normalised_diff, next_pc_sub_bumped_pc);
yield_constr.constraint(builder, constr);
}

#[cfg(test)]
#[allow(clippy::cast_possible_wrap)]
mod tests {
Expand Down
53 changes: 53 additions & 0 deletions circuits/src/cpu/columns.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;

use crate::bitshift::columns::Bitshift;
use crate::columns_view::{columns_view_impl, make_col_map, NumberOfColumns};
use crate::cpu::stark::add_extension_vec;
use crate::cross_table_lookup::Column;
use crate::program::columns::ProgramRom;
use crate::stark::mozak_stark::{CpuTable, Table};
Expand Down Expand Up @@ -213,6 +218,45 @@ impl<T: PackedField> CpuState<T> {
pub fn signed_diff(&self) -> T { self.op1_full_range() - self.op2_full_range() }
}

pub fn rs2_value_extension_target<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
cpu: &CpuState<ExtensionTarget<D>>,
) -> ExtensionTarget<D> {
let mut rs2_value = builder.zero_extension();
for reg in 0..32 {
let rs2_select = builder.mul_extension(cpu.inst.rs2_select[reg], cpu.regs[reg]);
rs2_value = builder.add_extension(rs2_value, rs2_select);
}
rs2_value
}

pub fn op1_full_range_extension_target<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
cpu: &CpuState<ExtensionTarget<D>>,
) -> ExtensionTarget<D> {
let shifted_32 = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32));
let op1_sign_bit = builder.mul_extension(cpu.op1_sign_bit, shifted_32);
builder.sub_extension(cpu.op1_value, op1_sign_bit)
}

pub fn op2_full_range_extension_target<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
cpu: &CpuState<ExtensionTarget<D>>,
) -> ExtensionTarget<D> {
let shifted_32 = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32));
let op2_sign_bit = builder.mul_extension(cpu.op2_sign_bit, shifted_32);
builder.sub_extension(cpu.op2_value, op2_sign_bit)
}

pub fn signed_diff_extension_target<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
cpu: &CpuState<ExtensionTarget<D>>,
) -> ExtensionTarget<D> {
let op1_full_range = op1_full_range_extension_target(builder, cpu);
let op2_full_range = op2_full_range_extension_target(builder, cpu);
builder.sub_extension(op1_full_range, op2_full_range)
}

/// Expressions we need to range check
///
/// Currently, we only support expressions over the
Expand Down Expand Up @@ -385,6 +429,15 @@ impl<T: core::ops::Add<Output = T>> OpSelectors<T> {
pub fn is_mem_ops(self) -> T { self.sb + self.lb + self.sh + self.lh + self.sw + self.lw }
}

pub fn is_mem_op_extention_target<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
ops: &OpSelectors<ExtensionTarget<D>>,
) -> ExtensionTarget<D> {
add_extension_vec(builder, vec![
ops.sb, ops.lb, ops.sh, ops.lh, ops.sw, ops.lw,
])
}

/// Columns containing the data to be matched against `Bitshift` stark.
/// [`CpuTable`](crate::cross_table_lookup::CpuTable).
#[must_use]
Expand Down
Loading

0 comments on commit d3db446

Please sign in to comment.