Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VAN-503] Convert to/from i1 as needed #47

Merged
merged 5 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions circom/tests/type_conversions/bool_1.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
pragma circom 2.0.0;
// REQUIRES: circom
// RUN: rm -rf %t && mkdir %t && %circom --llvm -o %t %s | sed -n 's/.*Written successfully:.* \(.*\)/\1/p' | xargs cat | FileCheck --enable-var-scope %s

function binop_comp(a, b) {
return a > b;
}

template A(x) {
signal input in;
signal output out;

out <-- binop_comp(in, x);
}

component main = A(5);

//CHECK-LABEL: define i256 @binop_comp_{{[0-9]+}}
//CHECK-SAME: (i256* %0)
//CHECK: %call.fr_gt = call i1 @fr_gt(i256 %{{[0-9]+}}, i256 %{{[0-9]+}})
//CHECK: %[[RET:[0-9]+]] = zext i1 %call.fr_gt to i256
//CHECK: ret i256 %[[RET]]
31 changes: 31 additions & 0 deletions circom/tests/type_conversions/bool_2.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
pragma circom 2.0.0;
// REQUIRES: circom
// RUN: rm -rf %t && mkdir %t && %circom --llvm -o %t %s | sed -n 's/.*Written successfully:.* \(.*\)/\1/p' | xargs cat | FileCheck --enable-var-scope %s

function binop_bool(a, b) {
return a || b;
}

template A(x) {
signal input in;
signal output out;

var temp;
if (binop_bool(in, x)) {
temp = 1;
} else {
temp = 0;
}
out <-- temp;

//Essentially equivalent code:
// out <-- binop_bool(in, x);
}

component main = A(555);

//CHECK-LABEL: define i256 @binop_bool_{{[0-9]+}}
//CHECK-SAME: (i256* %0)
//CHECK: %call.fr_logic_or = call i1 @fr_logic_or(i1 %{{[0-9]+}}, i1 %{{[0-9]+}})
//CHECK: %[[RET:[0-9]+]] = zext i1 %call.fr_logic_or to i256
//CHECK: ret i256 %[[RET]]
25 changes: 25 additions & 0 deletions circom/tests/type_conversions/bool_3.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
pragma circom 2.0.0;
// REQUIRES: circom
// RUN: rm -rf %t && mkdir %t && %circom --llvm -o %t %s | sed -n 's/.*Written successfully:.* \(.*\)/\1/p' | xargs cat | FileCheck --enable-var-scope %s

template A(x) {
signal input in;
signal output out;

var z = 0;
if (in || x) {
z = 1;
}
out <-- z;
}

component main = A(99);

//CHECK-LABEL: define void @A_{{[0-9]+}}_run
//CHECK-SAME: ([0 x i256]* %0)
//CHECK: branch{{[0-9]+}}:
//CHECK: %[[VAL_PTR:[0-9]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i32 1
//CHECK: %[[VAL:[0-9]+]] = load i256, i256* %[[VAL_PTR]]
//CHECK: %[[BOOL:[0-9]+]] = icmp ne i256 %[[VAL]], 0
//CHECK: %call.fr_logic_or = call i1 @fr_logic_or(i1 %[[BOOL]], i1 true)
iangneal marked this conversation as resolved.
Show resolved Hide resolved
//CHECK: br i1 %call.fr_logic_or, label %if.then, label %if.else
27 changes: 27 additions & 0 deletions circom/tests/type_conversions/bool_4.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
pragma circom 2.0.0;
// REQUIRES: circom
// RUN: rm -rf %t && mkdir %t && %circom --llvm -o %t %s | sed -n 's/.*Written successfully:.* \(.*\)/\1/p' | xargs cat | FileCheck --enable-var-scope %s

function binop_bool_array(a, b) {
var arr[10];
for (var i = 0; i < 10; i++) {
arr[i] = a[i] || b[i];
}
return arr;
}

template A() {
signal input in1[10];
signal input in2[10];
signal output out[10];

out <-- binop_bool_array(in1, in2);
}

component main = A();

//CHECK-LABEL: define void @binop_bool_array_{{[0-9]+}}
//CHECK-SAME: (i256* %0)
//CHECK: %call.fr_logic_or = call i1 @fr_logic_or(i1 %{{[0-9]+}}, i1 %{{[0-9]+}})
//CHECK: %[[VAL:[0-9]+]] = zext i1 %call.fr_logic_or to i256
//CHECK: store i256 %[[VAL]], i256* %{{[0-9]+}}
106 changes: 94 additions & 12 deletions code_producers/src/llvm_elements/instructions.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
use inkwell::basic_block::BasicBlock;
use inkwell::IntPredicate::{EQ, NE, SLT, SGT, SLE, SGE};
use inkwell::types::{AnyTypeEnum, PointerType};
use inkwell::values::{AnyValue, AnyValueEnum, BasicMetadataValueEnum, BasicValue, BasicValueEnum, FunctionValue, InstructionOpcode, InstructionValue, IntMathValue, IntValue, PhiValue, PointerValue};
use inkwell::types::{AnyTypeEnum, PointerType, IntType};
use inkwell::values::{
AnyValue, AnyValueEnum, BasicMetadataValueEnum, BasicValue, BasicValueEnum, FunctionValue,
InstructionOpcode, InstructionValue, IntMathValue, IntValue, PhiValue, PointerValue,
};
use crate::llvm_elements::{LLVMIRProducer};
use crate::llvm_elements::fr::{FR_MUL_FN_NAME, FR_LT_FN_NAME};
use crate::llvm_elements::functions::create_bb;
use crate::llvm_elements::types::{bigint_type, i32_type};

use super::types::bool_type;

// bigint abv;
// if (rhs < 0)
// abv = -rhs;
Expand Down Expand Up @@ -437,6 +442,50 @@ pub fn create_bit_xor<'a, T: IntMathValue<'a>>(
create_bit_xor_with_name(producer, lhs, rhs, "")
}

pub fn ensure_bool_with_name<'a, T: IntMathValue<'a>>(
producer: &dyn LLVMIRProducer<'a>,
val: T,
name: &str,
) -> AnyValueEnum<'a> {
let int_val = val.as_basic_value_enum().into_int_value();
if int_val.get_type() != bool_type(producer) {
create_neq_with_name(producer, int_val, bigint_type(producer).const_zero(), name)
} else {
val.as_any_value_enum()
}
}

pub fn ensure_bool<'a, T: IntMathValue<'a>>(
producer: &dyn LLVMIRProducer<'a>,
val: T,
) -> AnyValueEnum<'a> {
ensure_bool_with_name(producer, val, "")
}

pub fn ensure_int_type_match<'a>(
producer: &dyn LLVMIRProducer<'a>,
val: IntValue<'a>,
ty: IntType<'a>,
) -> IntValue<'a> {
if val.get_type() == ty {
// No conversion needed
val
} else if val.get_type() == bool_type(producer) {
// Zero extend
producer.llvm().builder.build_int_z_extend(val, ty, "")
} else if ty == bool_type(producer) {
// Convert to bool
ensure_bool(producer, val).into_int_value()
} else {
panic!(
"Unhandled int conversion of value '{:?}': {:?} to {:?} not supported!",
val,
val.get_type(),
ty
)
}
}

pub fn create_logic_and_with_name<'a, T: IntMathValue<'a>>(
producer: &dyn LLVMIRProducer<'a>,
lhs: T,
Expand Down Expand Up @@ -493,7 +542,10 @@ pub fn create_store<'a>(
) -> AnyValueEnum<'a> {
match value {
AnyValueEnum::ArrayValue(v) => producer.llvm().builder.build_store(ptr, v),
AnyValueEnum::IntValue(v) => producer.llvm().builder.build_store(ptr, v),
AnyValueEnum::IntValue(v) => {
let store_ty = ptr.get_type().get_element_type().into_int_type();
producer.llvm().builder.build_store(ptr, ensure_int_type_match(producer, v, store_ty))
}
AnyValueEnum::FloatValue(v) => producer.llvm().builder.build_store(ptr, v),
AnyValueEnum::PointerValue(v) => producer.llvm().builder.build_store(ptr, v),
AnyValueEnum::StructValue(v) => producer.llvm().builder.build_store(ptr, v),
Expand All @@ -511,7 +563,26 @@ pub fn create_return<'a, V: BasicValue<'a>>(
producer: &dyn LLVMIRProducer<'a>,
val: V,
) -> AnyValueEnum<'a> {
producer.llvm().builder.build_return(Some(&val)).as_any_value_enum()
let f = producer
.llvm()
.builder
.get_insert_block()
.expect("no current block!")
.get_parent()
.expect("no current function!");
let ret_ty =
f.get_type().get_return_type().expect("non-void function should have a return type!");
let ret_val = if ret_ty.is_int_type() {
ensure_int_type_match(
producer,
val.as_basic_value_enum().into_int_value(),
ret_ty.into_int_type(),
)
.as_basic_value_enum()
} else {
val.as_basic_value_enum()
};
producer.llvm().builder.build_return(Some(&ret_val)).as_any_value_enum()
}

pub fn create_br<'a>(producer: &dyn LLVMIRProducer<'a>, bb: BasicBlock<'a>) -> AnyValueEnum<'a> {
Expand All @@ -535,10 +606,27 @@ pub fn create_call<'a>(
arguments: &[BasicMetadataValueEnum<'a>],
) -> AnyValueEnum<'a> {
let f = find_function(producer, name);
let params = f.get_params();
let checked_arguments: Vec<BasicMetadataValueEnum<'a>> = arguments
.into_iter()
.zip(params.into_iter())
.map(|(arg, param)| {
if arg.is_int_value() && param.is_int_value() {
ensure_int_type_match(
producer,
arg.into_int_value(),
param.get_type().into_int_type(),
)
.into()
} else {
arg.clone()
}
})
.collect();
producer
.llvm()
.builder
.build_call(f, arguments, format!("call.{}", name).as_str())
.build_call(f, &checked_arguments, format!("call.{}", name).as_str())
.as_any_value_enum()
}

Expand All @@ -548,13 +636,7 @@ pub fn create_conditional_branch<'a>(
then_block: BasicBlock<'a>,
else_block: BasicBlock<'a>,
) -> AnyValueEnum<'a> {
let comparison_type = comparison.get_type();
let bool_ty = producer.llvm().module.get_context().bool_type();
let bool_comparison = if comparison_type != bool_ty {
create_neq(producer, comparison, comparison_type.const_zero()).into_int_value()
} else {
comparison
};
let bool_comparison = ensure_bool(producer, comparison).into_int_value();
producer
.llvm()
.builder
Expand Down