Skip to content

Commit

Permalink
Merge pull request #3284 from AayushSabharwal/as/indexing-hotfix
Browse files Browse the repository at this point in the history
fix: fix `timeseries_parameter_index` for array symbolics
  • Loading branch information
ChrisRackauckas authored Dec 25, 2024
2 parents 4792360 + 1835a56 commit dad05e5
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ function flatten_equations(eqs)
error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar")
size(eq.lhs) == size(eq.rhs) ||
error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))")
return collect(eq.lhs) .~ collect(eq.rhs)
return vec(collect(eq.lhs) .~ collect(eq.rhs))
else
eq
end
Expand Down
4 changes: 3 additions & 1 deletion src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,16 @@ function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sy
sym = get(ic.symbol_to_variable, sym, nothing)
sym === nothing && return nothing
end
sym = unwrap(sym)
idx = check_index_map(ic.discrete_idx, sym)
idx === nothing ||
return ParameterTimeseriesIndex(idx.clock_idx, (idx.buffer_idx, idx.idx_in_clock))
iscall(sym) && operation(sym) == getindex || return nothing
args = arguments(sym)
idx = timeseries_parameter_index(ic, args[1])
idx === nothing && return nothing
ParameterIndex(idx.portion, (idx.idx..., args[2:end]...), idx.validate_size)
return ParameterTimeseriesIndex(
idx.timeseries_idx, (idx.parameter_idx..., args[2:end]...))
end

function check_index_map(idxmap, sym)
Expand Down
4 changes: 3 additions & 1 deletion test/if_lifting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ using ModelingToolkit: t_nounits as t, D_nounits as D, IfLifting, no_if_lift
@test operation(only(equations(ss2)).rhs) === ifelse

discvar = only(parameters(ss2))
prob2 = ODEProblem(ss2, [x => 0.0], (0.0, 5.0))
prob1 = ODEProblem(ss1, [ss1.x => 0.0], (0.0, 5.0))
sol1 = solve(prob1, Tsit5())
prob2 = ODEProblem(ss2, [ss2.x => 0.0], (0.0, 5.0))
sol2 = solve(prob2, Tsit5())
@test count(isapprox(pi), sol2.t) == 2
@test any(isapprox(pi), sol2.discretes[1].t)
Expand Down
11 changes: 11 additions & 0 deletions test/symbolic_indexing_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,14 @@ end
end
@test isempty(get_all_timeseries_indexes(sys, a))
end

@testset "`timeseries_parameter_index` on unwrapped scalarized timeseries parameter" begin
@variables x(t)[1:2]
@parameters p(t)[1:2, 1:2]
ev = [x[1] ~ 2.0] => [p ~ -ones(2, 2)]
@mtkbuild sys = ODESystem(D(x) ~ p * x, t; continuous_events = [ev])
p = ModelingToolkit.unwrap(p)
@test timeseries_parameter_index(sys, p) === ParameterTimeseriesIndex(1, (1, 1))
@test timeseries_parameter_index(sys, p[1, 1]) ===
ParameterTimeseriesIndex(1, (1, 1, 1, 1))
end

0 comments on commit dad05e5

Please sign in to comment.