Skip to content

Commit

Permalink
fix: fix maketerm for ArrayOp involving broadcasted symbolics
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 21, 2024
1 parent 6a195b4 commit d272593
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
9 changes: 7 additions & 2 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,13 @@ ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T}

function SymbolicUtils.maketerm(::Type{<:ArrayOp}, f, args, m)
args = map(args) do arg
if iscall(arg) && operation(arg) == Ref && symbolic_type(only(arguments(arg))) == NotSymbolic()
return Ref(only(arguments(arg)))
if iscall(arg) && operation(arg) == Ref
inner = only(arguments(arg))
if symbolic_type(inner) == NotSymbolic()
return Ref(inner)
else
return inner
end
else
return arg
end
Expand Down
4 changes: 3 additions & 1 deletion test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,16 @@ end
end

@testset "maketerm" begin
@variables A[1:5, 1:5] B[1:5, 1:5]
@variables A[1:5, 1:5] B[1:5, 1:5] C

T = unwrap(3A)
@test isequal(T, Symbolics.maketerm(typeof(T), operation(T), arguments(T), nothing))
T2 = unwrap(3B)
@test isequal(T2, Symbolics.maketerm(typeof(T), operation(T), [*, 3, unwrap(B)], nothing))
T3 = unwrap(A .^ 2)
@test isequal(T3, Symbolics.maketerm(typeof(T3), operation(T3), arguments(T3), nothing))
T4 = unwrap(A .* C)
@test isequal(T4, Symbolics.maketerm(typeof(T4), operation(T4), arguments(T4), nothing))
end

getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue)
Expand Down

0 comments on commit d272593

Please sign in to comment.