diff --git a/circom/tests/subcmps/subcmps2.circom b/circom/tests/subcmps/subcmps2.circom new file mode 100644 index 000000000..6f1668e7e --- /dev/null +++ b/circom/tests/subcmps/subcmps2.circom @@ -0,0 +1,43 @@ +pragma circom 2.0.6; +// REQUIRES: circom +// RUN: rm -rf %t && mkdir %t && %circom --llvm -o %t %s | sed -n 's/.*Written successfully:.* \(.*\)/\1/p' | xargs cat | FileCheck %s + +template Sum(n) { + signal input inp[n]; + signal output outp; + + var s = 0; + + for (var i = 0; i < n; i++) { + s += inp[i]; + } + + outp <== s; +} + +function nop(i) { + return i; +} + +template Caller() { + signal input inp[4]; + signal output outp; + + component s = Sum(4); + + for (var i = 0; i < 4; i++) { + s.inp[i] <== nop(inp[i]); + } + + outp <== s.outp; +} + +component main = Caller(); + +//CHECK-LABEL: define void @Caller_{{[0-9]+}}_run +//CHECK-SAME: ([0 x i256]* %0) +//CHECK: %[[CALL_VAL:call\.nop_[0-3]]] = call i256 @nop_{{[0-3]}}(i256* %6) +//CHECK: %[[SUBCMP_PTR:.*]] = getelementptr [1 x { [0 x i256]*, i32 }], [1 x { [0 x i256]*, i32 }]* %subcmps, i32 0, i32 0, i32 {{[0-3]}} +//CHECK: %[[SUBCMP:.*]] = load [0 x i256]*, [0 x i256]** %[[SUBCMP_PTR]] +//CHECK: %[[SUBCMP_INP:.*]] = getelementptr [0 x i256], [0 x i256]* %[[SUBCMP]], i32 0, i32 {{[1-4]}} +//CHECK: store i256 %[[CALL_VAL]], i256* %[[SUBCMP_INP]] diff --git a/circuit_passes/src/passes/deterministic_subcomponent_invocation.rs b/circuit_passes/src/passes/deterministic_subcomponent_invocation.rs index aa7fa6d30..949bbda55 100644 --- a/circuit_passes/src/passes/deterministic_subcomponent_invocation.rs +++ b/circuit_passes/src/passes/deterministic_subcomponent_invocation.rs @@ -22,38 +22,42 @@ impl DeterministicSubCmpInvokePass { replacements: Default::default(), } } -} - -impl InterpreterObserver for DeterministicSubCmpInvokePass { - fn on_value_bucket(&self, _bucket: &ValueBucket, _env: &Env) -> bool { - true - } - - fn on_load_bucket(&self, _bucket: &LoadBucket, _env: &Env) -> bool { - true - } - fn on_store_bucket(&self, bucket: &StoreBucket, env: &Env) -> bool { - let env = env.clone(); - let mem = self.memory.borrow(); - let interpreter = mem.build_interpreter(self); - // If the address of the subcomponent input information is unk, then - // If executing this store bucket would result in calling the subcomponent we replace it with Last + pub fn try_resolve_input_status(&self, address_type: &AddressType, env: &Env) { + // If the address of the subcomponent input information is unknown, then + // If executing this instruction would result in calling the subcomponent we replace it with Last // Will result in calling if the counter is at one because after the execution it will be 0 // If not replace with NoLast if let AddressType::SubcmpSignal { input_information: InputInformation::Input { status: StatusInput::Unknown }, cmp_address, .. - } = &bucket.dest_address_type + } = address_type { + let env = env.clone(); + let mem = self.memory.borrow(); + let interpreter = mem.build_interpreter(self); let (addr, env) = interpreter.execute_instruction(cmp_address, env, false); let addr = addr - .expect("cmp_address instruction in StoreBucket SubcmpSignal must produce a value!") + .expect("cmp_address instruction in SubcmpSignal must produce a value!") .get_u32(); let new_status = if env.subcmp_counter_equal_to(addr, 1) { Last } else { NoLast }; - self.replacements.borrow_mut().insert(bucket.dest_address_type.clone(), new_status); + self.replacements.borrow_mut().insert(address_type.clone(), new_status); } + } +} + +impl InterpreterObserver for DeterministicSubCmpInvokePass { + fn on_value_bucket(&self, _bucket: &ValueBucket, _env: &Env) -> bool { + true + } + + fn on_load_bucket(&self, _bucket: &LoadBucket, _env: &Env) -> bool { + true + } + + fn on_store_bucket(&self, bucket: &StoreBucket, env: &Env) -> bool { + self.try_resolve_input_status(&bucket.dest_address_type, env); true } @@ -89,8 +93,14 @@ impl InterpreterObserver for DeterministicSubCmpInvokePass { true } - fn on_call_bucket(&self, _bucket: &CallBucket, _env: &Env) -> bool { - true + fn on_call_bucket(&self, bucket: &CallBucket, env: &Env) -> bool { + match &bucket.return_info { + ReturnType::Intermediate {..} => true, + ReturnType::Final(data) => { + self.try_resolve_input_status(&data.dest_address_type, env); + true + }, + } } fn on_branch_bucket(&self, _bucket: &BranchBucket, _env: &Env) -> bool { diff --git a/compiler/src/intermediate_representation/store_bucket.rs b/compiler/src/intermediate_representation/store_bucket.rs index 67ef371d1..73982ebe4 100644 --- a/compiler/src/intermediate_representation/store_bucket.rs +++ b/compiler/src/intermediate_representation/store_bucket.rs @@ -189,7 +189,7 @@ impl StoreBucket{ match status { StatusInput::Last => { let run_fn = run_fn_name(sub_cmp_name.expect("Could not get the name of the subcomponent")); - // If we reach this point gep is the address of the subcomponent so we ca just reuse it + // If we reach this point gep is the address of the subcomponent so we can just reuse it let addr = cmp_address.produce_llvm_ir(producer).expect("The address of a subcomponent must yield a value!"); let subcmp = producer.template_ctx().load_subcmp_addr(producer, addr); create_call(producer, run_fn.as_str(), &[subcmp.into()]);