Skip to content

Commit

Permalink
fix: make array hack also search callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 24, 2024
1 parent 36eb83f commit 5ec6c20
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
@set! sys.unknowns = unknowns

obs, subeqs, deps = cse_and_array_hacks(
obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
sys, obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)

@set! sys.eqs = neweqs
@set! sys.observed = obs
Expand Down Expand Up @@ -627,7 +627,7 @@ if all `p[i]` are present and the unscalarized form is used in any equation (obs
not) we first count the number of times the scalarized form of each observed variable
occurs in observed equations (and unknowns if it's split).
"""
function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = true)
function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, array = true)
# HACK 1
# mapping of rhs to temporary CSE variable
# `f(...) => tmpvar` in above example
Expand Down Expand Up @@ -725,6 +725,11 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array =
for eq in neweqs
vars!(all_vars, eq.rhs)
end

# also count unscalarized variables used in callbacks
for ev in Iterators.flatten((continuous_events(sys), discrete_events(sys)))
vars!(all_vars, ev)
end
obs_arr_eqs = Equation[]
for (arrvar, cnt) in arr_obs_occurrences
cnt == length(arrvar) || continue
Expand Down

0 comments on commit 5ec6c20

Please sign in to comment.