diff --git a/Project.toml b/Project.toml index dd3bfd159..d7d14695f 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,6 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" @@ -35,6 +34,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7" @@ -64,7 +64,6 @@ MacroTools = "0.5" NaNMath = "0.3, 1" PrecompileTools = "1" RecipesBase = "1.1" -RecursiveArrayTools = "2" Reexport = "0.2, 1" ReferenceTests = "0.9" Requires = "1.1" @@ -73,6 +72,7 @@ SciMLBase = "1.8, 2" Setfield = "0.7, 0.8, 1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "1.1" +SymbolicIndexingInterface = "0.3" SymbolicUtils = "1.4" TreeViews = "0.3" julia = "1.6" diff --git a/src/Symbolics.jl b/src/Symbolics.jl index f9565f4dc..e0afeaa75 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -35,6 +35,8 @@ using PrecompileTools using RuntimeGeneratedFunctions using SciMLBase, IfElse using MacroTools + + using SymbolicIndexingInterface end @reexport using SymbolicUtils RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/num.jl b/src/num.jl index 6646e8e6e..7244801a7 100644 --- a/src/num.jl +++ b/src/num.jl @@ -19,9 +19,6 @@ Num(x::Num) = x # ideally this should never be called (n::Num)(args...) = Num(value(n)(map(value,args)...)) value(x) = unwrap(x) -SciMLBase.issymbollike(::Num) = true -SciMLBase.issymbollike(::SymbolicUtils.Symbolic) = true - SymbolicUtils.@number_methods( Num, Num(f(value(a))), @@ -197,6 +194,3 @@ function Base.Docs.getdoc(x::Num) end Markdown.parse(join(strings, "\n\n ")) end - -using RecursiveArrayTools -RecursiveArrayTools.issymbollike(::Union{BasicSymbolic,Num}) = true diff --git a/src/variable.jl b/src/variable.jl index 19f101503..b2b3d3880 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -406,7 +406,16 @@ end getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val) -getname(x, val=_fail) = _getname(unwrap(x), val) +SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic() +SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic() + +SymbolicIndexingInterface.hasname(x::Union{Num,Arr}) = hasname(unwrap(x)) + +function SymbolicIndexingInterface.hasname(x::Symbolic) + issym(x) || !istree(x) || istree(x) && (issym(operation(x)) || operation(x) == getindex) +end + +SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val) function getparent(x, val=_fail) maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing) diff --git a/src/wrapper-types.jl b/src/wrapper-types.jl index c36fc9eca..f1007dbd3 100644 --- a/src/wrapper-types.jl +++ b/src/wrapper-types.jl @@ -17,9 +17,9 @@ function set_where(subt, supert) Expr(:where, supert, Ts...) end -getname(x::Symbol) = x +SymbolicIndexingInterface.getname(x::Symbol) = x -function getname(x::Expr) +function SymbolicIndexingInterface.getname(x::Expr) @assert x.head == :curly return x.args[1] end diff --git a/test/overloads.jl b/test/overloads.jl index 474218554..a32b39c54 100644 --- a/test/overloads.jl +++ b/test/overloads.jl @@ -237,6 +237,3 @@ for f in [<, <=, >, >=, isless] end @test_nowarn binomial(t, 1) - -using RecursiveArrayTools -@test RecursiveArrayTools.issymbollike(t) diff --git a/test/runtests.jl b/test/runtests.jl index e73d5ffca..27d013a61 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,6 +47,15 @@ if GROUP == "All" || GROUP == "Core" @safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end end +if GROUP == "All" || GROUP == "Core" || GROUP == "SymbolicIndexingInterface" + @safetestset "SymbolicIndexingInterface Trait Test" begin + include("symbolic_indexing_interface_trait.jl") + end + @safetestset "SymbolicIndexingInterface Parameter Indexing Test" begin + include("symbolic_indexing_interface_parameter_indexing.jl") + end +end + if GROUP == "Downstream" activate_downstream_env() #@time @safetestset "ParameterizedFunctions MATLABDiffEq Regression Test" begin include("downstream/ParameterizedFunctions_MATLAB.jl") end diff --git a/test/symbolic_indexing_interface_parameter_indexing.jl b/test/symbolic_indexing_interface_parameter_indexing.jl new file mode 100644 index 000000000..a05831484 --- /dev/null +++ b/test/symbolic_indexing_interface_parameter_indexing.jl @@ -0,0 +1,23 @@ +using SymbolicIndexingInterface +using Symbolics + +struct FakeIntegrator{P} + p::P +end + +SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys +SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p + +@variables a[1:2] b +sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t]) +p = [1.0, 2.0, 3.0] +fi = FakeIntegrator(copy(p)) +for (i, sym) in [(1, a[1]), (2, a[2]), (3, b), ([1,2], a), ([1, 3], [a[1], b]), ((2, 3), (a[2], b))] + get = getp(sys, sym) + set! = setp(sys, sym) + true_value = i isa Tuple ? getindex.((p,), i) : p[i] + @test get(fi) == true_value + set!(fi, 0.5 .* i) + @test get(fi) == 0.5 .* i + set!(fi, true_value) +end diff --git a/test/symbolic_indexing_interface_trait.jl b/test/symbolic_indexing_interface_trait.jl new file mode 100644 index 000000000..52d1579ae --- /dev/null +++ b/test/symbolic_indexing_interface_trait.jl @@ -0,0 +1,12 @@ +using Symbolics +using SymbolicUtils +using SymbolicIndexingInterface + +@test all(symbolic_type.([SymbolicUtils.BasicSymbolic, Symbolics.Num]) .== + (ScalarSymbolic(),)) +@test symbolic_type(Symbolics.Arr) == ArraySymbolic() +@variables x +@test symbolic_type(x) == ScalarSymbolic() +@variables y[1:3] +@test symbolic_type(y) == ArraySymbolic() +@test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),))