Skip to content

Commit

Permalink
Add inner_unwrap for array ops
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Apr 25, 2024
1 parent 56e7f47 commit e2b85a4
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/array-lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,10 @@ function getindex(A::Arr, j::Symbolic{<:Integer}, i::Int)
wrap(unwrap(A)[j, i])
end

inner_unwrap(x) = x isa AbstractArray ? unwrap.(x) : x
function _matmul(A, B)
A = inner_unwrap(A)
B = inner_unwrap(B)
@syms i::Int j::Int k::Int
if isadjointvec(A)
op = operation(A.term)
Expand All @@ -295,6 +298,8 @@ end
@wrapped (*)(A::AbstractVector, B::AbstractMatrix) = _matmul(A, B)

function _matvec(A, b)
A = inner_unwrap(A)
b = inner_unwrap(b)
@syms i::Int k::Int
sym_res = @arrayop (i,) A[i, k] * b[k] term=(A*b)
if isdot(A, b)
Expand Down

0 comments on commit e2b85a4

Please sign in to comment.