diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 59ad3c7a30..a3efc87e8a 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -152,10 +152,23 @@ fn collect_required_symbols<'a, T: FieldElement>( .values() .map(|p| SymbolReference::from(&p.polynomial.name)), ); + + let poly_ref_to_id = pil_file + .definitions + .values() + .filter_map(|(symbol, _)| matches!(symbol.kind, SymbolKind::Poly(_)).then_some(symbol)) + .flat_map(|symbol| symbol.array_elements()) + .collect::>(); + for fun in &pil_file.prover_functions { for e in fun.all_children() { if let Expression::Reference(_, Reference::Poly(poly_ref)) = e { - required_names.insert(SymbolReference::from(poly_ref)); + let symbol_ref = match poly_ref_to_id.get(&poly_ref.name) { + Some(poly_id) => poly_id_to_definition_name[poly_id].into(), + None => SymbolReference::from(poly_ref), + }; + + required_names.insert(symbol_ref); } } } diff --git a/pilopt/tests/optimizer.rs b/pilopt/tests/optimizer.rs index 5b66e79b80..6a9ebf2edf 100644 --- a/pilopt/tests/optimizer.rs +++ b/pilopt/tests/optimizer.rs @@ -283,9 +283,7 @@ fn enum_ref_by_trait() { } #[test] -#[should_panic = "Symbol not found: N::x[0]"] fn handle_array_references_in_prover_functions() { - // Reproduces https://github.com/powdr-labs/powdr/issues/2051 let input = r#"namespace N(8); col witness x[1]; @@ -300,5 +298,16 @@ fn handle_array_references_in_prover_functions() { } }; "#; - optimize(analyze_string::(input).unwrap()).to_string(); + let expectation = r#"namespace N(8); + col witness x[1]; + N::x[0]' = N::x[0] + 1; + { + let intermediate = N::x[0] + 1_expr; + query |i| { + let _: expr = intermediate; + } + }; +"#; + let optimized = optimize(analyze_string::(input).unwrap()).to_string(); + assert_eq!(optimized, expectation); }