Skip to content

Commit

Permalink
Add array constraint generation
Browse files Browse the repository at this point in the history
  • Loading branch information
iangneal committed Oct 10, 2023
1 parent 54f504c commit 06ed0a0
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 22 deletions.
73 changes: 73 additions & 0 deletions circom/tests/subcmps/array_copy_constraints.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
pragma circom 2.0.0;
// REQUIRES: circom
// RUN: rm -rf %t && mkdir %t && %circom --llvm -o %t %s | sed -n 's/.*Written successfully:.* \(.*\)/\1/p' | xargs cat | FileCheck %s

template Sum(n) {
signal input inp[n];
signal output outp;

var acc = 0;
for (var i = 0; i < n; i++) {
acc += inp[i];
}

outp <== acc;
}

template Caller(n) {
signal input inp[n];
signal outp;

component op = Sum(n);
op.inp <== inp;

outp <== op.outp;
}

component main = Caller(5);

//CHECK-LABEL: define void @Caller_{{[0-9]+}}_run
//CHECK-SAME: ([0 x i256]* %0)
//CHECK: store{{[0-9]+}}: ; preds = %create_cmp{{[0-9]+}}
//CHECK: %[[SUBCMP_PTR:[0-9]+]] = getelementptr [1 x { [0 x i256]*, i32 }], [1 x { [0 x i256]*, i32 }]* %subcmps, i32 0, i32 0, i32 0
//CHECK: %[[SUBCMP_INP_ARR:[0-9]+]] = load [0 x i256]*, [0 x i256]** %[[SUBCMP_PTR]]
//CHECK: %[[SUBCMP_INP_PTR:[0-9]+]] = getelementptr [0 x i256], [0 x i256]* %[[SUBCMP_INP_ARR]], i32 0, i32 1
//CHECK: %[[INP_PTR:[0-9]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i32 0
//CHECK: call void @fr_copy_n(i256* %[[INP_PTR]], i256* %[[SUBCMP_INP_PTR]], i32 5)
//CHECK: %decrement.counter = sub i32 %load.subcmp.counter, 5
//CHECK: call void @Sum_{{[0-9]+}}_run([0 x i256]* %{{[0-9]+}})

//CHECK: %[[INP_PTR_0:[0-9]+]] = getelementptr i256, i256* %[[INP_PTR]], i32 0
//CHECK: %[[INP_0:[0-9]+]] = load i256, i256* %[[INP_PTR_0]]
//CHECK: %[[SUBCMP_INP_PTR_0:[0-9]+]] = getelementptr i256, i256* %[[SUBCMP_INP_PTR]], i32 0
//CHECK: %[[SUBCMP_INP_0:[0-9]+]] = load i256, i256* %[[SUBCMP_INP_PTR_0]]
//CHECK: %constraint_0 = alloca i1
//CHECK: call void @__constraint_values(i256 %[[INP_0]], i256 %[[SUBCMP_INP_0]], i1* %constraint_0)

//CHECK: %[[INP_PTR_1:[0-9]+]] = getelementptr i256, i256* %[[INP_PTR]], i32 1
//CHECK: %[[INP_1:[0-9]+]] = load i256, i256* %[[INP_PTR_1]]
//CHECK: %[[SUBCMP_INP_PTR_1:[0-9]+]] = getelementptr i256, i256* %[[SUBCMP_INP_PTR]], i32 1
//CHECK: %[[SUBCMP_INP_1:[0-9]+]] = load i256, i256* %[[SUBCMP_INP_PTR_1]]
//CHECK: %constraint_1 = alloca i1
//CHECK: call void @__constraint_values(i256 %[[INP_1]], i256 %[[SUBCMP_INP_1]], i1* %constraint_1)

//CHECK: %[[INP_PTR_2:[0-9]+]] = getelementptr i256, i256* %[[INP_PTR]], i32 2
//CHECK: %[[INP_2:[0-9]+]] = load i256, i256* %[[INP_PTR_2]]
//CHECK: %[[SUBCMP_INP_PTR_2:[0-9]+]] = getelementptr i256, i256* %[[SUBCMP_INP_PTR]], i32 2
//CHECK: %[[SUBCMP_INP_2:[0-9]+]] = load i256, i256* %[[SUBCMP_INP_PTR_2]]
//CHECK: %constraint_2 = alloca i1
//CHECK: call void @__constraint_values(i256 %[[INP_2]], i256 %[[SUBCMP_INP_2]], i1* %constraint_2)

//CHECK: %[[INP_PTR_3:[0-9]+]] = getelementptr i256, i256* %[[INP_PTR]], i32 3
//CHECK: %[[INP_3:[0-9]+]] = load i256, i256* %[[INP_PTR_3]]
//CHECK: %[[SUBCMP_INP_PTR_3:[0-9]+]] = getelementptr i256, i256* %[[SUBCMP_INP_PTR]], i32 3
//CHECK: %[[SUBCMP_INP_3:[0-9]+]] = load i256, i256* %[[SUBCMP_INP_PTR_3]]
//CHECK: %constraint_3 = alloca i1
//CHECK: call void @__constraint_values(i256 %[[INP_3]], i256 %[[SUBCMP_INP_3]], i1* %constraint_3)

//CHECK: %[[INP_PTR_4:[0-9]+]] = getelementptr i256, i256* %[[INP_PTR]], i32 4
//CHECK: %[[INP_4:[0-9]+]] = load i256, i256* %[[INP_PTR_4]]
//CHECK: %[[SUBCMP_INP_PTR_4:[0-9]+]] = getelementptr i256, i256* %[[SUBCMP_INP_PTR]], i32 4
//CHECK: %[[SUBCMP_INP_4:[0-9]+]] = load i256, i256* %[[SUBCMP_INP_PTR_4]]
//CHECK: %constraint_4 = alloca i1
//CHECK: call void @__constraint_values(i256 %[[INP_4]], i256 %[[SUBCMP_INP_4]], i1* %constraint_4)
8 changes: 6 additions & 2 deletions code_producers/src/llvm_elements/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ impl<'a> TopLevelLLVMIRProducer<'a> {
pub type LLVMAdapter<'a> = &'a Rc<RefCell<LLVM<'a>>>;
pub type BigIntType<'a> = IntType<'a>; // i256

pub fn new_constraint<'a>(producer: &dyn LLVMIRProducer<'a>) -> AnyValueEnum<'a> {
let alloca = create_alloca(producer, bool_type(producer).into(), "constraint");
pub fn new_constraint_with_name<'a>(producer: &dyn LLVMIRProducer<'a>, name: &str) -> AnyValueEnum<'a> {
let alloca = create_alloca(producer, bool_type(producer).into(), name);
let s = producer.context().metadata_string("constraint");
let kind = producer.context().get_kind_id("constraint");
let node = producer.context().metadata_node(&[s.into()]);
Expand All @@ -169,6 +169,10 @@ pub fn new_constraint<'a>(producer: &dyn LLVMIRProducer<'a>) -> AnyValueEnum<'a>
alloca
}

pub fn new_constraint<'a>(producer: &dyn LLVMIRProducer<'a>) -> AnyValueEnum<'a> {
new_constraint_with_name(producer, "constraint")
}

#[inline]
pub fn any_value_wraps_basic_value(v: AnyValueEnum) -> bool {
match BasicValueEnum::try_from(v) {
Expand Down
87 changes: 70 additions & 17 deletions compiler/src/intermediate_representation/constraint_bucket.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use code_producers::c_elements::CProducer;
use code_producers::llvm_elements::types::bigint_type;
use code_producers::llvm_elements::values::create_literal_u32;
use code_producers::llvm_elements::{
LLVMInstruction, new_constraint, to_basic_metadata_enum, LLVMIRProducer,
LLVMInstruction, new_constraint, to_basic_metadata_enum, LLVMIRProducer, AnyType, new_constraint_with_name,
};
use code_producers::llvm_elements::instructions::{create_call, create_load, get_instruction_arg};
use code_producers::llvm_elements::instructions::{create_call, create_load, get_instruction_arg, create_gep};
use code_producers::llvm_elements::stdlib::{CONSTRAINT_VALUE_FN_NAME, CONSTRAINT_VALUES_FN_NAME};
use code_producers::wasm_elements::WASMProducer;
use crate::intermediate_representation::{Instruction, InstructionPointer, SExp, ToSExp, UpdateId};
Expand Down Expand Up @@ -112,21 +114,72 @@ impl WriteLLVMIR for ConstraintBucket {
const ASSERT_IDX: u32 = 0;

match self {
ConstraintBucket::Substitution(_) => {
let lhs = get_instruction_arg(prev.into_instruction_value(), STORE_DST_IDX);
let rhs_ptr = get_instruction_arg(prev.into_instruction_value(), STORE_SRC_IDX);
let rhs = create_load(producer, rhs_ptr.into_pointer_value());
let constr = new_constraint(producer);
let call = create_call(
producer,
CONSTRAINT_VALUES_FN_NAME,
&[
to_basic_metadata_enum(lhs),
to_basic_metadata_enum(rhs),
to_basic_metadata_enum(constr),
],
);
Some(call)
ConstraintBucket::Substitution(i) => {
let size = match i.as_ref() {
Instruction::Value(_) => todo!(),
Instruction::Load(_) => todo!(),
Instruction::Store(b) => b.context.size,
Instruction::Compute(_) => todo!(),
Instruction::Call(b) => {
for arg_ty in &b.argument_types {
if arg_ty.size > 1 {
todo!("not yet handling call arg array logic");
}
assert_ne!(0, arg_ty.size, "size should be non-zero");
}
1
},
Instruction::Branch(_) => todo!(),
Instruction::Return(_) => todo!(),
Instruction::Assert(_) => todo!(),
Instruction::Log(_) => todo!(),
Instruction::Loop(_) => todo!(),
Instruction::CreateCmp(_) => todo!(),
Instruction::Constraint(_) => todo!(),
Instruction::Block(_) => todo!(),
Instruction::Nop(_) => todo!(),
};
assert_ne!(0, size, "must have non-zero size");
if size == 1 {
let lhs = get_instruction_arg(prev.into_instruction_value(), STORE_DST_IDX);
assert_eq!(bigint_type(producer).as_any_type_enum(), lhs.get_type(), "wrong type");
let rhs_ptr = get_instruction_arg(prev.into_instruction_value(), STORE_SRC_IDX);
let rhs = create_load(producer, rhs_ptr.into_pointer_value());
let constr = new_constraint(producer);
let call = create_call(
producer,
CONSTRAINT_VALUES_FN_NAME,
&[
to_basic_metadata_enum(lhs),
to_basic_metadata_enum(rhs),
to_basic_metadata_enum(constr),
],
);
Some(call)
} else {
let lhs_ptr = get_instruction_arg(prev.into_instruction_value(), STORE_DST_IDX).into_pointer_value();
assert_eq!(bigint_type(producer).ptr_type(Default::default()), lhs_ptr.get_type(), "wrong type");
let rhs_ptr = get_instruction_arg(prev.into_instruction_value(), STORE_SRC_IDX).into_pointer_value();

let constraint_calls: Vec<_> = (0..size).map(|i| {
let idx = create_literal_u32(producer, i as u64);
let lhs = create_load(producer, create_gep(producer, lhs_ptr, &[idx]).into_pointer_value());
let rhs = create_load(producer, create_gep(producer, rhs_ptr, &[idx]).into_pointer_value());
let constr = new_constraint_with_name(producer, format!("constraint_{}", i).as_str());
create_call(
producer,
CONSTRAINT_VALUES_FN_NAME,
&[
to_basic_metadata_enum(lhs),
to_basic_metadata_enum(rhs),
to_basic_metadata_enum(constr),
],
)
}).collect();

Some(constraint_calls.last().expect("must have >1 value!").clone())
}

}
ConstraintBucket::Equality(_) => {
let bool = get_instruction_arg(prev.into_instruction_value(), ASSERT_IDX);
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/intermediate_representation/store_bucket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ impl StoreBucket{
// If we have bounds for an unknown index, we will get the base address and let the function check the bounds
let store = match &bounded_fn {
Some(name) => {
assert_eq!(1, context.size, "unhandled array store");
let arr_ptr = match &dest_address_type {
AddressType::Variable => producer.body_ctx().get_variable_array(producer),
AddressType::Signal => producer.template_ctx().get_signal_array(producer),
Expand Down Expand Up @@ -168,13 +169,12 @@ impl StoreBucket{
}
};

// If we have a subcomponent storage decrement the counter
// If we have a subcomponent storage decrement the counter by the size of the store (i.e., context.size)
if let AddressType::SubcmpSignal { cmp_address, .. } = &dest_address_type {
let addr = cmp_address.produce_llvm_ir(producer).expect("The address of a subcomponent must yield a value!");
let counter = producer.template_ctx().load_subcmp_counter(producer, addr);
let value = create_load_with_name(producer, counter, "load.subcmp.counter");
let new_value = create_sub_with_name(producer, value.into_int_value(), create_literal_u32(producer, 1), "decrement.counter");
assert_eq!(1, context.size, "unhandled array store");
let new_value = create_sub_with_name(producer, value.into_int_value(), create_literal_u32(producer, context.size as u64), "decrement.counter");
create_store(producer, counter, new_value);
}

Expand Down

0 comments on commit 06ed0a0

Please sign in to comment.