Skip to content

Commit

Permalink
Merge pull request #485 from powdr-labs/multi-return
Browse files Browse the repository at this point in the history
Implement Multi-return
  • Loading branch information
Leo authored Sep 29, 2023
2 parents 62fb60d + 8ec6bff commit 3e615a6
Show file tree
Hide file tree
Showing 15 changed files with 302 additions and 99 deletions.
96 changes: 54 additions & 42 deletions analysis/src/vm/inference.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
//! Infer assignment registers in asm statements
use ast::asm_analysis::{
AnalysisASMFile, AssignmentStatement, Expression, FunctionStatement, Machine,
use ast::{
asm_analysis::{AnalysisASMFile, Expression, FunctionStatement, Machine},
parsed::asm::AssignmentRegister,
};
use number::FieldElement;

Expand Down Expand Up @@ -33,47 +34,47 @@ fn infer_machine<T: FieldElement>(mut machine: Machine<T>) -> Result<Machine<T>,
for f in machine.callable.functions_mut() {
for s in f.body.statements.iter_mut() {
if let FunctionStatement::Assignment(a) = s {
let expr_reg = match &*a.rhs {
// Map function calls to the list of assignment registers and all other expressions to a list of None.
let expr_regs = match &*a.rhs {
Expression::FunctionCall(c) => {
let def = machine
.instructions
.iter()
.find(|i| i.name == c.id)
.unwrap();
let output = {
let outputs = def.instruction.params.outputs.as_ref().unwrap();
assert!(outputs.params.len() == 1);
&outputs.params[0]
};
assert!(output.ty.is_none());
Some(output.name.clone())

let outputs = def.instruction.params.outputs.clone().unwrap_or_default();

outputs
.params
.iter()
.map(|o| {
assert!(o.ty.is_none());
AssignmentRegister::Register(o.name.clone())
})
.collect::<Vec<_>>()
}
_ => None,
_ => vec![AssignmentRegister::Wildcard; a.lhs_with_reg.len()],
};

match (&mut a.using_reg, expr_reg) {
(Some(using_reg), Some(expr_reg)) if *using_reg != expr_reg => {
errors.push(format!("Assignment register `{}` is incompatible with `{}`. Try replacing `<={}=` by `<==`.", using_reg, a.rhs, using_reg));
}
(Some(_), _) => {}
(None, Some(expr_reg)) => {
// infer the assignment register to that of the rhs
a.using_reg = Some(expr_reg);
}
(None, None) => {
let hint = AssignmentStatement {
using_reg: Some(
machine
.registers
.iter()
.find(|r| r.ty.is_assignment())
.unwrap()
.name
.clone(),
),
..a.clone()
};
errors.push(format!("Impossible to infer the assignment register for `{a}`. Try using an assignment register like `{hint}`."));
assert_eq!(expr_regs.len(), a.lhs_with_reg.len());

for ((w, reg), expr_reg) in a.lhs_with_reg.iter_mut().zip(expr_regs) {
match (&reg, expr_reg) {
(
AssignmentRegister::Register(using_reg),
AssignmentRegister::Register(expr_reg),
) if *using_reg != expr_reg => {
errors.push(format!("Assignment register `{}` is incompatible with `{}`. Try using `<==` with no explicit assignment registers.", using_reg, a.rhs));
}
(AssignmentRegister::Register(_), _) => {}
(AssignmentRegister::Wildcard, AssignmentRegister::Register(expr_reg)) => {
// infer the assignment register to that of the rhs
*reg = AssignmentRegister::Register(expr_reg);
}
(AssignmentRegister::Wildcard, AssignmentRegister::Wildcard) => {
errors.push(format!("Impossible to infer the assignment register to write to register `{w}`"));
}
}
}
}
Expand Down Expand Up @@ -115,8 +116,8 @@ mod tests {

let file = infer_str::<Bn254Field>(file).unwrap();

if let FunctionStatement::Assignment(AssignmentStatement { using_reg, .. }) = file.machines
[&parse_absolute_path("Machine")]
if let FunctionStatement::Assignment(AssignmentStatement { lhs_with_reg, .. }) = file
.machines[&parse_absolute_path("Machine")]
.functions()
.next()
.unwrap()
Expand All @@ -126,7 +127,10 @@ mod tests {
.next()
.unwrap()
{
assert_eq!(*using_reg, Some("X".to_string()));
assert_eq!(
lhs_with_reg[0].1,
AssignmentRegister::Register("X".to_string())
);
} else {
panic!()
};
Expand All @@ -151,8 +155,8 @@ mod tests {

let file = infer_str::<Bn254Field>(file).unwrap();

if let FunctionStatement::Assignment(AssignmentStatement { using_reg, .. }) = &file.machines
[&parse_absolute_path("Machine")]
if let FunctionStatement::Assignment(AssignmentStatement { lhs_with_reg, .. }) = &file
.machines[&parse_absolute_path("Machine")]
.functions()
.next()
.unwrap()
Expand All @@ -162,7 +166,10 @@ mod tests {
.next()
.unwrap()
{
assert_eq!(*using_reg, Some("X".to_string()));
assert_eq!(
lhs_with_reg[0].1,
AssignmentRegister::Register("X".to_string())
);
} else {
panic!()
};
Expand All @@ -185,7 +192,7 @@ mod tests {
}
"#;

assert_eq!(infer_str::<Bn254Field>(file).unwrap_err(), vec!["Assignment register `Y` is incompatible with `foo()`. Try replacing `<=Y=` by `<==`."]);
assert_eq!(infer_str::<Bn254Field>(file).unwrap_err(), vec!["Assignment register `Y` is incompatible with `foo()`. Try using `<==` with no explicit assignment registers."]);
}

#[test]
Expand All @@ -203,6 +210,11 @@ mod tests {
}
"#;

assert_eq!(infer_str::<Bn254Field>(file).unwrap_err(), vec!["Impossible to infer the assignment register for `A <== 1;`. Try using an assignment register like `A <=X= 1;`.".to_string()]);
assert_eq!(
infer_str::<Bn254Field>(file).unwrap_err(),
vec![
"Impossible to infer the assignment register to write to register `A`".to_string()
]
);
}
}
12 changes: 8 additions & 4 deletions asm_to_pil/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub mod utils {
InstructionStatement, LabelStatement, RegisterDeclarationStatement, RegisterTy,
},
parsed::{
asm::{InstructionBody, MachineStatement, RegisterFlag},
asm::{AssignmentRegister, InstructionBody, MachineStatement, RegisterFlag},
PilStatement,
},
};
Expand Down Expand Up @@ -76,11 +76,15 @@ pub mod utils {
.parse::<T>(input)
.unwrap()
{
ast::parsed::asm::FunctionStatement::Assignment(start, lhs, using_reg, rhs) => {
ast::parsed::asm::FunctionStatement::Assignment(start, lhs, reg, rhs) => {
AssignmentStatement {
start,
lhs,
using_reg,
lhs_with_reg: {
let lhs_len = lhs.len();
lhs.into_iter()
.zip(reg.unwrap_or(vec![AssignmentRegister::Wildcard; lhs_len]))
.collect()
},
rhs,
}
.into()
Expand Down
60 changes: 33 additions & 27 deletions asm_to_pil/src/vm_to_constrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,22 @@ impl<T: FieldElement> ASMPILConverter<T> {
match statement {
FunctionStatement::Assignment(AssignmentStatement {
start,
lhs,
using_reg,
lhs_with_reg,
rhs,
}) => match *rhs {
Expression::FunctionCall(c) => {
self.handle_functional_instruction(lhs, using_reg.unwrap(), c.id, c.arguments)
}) => {
let lhs_with_reg = lhs_with_reg
.into_iter()
// All assignment registers should be inferred at this point.
.map(|(lhs, reg)| (lhs, reg.unwrap()))
.collect();

match *rhs {
Expression::FunctionCall(c) => {
self.handle_functional_instruction(lhs_with_reg, c.id, c.arguments)
}
_ => self.handle_non_functional_assignment(start, lhs_with_reg, *rhs),
}
_ => self.handle_assignment(start, lhs, using_reg, *rhs),
},
}
FunctionStatement::Instruction(InstructionStatement {
instruction,
inputs,
Expand Down Expand Up @@ -436,48 +443,47 @@ impl<T: FieldElement> ASMPILConverter<T> {
res
}

fn handle_assignment(
fn handle_non_functional_assignment(
&mut self,
_start: usize,
write_regs: Vec<String>,
assign_reg: Option<String>,
lhs_with_reg: Vec<(String, String)>,
value: Expression<T>,
) -> CodeLine<T> {
assert!(write_regs.len() <= 1);
assert!(
assign_reg.is_some(),
"Implicit assign register not yet supported."
lhs_with_reg.len() == 1,
"Multi assignments are only implemented for function calls."
);
let assign_reg = assign_reg.unwrap();
let (write_regs, assign_reg) = lhs_with_reg.into_iter().next().unwrap();
let value = self.process_assignment_value(value);
CodeLine {
write_regs: [(assign_reg.clone(), write_regs)].into_iter().collect(),
write_regs: [(assign_reg.clone(), vec![write_regs])]
.into_iter()
.collect(),
value: [(assign_reg, value)].into(),
..Default::default()
}
}

fn handle_functional_instruction(
&mut self,
write_regs: Vec<String>,
assign_reg: String,
lhs_with_regs: Vec<(String, String)>,
instr_name: String,
args: Vec<Expression<T>>,
mut args: Vec<Expression<T>>,
) -> CodeLine<T> {
assert!(write_regs.len() == 1);
let instr = &self
.instructions
.get(&instr_name)
.unwrap_or_else(|| panic!("Instruction not found: {instr_name}"));
assert_eq!(instr.outputs.len(), 1);
let output = instr.outputs[0].clone();
assert!(
output == assign_reg,
"The instruction {instr_name} uses the assignment register {output}, but the caller uses {assign_reg} to further process the value.",
);
let output = instr.outputs.clone();

for (o, (_, r)) in output.iter().zip(lhs_with_regs.iter()) {
assert!(
o == r,
"The instruction {instr_name} uses the output register {o}, but the caller uses {r} to further process the value.",
);
}

let mut args = args;
args.push(direct_reference(write_regs.first().unwrap().clone()));
args.extend(lhs_with_regs.iter().map(|(lhs, _)| direct_reference(lhs)));
self.handle_instruction(instr_name, args)
}

Expand Down
7 changes: 2 additions & 5 deletions ast/src/asm_analysis/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,8 @@ impl<T: Display> Display for AssignmentStatement<T> {
write!(
f,
"{} <={}= {};",
self.lhs.join(", "),
self.using_reg
.as_ref()
.map(ToString::to_string)
.unwrap_or_default(),
self.lhs().format(", "),
self.assignment_registers().format(", "),
self.rhs
)
}
Expand Down
17 changes: 14 additions & 3 deletions ast/src/asm_analysis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ use num_bigint::BigUint;
use number::FieldElement;

use crate::parsed::{
asm::{AbsoluteSymbolPath, CallableRef, InstructionBody, OperationId, Params},
asm::{
AbsoluteSymbolPath, AssignmentRegister, CallableRef, InstructionBody, OperationId, Params,
},
PilStatement,
};

Expand Down Expand Up @@ -570,11 +572,20 @@ impl<T> From<Return<T>> for FunctionStatement<T> {
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct AssignmentStatement<T> {
pub start: usize,
pub lhs: Vec<String>,
pub using_reg: Option<String>,
pub lhs_with_reg: Vec<(String, AssignmentRegister)>,
pub rhs: Box<Expression<T>>,
}

impl<T> AssignmentStatement<T> {
fn lhs(&self) -> impl Iterator<Item = &String> {
self.lhs_with_reg.iter().map(|(lhs, _)| lhs)
}

fn assignment_registers(&self) -> impl Iterator<Item = &AssignmentRegister> {
self.lhs_with_reg.iter().map(|(_, reg)| reg)
}
}

#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct InstructionStatement<T> {
pub start: usize,
Expand Down
22 changes: 21 additions & 1 deletion ast/src/parsed/asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,29 @@ pub enum InstructionBody<T> {
CallableRef(CallableRef),
}

#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum AssignmentRegister {
Register(String),
Wildcard,
}

impl AssignmentRegister {
pub fn unwrap(self) -> String {
match self {
AssignmentRegister::Register(r) => r,
AssignmentRegister::Wildcard => panic!("cannot unwrap wildcard"),
}
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub enum FunctionStatement<T> {
Assignment(usize, Vec<String>, Option<String>, Box<Expression<T>>),
Assignment(
usize,
Vec<String>,
Option<Vec<AssignmentRegister>>,
Box<Expression<T>>,
),
Instruction(usize, String, Vec<Expression<T>>),
Label(usize, String),
DebugDirective(usize, DebugDirective),
Expand Down
15 changes: 14 additions & 1 deletion ast/src/parsed/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,19 @@ impl<T: Display> Display for OperationId<T> {
}
}

impl Display for AssignmentRegister {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(
f,
"{}",
match self {
Self::Register(r) => r.to_string(),
Self::Wildcard => "_".to_string(),
}
)
}
}

impl<T: Display> Display for FunctionStatement<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
match self {
Expand All @@ -186,7 +199,7 @@ impl<T: Display> Display for FunctionStatement<T> {
write_regs.join(", "),
assignment_reg
.as_ref()
.map(ToString::to_string)
.map(|s| s.iter().format(", ").to_string())
.unwrap_or_default(),
expression
),
Expand Down
Loading

0 comments on commit 3e615a6

Please sign in to comment.