Skip to content

Commit

Permalink
Add conversions for boolean <=> bigint type as needed
Browse files Browse the repository at this point in the history
  • Loading branch information
iangneal committed Aug 18, 2023
1 parent 0e67892 commit b747f33
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 14 deletions.
22 changes: 22 additions & 0 deletions circom/tests/type_conversions/bool_1.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
pragma circom 2.0.0;
// REQUIRES: circom
// RUN: rm -rf %t && mkdir %t && %circom --llvm -o %t %s | sed -n 's/.*Written successfully:.* \(.*\)/\1/p' | xargs cat | FileCheck --enable-var-scope %s

function binop_comp(a, b) {
return a > b;
}

template A(x) {
signal input in;
signal output out;

out <-- binop_comp(in, x);
}

component main = A(5);

//CHECK-LABEL: define i256 @binop_comp_{{[0-9]+}}
//CHECK-SAME: (i256* %0)
//CHECK: %call.fr_gt = call i1 @fr_gt(i256 %{{[0-9]+}}, i256 %{{[0-9]+}})
//CHECK: %[[RET:[0-9]+]] = zext i1 %call.fr_gt to i256
//CHECK: ret i256 %[[RET]]
31 changes: 31 additions & 0 deletions circom/tests/type_conversions/bool_2.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
pragma circom 2.0.0;
// REQUIRES: circom
// RUN: rm -rf %t && mkdir %t && %circom --llvm -o %t %s | sed -n 's/.*Written successfully:.* \(.*\)/\1/p' | xargs cat | FileCheck --enable-var-scope %s

function binop_bool(a, b) {
return a || b;
}

template A(x) {
signal input in;
signal output out;

var temp;
if (binop_bool(in, x)) {
temp = 1;
} else {
temp = 0;
}
out <-- temp;

//Essentially equivalent code:
// out <-- binop_bool(in, x);
}

component main = A(555);

//CHECK-LABEL: define i256 @binop_bool_{{[0-9]+}}
//CHECK-SAME: (i256* %0)
//CHECK: %call.fr_logic_or = call i1 @fr_logic_or(i1 %{{[0-9]+}}, i1 %{{[0-9]+}})
//CHECK: %[[RET:[0-9]+]] = zext i1 %call.fr_logic_or to i256
//CHECK: ret i256 %[[RET]]
24 changes: 24 additions & 0 deletions circom/tests/type_conversions/bool_3.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
pragma circom 2.0.0;
// REQUIRES: circom
// RUN: rm -rf %t && mkdir %t && %circom --llvm -o %t %s | sed -n 's/.*Written successfully:.* \(.*\)/\1/p' | xargs cat | FileCheck --enable-var-scope %s

template A(x) {
signal input in;
signal output out;

var z = 0;
if (in || x) {
z = 1;
}
out <-- z;
}

component main = A(99);

//CHECK-LABEL: define void @A_{{[0-9]+}}_run
//CHECK-SAME: ([0 x i256]* %0)
//CHECK: branch{{[0-9]+}}:
//CHECK: %[[VAL_PTR:[0-9]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i32 1
//CHECK: %[[VAL:[0-9]+]] = load i256, i256* %[[VAL_PTR]]
//CHECK: %[[BOOL:[0-9]+]] = icmp ne i256 %[[VAL]], 0
//CHECK: %call.fr_logic_or = call i1 @fr_logic_or(i1 %[[BOOL]], i1 true)
27 changes: 27 additions & 0 deletions circom/tests/type_conversions/bool_4.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
pragma circom 2.0.0;
// REQUIRES: circom
// RUN: rm -rf %t && mkdir %t && %circom --llvm -o %t %s | sed -n 's/.*Written successfully:.* \(.*\)/\1/p' | xargs cat | FileCheck --enable-var-scope %s

function binop_bool_array(a, b) {
var arr[10];
for (var i = 0; i < 10; i++) {
arr[i] = a[i] || b[i];
}
return arr;
}

template A() {
signal input in1[10];
signal input in2[10];
signal output out[10];

out <-- binop_bool_array(in1, in2);
}

component main = A();

//CHECK-LABEL: define void @binop_bool_array_{{[0-9]+}}
//CHECK-SAME: (i256* %0)
//CHECK: %call.fr_logic_or = call i1 @fr_logic_or(i1 %{{[0-9]+}}, i1 %{{[0-9]+}})
//CHECK: %[[VAL:[0-9]+]] = zext i1 %call.fr_logic_or to i256
//CHECK: store i256 %[[VAL]], i256* %{{[0-9]+}}
115 changes: 101 additions & 14 deletions code_producers/src/llvm_elements/instructions.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
use inkwell::basic_block::BasicBlock;
use inkwell::IntPredicate::{EQ, NE, SLT, SGT, SLE, SGE};
use inkwell::types::{AnyTypeEnum, PointerType};
use inkwell::values::{AnyValue, AnyValueEnum, BasicMetadataValueEnum, BasicValue, BasicValueEnum, FunctionValue, InstructionOpcode, InstructionValue, IntMathValue, IntValue, PhiValue, PointerValue};
use inkwell::types::{AnyTypeEnum, PointerType, IntType};
use inkwell::values::{
AnyValue, AnyValueEnum, BasicMetadataValueEnum, BasicValue, BasicValueEnum, FunctionValue,
InstructionOpcode, InstructionValue, IntMathValue, IntValue, PhiValue, PointerValue,
};
use crate::llvm_elements::{LLVMIRProducer};
use crate::llvm_elements::fr::{FR_MUL_FN_NAME, FR_LT_FN_NAME};
use crate::llvm_elements::functions::create_bb;
use crate::llvm_elements::types::{bigint_type, i32_type};

use super::types::bool_type;

// bigint abv;
// if (rhs < 0)
// abv = -rhs;
Expand Down Expand Up @@ -437,13 +442,73 @@ pub fn create_bit_xor<'a, T: IntMathValue<'a>>(
create_bit_xor_with_name(producer, lhs, rhs, "")
}

pub fn ensure_bool_with_name<'a, T: IntMathValue<'a>>(
producer: &dyn LLVMIRProducer<'a>,
val: T,
name: &str,
) -> AnyValueEnum<'a> {
if val.as_basic_value_enum().into_int_value().get_type() != bool_type(producer) {
create_neq_with_name(
producer,
val.as_basic_value_enum().into_int_value(),
bigint_type(producer).const_zero(),
name,
)
} else {
val.as_any_value_enum()
}
}

pub fn ensure_bool<'a, T: IntMathValue<'a>>(
producer: &dyn LLVMIRProducer<'a>,
val: T,
) -> AnyValueEnum<'a> {
ensure_bool_with_name(producer, val, "")
}

pub fn ensure_int_type_match<'a>(
producer: &dyn LLVMIRProducer<'a>,
val: IntValue<'a>,
ty: IntType<'a>,
) -> IntValue<'a> {
if val.get_type() == ty {
// No conversion needed
val
} else if val.get_type() == bool_type(producer) {
// Zero extend
producer.llvm().builder.build_int_z_extend(val, ty, "")
} else if ty == bool_type(producer) {
// Convert to bool
ensure_bool(producer, val).into_int_value()
} else {
panic!(
"Unhandled int conversion of value '{:?}': {:?} to {:?} not supported!",
val,
val.get_type(),
ty
)
}
}

macro_rules! conditional_name {
($fmt: expr, $name: expr) => {{
if $name.is_empty() { $name.to_string() } else { format!($fmt, $name) }.as_ref()
}};
}

pub fn create_logic_and_with_name<'a, T: IntMathValue<'a>>(
producer: &dyn LLVMIRProducer<'a>,
lhs: T,
rhs: T,
name: &str,
) -> AnyValueEnum<'a> {
producer.llvm().builder.build_and(lhs, rhs, name).as_any_value_enum()
let bool_lhs =
ensure_bool_with_name(producer, lhs, conditional_name!("bool_cast_lhs.{}", name))
.into_int_value();
let bool_rhs =
ensure_bool_with_name(producer, rhs, conditional_name!("bool_cast_rhs.{}", name))
.into_int_value();
producer.llvm().builder.build_and(bool_lhs, bool_rhs, name).as_any_value_enum()
}

pub fn create_logic_and<'a, T: IntMathValue<'a>>(
Expand All @@ -460,7 +525,13 @@ pub fn create_logic_or_with_name<'a, T: IntMathValue<'a>>(
rhs: T,
name: &str,
) -> AnyValueEnum<'a> {
producer.llvm().builder.build_or(lhs, rhs, name).as_any_value_enum()
let bool_lhs =
ensure_bool_with_name(producer, lhs, conditional_name!("bool_cast_lhs.{}", name))
.into_int_value();
let bool_rhs =
ensure_bool_with_name(producer, rhs, conditional_name!("bool_cast_rhs.{}", name))
.into_int_value();
producer.llvm().builder.build_or(bool_lhs, bool_rhs, name).as_any_value_enum()
}

pub fn create_logic_or<'a, T: IntMathValue<'a>>(
Expand All @@ -476,7 +547,9 @@ pub fn create_logic_not_with_name<'a, T: IntMathValue<'a>>(
val: T,
name: &str,
) -> AnyValueEnum<'a> {
producer.llvm().builder.build_not(val, name).as_any_value_enum()
let bool_val = ensure_bool_with_name(producer, val, conditional_name!("bool_cast.{}", name))
.into_int_value();
producer.llvm().builder.build_not(bool_val, name).as_any_value_enum()
}

pub fn create_logic_not<'a, T: IntMathValue<'a>>(
Expand All @@ -493,7 +566,10 @@ pub fn create_store<'a>(
) -> AnyValueEnum<'a> {
match value {
AnyValueEnum::ArrayValue(v) => producer.llvm().builder.build_store(ptr, v),
AnyValueEnum::IntValue(v) => producer.llvm().builder.build_store(ptr, v),
AnyValueEnum::IntValue(v) => {
let store_ty = ptr.get_type().get_element_type().into_int_type();
producer.llvm().builder.build_store(ptr, ensure_int_type_match(producer, v, store_ty))
}
AnyValueEnum::FloatValue(v) => producer.llvm().builder.build_store(ptr, v),
AnyValueEnum::PointerValue(v) => producer.llvm().builder.build_store(ptr, v),
AnyValueEnum::StructValue(v) => producer.llvm().builder.build_store(ptr, v),
Expand Down Expand Up @@ -535,10 +611,27 @@ pub fn create_call<'a>(
arguments: &[BasicMetadataValueEnum<'a>],
) -> AnyValueEnum<'a> {
let f = find_function(producer, name);
let params = f.get_params();
let checked_arguments: Vec<BasicMetadataValueEnum<'a>> = arguments
.into_iter()
.zip(params.into_iter())
.map(|(arg, param)| {
if arg.is_int_value() && param.is_int_value() {
ensure_int_type_match(
producer,
arg.into_int_value(),
param.get_type().into_int_type(),
)
.into()
} else {
arg.clone()
}
})
.collect();
producer
.llvm()
.builder
.build_call(f, arguments, format!("call.{}", name).as_str())
.build_call(f, &checked_arguments, format!("call.{}", name).as_str())
.as_any_value_enum()
}

Expand All @@ -548,13 +641,7 @@ pub fn create_conditional_branch<'a>(
then_block: BasicBlock<'a>,
else_block: BasicBlock<'a>,
) -> AnyValueEnum<'a> {
let comparison_type = comparison.get_type();
let bool_ty = producer.llvm().module.get_context().bool_type();
let bool_comparison = if comparison_type != bool_ty {
create_neq(producer, comparison, comparison_type.const_zero()).into_int_value()
} else {
comparison
};
let bool_comparison = ensure_bool(producer, comparison).into_int_value();
producer
.llvm()
.builder
Expand Down

0 comments on commit b747f33

Please sign in to comment.