Skip to content

Commit

Permalink
[NDTensors] Fix CPU performance issue caused by bad mul dispatch (#1218)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT authored Oct 25, 2023
1 parent cea22f9 commit bde7da7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions NDTensors/src/abstractarray/iswrappedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ parenttype(::Type{<:UnitUpperTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:UnitLowerTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:Diagonal{<:Any,P}}) where {P} = P
parenttype(::Type{<:SubArray{<:Any,<:Any,P}}) where {P} = P
parenttype(::Type{<:StridedView{<:Any,<:Any,P}}) where {P} = P

# For working with instances, not used by
# `SimpleTraits.jl` traits dispatch.
Expand Down
5 changes: 4 additions & 1 deletion NDTensors/src/array/permutedims.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# NDTensors.permutedims
function permutedims(::Type{<:Array}, M, perm)
return @strided Base.permutedims(M, perm)
## Creating Mperm here to evaluate the permutation and
## avoid returning a Stridedview
@strided Mperm = Base.permutedims(M, perm)
return Mperm
end

# NDTensors.permutedims!
Expand Down

0 comments on commit bde7da7

Please sign in to comment.