Skip to content

Commit

Permalink
fix: fix symbolic_type for unwrapped array symbolics
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 5, 2024
1 parent 1305b7e commit a5dfae8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,14 @@ getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val)

SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic()
SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic()
function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: AbstractArray, T <: Symbolic{S}}
ArraySymbolic()
end
# need this otherwise the `::Type{<:BasicSymbolic}` method in SymbolicUtils is
# more specific
function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: AbstractArray, T <: BasicSymbolic{S}}
ArraySymbolic()
end

SymbolicIndexingInterface.hasname(x::Union{Num,Arr}) = hasname(unwrap(x))

Expand Down
1 change: 1 addition & 0 deletions test/symbolic_indexing_interface_trait.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using SymbolicIndexingInterface
@test symbolic_type(x) == ScalarSymbolic()
@variables y[1:3]
@test symbolic_type(y) == ArraySymbolic()
@test symbolic_type(Symbolics.unwrap(y)) == ArraySymbolic()
@test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),))
@test symbolic_type(Symbolics.unwrap(y .* y)) == ArraySymbolic()
@variables z(..)
Expand Down

0 comments on commit a5dfae8

Please sign in to comment.