diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 12fe0a5..de21a38 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -628,7 +628,8 @@ function _getp(sys, ::ArraySymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) if is_timeseries_parameter(sys, p) ts_idx = timeseries_parameter_index(sys, p) - return GetParameterTimeseriesIndex(idx, ts_idx) + return GetParameterTimeseriesIndex( + GetParameterIndex(idx), GetParameterIndex(ts_idx)) else return GetParameterIndex(idx) end @@ -750,5 +751,5 @@ function _setp_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym) if is_parameter(indp, sym) return OOPSetter(_root_indp(indp), parameter_index(indp, sym), false) end - error("$sym is not a valid parameter") + return setp_oop(indp, collect(sym)) end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 243f7f2..6e1d3d0 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -468,7 +468,7 @@ function _setsym_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym) return setsym_oop(indp, idx) elseif (idx = parameter_index(indp, sym)) !== nothing return FullSetter( - nothing, OOPSetter(indp, idx isa AbstractArray ? idx : (idx,), false)) + nothing, OOPSetter(indp, idx, false)) end return setsym_oop(indp, collect(sym)) end diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 8657d46..b4c2174 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -6,3 +6,4 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [compat] SymbolicUtils = "3.2" +ModelingToolkit = "9.60" diff --git a/test/downstream/array_indexing.jl b/test/downstream/array_indexing.jl new file mode 100644 index 0000000..1fcf4fd --- /dev/null +++ b/test/downstream/array_indexing.jl @@ -0,0 +1,58 @@ +using ModelingToolkit, SymbolicIndexingInterface +using ModelingToolkit: t_nounits as t, D_nounits as D + +@variables x(t)[1:2] +@parameters p[1:2, 1:2] q(t)[1:2] r[1:2] + +ev = [x[1] ~ 2.0] => [q ~ -ones(2)] +@mtkbuild sys = ODESystem( + [D(x) ~ p * x + q + r], t, [x], [p, q, r...]; continuous_events = [ev]) +@test is_timeseries_parameter(sys, q) +@test !is_timeseries_parameter(sys, p) +@test !is_parameter(sys, r) +@test is_parameter(sys, r[1]) +@test is_parameter(sys, r[2]) + +prob = ODEProblem( + sys, [x => ones(2)], (0.0, 10.0), [p => ones(2, 2), q => ones(2), r => 2ones(2)]) +@test prob.ps[q] ≈ ones(2) +@test prob.ps[p] ≈ ones(2, 2) +@test prob.ps[r] ≈ 2ones(2) +@test prob.ps[p * q] ≈ 2ones(2) + +@test getu(sys, p)(prob) ≈ ones(2, 2) +@test getu(sys, r)(prob) ≈ 2ones(2) + +prob.ps[p] = 2ones(2, 2) +@test prob.ps[p] ≈ 2ones(2, 2) +prob.ps[q] = 2ones(2) +@test prob.ps[q] ≈ 2ones(2) +prob.ps[r] = ones(2) +@test prob.ps[r] ≈ ones(2) + +setter = setp_oop(sys, p) +newp = setter(prob, 3ones(2, 2)) +@test getp(sys, p)(newp) ≈ 3ones(2, 2) +setter = setp_oop(sys, r) +newp = setter(prob, 3ones(2)) +@test getp(sys, r)(newp) ≈ 3ones(2) + +setter = setsym_oop(sys, p) +_, newp = setter(prob, 3ones(2, 2)) +@test getp(sys, p)(newp) ≈ 3ones(2, 2) +setter = setsym_oop(sys, r) +_, newp = setter(prob, 3ones(2)) +@test getp(sys, r)(newp) ≈ 3ones(2) + +@test prob[x] ≈ ones(2) +prob[x] = 2ones(2) +@test prob[x] ≈ 2ones(2) + +setu(sys, p)(prob, 4ones(2, 2)) +@test prob.ps[p] ≈ 4ones(2, 2) +setu(sys, r)(prob, 4ones(2)) +@test prob.ps[r] ≈ 4ones(2) + +setter = setsym_oop(sys, x) +newu, newp = setter(prob, 3ones(2)) +@test getu(sys, x)(newu) ≈ 3ones(2) diff --git a/test/runtests.jl b/test/runtests.jl index 363d9fb..13550e2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -58,4 +58,7 @@ if GROUP == "All" || GROUP == "Downstream" @safetestset "remake_buffer with array symbolics test" begin @time include("downstream/remake_arrayvars.jl") end + @safetestset "array indexing" begin + @time include("downstream/array_indexing.jl") + end end