From a5dfae861cd4b01a48ea31278ff8345e768de27c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 5 Nov 2024 12:02:30 +0530 Subject: [PATCH] fix: fix `symbolic_type` for unwrapped array symbolics --- src/variable.jl | 8 ++++++++ test/symbolic_indexing_interface_trait.jl | 1 + 2 files changed, 9 insertions(+) diff --git a/src/variable.jl b/src/variable.jl index 15db8d35f..ba446a8e2 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -489,6 +489,14 @@ getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val) SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic() SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic() +function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: AbstractArray, T <: Symbolic{S}} + ArraySymbolic() +end +# need this otherwise the `::Type{<:BasicSymbolic}` method in SymbolicUtils is +# more specific +function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: AbstractArray, T <: BasicSymbolic{S}} + ArraySymbolic() +end SymbolicIndexingInterface.hasname(x::Union{Num,Arr}) = hasname(unwrap(x)) diff --git a/test/symbolic_indexing_interface_trait.jl b/test/symbolic_indexing_interface_trait.jl index 15d38e328..d25a08cfe 100644 --- a/test/symbolic_indexing_interface_trait.jl +++ b/test/symbolic_indexing_interface_trait.jl @@ -10,6 +10,7 @@ using SymbolicIndexingInterface @test symbolic_type(x) == ScalarSymbolic() @variables y[1:3] @test symbolic_type(y) == ArraySymbolic() +@test symbolic_type(Symbolics.unwrap(y)) == ArraySymbolic() @test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),)) @test symbolic_type(Symbolics.unwrap(y .* y)) == ArraySymbolic() @variables z(..)