From 127dec572b44a9d98515f7316ba5e94c7b6f9359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gast=C3=B3n=20Zanitti?= Date: Mon, 9 Dec 2024 07:03:41 -0300 Subject: [PATCH] powdr-asmopt: remove unused submachines, instructions, registers (#2143) Solves #682 --- Cargo.toml | 2 + asmopt/Cargo.toml | 14 ++ asmopt/src/lib.rs | 276 +++++++++++++++++++++ asmopt/tests/optimizer.rs | 360 ++++++++++++++++++++++++++++ ast/src/asm_analysis/mod.rs | 15 +- pilopt/src/lib.rs | 2 +- pilopt/src/referenced_symbols.rs | 257 +++++++++++++++++++- pipeline/Cargo.toml | 1 + pipeline/src/pipeline.rs | 28 ++- test_data/asm/book/declarations.asm | 3 +- 10 files changed, 944 insertions(+), 14 deletions(-) create mode 100644 asmopt/Cargo.toml create mode 100644 asmopt/src/lib.rs create mode 100644 asmopt/tests/optimizer.rs diff --git a/Cargo.toml b/Cargo.toml index fa798ea267..bdaa7d2c41 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "pilopt", "plonky3", "asm-to-pil", + "asmopt", "backend", "ast", "analysis", @@ -49,6 +50,7 @@ powdr-ast = { path = "./ast", version = "0.1.3" } powdr-asm-to-pil = { path = "./asm-to-pil", version = "0.1.3" } powdr-isa-utils = { path = "./isa-utils", version = "0.1.3" } powdr-analysis = { path = "./analysis", version = "0.1.3" } +powdr-asmopt = { path = "./asmopt", version = "0.1.3" } powdr-backend = { path = "./backend", version = "0.1.3" } powdr-backend-utils = { path = "./backend-utils", version = "0.1.3" } powdr-executor = { path = "./executor", version = "0.1.3" } diff --git a/asmopt/Cargo.toml b/asmopt/Cargo.toml new file mode 100644 index 0000000000..998c554f49 --- /dev/null +++ b/asmopt/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "powdr-asmopt" +version.workspace = true +edition.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +powdr-ast.workspace = true +powdr-analysis.workspace = true +powdr-importer.workspace = true +powdr-pilopt.workspace = true +powdr-parser.workspace = true \ No newline at end of file diff --git a/asmopt/src/lib.rs b/asmopt/src/lib.rs new file mode 100644 index 0000000000..4b3ea135a8 --- /dev/null +++ b/asmopt/src/lib.rs @@ -0,0 +1,276 @@ +use std::collections::{HashMap, HashSet}; +use std::iter::once; + +use powdr_ast::parsed::asm::parse_absolute_path; +use powdr_ast::{ + asm_analysis::{AnalysisASMFile, Machine}, + parsed::{asm::AbsoluteSymbolPath, NamespacedPolynomialReference}, +}; +use powdr_pilopt::referenced_symbols::ReferencedSymbols; + +type Expression = powdr_ast::asm_analysis::Expression; + +const MAIN_MACHINE_STR: &str = "::Main"; +const PC_REGISTER: &str = "pc"; + +pub fn optimize(mut analyzed_asm: AnalysisASMFile) -> AnalysisASMFile { + // Optimizations assume the existence of a Main machine as an entry point. + // If it doesn't exist, return the ASM as-is to prevent removing all machines, + // which would break some examples. + let main_machine_path = parse_absolute_path(MAIN_MACHINE_STR); + if analyzed_asm + .machines() + .all(|(path, _)| path != main_machine_path) + { + return analyzed_asm; + } + + asm_remove_unreferenced_machines(&mut analyzed_asm); + asm_remove_unused_machine_components(&mut analyzed_asm); + asm_remove_unreferenced_machines(&mut analyzed_asm); + + analyzed_asm +} + +/// Remove all machines that are not referenced in any other machine. +/// This function traverses the dependency graph starting from ::Main to identify all reachable machines. +fn asm_remove_unreferenced_machines(asm_file: &mut AnalysisASMFile) { + let deps = build_machine_dependencies(asm_file); + let all_machines = collect_all_dependent_machines(&deps, MAIN_MACHINE_STR) + .into_iter() + .collect::>(); + asm_file.modules.iter_mut().for_each(|(path, module)| { + let machines_in_module = machines_in_module(&all_machines, path); + module.retain_machines(machines_in_module); + }); +} + +/// Analyzes each machine and successively removes unnecessary components: +/// 1. Removes declarations of instructions that are never used. +/// 2. Removes instances of submachines that are never used, including those that became unused in the previous step. +/// 3. Removes unused registers. +fn asm_remove_unused_machine_components(asm_file: &mut AnalysisASMFile) { + for (_, machine) in asm_file.machines_mut() { + let submachine_to_decl: HashMap = machine + .submachines + .iter() + .map(|sub| (sub.name.clone(), sub.ty.to_string())) + .collect(); + + let symbols_in_callable: HashSet = machine_callable_body_symbols(machine).collect(); + + machine_remove_unused_instructions(machine, &symbols_in_callable); + machine_remove_unused_submachines(machine, &symbols_in_callable, &submachine_to_decl); + machine_remove_unused_registers(machine, &submachine_to_decl); + } +} + +fn machine_remove_unused_registers( + machine: &mut Machine, + submachine_to_decl: &HashMap, +) { + let used_symbols: HashSet<_> = once(PC_REGISTER.to_string()) + .chain(machine_callable_body_symbols(machine)) + .chain(machine_in_links(machine, submachine_to_decl)) + .chain(machine_instructions_symbols(machine)) + .chain(machine_links_symbols(machine)) + .collect(); + + machine + .registers + .retain(|reg| used_symbols.contains(®.name)); +} + +fn machine_remove_unused_submachines( + machine: &mut Machine, + symbols: &HashSet, + submachine_to_decl: &HashMap, +) { + let visited_submachines = machine + .instructions + .iter() + .filter(|ins| symbols.contains(&ins.name)) + .flat_map(|ins| { + ins.instruction + .links + .iter() + .filter_map(|link| submachine_to_decl.get(&link.link.instance)) + }) + .cloned(); + + let used_submachines: HashSet<_> = visited_submachines + .chain(machine_in_links(machine, submachine_to_decl)) + .chain(machine_in_args(machine, submachine_to_decl)) + .chain(symbols.iter().cloned()) + .collect(); + + machine + .submachines + .retain(|sub| used_submachines.contains(&sub.ty.to_string())); +} + +fn machine_remove_unused_instructions(machine: &mut Machine, symbols: &HashSet) { + machine + .instructions + .retain(|ins| symbols.contains(&ins.name)); +} + +/// Retrieves all machines defined within a specific module, relative to the given module path. +/// +/// This function filters the provided set of all machine paths to include only those machines +/// that are defined within the module specified by `path`. It then strips the module path prefix from each +/// machine path to return the machine names relative to that module. +fn machines_in_module( + all_machines: &HashSet, + path: &AbsoluteSymbolPath, +) -> HashSet { + let path_str = path.to_string(); + let path_prefix = if path_str == "::" { + "::".to_string() + } else { + format!("{}{}", path_str, "::") + }; + + all_machines + .iter() + .filter(|machine_path| machine_path.starts_with(&path_prefix)) + .map(|machine_path| { + machine_path + .strip_prefix(&path_prefix) + .unwrap_or(machine_path) + .to_string() + }) + .collect() +} + +/// Creates a mapping between machine names and sets of paths for their instantiated submachines. +fn build_machine_dependencies(asm_file: &AnalysisASMFile) -> HashMap> { + let mut dependencies = HashMap::new(); + + for (path, machine) in asm_file.machines() { + let submachine_to_decl: HashMap = machine + .submachines + .iter() + .map(|sub| (sub.name.clone(), sub.ty.to_string())) + .collect(); + + let submachine_names = dependencies_by_machine(machine, submachine_to_decl); + dependencies.insert(path.to_string(), submachine_names); + } + + dependencies +} + +/// This function analyzes a given `Machine` and gathers all the submachines it depends on. +/// Dependencies are collected from various components of the machine: +/// +/// 1. Instantiated Submachines: Submachines that are directly instantiated within the machine. +/// 2. Submachine Arguments: Submachines referenced in the arguments of the instantiated submachines. +/// 3. Parameters: Submachines specified in the machine's parameters. +/// 4. Links: Submachines that are used in links within the machine. +fn dependencies_by_machine( + machine: &Machine, + submachine_to_decl: HashMap, +) -> HashSet { + let submachine_names: HashSet = machine + .submachines + .iter() + .map(|sub| sub.ty.to_string()) + .chain(machine.submachines.iter().flat_map(|sub| { + sub.args.iter().filter_map(|expr| { + expr_to_ref(expr).and_then(|ref_name| submachine_to_decl.get(&ref_name).cloned()) + }) + })) + .chain( + machine + .params + .0 + .iter() + .map(|param| param.ty.as_ref().unwrap().to_string()), + ) + .chain( + machine + .links + .iter() + .filter_map(|ld| submachine_to_decl.get(&ld.to.instance)) + .cloned(), + ) + .collect(); + submachine_names +} + +fn expr_to_ref(expr: &Expression) -> Option { + match expr { + Expression::Reference(_, NamespacedPolynomialReference { path, .. }) => { + Some(path.to_string()) + } + Expression::PublicReference(_, pref) => Some(pref.clone()), + _ => None, + } +} + +fn collect_all_dependent_machines( + dependencies: &HashMap>, + start: &str, +) -> HashSet { + let mut result = HashSet::new(); + let mut to_visit = vec![start.to_string()]; + let mut visited = HashSet::new(); + + while let Some(machine) = to_visit.pop() { + if visited.insert(machine.clone()) { + result.insert(machine.clone()); + + if let Some(submachines) = dependencies.get(&machine) { + to_visit.extend(submachines.iter().cloned()); + } + } + } + + result +} + +fn machine_callable_body_symbols(machine: &Machine) -> impl Iterator + '_ { + machine.callable.function_definitions().flat_map(|def| { + def.symbols() + .map(|s| s.name.to_string()) + .collect::>() + }) +} + +fn machine_instructions_symbols(machine: &Machine) -> impl Iterator + '_ { + machine + .instructions + .iter() + .flat_map(|ins| ins.symbols().map(|s| s.name.to_string())) +} + +fn machine_links_symbols(machine: &Machine) -> impl Iterator + '_ { + machine + .links + .iter() + .flat_map(|ld| ld.symbols().map(|s| s.name.to_string())) +} + +fn machine_in_args<'a>( + machine: &'a Machine, + submachine_to_decl: &'a HashMap, +) -> impl Iterator + 'a { + machine + .submachines + .iter() + .flat_map(|sm| sm.args.iter().filter_map(expr_to_ref)) + .filter_map(|ref_name| submachine_to_decl.get(&ref_name)) + .cloned() +} + +fn machine_in_links<'a>( + machine: &'a Machine, + submachine_to_decl: &'a HashMap, +) -> impl Iterator + 'a { + machine + .links + .iter() + .filter_map(move |ld| submachine_to_decl.get(&ld.to.instance)) + .cloned() +} diff --git a/asmopt/tests/optimizer.rs b/asmopt/tests/optimizer.rs new file mode 100644 index 0000000000..48c98815f3 --- /dev/null +++ b/asmopt/tests/optimizer.rs @@ -0,0 +1,360 @@ +use powdr_analysis::analyze; +use powdr_asmopt::optimize; +use powdr_parser::parse_asm; + +#[test] +fn remove_unused_machine() { + let input = r#" + machine Main with degree: 8 { + reg pc[@pc]; + reg X[<=]; + reg A; + + instr assert_eq X, A { X = A } + + function main { + assert_eq 1, 1; + return; + } + } + + // This machine should be removed since it's never used + machine Unused with degree: 8 { + reg pc[@pc]; + col witness w; + w = w * w; + } + "#; + + let expectation = r#"machine Main with degree: 8 { + reg pc[@pc]; + reg X[<=]; + reg A; + instr assert_eq X, A{ X = A } + function main { + assert_eq 1, 1; + // END BATCH Unimplemented + return; + // END BATCH + } +} +"#; + + let parsed = parse_asm(None, input).unwrap(); + let analyzed = analyze(parsed).unwrap(); + let optimized = optimize(analyzed).to_string(); + assert_eq!(optimized, expectation); +} + +#[test] +fn remove_unused_instruction_and_machine() { + let input = r#" + machine Main with degree: 8 { + Helper helper; + + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + reg A; + + // This instruction is never used and should be removed + // which will also remove Helper machine since it's the only usage + instr unused X -> Y link ~> Z = helper.double(X); + instr assert_eq X, A { X = A } + + function main { + assert_eq 1, 1; + return; + } + } + + machine Helper with degree: 8 { + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + + function double x: field -> field { + return x + x; + } + } + "#; + + let expectation = r#"machine Main with degree: 8 { + reg pc[@pc]; + reg X[<=]; + reg A; + instr assert_eq X, A{ X = A } + function main { + assert_eq 1, 1; + // END BATCH Unimplemented + return; + // END BATCH + } +} +"#; + + let parsed = parse_asm(None, input).unwrap(); + let analyzed = analyze(parsed).unwrap(); + let optimized = optimize(analyzed).to_string(); + assert_eq!(optimized, expectation); +} + +#[test] +fn keep_machine_with_multiple_references() { + let input = r#" + machine Main with degree: 8 { + Helper helper; + + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + reg A; + + // Two different instructions using the same machine + instr double X -> Y link => Y = helper.double(X); + instr triple X -> Y link => Y = helper.triple(X); + + function main { + // Only using one instruction + A <== double(2); + return; + } + } + + machine Helper with degree: 8 { + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + + function double x: field -> field { return x + x; } + function triple x: field -> field { return x + x + x; } + } + "#; + + let expectation = r#"machine Main with degree: 8 { + ::Helper helper + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + reg A; + instr double X -> Y link => Y = helper.double(X){ } + function main { + A <=Y= double(2); + // END BATCH Unimplemented + return; + // END BATCH + } +} +machine Helper with degree: 8 { + reg pc[@pc]; + function double x: field -> field { + return x + x; + // END BATCH + } + function triple x: field -> field { + return x + x + x; + // END BATCH + } +} +"#; + + let parsed = parse_asm(None, input).unwrap(); + let analyzed = analyze(parsed).unwrap(); + let optimized = optimize(analyzed).to_string(); + assert_eq!(optimized, expectation); +} + +#[test] +fn keep_machine_parameters() { + let input = r#" + machine Main with degree: 8 { + Required required; + ParamMachine sub(required); + Unused unused; + + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + reg A; + + instr compute X -> Y link => Y = sub.compute(X); + + function main { + A <== compute(1); + return; + } + } + + machine ParamMachine(mem: Required) with degree: 8 { + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + + function compute x: field -> field { + return x + x; + } + } + + machine Required with + latch: latch, + operation_id: operation_id + { + operation compute<0> x -> y; + + col fixed latch = [1]*; + col witness operation_id; + col witness x; + col witness y; + + y = x + x; + } + + machine Unused with degree: 8 { + reg pc[@pc]; + col witness w; + w = w * w; + } + "#; + + let expectation = r#"machine Main with degree: 8 { + ::Required required + ::ParamMachine sub(required) + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + reg A; + instr compute X -> Y link => Y = sub.compute(X){ } + function main { + A <=Y= compute(1); + // END BATCH Unimplemented + return; + // END BATCH + } +} +machine ParamMachine with degree: 8 { + reg pc[@pc]; + function compute x: field -> field { + return x + x; + // END BATCH + } +} +machine Required with , latch: latch, operation_id: operation_id { + operation compute<0> x -> y; + pol constant latch = [1]*; + pol commit operation_id; + pol commit x; + pol commit y; + y = x + x; +} +"#; + + let parsed = parse_asm(None, input).unwrap(); + let analyzed = analyze(parsed).unwrap(); + let optimized = optimize(analyzed).to_string(); + assert_eq!(optimized, expectation); +} + +#[test] +fn remove_unused_registers() { + let input = r#" + machine Main with degree: 8 { + Helper helper; + reg pc[@pc]; + reg Y[<=]; + reg Z[<=]; + reg A; + + instr compute X -> A link => X = helper.compute(X); + + function main { + A <== compute(5); + return; + } + } + + machine Helper with degree: 8 { + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + + function compute x: field -> field { + return x + 1; + } + } + "#; + + let expectation = r#"machine Main with degree: 8 { + ::Helper helper + reg pc[@pc]; + reg A; + instr compute X -> A link => X = helper.compute(X){ } + function main { + A <=A= compute(5); + // END BATCH Unimplemented + return; + // END BATCH + } +} +machine Helper with degree: 8 { + reg pc[@pc]; + function compute x: field -> field { + return x + 1; + // END BATCH + } +} +"#; + + let parsed = parse_asm(None, input).unwrap(); + let analyzed = analyze(parsed).unwrap(); + let optimized = optimize(analyzed).to_string(); + assert_eq!(optimized, expectation); +} + +#[test] +fn keep_linked_submachine() { + let input = r#" + machine Main with degree: 8 { + Helper helper; + reg pc[@pc]; + reg X[<=]; + + link => X = helper.check(X); + + function main { + return; + } + } + + machine Helper with degree: 8 { + reg pc[@pc]; + + function check x: field -> field { + return x + x; + } + } + "#; + + let expectation = r#"machine Main with degree: 8 { + ::Helper helper + reg pc[@pc]; + reg X[<=]; + function main { + return; + // END BATCH + } + link => X = helper.check(X); +} +machine Helper with degree: 8 { + reg pc[@pc]; + function check x: field -> field { + return x + x; + // END BATCH + } +} +"#; + + let parsed = parse_asm(None, input).unwrap(); + let analyzed = analyze(parsed).unwrap(); + let optimized = optimize(analyzed).to_string(); + assert_eq!(optimized, expectation); +} diff --git a/ast/src/asm_analysis/mod.rs b/ast/src/asm_analysis/mod.rs index ada87eb8c9..df82e6bdaa 100644 --- a/ast/src/asm_analysis/mod.rs +++ b/ast/src/asm_analysis/mod.rs @@ -3,7 +3,7 @@ mod display; use std::{ collections::{ btree_map::{IntoIter, Iter, IterMut}, - BTreeMap, BTreeSet, + BTreeMap, BTreeSet, HashSet, }, iter::{once, repeat}, ops::ControlFlow, @@ -882,6 +882,19 @@ impl Module { self.ordering.push(StatementReference::Module(name)); } + /// Retains only the machines with the specified names. + /// Ordering is preserved. + pub fn retain_machines(&mut self, names: HashSet) { + self.machines.retain(|key, _| names.contains(key)); + self.ordering.retain(|statement| { + if let StatementReference::MachineDeclaration(decl_name) = statement { + names.contains(decl_name) + } else { + true + } + }); + } + pub fn into_inner( self, ) -> ( diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 541c008d0e..5dc63b9814 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -16,7 +16,7 @@ use powdr_ast::parsed::visitor::{AllChildren, Children, ExpressionVisitable}; use powdr_ast::parsed::Number; use powdr_number::{BigUint, FieldElement}; -mod referenced_symbols; +pub mod referenced_symbols; use referenced_symbols::{ReferencedSymbols, SymbolReference}; diff --git a/pilopt/src/referenced_symbols.rs b/pilopt/src/referenced_symbols.rs index b7ee422017..0f0bf45048 100644 --- a/pilopt/src/referenced_symbols.rs +++ b/pilopt/src/referenced_symbols.rs @@ -4,11 +4,20 @@ use powdr_ast::{ analyzed::{ Expression, FunctionValueDefinition, PolynomialReference, Reference, TypedExpression, }, + asm_analysis::{ + AssignmentStatement, Expression as ExpressionASM, FunctionBody, FunctionDefinitionRef, + FunctionStatement, FunctionSymbol, InstructionDefinitionStatement, InstructionStatement, + LinkDefinition, Return, + }, parsed::{ - asm::SymbolPath, + asm::{ + AssignmentRegister, CallableRef, Instruction, InstructionBody, LinkDeclaration, Param, + Params, SymbolPath, + }, types::Type, visitor::{AllChildren, Children}, - EnumDeclaration, StructDeclaration, TraitImplementation, TypeDeclaration, + EnumDeclaration, FunctionDefinition, NamespacedPolynomialReference, PilStatement, + StructDeclaration, TraitImplementation, TypeDeclaration, }, }; @@ -20,7 +29,7 @@ pub trait ReferencedSymbols { fn symbols(&self) -> Box> + '_>; } -#[derive(Clone, Hash, Ord, PartialOrd, Eq, PartialEq)] +#[derive(Clone, Hash, Ord, PartialOrd, Eq, PartialEq, Debug)] pub struct SymbolReference<'a> { pub name: Cow<'a, str>, pub type_args: Option<&'a Vec>, @@ -59,6 +68,15 @@ impl<'a> From<&'a PolynomialReference> for SymbolReference<'a> { } } +impl<'a> From<&'a NamespacedPolynomialReference> for SymbolReference<'a> { + fn from(poly: &'a NamespacedPolynomialReference) -> Self { + SymbolReference { + name: poly.path.to_string().into(), + type_args: None, + } + } +} + impl ReferencedSymbols for FunctionValueDefinition { fn symbols(&self) -> Box> + '_> { match self { @@ -85,7 +103,7 @@ impl ReferencedSymbols for FunctionValueDefinition { } } -impl ReferencedSymbols for TraitImplementation { +impl ReferencedSymbols for TraitImplementation { fn symbols(&self) -> Box> + '_> { Box::new( once(SymbolReference::from(&self.name)) @@ -95,7 +113,7 @@ impl ReferencedSymbols for TraitImplementation { } } -impl ReferencedSymbols for TypeDeclaration { +impl ReferencedSymbols for TypeDeclaration { fn symbols(&self) -> Box> + '_> { match self { TypeDeclaration::Enum(enum_decl) => enum_decl.symbols(), @@ -104,7 +122,7 @@ impl ReferencedSymbols for TypeDeclaration { } } -impl ReferencedSymbols for EnumDeclaration { +impl ReferencedSymbols for EnumDeclaration { fn symbols(&self) -> Box> + '_> { Box::new( self.variants @@ -116,7 +134,7 @@ impl ReferencedSymbols for EnumDeclaration { } } -impl ReferencedSymbols for StructDeclaration { +impl ReferencedSymbols for StructDeclaration { fn symbols(&self) -> Box> + '_> { Box::new(self.fields.iter().flat_map(|named| named.ty.symbols())) } @@ -149,8 +167,231 @@ fn symbols_in_expression( } } -impl ReferencedSymbols for Type { +fn symbols_in_expression_asm( + e: &ExpressionASM, +) -> Option> + '_>> { + match e { + ExpressionASM::PublicReference(_, name) => { + Some(Box::new(once(SymbolReference::from(name)))) + } + ExpressionASM::Reference(_, pr @ NamespacedPolynomialReference { type_args, .. }) => { + let type_iter = type_args + .iter() + .flat_map(|t| t.iter()) + .flat_map(|t| t.symbols()); + + Some(Box::new(type_iter.chain(once(SymbolReference::from(pr))))) + } + _ => None, + } +} + +impl ReferencedSymbols for Type { fn symbols(&self) -> Box> + '_> { Box::new(self.contained_named_types().map(SymbolReference::from)) } } + +impl ReferencedSymbols for InstructionDefinitionStatement { + fn symbols(&self) -> Box> + '_> { + Box::new(once(SymbolReference::from(&self.name)).chain(self.instruction.symbols())) + } +} + +impl ReferencedSymbols for Instruction { + fn symbols(&self) -> Box> + '_> { + Box::new( + self.links + .iter() + .flat_map(|l| l.symbols()) + .chain(self.body.symbols()), + ) + } +} + +impl ReferencedSymbols for Params { + fn symbols(&self) -> Box> + '_> { + Box::new( + self.inputs + .iter() + .flat_map(|p| p.symbols()) + .chain(self.outputs.iter().flat_map(|p| p.symbols())), + ) + } +} + +impl ReferencedSymbols for Param { + fn symbols(&self) -> Box> + '_> { + Box::new( + once(SymbolReference::from(&self.name)) + .chain(self.ty.as_ref().map(SymbolReference::from)), + ) + } +} + +impl ReferencedSymbols for LinkDeclaration { + fn symbols(&self) -> Box> + '_> { + Box::new(self.flag.symbols().chain(self.link.symbols())) + } +} + +impl ReferencedSymbols for CallableRef { + fn symbols(&self) -> Box> + '_> { + Box::new( + once(SymbolReference::from(&self.instance)) + .chain(once(SymbolReference::from(&self.callable))) + .chain(self.params.symbols()), + ) + } +} + +impl ReferencedSymbols for LinkDefinition { + fn symbols(&self) -> Box> + '_> { + Box::new( + self.link_flag + .symbols() + .chain(self.instr_flag.iter().flat_map(|f| f.symbols())) + .chain(self.to.symbols()), + ) + } +} + +impl ReferencedSymbols for FunctionDefinitionRef<'_> { + fn symbols(&self) -> Box> + '_> { + Box::new(once(SymbolReference::from(self.name)).chain(self.function.symbols())) + } +} + +impl ReferencedSymbols for FunctionSymbol { + fn symbols(&self) -> Box> + '_> { + Box::new(self.body.symbols().chain(self.params.symbols())) + } +} + +impl ReferencedSymbols for InstructionBody { + fn symbols(&self) -> Box> + '_> { + Box::new(self.0.iter().flat_map(|e| e.symbols())) + } +} + +impl ReferencedSymbols for PilStatement { + fn symbols(&self) -> Box> + '_> { + match self { + PilStatement::Include(_, _) => Box::new(std::iter::empty()), + PilStatement::Namespace(_, _, _) => Box::new(std::iter::empty()), + PilStatement::LetStatement(_, name, type_scheme, expression) => Box::new( + type_scheme + .iter() + .flat_map(|ts| ts.ty.symbols()) + .chain(expression.iter().flat_map(|e| e.symbols())) + .chain(once(SymbolReference::from(name))), + ), + PilStatement::PolynomialDefinition(_, polynomial_name, expression) => Box::new( + expression + .symbols() + .chain(std::iter::once(SymbolReference::from( + &polynomial_name.name, + ))), + ), + PilStatement::PublicDeclaration( + _, + _, + namespaced_polynomial_reference, + expression, + expression1, + ) => Box::new(Box::new( + once(SymbolReference::from(namespaced_polynomial_reference)) + .chain(expression.iter().flat_map(|e| e.symbols())) + .chain(expression1.symbols()), + )), + PilStatement::PolynomialConstantDefinition(_, _, function_definition) => { + function_definition.symbols() + } + PilStatement::PolynomialCommitDeclaration(_, _, _, function_definition) => { + Box::new(function_definition.iter().flat_map(|f| f.symbols())) + } + PilStatement::EnumDeclaration(_, enum_declaration) => enum_declaration.symbols(), + PilStatement::StructDeclaration(_, struct_declaration) => struct_declaration.symbols(), + PilStatement::TraitImplementation(_, trait_implementation) => { + trait_implementation.symbols() + } + PilStatement::TraitDeclaration(_, _) => Box::new(std::iter::empty()), + PilStatement::Expression(_, expression) => Box::new( + expression + .all_children() + .flat_map(symbols_in_expression_asm) + .flatten(), + ), + } + } +} + +impl ReferencedSymbols for FunctionDefinition { + fn symbols(&self) -> Box> + '_> { + match self { + FunctionDefinition::TypeDeclaration(type_declaration) => type_declaration.symbols(), + FunctionDefinition::Array(..) + | FunctionDefinition::Expression(..) + | FunctionDefinition::TraitDeclaration(..) => { + Box::new(self.children().flat_map(|e| e.symbols())) + } + } + } +} + +impl ReferencedSymbols for FunctionBody { + fn symbols(&self) -> Box> + '_> { + Box::new(self.statements.iter().flat_map(|e| e.symbols())) + } +} + +impl ReferencedSymbols for FunctionStatement { + fn symbols(&self) -> Box> + '_> { + match self { + FunctionStatement::Assignment(a) => a.symbols(), + FunctionStatement::Instruction(i) => i.symbols(), + FunctionStatement::Return(r) => r.symbols(), + _ => Box::new(std::iter::empty()), + } + } +} + +impl ReferencedSymbols for AssignmentStatement { + fn symbols(&self) -> Box> + '_> { + Box::new( + self.lhs_with_reg + .iter() + .flat_map(|(n, reg)| { + let name_ref = Some(SymbolReference::from(n)); + let reg_ref = match reg { + AssignmentRegister::Register(name) => Some(SymbolReference::from(name)), + AssignmentRegister::Wildcard => None, + }; + [name_ref, reg_ref].into_iter().flatten() + }) + .chain(self.rhs.as_ref().symbols()), + ) + } +} + +impl ReferencedSymbols for Return { + fn symbols(&self) -> Box> + '_> { + Box::new(self.values.iter().flat_map(|expr| expr.symbols())) + } +} + +impl ReferencedSymbols for InstructionStatement { + fn symbols(&self) -> Box> + '_> { + Box::new(once(SymbolReference::from(&self.instruction))) + } +} + +impl ReferencedSymbols for ExpressionASM { + fn symbols(&self) -> Box> + '_> { + Box::new( + self.all_children() + .flat_map(symbols_in_expression_asm) + .flatten(), + ) + } +} diff --git a/pipeline/Cargo.toml b/pipeline/Cargo.toml index e1c0e41603..45f85e1006 100644 --- a/pipeline/Cargo.toml +++ b/pipeline/Cargo.toml @@ -21,6 +21,7 @@ estark-starky-simd = ["powdr-backend/estark-starky-simd"] [dependencies] powdr-airgen.workspace = true powdr-analysis.workspace = true +powdr-asmopt.workspace = true powdr-asm-to-pil.workspace = true powdr-ast.workspace = true powdr-backend.workspace = true diff --git a/pipeline/src/pipeline.rs b/pipeline/src/pipeline.rs index af64b8a022..5cf852ba92 100644 --- a/pipeline/src/pipeline.rs +++ b/pipeline/src/pipeline.rs @@ -55,6 +55,8 @@ pub struct Artifacts { /// The analyzed .asm file: Assignment registers are inferred, instructions /// are batched and some properties are checked. analyzed_asm: Option, + /// The optimized version of the analyzed ASM file. + optimized_asm: Option, /// A machine collection that only contains constrained machines. constrained_machine_collection: Option, /// The airgen graph, i.e. a collection of constrained machines with resolved @@ -156,6 +158,7 @@ impl Clone for Artifacts { parsed_asm_file: self.parsed_asm_file.clone(), resolved_module_tree: self.resolved_module_tree.clone(), analyzed_asm: self.analyzed_asm.clone(), + optimized_asm: self.optimized_asm.clone(), constrained_machine_collection: self.constrained_machine_collection.clone(), linked_machine_graph: self.linked_machine_graph.clone(), parsed_pil_file: self.parsed_pil_file.clone(), @@ -786,14 +789,33 @@ impl Pipeline { Ok(self.artifact.analyzed_asm.as_ref().unwrap()) } + pub fn compute_optimized_asm(&mut self) -> Result<&AnalysisASMFile, Vec> { + if let Some(ref optimized_asm) = self.artifact.optimized_asm { + return Ok(optimized_asm); + } + + self.compute_analyzed_asm()?; + let analyzed_asm = self.artifact.analyzed_asm.take().unwrap(); + + self.log("Optimizing asm..."); + let optimized = powdr_asmopt::optimize(analyzed_asm); + self.artifact.optimized_asm = Some(optimized); + + Ok(self.artifact.optimized_asm.as_ref().unwrap()) + } + + pub fn optimized_asm(&self) -> Result<&AnalysisASMFile, Vec> { + Ok(self.artifact.optimized_asm.as_ref().unwrap()) + } + pub fn compute_constrained_machine_collection( &mut self, ) -> Result<&AnalysisASMFile, Vec> { if self.artifact.constrained_machine_collection.is_none() { self.artifact.constrained_machine_collection = Some({ - self.compute_analyzed_asm()?; - let analyzed_asm = self.artifact.analyzed_asm.take().unwrap(); - powdr_asm_to_pil::compile::(analyzed_asm) + self.compute_optimized_asm()?; + let optimized_asm = self.artifact.optimized_asm.take().unwrap(); + powdr_asm_to_pil::compile::(optimized_asm) }); } diff --git a/test_data/asm/book/declarations.asm b/test_data/asm/book/declarations.asm index 75593153b2..3710f19307 100644 --- a/test_data/asm/book/declarations.asm +++ b/test_data/asm/book/declarations.asm @@ -28,9 +28,10 @@ machine Main with degree: 4 { utils::constrain_incremented_by(x, 0); // We define an instruction that uses a complicated way to increment a register. - instr incr_a { A = utils::incremented(A) } + instr incr_a { A' = utils::incremented(A) } function main { + incr_a; return; } } \ No newline at end of file