From 5cb4ffd5f125971ce7055c99714294152d8bf647 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Mon, 9 Oct 2023 09:18:52 +0000 Subject: [PATCH 1/6] feat: add support for functions This commit extends the syntax and semantics for functions in AirScript: * You can now define functions that take and produce values, using the `fn` keyword, see below for an example. * You can now call value-producing functions in expresssion position, i.e. as operands to a binary operator, as an iterable for use by comprehensions, etc. * Functions are strongly typed, and are type checked to ensure that both the function itself, and all call sites, are correctly typed. Example: ``` fn madd3(a: felt[3], b: felt) -> felt { let d = [c * b for c in a] return sum(d) } ``` In this example: * Two parameters are bound, `a` and `b`, a vector of 3 felts, and a felt, respectively * The return type is declared to be felt * The body of the function can be arbitrarily complex, i.e. you can define variables, comprehensions, etc. * Not illustrated here, but all of the usual global bindings (i.e. trace columns, public inputs, random values, etc.) are in scope and can be referenced. Things you cannot do with functions: * A function _must_ return a value, i.e. it cannot have an empty body * A function _may not_ contain constraints * You may call functions within a function, but recursion is not supported, i.e. an error will be raised if a call is made to a function which is already on the call stack --- air-script/tests/codegen/masm.rs | 25 + air-script/tests/codegen/winterfell.rs | 26 + .../tests/functions/functions_complex.air | 45 ++ .../tests/functions/functions_complex.masm | 166 ++++ .../tests/functions/functions_complex.rs | 91 +++ .../tests/functions/functions_simple.air | 90 +++ .../tests/functions/functions_simple.masm | 166 ++++ .../tests/functions/functions_simple.rs | 98 +++ .../functions/inlined_functions_simple.air | 52 ++ codegen/masm/tests/test_aux.rs | 4 +- codegen/masm/tests/test_boundary.rs | 8 +- codegen/masm/tests/test_constants.rs | 2 +- codegen/masm/tests/test_divisor.rs | 6 +- codegen/masm/tests/test_periodic.rs | 4 +- ir/src/passes/translate.rs | 10 +- parser/src/ast/declarations.rs | 51 ++ parser/src/ast/display.rs | 19 +- parser/src/ast/errors.rs | 26 + parser/src/ast/expression.rs | 15 + parser/src/ast/mod.rs | 34 +- parser/src/ast/module.rs | 22 + parser/src/ast/statement.rs | 27 +- parser/src/ast/types.rs | 7 +- parser/src/ast/visit.rs | 42 + parser/src/lexer/mod.rs | 17 +- parser/src/lexer/tests/functions.rs | 87 ++ parser/src/lexer/tests/mod.rs | 1 + parser/src/parser/grammar.lalrpop | 60 ++ parser/src/parser/tests/functions.rs | 539 +++++++++++++ parser/src/parser/tests/inlining.rs | 193 ++++- parser/src/parser/tests/mod.rs | 7 + parser/src/sema/errors.rs | 5 +- parser/src/sema/semantic_analysis.rs | 64 ++ parser/src/transforms/constant_propagation.rs | 5 + parser/src/transforms/inlining.rs | 754 +++++++++++++++--- 35 files changed, 2604 insertions(+), 164 deletions(-) create mode 100644 air-script/tests/functions/functions_complex.air create mode 100644 air-script/tests/functions/functions_complex.masm create mode 100644 air-script/tests/functions/functions_complex.rs create mode 100644 air-script/tests/functions/functions_simple.air create mode 100644 air-script/tests/functions/functions_simple.masm create mode 100644 air-script/tests/functions/functions_simple.rs create mode 100644 air-script/tests/functions/inlined_functions_simple.air create mode 100644 parser/src/lexer/tests/functions.rs create mode 100644 parser/src/parser/tests/functions.rs diff --git a/air-script/tests/codegen/masm.rs b/air-script/tests/codegen/masm.rs index 983c2b60..b82414a0 100644 --- a/air-script/tests/codegen/masm.rs +++ b/air-script/tests/codegen/masm.rs @@ -84,6 +84,31 @@ fn evaluators() { expected.assert_eq(&generated_masm); } +#[test] +fn functions() { + let generated_masm = Test::new("tests/functions/functions_simple.air".to_string()) + .transpile(Target::Masm) + .unwrap(); + + let expected = expect_file!["../functions/functions_simple.masm"]; + expected.assert_eq(&generated_masm); + + // make sure that the constraints generated using inlined functions are the same as the ones + // generated using regular functions + let generated_masm = Test::new("tests/functions/inlined_functions_simple.air".to_string()) + .transpile(Target::Masm) + .unwrap(); + let expected = expect_file!["../functions/functions_simple.masm"]; + expected.assert_eq(&generated_masm); + + // let generated_masm = Test::new("tests/functions/functions_complex.air".to_string()) + // .transpile(Target::Masm) + // .unwrap(); + + // let expected = expect_file!["../functions/functions_complex.masm"]; + // expected.assert_eq(&generated_masm); +} + #[test] fn variables() { let generated_masm = Test::new("tests/variables/variables.air".to_string()) diff --git a/air-script/tests/codegen/winterfell.rs b/air-script/tests/codegen/winterfell.rs index caad525f..99117a58 100644 --- a/air-script/tests/codegen/winterfell.rs +++ b/air-script/tests/codegen/winterfell.rs @@ -84,6 +84,32 @@ fn evaluators() { expected.assert_eq(&generated_air); } +#[test] +fn functions() { + let generated_air = Test::new("tests/functions/functions_simple.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../functions/functions_simple.rs"]; + expected.assert_eq(&generated_air); + + // make sure that the constraints generated using inlined functions are the same as the ones + // generated using regular functions + let generated_air = Test::new("tests/functions/inlined_functions_simple.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../functions/functions_simple.rs"]; + expected.assert_eq(&generated_air); + + // let generated_air = Test::new("tests/functions/functions_complex.air".to_string()) + // .transpile(Target::Winterfell) + // .unwrap(); + + // let expected = expect_file!["../functions/functions_complex.rs"]; + // expected.assert_eq(&generated_air); +} + #[test] fn variables() { let generated_air = Test::new("tests/variables/variables.air".to_string()) diff --git a/air-script/tests/functions/functions_complex.air b/air-script/tests/functions/functions_complex.air new file mode 100644 index 00000000..b72c0437 --- /dev/null +++ b/air-script/tests/functions/functions_complex.air @@ -0,0 +1,45 @@ +def FunctionsAir + +fn get_multiplicity_flags(s0: felt, s1: felt) -> felt[4] { + return [!s0 & !s1, s0 & !s1, !s0 & s1, s0 & s1] +} + +fn fold_vec(a: felt[12]) -> felt { + return sum([x for x in a]) +} + +fn fold_scalar_and_vec(a: felt, b: felt[12]) -> felt { + let m = fold_vec(b) + let n = m + 1 + let o = n * 2 + return o +} + +trace_columns { + main: [t, s0, s1, v, b[12]] + aux: [b_range] +} + +public_inputs { + stack_inputs: [16] +} + +random_values { + alpha: [16] +} + +boundary_constraints { + enf v.first = 0 +} + +integrity_constraints { + # let val = $alpha[0] + v + let f = get_multiplicity_flags(s0, s1) + let z = v^4 * f[3] + v^2 * f[2] + v * f[1] + f[0] + # let folded_value = fold_scalar_and_vec(v, b) + # enf b_range' = b_range * (z * t - t + 1) + enf b_range' = b_range * 2 + # let y = fold_scalar_and_vec(v, b) + # let c = fold_scalar_and_vec(t, b) + # enf v' = y +} diff --git a/air-script/tests/functions/functions_complex.masm b/air-script/tests/functions/functions_complex.masm new file mode 100644 index 00000000..a145bd1b --- /dev/null +++ b/air-script/tests/functions/functions_complex.masm @@ -0,0 +1,166 @@ +# Procedure to efficiently compute the required exponentiations of the out-of-domain point `z` and cache them for later use. +# +# This computes the power of `z` needed to evaluate the periodic polynomials and the constraint divisors +# +# Input: [...] +# Output: [...] +proc.cache_z_exp + padw mem_loadw.4294903304 drop drop # load z + # => [z_1, z_0, ...] + # Exponentiate z trace_len times + mem_load.4294903307 neg + # => [count, z_1, z_0, ...] where count = -log2(trace_len) + dup.0 neq.0 + while.true + movdn.2 dup.1 dup.1 ext2mul + # => [(e_1, e_0)^n, i, ...] + movup.2 add.1 dup.0 neq.0 + # => [b, i+1, (e_1, e_0)^n, ...] + end # END while + push.0 mem_storew.500000100 # z^trace_len + # => [0, 0, (z_1, z_0)^trace_len, ...] + dropw # Clean stack +end # END PROC cache_z_exp + +# Procedure to compute the exemption points. +# +# Input: [...] +# Output: [g^{-2}, g^{-1}, ...] +proc.get_exemptions_points + mem_load.4294799999 + # => [g, ...] + push.1 swap div + # => [g^{-1}, ...] + dup.0 dup.0 mul + # => [g^{-2}, g^{-1}, ...] +end # END PROC get_exemptions_points + +# Procedure to compute the integrity constraint divisor. +# +# The divisor is defined as `(z^trace_len - 1) / ((z - g^{trace_len-2}) * (z - g^{trace_len-1}))` +# Procedure `cache_z_exp` must have been called prior to this. +# +# Input: [...] +# Output: [divisor_1, divisor_0, ...] +proc.compute_integrity_constraint_divisor + padw mem_loadw.500000100 drop drop # load z^trace_len + # Comments below use zt = `z^trace_len` + # => [zt_1, zt_0, ...] + push.1 push.0 ext2sub + # => [zt_1-1, zt_0-1, ...] + padw mem_loadw.4294903304 drop drop # load z + # => [z_1, z_0, zt_1-1, zt_0-1, ...] + exec.get_exemptions_points + # => [g^{trace_len-2}, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...] + dup.0 mem_store.500000101 # Save a copy of `g^{trace_len-2} to be used by the boundary divisor + dup.3 dup.3 movup.3 push.0 ext2sub + # => [e_1, e_0, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...] + movup.4 movup.4 movup.4 push.0 ext2sub + # => [e_3, e_2, e_1, e_0, zt_1-1, zt_0-1, ...] + ext2mul + # => [denominator_1, denominator_0, zt_1-1, zt_0-1, ...] + ext2div + # => [divisor_1, divisor_0, ...] +end # END PROC compute_integrity_constraint_divisor + +# Procedure to evaluate numerators of all integrity constraints. +# +# All the 1 main and 1 auxiliary constraints are evaluated. +# The result of each evaluation is kept on the stack, with the top of the stack +# containing the evaluations for the auxiliary trace (if any) followed by the main trace. +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# where: (r_1, r_0) is the quadratic extension element resulting from the integrity constraint evaluation. +# This procedure pushes 2 quadratic extension field elements to the stack +proc.compute_integrity_constraints + # integrity constraint 0 for main + padw mem_loadw.4294900003 drop drop padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900008 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900009 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900010 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900011 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900012 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900013 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900014 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900015 movdn.3 movdn.3 drop drop ext2add push.1 push.0 ext2add push.2 push.0 ext2mul ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900200 movdn.3 movdn.3 drop drop ext2mul + # integrity constraint 0 for aux + padw mem_loadw.4294900072 drop drop padw mem_loadw.4294900072 movdn.3 movdn.3 drop drop padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add + # push the accumulator to the stack + push.1 movdn.2 push.0 movdn.2 + # => [b1, b0, r1, r0, ...] + # square 2 times + dup.1 dup.1 ext2mul dup.1 dup.1 ext2mul + # multiply + dup.1 dup.1 movdn.5 movdn.5 + # => [b1, b0, r1, r0, b1, b0, ...] (4 cycles) + ext2mul movdn.3 movdn.3 + # => [b1, b0, r1', r0', ...] (5 cycles) + # clean stack + drop drop + # => [r1, r0, ...] (2 cycles) + padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add + # push the accumulator to the stack + push.1 movdn.2 push.0 movdn.2 + # => [b1, b0, r1, r0, ...] + # square 1 times + dup.1 dup.1 ext2mul + # multiply + dup.1 dup.1 movdn.5 movdn.5 + # => [b1, b0, r1, r0, b1, b0, ...] (4 cycles) + ext2mul movdn.3 movdn.3 + # => [b1, b0, r1', r0', ...] (5 cycles) + # clean stack + drop drop + # => [r1, r0, ...] (2 cycles) + push.1 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2sub padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2mul ext2add padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul ext2mul ext2add push.1 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 ext2add ext2mul ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900200 drop drop ext2mul +end # END PROC compute_integrity_constraints + +# Procedure to evaluate the boundary constraint numerator for the first row of the main trace +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# Where: (r_1, r_0) is one quadratic extension field element for each constraint +proc.compute_boundary_constraints_main_first + # boundary constraint 0 for main + padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop push.0 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900201 movdn.3 movdn.3 drop drop ext2mul +end # END PROC compute_boundary_constraints_main_first + +# Procedure to evaluate all integrity constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +# Where: (r_1, r_0) is the final result with the divisor applied +proc.evaluate_integrity_constraints + exec.compute_integrity_constraints + # Numerator of the transition constraint polynomial + ext2add ext2add + # Divisor of the transition constraint polynomial + exec.compute_integrity_constraint_divisor + ext2div # divide the numerator by the divisor +end # END PROC evaluate_integrity_constraints + +# Procedure to evaluate all boundary constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +# Where: (r_1, r_0) is the final result with the divisor applied +proc.evaluate_boundary_constraints + exec.compute_boundary_constraints_main_first + # => [(first1, first0), ...] + # Compute the denominator for domain FirstRow + padw mem_loadw.4294903304 drop drop # load z + push.1 push.0 ext2sub + # Compute numerator/denominator for first row + ext2div +end # END PROC evaluate_boundary_constraints + +# Procedure to evaluate the integrity and boundary constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +export.evaluate_constraints + exec.cache_z_exp + exec.evaluate_integrity_constraints + exec.evaluate_boundary_constraints + ext2add +end # END PROC evaluate_constraints + diff --git a/air-script/tests/functions/functions_complex.rs b/air-script/tests/functions/functions_complex.rs new file mode 100644 index 00000000..4f4883f5 --- /dev/null +++ b/air-script/tests/functions/functions_complex.rs @@ -0,0 +1,91 @@ +use winter_air::{Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo}; +use winter_math::fields::f64::BaseElement as Felt; +use winter_math::{ExtensionOf, FieldElement}; +use winter_utils::collections::Vec; +use winter_utils::{ByteWriter, Serializable}; + +pub struct PublicInputs { + stack_inputs: [Felt; 16], +} + +impl PublicInputs { + pub fn new(stack_inputs: [Felt; 16]) -> Self { + Self { stack_inputs } + } +} + +impl Serializable for PublicInputs { + fn write_into(&self, target: &mut W) { + target.write(self.stack_inputs.as_slice()); + } +} + +pub struct FunctionsAir { + context: AirContext, + stack_inputs: [Felt; 16], +} + +impl FunctionsAir { + pub fn last_step(&self) -> usize { + self.trace_length() - self.context().num_transition_exemptions() + } +} + +impl Air for FunctionsAir { + type BaseField = Felt; + type PublicInputs = PublicInputs; + + fn context(&self) -> &AirContext { + &self.context + } + + fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { + let main_degrees = vec![TransitionConstraintDegree::new(1)]; + let aux_degrees = vec![TransitionConstraintDegree::new(8)]; + let num_main_assertions = 1; + let num_aux_assertions = 0; + + let context = AirContext::new_multi_segment( + trace_info, + main_degrees, + aux_degrees, + num_main_assertions, + num_aux_assertions, + options, + ) + .set_num_transition_exemptions(2); + Self { context, stack_inputs: public_inputs.stack_inputs } + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } + + fn get_assertions(&self) -> Vec> { + let mut result = Vec::new(); + result.push(Assertion::single(3, 0, Felt::ZERO)); + result + } + + fn get_aux_assertions>(&self, aux_rand_elements: &AuxTraceRandElements) -> Vec> { + let mut result = Vec::new(); + result + } + + fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { + let main_current = frame.current(); + let main_next = frame.next(); + result[0] = main_next[3] - (main_current[4] + main_current[5] + main_current[6] + main_current[7] + main_current[8] + main_current[9] + main_current[10] + main_current[11] + main_current[12] + main_current[13] + main_current[14] + main_current[15] + E::ONE) * E::from(2_u64); + } + + fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements, result: &mut [E]) + where F: FieldElement, + E: FieldElement + ExtensionOf, + { + let main_current = main_frame.current(); + let main_next = main_frame.next(); + let aux_current = aux_frame.current(); + let aux_next = aux_frame.next(); + result[0] = aux_next[0] - aux_current[0] * (((aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])).exp(E::PositiveInteger::from(4_u64)) * E::from(main_current[1]) * E::from(main_current[2]) + (aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])).exp(E::PositiveInteger::from(2_u64)) * (E::ONE - E::from(main_current[1])) * E::from(main_current[2]) + (aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])) * E::from(main_current[1]) * (E::ONE - E::from(main_current[2])) + (E::ONE - E::from(main_current[1])) * (E::ONE - E::from(main_current[2]))) * E::from(main_current[0]) - E::from(main_current[0]) + E::ONE); + } +} \ No newline at end of file diff --git a/air-script/tests/functions/functions_simple.air b/air-script/tests/functions/functions_simple.air new file mode 100644 index 00000000..2f93227b --- /dev/null +++ b/air-script/tests/functions/functions_simple.air @@ -0,0 +1,90 @@ +def FunctionsAir + +fn fold_sum(a: felt[4]) -> felt { + return a[0] + a[1] + a[2] + a[3] +} + +fn fold_vec(a: felt[4]) -> felt { + let m = a[0] * a[1] + let n = m * a[2] + let o = n * a[3] + return o +} + +fn cube(base: felt) -> felt { + return base^3 +} + +fn cube_vec(base: felt[4]) -> felt[4] { + let cubed_vec = [x^3 for x in base] + return cubed_vec +} + +fn func_return(a: felt[4]) -> felt { + return fold_sum(a) +} + +fn func_func_return(a: felt[4]) -> felt { + return fold_sum(a) * fold_vec(a) +} + +fn bin_return(a: felt[4]) -> felt { + return fold_sum(a) * 4 +} + +trace_columns { + main: [t, s0, s1, v, b[4]] + aux: [b_range] +} + +public_inputs { + stack_inputs: [16] +} + +random_values { + alpha: [16] +} + +boundary_constraints { + enf v.first = 0 +} + +integrity_constraints { + # -------- function call is assigned to a variable and used in a binary expression ------------ + + # binary expression invloving scalar expressions + let simple_expression = t * v + enf simple_expression = 1 + + # binary expression involving one function call + let folded_vec = fold_vec(b) * v + enf folded_vec = 1 + + # binary expression involving two function calls + let complex_fold = fold_sum(b) * fold_vec(b) + enf complex_fold = 1 + + + # -------- function calls used in constraint ------------ + enf fold_vec(b) = 1 + enf t * fold_vec(b) = 1 + enf s0 + fold_sum(b) * fold_vec(b) = 1 + + # -------- functions with function calls as return statements ------------ + enf func_return(b) = 1 + enf func_func_return(b) = 1 + enf bin_return(b) = 1 + + # -------- different types of arguments in a function call ------------ + + # function call with a function call as an argument + # enf fold_vec(cube_vec(b)) = 1 + + # function call as value in list comprehension + # let folded_vec = sum([cube(x) for x in b]) + # enf t * folded_vec = 1 + + # function call as iterable in list comprehension + # let folded_vec = sum([x + 1 for x in cube_vec(b)]) + # enf t * folded_vec = 1 +} diff --git a/air-script/tests/functions/functions_simple.masm b/air-script/tests/functions/functions_simple.masm new file mode 100644 index 00000000..7058aee3 --- /dev/null +++ b/air-script/tests/functions/functions_simple.masm @@ -0,0 +1,166 @@ +# Procedure to efficiently compute the required exponentiations of the out-of-domain point `z` and cache them for later use. +# +# This computes the power of `z` needed to evaluate the periodic polynomials and the constraint divisors +# +# Input: [...] +# Output: [...] +proc.cache_z_exp + padw mem_loadw.4294903304 drop drop # load z + # => [z_1, z_0, ...] + # Exponentiate z trace_len times + mem_load.4294903307 neg + # => [count, z_1, z_0, ...] where count = -log2(trace_len) + dup.0 neq.0 + while.true + movdn.2 dup.1 dup.1 ext2mul + # => [(e_1, e_0)^n, i, ...] + movup.2 add.1 dup.0 neq.0 + # => [b, i+1, (e_1, e_0)^n, ...] + end # END while + push.0 mem_storew.500000100 # z^trace_len + # => [0, 0, (z_1, z_0)^trace_len, ...] + dropw # Clean stack +end # END PROC cache_z_exp + +# Procedure to compute the exemption points. +# +# Input: [...] +# Output: [g^{-2}, g^{-1}, ...] +proc.get_exemptions_points + mem_load.4294799999 + # => [g, ...] + push.1 swap div + # => [g^{-1}, ...] + dup.0 dup.0 mul + # => [g^{-2}, g^{-1}, ...] +end # END PROC get_exemptions_points + +# Procedure to compute the integrity constraint divisor. +# +# The divisor is defined as `(z^trace_len - 1) / ((z - g^{trace_len-2}) * (z - g^{trace_len-1}))` +# Procedure `cache_z_exp` must have been called prior to this. +# +# Input: [...] +# Output: [divisor_1, divisor_0, ...] +proc.compute_integrity_constraint_divisor + padw mem_loadw.500000100 drop drop # load z^trace_len + # Comments below use zt = `z^trace_len` + # => [zt_1, zt_0, ...] + push.1 push.0 ext2sub + # => [zt_1-1, zt_0-1, ...] + padw mem_loadw.4294903304 drop drop # load z + # => [z_1, z_0, zt_1-1, zt_0-1, ...] + exec.get_exemptions_points + # => [g^{trace_len-2}, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...] + dup.0 mem_store.500000101 # Save a copy of `g^{trace_len-2} to be used by the boundary divisor + dup.3 dup.3 movup.3 push.0 ext2sub + # => [e_1, e_0, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...] + movup.4 movup.4 movup.4 push.0 ext2sub + # => [e_3, e_2, e_1, e_0, zt_1-1, zt_0-1, ...] + ext2mul + # => [denominator_1, denominator_0, zt_1-1, zt_0-1, ...] + ext2div + # => [divisor_1, divisor_0, ...] +end # END PROC compute_integrity_constraint_divisor + +# Procedure to evaluate numerators of all integrity constraints. +# +# All the 9 main and 0 auxiliary constraints are evaluated. +# The result of each evaluation is kept on the stack, with the top of the stack +# containing the evaluations for the auxiliary trace (if any) followed by the main trace. +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# where: (r_1, r_0) is the quadratic extension element resulting from the integrity constraint evaluation. +# This procedure pushes 9 quadratic extension field elements to the stack +proc.compute_integrity_constraints + # integrity constraint 0 for main + padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2mul push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900200 movdn.3 movdn.3 drop drop ext2mul + # integrity constraint 1 for main + padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2mul push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900200 drop drop ext2mul + # integrity constraint 2 for main + padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2mul ext2mul push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900201 movdn.3 movdn.3 drop drop ext2mul + # integrity constraint 3 for main + padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2mul push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900201 drop drop ext2mul + # integrity constraint 4 for main + padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2mul ext2mul push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900202 movdn.3 movdn.3 drop drop ext2mul + # integrity constraint 5 for main + padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2mul ext2mul ext2add push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900202 drop drop ext2mul + # integrity constraint 6 for main + padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2add push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900203 movdn.3 movdn.3 drop drop ext2mul + # integrity constraint 7 for main + padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2mul ext2mul push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900203 drop drop ext2mul + # integrity constraint 8 for main + padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2add push.4 push.0 ext2mul push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900204 movdn.3 movdn.3 drop drop ext2mul +end # END PROC compute_integrity_constraints + +# Procedure to evaluate the boundary constraint numerator for the first row of the main trace +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# Where: (r_1, r_0) is one quadratic extension field element for each constraint +proc.compute_boundary_constraints_main_first + # boundary constraint 0 for main + padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop push.0 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900204 drop drop ext2mul +end # END PROC compute_boundary_constraints_main_first + +# Procedure to evaluate all integrity constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +# Where: (r_1, r_0) is the final result with the divisor applied +proc.evaluate_integrity_constraints + exec.compute_integrity_constraints + # Numerator of the transition constraint polynomial + ext2add ext2add ext2add ext2add ext2add ext2add ext2add ext2add ext2add + # Divisor of the transition constraint polynomial + exec.compute_integrity_constraint_divisor + ext2div # divide the numerator by the divisor +end # END PROC evaluate_integrity_constraints + +# Procedure to evaluate all boundary constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +# Where: (r_1, r_0) is the final result with the divisor applied +proc.evaluate_boundary_constraints + exec.compute_boundary_constraints_main_first + # => [(first1, first0), ...] + # Compute the denominator for domain FirstRow + padw mem_loadw.4294903304 drop drop # load z + push.1 push.0 ext2sub + # Compute numerator/denominator for first row + ext2div +end # END PROC evaluate_boundary_constraints + +# Procedure to evaluate the integrity and boundary constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +export.evaluate_constraints + exec.cache_z_exp + exec.evaluate_integrity_constraints + exec.evaluate_boundary_constraints + ext2add +end # END PROC evaluate_constraints + diff --git a/air-script/tests/functions/functions_simple.rs b/air-script/tests/functions/functions_simple.rs new file mode 100644 index 00000000..242dc697 --- /dev/null +++ b/air-script/tests/functions/functions_simple.rs @@ -0,0 +1,98 @@ +use winter_air::{Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo}; +use winter_math::fields::f64::BaseElement as Felt; +use winter_math::{ExtensionOf, FieldElement}; +use winter_utils::collections::Vec; +use winter_utils::{ByteWriter, Serializable}; + +pub struct PublicInputs { + stack_inputs: [Felt; 16], +} + +impl PublicInputs { + pub fn new(stack_inputs: [Felt; 16]) -> Self { + Self { stack_inputs } + } +} + +impl Serializable for PublicInputs { + fn write_into(&self, target: &mut W) { + target.write(self.stack_inputs.as_slice()); + } +} + +pub struct FunctionsAir { + context: AirContext, + stack_inputs: [Felt; 16], +} + +impl FunctionsAir { + pub fn last_step(&self) -> usize { + self.trace_length() - self.context().num_transition_exemptions() + } +} + +impl Air for FunctionsAir { + type BaseField = Felt; + type PublicInputs = PublicInputs; + + fn context(&self) -> &AirContext { + &self.context + } + + fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { + let main_degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(4), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(1)]; + let aux_degrees = vec![]; + let num_main_assertions = 1; + let num_aux_assertions = 0; + + let context = AirContext::new_multi_segment( + trace_info, + main_degrees, + aux_degrees, + num_main_assertions, + num_aux_assertions, + options, + ) + .set_num_transition_exemptions(2); + Self { context, stack_inputs: public_inputs.stack_inputs } + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } + + fn get_assertions(&self) -> Vec> { + let mut result = Vec::new(); + result.push(Assertion::single(3, 0, Felt::ZERO)); + result + } + + fn get_aux_assertions>(&self, aux_rand_elements: &AuxTraceRandElements) -> Vec> { + let mut result = Vec::new(); + result + } + + fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { + let main_current = frame.current(); + let main_next = frame.next(); + result[0] = main_current[0] * main_current[3] - E::ONE; + result[1] = main_current[4] * main_current[5] * main_current[6] * main_current[7] * main_current[3] - E::ONE; + result[2] = (main_current[4] + main_current[5] + main_current[6] + main_current[7]) * main_current[4] * main_current[5] * main_current[6] * main_current[7] - E::ONE; + result[3] = main_current[4] * main_current[5] * main_current[6] * main_current[7] - E::ONE; + result[4] = main_current[0] * main_current[4] * main_current[5] * main_current[6] * main_current[7] - E::ONE; + result[5] = main_current[1] + (main_current[4] + main_current[5] + main_current[6] + main_current[7]) * main_current[4] * main_current[5] * main_current[6] * main_current[7] - E::ONE; + result[6] = main_current[4] + main_current[5] + main_current[6] + main_current[7] - E::ONE; + result[7] = (main_current[4] + main_current[5] + main_current[6] + main_current[7]) * main_current[4] * main_current[5] * main_current[6] * main_current[7] - E::ONE; + result[8] = (main_current[4] + main_current[5] + main_current[6] + main_current[7]) * E::from(4_u64) - E::ONE; + } + + fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements, result: &mut [E]) + where F: FieldElement, + E: FieldElement + ExtensionOf, + { + let main_current = main_frame.current(); + let main_next = main_frame.next(); + let aux_current = aux_frame.current(); + let aux_next = aux_frame.next(); + } +} \ No newline at end of file diff --git a/air-script/tests/functions/inlined_functions_simple.air b/air-script/tests/functions/inlined_functions_simple.air new file mode 100644 index 00000000..4fcde262 --- /dev/null +++ b/air-script/tests/functions/inlined_functions_simple.air @@ -0,0 +1,52 @@ +# This file is added as a sanity check to make sure the constraints generated with inlined +# functions is the same as those with functions +def FunctionsAir + +trace_columns { + main: [t, s0, s1, v, b[4]] + aux: [b_range] +} + +public_inputs { + stack_inputs: [16] +} + +random_values { + alpha: [16] +} + +boundary_constraints { + enf v.first = 0 +} + +integrity_constraints { + # -------- function call is assigned to a variable and used in a binary expression ------------ + + # binary expression invloving scalar expressions + let simple_expression = t * v + enf simple_expression = 1 + + # binary expression involving one function call + + # fold_vec function body where o is the return value of the function + let m = b[0] * b[1] + let n = m * b[2] + let o = n * b[3] + + let folded_vec = o * v + enf folded_vec = 1 + + # binary expression involving two function calls + let complex_fold = (b[0] + b[1] + b[2] + b[3]) * o + enf complex_fold = 1 + + # function calls used in constraints + enf o = 1 + enf t * o = 1 + enf s0 + (b[0] + b[1] + b[2] + b[3]) * o = 1 + + # -------- functions with function calls as return statements ------------ + enf b[0] + b[1] + b[2] + b[3] = 1 + enf (b[0] + b[1] + b[2] + b[3]) * o = 1 + enf (b[0] + b[1] + b[2] + b[3]) * 4 = 1 +} diff --git a/codegen/masm/tests/test_aux.rs b/codegen/masm/tests/test_aux.rs index 93062042..2455165e 100644 --- a/codegen/masm/tests/test_aux.rs +++ b/codegen/masm/tests/test_aux.rs @@ -40,7 +40,7 @@ fn test_simple_aux() { let trace_len = 2u64.pow(4); let one = QuadExtension::new(Felt::new(1), Felt::ZERO); - let z = one.clone(); + let z = one; let a = QuadExtension::new(Felt::new(3), Felt::ZERO); let b = QuadExtension::new(Felt::new(7), Felt::ZERO); let a_prime = a; @@ -60,7 +60,7 @@ fn test_simple_aux() { descriptor: "aux_trace", }, Data { - data: to_stack_order(&vec![one; 6]), + data: to_stack_order(&[one; 6]), address: constants::COMPOSITION_COEF_ADDRESS, descriptor: "composition_coefficients", }, diff --git a/codegen/masm/tests/test_boundary.rs b/codegen/masm/tests/test_boundary.rs index c8c4f849..5a18856f 100644 --- a/codegen/masm/tests/test_boundary.rs +++ b/codegen/masm/tests/test_boundary.rs @@ -38,18 +38,18 @@ fn test_simple_boundary() { let trace_len = 32u64; let one = QuadExtension::ONE; - let z = one.clone(); + let z = one; let a = QuadExtension::new(Felt::new(514229), Felt::ZERO); let b = QuadExtension::new(Felt::new(317811), Felt::ZERO); let len = QuadExtension::new(Felt::new(27), Felt::ZERO); let a_prime = QuadExtension::new(Felt::new(514229 + 317811), Felt::ZERO); - let b_prime = a.clone(); + let b_prime = a; let code = test_code( code, vec![ Data { - data: to_stack_order(&[a, a_prime, b, b_prime, len.clone(), len.clone()]), + data: to_stack_order(&[a, a_prime, b, b_prime, len, len]), address: constants::OOD_FRAME_ADDRESS, descriptor: "main_trace", }, @@ -151,7 +151,7 @@ fn test_complex_boundary() { let trace_len = 32u64; let one = QuadExtension::new(Felt::new(1), Felt::ZERO); - let z = one.clone(); + let z = one; let public_inputs = [ // stack_inputs diff --git a/codegen/masm/tests/test_constants.rs b/codegen/masm/tests/test_constants.rs index f551ca8a..4d81ac90 100644 --- a/codegen/masm/tests/test_constants.rs +++ b/codegen/masm/tests/test_constants.rs @@ -66,7 +66,7 @@ fn test_constants() { descriptor: "aux_trace", }, Data { - data: to_stack_order(&vec![one; 3]), + data: to_stack_order(&[one; 3]), address: constants::COMPOSITION_COEF_ADDRESS, descriptor: "composition_coefficients", }, diff --git a/codegen/masm/tests/test_divisor.rs b/codegen/masm/tests/test_divisor.rs index 0b1b73e9..2d196b93 100644 --- a/codegen/masm/tests/test_divisor.rs +++ b/codegen/masm/tests/test_divisor.rs @@ -50,7 +50,7 @@ fn test_integrity_divisor() { descriptor: "main_trace", }, Data { - data: to_stack_order(&vec![one; 2]), + data: to_stack_order(&[one; 2]), address: constants::COMPOSITION_COEF_ADDRESS, descriptor: "composition_coefficients", }, @@ -149,7 +149,7 @@ fn test_boundary_divisor() { descriptor: "aux_trace", }, Data { - data: to_stack_order(&vec![one; 5]), + data: to_stack_order(&[one; 5]), address: constants::COMPOSITION_COEF_ADDRESS, descriptor: "composition_coefficients", }, @@ -268,7 +268,7 @@ fn test_mixed_boundary_divisor() { descriptor: "aux_trace", }, Data { - data: to_stack_order(&vec![one; 5]), + data: to_stack_order(&[one; 5]), address: constants::COMPOSITION_COEF_ADDRESS, descriptor: "composition_coefficients", }, diff --git a/codegen/masm/tests/test_periodic.rs b/codegen/masm/tests/test_periodic.rs index 6e8388f2..3a77dfa2 100644 --- a/codegen/masm/tests/test_periodic.rs +++ b/codegen/masm/tests/test_periodic.rs @@ -55,7 +55,7 @@ fn test_simple_periodic() { descriptor: "aux_trace", }, Data { - data: to_stack_order(&vec![one; 1]), + data: to_stack_order(&[one; 1]), address: constants::COMPOSITION_COEF_ADDRESS, descriptor: "composition_coefficients", }, @@ -150,7 +150,7 @@ fn test_multiple_periodic() { descriptor: "aux_trace", }, Data { - data: to_stack_order(&vec![one; 3]), + data: to_stack_order(&[one; 3]), address: constants::COMPOSITION_COEF_ADDRESS, descriptor: "composition_coefficients", }, diff --git a/ir/src/passes/translate.rs b/ir/src/passes/translate.rs index 694c9694..61ce38b4 100644 --- a/ir/src/passes/translate.rs +++ b/ir/src/passes/translate.rs @@ -23,6 +23,8 @@ impl<'p> Pass for AstToAir<'p> { type Error = CompileError; fn run<'a>(&mut self, program: Self::Input<'a>) -> Result, Self::Error> { + dbg!(&program); + let mut air = Air::new(program.name); let random_values = program.random_values; @@ -47,8 +49,8 @@ impl<'p> Pass for AstToAir<'p> { builder.build_boundary_constraint(bc)?; } - for bc in integrity_constraints.iter() { - builder.build_integrity_constraint(bc)?; + for ic in integrity_constraints.iter() { + builder.build_integrity_constraint(ic)?; } Ok(air) @@ -98,8 +100,8 @@ impl<'a> AirBuilder<'a> { } } - fn build_integrity_constraint(&mut self, bc: &ast::Statement) -> Result<(), CompileError> { - match bc { + fn build_integrity_constraint(&mut self, ic: &ast::Statement) -> Result<(), CompileError> { + match ic { ast::Statement::Enforce(ast::ScalarExpr::Binary(ast::BinaryExpr { op: ast::BinaryOp::Eq, ref lhs, diff --git a/parser/src/ast/declarations.rs b/parser/src/ast/declarations.rs index 10c0c018..cbb9aaac 100644 --- a/parser/src/ast/declarations.rs +++ b/parser/src/ast/declarations.rs @@ -41,6 +41,10 @@ pub enum Declaration { /// /// Evaluator functions can be defined in any module of the program EvaluatorFunction(EvaluatorFunction), + /// A pure function definition + /// + /// Pure functions can be defined in any module of the program + Function(Function), /// A `periodic_columns` section declaration /// /// This may appear any number of times in the program, and may be declared in any module. @@ -525,3 +529,50 @@ impl PartialEq for EvaluatorFunction { self.name == other.name && self.params == other.params && self.body == other.body } } + +/// Functions take a group of expressions as parameters and returns a value. +/// +/// The result value of a function may be a felt, vector, or a matrix. +/// +/// NOTE: Functions do not take trace bindings as parameters. +#[derive(Debug, Clone, Spanned)] +pub struct Function { + #[span] + pub span: SourceSpan, + pub name: Identifier, + pub params: Vec<(Identifier, Type)>, + pub return_type: Type, + pub body: Vec, +} +impl Function { + /// Creates a new function. + pub const fn new( + span: SourceSpan, + name: Identifier, + params: Vec<(Identifier, Type)>, + return_type: Type, + body: Vec, + ) -> Self { + Self { + span, + name, + params, + return_type, + body, + } + } + + pub fn param_types(&self) -> Vec { + self.params.iter().map(|(_, ty)| *ty).collect::>() + } +} + +impl Eq for Function {} +impl PartialEq for Function { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.params == other.params + && self.return_type == other.return_type + && self.body == other.body + } +} diff --git a/parser/src/ast/display.rs b/parser/src/ast/display.rs index 2172795a..d8ee0de6 100644 --- a/parser/src/ast/display.rs +++ b/parser/src/ast/display.rs @@ -10,7 +10,7 @@ impl fmt::Display for DisplayBracketed { } } -/// Displays a slice of items surrounded by brackets, e.g. `[foo]` +/// Displays a slice of items surrounded by brackets, e.g. `[foo, bar]` pub struct DisplayList<'a, T>(pub &'a [T]); impl<'a, T: fmt::Display> fmt::Display for DisplayList<'a, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -26,7 +26,7 @@ impl fmt::Display for DisplayParenthesized { } } -/// Displays a slice of items surrounded by parentheses, e.g. `(foo)` +/// Displays a slice of items surrounded by parentheses, e.g. `(foo, bar)` pub struct DisplayTuple<'a, T>(pub &'a [T]); impl<'a, T: fmt::Display> fmt::Display for DisplayTuple<'a, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -34,6 +34,19 @@ impl<'a, T: fmt::Display> fmt::Display for DisplayTuple<'a, T> { } } +/// Displays a slice of items with their types surrounded by parentheses, +/// e.g. `(foo: felt, bar: felt[12])` +pub struct DisplayTypedTuple<'a, V, T>(pub &'a [(V, T)]); +impl<'a, V: fmt::Display, T: fmt::Display> fmt::Display for DisplayTypedTuple<'a, V, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "({})", + DisplayCsv::new(self.0.iter().map(|(v, t)| format!("{}: {}", v, t))) + ) + } +} + /// Displays one or more items separated by commas, e.g. `foo, bar` pub struct DisplayCsv(Cell>); impl DisplayCsv @@ -96,7 +109,7 @@ impl<'a> fmt::Display for DisplayStatement<'a> { Statement::EnforceAll(ref expr) => { write!(f, "enf {}", expr) } - Statement::Expr(ref expr) => write!(f, "{}", expr), + Statement::Expr(ref expr) => write!(f, "return {}", expr), } } } diff --git a/parser/src/ast/errors.rs b/parser/src/ast/errors.rs index 9f7abcac..be0e6750 100644 --- a/parser/src/ast/errors.rs +++ b/parser/src/ast/errors.rs @@ -49,3 +49,29 @@ impl ToDiagnostic for InvalidExprError { } } } + +/// Represents an invalid type for use in a `BindingType` context +#[derive(Debug, thiserror::Error)] +pub enum InvalidTypeError { + #[error("expected iterable to be a vector")] + NonVectorIterable(SourceSpan), +} +impl Eq for InvalidTypeError {} +impl PartialEq for InvalidTypeError { + fn eq(&self, other: &Self) -> bool { + core::mem::discriminant(self) == core::mem::discriminant(other) + } +} +impl ToDiagnostic for InvalidTypeError { + fn to_diagnostic(self) -> Diagnostic { + let message = format!("{}", &self); + match self { + Self::NonVectorIterable(span) => Diagnostic::error() + .with_message("invalid type") + .with_labels(vec![ + Label::primary(span.source_id(), span).with_message(message) + ]) + .with_notes(vec!["Only vectors can be used as iterables".to_string()]), + } + } +} diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index e115d7f0..d763328d 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -445,6 +445,15 @@ impl ScalarExpr { matches!(self, Self::Const(_)) } + /// Returns true if this scalar expression could expand to a block, e.g. due to a function call being inlined. + pub fn has_block_like_expansion(&self) -> bool { + match self { + Self::Binary(ref expr) => expr.has_block_like_expansion(), + Self::Call(_) => true, + _ => false, + } + } + /// Returns the resolved type of this expression, if known. /// /// Returns `Ok(Some)` if the type could be resolved without conflict. @@ -527,6 +536,12 @@ impl BinaryExpr { rhs: Box::new(rhs), } } + + /// Returns true if this binary expression could expand to a block, e.g. due to a function call being inlined. + #[inline] + pub fn has_block_like_expansion(&self) -> bool { + self.lhs.has_block_like_expansion() || self.rhs.has_block_like_expansion() + } } impl Eq for BinaryExpr {} impl PartialEq for BinaryExpr { diff --git a/parser/src/ast/mod.rs b/parser/src/ast/mod.rs index 5df1cf70..b02b952c 100644 --- a/parser/src/ast/mod.rs +++ b/parser/src/ast/mod.rs @@ -71,6 +71,8 @@ pub struct Program { pub constants: BTreeMap, /// The set of used evaluator functions referenced in this program. pub evaluators: BTreeMap, + /// The set of used pure functions referenced in this program. + pub functions: BTreeMap, /// The set of used periodic columns referenced in this program. pub periodic_columns: BTreeMap, /// The set of public inputs defined in the root module @@ -115,6 +117,7 @@ impl Program { name, constants: Default::default(), evaluators: Default::default(), + functions: Default::default(), periodic_columns: Default::default(), public_inputs: Default::default(), random_values: None, @@ -265,7 +268,12 @@ impl Program { .entry(referenced) .or_insert_with(|| referenced_module.evaluators[&id].clone()); } - DependencyType::Function => unimplemented!(), + DependencyType::Function => { + program + .functions + .entry(referenced) + .or_insert_with(|| referenced_module.functions[&id].clone()); + } DependencyType::PeriodicColumn => { program .periodic_columns @@ -288,6 +296,7 @@ impl PartialEq for Program { self.name == other.name && self.constants == other.constants && self.evaluators == other.evaluators + && self.functions == other.functions && self.periodic_columns == other.periodic_columns && self.public_inputs == other.public_inputs && self.random_values == other.random_values @@ -384,6 +393,29 @@ impl fmt::Display for Program { f.write_str("\n")?; } + for (qid, function) in self.functions.iter() { + f.write_str("fn ")?; + if qid.module == self.name { + writeln!( + f, + "{}{}", + &qid.item, + DisplayTypedTuple(function.params.as_slice()) + )?; + } else { + writeln!( + f, + "{}{}", + qid, + DisplayTypedTuple(function.params.as_slice()) + )?; + } + + for statement in function.body.iter() { + writeln!(f, "{}", statement.display(1))?; + } + } + Ok(()) } } diff --git a/parser/src/ast/module.rs b/parser/src/ast/module.rs index 393f93db..c9363ca9 100644 --- a/parser/src/ast/module.rs +++ b/parser/src/ast/module.rs @@ -54,6 +54,7 @@ pub struct Module { pub imports: BTreeMap, pub constants: BTreeMap, pub evaluators: BTreeMap, + pub functions: BTreeMap, pub periodic_columns: BTreeMap, pub public_inputs: BTreeMap, pub random_values: Option, @@ -79,6 +80,7 @@ impl Module { imports: Default::default(), constants: Default::default(), evaluators: Default::default(), + functions: Default::default(), periodic_columns: Default::default(), public_inputs: Default::default(), random_values: None, @@ -121,6 +123,9 @@ impl Module { Declaration::EvaluatorFunction(evaluator) => { module.declare_evaluator(diagnostics, &mut names, evaluator)?; } + Declaration::Function(function) => { + module.declare_function(diagnostics, &mut names, function)?; + } Declaration::PeriodicColumns(mut columns) => { for column in columns.drain(..) { module.declare_periodic_column(diagnostics, &mut names, column)?; @@ -395,6 +400,22 @@ impl Module { Ok(()) } + fn declare_function( + &mut self, + diagnostics: &DiagnosticsHandler, + names: &mut HashSet, + function: Function, + ) -> Result<(), SemanticAnalysisError> { + if let Some(prev) = names.replace(NamespacedIdentifier::Function(function.name)) { + conflicting_declaration(diagnostics, "function", prev.span(), function.name.span()); + return Err(SemanticAnalysisError::NameConflict(function.name.span())); + } + + self.functions.insert(function.name, function); + + Ok(()) + } + fn declare_periodic_column( &mut self, diagnostics: &DiagnosticsHandler, @@ -621,6 +642,7 @@ impl PartialEq for Module { && self.imports == other.imports && self.constants == other.constants && self.evaluators == other.evaluators + && self.functions == other.functions && self.periodic_columns == other.periodic_columns && self.public_inputs == other.public_inputs && self.random_values == other.random_values diff --git a/parser/src/ast/statement.rs b/parser/src/ast/statement.rs index 3f0594cd..5b5cd367 100644 --- a/parser/src/ast/statement.rs +++ b/parser/src/ast/statement.rs @@ -127,12 +127,29 @@ impl Let { } } + /// Return the type of the overall `let` expression. + /// + /// A `let` with an empty body, or with a body that terminates with a non-expression statement + /// has no type (or rather, one could consider the type it returns to be of "void" or "unit" type). + /// + /// For `let` statements with a non-empty body that terminates with an expression, the `let` can + /// be used in expression position, producing the value of the terminating expression in its body, + /// and having the same type as that value. pub fn ty(&self) -> Option { - self.body.last().and_then(|stmt| match stmt { - Statement::Let(ref nested) => nested.ty(), - Statement::Expr(ref expr) => expr.ty(), - _ => None, - }) + let mut last = self.body.last(); + while let Some(stmt) = last.take() { + match stmt { + Statement::Let(ref let_expr) => { + last = let_expr.body.last(); + } + Statement::Expr(ref expr) => return expr.ty(), + Statement::Enforce(_) | Statement::EnforceIf(_, _) | Statement::EnforceAll(_) => { + break + } + } + } + + None } } impl Eq for Let {} diff --git a/parser/src/ast/types.rs b/parser/src/ast/types.rs index a5cdfd4f..96620db8 100644 --- a/parser/src/ast/types.rs +++ b/parser/src/ast/types.rs @@ -72,9 +72,9 @@ impl Type { impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Self::Felt => f.write_str("field element"), - Self::Vector(n) => write!(f, "vector of length {}", n), - Self::Matrix(rows, cols) => write!(f, "matrix of {} rows and {} columns", rows, cols), + Self::Felt => f.write_str("felt"), + Self::Vector(n) => write!(f, "felt[{}]", n), + Self::Matrix(rows, cols) => write!(f, "felt[{}, {}]", rows, cols), } } } @@ -86,7 +86,6 @@ pub enum FunctionType { /// a complex type signature due to the nature of trace bindings Evaluator(Vec), /// A standard function with one or more inputs, and a result - #[allow(dead_code)] Function(Vec, Type), } impl FunctionType { diff --git a/parser/src/ast/visit.rs b/parser/src/ast/visit.rs index d879ecaf..89765c5e 100644 --- a/parser/src/ast/visit.rs +++ b/parser/src/ast/visit.rs @@ -122,6 +122,9 @@ pub trait VisitMut { ) -> ControlFlow { visit_mut_evaluator_function(self, expr) } + fn visit_mut_function(&mut self, expr: &mut ast::Function) -> ControlFlow { + visit_mut_function(self, expr) + } fn visit_mut_periodic_column(&mut self, expr: &mut ast::PeriodicColumn) -> ControlFlow { visit_mut_periodic_column(self, expr) } @@ -223,6 +226,12 @@ pub trait VisitMut { fn visit_mut_identifier(&mut self, expr: &mut ast::Identifier) -> ControlFlow { visit_mut_identifier(self, expr) } + fn visit_mut_typed_identifier( + &mut self, + expr: &mut (ast::Identifier, ast::Type), + ) -> ControlFlow { + visit_mut_typed_identifier(self, expr) + } } impl<'a, V, T> VisitMut for &'a mut V @@ -244,6 +253,9 @@ where ) -> ControlFlow { (**self).visit_mut_evaluator_function(expr) } + fn visit_mut_function(&mut self, expr: &mut ast::Function) -> ControlFlow { + (**self).visit_mut_function(expr) + } fn visit_mut_periodic_column(&mut self, expr: &mut ast::PeriodicColumn) -> ControlFlow { (**self).visit_mut_periodic_column(expr) } @@ -344,6 +356,12 @@ where fn visit_mut_identifier(&mut self, expr: &mut ast::Identifier) -> ControlFlow { (**self).visit_mut_identifier(expr) } + fn visit_mut_typed_identifier( + &mut self, + expr: &mut (ast::Identifier, ast::Type), + ) -> ControlFlow { + (**self).visit_mut_typed_identifier(expr) + } } pub fn visit_mut_module(visitor: &mut V, module: &mut ast::Module) -> ControlFlow @@ -359,6 +377,9 @@ where for evaluator in module.evaluators.values_mut() { visitor.visit_mut_evaluator_function(evaluator)?; } + for function in module.functions.values_mut() { + visitor.visit_mut_function(function)?; + } for column in module.periodic_columns.values_mut() { visitor.visit_mut_periodic_column(column)?; } @@ -439,6 +460,17 @@ where visitor.visit_mut_statement_block(&mut expr.body) } +pub fn visit_mut_function(visitor: &mut V, expr: &mut ast::Function) -> ControlFlow +where + V: ?Sized + VisitMut, +{ + visitor.visit_mut_identifier(&mut expr.name)?; + for param in expr.params.iter_mut() { + visitor.visit_mut_typed_identifier(param)?; + } + visitor.visit_mut_statement_block(&mut expr.body) +} + pub fn visit_mut_evaluator_trace_segment( visitor: &mut V, expr: &mut ast::TraceSegment, @@ -661,3 +693,13 @@ where { ControlFlow::Continue(()) } + +pub fn visit_mut_typed_identifier( + _visitor: &mut V, + _expr: &mut (ast::Identifier, ast::Type), +) -> ControlFlow +where + V: ?Sized + VisitMut, +{ + ControlFlow::Continue(()) +} diff --git a/parser/src/lexer/mod.rs b/parser/src/lexer/mod.rs index 6771316f..fe89b499 100644 --- a/parser/src/lexer/mod.rs +++ b/parser/src/lexer/mod.rs @@ -113,6 +113,8 @@ pub enum Token { RandomValues, /// Keyword to declare the evaluator function section in the AIR constraints module. Ev, + /// Keyword to declare the function section in the AIR constraints module. + Fn, // BOUNDARY CONSTRAINT KEYWORDS // -------------------------------------------------------------------------------------------- @@ -137,9 +139,11 @@ pub enum Token { // -------------------------------------------------------------------------------------------- /// Keyword to signify that a constraint needs to be enforced Enf, + Return, Match, Case, When, + Felt, // PUNCTUATION // -------------------------------------------------------------------------------------------- @@ -163,6 +167,7 @@ pub enum Token { Ampersand, Bar, Bang, + Arrow, } impl Token { pub fn from_keyword_or_ident(s: &str) -> Self { @@ -179,6 +184,8 @@ impl Token { "periodic_columns" => Self::PeriodicColumns, "random_values" => Self::RandomValues, "ev" => Self::Ev, + "fn" => Self::Fn, + "felt" => Self::Felt, "boundary_constraints" => Self::BoundaryConstraints, "integrity_constraints" => Self::IntegrityConstraints, "first" => Self::First, @@ -186,6 +193,7 @@ impl Token { "for" => Self::For, "in" => Self::In, "enf" => Self::Enf, + "return" => Self::Return, "match" => Self::Match, "case" => Self::Case, "when" => Self::When, @@ -249,6 +257,8 @@ impl fmt::Display for Token { Self::PeriodicColumns => write!(f, "periodic_columns"), Self::RandomValues => write!(f, "random_values"), Self::Ev => write!(f, "ev"), + Self::Fn => write!(f, "fn"), + Self::Felt => write!(f, "felt"), Self::BoundaryConstraints => write!(f, "boundary_constraints"), Self::First => write!(f, "first"), Self::Last => write!(f, "last"), @@ -256,6 +266,7 @@ impl fmt::Display for Token { Self::For => write!(f, "for"), Self::In => write!(f, "in"), Self::Enf => write!(f, "enf"), + Self::Return => write!(f, "return"), Self::Match => write!(f, "match"), Self::Case => write!(f, "case"), Self::When => write!(f, "when"), @@ -279,6 +290,7 @@ impl fmt::Display for Token { Self::Ampersand => write!(f, "&"), Self::Bar => write!(f, "|"), Self::Bang => write!(f, "!"), + Self::Arrow => write!(f, "->"), } } } @@ -492,7 +504,10 @@ where '}' => pop!(self, Token::RBrace), '=' => pop!(self, Token::Equal), '+' => pop!(self, Token::Plus), - '-' => pop!(self, Token::Minus), + '-' => match self.peek() { + '>' => pop2!(self, Token::Arrow), + _ => pop!(self, Token::Minus), + }, '*' => pop!(self, Token::Star), '^' => pop!(self, Token::Caret), '&' => pop!(self, Token::Ampersand), diff --git a/parser/src/lexer/tests/functions.rs b/parser/src/lexer/tests/functions.rs new file mode 100644 index 00000000..a8302ab0 --- /dev/null +++ b/parser/src/lexer/tests/functions.rs @@ -0,0 +1,87 @@ +use super::{expect_valid_tokenization, Symbol, Token}; + +// FUNCTION VALID TOKENIZATION +// ================================================================================================ + +#[test] +fn fn_with_scalars() { + let source = "fn fn_name(a: felt, b: felt) -> felt { + return a + b + }"; + + let tokens = [ + Token::Fn, + Token::FunctionIdent(Symbol::intern("fn_name")), + Token::LParen, + Token::Ident(Symbol::intern("a")), + Token::Colon, + Token::Felt, + Token::Comma, + Token::Ident(Symbol::intern("b")), + Token::Colon, + Token::Felt, + Token::RParen, + Token::Arrow, + Token::Felt, + Token::LBrace, + Token::Return, + Token::Ident(Symbol::intern("a")), + Token::Plus, + Token::Ident(Symbol::intern("b")), + Token::RBrace, + ]; + + expect_valid_tokenization(source, tokens.to_vec()); +} + +#[test] +fn fn_with_vectors() { + let source = "fn fn_name(a: felt[12], b: felt[12]) -> felt[12] { + return [x + y for x, y in (a, b)] + }"; + + let tokens = [ + Token::Fn, + Token::FunctionIdent(Symbol::intern("fn_name")), + Token::LParen, + Token::Ident(Symbol::intern("a")), + Token::Colon, + Token::Felt, + Token::LBracket, + Token::Num(12), + Token::RBracket, + Token::Comma, + Token::Ident(Symbol::intern("b")), + Token::Colon, + Token::Felt, + Token::LBracket, + Token::Num(12), + Token::RBracket, + Token::RParen, + Token::Arrow, + Token::Felt, + Token::LBracket, + Token::Num(12), + Token::RBracket, + Token::LBrace, + Token::Return, + Token::LBracket, + Token::Ident(Symbol::intern("x")), + Token::Plus, + Token::Ident(Symbol::intern("y")), + Token::For, + Token::Ident(Symbol::intern("x")), + Token::Comma, + Token::Ident(Symbol::intern("y")), + Token::In, + Token::LParen, + Token::Ident(Symbol::intern("a")), + Token::Comma, + Token::Ident(Symbol::intern("b")), + Token::RParen, + Token::RBracket, + Token::RBrace, + ]; + + expect_valid_tokenization(source, tokens.to_vec()); +} diff --git a/parser/src/lexer/tests/mod.rs b/parser/src/lexer/tests/mod.rs index 0426e68a..a312c007 100644 --- a/parser/src/lexer/tests/mod.rs +++ b/parser/src/lexer/tests/mod.rs @@ -6,6 +6,7 @@ mod arithmetic_ops; mod boundary_constraints; mod constants; mod evaluator_functions; +mod functions; mod identifiers; mod list_comprehension; mod modules; diff --git a/parser/src/parser/grammar.lalrpop b/parser/src/parser/grammar.lalrpop index a488424c..c54077b1 100644 --- a/parser/src/parser/grammar.lalrpop +++ b/parser/src/parser/grammar.lalrpop @@ -75,6 +75,7 @@ Declaration: Declaration = { PeriodicColumns => Declaration::PeriodicColumns(<>), RandomValues => Declaration::RandomValues(<>), EvaluatorFunction => Declaration::EvaluatorFunction(<>), + Function => Declaration::Function(<>), => Declaration::Trace(Span::new(span!(l, r), trace)), => Declaration::PublicInputs(<>), => Declaration::BoundaryConstraints(<>), @@ -256,6 +257,42 @@ EvaluatorSegmentBindings: (SourceSpan, Vec>) = { "[" "]" => (span!(l, r), vec![]), } +// FUNCTIONS +// ================================================================================================ + +Function: Function = { + "fn" "(" ")" "->" "{" "}" + => Function::new(span!(l, r), name, params, ty, body) +} + +FunctionBindings: Vec<(Identifier, Type)> = { + > => params, +} + +FunctionBinding: (Identifier, Type) = { + ":" => (name, ty), +} + +FunctionBindingType: Type = { + "felt" => Type::Felt, + "felt" => Type::Vector(size as usize), + "felt" "[" "," "]" => Type::Matrix(row_size as usize, col_size as usize), +} + +FunctionBody: Vec = { + // TODO: validate + =>? { + if stmts.len() > 1 { + diagnostics.diagnostic(Severity::Error) + .with_message("invalid function definition") + .with_primary_label(span!(l, r), "function should have 0 or more let statements followed by a return statement.") + .emit(); + return Err(ParseError::Failed.into()); + } + Ok(stmts) + }, +} + // BOUNDARY CONSTRAINTS // ================================================================================================ @@ -287,6 +324,7 @@ StatementBlock: Vec = { stmts }, , + => vec![Statement::Expr(<>)], } Let: Let = { @@ -305,6 +343,10 @@ ConstraintStatement: Vec = { "enf" => vec![<>], } +ReturnStatement: Expr = { + "return" => expr, +} + MatchArm: Statement = { "case" ":" => { let generated_name = format!("%{}", *next_var); @@ -516,6 +558,11 @@ Iterable: Expr = { => Expr::SymbolAccess(SymbolAccess::new(ident.span(), ident, AccessType::Default, 0)), => Expr::Range(Span::new(span!(l, r), range)), "[" "]" => Expr::SymbolAccess(SymbolAccess::new(span!(l, r), ident, AccessType::Slice(range), 0)), + => if let ScalarExpr::Call(call) = function_call { + Expr::Call(call) + } else { + unreachable!() + } } Range: Range = { @@ -533,6 +580,15 @@ Matrix: Vec> = { Vector>, } +Tuple: Vec = { + "(" "," )*> ")" => { + let mut v = v; + v.insert(0, v2); + v.insert(0, v1); + v + } +}; + Size: u64 = { "[" "]" => <> } @@ -591,10 +647,13 @@ extern { "last" => Token::Last, "integrity_constraints" => Token::IntegrityConstraints, "ev" => Token::Ev, + "fn" => Token::Fn, "enf" => Token::Enf, + "return" => Token::Return, "match" => Token::Match, "case" => Token::Case, "when" => Token::When, + "felt" => Token::Felt, "'" => Token::Quote, "=" => Token::Equal, "+" => Token::Plus, @@ -615,5 +674,6 @@ extern { "}" => Token::RBrace, "." => Token::Dot, ".." => Token::DotDot, + "->" => Token::Arrow, } } diff --git a/parser/src/parser/tests/functions.rs b/parser/src/parser/tests/functions.rs new file mode 100644 index 00000000..befa6411 --- /dev/null +++ b/parser/src/parser/tests/functions.rs @@ -0,0 +1,539 @@ +use miden_diagnostics::{SourceSpan, Span}; + +use crate::ast::*; + +use super::ParseTest; + +// PURE FUNCTIONS +// ================================================================================================ + +#[test] +fn fn_def_with_scalars() { + let source = " + mod test + + fn fn_with_scalars(a: felt, b: felt) -> felt { + return a + b + }"; + + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + expected.functions.insert( + ident!(fn_with_scalars), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(fn_with_scalars), + vec![(ident!(a), Type::Felt), (ident!(b), Type::Felt)], + Type::Felt, + vec![return_!(expr!(add!(access!(a), access!(b))))], + ), + ); + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn fn_def_with_vectors() { + let source = " + mod test + + fn fn_with_vectors(a: felt[12], b: felt[12]) -> felt[12] { + return [x + y for (x, y) in (a, b)] + }"; + + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + expected.functions.insert( + ident!(fn_with_vectors), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(fn_with_vectors), + vec![(ident!(a), Type::Vector(12)), (ident!(b), Type::Vector(12))], + Type::Vector(12), + vec![return_!(expr!( + lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => + add!(access!(x), access!(y))) + ))], + ), + ); + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn fn_use_scalars_and_vectors() { + let source = " + def root + + public_inputs { + stack_inputs: [16] + } + + trace_columns { + main: [a, b[12]] + } + + fn fn_with_scalars_and_vectors(a: felt, b: felt[12]) -> felt { + return sum([a + x for x in b]) + } + + boundary_constraints { + enf a.first = 0 + } + + integrity_constraints { + enf a' = fn_with_scalars_and_vectors(a, b) + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(root)); + + expected.functions.insert( + ident!(fn_with_scalars_and_vectors), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(fn_with_scalars_and_vectors), + vec![(ident!(a), Type::Felt), (ident!(b), Type::Vector(12))], + Type::Felt, + vec![return_!(expr!(call!(sum(expr!( + lc!(((x, expr!(access!(b)))) => add!(access!(a), access!(x))) + )))))], + ), + ); + + expected + .trace_columns + .push(trace_segment!(0, "$main", [(a, 1), (b, 12)])); + + expected.public_inputs.insert( + ident!(stack_inputs), + PublicInput::new(SourceSpan::UNKNOWN, ident!(stack_inputs), 16), + ); + + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], + )); + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!( + access!(a, 1), + call!(fn_with_scalars_and_vectors( + expr!(access!(a)), + expr!(access!(b)) + )) + ))], + )); + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn fn_call_in_fn() { + let source = " + def root + + public_inputs { + stack_inputs: [16] + } + + trace_columns { + main: [a, b[12]] + } + + fn fold_vec(a: felt[12]) -> felt { + return sum([x for x in a]) + } + + fn fold_scalar_and_vec(a: felt, b: felt[12]) -> felt { + return a + fold_vec(b) + } + + boundary_constraints { + enf a.first = 0 + } + + integrity_constraints { + enf a' = fold_scalar_and_vec(a, b) + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(root)); + + expected.functions.insert( + ident!(fold_vec), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(fold_vec), + vec![(ident!(a), Type::Vector(12))], + Type::Felt, + vec![return_!(expr!(call!(sum(expr!( + lc!(((x, expr!(access!(a)))) => access!(x)) + )))))], + ), + ); + + expected.functions.insert( + ident!(fold_scalar_and_vec), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(fold_scalar_and_vec), + vec![(ident!(a), Type::Felt), (ident!(b), Type::Vector(12))], + Type::Felt, + vec![return_!(expr!(add!( + access!(a), + call!(fold_vec(expr!(access!(b)))) + )))], + ), + ); + + expected + .trace_columns + .push(trace_segment!(0, "$main", [(a, 1), (b, 12)])); + + expected.public_inputs.insert( + ident!(stack_inputs), + PublicInput::new(SourceSpan::UNKNOWN, ident!(stack_inputs), 16), + ); + + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], + )); + + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!( + access!(a, 1), + call!(fold_scalar_and_vec(expr!(access!(a)), expr!(access!(b)))) + ))], + )); + + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn fn_call_in_ev() { + let source = " + def root + + public_inputs { + stack_inputs: [16] + } + + trace_columns { + main: [a, b[12]] + } + + fn fold_vec(a: felt[12]) -> felt { + return sum([x for x in a]) + } + + fn fold_scalar_and_vec(a: felt, b: felt[12]) -> felt { + return a + fold_vec(b) + } + + ev evaluator([a, b[12]]) { + enf a' = fold_scalar_and_vec(a, b) + } + + boundary_constraints { + enf a.first = 0 + } + + integrity_constraints { + enf evaluator(a, b) + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(root)); + + expected.functions.insert( + ident!(fold_vec), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(fold_vec), + vec![(ident!(a), Type::Vector(12))], + Type::Felt, + vec![return_!(expr!(call!(sum(expr!( + lc!(((x, expr!(access!(a)))) => access!(x)) + )))))], + ), + ); + + expected.functions.insert( + ident!(fold_scalar_and_vec), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(fold_scalar_and_vec), + vec![(ident!(a), Type::Felt), (ident!(b), Type::Vector(12))], + Type::Felt, + vec![return_!(expr!(add!( + access!(a), + call!(fold_vec(expr!(access!(b)))) + )))], + ), + ); + + expected.evaluators.insert( + ident!(evaluator), + EvaluatorFunction::new( + SourceSpan::UNKNOWN, + ident!(evaluator), + vec![trace_segment!(0, "%0", [(a, 1), (b, 12)])], + vec![enforce!(eq!( + access!(a, 1), + call!(fold_scalar_and_vec(expr!(access!(a)), expr!(access!(b)))) + ))], + ), + ); + + expected + .trace_columns + .push(trace_segment!(0, "$main", [(a, 1), (b, 12)])); + + expected.public_inputs.insert( + ident!(stack_inputs), + PublicInput::new(SourceSpan::UNKNOWN, ident!(stack_inputs), 16), + ); + + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], + )); + + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(call!(evaluator( + expr!(access!(a)), + expr!(access!(b)) + )))], + )); + + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn fn_as_lc_iterables() { + let source = " + def root + + public_inputs { + stack_inputs: [16] + } + + trace_columns { + main: [a[12], b[12]] + } + + fn operation(a: felt, b: felt) -> felt { + let x = a^b + 1 + return b^x + } + + boundary_constraints { + enf a.first = 0 + } + + integrity_constraints { + enf a' = sum([operation(x, y) for (x, y) in (a, b)]) + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(root)); + + expected.functions.insert( + ident!(operation), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(operation), + vec![(ident!(a), Type::Felt), (ident!(b), Type::Felt)], + Type::Felt, + vec![ + let_!(x = expr!(add!(exp!(access!(a), access!(b)), int!(1))) => + return_!(expr!(exp!(access!(b), access!(x))))), + ], + ), + ); + + expected + .trace_columns + .push(trace_segment!(0, "$main", [(a, 12), (b, 12)])); + + expected.public_inputs.insert( + ident!(stack_inputs), + PublicInput::new(SourceSpan::UNKNOWN, ident!(stack_inputs), 16), + ); + + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], + )); + + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!( + access!(a, 1), + call!(sum(expr!( + lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => call!(operation( + expr!(access!(x)), + expr!(access!(y)) + )) + ) + ))) + ))], + )); + + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn fn_call_in_binary_ops() { + let source = " + def root + + public_inputs { + stack_inputs: [16] + } + + trace_columns { + main: [a[12], b[12]] + } + + fn operation(a: felt[12], b: felt[12]) -> felt { + return sum([x + y for (x, y) in (a, b)]) + } + + boundary_constraints { + enf a[0].first = 0 + } + + integrity_constraints { + enf a[0]' = a[0] * operation(a, b) + enf b[0]' = b[0] * operation(a, b) + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(root)); + + expected.functions.insert( + ident!(operation), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(operation), + vec![(ident!(a), Type::Vector(12)), (ident!(b), Type::Vector(12))], + Type::Felt, + vec![return_!(expr!(call!(sum(expr!( + lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => add!( + access!(x), + access!(y) + )) + )))))], + ), + ); + + expected + .trace_columns + .push(trace_segment!(0, "$main", [(a, 12), (b, 12)])); + + expected.public_inputs.insert( + ident!(stack_inputs), + PublicInput::new(SourceSpan::UNKNOWN, ident!(stack_inputs), 16), + ); + + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!( + bounded_access!(a[0], Boundary::First), + int!(0) + ))], + )); + + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![ + enforce!(eq!( + access!(a[0], 1), + mul!( + access!(a[0], 0), + call!(operation(expr!(access!(a)), expr!(access!(b)))) + ) + )), + enforce!(eq!( + access!(b[0], 1), + mul!( + access!(b[0], 0), + call!(operation(expr!(access!(a)), expr!(access!(b)))) + ) + )), + ], + )); + + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn fn_call_in_vector_def() { + let source = " + def root + + public_inputs { + stack_inputs: [16] + } + + trace_columns { + main: [a[12], b[12]] + } + + fn operation(a: felt[12], b: felt[12]) -> felt[12] { + return [x + y for (x, y) in (a, b)] + } + + boundary_constraints { + enf a[0].first = 0 + } + + integrity_constraints { + let d = [a[0] * operation(a, b), b[0] * operation(a, b)] + enf a[0]' = d[0] + enf b[0]' = d[1] + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(root)); + + expected.functions.insert( + ident!(operation), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(operation), + vec![(ident!(a), Type::Vector(12)), (ident!(b), Type::Vector(12))], + Type::Vector(12), + vec![return_!(expr!( + lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => add!( + access!(x), + access!(y) + )) + ))], + ), + ); + + expected + .trace_columns + .push(trace_segment!(0, "$main", [(a, 12), (b, 12)])); + + expected.public_inputs.insert( + ident!(stack_inputs), + PublicInput::new(SourceSpan::UNKNOWN, ident!(stack_inputs), 16), + ); + + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!( + bounded_access!(a[0], Boundary::First), + int!(0) + ))], + )); + + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![let_!( + d = vector!( + mul!(access!(a[0]), call!(operation(expr!(access!(a)), expr!(access!(b))))), + mul!(access!(b[0]), call!(operation(expr!(access!(a)), expr!(access!(b)))))) + => + enforce!(eq!(access!(a[0], 1), access!(d[0]))), + enforce!(eq!(access!(b[0], 1), access!(d[1]))))], + )); + + ParseTest::new().expect_module_ast(source, expected); +} diff --git a/parser/src/parser/tests/inlining.rs b/parser/src/parser/tests/inlining.rs index 0fad7aa7..5289ed08 100644 --- a/parser/src/parser/tests/inlining.rs +++ b/parser/src/parser/tests/inlining.rs @@ -1155,6 +1155,121 @@ fn test_inlining_constraints_with_folded_comprehensions_in_evaluator() { assert_eq!(program, expected); } +#[test] +fn test_inlining_with_function_call_as_binary_operand() { + let root = r#" + def root + + trace_columns { + main: [clk, a, b[4], c] + } + + public_inputs { + inputs: [0] + } + + integrity_constraints { + let complex_fold = fold_sum(b) * fold_vec(b) + enf complex_fold = 1 + } + + boundary_constraints { + enf clk.first = 0 + } + + fn fold_sum(a: felt[4]) -> felt { + return a[0] + a[1] + a[2] + a[3] + } + + fn fold_vec(a: felt[4]) -> felt { + let m = a[0] * a[1] + let n = m * a[2] + let o = n * a[3] + return o + } + "#; + + let test = ParseTest::new(); + let program = match test.parse_program(root) { + Err(err) => { + test.diagnostics.emit(err); + panic!("expected parsing to succeed, see diagnostics for details"); + } + Ok(ast) => ast, + }; + + let mut pipeline = + ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); + let program = pipeline.run(program).unwrap(); + + let mut expected = Program::new(ident!(root)); + expected.trace_columns.push(trace_segment!( + 0, + "$main", + [(clk, 1), (a, 1), (b, 4), (c, 1)] + )); + expected.public_inputs.insert( + ident!(inputs), + PublicInput::new(SourceSpan::UNKNOWN, ident!(inputs), 0), + ); + expected.functions.insert( + function_ident!(root, fold_sum), + Function::new( + SourceSpan::UNKNOWN, + ident!(fold_sum), + vec![(ident!(a), Type::Vector(4))], + Type::Felt, + vec![return_!(expr!(add!( + add!( + add!(access!(a[0], Type::Felt), access!(a[1], Type::Felt)), + access!(a[2], Type::Felt) + ), + access!(a[3], Type::Felt) + )))], + ), + ); + expected.functions.insert( + function_ident!(root, fold_vec), + Function::new( + SourceSpan::UNKNOWN, + ident!(fold_vec), + vec![(ident!(a), Type::Vector(4))], + Type::Felt, + vec![ + let_!("m" = expr!(mul!(access!(a[0], Type::Felt), access!(a[1], Type::Felt))) + => let_!("n" = expr!(mul!(access!(m, Type::Felt), access!(a[2], Type::Felt))) + => let_!("o" = expr!(mul!(access!(n, Type::Felt), access!(a[3], Type::Felt))) + => return_!(expr!(access!(o, Type::Felt))) + ))), + ], + ), + ); + // The sole boundary constraint is already minimal + expected.boundary_constraints.push(enforce!(eq!( + bounded_access!(clk, Boundary::First, Type::Felt), + int!(0) + ))); + // With constant propagation and inlining done + // + // let %0 = b[0] + b[1] + b[2] + b[3] + // let m = b[0] * b[1] + // let n = m * b[2] + // let o = n * b[3] + // let %1 = o + // let complex_fold = %0 * %1 + // enf complex_fold = 1 + expected.integrity_constraints.push( + let_!("%0" = expr!(add!(add!(add!(access!(b[0], Type::Felt), access!(b[1], Type::Felt)), access!(b[2], Type::Felt)), access!(b[3], Type::Felt))) + => let_!(m = expr!(mul!(access!(b[0], Type::Felt), access!(b[1], Type::Felt))) + => let_!(n = expr!(mul!(access!(m, Type::Felt), access!(b[2], Type::Felt))) + => let_!(o = expr!(mul!(access!(n, Type::Felt), access!(b[3], Type::Felt))) + => let_!(complex_fold = expr!(mul!(access!("%0", Type::Felt), access!(o, Type::Felt))) + => enforce!(eq!(access!(complex_fold, Type::Felt), int!(1)))))))) + ); + + assert_eq!(program, expected); +} + /// This test originally reproduced the bug in air-script#340, but as of this commit /// that bug is fixed. This test remains not to prevent regressions necessarily, but /// to add a more realistic test case to our test suite, and potentially catch bugs @@ -1265,51 +1380,45 @@ fn test_repro_issue340() { .integrity_constraints .push(enforce!(eq!(exp!(access.clone(), int!(2)), access.clone()))); } - let word_sum = (1..32) - .into_iter() - .fold(access!("%lc0", Type::Felt), |acc, i| { - let access = ScalarExpr::SymbolAccess(SymbolAccess { - span: miden_diagnostics::SourceSpan::UNKNOWN, - name: ResolvableIdentifier::Local(Identifier::new( - miden_diagnostics::SourceSpan::UNKNOWN, - crate::Symbol::intern(format!("%lc{}", i)), - )), - access_type: AccessType::Default, - offset: 0, - ty: Some(Type::Felt), - }); - add!(acc, access) + let word_sum = (1..32).fold(access!("%lc0", Type::Felt), |acc, i| { + let access = ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(Identifier::new( + miden_diagnostics::SourceSpan::UNKNOWN, + crate::Symbol::intern(format!("%lc{}", i)), + )), + access_type: AccessType::Default, + offset: 0, + ty: Some(Type::Felt), }); - let high_bit_sum = (33..53) - .into_iter() - .fold(access!("%lc32", Type::Felt), |acc, i| { - let access = ScalarExpr::SymbolAccess(SymbolAccess { - span: miden_diagnostics::SourceSpan::UNKNOWN, - name: ResolvableIdentifier::Local(Identifier::new( - miden_diagnostics::SourceSpan::UNKNOWN, - crate::Symbol::intern(format!("%lc{}", i)), - )), - access_type: AccessType::Default, - offset: 0, - ty: Some(Type::Felt), - }); - add!(acc, access) + add!(acc, access) + }); + let high_bit_sum = (33..53).fold(access!("%lc32", Type::Felt), |acc, i| { + let access = ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(Identifier::new( + miden_diagnostics::SourceSpan::UNKNOWN, + crate::Symbol::intern(format!("%lc{}", i)), + )), + access_type: AccessType::Default, + offset: 0, + ty: Some(Type::Felt), }); - let low_bit_sum = (54..64) - .into_iter() - .fold(access!("%lc53", Type::Felt), |acc, i| { - let access = ScalarExpr::SymbolAccess(SymbolAccess { - span: miden_diagnostics::SourceSpan::UNKNOWN, - name: ResolvableIdentifier::Local(Identifier::new( - miden_diagnostics::SourceSpan::UNKNOWN, - crate::Symbol::intern(format!("%lc{}", i)), - )), - access_type: AccessType::Default, - offset: 0, - ty: Some(Type::Felt), - }); - add!(acc, access) + add!(acc, access) + }); + let low_bit_sum = (54..64).fold(access!("%lc53", Type::Felt), |acc, i| { + let access = ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(Identifier::new( + miden_diagnostics::SourceSpan::UNKNOWN, + crate::Symbol::intern(format!("%lc{}", i)), + )), + access_type: AccessType::Default, + offset: 0, + ty: Some(Type::Felt), }); + add!(acc, access) + }); let low_bit_sum_body = let_!(low_bit_sum = expr!(low_bit_sum) => enforce!(eq!(access!(immediate, Type::Felt), add!(access!(low_bit_sum, Type::Felt), access!(high_bit_sum, Type::Felt))), when access!(s, Type::Felt)), enforce!(eq!(access!(instruction_bits[31], Type::Felt), int!(1)), when access!(s, Type::Felt))); diff --git a/parser/src/parser/tests/mod.rs b/parser/src/parser/tests/mod.rs index 035d5237..2d88454b 100644 --- a/parser/src/parser/tests/mod.rs +++ b/parser/src/parser/tests/mod.rs @@ -448,6 +448,12 @@ macro_rules! let_ { }; } +macro_rules! return_ { + ($value:expr) => { + Statement::Expr($value) + }; +} + macro_rules! enforce { ($expr:expr) => { Statement::Enforce($expr) @@ -606,6 +612,7 @@ mod calls; mod constant_propagation; mod constants; mod evaluators; +mod functions; mod identifiers; mod inlining; mod integrity_constraints; diff --git a/parser/src/sema/errors.rs b/parser/src/sema/errors.rs index 50d56f63..ab67866a 100644 --- a/parser/src/sema/errors.rs +++ b/parser/src/sema/errors.rs @@ -1,6 +1,6 @@ use miden_diagnostics::{Diagnostic, Label, SourceSpan, Spanned, ToDiagnostic}; -use crate::ast::{Identifier, InvalidExprError, ModuleId}; +use crate::ast::{Identifier, InvalidExprError, InvalidTypeError, ModuleId}; /// Represents the various module validation errors we might encounter during semantic analysis. #[derive(Debug, thiserror::Error)] @@ -31,6 +31,8 @@ pub enum SemanticAnalysisError { ImportFailed(SourceSpan), #[error(transparent)] InvalidExpr(#[from] InvalidExprError), + #[error(transparent)] + InvalidType(#[from] InvalidTypeError), #[error("module is invalid, see diagnostics for details")] Invalid, } @@ -89,6 +91,7 @@ impl ToDiagnostic for SemanticAnalysisError { .with_labels(vec![Label::primary(span.source_id(), span) .with_message("failed import occurred here")]), Self::InvalidExpr(err) => err.to_diagnostic(), + Self::InvalidType(err) => err.to_diagnostic(), Self::Invalid => Diagnostic::error().with_message("module is invalid, see diagnostics for details"), } } diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index 3bae4c32..40de33cb 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -253,6 +253,23 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { ); } + for (function_name, function) in module.functions.iter() { + let namespaced_name = NamespacedIdentifier::Function(*function_name); + if let Some((prev, _)) = self.imported.get_key_value(&namespaced_name) { + self.declaration_import_conflict(namespaced_name.span(), prev.span())?; + } + assert_eq!( + self.locals.insert( + namespaced_name, + BindingType::Function(FunctionType::Function( + function.param_types(), + function.return_type + )) + ), + None + ); + } + // Next, we add any periodic columns to the set of local bindings. // // These _can_ conflict with globally defined names, but are guaranteed not to conflict @@ -287,6 +304,10 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { self.visit_mut_evaluator_function(evaluator)?; } + for function in module.functions.values_mut() { + self.visit_mut_function(function)?; + } + if let Some(boundary_constraints) = module.boundary_constraints.as_mut() { if !boundary_constraints.is_empty() { self.visit_mut_boundary_constraints(boundary_constraints)?; @@ -363,6 +384,48 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { ControlFlow::Continue(()) } + fn visit_mut_function( + &mut self, + function: &mut Function, + ) -> ControlFlow { + // constraints are not allowed in pure functions + self.constraint_mode = ConstraintMode::None; + + // Start a new lexical scope + self.locals.enter(); + + // Track referenced imports in a new context, as we want to update the dependency graph + // for this function using only those imports referenced from this function body + let referenced = mem::take(&mut self.referenced); + + // Add the set of parameters to the current scope, check for conflicts + for (param, param_type) in function.params.iter_mut() { + let namespaced_name = NamespacedIdentifier::Binding(*param); + self.locals + .insert(namespaced_name, BindingType::Local(*param_type)); + } + + // Visit all of the statements in the body + self.visit_mut_statement_block(&mut function.body)?; + + // Update the dependency graph for this function + let current_item = QualifiedIdentifier::new( + self.current_module.unwrap(), + NamespacedIdentifier::Function(function.name), + ); + for (referenced_item, ref_type) in self.referenced.iter() { + let referenced_item = self.deps.add_node(*referenced_item); + self.deps.add_edge(current_item, referenced_item, *ref_type); + } + + // Restore the original references metadata + self.referenced = referenced; + // Restore the original lexical scope + self.locals.exit(); + + ControlFlow::Continue(()) + } + fn visit_mut_boundary_constraints( &mut self, body: &mut Vec, @@ -598,6 +661,7 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { } // TODO: When we have non-evaluator functions, we must fetch the type in its signature here, // and store it as the type of the Call expression + expr.ty = fty.result(); } } else { self.has_type_errors = true; diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs index ffb27629..6f77c315 100644 --- a/parser/src/transforms/constant_propagation.rs +++ b/parser/src/transforms/constant_propagation.rs @@ -69,6 +69,11 @@ impl<'a> ConstantPropagation<'a> { self.visit_mut_evaluator_function(evaluator)?; } + // Visit all of the functions + for function in program.functions.values_mut() { + self.visit_mut_function(function)?; + } + // Visit all of the constraints self.visit_mut_boundary_constraints(&mut program.boundary_constraints)?; self.visit_mut_integrity_constraints(&mut program.integrity_constraints) diff --git a/parser/src/transforms/inlining.rs b/parser/src/transforms/inlining.rs index b8a34383..4ee69400 100644 --- a/parser/src/transforms/inlining.rs +++ b/parser/src/transforms/inlining.rs @@ -1,10 +1,11 @@ use std::{ collections::{BTreeMap, HashMap, HashSet, VecDeque}, ops::ControlFlow, + vec, }; use air_pass::Pass; -use miden_diagnostics::{DiagnosticsHandler, SourceSpan, Span, Spanned}; +use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; use crate::{ ast::{visit::VisitMut, *}, @@ -19,12 +20,12 @@ use super::constant_propagation; /// * Monomorphizing and inlining evaluators/functions at their call sites /// * Unrolling constraint comprehensions into a sequence of scalar constraints /// * Unrolling list comprehensions into a tree of `let` statements which end in -/// a vector expression (the implicit result of the tree). Each iteration of the -/// unrolled comprehension is reified as a value and bound to a variable so that -/// other transformations may refer to it directly. +/// a vector expression (the implicit result of the tree). Each iteration of the +/// unrolled comprehension is reified as a value and bound to a variable so that +/// other transformations may refer to it directly. /// * Rewriting aliases of top-level declarations to refer to those declarations directly /// * Removing let-bound variables which are unused, which is also used to clean up -/// after the aliasing rewrite mentioned above. +/// after the aliasing rewrite mentioned above. /// /// The trickiest transformation comes with inlining the body of evaluators at their /// call sites, as evaluator parameter lists can arbitrarily destructure/regroup columns @@ -75,13 +76,23 @@ pub struct Inlining<'a> { imported: HashMap, /// All evaluator functions in the program evaluators: HashMap, + /// All pure functions in the program + functions: HashMap, /// A set of identifiers for which accesses should be rewritten. /// /// When an identifier is in this set, it means it is a local alias for a trace column, /// and should be rewritten based on the current `BindingType` associated with the alias /// identifier in `bindings`. rewrites: HashSet, + /// The call stack during expansion of a function call. + /// + /// Each time we begin to expand a call, we check if it is already present on the call + /// stack, and if so, raise a diagnostic due to infinite recursion. If not, the callee + /// is pushed on the stack while we expand its body. When we finish expanding the body + /// of the callee, we pop it off this stack, and proceed as usual. + call_stack: Vec, in_comprehension_constraint: bool, + next_ident_lc: usize, next_ident: usize, } impl<'p> Pass for Inlining<'p> { @@ -97,6 +108,12 @@ impl<'p> Pass for Inlining<'p> { .map(|(k, v)| (*k, v.clone())) .collect(); + self.functions = program + .functions + .iter() + .map(|(k, v)| (*k, v.clone())) + .collect(); + // We'll be referencing the trace configuration during inlining, so keep a copy of it self.trace = program.trace_columns.clone(); // Same with the random values @@ -187,8 +204,11 @@ impl<'a> Inlining<'a> { let_bound: Default::default(), imported: Default::default(), evaluators: Default::default(), + functions: Default::default(), rewrites: Default::default(), in_comprehension_constraint: false, + call_stack: vec![], + next_ident_lc: 0, next_ident: 0, } } @@ -197,10 +217,16 @@ impl<'a> Inlining<'a> { /// /// This is only used when expanding list comprehensions, so we use a special prefix for /// these generated identifiers to make it clear what they were expanded from. - fn next_ident(&mut self, span: SourceSpan) -> Identifier { + fn get_next_ident_lc(&mut self, span: SourceSpan) -> Identifier { + let id = self.next_ident_lc; + self.next_ident_lc += 1; + Identifier::new(span, crate::Symbol::intern(format!("%lc{}", id))) + } + + fn get_next_ident(&mut self, span: SourceSpan) -> Identifier { let id = self.next_ident; self.next_ident += 1; - Identifier::new(span, crate::Symbol::intern(format!("%lc{}", id))) + Identifier::new(span, crate::Symbol::intern(format!("%{}", id))) } /// Inline/expand all of the statements in the `boundary_constraints` section @@ -320,9 +346,47 @@ impl<'a> Inlining<'a> { /// by replacing it with the result of expanding its body fn expand_let(&mut self, expr: Let) -> Result, SemanticAnalysisError> { let span = expr.span(); - let name = expr.name; + let mut name = expr.name; let body = expr.body; + // When expanding a `let` that was inlined at a function callsite, we must ensure that any + // let-bound variables introduced do not shadow bindings for the remaining statements in + // the body of the block at which the function is inlined. For example, consider the following: + // + // fn foo(a: felt) -> felt { + // let b = a * a + // b + // } + // + // integrity_constraints { + // let b = col[0] + // enf foo(b) = 1 + // enf b = 0 + // } + // + // If the call to `foo` is naively inlined, we will end up with: + // + // integrity_constraints { + // let b = col[0] + // let b = b * b + // enf b = 1 + // enf b = 0 + // } + // + // As you can see, this has the effect of breaking the last constraint, by changing the + // definition bound to `b` at that point in the program. + // + // To solve this, we check if we are currently expanding a `let` being inlined as part of + // a function call, and if so, we generate new variable names that will replace the originals. + if !self.call_stack.is_empty() { + name = self.get_next_ident(span); + let binding_ty = self + .expr_binding_type(&expr.value) + .expect("unexpected undefined variable"); + self.rewrites.insert(name); + self.bindings.insert(name, binding_ty); + } + // Visit the let-bound expression first, since it determines how the rest of the process goes let mut statements = match expr.value { // When expanding a call in this context, we're expecting a single @@ -337,6 +401,9 @@ impl<'a> Inlining<'a> { // // The rules for expansion are the same. Expr::ListComprehension(lc) => self.expand_comprehension(lc)?, + // The operands of a binary expression can contain function calls, so we must ensure + // that we expand the operands as needed, and then proceed with expanding the let. + Expr::Binary(expr) => self.expand_binary_expr(expr)?, // Other expressions we visit just to expand rewrites mut value => { self.rewrite_expr(&mut value)?; @@ -431,7 +498,116 @@ impl<'a> Inlining<'a> { other => unimplemented!("unhandled builtin: {}", other), } } else { - todo!("pure functions are not implemented yet") + self.expand_function_callsite(call) + } + } + + fn maybe_expand_scalar_expr( + &mut self, + mut expr: Box, + ) -> Result, Box>, SemanticAnalysisError> { + match *expr { + ScalarExpr::Binary(expr) if expr.has_block_like_expansion() => { + self.expand_binary_expr(expr).map(Ok) + } + ScalarExpr::Call(lhs) => self.expand_call(lhs).map(Ok), + _ => { + self.rewrite_scalar_expr(&mut expr)?; + Ok(Err(expr)) + } + } + } + + fn expand_binary_expr( + &mut self, + expr: BinaryExpr, + ) -> Result, SemanticAnalysisError> { + let span = expr.span(); + let op = expr.op; + let lhs = self.maybe_expand_scalar_expr(expr.lhs)?; + let rhs = self.maybe_expand_scalar_expr(expr.rhs)?; + + match (lhs, rhs) { + (Err(lhs), Err(rhs)) => Ok(vec![Statement::Expr(Expr::Binary(BinaryExpr { + span, + op, + lhs, + rhs, + }))]), + (Err(lhs), Ok(mut rhs)) => { + with_let_result(self, &mut rhs, |_, value| { + let value = core::mem::replace( + value, + Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), + ); + + Ok(Some(Statement::Expr(Expr::Binary(BinaryExpr { + span, + op, + lhs, + rhs: Box::new(value.try_into()?), + })))) + })?; + + Ok(rhs) + } + (Ok(mut lhs), Err(rhs)) => { + with_let_result(self, &mut lhs, |_, value| { + let value = core::mem::replace( + value, + Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), + ); + + Ok(Some(Statement::Expr(Expr::Binary(BinaryExpr { + span, + op, + lhs: Box::new(value.try_into()?), + rhs, + })))) + })?; + + Ok(lhs) + } + (Ok(mut lhs), Ok(mut rhs)) => { + let name = self.get_next_ident(span); + let ty = match lhs + .last() + .expect("unexpected empty expansion for scalar expression") + { + Statement::Expr(ref expr) => expr.ty(), + Statement::Let(ref expr) => expr.ty(), + _ => unreachable!(), + }; + + with_let_result(self, &mut rhs, |_, value| { + let value = core::mem::replace( + value, + Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), + ); + + Ok(Some(Statement::Expr(Expr::Binary(BinaryExpr::new( + span, + op, + ScalarExpr::SymbolAccess(SymbolAccess { + span, + name: ResolvableIdentifier::Local(name), + access_type: AccessType::Default, + offset: 0, + ty, + }), + value.try_into()?, + ))))) + })?; + + with_let_result(self, &mut lhs, move |_, value| { + let value = core::mem::replace( + value, + Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), + ); + Ok(Some(Statement::Let(Let::new(span, name, value, rhs)))) + })?; + Ok(lhs) + } } } @@ -523,18 +699,90 @@ impl<'a> Inlining<'a> { match constraint { ScalarExpr::Binary(BinaryExpr { op: BinaryOp::Eq, - mut lhs, - mut rhs, + lhs, + rhs, span, }) => { - self.rewrite_scalar_expr(lhs.as_mut())?; - self.rewrite_scalar_expr(rhs.as_mut())?; - Ok(vec![Statement::Enforce(ScalarExpr::Binary(BinaryExpr { - op: BinaryOp::Eq, - lhs, - rhs, - span, - }))]) + let lhs = self.maybe_expand_scalar_expr(lhs)?; + let rhs = self.maybe_expand_scalar_expr(rhs)?; + + match (lhs, rhs) { + (Err(lhs), Err(rhs)) => { + Ok(vec![Statement::Enforce(ScalarExpr::Binary(BinaryExpr { + span, + op: BinaryOp::Eq, + lhs, + rhs, + }))]) + } + (Err(lhs), Ok(mut rhs)) => { + with_let_result(self, &mut rhs, |_, value| { + let value = core::mem::replace( + value, + Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), + ); + + Ok(Some(Statement::Enforce(ScalarExpr::Binary(BinaryExpr { + span, + op: BinaryOp::Eq, + lhs, + rhs: Box::new(value.try_into()?), + })))) + })?; + + Ok(rhs) + } + (Ok(mut lhs), Err(rhs)) => { + with_let_result(self, &mut lhs, |_, value| { + let value = core::mem::replace( + value, + Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), + ); + + Ok(Some(Statement::Enforce(ScalarExpr::Binary(BinaryExpr { + span, + op: BinaryOp::Eq, + lhs: Box::new(value.try_into()?), + rhs, + })))) + })?; + + Ok(lhs) + } + (Ok(mut lhs), Ok(mut rhs)) => { + let name = self.get_next_ident(span); + + with_let_result(self, &mut rhs, |_, value| { + let value = core::mem::replace( + value, + Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), + ); + + Ok(Some(Statement::Enforce(ScalarExpr::Binary( + BinaryExpr::new( + span, + BinaryOp::Eq, + ScalarExpr::SymbolAccess(SymbolAccess::new( + span, + name, + AccessType::Default, + 0, + )), + value.try_into()?, + ), + )))) + })?; + + with_let_result(self, &mut lhs, move |_, value| { + let value = core::mem::replace( + value, + Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), + ); + Ok(Some(Statement::Let(Let::new(span, name, value, rhs)))) + })?; + Ok(lhs) + } + } } invalid => unreachable!("unexpected constraint node: {:#?}", invalid), } @@ -636,10 +884,51 @@ impl<'a> Inlining<'a> { /// * A flat list of constraint statements fn expand_comprehension( &mut self, - expr: ListComprehension, + mut expr: ListComprehension, ) -> Result, SemanticAnalysisError> { - // This is the vector containing the expansion result - let mut statements = vec![]; + // Lift any function calls in iterable position out of the comprehension, + // binding the result of those calls via `let`. Rewrite the iterable as + // a symbol access to the newly-bound variable. + // + // NOTE: The actual expansion of the lifted iterables occurs after we expand + // the comprehension, so that we can place the expanded comprehension in the + // body of the final let + let mut lifted_bindings = vec![]; + let mut lifted = vec![]; + for param in expr.iterables.iter_mut() { + if !matches!(param, Expr::Call(_)) { + continue; + } + + let span = param.span(); + let name = self.get_next_ident(span); + let ty = match param { + Expr::Call(Call { callee, .. }) => { + let callee = callee + .resolved() + .expect("callee should have been resolved by now"); + self.functions[&callee].return_type + } + _ => unsafe { core::hint::unreachable_unchecked() }, + }; + let param = core::mem::replace( + param, + Expr::SymbolAccess(SymbolAccess { + span, + name: ResolvableIdentifier::Local(name), + access_type: AccessType::Default, + offset: 0, + ty: Some(ty), + }), + ); + match param { + Expr::Call(call) => { + lifted_bindings.push((name, BindingType::Local(ty))); + lifted.push((name, call)); + } + _ => unsafe { core::hint::unreachable_unchecked() }, + } + } // Get the number of iterations in this comprehension let Type::Vector(num_iterations) = expr.ty.unwrap() else { @@ -647,83 +936,161 @@ impl<'a> Inlining<'a> { }; // Step the iterables for each iteration, giving each it's own lexical scope + let mut statement_groups = vec![]; for i in 0..num_iterations { self.bindings.enter(); - let mut expansion = self.expand_comprehension_iteration(&expr, i)?; - statements.append(&mut expansion); + // Ensure any lifted iterables are in scope for the expansion of this iteration + for (name, binding_ty) in lifted_bindings.iter() { + self.bindings.insert(*name, binding_ty.clone()); + } + let expansion = self.expand_comprehension_iteration(&expr, i)?; + // An expansion can be empty if a constraint selector with a constant selector expression + // evaluates to false (allowing us to elide the constraint for that iteration entirely). + if !expansion.is_empty() { + statement_groups.push(expansion); + } self.bindings.exit(); } - // If we're in a constraint comprehension, we're already fully expanded - if self.in_comprehension_constraint { - return Ok(statements); - } - - // Otherwise, this is a list comprehension, which means the current expansion - // is a flat list of statements, one for each element of the unrolled comprehension. + // At this point, we have one or more statement groups, representing the expansions + // of each iteration of the comprehension. Additionally, we may have a set of lifted + // iterables which we need to bind (and expand) "around" the expansion of the comprehension + // itself. + // + // In short, we must take this list of statement groups, and flatten/treeify it. Once + // a let binding is introduced into scope, all subsequent statements must occur in the body + // of that let, forming a tree. Consecutive statements which introduce no new bindings do + // not require any nesting, resulting in the groups containing those statements being flattened. // - // We need to convert this into a nested tree of `let` statements that bind each - // element of the comprehension to a variable, and at the bottom construct a vector - // of all the elements to return as the result of the tree. + // Lastly, whether this is a list or constraint comprehension determines if we will also be + // constructing a vector from the values produced by each iteration, and returning it as the + // result of the comprehension itself. let span = expr.span(); - // Generate a new variable name for each element in the comprehension - let mut symbols = statements - .iter() - .map(|_| self.next_ident(span)) - .collect::>(); - // Generate the list of elements for the vector which is to be the result of the let-tree - let vars = statements - .iter() - .zip(symbols.iter().copied()) - .map(|(stmt, name)| { - // The type of these statements must be known by now - let ty = match stmt { - Statement::Expr(value) => value.ty(), - Statement::Let(nested) => nested.ty(), - stmt => unreachable!( - "unexpected statement type in comprehension body: {}", - stmt.display(0) - ), - }; - Expr::SymbolAccess(SymbolAccess { - span, - name: ResolvableIdentifier::Local(name), - access_type: AccessType::Default, - offset: 0, - ty, + if self.in_comprehension_constraint { + let mut result = vec![]; + + // If for some reason, we've been able to eliminate all constraints for this comprehension, + // then return an empty block, since we need not emit any code at all if unused. + if statement_groups.is_empty() { + return Ok(result); + } + + // Each group is presumed to already be flattened/treeified + for mut group in statement_groups { + // The first group is simply flattened + if result.is_empty() { + result.append(&mut group); + continue; + } + // Ensure that all statements preceded by a let-bound variable, are nested in the + // body of that let. + match group.pop().unwrap() { + Statement::Let(mut let_expr) => { + with_let_innermost_block(self, &mut let_expr.body, |_, block| { + block.append(&mut result); + + Ok(()) + })?; + result.append(&mut group); + result.push(Statement::Let(let_expr)); + } + stmt => { + result.append(&mut group); + result.push(stmt); + } + } + } + Ok(result) + } else { + // For list comprehensions, we must emit a let tree that binds each iteration, + // and ensure that the expansion of the iteration itself is properly nested so + // that the lexical scope of all bound variables is correct. This is more complex + // than the constraint comprehension case, as we must emit a single expression + // representing the entire expansion of the comprehension as an aggregate, whereas + // constraints produce no results. + + // Generate a new variable name for each element in the comprehension + let symbols = statement_groups + .iter() + .map(|_| self.get_next_ident_lc(span)) + .collect::>(); + // Generate the list of elements for the vector which is to be the result of the let-tree + let vars = statement_groups + .iter() + .zip(symbols.iter().copied()) + .map(|(group, name)| { + // The type of these statements must be known by now + let ty = match group.last().unwrap() { + Statement::Expr(value) => value.ty(), + Statement::Let(nested) => nested.ty(), + stmt => unreachable!( + "unexpected statement type in comprehension body: {}", + stmt.display(0) + ), + }; + Expr::SymbolAccess(SymbolAccess { + span, + name: ResolvableIdentifier::Local(name), + access_type: AccessType::Default, + offset: 0, + ty, + }) }) - }) - .collect(); - // Construct the let tree by visiting the statements bottom-up - let acc = Statement::Expr(Expr::Vector(Span::new(span, vars))); - let result = statements.drain(..).zip(symbols.drain(..)).try_rfold( - acc, - |acc, (mut stmt, name)| { - match stmt { - // If the current statement is an expression, it represents the value of this - // element of the comprehension, and we must generate a let to bind it, using - // the accumulator expression as the body - Statement::Expr(value) => { - Ok(Statement::Let(Let::new(span, name, value, vec![acc]))) + .collect(); + // Construct the let tree by visiting the statements bottom-up + let acc = vec![Statement::Expr(Expr::Vector(Span::new(span, vars)))]; + let expanded = statement_groups.into_iter().zip(symbols).try_rfold( + acc, + |acc, (mut group, name)| { + match group.pop().unwrap() { + // If the current statement is an expression, it represents the value of this + // iteration of the comprehension, and we must generate a let to bind it, using + // the accumulator expression as the body + Statement::Expr(expr) => { + group.push(Statement::Let(Let::new(span, name, expr, acc))); + } + // If the current statement is a `let`-tree, we need to generate a new `let` at + // the bottom of the tree, which binds the result expression as the value of the + // generated `let`, and uses the accumulator as the body + Statement::Let(mut wrapper) => { + with_let_result(self, &mut wrapper.body, move |_, value| { + let value = core::mem::replace( + value, + Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), + ); + Ok(Some(Statement::Let(Let::new(span, name, value, acc)))) + })?; + group.push(Statement::Let(wrapper)); + } + _ => unreachable!(), + } + Ok::<_, SemanticAnalysisError>(group) + }, + )?; + // Lastly, construct the let tree for the lifted iterables, placing the expanded + // comprehension at the bottom of that tree. + lifted.into_iter().try_rfold(expanded, |acc, (name, call)| { + let span = call.span(); + let mut preamble = self.expand_call(call)?; + match preamble.pop().unwrap() { + Statement::Expr(expr) => { + preamble.push(Statement::Let(Let::new(span, name, expr, acc))); } - // If the current statement is a `let`-tree, we need to generate a new `let` at - // the bottom of the tree, which binds the result expression as the value of the - // generated `let`, and uses the accumualtor expression as the body - Statement::Let(ref mut wrapper) => { + Statement::Let(mut wrapper) => { with_let_result(self, &mut wrapper.body, move |_, value| { let value = core::mem::replace( value, Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), ); - Ok(Some(Statement::Let(Let::new(span, name, value, vec![acc])))) - }) - .map(|_| stmt) + Ok(Some(Statement::Let(Let::new(span, name, value, acc)))) + })?; + preamble.push(Statement::Let(wrapper)); } _ => unreachable!(), } - }, - )?; - Ok(vec![result]) + Ok(preamble) + }) + } } fn expand_comprehension_iteration( @@ -800,13 +1167,10 @@ impl<'a> Inlining<'a> { self.bindings.insert(binding, binding_ty); Expr::SymbolAccess(current_access) } - // TODO: Currently, calls cannot be used as iterables, because we don't have pure functions - // which can produce aggregates. However, when those are added, we may want to add support - // for that here. This branch is set up to raise an appropriate panic if we forget to do so. - Expr::Call(_) => unimplemented!("calls to functions as iterables"), - // Binary expressions are scalar, so cannot be used as iterables, and we don't (currently) - // support nested comprehensions, so it is never possible to observe these expression types here - Expr::Binary(_) | Expr::ListComprehension(_) => unreachable!(), + // Binary expressions are scalar, so cannot be used as iterables, and we don't + // (currently) support nested comprehensions, so it is never possible to observe + // these expression types here. Calls should have been lifted prior to expansion. + Expr::Call(_) | Expr::Binary(_) | Expr::ListComprehension(_) => unreachable!(), }; bound_values.insert(binding, abstract_value); } @@ -958,7 +1322,7 @@ impl<'a> Inlining<'a> { // NOTE: We create a new nested scope for the parameters in order to avoid conflicting // with the root declarations eval_bindings.enter(); - self.populate_rewrites( + self.populate_evaluator_rewrites( &mut eval_bindings, call.args.as_slice(), evaluator.params.as_slice(), @@ -976,10 +1340,142 @@ impl<'a> Inlining<'a> { Ok(evaluator.body) } + /// This function handles inlining pure function calls. + fn expand_function_callsite( + &mut self, + call: Call, + ) -> Result, SemanticAnalysisError> { + self.bindings.enter(); + // The callee is guaranteed to be resolved and exist at this point + let callee = call + .callee + .resolved() + .expect("callee should have been resolved by now"); + + if self.call_stack.contains(&callee) { + let ifd = self + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid recursive function call") + .with_primary_label(call.span, "recursion occurs due to this function call"); + self.call_stack + .iter() + .rev() + .fold(ifd, |ifd, caller| { + ifd.with_secondary_label(caller.span(), "which was called from") + }) + .emit(); + return Err(SemanticAnalysisError::Invalid); + } else { + self.call_stack.push(callee); + } + + // We clone the function here as we will be modifying the body during the + // inlining process, and we must not modify the original + let mut function = self.functions.get(&callee).unwrap().clone(); + + // This will be the initial set of bindings visible within the function body + // + // This is distinct from `self.bindings` at this point, because the function doesn't + // inherit the caller's scope, it has an entirely new one. + let mut function_bindings = LexicalScope::default(); + + // Add all referenced (and thus imported) items from the function module + // + // NOTE: This will include constants, periodic columns, and other functions + for (qid, binding_ty) in self.imported.iter() { + if qid.module == callee.module { + function_bindings.insert(*qid.as_ref(), binding_ty.clone()); + } + } + + // Add random values, trace columns, and other root declarations to the set of + // bindings visible in the function body, _if_ the function is defined in the + // root module. + let is_function_in_root = callee.module == self.root; + if is_function_in_root { + if let Some(rv) = self.random_values.as_ref() { + function_bindings.insert( + rv.name, + BindingType::RandomValue(RandBinding::new( + rv.name.span(), + rv.name, + rv.size, + 0, + Type::Vector(rv.size), + )), + ); + for binding in rv.bindings.iter().copied() { + function_bindings.insert(binding.name, BindingType::RandomValue(binding)); + } + } + + for segment in self.trace.iter() { + function_bindings.insert( + segment.name, + BindingType::TraceColumn(TraceBinding { + span: segment.name.span(), + segment: segment.id, + name: Some(segment.name), + offset: 0, + size: segment.size, + ty: Type::Vector(segment.size), + }), + ); + for binding in segment.bindings.iter().copied() { + function_bindings.insert( + binding.name.unwrap(), + BindingType::TraceColumn(TraceBinding { + span: segment.name.span(), + segment: segment.id, + name: binding.name, + offset: binding.offset, + size: binding.size, + ty: binding.ty, + }), + ); + } + } + + for input in self.public_inputs.values() { + function_bindings.insert( + input.name, + BindingType::PublicInput(Type::Vector(input.size)), + ); + } + } + + // Match call arguments to function parameters, populating the set of rewrites + // which should be performed on the inlined function body. + // + // NOTE: We create a new nested scope for the parameters in order to avoid conflicting + // with the root declarations + function_bindings.enter(); + self.populate_function_rewrites( + &mut function_bindings, + call.args.as_slice(), + function.params.as_slice(), + ); + + // While we're inlining the body, use the set of function bindings we built above + let prev_bindings = core::mem::replace(&mut self.bindings, function_bindings); + + // Expand the function body into a block of statements + self.expand_statement_block(&mut function.body)?; + + // Restore the caller's bindings before we leave + self.bindings = prev_bindings; + + // We're done expanding this call, so remove it from the call stack + self.call_stack.pop(); + + Ok(function.body) + } + /// Populate the set of access rewrites, as well as the initial set of bindings to use when inlining an evaluator function. /// /// This is done by resolving the arguments provided by the call to the evaluator, with the parameter list of the evaluator itself. - fn populate_rewrites( + fn populate_evaluator_rewrites( &mut self, eval_bindings: &mut LexicalScope, args: &[Expr], @@ -1164,6 +1660,25 @@ impl<'a> Inlining<'a> { } } + fn populate_function_rewrites( + &mut self, + function_bindings: &mut LexicalScope, + args: &[Expr], + params: &[(Identifier, Type)], + ) { + // Reset the rewrites set + self.rewrites.clear(); + + for (arg, (param_name, param_ty)) in args.iter().zip(params.iter()) { + // We can safely assume that there is a binding type available here, + // otherwise the semantic analysis pass missed something + let binding_ty = self.expr_binding_type(arg).unwrap(); + debug_assert_eq!(binding_ty.ty(), Some(*param_ty), "unexpected type mismatch"); + self.rewrites.insert(*param_name); + function_bindings.insert(*param_name, binding_ty); + } + } + /// Returns a new [SymbolAccess] which should be used in place of `access` in the current scope. /// /// This function should only be called on accesses which have a trace column/param [BindingType], @@ -1536,3 +2051,60 @@ where Ok(()) } + +fn with_let_innermost_block( + inliner: &mut Inlining, + entry: &mut Vec, + callback: F, +) -> Result<(), SemanticAnalysisError> +where + F: FnOnce(&mut Inlining, &mut Vec) -> Result<(), SemanticAnalysisError>, +{ + // Preserve the original lexical scope to be restored on exit + let prev = inliner.bindings.clone(); + + // SAFETY: We must use a raw pointer here because the Rust compiler is not able to + // see that we only ever use the mutable reference once, and that the reference + // is never aliased. + // + // Both of these guarantees are in fact upheld here however, as each iteration of the loop + // is either the last iteration (when we use the mutable reference to mutate the end of the + // bottom-most block), or a traversal to the last child of the current let expression. + // We never alias the mutable reference, and in fact immediately convert back to a mutable + // reference inside the loop to ensure that within the loop body we have some degree of + // compiler-assisted checking of that invariant. + let mut current_block = Some(entry as *mut Vec); + while let Some(parent_block) = current_block.take() { + // SAFETY: We convert the pointer back to a mutable reference here before + // we do anything else to ensure the usual aliasing rules are enforced. + // + // It is further guaranteed that this reference is never improperly aliased + // across iterations, as each iteration is visiting a child of the previous + // iteration's node, i.e. what we're doing here is equivalent to holding a + // mutable reference and using it to mutate a field in a deeply nested struct. + let parent_block = unsafe { &mut *parent_block }; + // A block is guaranteed to always have at least one statement here + if let Some(Statement::Let(ref mut let_expr)) = parent_block.last_mut() { + // Register this binding + let binding_ty = inliner.expr_binding_type(&let_expr.value).unwrap(); + inliner.bindings.insert(let_expr.name, binding_ty); + // Set up the next iteration + current_block = Some(&mut let_expr.body as *mut Vec); + continue; + } + // When we hit a block whose last statement is an expression, which + // must also be the bottom-most block of this tree. + match callback(inliner, parent_block) { + Ok(_) => break, + Err(err) => { + inliner.bindings = prev; + return Err(err); + } + } + } + + // Restore the original lexical scope + inliner.bindings = prev; + + Ok(()) +} From b3c64404678ba7293a4c2b00e6f400c98c1fff38 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Thu, 20 Jun 2024 18:59:43 -0400 Subject: [PATCH 2/6] feat: support more expressive syntax in ast This commit does two things: 1. Modifies the AST to allow `let` in expression position (scalar or otherwise). 2. Refactors the constant propagation and inlining passes to properly handle occurrances of `let` in expression position, and make use of this new capability to simplify inlining of certain syntax nodes. In particular, the inliner now makes liberal use of this new flexibility in the AST, in order to expand syntax nodes in expression position. Such nodes, with the introduction of functions, can have arbitrarily complex expansions, and with this change those expansions can now be done in-place, rather than attempting to "lift" expressions that may produce block-like expansions into the nearest containing statement block, requiring expensive let-tree rewrites. In fact, it became clear during the testing and implementation of functions that without the ability to expand the syntax tree in this manner, it would be virtually impossible to correctly inline a full AirScript program. For example, consider the following: ``` trace_columns { main: [clk, a, b[4]] } fn fold_vec(a: felt[4]) -> felt { let m = a[0] * a[1] let n = m * a[2] let o = n * a[3] return o } integrity_constraints { let o = a let m = fold_vec(b) enf o = m } ``` After inlining (the old way), we would get (commentary inline): ``` integrity_constraints { # This binding will be shadowed by the binding of the same name from # the inlined body of `fold_vec` let o = a # This `m` is the one inlined from `fold_vec` let m = b[0] * b[1] let n = m * b[2] # This `o` is the one inlined from `fold_vec`, and now shadows the `o` # bound in the original `integrity_constraints` let o = n * b[3] # The inliner moved the original binding of `m` into the innermost # `let` body, so that it can bind the value "returned" from # `fold_vec`, as expected. Because of this, it shadows the `m` that # was inlined from `fold_vec`, and no one is the wiser because the # semantics of the original code are preserved let m = o # This constraint is now incorrect, as the binding, `o`, we intended # to constrain has been shadowed by a different `o`. enf o = m } ``` To summarize, the original inliner needed to split the current block at the statement being expanded/inlined, and move code in two directions: "lifting" inlined statements into the current block before the split, and "lowering" the statements after the split by placing them in the innermost `let` block. This was necessary so that the result of the inlined function (or more generally, any expression with a block-like expansion, e.g. comprehensions) could be bound to the name used in the original source code, with all of the necessary bindings in scope so that the expression that was bound would correctly evaluate during codegen. As we can see, the result of this is that an expanded/inlined syntax node could introduce bindings that would shadow other bindings in scope, and change the behavior of the resulting program (as demonstrated above). This commit allows bindings to be introduced anywhere that an expression is valid. This has the effect of no longer requiring code motion just to support let-bound variables in an inlined/expanded expression. This simplifies the inliner quite a bit. --- air-script/tests/codegen/masm.rs | 18 +- air-script/tests/codegen/winterfell.rs | 18 +- .../tests/functions/functions_complex.air | 8 +- .../tests/functions/functions_complex.masm | 6 +- .../tests/functions/functions_complex.rs | 2 +- ir/src/passes/translate.rs | 348 ++++----- parser/Cargo.toml | 3 +- parser/src/ast/display.rs | 67 +- parser/src/ast/errors.rs | 20 +- parser/src/ast/expression.rs | 91 ++- parser/src/ast/statement.rs | 18 + parser/src/ast/visit.rs | 2 + parser/src/parser/tests/inlining.rs | 211 +++--- parser/src/parser/tests/mod.rs | 12 + parser/src/sema/scope.rs | 44 +- parser/src/sema/semantic_analysis.rs | 8 + parser/src/transforms/constant_propagation.rs | 183 ++++- parser/src/transforms/inlining.rs | 670 +++++++----------- 18 files changed, 949 insertions(+), 780 deletions(-) diff --git a/air-script/tests/codegen/masm.rs b/air-script/tests/codegen/masm.rs index b82414a0..d952e494 100644 --- a/air-script/tests/codegen/masm.rs +++ b/air-script/tests/codegen/masm.rs @@ -85,14 +85,17 @@ fn evaluators() { } #[test] -fn functions() { +fn functions_simple() { let generated_masm = Test::new("tests/functions/functions_simple.air".to_string()) .transpile(Target::Masm) .unwrap(); let expected = expect_file!["../functions/functions_simple.masm"]; expected.assert_eq(&generated_masm); +} +#[test] +fn functions_simple_inlined() { // make sure that the constraints generated using inlined functions are the same as the ones // generated using regular functions let generated_masm = Test::new("tests/functions/inlined_functions_simple.air".to_string()) @@ -100,13 +103,16 @@ fn functions() { .unwrap(); let expected = expect_file!["../functions/functions_simple.masm"]; expected.assert_eq(&generated_masm); +} - // let generated_masm = Test::new("tests/functions/functions_complex.air".to_string()) - // .transpile(Target::Masm) - // .unwrap(); +#[test] +fn functions_complex() { + let generated_masm = Test::new("tests/functions/functions_complex.air".to_string()) + .transpile(Target::Masm) + .unwrap(); - // let expected = expect_file!["../functions/functions_complex.masm"]; - // expected.assert_eq(&generated_masm); + let expected = expect_file!["../functions/functions_complex.masm"]; + expected.assert_eq(&generated_masm); } #[test] diff --git a/air-script/tests/codegen/winterfell.rs b/air-script/tests/codegen/winterfell.rs index 99117a58..3d2874e0 100644 --- a/air-script/tests/codegen/winterfell.rs +++ b/air-script/tests/codegen/winterfell.rs @@ -85,14 +85,17 @@ fn evaluators() { } #[test] -fn functions() { +fn functions_simple() { let generated_air = Test::new("tests/functions/functions_simple.air".to_string()) .transpile(Target::Winterfell) .unwrap(); let expected = expect_file!["../functions/functions_simple.rs"]; expected.assert_eq(&generated_air); +} +#[test] +fn functions_simple_inlined() { // make sure that the constraints generated using inlined functions are the same as the ones // generated using regular functions let generated_air = Test::new("tests/functions/inlined_functions_simple.air".to_string()) @@ -101,13 +104,16 @@ fn functions() { let expected = expect_file!["../functions/functions_simple.rs"]; expected.assert_eq(&generated_air); +} - // let generated_air = Test::new("tests/functions/functions_complex.air".to_string()) - // .transpile(Target::Winterfell) - // .unwrap(); +#[test] +fn functions_complex() { + let generated_air = Test::new("tests/functions/functions_complex.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); - // let expected = expect_file!["../functions/functions_complex.rs"]; - // expected.assert_eq(&generated_air); + let expected = expect_file!["../functions/functions_complex.rs"]; + expected.assert_eq(&generated_air); } #[test] diff --git a/air-script/tests/functions/functions_complex.air b/air-script/tests/functions/functions_complex.air index b72c0437..f4b0607f 100644 --- a/air-script/tests/functions/functions_complex.air +++ b/air-script/tests/functions/functions_complex.air @@ -37,9 +37,9 @@ integrity_constraints { let f = get_multiplicity_flags(s0, s1) let z = v^4 * f[3] + v^2 * f[2] + v * f[1] + f[0] # let folded_value = fold_scalar_and_vec(v, b) - # enf b_range' = b_range * (z * t - t + 1) - enf b_range' = b_range * 2 - # let y = fold_scalar_and_vec(v, b) + enf b_range' = b_range * (z * t - t + 1) + # enf b_range' = b_range * 2 + let y = fold_scalar_and_vec(v, b) # let c = fold_scalar_and_vec(t, b) - # enf v' = y + enf v' = y } diff --git a/air-script/tests/functions/functions_complex.masm b/air-script/tests/functions/functions_complex.masm index a145bd1b..93dd497d 100644 --- a/air-script/tests/functions/functions_complex.masm +++ b/air-script/tests/functions/functions_complex.masm @@ -79,7 +79,7 @@ proc.compute_integrity_constraints # Multiply by the composition coefficient padw mem_loadw.4294900200 movdn.3 movdn.3 drop drop ext2mul # integrity constraint 0 for aux - padw mem_loadw.4294900072 drop drop padw mem_loadw.4294900072 movdn.3 movdn.3 drop drop padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add + padw mem_loadw.4294900072 drop drop padw mem_loadw.4294900072 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop # push the accumulator to the stack push.1 movdn.2 push.0 movdn.2 # => [b1, b0, r1, r0, ...] @@ -93,7 +93,7 @@ proc.compute_integrity_constraints # clean stack drop drop # => [r1, r0, ...] (2 cycles) - padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add + padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2mul padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop # push the accumulator to the stack push.1 movdn.2 push.0 movdn.2 # => [b1, b0, r1, r0, ...] @@ -107,7 +107,7 @@ proc.compute_integrity_constraints # clean stack drop drop # => [r1, r0, ...] (2 cycles) - push.1 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2sub padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2mul ext2add padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul ext2mul ext2add push.1 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 ext2add ext2mul ext2sub + push.1 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2sub padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2mul ext2add padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul ext2mul ext2add push.1 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 ext2add ext2mul ext2sub # Multiply by the composition coefficient padw mem_loadw.4294900200 drop drop ext2mul end # END PROC compute_integrity_constraints diff --git a/air-script/tests/functions/functions_complex.rs b/air-script/tests/functions/functions_complex.rs index 4f4883f5..baa2cad5 100644 --- a/air-script/tests/functions/functions_complex.rs +++ b/air-script/tests/functions/functions_complex.rs @@ -86,6 +86,6 @@ impl Air for FunctionsAir { let main_next = main_frame.next(); let aux_current = aux_frame.current(); let aux_next = aux_frame.next(); - result[0] = aux_next[0] - aux_current[0] * (((aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])).exp(E::PositiveInteger::from(4_u64)) * E::from(main_current[1]) * E::from(main_current[2]) + (aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])).exp(E::PositiveInteger::from(2_u64)) * (E::ONE - E::from(main_current[1])) * E::from(main_current[2]) + (aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])) * E::from(main_current[1]) * (E::ONE - E::from(main_current[2])) + (E::ONE - E::from(main_current[1])) * (E::ONE - E::from(main_current[2]))) * E::from(main_current[0]) - E::from(main_current[0]) + E::ONE); + result[0] = aux_next[0] - aux_current[0] * ((E::from(main_current[3]).exp(E::PositiveInteger::from(4_u64)) * E::from(main_current[1]) * E::from(main_current[2]) + E::from(main_current[3]).exp(E::PositiveInteger::from(2_u64)) * (E::ONE - E::from(main_current[1])) * E::from(main_current[2]) + E::from(main_current[3]) * E::from(main_current[1]) * (E::ONE - E::from(main_current[2])) + (E::ONE - E::from(main_current[1])) * (E::ONE - E::from(main_current[2]))) * E::from(main_current[0]) - E::from(main_current[0]) + E::ONE); } } \ No newline at end of file diff --git a/ir/src/passes/translate.rs b/ir/src/passes/translate.rs index 61ce38b4..28876523 100644 --- a/ir/src/passes/translate.rs +++ b/ir/src/passes/translate.rs @@ -1,6 +1,4 @@ -use std::collections::HashMap; - -use air_parser::ast; +use air_parser::{ast, LexicalScope}; use air_pass::Pass; use miden_diagnostics::{DiagnosticsHandler, Severity, Span, Spanned}; @@ -23,8 +21,6 @@ impl<'p> Pass for AstToAir<'p> { type Error = CompileError; fn run<'a>(&mut self, program: Self::Input<'a>) -> Result, Self::Error> { - dbg!(&program); - let mut air = Air::new(program.name); let random_values = program.random_values; @@ -57,7 +53,7 @@ impl<'p> Pass for AstToAir<'p> { } } -#[derive(Clone)] +#[derive(Debug, Clone)] enum MemoizedBinding { /// The binding was reduced to a node in the graph Scalar(NodeIndex), @@ -72,7 +68,7 @@ struct AirBuilder<'a> { air: &'a mut Air, random_values: Option, trace_columns: Vec, - bindings: HashMap, + bindings: LexicalScope, } impl<'a> AirBuilder<'a> { fn build_boundary_constraint(&mut self, bc: &ast::Statement) -> Result<(), CompileError> { @@ -142,145 +138,13 @@ impl<'a> AirBuilder<'a> { where F: FnMut(&mut AirBuilder, &ast::Statement) -> Result<(), CompileError>, { - let prev = self.bindings.clone(); - match expr.value { - ast::Expr::Const(ref constant) => match &constant.item { - ast::ConstantExpr::Scalar(value) => { - let value = self.insert_constant(*value); - self.bindings - .insert(expr.name, MemoizedBinding::Scalar(value)); - } - ast::ConstantExpr::Vector(values) => { - let values = self.insert_constants(values.as_slice()); - self.bindings - .insert(expr.name, MemoizedBinding::Vector(values)); - } - ast::ConstantExpr::Matrix(values) => { - let values = values - .iter() - .map(|vs| self.insert_constants(vs.as_slice())) - .collect(); - self.bindings - .insert(expr.name, MemoizedBinding::Matrix(values)); - } - }, - ast::Expr::Range(ref values) => { - let values = values - .item - .clone() - .map(|v| self.insert_constant(v as u64)) - .collect(); - self.bindings - .insert(expr.name, MemoizedBinding::Vector(values)); - } - ast::Expr::Vector(ref values) => match values[0].ty().unwrap() { - ast::Type::Felt => { - let mut nodes = vec![]; - for value in values.iter().cloned() { - let value = value.try_into().unwrap(); - nodes.push(self.insert_scalar_expr(&value)); - } - self.bindings - .insert(expr.name, MemoizedBinding::Vector(nodes)); - } - ast::Type::Vector(n) => { - let mut nodes = vec![]; - for row in values.iter().cloned() { - match row { - ast::Expr::Const(Span { - item: ast::ConstantExpr::Vector(vs), - .. - }) => { - nodes.push(self.insert_constants(vs.as_slice())); - } - ast::Expr::SymbolAccess(access) => { - let mut cols = vec![]; - for i in 0..n { - let access = ast::ScalarExpr::SymbolAccess( - access.access(AccessType::Index(i)).unwrap(), - ); - let node = self.insert_scalar_expr(&access); - cols.push(node); - } - nodes.push(cols); - } - ast::Expr::Vector(ref elems) => { - let mut cols = vec![]; - for elem in elems.iter().cloned() { - let elem: ast::ScalarExpr = elem.try_into().unwrap(); - let node = self.insert_scalar_expr(&elem); - cols.push(node); - } - nodes.push(cols); - } - _ => unreachable!(), - } - } - self.bindings - .insert(expr.name, MemoizedBinding::Matrix(nodes)); - } - _ => unreachable!(), - }, - ast::Expr::Matrix(ref values) => { - let values = values - .iter() - .map(|vs| vs.iter().map(|v| self.insert_scalar_expr(v)).collect()) - .collect(); - self.bindings - .insert(expr.name, MemoizedBinding::Matrix(values)); - } - ast::Expr::Binary(ref bexpr) => { - let value = self.insert_binary_expr(bexpr); - self.bindings - .insert(expr.name, MemoizedBinding::Scalar(value)); - } - ast::Expr::SymbolAccess(ref access) => { - match self.bindings.get(access.name.as_ref()) { - None => { - // Must be a reference to a declaration - let value = self.insert_symbol_access(access); - self.bindings - .insert(expr.name, MemoizedBinding::Scalar(value)); - } - Some(MemoizedBinding::Scalar(node)) => { - assert_eq!(access.access_type, AccessType::Default); - self.bindings - .insert(expr.name, MemoizedBinding::Scalar(*node)); - } - Some(MemoizedBinding::Vector(nodes)) => { - let value = match &access.access_type { - AccessType::Default => MemoizedBinding::Vector(nodes.clone()), - AccessType::Index(idx) => MemoizedBinding::Scalar(nodes[*idx]), - AccessType::Slice(range) => { - MemoizedBinding::Vector(nodes[range.start..range.end].to_vec()) - } - AccessType::Matrix(_, _) => unreachable!(), - }; - self.bindings.insert(expr.name, value); - } - Some(MemoizedBinding::Matrix(nodes)) => { - let value = match &access.access_type { - AccessType::Default => MemoizedBinding::Matrix(nodes.clone()), - AccessType::Index(idx) => MemoizedBinding::Vector(nodes[*idx].clone()), - AccessType::Slice(range) => { - MemoizedBinding::Matrix(nodes[range.start..range.end].to_vec()) - } - AccessType::Matrix(row, col) => { - MemoizedBinding::Scalar(nodes[*row][*col]) - } - }; - self.bindings.insert(expr.name, value); - } - } - } - ast::Expr::Call(_) | ast::Expr::ListComprehension(_) => unreachable!(), - } - - for statement in expr.body.iter() { - statement_builder(self, statement)?; + let bound = self.eval_expr(&expr.value)?; + self.bindings.enter(); + self.bindings.insert(expr.name, bound); + for stmt in expr.body.iter() { + statement_builder(self, stmt)?; } - - self.bindings = prev; + self.bindings.exit(); Ok(()) } @@ -328,7 +192,7 @@ impl<'a> AirBuilder<'a> { let lhs = self.insert_op(Operation::Value(Value::TraceAccess(trace_access))); // Insert the right-hand expression into the graph - let rhs = self.insert_scalar_expr(rhs); + let rhs = self.insert_scalar_expr(rhs)?; // Compare the inferred trace segment and domain of the operands let domain = access.boundary.into(); { @@ -374,9 +238,12 @@ impl<'a> AirBuilder<'a> { rhs: &ast::ScalarExpr, condition: Option<&ast::ScalarExpr>, ) -> Result<(), CompileError> { - let lhs = self.insert_scalar_expr(lhs); - let rhs = self.insert_scalar_expr(rhs); - let condition = condition.as_ref().map(|cond| self.insert_scalar_expr(cond)); + let lhs = self.insert_scalar_expr(lhs)?; + let rhs = self.insert_scalar_expr(rhs)?; + let condition = match condition { + Some(cond) => Some(self.insert_scalar_expr(cond)?), + None => None, + }; let root = self.merge_equal_exprs(lhs, rhs, condition); // Get the trace segment and domain of the constraint. // @@ -407,34 +274,197 @@ impl<'a> AirBuilder<'a> { } } - fn insert_scalar_expr(&mut self, expr: &ast::ScalarExpr) -> NodeIndex { + fn eval_let_expr(&mut self, expr: &ast::Let) -> Result { + let mut next_let = Some(expr); + let snapshot = self.bindings.clone(); + loop { + let let_expr = next_let.take().expect("invalid empty let body"); + let bound = self.eval_expr(&let_expr.value)?; + self.bindings.enter(); + self.bindings.insert(let_expr.name, bound); + match let_expr.body.last().unwrap() { + ast::Statement::Let(ref inner_let) => { + next_let = Some(inner_let); + } + ast::Statement::Expr(ref expr) => { + let value = self.eval_expr(expr); + self.bindings = snapshot; + break value; + } + ast::Statement::Enforce(_) + | ast::Statement::EnforceIf(_, _) + | ast::Statement::EnforceAll(_) => { + unreachable!() + } + } + } + } + + fn eval_expr(&mut self, expr: &ast::Expr) -> Result { + match expr { + ast::Expr::Const(ref constant) => match &constant.item { + ast::ConstantExpr::Scalar(value) => { + let value = self.insert_constant(*value); + Ok(MemoizedBinding::Scalar(value)) + } + ast::ConstantExpr::Vector(values) => { + let values = self.insert_constants(values.as_slice()); + Ok(MemoizedBinding::Vector(values)) + } + ast::ConstantExpr::Matrix(values) => { + let values = values + .iter() + .map(|vs| self.insert_constants(vs.as_slice())) + .collect(); + Ok(MemoizedBinding::Matrix(values)) + } + }, + ast::Expr::Range(ref values) => { + let values = values + .item + .clone() + .map(|v| self.insert_constant(v as u64)) + .collect(); + Ok(MemoizedBinding::Vector(values)) + } + ast::Expr::Vector(ref values) => match values[0].ty().unwrap() { + ast::Type::Felt => { + let mut nodes = vec![]; + for value in values.iter().cloned() { + let value = value.try_into().unwrap(); + nodes.push(self.insert_scalar_expr(&value)?); + } + Ok(MemoizedBinding::Vector(nodes)) + } + ast::Type::Vector(n) => { + let mut nodes = vec![]; + for row in values.iter().cloned() { + match row { + ast::Expr::Const(Span { + item: ast::ConstantExpr::Vector(vs), + .. + }) => { + nodes.push(self.insert_constants(vs.as_slice())); + } + ast::Expr::SymbolAccess(access) => { + let mut cols = vec![]; + for i in 0..n { + let access = ast::ScalarExpr::SymbolAccess( + access.access(AccessType::Index(i)).unwrap(), + ); + let node = self.insert_scalar_expr(&access)?; + cols.push(node); + } + nodes.push(cols); + } + ast::Expr::Vector(ref elems) => { + let mut cols = vec![]; + for elem in elems.iter().cloned() { + let elem: ast::ScalarExpr = elem.try_into().unwrap(); + let node = self.insert_scalar_expr(&elem)?; + cols.push(node); + } + nodes.push(cols); + } + _ => unreachable!(), + } + } + Ok(MemoizedBinding::Matrix(nodes)) + } + _ => unreachable!(), + }, + ast::Expr::Matrix(ref values) => { + let mut rows = Vec::with_capacity(values.len()); + for vs in values.iter() { + let mut cols = Vec::with_capacity(vs.len()); + for value in vs { + cols.push(self.insert_scalar_expr(value)?); + } + rows.push(cols); + } + Ok(MemoizedBinding::Matrix(rows)) + } + ast::Expr::Binary(ref bexpr) => { + let value = self.insert_binary_expr(bexpr)?; + Ok(MemoizedBinding::Scalar(value)) + } + ast::Expr::SymbolAccess(ref access) => { + match self.bindings.get(access.name.as_ref()) { + None => { + // Must be a reference to a declaration + let value = self.insert_symbol_access(access); + Ok(MemoizedBinding::Scalar(value)) + } + Some(MemoizedBinding::Scalar(node)) => { + assert_eq!(access.access_type, AccessType::Default); + Ok(MemoizedBinding::Scalar(*node)) + } + Some(MemoizedBinding::Vector(nodes)) => { + let value = match &access.access_type { + AccessType::Default => MemoizedBinding::Vector(nodes.clone()), + AccessType::Index(idx) => MemoizedBinding::Scalar(nodes[*idx]), + AccessType::Slice(range) => { + MemoizedBinding::Vector(nodes[range.start..range.end].to_vec()) + } + AccessType::Matrix(_, _) => unreachable!(), + }; + Ok(value) + } + Some(MemoizedBinding::Matrix(nodes)) => { + let value = match &access.access_type { + AccessType::Default => MemoizedBinding::Matrix(nodes.clone()), + AccessType::Index(idx) => MemoizedBinding::Vector(nodes[*idx].clone()), + AccessType::Slice(range) => { + MemoizedBinding::Matrix(nodes[range.start..range.end].to_vec()) + } + AccessType::Matrix(row, col) => { + MemoizedBinding::Scalar(nodes[*row][*col]) + } + }; + Ok(value) + } + } + } + ast::Expr::Let(ref let_expr) => self.eval_let_expr(let_expr), + // These node types should not exist at this point + ast::Expr::Call(_) | ast::Expr::ListComprehension(_) => unreachable!(), + } + } + + fn insert_scalar_expr(&mut self, expr: &ast::ScalarExpr) -> Result { match expr { ast::ScalarExpr::Const(value) => { - self.insert_op(Operation::Value(Value::Constant(value.item))) + Ok(self.insert_op(Operation::Value(Value::Constant(value.item)))) } - ast::ScalarExpr::SymbolAccess(access) => self.insert_symbol_access(access), + ast::ScalarExpr::SymbolAccess(access) => Ok(self.insert_symbol_access(access)), ast::ScalarExpr::Binary(expr) => self.insert_binary_expr(expr), + ast::ScalarExpr::Let(ref let_expr) => match self.eval_let_expr(let_expr)? { + MemoizedBinding::Scalar(node) => Ok(node), + invalid => { + panic!("expected scalar expression to produce scalar value, got: {invalid:?}") + } + }, ast::ScalarExpr::Call(_) | ast::ScalarExpr::BoundedSymbolAccess(_) => unreachable!(), } } - fn insert_binary_expr(&mut self, expr: &ast::BinaryExpr) -> NodeIndex { + fn insert_binary_expr(&mut self, expr: &ast::BinaryExpr) -> Result { if expr.op == ast::BinaryOp::Exp { - let lhs = self.insert_scalar_expr(expr.lhs.as_ref()); + let lhs = self.insert_scalar_expr(expr.lhs.as_ref())?; let ast::ScalarExpr::Const(rhs) = expr.rhs.as_ref() else { unreachable!(); }; - return self.insert_op(Operation::Exp(lhs, rhs.item as usize)); + return Ok(self.insert_op(Operation::Exp(lhs, rhs.item as usize))); } - let lhs = self.insert_scalar_expr(expr.lhs.as_ref()); - let rhs = self.insert_scalar_expr(expr.rhs.as_ref()); - match expr.op { + let lhs = self.insert_scalar_expr(expr.lhs.as_ref())?; + let rhs = self.insert_scalar_expr(expr.rhs.as_ref())?; + Ok(match expr.op { ast::BinaryOp::Add => self.insert_op(Operation::Add(lhs, rhs)), ast::BinaryOp::Sub => self.insert_op(Operation::Sub(lhs, rhs)), ast::BinaryOp::Mul => self.insert_op(Operation::Mul(lhs, rhs)), _ => unreachable!(), - } + }) } fn insert_symbol_access(&mut self, access: &ast::SymbolAccess) -> NodeIndex { diff --git a/parser/Cargo.toml b/parser/Cargo.toml index 79429a76..3a3d08f5 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -16,9 +16,10 @@ lalrpop = { version = "0.20", default-features = false } [dependencies] air-pass = { package = "air-pass", path = "../pass", version = "0.1" } +either = "1.12" miden-diagnostics = "0.1" miden-parsing = "0.1" -lalrpop-util="0.20" +lalrpop-util = "0.20" lazy_static = "1.4" petgraph = "0.6" regex = "1" diff --git a/parser/src/ast/display.rs b/parser/src/ast/display.rs index d8ee0de6..38e63551 100644 --- a/parser/src/ast/display.rs +++ b/parser/src/ast/display.rs @@ -94,11 +94,12 @@ impl<'a> fmt::Display for DisplayStatement<'a> { self.write_indent(f)?; match self.statement { Statement::Let(ref expr) => { - writeln!(f, "let {} = {}", expr.name, expr.value)?; - for statement in expr.body.iter() { - writeln!(f, "{}", statement.display(self.indent))?; - } - Ok(()) + let display = DisplayLet { + let_expr: expr, + indent: self.indent, + in_expr_position: false, + }; + write!(f, "{display}") } Statement::Enforce(ref expr) => { write!(f, "enf {}", expr) @@ -113,3 +114,59 @@ impl<'a> fmt::Display for DisplayStatement<'a> { } } } + +pub struct DisplayLet<'a> { + pub let_expr: &'a super::Let, + pub indent: usize, + pub in_expr_position: bool, +} +impl DisplayLet<'_> { + const INDENT: &'static str = " "; + + fn write_indent(&self, f: &mut fmt::Formatter) -> fmt::Result { + for _ in 0..self.indent { + f.write_str(Self::INDENT)?; + } + Ok(()) + } +} +impl<'a> fmt::Display for DisplayLet<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use core::fmt::Write; + + self.write_indent(f)?; + match &self.let_expr.value { + super::Expr::Let(value) => { + writeln!(f, "let {} = {{", self.let_expr.name)?; + let display = DisplayLet { + let_expr: value, + indent: self.indent + 1, + in_expr_position: true, + }; + writeln!(f, "{display}")?; + self.write_indent(f)?; + if self.in_expr_position { + f.write_str("} in {\n")?; + } else { + f.write_str("}\n")?; + } + } + value => { + write!(f, "let {} = {}", self.let_expr.name, value)?; + if self.in_expr_position { + f.write_str(" in {\n")?; + } else { + f.write_char('\n')?; + } + } + } + for stmt in self.let_expr.body.iter() { + writeln!(f, "{}", stmt.display(self.indent + 1))?; + } + if self.in_expr_position { + self.write_indent(f)?; + f.write_char('}')?; + } + Ok(()) + } +} diff --git a/parser/src/ast/errors.rs b/parser/src/ast/errors.rs index be0e6750..0d8aab66 100644 --- a/parser/src/ast/errors.rs +++ b/parser/src/ast/errors.rs @@ -11,6 +11,10 @@ pub enum InvalidExprError { BoundedSymbolAccess(SourceSpan), #[error("expected scalar expression")] InvalidScalarExpr(SourceSpan), + #[error("invalid let in expression position: body produces no value, or the type of that value is unknown")] + InvalidLetExpr(SourceSpan), + #[error("syntax does not represent a valid expression")] + NotAnExpr(SourceSpan), } impl Eq for InvalidExprError {} impl PartialEq for InvalidExprError { @@ -22,11 +26,6 @@ impl ToDiagnostic for InvalidExprError { fn to_diagnostic(self) -> Diagnostic { let message = format!("{}", &self); match self { - Self::InvalidExponent(span) => Diagnostic::error() - .with_message("invalid expression") - .with_labels(vec![ - Label::primary(span.source_id(), span).with_message(message) - ]), Self::NonConstantExponent(span) => Diagnostic::error() .with_message("invalid expression") .with_labels(vec![ @@ -36,12 +35,11 @@ impl ToDiagnostic for InvalidExprError { "Only constant powers are supported with the exponentiation operator currently" .to_string(), ]), - Self::BoundedSymbolAccess(span) => Diagnostic::error() - .with_message("invalid expression") - .with_labels(vec![ - Label::primary(span.source_id(), span).with_message(message) - ]), - Self::InvalidScalarExpr(span) => Diagnostic::error() + Self::InvalidExponent(span) + | Self::BoundedSymbolAccess(span) + | Self::InvalidScalarExpr(span) + | Self::InvalidLetExpr(span) + | Self::NotAnExpr(span) => Diagnostic::error() .with_message("invalid expression") .with_labels(vec![ Label::primary(span.source_id(), span).with_message(message) diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index d763328d..aa28d0f1 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -296,6 +296,11 @@ pub enum Expr { Call(Call), /// A generator expression which produces a vector or matrix of values ListComprehension(ListComprehension), + /// A `let` expression, used to bind temporaries in expression position during compilation. + /// + /// NOTE: The AirScript syntax only permits `let` in statement position, so this variant + /// is only present in the AST as the result of an explicit transformation. + Let(Box), } impl Expr { /// Returns true if this expression is constant @@ -325,6 +330,7 @@ impl Expr { Self::Binary(_) => Some(Type::Felt), Self::Call(ref call) => call.ty, Self::ListComprehension(ref lc) => lc.ty, + Self::Let(ref let_expr) => let_expr.ty(), } } } @@ -341,6 +347,7 @@ impl fmt::Debug for Expr { Self::ListComprehension(ref expr) => { f.debug_tuple("ListComprehension").field(expr).finish() } + Self::Let(ref let_expr) => write!(f, "{let_expr:#?}"), } } } @@ -360,10 +367,18 @@ impl fmt::Display for Expr { } f.write_str("]") } - Self::ListComprehension(ref expr) => write!(f, "{}", DisplayBracketed(expr)), Self::SymbolAccess(ref expr) => write!(f, "{}", expr), Self::Binary(ref expr) => write!(f, "{}", expr), Self::Call(ref expr) => write!(f, "{}", expr), + Self::ListComprehension(ref expr) => write!(f, "{}", DisplayBracketed(expr)), + Self::Let(ref let_expr) => { + let display = DisplayLet { + let_expr, + indent: 0, + in_expr_position: true, + }; + write!(f, "{display}") + } } } } @@ -391,22 +406,45 @@ impl From for Expr { Self::ListComprehension(expr) } } +impl TryFrom for Expr { + type Error = InvalidExprError; + + fn try_from(expr: Let) -> Result { + if expr.ty().is_some() { + Ok(Self::Let(Box::new(expr))) + } else { + Err(InvalidExprError::InvalidLetExpr(expr.span())) + } + } +} impl TryFrom for Expr { type Error = InvalidExprError; #[inline] fn try_from(expr: ScalarExpr) -> Result { match expr { - ScalarExpr::Const(spanned) => Ok(Expr::Const(Span::new( + ScalarExpr::Const(spanned) => Ok(Self::Const(Span::new( spanned.span(), ConstantExpr::Scalar(spanned.item), ))), - ScalarExpr::SymbolAccess(access) => Ok(Expr::SymbolAccess(access)), - ScalarExpr::Binary(expr) => Ok(Expr::Binary(expr)), - ScalarExpr::Call(expr) => Ok(Expr::Call(expr)), + ScalarExpr::SymbolAccess(access) => Ok(Self::SymbolAccess(access)), + ScalarExpr::Binary(expr) => Ok(Self::Binary(expr)), + ScalarExpr::Call(expr) => Ok(Self::Call(expr)), ScalarExpr::BoundedSymbolAccess(_) => { Err(InvalidExprError::BoundedSymbolAccess(expr.span())) } + ScalarExpr::Let(expr) => Ok(Self::Let(expr)), + } + } +} +impl TryFrom for Expr { + type Error = InvalidExprError; + + fn try_from(stmt: Statement) -> Result { + match stmt { + Statement::Let(let_expr) => Ok(Self::Let(Box::new(let_expr))), + Statement::Expr(expr) => Ok(expr), + _ => Err(InvalidExprError::NotAnExpr(stmt.span())), } } } @@ -430,7 +468,7 @@ pub enum ScalarExpr { Binary(BinaryExpr), /// A call to a pure function or evaluator /// - /// NOTE: This is only a valid expression when one of the following hold: + /// NOTE: This is only a valid scalar expression when one of the following hold: /// /// 1. The call is the top-level expression of a constraint, and is to an evaluator function /// 2. The call is not the top-level expression of a constraint, and is to a pure function @@ -438,6 +476,12 @@ pub enum ScalarExpr { /// /// If neither of the above are true, the call is invalid in a `ScalarExpr` context Call(Call), + /// An expression that binds a local variable to a temporary value during evaluation. + /// + /// NOTE: This is only a valid scalar expression during the inlining phase, when we expand + /// binary expressions or function calls to a block of statements, and only when the result + /// of evaluating the `let` produces a valid scalar expression. + Let(Box), } impl ScalarExpr { /// Returns true if this is a constant value @@ -449,7 +493,7 @@ impl ScalarExpr { pub fn has_block_like_expansion(&self) -> bool { match self { Self::Binary(ref expr) => expr.has_block_like_expansion(), - Self::Call(_) => true, + Self::Call(_) | Self::Let(_) => true, _ => false, } } @@ -471,6 +515,7 @@ impl ScalarExpr { _ => Err(expr.span()), }, Self::Call(ref expr) => Ok(expr.ty), + Self::Let(ref expr) => Ok(expr.ty()), } } } @@ -489,10 +534,33 @@ impl TryFrom for ScalarExpr { Expr::SymbolAccess(sym) => Ok(Self::SymbolAccess(sym)), Expr::Binary(bin) => Ok(Self::Binary(bin)), Expr::Call(call) => Ok(Self::Call(call)), + Expr::Let(let_expr) => { + if let_expr.ty().is_none() { + Err(InvalidExprError::InvalidScalarExpr(let_expr.span())) + } else { + Ok(Self::Let(let_expr)) + } + } invalid => Err(InvalidExprError::InvalidScalarExpr(invalid.span())), } } } +impl TryFrom for ScalarExpr { + type Error = InvalidExprError; + + fn try_from(stmt: Statement) -> Result { + match stmt { + Statement::Let(let_expr) => Self::try_from(Expr::Let(Box::new(let_expr))), + Statement::Expr(expr) => Self::try_from(expr), + stmt => Err(InvalidExprError::InvalidScalarExpr(stmt.span())), + } + } +} +impl From for ScalarExpr { + fn from(value: u64) -> Self { + Self::Const(Span::new(SourceSpan::UNKNOWN, value)) + } +} impl fmt::Debug for ScalarExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -503,6 +571,7 @@ impl fmt::Debug for ScalarExpr { } Self::Binary(ref expr) => f.debug_tuple("Binary").field(expr).finish(), Self::Call(ref expr) => f.debug_tuple("Call").field(expr).finish(), + Self::Let(ref expr) => write!(f, "{:#?}", expr), } } } @@ -514,6 +583,14 @@ impl fmt::Display for ScalarExpr { Self::BoundedSymbolAccess(ref expr) => write!(f, "{}.{}", &expr.column, &expr.boundary), Self::Binary(ref expr) => write!(f, "{}", expr), Self::Call(ref call) => write!(f, "{}", call), + Self::Let(ref let_expr) => { + let display = DisplayLet { + let_expr, + indent: 0, + in_expr_position: true, + }; + write!(f, "{display}") + } } } } diff --git a/parser/src/ast/statement.rs b/parser/src/ast/statement.rs index 5b5cd367..3e13723c 100644 --- a/parser/src/ast/statement.rs +++ b/parser/src/ast/statement.rs @@ -82,6 +82,24 @@ impl Statement { } } } +impl From for Statement { + fn from(expr: Expr) -> Self { + match expr { + Expr::Let(let_expr) => Self::Let(*let_expr), + expr => Self::Expr(expr), + } + } +} +impl TryFrom for Statement { + type Error = (); + + fn try_from(expr: ScalarExpr) -> Result { + match expr { + ScalarExpr::Let(let_expr) => Ok(Self::Let(*let_expr)), + expr => Expr::try_from(expr).map_err(|_| ()).map(Self::Expr), + } + } +} /// A `let` statement binds `name` to the value of `expr` in `body`. #[derive(Clone, Spanned)] diff --git a/parser/src/ast/visit.rs b/parser/src/ast/visit.rs index 89765c5e..2bb33f0e 100644 --- a/parser/src/ast/visit.rs +++ b/parser/src/ast/visit.rs @@ -601,6 +601,7 @@ where ast::Expr::Binary(ref mut expr) => visitor.visit_mut_binary_expr(expr), ast::Expr::Call(ref mut expr) => visitor.visit_mut_call(expr), ast::Expr::ListComprehension(ref mut expr) => visitor.visit_mut_list_comprehension(expr), + ast::Expr::Let(ref mut expr) => visitor.visit_mut_let(expr), } } @@ -616,6 +617,7 @@ where } ast::ScalarExpr::Binary(ref mut expr) => visitor.visit_mut_binary_expr(expr), ast::ScalarExpr::Call(ref mut expr) => visitor.visit_mut_call(expr), + ast::ScalarExpr::Let(ref mut expr) => visitor.visit_mut_let(expr), } } diff --git a/parser/src/parser/tests/inlining.rs b/parser/src/parser/tests/inlining.rs index 5289ed08..08dc692b 100644 --- a/parser/src/parser/tests/inlining.rs +++ b/parser/src/parser/tests/inlining.rs @@ -1120,22 +1120,32 @@ fn test_inlining_constraints_with_folded_comprehensions_in_evaluator() { int!(0) ))); // When constant propagation and inlining is done, integrity_constraints should look like: - // let lc%0 = b[2]^7 - // let lc%1 = b[3]^7 - // let y = lc%0 + lc%1 - // let lc%2 = b[2]^7 - // let lc%3 = b[3]^7 - // let z = lc%2 + lc%3 + // let y = + // let %lc0 = b[2]^7 + // let %lc1 = b[3]^7 + // %lc0 + %lc1 + // in + // let z = + // let %lc2 = b[2]^7 + // let %lc3 = b[3]^7 + // %lc2 * %lc3 + // in // enf b[1] = y + z expected .integrity_constraints - .push(let_!("%lc0" = expr!(exp!(access!(b[2], Type::Felt), int!(7))) + .push(let_!(y = expr!( + let_!("%lc0" = expr!(exp!(access!(b[2], Type::Felt), int!(7))) => let_!("%lc1" = expr!(exp!(access!(b[3], Type::Felt), int!(7))) - => let_!(y = expr!(add!(access!("%lc0", Type::Felt), access!("%lc1", Type::Felt))) - => let_!("%lc2" = expr!(exp!(access!(b[2], Type::Felt), int!(7))) - => let_!("%lc3" = expr!(exp!(access!(b[3], Type::Felt), int!(7))) - => let_!(z = expr!(mul!(access!("%lc2", Type::Felt), access!("%lc3", Type::Felt))) - => enforce!(eq!(access!(b[1], Type::Felt), add!(access!(y, Type::Felt), access!(z, Type::Felt))))))))))); + => statement!(add!(access!("%lc0", Type::Felt), access!("%lc1", Type::Felt))))) + ) => + let_!(z = expr!( + let_!("%lc2" = expr!(exp!(access!(b[2], Type::Felt), int!(7))) + => let_!("%lc3" = expr!(exp!(access!(b[3], Type::Felt), int!(7))) + => statement!(mul!(access!("%lc2", Type::Felt), access!("%lc3", Type::Felt))))) + ) => + enforce!(eq!(access!(b[1], Type::Felt), add!(access!(y, Type::Felt), access!(z, Type::Felt)))) + ) + )); // The evaluator definition is never modified by constant propagation or inlining let body = vec![ let_!(y = expr!(call!(sum(expr!(lc!(((col, expr!(access!(ys, Type::Vector(2))))) => exp!(access!(col, Type::Felt), int!(7))))))) @@ -1251,20 +1261,19 @@ fn test_inlining_with_function_call_as_binary_operand() { ))); // With constant propagation and inlining done // - // let %0 = b[0] + b[1] + b[2] + b[3] - // let m = b[0] * b[1] - // let n = m * b[2] - // let o = n * b[3] - // let %1 = o - // let complex_fold = %0 * %1 + // let complex_fold = + // (b[0] + b[1] + b[2] + b[3]) * + // (let m = b[0] * b[1] + // let n = m * b[2] + // let o = n * b[3] in o) // enf complex_fold = 1 expected.integrity_constraints.push( - let_!("%0" = expr!(add!(add!(add!(access!(b[0], Type::Felt), access!(b[1], Type::Felt)), access!(b[2], Type::Felt)), access!(b[3], Type::Felt))) - => let_!(m = expr!(mul!(access!(b[0], Type::Felt), access!(b[1], Type::Felt))) + let_!(complex_fold = expr!(mul!( + add!(add!(add!(access!(b[0], Type::Felt), access!(b[1], Type::Felt)), access!(b[2], Type::Felt)), access!(b[3], Type::Felt)), + scalar!(let_!(m = expr!(mul!(access!(b[0], Type::Felt), access!(b[1], Type::Felt))) => let_!(n = expr!(mul!(access!(m, Type::Felt), access!(b[2], Type::Felt))) - => let_!(o = expr!(mul!(access!(n, Type::Felt), access!(b[3], Type::Felt))) - => let_!(complex_fold = expr!(mul!(access!("%0", Type::Felt), access!(o, Type::Felt))) - => enforce!(eq!(access!(complex_fold, Type::Felt), int!(1)))))))) + => let_!(o = expr!(mul!(access!(n, Type::Felt), access!(b[3], Type::Felt))) => return_!(expr!(access!(o, Type::Felt))))))) + )) => enforce!(eq!(access!(complex_fold, Type::Felt), int!(1)))) ); assert_eq!(program, expected); @@ -1419,84 +1428,88 @@ fn test_repro_issue340() { }); add!(acc, access) }); - let low_bit_sum_body = let_!(low_bit_sum = expr!(low_bit_sum) => - enforce!(eq!(access!(immediate, Type::Felt), add!(access!(low_bit_sum, Type::Felt), access!(high_bit_sum, Type::Felt))), when access!(s, Type::Felt)), - enforce!(eq!(access!(instruction_bits[31], Type::Felt), int!(1)), when access!(s, Type::Felt))); - let high_bit_sum_body = let_!(high_bit_sum = expr!(high_bit_sum) - => let_!("%lc53" = expr!(mul!(access!(instruction_bits[20], Type::Felt), int!(1))) - => let_!("%lc54" = expr!(mul!(access!(instruction_bits[21], Type::Felt), int!(2))) - => let_!("%lc55" = expr!(mul!(access!(instruction_bits[22], Type::Felt), int!(4))) - => let_!("%lc56" = expr!(mul!(access!(instruction_bits[23], Type::Felt), int!(8))) - => let_!("%lc57" = expr!(mul!(access!(instruction_bits[24], Type::Felt), int!(16))) - => let_!("%lc58" = expr!(mul!(access!(instruction_bits[25], Type::Felt), int!(32))) - => let_!("%lc59" = expr!(mul!(access!(instruction_bits[26], Type::Felt), int!(64))) - => let_!("%lc60" = expr!(mul!(access!(instruction_bits[27], Type::Felt), int!(128))) - => let_!("%lc61" = expr!(mul!(access!(instruction_bits[28], Type::Felt), int!(256))) - => let_!("%lc62" = expr!(mul!(access!(instruction_bits[29], Type::Felt), int!(512))) - => let_!("%lc63" = expr!(mul!(access!(instruction_bits[30], Type::Felt), int!(1024))) - => low_bit_sum_body)))))))))))); - let word_sum_body = let_!(word_sum = expr!(word_sum) - => enforce!(eq!(access!(instruction_word, Type::Felt), access!(word_sum, Type::Felt))), - let_!("%lc32" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(2048))) - => let_!("%lc33" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(4096))) - => let_!("%lc34" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(8192))) - => let_!("%lc35" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(16384))) - => let_!("%lc36" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(32768))) - => let_!("%lc37" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(65536))) - => let_!("%lc38" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(131072))) - => let_!("%lc39" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(262144))) - => let_!("%lc40" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(524288))) - => let_!("%lc41" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(1048576))) - => let_!("%lc42" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(2097152))) - => let_!("%lc43" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(4194304))) - => let_!("%lc44" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(8388608))) - => let_!("%lc45" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(16777216))) - => let_!("%lc46" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(33554432))) - => let_!("%lc47" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(67108864))) - => let_!("%lc48" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(134217728))) - => let_!("%lc49" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(268435456))) - => let_!("%lc50" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(536870912))) - => let_!("%lc51" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(1073741824))) - => let_!("%lc52" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(2147483648))) - => high_bit_sum_body))))))))))))))))))))), - enforce!(eq!(access!(immediate, Type::Felt), int!(0)), when not!(access!(s, Type::Felt))) + let instruction_bits = ident!(instruction_bits); + let low_bit_sum_value_expr = (53..64).rfold( + Statement::Expr(Expr::try_from(low_bit_sum).unwrap()), + |acc, i| { + let literal = 2u64.pow((i as u32) - 53); + let access = ScalarExpr::SymbolAccess(SymbolAccess { + span: Default::default(), + name: ResolvableIdentifier::Local(instruction_bits), + access_type: AccessType::Index(i - 33), + offset: 0, + ty: Some(Type::Felt), + }); + Statement::Let(Let { + span: SourceSpan::default(), + name: Identifier::new( + SourceSpan::default(), + crate::Symbol::intern(format!("%lc{}", i)), + ), + value: expr!(mul!(access, scalar!(literal))), + body: vec![acc], + }) + }, ); - - expected.integrity_constraints.push( - let_!("%lc0" = expr!(mul!(int!(1), access!(instruction_bits[0], Type::Felt))) - => let_!("%lc1" = expr!(mul!(int!(2), access!(instruction_bits[1], Type::Felt))) - => let_!("%lc2" = expr!(mul!(int!(4), access!(instruction_bits[2], Type::Felt))) - => let_!("%lc3" = expr!(mul!(int!(8), access!(instruction_bits[3], Type::Felt))) - => let_!("%lc4" = expr!(mul!(int!(16), access!(instruction_bits[4], Type::Felt))) - => let_!("%lc5" = expr!(mul!(int!(32), access!(instruction_bits[5], Type::Felt))) - => let_!("%lc6" = expr!(mul!(int!(64), access!(instruction_bits[6], Type::Felt))) - => let_!("%lc7" = expr!(mul!(int!(128), access!(instruction_bits[7], Type::Felt))) - => let_!("%lc8" = expr!(mul!(int!(256), access!(instruction_bits[8], Type::Felt))) - => let_!("%lc9" = expr!(mul!(int!(512), access!(instruction_bits[9], Type::Felt))) - => let_!("%lc10" = expr!(mul!(int!(1024), access!(instruction_bits[10], Type::Felt))) - => let_!("%lc11" = expr!(mul!(int!(2048), access!(instruction_bits[11], Type::Felt))) - => let_!("%lc12" = expr!(mul!(int!(4096), access!(instruction_bits[12], Type::Felt))) - => let_!("%lc13" = expr!(mul!(int!(8192), access!(instruction_bits[13], Type::Felt))) - => let_!("%lc14" = expr!(mul!(int!(16384), access!(instruction_bits[14], Type::Felt))) - => let_!("%lc15" = expr!(mul!(int!(32768), access!(instruction_bits[15], Type::Felt))) - => let_!("%lc16" = expr!(mul!(int!(65536), access!(instruction_bits[16], Type::Felt))) - => let_!("%lc17" = expr!(mul!(int!(131072), access!(instruction_bits[17], Type::Felt))) - => let_!("%lc18" = expr!(mul!(int!(262144), access!(instruction_bits[18], Type::Felt))) - => let_!("%lc19" = expr!(mul!(int!(524288), access!(instruction_bits[19], Type::Felt))) - => let_!("%lc20" = expr!(mul!(int!(1048576), access!(instruction_bits[20], Type::Felt))) - => let_!("%lc21" = expr!(mul!(int!(2097152), access!(instruction_bits[21], Type::Felt))) - => let_!("%lc22" = expr!(mul!(int!(4194304), access!(instruction_bits[22], Type::Felt))) - => let_!("%lc23" = expr!(mul!(int!(8388608), access!(instruction_bits[23], Type::Felt))) - => let_!("%lc24" = expr!(mul!(int!(16777216), access!(instruction_bits[24], Type::Felt))) - => let_!("%lc25" = expr!(mul!(int!(33554432), access!(instruction_bits[25], Type::Felt))) - => let_!("%lc26" = expr!(mul!(int!(67108864), access!(instruction_bits[26], Type::Felt))) - => let_!("%lc27" = expr!(mul!(int!(134217728), access!(instruction_bits[27], Type::Felt))) - => let_!("%lc28" = expr!(mul!(int!(268435456), access!(instruction_bits[28], Type::Felt))) - => let_!("%lc29" = expr!(mul!(int!(536870912), access!(instruction_bits[29], Type::Felt))) - => let_!("%lc30" = expr!(mul!(int!(1073741824), access!(instruction_bits[30], Type::Felt))) - => let_!("%lc31" = expr!(mul!(int!(2147483648), access!(instruction_bits[31], Type::Felt))) - => word_sum_body)))))))))))))))))))))))))))))))), + let low_bit_sum_value_expr = Expr::try_from(low_bit_sum_value_expr).unwrap(); + let high_bit_sum_value_expr = (11..32).rfold( + Statement::Expr(Expr::try_from(high_bit_sum).unwrap()), + |acc, i| { + let literal = 2u64.pow(i); + let access = ScalarExpr::SymbolAccess(SymbolAccess { + span: Default::default(), + name: ResolvableIdentifier::Local(instruction_bits), + access_type: AccessType::Index(31), + offset: 0, + ty: Some(Type::Felt), + }); + Statement::Let(Let { + span: SourceSpan::default(), + name: Identifier::new( + SourceSpan::default(), + crate::Symbol::intern(format!("%lc{}", i + 21)), + ), + value: expr!(mul!(access, scalar!(literal))), + body: vec![acc], + }) + }, ); + let high_bit_sum_value_expr = Expr::try_from(high_bit_sum_value_expr).unwrap(); + let word_sum_value_expr = (0..32).rfold( + Statement::Expr(Expr::try_from(word_sum).unwrap()), + |acc, i| { + let literal = 2u64.pow(i as u32); + let access = ScalarExpr::SymbolAccess(SymbolAccess { + span: Default::default(), + name: ResolvableIdentifier::Local(instruction_bits), + access_type: AccessType::Index(i), + offset: 0, + ty: Some(Type::Felt), + }); + Statement::Let(Let { + span: SourceSpan::default(), + name: Identifier::new( + SourceSpan::default(), + crate::Symbol::intern(format!("%lc{}", i)), + ), + value: expr!(mul!(scalar!(literal), access)), + body: vec![acc], + }) + }, + ); + let word_sum_value_expr = Expr::try_from(word_sum_value_expr).unwrap(); + let word_sum_body = let_!(word_sum = word_sum_value_expr + => enforce!(eq!(access!(instruction_word, Type::Felt), access!(word_sum, Type::Felt))), + let_!(high_bit_sum = high_bit_sum_value_expr + => let_!(low_bit_sum = low_bit_sum_value_expr + => enforce!(eq!(access!(immediate, Type::Felt), add!(access!(low_bit_sum, Type::Felt), access!(high_bit_sum, Type::Felt))), when access!(s, Type::Felt)), + enforce!(eq!(access!(instruction_bits[31], Type::Felt), int!(1)), when access!(s, Type::Felt)) + )), + enforce!(eq!(access!(immediate, Type::Felt), int!(0)), when not!(access!(s, Type::Felt))) + ); + + expected.integrity_constraints.push(word_sum_body); + // The evaluator definition is never modified by constant propagation or inlining let body = vec![ let_!(sign_bit = expr!(access!(instruction_bits[31], Type::Felt)) diff --git a/parser/src/parser/tests/mod.rs b/parser/src/parser/tests/mod.rs index 2d88454b..d570c011 100644 --- a/parser/src/parser/tests/mod.rs +++ b/parser/src/parser/tests/mod.rs @@ -275,6 +275,18 @@ macro_rules! expr { }; } +macro_rules! scalar { + ($expr:expr) => { + ScalarExpr::try_from($expr).unwrap() + }; +} + +macro_rules! statement { + ($expr:expr) => { + Statement::try_from($expr).unwrap() + }; +} + macro_rules! slice { ($name:ident, $range:expr) => { ScalarExpr::SymbolAccess(SymbolAccess { diff --git a/parser/src/sema/scope.rs b/parser/src/sema/scope.rs index 7a98ecc5..2390ba24 100644 --- a/parser/src/sema/scope.rs +++ b/parser/src/sema/scope.rs @@ -3,6 +3,7 @@ use std::{ collections::HashMap, hash::Hash, ops::{Index, IndexMut}, + rc::Rc, }; /// A simple type alias for a boxed `HashMap` to aid in readability of the code below @@ -21,7 +22,6 @@ pub type Env = Box>; /// When searching for keys, the search begins in the current scope, and searches upwards /// in the scope tree until either the root is reached and the search terminates, or the /// key is found in some intervening scope. -#[derive(Clone)] pub enum LexicalScope { /// An empty scope, this is the default state in which all [LexicalScope] start Empty, @@ -29,7 +29,20 @@ pub enum LexicalScope { Root(Env), /// Represents a (possibly empty) nested scope, as a tuple of the parent scope and /// the environment of the current scope. - Nested(Box>, Env), + Nested(Rc>, Env), +} +impl Clone for LexicalScope +where + K: Clone, + V: Clone, +{ + fn clone(&self) -> Self { + match self { + Self::Empty => Self::Empty, + Self::Root(scope) => Self::Root(scope.clone()), + Self::Nested(parent, scope) => Self::Nested(Rc::clone(parent), scope.clone()), + } + } } impl Default for LexicalScope { fn default() -> Self { @@ -45,23 +58,26 @@ impl LexicalScope { Self::Nested(parent, env) => env.is_empty() && parent.is_empty(), } } +} +impl LexicalScope +where + K: Clone, + V: Clone, +{ + /// Returns true if this scope is empty /// Enters a new, nested lexical scope pub fn enter(&mut self) { - let moved = Box::new(core::mem::take(self)); + let moved = Rc::new(core::mem::take(self)); *self = Self::Nested(moved, Env::default()); } /// Exits the current lexical scope pub fn exit(&mut self) { - match self { - Self::Empty => (), - Self::Root(_env) => { - *self = Self::Empty; - } - Self::Nested(ref mut parent, _) => { - let moved = core::mem::take(parent.as_mut()); - *self = moved; + match core::mem::replace(self, Self::Empty) { + Self::Empty | Self::Root(_) => (), + Self::Nested(parent, _) => { + *self = Rc::unwrap_or_clone(parent); } } } @@ -108,9 +124,9 @@ where match self { Self::Empty => None, Self::Root(ref mut env) => env.get_mut(key), - Self::Nested(ref mut parent, ref mut env) => { - env.get_mut(key).or_else(|| parent.get_mut(key)) - } + Self::Nested(ref mut parent, ref mut env) => env + .get_mut(key) + .or_else(|| Rc::get_mut(parent).and_then(|p| p.get_mut(key))), } } diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index 40de33cb..0788a03b 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -1562,6 +1562,14 @@ impl<'a> SemanticAnalysis<'a> { } } } + Expr::Let(ref expr) => { + self.diagnostics + .diagnostic(Severity::Bug) + .with_message("invalid expression") + .with_primary_label(expr.span(), "let expressions are not valid here") + .emit(); + Err(InvalidAccessError::InvalidBinding) + } } } diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs index 6f77c315..cc08998e 100644 --- a/parser/src/transforms/constant_propagation.rs +++ b/parser/src/transforms/constant_propagation.rs @@ -4,6 +4,7 @@ use std::{ }; use air_pass::Pass; +use either::Either::{self, Left, Right}; use miden_diagnostics::{DiagnosticsHandler, Span, Spanned}; use crate::{ @@ -93,6 +94,79 @@ impl<'a> ConstantPropagation<'a> { // If both operands are constant, fold try_fold_binary_expr(expr).map_err(SemanticAnalysisError::InvalidExpr) } + + /// When folding a `let`, one of the following can occur: + /// + /// * The let-bound variable is non-constant, so the entire let must remain, but we + /// can constant-propagate as much of the bound expression and body as possible. + /// * The let-bound variable is constant, so once we have constant propagated the body, + /// the let is no longer needed, and one of the following happens: + /// * The `let` terminates with a constant expression, so the entire `let` is replaced + /// with that expression. + /// * The `let` terminates with a non-constant expression, or a constraint, so we inline + /// the let body into the containing block. In the non-constant expression case, we + /// replace the `let` with the last expression in the returned block, since in expression + /// position, we may not have a statement block to inline into. + fn try_fold_let_expr( + &mut self, + expr: &mut Let, + ) -> Result>, Vec>, SemanticAnalysisError> { + // Visit the binding expression first + if let ControlFlow::Break(err) = self.visit_mut_expr(&mut expr.value) { + return Err(err); + } + + // Enter a new lexical scope + let prev_live = core::mem::take(&mut self.live); + self.local.enter(); + // If the value is constant, record it in our bindings map + let is_constant = expr.value.is_constant(); + if is_constant { + match expr.value { + Expr::Const(ref value) => { + self.local.insert(expr.name, value.clone()); + } + Expr::Range(ref range) => { + let vector = range.item.clone().map(|i| i as u64).collect(); + self.local.insert( + expr.name, + Span::new(range.span(), ConstantExpr::Vector(vector)), + ); + } + _ => unreachable!(), + } + } + + // Visit the let body + if let ControlFlow::Break(err) = self.visit_mut_statement_block(&mut expr.body) { + return Err(err); + } + + // If this let is constant, then the binding is no longer + // used in the body after constant propagation, so we can + // fold away the let entirely + let is_live = self.live.contains(&expr.name); + let result = if is_constant && !is_live { + match expr.body.last().unwrap() { + Statement::Expr(Expr::Const(ref const_value)) => { + Left(Some(Span::new(expr.span(), const_value.item.clone()))) + } + _ => Right(core::mem::take(&mut expr.body)), + } + } else { + Left(None) + }; + + // Propagate liveness from the body of the let to its parent scope + let mut live = core::mem::take(&mut self.live); + live.remove(&expr.name); + self.live = &prev_live | &live; + + // Restore the previous scope + self.local.exit(); + + Ok(result) + } } impl<'a> VisitMut for ConstantPropagation<'a> { /// Fold constant expressions @@ -166,6 +240,47 @@ impl<'a> VisitMut for ConstantPropagation<'a> { ScalarExpr::Call(ref mut call) => self.visit_mut_call(call), // This cannot be constant folded ScalarExpr::BoundedSymbolAccess(_) => ControlFlow::Continue(()), + // A let that evaluates to a constant value can be folded to the constant value + ScalarExpr::Let(ref mut let_expr) => { + match self.try_fold_let_expr(let_expr) { + Ok(Left(Some(const_expr))) => { + let span = const_expr.span(); + match const_expr.item { + ConstantExpr::Scalar(value) => { + *expr = ScalarExpr::Const(Span::new(span, value)); + } + _ => { + self.diagnostics.diagnostic(miden_diagnostics::Severity::Error) + .with_message("invalid scalar expression") + .with_primary_label(span, "expected scalar value, but this expression evaluates to an aggregate type") + .emit(); + return ControlFlow::Break(SemanticAnalysisError::Invalid); + } + } + } + Ok(Left(None)) => (), + Ok(Right(mut block)) => match block.pop().unwrap() { + Statement::Let(inner_expr) => { + *let_expr.as_mut() = inner_expr; + } + Statement::Expr(inner_expr) => { + match ScalarExpr::try_from(inner_expr) + .map_err(SemanticAnalysisError::InvalidExpr) + { + Ok(scalar_expr) => { + *expr = scalar_expr; + } + Err(err) => return ControlFlow::Break(err), + } + } + Statement::Enforce(_) + | Statement::EnforceIf(_, _) + | Statement::EnforceAll(_) => unreachable!(), + }, + Err(err) => return ControlFlow::Break(err), + } + ControlFlow::Continue(()) + } } } @@ -458,6 +573,27 @@ impl<'a> VisitMut for ConstantPropagation<'a> { *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(folded))); ControlFlow::Continue(()) } + Expr::Let(ref mut let_expr) => { + match self.try_fold_let_expr(let_expr) { + Ok(Left(Some(const_expr))) => { + *expr = Expr::Const(Span::new(span, const_expr.item)); + } + Ok(Left(None)) => (), + Ok(Right(mut block)) => match block.pop().unwrap() { + Statement::Let(inner_expr) => { + *let_expr.as_mut() = inner_expr; + } + Statement::Expr(inner_expr) => { + *expr = inner_expr; + } + Statement::Enforce(_) + | Statement::EnforceIf(_, _) + | Statement::EnforceAll(_) => unreachable!(), + }, + Err(err) => return ControlFlow::Break(err), + } + ControlFlow::Continue(()) + } } } @@ -479,47 +615,16 @@ impl<'a> VisitMut for ConstantPropagation<'a> { num_statements - 1, "let is not in tail position of block" ); - // Visit the binding expression first - self.visit_mut_expr(&mut expr.value)?; - // Enter a new lexical scope - let prev_live = core::mem::take(&mut self.live); - self.local.enter(); - // If the value is constant, record it in our bindings map - let is_constant = expr.value.is_constant(); - if is_constant { - match expr.value { - Expr::Const(ref value) => { - self.local.insert(expr.name, value.clone()); - } - Expr::Range(ref range) => { - let vector = range.item.clone().map(|i| i as u64).collect(); - self.local.insert( - expr.name, - Span::new(range.span(), ConstantExpr::Vector(vector)), - ); - } - _ => unreachable!(), + match self.try_fold_let_expr(expr) { + Ok(Left(Some(const_expr))) => { + buffer.push(Statement::Expr(Expr::Const(const_expr))); } + Ok(Left(None)) => (), + Ok(Right(mut block)) => { + buffer.append(&mut block); + } + Err(err) => return ControlFlow::Break(err), } - - // Visit the let body - self.visit_mut_statement_block(&mut expr.body)?; - - // If this let is constant, then the binding is no longer - // used in the body after constant propagation, flatten its - // body into the current block. - let is_live = self.live.contains(&expr.name); - if is_constant && !is_live { - buffer.append(&mut expr.body); - } - - // Propagate liveness from the body of the let to its parent scope - let mut live = core::mem::take(&mut self.live); - live.remove(&expr.name); - self.live = &prev_live | &live; - - // Restore the previous scope - self.local.exit(); } Statement::Enforce(ref mut expr) => { self.visit_mut_enforce(expr)?; diff --git a/parser/src/transforms/inlining.rs b/parser/src/transforms/inlining.rs index 4ee69400..5f793b3a 100644 --- a/parser/src/transforms/inlining.rs +++ b/parser/src/transforms/inlining.rs @@ -317,12 +317,47 @@ impl<'a> Inlining<'a> { } Ok(statements) } - // Expression statements are introduced during inlining, and are always already expanded, - // but they are recursively visited to apply rewrites - Statement::Expr(mut expr) => { - self.rewrite_expr(&mut expr)?; - Ok(vec![Statement::Expr(expr)]) + // Expresssions containing function calls require expansion via inlining, otherwise + // all other expression types are introduced during inlining and are thus already expanded, + // but we must still visit them to apply rewrites. + Statement::Expr(expr) => match self.expand_expr(expr)? { + Expr::Let(let_expr) => Ok(vec![Statement::Let(*let_expr)]), + expr => Ok(vec![Statement::Expr(expr)]), + }, + } + } + + fn expand_expr(&mut self, expr: Expr) -> Result { + match expr { + Expr::Vector(mut elements) => { + let elems = Vec::with_capacity(elements.len()); + for elem in core::mem::replace(&mut elements.item, elems) { + elements.push(self.expand_expr(elem)?); + } + Ok(Expr::Vector(elements)) + } + Expr::Matrix(mut rows) => { + for row in rows.iter_mut() { + let cols = Vec::with_capacity(row.len()); + for col in core::mem::replace(row, cols) { + row.push(self.expand_scalar_expr(col)?); + } + } + Ok(Expr::Matrix(rows)) + } + Expr::Binary(expr) => self.expand_binary_expr(expr), + Expr::Call(expr) => self.expand_call(expr), + Expr::ListComprehension(expr) => { + let mut block = self.expand_comprehension(expr)?; + assert_eq!(block.len(), 1); + Expr::try_from(block.pop().unwrap()).map_err(SemanticAnalysisError::InvalidExpr) + } + Expr::Let(expr) => { + let mut block = self.expand_let(*expr)?; + assert_eq!(block.len(), 1); + Expr::try_from(block.pop().unwrap()).map_err(SemanticAnalysisError::InvalidExpr) } + expr @ (Expr::Const(_) | Expr::Range(_) | Expr::SymbolAccess(_)) => Ok(expr), } } @@ -346,102 +381,46 @@ impl<'a> Inlining<'a> { /// by replacing it with the result of expanding its body fn expand_let(&mut self, expr: Let) -> Result, SemanticAnalysisError> { let span = expr.span(); - let mut name = expr.name; + let name = expr.name; let body = expr.body; - // When expanding a `let` that was inlined at a function callsite, we must ensure that any - // let-bound variables introduced do not shadow bindings for the remaining statements in - // the body of the block at which the function is inlined. For example, consider the following: - // - // fn foo(a: felt) -> felt { - // let b = a * a - // b - // } - // - // integrity_constraints { - // let b = col[0] - // enf foo(b) = 1 - // enf b = 0 - // } - // - // If the call to `foo` is naively inlined, we will end up with: - // - // integrity_constraints { - // let b = col[0] - // let b = b * b - // enf b = 1 - // enf b = 0 - // } - // - // As you can see, this has the effect of breaking the last constraint, by changing the - // definition bound to `b` at that point in the program. - // - // To solve this, we check if we are currently expanding a `let` being inlined as part of - // a function call, and if so, we generate new variable names that will replace the originals. - if !self.call_stack.is_empty() { - name = self.get_next_ident(span); - let binding_ty = self - .expr_binding_type(&expr.value) - .expect("unexpected undefined variable"); - self.rewrites.insert(name); - self.bindings.insert(name, binding_ty); - } - // Visit the let-bound expression first, since it determines how the rest of the process goes - let mut statements = match expr.value { + let value = match expr.value { // When expanding a call in this context, we're expecting a single // statement of either `Expr` or `Let` type, as calls to pure functions // can never contain constraints. - // - // In the case where a `Let` is produced, we'll sink the current - // let to the end of its body, so that it appears that the current - // let came after the expansion point. Expr::Call(call) => self.expand_call(call)?, // Same as above, but for list comprehensions. // // The rules for expansion are the same. - Expr::ListComprehension(lc) => self.expand_comprehension(lc)?, + Expr::ListComprehension(lc) => { + let mut expanded = self.expand_comprehension(lc)?; + match expanded.pop().unwrap() { + Statement::Let(let_expr) => Expr::Let(Box::new(let_expr)), + Statement::Expr(expr) => expr, + Statement::Enforce(_) + | Statement::EnforceIf(_, _) + | Statement::EnforceAll(_) => unreachable!(), + } + } // The operands of a binary expression can contain function calls, so we must ensure // that we expand the operands as needed, and then proceed with expanding the let. Expr::Binary(expr) => self.expand_binary_expr(expr)?, // Other expressions we visit just to expand rewrites - mut value => { - self.rewrite_expr(&mut value)?; - vec![Statement::Expr(value)] + mut expr => { + self.rewrite_expr(&mut expr)?; + expr } }; - // Since the let-bound expression may have expanded to a nested `let` tree, - // ultimately terminating a value expression of some kind, it is necessary to - // push down the current let to the bottom of that tree, replacing the value - // of the current let with whatever expression is at the bottom of the tree - // wrapped in a `Statement::Expr`. We then visit the tree top-down normally, - // knowing that all of the let-bound expressions are simple values. - // - // In short, we perform two visits of the tree - once to nest the current let - // at the bottom of the tree, and a second time to perform expansion/inlining/rewrites - // on the tree. - with_let_result(self, &mut statements, move |_, value| { - // Steal the result value, and replace it with a dummy - // - // The dummy expression will be replaced with the let we're constructing - // when this function returns - let value = - core::mem::replace(value, Expr::Const(Span::new(span, ConstantExpr::Scalar(0)))); - Ok(Some(Statement::Let(Let::new(span, name, value, body)))) - })?; - - // The last statement in the current block of statements _must_ be a `let` here - match statements.pop().unwrap() { - Statement::Let(current_let) => { - // This is where we visit the tree to perform any final transformations on it - let mut expanded = self.expand_let_tree(current_let)?; - // Whatever that expanded to gets appended to the current block, which is returned to the caller - statements.append(&mut expanded); - Ok(statements) - } - ref invalid => panic!("expected let, got {:#?}", invalid), - } + let expr = Let { + span, + name, + value, + body, + }; + + self.expand_let_tree(expr) } /// This is only expected to be called on a let tree which is guaranteed to only have @@ -484,7 +463,7 @@ impl<'a> Inlining<'a> { } /// Expand a call to a pure function (including builtin list folding functions) - fn expand_call(&mut self, mut call: Call) -> Result, SemanticAnalysisError> { + fn expand_call(&mut self, mut call: Call) -> Result { if call.is_builtin() { match call.callee.as_ref().name() { symbols::Sum => { @@ -502,127 +481,45 @@ impl<'a> Inlining<'a> { } } - fn maybe_expand_scalar_expr( + fn expand_scalar_expr( &mut self, - mut expr: Box, - ) -> Result, Box>, SemanticAnalysisError> { - match *expr { + expr: ScalarExpr, + ) -> Result { + match expr { ScalarExpr::Binary(expr) if expr.has_block_like_expansion() => { - self.expand_binary_expr(expr).map(Ok) + self.expand_binary_expr(expr).and_then(|expr| { + ScalarExpr::try_from(expr).map_err(SemanticAnalysisError::InvalidExpr) + }) } - ScalarExpr::Call(lhs) => self.expand_call(lhs).map(Ok), - _ => { + ScalarExpr::Call(lhs) => self.expand_call(lhs).and_then(|expr| { + ScalarExpr::try_from(expr).map_err(SemanticAnalysisError::InvalidExpr) + }), + mut expr => { self.rewrite_scalar_expr(&mut expr)?; - Ok(Err(expr)) + Ok(expr) } } } - fn expand_binary_expr( - &mut self, - expr: BinaryExpr, - ) -> Result, SemanticAnalysisError> { + fn expand_binary_expr(&mut self, expr: BinaryExpr) -> Result { let span = expr.span(); let op = expr.op; - let lhs = self.maybe_expand_scalar_expr(expr.lhs)?; - let rhs = self.maybe_expand_scalar_expr(expr.rhs)?; - - match (lhs, rhs) { - (Err(lhs), Err(rhs)) => Ok(vec![Statement::Expr(Expr::Binary(BinaryExpr { - span, - op, - lhs, - rhs, - }))]), - (Err(lhs), Ok(mut rhs)) => { - with_let_result(self, &mut rhs, |_, value| { - let value = core::mem::replace( - value, - Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), - ); - - Ok(Some(Statement::Expr(Expr::Binary(BinaryExpr { - span, - op, - lhs, - rhs: Box::new(value.try_into()?), - })))) - })?; - - Ok(rhs) - } - (Ok(mut lhs), Err(rhs)) => { - with_let_result(self, &mut lhs, |_, value| { - let value = core::mem::replace( - value, - Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), - ); - - Ok(Some(Statement::Expr(Expr::Binary(BinaryExpr { - span, - op, - lhs: Box::new(value.try_into()?), - rhs, - })))) - })?; - - Ok(lhs) - } - (Ok(mut lhs), Ok(mut rhs)) => { - let name = self.get_next_ident(span); - let ty = match lhs - .last() - .expect("unexpected empty expansion for scalar expression") - { - Statement::Expr(ref expr) => expr.ty(), - Statement::Let(ref expr) => expr.ty(), - _ => unreachable!(), - }; - - with_let_result(self, &mut rhs, |_, value| { - let value = core::mem::replace( - value, - Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), - ); - - Ok(Some(Statement::Expr(Expr::Binary(BinaryExpr::new( - span, - op, - ScalarExpr::SymbolAccess(SymbolAccess { - span, - name: ResolvableIdentifier::Local(name), - access_type: AccessType::Default, - offset: 0, - ty, - }), - value.try_into()?, - ))))) - })?; - - with_let_result(self, &mut lhs, move |_, value| { - let value = core::mem::replace( - value, - Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), - ); - Ok(Some(Statement::Let(Let::new(span, name, value, rhs)))) - })?; - Ok(lhs) - } - } + let lhs = self.expand_scalar_expr(*expr.lhs)?; + let rhs = self.expand_scalar_expr(*expr.rhs)?; + + Ok(Expr::Binary(BinaryExpr { + span, + op, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + })) } /// Expand a list folding operation (e.g. sum/prod) over an expression of aggregate type into an equivalent expression tree - fn expand_fold( - &mut self, - op: BinaryOp, - mut list: Expr, - ) -> Result, SemanticAnalysisError> { + fn expand_fold(&mut self, op: BinaryOp, mut list: Expr) -> Result { let span = list.span(); match list { - Expr::Vector(ref mut elems) => { - let folded = self.expand_vector_fold(span, op, elems)?; - Ok(vec![Statement::Expr(folded)]) - } + Expr::Vector(ref mut elems) => self.expand_vector_fold(span, op, elems), Expr::ListComprehension(lc) => { // Expand the comprehension, but ensure we don't treat it like a comprehension constraint let in_cc = core::mem::replace(&mut self.in_comprehension_constraint, false); @@ -641,7 +538,13 @@ impl<'a> Inlining<'a> { _ => unreachable!(), } })?; - Ok(expanded) + match expanded.pop().unwrap() { + Statement::Expr(expr) => Ok(expr), + Statement::Let(expr) => Ok(Expr::Let(Box::new(expr))), + Statement::Enforce(_) + | Statement::EnforceIf(_, _) + | Statement::EnforceAll(_) => unreachable!(), + } } Expr::SymbolAccess(ref access) => { match self.let_bound.get(access.name.as_ref()).cloned() { @@ -654,8 +557,7 @@ impl<'a> Inlining<'a> { access.access(AccessType::Index(i)).unwrap(), )); } - let folded = self.expand_vector_fold(span, op, &mut vector)?; - Ok(vec![Statement::Expr(folded)]) + self.expand_vector_fold(span, op, &mut vector) } Ok(_) | Err(_) => unimplemented!(), }, @@ -703,86 +605,15 @@ impl<'a> Inlining<'a> { rhs, span, }) => { - let lhs = self.maybe_expand_scalar_expr(lhs)?; - let rhs = self.maybe_expand_scalar_expr(rhs)?; - - match (lhs, rhs) { - (Err(lhs), Err(rhs)) => { - Ok(vec![Statement::Enforce(ScalarExpr::Binary(BinaryExpr { - span, - op: BinaryOp::Eq, - lhs, - rhs, - }))]) - } - (Err(lhs), Ok(mut rhs)) => { - with_let_result(self, &mut rhs, |_, value| { - let value = core::mem::replace( - value, - Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), - ); - - Ok(Some(Statement::Enforce(ScalarExpr::Binary(BinaryExpr { - span, - op: BinaryOp::Eq, - lhs, - rhs: Box::new(value.try_into()?), - })))) - })?; - - Ok(rhs) - } - (Ok(mut lhs), Err(rhs)) => { - with_let_result(self, &mut lhs, |_, value| { - let value = core::mem::replace( - value, - Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), - ); + let lhs = self.expand_scalar_expr(*lhs)?; + let rhs = self.expand_scalar_expr(*rhs)?; - Ok(Some(Statement::Enforce(ScalarExpr::Binary(BinaryExpr { - span, - op: BinaryOp::Eq, - lhs: Box::new(value.try_into()?), - rhs, - })))) - })?; - - Ok(lhs) - } - (Ok(mut lhs), Ok(mut rhs)) => { - let name = self.get_next_ident(span); - - with_let_result(self, &mut rhs, |_, value| { - let value = core::mem::replace( - value, - Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), - ); - - Ok(Some(Statement::Enforce(ScalarExpr::Binary( - BinaryExpr::new( - span, - BinaryOp::Eq, - ScalarExpr::SymbolAccess(SymbolAccess::new( - span, - name, - AccessType::Default, - 0, - )), - value.try_into()?, - ), - )))) - })?; - - with_let_result(self, &mut lhs, move |_, value| { - let value = core::mem::replace( - value, - Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), - ); - Ok(Some(Statement::Let(Let::new(span, name, value, rhs)))) - })?; - Ok(lhs) - } - } + Ok(vec![Statement::Enforce(ScalarExpr::Binary(BinaryExpr { + span, + op: BinaryOp::Eq, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + }))]) } invalid => unreachable!("unexpected constraint node: {:#?}", invalid), } @@ -824,6 +655,23 @@ impl<'a> Inlining<'a> { self.rewrite_expr(expr)?; } } + Expr::Let(ref mut let_expr) => { + let mut next = Some(let_expr.as_mut()); + while let Some(next_let) = next.take() { + self.rewrite_expr(&mut next_let.value)?; + match next_let.body.last_mut().unwrap() { + Statement::Let(ref mut inner) => { + next = Some(inner); + } + Statement::Expr(ref mut expr) => { + self.rewrite_expr(expr)?; + } + Statement::Enforce(_) + | Statement::EnforceIf(_, _) + | Statement::EnforceAll(_) => unreachable!(), + } + } + } } Ok(()) } @@ -863,6 +711,24 @@ impl<'a> Inlining<'a> { } Ok(()) } + ScalarExpr::Let(ref mut let_expr) => { + let mut next = Some(let_expr.as_mut()); + while let Some(next_let) = next.take() { + self.rewrite_expr(&mut next_let.value)?; + match next_let.body.last_mut().unwrap() { + Statement::Let(ref mut inner) => { + next = Some(inner); + } + Statement::Expr(ref mut expr) => { + self.rewrite_expr(expr)?; + } + Statement::Enforce(_) + | Statement::EnforceIf(_, _) + | Statement::EnforceAll(_) => unreachable!(), + } + } + Ok(()) + } } } @@ -967,40 +833,7 @@ impl<'a> Inlining<'a> { // result of the comprehension itself. let span = expr.span(); if self.in_comprehension_constraint { - let mut result = vec![]; - - // If for some reason, we've been able to eliminate all constraints for this comprehension, - // then return an empty block, since we need not emit any code at all if unused. - if statement_groups.is_empty() { - return Ok(result); - } - - // Each group is presumed to already be flattened/treeified - for mut group in statement_groups { - // The first group is simply flattened - if result.is_empty() { - result.append(&mut group); - continue; - } - // Ensure that all statements preceded by a let-bound variable, are nested in the - // body of that let. - match group.pop().unwrap() { - Statement::Let(mut let_expr) => { - with_let_innermost_block(self, &mut let_expr.body, |_, block| { - block.append(&mut result); - - Ok(()) - })?; - result.append(&mut group); - result.push(Statement::Let(let_expr)); - } - stmt => { - result.append(&mut group); - result.push(stmt); - } - } - } - Ok(result) + Ok(statement_groups.into_iter().flatten().collect()) } else { // For list comprehensions, we must emit a let tree that binds each iteration, // and ensure that the expansion of the iteration itself is properly nested so @@ -1071,12 +904,8 @@ impl<'a> Inlining<'a> { // comprehension at the bottom of that tree. lifted.into_iter().try_rfold(expanded, |acc, (name, call)| { let span = call.span(); - let mut preamble = self.expand_call(call)?; - match preamble.pop().unwrap() { - Statement::Expr(expr) => { - preamble.push(Statement::Let(Let::new(span, name, expr, acc))); - } - Statement::Let(mut wrapper) => { + match self.expand_call(call)? { + Expr::Let(mut wrapper) => { with_let_result(self, &mut wrapper.body, move |_, value| { let value = core::mem::replace( value, @@ -1084,11 +913,10 @@ impl<'a> Inlining<'a> { ); Ok(Some(Statement::Let(Let::new(span, name, value, acc)))) })?; - preamble.push(Statement::Let(wrapper)); + Ok(vec![Statement::Let(*wrapper)]) } - _ => unreachable!(), + expr => Ok(vec![Statement::Let(Let::new(span, name, expr, acc))]), } - Ok(preamble) }) } } @@ -1170,7 +998,9 @@ impl<'a> Inlining<'a> { // Binary expressions are scalar, so cannot be used as iterables, and we don't // (currently) support nested comprehensions, so it is never possible to observe // these expression types here. Calls should have been lifted prior to expansion. - Expr::Call(_) | Expr::Binary(_) | Expr::ListComprehension(_) => unreachable!(), + Expr::Call(_) | Expr::Binary(_) | Expr::ListComprehension(_) | Expr::Let(_) => { + unreachable!() + } }; bound_values.insert(binding, abstract_value); } @@ -1340,11 +1170,8 @@ impl<'a> Inlining<'a> { Ok(evaluator.body) } - /// This function handles inlining pure function calls. - fn expand_function_callsite( - &mut self, - call: Call, - ) -> Result, SemanticAnalysisError> { + /// This function handles inlining pure function calls, which must produce an expression + fn expand_function_callsite(&mut self, call: Call) -> Result { self.bindings.enter(); // The callee is guaranteed to be resolved and exist at this point let callee = call @@ -1469,7 +1296,13 @@ impl<'a> Inlining<'a> { // We're done expanding this call, so remove it from the call stack self.call_stack.pop(); - Ok(function.body) + match function.body.pop().unwrap() { + Statement::Expr(expr) => Ok(expr), + Statement::Let(expr) => Ok(Expr::Let(Box::new(expr))), + Statement::Enforce(_) | Statement::EnforceIf(_, _) | Statement::EnforceAll(_) => { + panic!("unexpected constraint in function body") + } + } } /// Populate the set of access rewrites, as well as the initial set of bindings to use when inlining an evaluator function. @@ -1728,56 +1561,93 @@ impl<'a> Inlining<'a> { } } - /// Returns the effective [BindingType] of the given expression fn expr_binding_type(&self, expr: &Expr) -> Result { - match expr { - Expr::Const(constant) => Ok(BindingType::Local(constant.ty())), - Expr::Range(range) => Ok(BindingType::Local(Type::Vector(range.end - range.start))), - Expr::Vector(ref elems) => match elems[0].ty() { - None | Some(Type::Felt) => { - let mut binding_tys = Vec::with_capacity(elems.len()); - for elem in elems.iter() { - binding_tys.push(self.expr_binding_type(elem)?); - } - Ok(BindingType::Vector(binding_tys)) - } - Some(Type::Vector(cols)) => { - let rows = elems.len(); - Ok(BindingType::Local(Type::Matrix(rows, cols))) + let mut bindings = self.bindings.clone(); + eval_expr_binding_type(expr, &mut bindings, &self.imported) + } + + /// Returns the effective [BindingType] of the value produced by the given access + fn access_binding_type(&self, expr: &SymbolAccess) -> Result { + eval_access_binding_type(expr, &self.bindings, &self.imported) + } +} + +/// Returns the effective [BindingType] of the given expression +fn eval_expr_binding_type( + expr: &Expr, + bindings: &mut LexicalScope, + imported: &HashMap, +) -> Result { + match expr { + Expr::Const(constant) => Ok(BindingType::Local(constant.ty())), + Expr::Range(range) => Ok(BindingType::Local(Type::Vector(range.end - range.start))), + Expr::Vector(ref elems) => match elems[0].ty() { + None | Some(Type::Felt) => { + let mut binding_tys = Vec::with_capacity(elems.len()); + for elem in elems.iter() { + binding_tys.push(eval_expr_binding_type(elem, bindings, imported)?); } - Some(_) => unreachable!(), - }, - Expr::Matrix(expr) => { - let rows = expr.len(); - let columns = expr[0].len(); - Ok(BindingType::Local(Type::Matrix(rows, columns))) + Ok(BindingType::Vector(binding_tys)) } - Expr::SymbolAccess(ref access) => self.access_binding_type(access), - Expr::Call(Call { ty: None, .. }) => Err(InvalidAccessError::InvalidBinding), - Expr::Call(Call { ty: Some(ty), .. }) => Ok(BindingType::Local(*ty)), - Expr::Binary(_) => Ok(BindingType::Local(Type::Felt)), - Expr::ListComprehension(ref lc) => { - // The types of all iterables must be the same, so the type of - // the comprehension is given by the type of the iterables. We - // just pick the first iterable to tell us the type - self.expr_binding_type(&lc.iterables[0]) + Some(Type::Vector(cols)) => { + let rows = elems.len(); + Ok(BindingType::Local(Type::Matrix(rows, cols))) } + Some(_) => unreachable!(), + }, + Expr::Matrix(expr) => { + let rows = expr.len(); + let columns = expr[0].len(); + Ok(BindingType::Local(Type::Matrix(rows, columns))) } + Expr::SymbolAccess(ref access) => eval_access_binding_type(access, bindings, imported), + Expr::Call(Call { ty: None, .. }) => Err(InvalidAccessError::InvalidBinding), + Expr::Call(Call { ty: Some(ty), .. }) => Ok(BindingType::Local(*ty)), + Expr::Binary(_) => Ok(BindingType::Local(Type::Felt)), + Expr::ListComprehension(ref lc) => { + // The types of all iterables must be the same, so the type of + // the comprehension is given by the type of the iterables. We + // just pick the first iterable to tell us the type + eval_expr_binding_type(&lc.iterables[0], bindings, imported) + } + Expr::Let(ref let_expr) => eval_let_binding_ty(let_expr, bindings, imported), } +} - /// Returns the effective [BindingType] of the value produced by the given access - fn access_binding_type(&self, expr: &SymbolAccess) -> Result { - let binding_ty = self - .bindings - .get(expr.name.as_ref()) - .or_else(|| match expr.name { - ResolvableIdentifier::Resolved(qid) => self.imported.get(&qid), - _ => None, - }) - .ok_or(InvalidAccessError::UndefinedVariable) - .clone()?; - binding_ty.access(expr.access_type.clone()) - } +/// Returns the effective [BindingType] of the value produced by the given access +fn eval_access_binding_type( + expr: &SymbolAccess, + bindings: &LexicalScope, + imported: &HashMap, +) -> Result { + let binding_ty = bindings + .get(expr.name.as_ref()) + .or_else(|| match expr.name { + ResolvableIdentifier::Resolved(qid) => imported.get(&qid), + _ => None, + }) + .ok_or(InvalidAccessError::UndefinedVariable) + .clone()?; + binding_ty.access(expr.access_type.clone()) +} + +fn eval_let_binding_ty( + let_expr: &Let, + bindings: &mut LexicalScope, + imported: &HashMap, +) -> Result { + let variable_ty = eval_expr_binding_type(&let_expr.value, bindings, imported)?; + bindings.enter(); + bindings.insert(let_expr.name, variable_ty); + let binding_ty = match let_expr.body.last().unwrap() { + Statement::Let(ref inner_let) => eval_let_binding_ty(inner_let, bindings, imported)?, + Statement::Expr(ref expr) => eval_expr_binding_type(expr, bindings, imported)?, + Statement::Enforce(_) | Statement::EnforceIf(_, _) | Statement::EnforceAll(_) => { + unreachable!() + } + }; + bindings.exit(); + Ok(binding_ty) } /// This visitor is used to rewrite uses of iterable bindings within a comprehension body, @@ -1870,8 +1740,11 @@ impl<'a> RewriteIterableBindingsVisitor<'a> { Some(ScalarExpr::SymbolAccess(new_access)) } // These types of expressions will never be observed in this context, as they are - // not valid iterable elements. - Some(Expr::Call(_) | Expr::Binary(_) | Expr::ListComprehension(_)) => unreachable!(), + // not valid iterable expressions (except calls, but those are lifted prior to rewrite + // so that their use in this context is always a symbol access). + Some(Expr::Call(_) | Expr::Binary(_) | Expr::ListComprehension(_) | Expr::Let(_)) => { + unreachable!() + } None => None, }; ControlFlow::Continue(result) @@ -1923,6 +1796,10 @@ impl<'a> VisitMut for RewriteIterableBindingsVisitor<'a> } ControlFlow::Continue(()) } + // We rewrite comprehension bodies before they are expanded, so it should never be + // the case that we encounter a let here, as they can only be introduced in scalar + // expression position as a result of inlining/expansion + ScalarExpr::Let(_) => unreachable!(), } } } @@ -2051,60 +1928,3 @@ where Ok(()) } - -fn with_let_innermost_block( - inliner: &mut Inlining, - entry: &mut Vec, - callback: F, -) -> Result<(), SemanticAnalysisError> -where - F: FnOnce(&mut Inlining, &mut Vec) -> Result<(), SemanticAnalysisError>, -{ - // Preserve the original lexical scope to be restored on exit - let prev = inliner.bindings.clone(); - - // SAFETY: We must use a raw pointer here because the Rust compiler is not able to - // see that we only ever use the mutable reference once, and that the reference - // is never aliased. - // - // Both of these guarantees are in fact upheld here however, as each iteration of the loop - // is either the last iteration (when we use the mutable reference to mutate the end of the - // bottom-most block), or a traversal to the last child of the current let expression. - // We never alias the mutable reference, and in fact immediately convert back to a mutable - // reference inside the loop to ensure that within the loop body we have some degree of - // compiler-assisted checking of that invariant. - let mut current_block = Some(entry as *mut Vec); - while let Some(parent_block) = current_block.take() { - // SAFETY: We convert the pointer back to a mutable reference here before - // we do anything else to ensure the usual aliasing rules are enforced. - // - // It is further guaranteed that this reference is never improperly aliased - // across iterations, as each iteration is visiting a child of the previous - // iteration's node, i.e. what we're doing here is equivalent to holding a - // mutable reference and using it to mutate a field in a deeply nested struct. - let parent_block = unsafe { &mut *parent_block }; - // A block is guaranteed to always have at least one statement here - if let Some(Statement::Let(ref mut let_expr)) = parent_block.last_mut() { - // Register this binding - let binding_ty = inliner.expr_binding_type(&let_expr.value).unwrap(); - inliner.bindings.insert(let_expr.name, binding_ty); - // Set up the next iteration - current_block = Some(&mut let_expr.body as *mut Vec); - continue; - } - // When we hit a block whose last statement is an expression, which - // must also be the bottom-most block of this tree. - match callback(inliner, parent_block) { - Ok(_) => break, - Err(err) => { - inliner.bindings = prev; - return Err(err); - } - } - } - - // Restore the original lexical scope - inliner.bindings = prev; - - Ok(()) -} From 8423d0d9f9dcac99e3b3c17f56e652cc5bc8c266 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 21 Jun 2024 13:29:45 -0400 Subject: [PATCH 3/6] ci: update to 1.78, use explicit toolchain --- .github/workflows/ci.yml | 2 +- Cargo.toml | 16 ++++++++++------ air-script/Cargo.toml | 6 +++--- codegen/masm/Cargo.toml | 8 +++++--- codegen/winterfell/Cargo.toml | 4 ++-- ir/Cargo.toml | 4 ++-- parser/Cargo.toml | 4 ++-- rust-toolchain.toml | 5 +++++ 8 files changed, 30 insertions(+), 19 deletions(-) create mode 100644 rust-toolchain.toml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4924f7ce..6e6834a2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,7 @@ on: branches: - main pull_request: - types: [opened, repoened, synchronize] + types: [opened, reopened, synchronize] jobs: test: diff --git a/Cargo.toml b/Cargo.toml index 40815df4..2e94e473 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,14 @@ [workspace] members = [ - "air-script", - "parser", - "pass", - "ir", - "codegen/masm", - "codegen/winterfell", + "air-script", + "parser", + "pass", + "ir", + "codegen/masm", + "codegen/winterfell", ] resolver = "2" + +[workspace.package] +edition = "2021" +rust-version = "1.78" diff --git a/air-script/Cargo.toml b/air-script/Cargo.toml index 1c706924..84856227 100644 --- a/air-script/Cargo.toml +++ b/air-script/Cargo.toml @@ -9,8 +9,8 @@ repository = "https://github.com/0xPolygonMiden/air-script" documentation = "https://0xpolygonmiden.github.io/air-script/" categories = ["compilers", "cryptography"] keywords = ["air", "stark", "zero-knowledge", "zkp"] -edition = "2021" -rust-version = "1.67" +edition.workspace = true +rust-version.workspace = true [[bin]] name = "airc" @@ -22,7 +22,7 @@ air-parser = { package = "air-parser", path = "../parser", version = "0.4" } air-pass = { package = "air-pass", path = "../pass", version = "0.1" } air-codegen-masm = { package = "air-codegen-masm", path = "../codegen/masm", version = "0.4" } air-codegen-winter = { package = "air-codegen-winter", path = "../codegen/winterfell", version = "0.4" } -clap = {version = "4.2", features = ["derive"] } +clap = { version = "4.2", features = ["derive"] } env_logger = "0.10" log = { version = "0.4", default-features = false } miden-diagnostics = "0.1" diff --git a/codegen/masm/Cargo.toml b/codegen/masm/Cargo.toml index 78712b97..896deea6 100644 --- a/codegen/masm/Cargo.toml +++ b/codegen/masm/Cargo.toml @@ -8,8 +8,8 @@ license = "MIT" repository = "https://github.com/0xPolygonMiden/air-script" categories = ["compilers", "cryptography"] keywords = ["air", "stark", "winterfell", "zero-knowledge", "zkp"] -edition = "2021" -rust-version = "1.67" +edition.workspace = true +rust-version.workspace = true [dependencies] air-ir = { package = "air-ir", path = "../../ir", version = "0.4" } @@ -22,6 +22,8 @@ winter-math = { package = "winter-math", version = "0.6", default-features = fal air-parser = { path = "../../parser" } air-pass = { path = "../../pass" } miden-assembly = { package = "miden-assembly", version = "0.6", default-features = false } -miden-processor = { package = "miden-processor", version = "0.6", features = ["internals"], default-features = false } +miden-processor = { package = "miden-processor", version = "0.6", features = [ + "internals", +], default-features = false } miden-diagnostics = "0.1" winter-air = { package = "winter-air", version = "0.6", default-features = false } diff --git a/codegen/winterfell/Cargo.toml b/codegen/winterfell/Cargo.toml index 6b734107..63cdcadf 100644 --- a/codegen/winterfell/Cargo.toml +++ b/codegen/winterfell/Cargo.toml @@ -8,8 +8,8 @@ license = "MIT" repository = "https://github.com/0xPolygonMiden/air-script" categories = ["compilers", "cryptography"] keywords = ["air", "stark", "winterfell", "zero-knowledge", "zkp"] -edition = "2021" -rust-version = "1.67" +edition.workspace = true +rust-version.workspace = true [dependencies] air-ir = { package = "air-ir", path = "../../ir", version = "0.4" } diff --git a/ir/Cargo.toml b/ir/Cargo.toml index be68ec6b..27901fd3 100644 --- a/ir/Cargo.toml +++ b/ir/Cargo.toml @@ -8,8 +8,8 @@ license = "MIT" repository = "https://github.com/0xPolygonMiden/air-script" categories = ["compilers", "cryptography"] keywords = ["air", "stark", "zero-knowledge", "zkp"] -edition = "2021" -rust-version = "1.67" +rust-version.workspace = true +edition.workspace = true [dependencies] air-parser = { package = "air-parser", path = "../parser", version = "0.4" } diff --git a/parser/Cargo.toml b/parser/Cargo.toml index 3a3d08f5..2b971bd3 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -8,8 +8,8 @@ license = "MIT" repository = "https://github.com/0xPolygonMiden/air-script" categories = ["compilers", "cryptography", "parser-implementations"] keywords = ["air", "stark", "zero-knowledge", "zkp"] -edition = "2021" -rust-version = "1.67" +rust-version.workspace = true +edition.workspace = true [build-dependencies] lalrpop = { version = "0.20", default-features = false } diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000..f2b38786 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,5 @@ +[toolchain] +channel = "1.78" +components = ["rustfmt", "rust-src", "clippy"] +targets = ["wasm32-unknown-unknown"] +profile = "minimal" From 7dcb4da589d5eed693d64f28e7255fb5ac7f5aef Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 21 Jun 2024 13:30:14 -0400 Subject: [PATCH 4/6] fix: a couple clippy warnings --- parser/src/sema/semantic_analysis.rs | 2 +- parser/src/transforms/inlining.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index 0788a03b..37861305 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -1386,7 +1386,7 @@ impl<'a> SemanticAnalysis<'a> { // and we will have already validated the reference let (import_id, module_id) = self.imported.get_key_value(&id).unwrap(); let module = self.library.get(module_id).unwrap(); - if module.evaluators.get(&id.id()).is_none() { + if !module.evaluators.contains_key(&id.id()) { self.invalid_constraint(id.span(), "calls in constraints must be to evaluator functions") .with_secondary_label(import_id.span(), "the function imported here is not an evaluator") .emit(); diff --git a/parser/src/transforms/inlining.rs b/parser/src/transforms/inlining.rs index 5f793b3a..871ee6a6 100644 --- a/parser/src/transforms/inlining.rs +++ b/parser/src/transforms/inlining.rs @@ -115,11 +115,11 @@ impl<'p> Pass for Inlining<'p> { .collect(); // We'll be referencing the trace configuration during inlining, so keep a copy of it - self.trace = program.trace_columns.clone(); + self.trace.clone_from(&program.trace_columns); // Same with the random values - self.random_values = program.random_values.clone(); + self.random_values.clone_from(&program.random_values); // And the public inputs - self.public_inputs = program.public_inputs.clone(); + self.public_inputs.clone_from(&program.public_inputs); // Add all of the local bindings visible in the root module, except for // constants and periodic columns, which by this point have been rewritten From 7f12780aaba74bcae27aa3e31caead6a4b3d25fc Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 9 Aug 2024 02:19:59 -0400 Subject: [PATCH 5/6] chore: remove leftover todo --- parser/src/parser/grammar.lalrpop | 1 - 1 file changed, 1 deletion(-) diff --git a/parser/src/parser/grammar.lalrpop b/parser/src/parser/grammar.lalrpop index c54077b1..0e2dbf19 100644 --- a/parser/src/parser/grammar.lalrpop +++ b/parser/src/parser/grammar.lalrpop @@ -280,7 +280,6 @@ FunctionBindingType: Type = { } FunctionBody: Vec = { - // TODO: validate =>? { if stmts.len() > 1 { diagnostics.diagnostic(Severity::Error) From 8a9bece511832ff803d92ca896a5bbd8933686a7 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 9 Aug 2024 02:29:59 -0400 Subject: [PATCH 6/6] chore: fix new clippy warnings in 1.80 --- parser/src/ast/expression.rs | 6 +++--- parser/src/ast/module.rs | 4 ++-- parser/src/transforms/inlining.rs | 24 ++++++++++++------------ 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index aa28d0f1..9f16c013 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -472,7 +472,7 @@ pub enum ScalarExpr { /// /// 1. The call is the top-level expression of a constraint, and is to an evaluator function /// 2. The call is not the top-level expression of a constraint, and is to a pure function - /// that produces a scalar value type. + /// that produces a scalar value type. /// /// If neither of the above are true, the call is invalid in a `ScalarExpr` context Call(Call), @@ -1138,8 +1138,8 @@ pub struct Call { /// /// * Calls to evaluators produce no value, and thus have no type /// * When parsed, the callee has not yet been resolved, so we don't know the - /// type of the function being called. During semantic analysis, the callee is - /// resolved and this field is set to the result type of that function. + /// type of the function being called. During semantic analysis, the callee is + /// resolved and this field is set to the result type of that function. pub ty: Option, } impl Call { diff --git a/parser/src/ast/module.rs b/parser/src/ast/module.rs index c9363ca9..2d306b36 100644 --- a/parser/src/ast/module.rs +++ b/parser/src/ast/module.rs @@ -35,8 +35,8 @@ pub enum ModuleType { /// /// * Fields which are only allowed in root modules are empty/unset in library modules /// * Fields which must be present in root modules are guaranteed to be present in a root module -/// * It is guaranteed that at least one boundary constraint and one integrity constraint are present -/// in a root module +/// * It is guaranteed that at least one boundary constraint and one integrity constraint are +/// present in a root module /// * No duplicate module-level declarations were present /// * All globally-visible declarations are unique /// diff --git a/parser/src/transforms/inlining.rs b/parser/src/transforms/inlining.rs index 871ee6a6..f4221517 100644 --- a/parser/src/transforms/inlining.rs +++ b/parser/src/transforms/inlining.rs @@ -364,21 +364,21 @@ impl<'a> Inlining<'a> { /// Let expressions are expanded using the following rules: /// /// * The let-bound expression is expanded first. If it expands to a statement block and - /// not an expression, the block is inlined in place of the let being expanded, and the - /// rest of the expansion takes place at the end of the block; replacing the last statement - /// in the block. If the last statement in the block was an expression, it is treated as - /// the let-bound value. If the last statement in the block was another `let` however, then - /// we recursively walk down the let tree until we reach the bottom, which must always be - /// an expression statement. + /// not an expression, the block is inlined in place of the let being expanded, and the + /// rest of the expansion takes place at the end of the block; replacing the last statement + /// in the block. If the last statement in the block was an expression, it is treated as + /// the let-bound value. If the last statement in the block was another `let` however, then + /// we recursively walk down the let tree until we reach the bottom, which must always be + /// an expression statement. /// /// * The body is expanded in-place after the previous step has been completed. /// /// * If a let-bound variable is an alias for a declaration, we replace all uses - /// of the variable with direct references to the declaration, making the let-bound variable - /// dead + /// of the variable with direct references to the declaration, making the let-bound + /// variable dead /// /// * If a let-bound variable is dead (i.e. has no references), then the let is elided, - /// by replacing it with the result of expanding its body + /// by replacing it with the result of expanding its body fn expand_let(&mut self, expr: Let) -> Result, SemanticAnalysisError> { let span = expr.span(); let name = expr.name; @@ -744,9 +744,9 @@ impl<'a> Inlining<'a> { /// the expansion is, respectively: /// /// * A tree of let statements (using generated variables), where each let binds the value of a - /// single iteration of the comprehension. The body of the final let, and thus the effective value - /// of the entire tree, is a vector containing all of the bindings in the evaluation order of the - /// comprehension. + /// single iteration of the comprehension. The body of the final let, and thus the effective + /// value of the entire tree, is a vector containing all of the bindings in the evaluation + /// order of the comprehension. /// * A flat list of constraint statements fn expand_comprehension( &mut self,