From 8ec6bffe96651a6db895e154e5af11e87921d367 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Thu, 28 Sep 2023 16:04:51 +0000 Subject: [PATCH] Implement Multi-return --- analysis/src/vm/inference.rs | 96 +++++++++++-------- asm_to_pil/src/lib.rs | 12 ++- asm_to_pil/src/vm_to_constrained.rs | 60 ++++++------ ast/src/asm_analysis/display.rs | 7 +- ast/src/asm_analysis/mod.rs | 17 +++- ast/src/parsed/asm.rs | 22 ++++- ast/src/parsed/display.rs | 15 ++- book/src/asm/functions.md | 12 ++- compiler/tests/asm.rs | 25 +++++ parser/src/powdr.lalrpop | 15 ++- test_data/asm/book/function.asm | 19 ++-- test_data/asm/multi_return.asm | 31 ++++++ ...eturn_wrong_assignment_register_length.asm | 23 +++++ ...ulti_return_wrong_assignment_registers.asm | 23 +++++ type_check/src/lib.rs | 24 ++++- 15 files changed, 302 insertions(+), 99 deletions(-) create mode 100644 test_data/asm/multi_return.asm create mode 100644 test_data/asm/multi_return_wrong_assignment_register_length.asm create mode 100644 test_data/asm/multi_return_wrong_assignment_registers.asm diff --git a/analysis/src/vm/inference.rs b/analysis/src/vm/inference.rs index 0c7f9cab10..6343baea06 100644 --- a/analysis/src/vm/inference.rs +++ b/analysis/src/vm/inference.rs @@ -1,7 +1,8 @@ //! Infer assignment registers in asm statements -use ast::asm_analysis::{ - AnalysisASMFile, AssignmentStatement, Expression, FunctionStatement, Machine, +use ast::{ + asm_analysis::{AnalysisASMFile, Expression, FunctionStatement, Machine}, + parsed::asm::AssignmentRegister, }; use number::FieldElement; @@ -33,47 +34,47 @@ fn infer_machine(mut machine: Machine) -> Result, for f in machine.callable.functions_mut() { for s in f.body.statements.iter_mut() { if let FunctionStatement::Assignment(a) = s { - let expr_reg = match &*a.rhs { + // Map function calls to the list of assignment registers and all other expressions to a list of None. + let expr_regs = match &*a.rhs { Expression::FunctionCall(c) => { let def = machine .instructions .iter() .find(|i| i.name == c.id) .unwrap(); - let output = { - let outputs = def.instruction.params.outputs.as_ref().unwrap(); - assert!(outputs.params.len() == 1); - &outputs.params[0] - }; - assert!(output.ty.is_none()); - Some(output.name.clone()) + + let outputs = def.instruction.params.outputs.clone().unwrap_or_default(); + + outputs + .params + .iter() + .map(|o| { + assert!(o.ty.is_none()); + AssignmentRegister::Register(o.name.clone()) + }) + .collect::>() } - _ => None, + _ => vec![AssignmentRegister::Wildcard; a.lhs_with_reg.len()], }; - match (&mut a.using_reg, expr_reg) { - (Some(using_reg), Some(expr_reg)) if *using_reg != expr_reg => { - errors.push(format!("Assignment register `{}` is incompatible with `{}`. Try replacing `<={}=` by `<==`.", using_reg, a.rhs, using_reg)); - } - (Some(_), _) => {} - (None, Some(expr_reg)) => { - // infer the assignment register to that of the rhs - a.using_reg = Some(expr_reg); - } - (None, None) => { - let hint = AssignmentStatement { - using_reg: Some( - machine - .registers - .iter() - .find(|r| r.ty.is_assignment()) - .unwrap() - .name - .clone(), - ), - ..a.clone() - }; - errors.push(format!("Impossible to infer the assignment register for `{a}`. Try using an assignment register like `{hint}`.")); + assert_eq!(expr_regs.len(), a.lhs_with_reg.len()); + + for ((w, reg), expr_reg) in a.lhs_with_reg.iter_mut().zip(expr_regs) { + match (®, expr_reg) { + ( + AssignmentRegister::Register(using_reg), + AssignmentRegister::Register(expr_reg), + ) if *using_reg != expr_reg => { + errors.push(format!("Assignment register `{}` is incompatible with `{}`. Try using `<==` with no explicit assignment registers.", using_reg, a.rhs)); + } + (AssignmentRegister::Register(_), _) => {} + (AssignmentRegister::Wildcard, AssignmentRegister::Register(expr_reg)) => { + // infer the assignment register to that of the rhs + *reg = AssignmentRegister::Register(expr_reg); + } + (AssignmentRegister::Wildcard, AssignmentRegister::Wildcard) => { + errors.push(format!("Impossible to infer the assignment register to write to register `{w}`")); + } } } } @@ -115,8 +116,8 @@ mod tests { let file = infer_str::(file).unwrap(); - if let FunctionStatement::Assignment(AssignmentStatement { using_reg, .. }) = file.machines - [&parse_absolute_path("Machine")] + if let FunctionStatement::Assignment(AssignmentStatement { lhs_with_reg, .. }) = file + .machines[&parse_absolute_path("Machine")] .functions() .next() .unwrap() @@ -126,7 +127,10 @@ mod tests { .next() .unwrap() { - assert_eq!(*using_reg, Some("X".to_string())); + assert_eq!( + lhs_with_reg[0].1, + AssignmentRegister::Register("X".to_string()) + ); } else { panic!() }; @@ -151,8 +155,8 @@ mod tests { let file = infer_str::(file).unwrap(); - if let FunctionStatement::Assignment(AssignmentStatement { using_reg, .. }) = &file.machines - [&parse_absolute_path("Machine")] + if let FunctionStatement::Assignment(AssignmentStatement { lhs_with_reg, .. }) = &file + .machines[&parse_absolute_path("Machine")] .functions() .next() .unwrap() @@ -162,7 +166,10 @@ mod tests { .next() .unwrap() { - assert_eq!(*using_reg, Some("X".to_string())); + assert_eq!( + lhs_with_reg[0].1, + AssignmentRegister::Register("X".to_string()) + ); } else { panic!() }; @@ -185,7 +192,7 @@ mod tests { } "#; - assert_eq!(infer_str::(file).unwrap_err(), vec!["Assignment register `Y` is incompatible with `foo()`. Try replacing `<=Y=` by `<==`."]); + assert_eq!(infer_str::(file).unwrap_err(), vec!["Assignment register `Y` is incompatible with `foo()`. Try using `<==` with no explicit assignment registers."]); } #[test] @@ -203,6 +210,11 @@ mod tests { } "#; - assert_eq!(infer_str::(file).unwrap_err(), vec!["Impossible to infer the assignment register for `A <== 1;`. Try using an assignment register like `A <=X= 1;`.".to_string()]); + assert_eq!( + infer_str::(file).unwrap_err(), + vec![ + "Impossible to infer the assignment register to write to register `A`".to_string() + ] + ); } } diff --git a/asm_to_pil/src/lib.rs b/asm_to_pil/src/lib.rs index aaa762bb78..b3d76238aa 100644 --- a/asm_to_pil/src/lib.rs +++ b/asm_to_pil/src/lib.rs @@ -28,7 +28,7 @@ pub mod utils { InstructionStatement, LabelStatement, RegisterDeclarationStatement, RegisterTy, }, parsed::{ - asm::{InstructionBody, MachineStatement, RegisterFlag}, + asm::{AssignmentRegister, InstructionBody, MachineStatement, RegisterFlag}, PilStatement, }, }; @@ -76,11 +76,15 @@ pub mod utils { .parse::(input) .unwrap() { - ast::parsed::asm::FunctionStatement::Assignment(start, lhs, using_reg, rhs) => { + ast::parsed::asm::FunctionStatement::Assignment(start, lhs, reg, rhs) => { AssignmentStatement { start, - lhs, - using_reg, + lhs_with_reg: { + let lhs_len = lhs.len(); + lhs.into_iter() + .zip(reg.unwrap_or(vec![AssignmentRegister::Wildcard; lhs_len])) + .collect() + }, rhs, } .into() diff --git a/asm_to_pil/src/vm_to_constrained.rs b/asm_to_pil/src/vm_to_constrained.rs index aa8961565d..1c5e5b9e30 100644 --- a/asm_to_pil/src/vm_to_constrained.rs +++ b/asm_to_pil/src/vm_to_constrained.rs @@ -241,15 +241,22 @@ impl ASMPILConverter { match statement { FunctionStatement::Assignment(AssignmentStatement { start, - lhs, - using_reg, + lhs_with_reg, rhs, - }) => match *rhs { - Expression::FunctionCall(c) => { - self.handle_functional_instruction(lhs, using_reg.unwrap(), c.id, c.arguments) + }) => { + let lhs_with_reg = lhs_with_reg + .into_iter() + // All assignment registers should be inferred at this point. + .map(|(lhs, reg)| (lhs, reg.unwrap())) + .collect(); + + match *rhs { + Expression::FunctionCall(c) => { + self.handle_functional_instruction(lhs_with_reg, c.id, c.arguments) + } + _ => self.handle_non_functional_assignment(start, lhs_with_reg, *rhs), } - _ => self.handle_assignment(start, lhs, using_reg, *rhs), - }, + } FunctionStatement::Instruction(InstructionStatement { instruction, inputs, @@ -436,22 +443,22 @@ impl ASMPILConverter { res } - fn handle_assignment( + fn handle_non_functional_assignment( &mut self, _start: usize, - write_regs: Vec, - assign_reg: Option, + lhs_with_reg: Vec<(String, String)>, value: Expression, ) -> CodeLine { - assert!(write_regs.len() <= 1); assert!( - assign_reg.is_some(), - "Implicit assign register not yet supported." + lhs_with_reg.len() == 1, + "Multi assignments are only implemented for function calls." ); - let assign_reg = assign_reg.unwrap(); + let (write_regs, assign_reg) = lhs_with_reg.into_iter().next().unwrap(); let value = self.process_assignment_value(value); CodeLine { - write_regs: [(assign_reg.clone(), write_regs)].into_iter().collect(), + write_regs: [(assign_reg.clone(), vec![write_regs])] + .into_iter() + .collect(), value: [(assign_reg, value)].into(), ..Default::default() } @@ -459,25 +466,24 @@ impl ASMPILConverter { fn handle_functional_instruction( &mut self, - write_regs: Vec, - assign_reg: String, + lhs_with_regs: Vec<(String, String)>, instr_name: String, - args: Vec>, + mut args: Vec>, ) -> CodeLine { - assert!(write_regs.len() == 1); let instr = &self .instructions .get(&instr_name) .unwrap_or_else(|| panic!("Instruction not found: {instr_name}")); - assert_eq!(instr.outputs.len(), 1); - let output = instr.outputs[0].clone(); - assert!( - output == assign_reg, - "The instruction {instr_name} uses the assignment register {output}, but the caller uses {assign_reg} to further process the value.", - ); + let output = instr.outputs.clone(); + + for (o, (_, r)) in output.iter().zip(lhs_with_regs.iter()) { + assert!( + o == r, + "The instruction {instr_name} uses the output register {o}, but the caller uses {r} to further process the value.", + ); + } - let mut args = args; - args.push(direct_reference(write_regs.first().unwrap().clone())); + args.extend(lhs_with_regs.iter().map(|(lhs, _)| direct_reference(lhs))); self.handle_instruction(instr_name, args) } diff --git a/ast/src/asm_analysis/display.rs b/ast/src/asm_analysis/display.rs index 5b275adc16..2a5b8b525e 100644 --- a/ast/src/asm_analysis/display.rs +++ b/ast/src/asm_analysis/display.rs @@ -96,11 +96,8 @@ impl Display for AssignmentStatement { write!( f, "{} <={}= {};", - self.lhs.join(", "), - self.using_reg - .as_ref() - .map(ToString::to_string) - .unwrap_or_default(), + self.lhs().format(", "), + self.assignment_registers().format(", "), self.rhs ) } diff --git a/ast/src/asm_analysis/mod.rs b/ast/src/asm_analysis/mod.rs index 91607c087e..0ee69d83cb 100644 --- a/ast/src/asm_analysis/mod.rs +++ b/ast/src/asm_analysis/mod.rs @@ -14,7 +14,9 @@ use num_bigint::BigUint; use number::FieldElement; use crate::parsed::{ - asm::{AbsoluteSymbolPath, CallableRef, InstructionBody, OperationId, Params}, + asm::{ + AbsoluteSymbolPath, AssignmentRegister, CallableRef, InstructionBody, OperationId, Params, + }, PilStatement, }; @@ -570,11 +572,20 @@ impl From> for FunctionStatement { #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct AssignmentStatement { pub start: usize, - pub lhs: Vec, - pub using_reg: Option, + pub lhs_with_reg: Vec<(String, AssignmentRegister)>, pub rhs: Box>, } +impl AssignmentStatement { + fn lhs(&self) -> impl Iterator { + self.lhs_with_reg.iter().map(|(lhs, _)| lhs) + } + + fn assignment_registers(&self) -> impl Iterator { + self.lhs_with_reg.iter().map(|(_, reg)| reg) + } +} + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct InstructionStatement { pub start: usize, diff --git a/ast/src/parsed/asm.rs b/ast/src/parsed/asm.rs index 2909d04517..104d6a094b 100644 --- a/ast/src/parsed/asm.rs +++ b/ast/src/parsed/asm.rs @@ -292,9 +292,29 @@ pub enum InstructionBody { CallableRef(CallableRef), } +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum AssignmentRegister { + Register(String), + Wildcard, +} + +impl AssignmentRegister { + pub fn unwrap(self) -> String { + match self { + AssignmentRegister::Register(r) => r, + AssignmentRegister::Wildcard => panic!("cannot unwrap wildcard"), + } + } +} + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] pub enum FunctionStatement { - Assignment(usize, Vec, Option, Box>), + Assignment( + usize, + Vec, + Option>, + Box>, + ), Instruction(usize, String, Vec>), Label(usize, String), DebugDirective(usize, DebugDirective), diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index de1f38a34f..8e1f34e261 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -177,6 +177,19 @@ impl Display for OperationId { } } +impl Display for AssignmentRegister { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!( + f, + "{}", + match self { + Self::Register(r) => r.to_string(), + Self::Wildcard => "_".to_string(), + } + ) + } +} + impl Display for FunctionStatement { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match self { @@ -186,7 +199,7 @@ impl Display for FunctionStatement { write_regs.join(", "), assignment_reg .as_ref() - .map(ToString::to_string) + .map(|s| s.iter().format(", ").to_string()) .unwrap_or_default(), expression ), diff --git a/book/src/asm/functions.md b/book/src/asm/functions.md index 14152d6b22..0629788b13 100644 --- a/book/src/asm/functions.md +++ b/book/src/asm/functions.md @@ -24,16 +24,22 @@ Labels allow referring to a location in a function by name. ### Assignments -Assignments allow setting the value of a write register to the value of an [expression](#expressions) using an assignment register. +Assignments allow setting the values of some write registers to the values of some expressions [expression](#expressions) using assignment registers. + +``` +{{#include ../../../test_data/asm/book/function.asm:literals}} +``` + +If the right-hand side of the assignment is an instruction, assignment registers can be inferred and are optional: ``` {{#include ../../../test_data/asm/book/function.asm:instruction}} ``` -One important requirement is for the assignment register of the assignment to be compatible with that of the expression. This is especially relevant for instructions: the assignment register of the instruction output must match that of the assignment. In this example, we use `Y` in the assignment as the output of `square` is `Y`: +This will be inferred to be the same as `A, B <=Y, Z= square_and_double(A);` from the definition of the instruction: ``` -{{#include ../../../test_data/asm/book/function.asm:square}} +{{#include ../../../test_data/asm/book/function.asm:square_and_double}} ``` ### Instructions diff --git a/compiler/tests/asm.rs b/compiler/tests/asm.rs index 5db08b8bce..ed62a390ec 100644 --- a/compiler/tests/asm.rs +++ b/compiler/tests/asm.rs @@ -187,6 +187,31 @@ fn test_multi_assign() { gen_estark_proof(f, slice_to_vec(&i)); } +#[test] +fn test_multi_return() { + let f = "multi_return.asm"; + let i = []; + verify_asm::(f, slice_to_vec(&i)); + gen_halo2_proof(f, slice_to_vec(&i)); + gen_estark_proof(f, Default::default()); +} + +#[test] +#[should_panic = "called `Result::unwrap()` on an `Err` value: [\"Assignment register `Z` is incompatible with `square_and_double(3)`. Try using `<==` with no explicit assignment registers.\", \"Assignment register `Y` is incompatible with `square_and_double(3)`. Try using `<==` with no explicit assignment registers.\"]"] +fn test_multi_return_wrong_assignment_registers() { + let f = "multi_return_wrong_assignment_registers.asm"; + let i = []; + verify_asm::(f, slice_to_vec(&i)); +} + +#[test] +#[should_panic = "Result::unwrap()` on an `Err` value: [\"Mismatched number of registers for assignment A, B <=Y= square_and_double(3);\"]"] +fn test_multi_return_wrong_assignment_register_length() { + let f = "multi_return_wrong_assignment_register_length.asm"; + let i = []; + verify_asm::(f, slice_to_vec(&i)); +} + #[test] fn test_bit_access() { let f = "bit_access.asm"; diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index fb744c120d..ea44064872 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -304,8 +304,19 @@ IdentifierList: Vec = { => vec![] } -AssignOperator: Option = { - "<=" "=" +AssignOperator: Option> = { + "<==" => None, + "<=" "=" => Some(<>) +} + +AssignmentRegisterList: Vec = { + "," )*> => { list.push(end); list }, + => vec![] +} + +AssignmentRegister: AssignmentRegister = { + => AssignmentRegister::Register(<>), + "_" => AssignmentRegister::Wildcard, } ReturnStatement: FunctionStatement = { diff --git a/test_data/asm/book/function.asm b/test_data/asm/book/function.asm index 852903c4a4..66731b678c 100644 --- a/test_data/asm/book/function.asm +++ b/test_data/asm/book/function.asm @@ -7,8 +7,10 @@ machine Machine { reg pc[@pc]; reg X[<=]; reg Y[<=]; + reg Z[<=]; reg CNT; reg A; + reg B; // an instruction to assert that a number is zero instr assert_zero X { @@ -25,12 +27,13 @@ machine Machine { pc' = XIsZero * l + (1 - XIsZero) * (pc + 1) } - // an instruction to return the square of an input - // ANCHOR: square - instr square X -> Y { - Y = X * X + // an instruction to return the square of an input as well as its double + // ANCHOR: square_and_double + instr square_and_double X -> Y, Z { + Y = X * X, + Z = 2 * X } - // ANCHOR_END: square + // ANCHOR_END: square_and_double function main { // initialise `A` to 2 @@ -48,9 +51,9 @@ machine Machine { // ANCHOR: read_register CNT <=X= CNT - 1; // ANCHOR_END: read_register - // square `A` + // get the square and the double of `A` // ANCHOR: instruction - A <== square(A); + A, B <== square_and_double(A); // ANCHOR_END: instruction // jump back to `start` jmp start; @@ -59,6 +62,8 @@ machine Machine { // ANCHOR: instruction_statement assert_zero A - ((2**2)**2)**2; // ANCHOR_END: instruction_statement + // check that `B == ((2**2)**2)*2` + assert_zero B - ((2**2)**2)*2; return; } diff --git a/test_data/asm/multi_return.asm b/test_data/asm/multi_return.asm new file mode 100644 index 0000000000..cfe64c5842 --- /dev/null +++ b/test_data/asm/multi_return.asm @@ -0,0 +1,31 @@ +machine MultiAssign { + degree 16; + + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + reg Z[<=]; + reg A; + reg B; + + instr assert_eq X, Y { X = Y } + + instr square_and_double X -> Y, Z { + Y = X * X, + Z = 2 * X + } + + function main { + + // Different ways of expressing the same thing... + A, B <== square_and_double(3); + A, B <=Y,Z= square_and_double(3); + A, B <=Y,_= square_and_double(3); + A, B <=_,Z= square_and_double(3); + + assert_eq A, 9; + assert_eq B, 6; + + return; + } +} \ No newline at end of file diff --git a/test_data/asm/multi_return_wrong_assignment_register_length.asm b/test_data/asm/multi_return_wrong_assignment_register_length.asm new file mode 100644 index 0000000000..20e67196d6 --- /dev/null +++ b/test_data/asm/multi_return_wrong_assignment_register_length.asm @@ -0,0 +1,23 @@ +machine MultiAssign { + degree 16; + + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + reg Z[<=]; + reg A; + reg B; + + instr square_and_double X -> Y, Z { + Y = X * X, + Z = 2 * X + } + + function main { + + // Should be using assignment registers Y, Z + A, B <=Y= square_and_double(3); + + return; + } +} \ No newline at end of file diff --git a/test_data/asm/multi_return_wrong_assignment_registers.asm b/test_data/asm/multi_return_wrong_assignment_registers.asm new file mode 100644 index 0000000000..09c744c0cf --- /dev/null +++ b/test_data/asm/multi_return_wrong_assignment_registers.asm @@ -0,0 +1,23 @@ +machine MultiAssign { + degree 16; + + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + reg Z[<=]; + reg A; + reg B; + + instr square_and_double X -> Y, Z { + Y = X * X, + Z = 2 * X + } + + function main { + + // Should be using assignment registers Y, Z + A, B <=Z,Y= square_and_double(3); + + return; + } +} \ No newline at end of file diff --git a/type_check/src/lib.rs b/type_check/src/lib.rs index 6e464a7a75..71eb766f02 100644 --- a/type_check/src/lib.rs +++ b/type_check/src/lib.rs @@ -11,8 +11,9 @@ use ast::{ parsed::{ self, asm::{ - self, ASMModule, ASMProgram, AbsoluteSymbolPath, FunctionStatement, InstructionBody, - LinkDeclaration, MachineStatement, ModuleStatement, RegisterFlag, SymbolDefinition, + self, ASMModule, ASMProgram, AbsoluteSymbolPath, AssignmentRegister, FunctionStatement, + InstructionBody, LinkDeclaration, MachineStatement, ModuleStatement, RegisterFlag, + SymbolDefinition, }, }, }; @@ -98,13 +99,28 @@ impl TypeChecker { MachineStatement::FunctionDeclaration(start, name, params, statements) => { let mut function_statements = vec![]; for s in statements { + let statement_string = s.to_string(); match s { FunctionStatement::Assignment(start, lhs, using_reg, rhs) => { + if let Some(using_reg) = &using_reg { + if using_reg.len() != lhs.len() { + errors.push(format!( + "Mismatched number of registers for assignment {}", + statement_string + )); + } + } + let using_reg = using_reg.unwrap_or_else(|| { + vec![AssignmentRegister::Wildcard; lhs.len()] + }); + let lhs_with_reg = lhs + .into_iter() + .zip(using_reg.into_iter()) + .collect::>(); function_statements.push( AssignmentStatement { start, - lhs, - using_reg, + lhs_with_reg, rhs, } .into(),