Skip to content

Commit

Permalink
Fix issues for block sparse SVD on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Oct 23, 2023
1 parent 59cccfc commit 28b9e5a
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 21 deletions.
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ include("adapt.jl")
include("set_types.jl")
include("indexing.jl")
include("linearalgebra.jl")
include("copyto.jl")
include("permutedims.jl")
end
13 changes: 13 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/copyto.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Catches a bug in `copyto!` in Metal backend.
function NDTensors.copyto!(
::Type{<:MtlArray}, dest::AbstractArray, ::Type{<:MtlArray}, src::SubArray
)
return Base.copyto!(dest, copy(src))
end

# Catches a bug in `copyto!` in Metal backend.
function NDTensors.copyto!(
::Type{<:MtlArray}, dest::AbstractArray, ::Type{<:MtlArray}, src::Base.ReshapedArray
)
return NDTensors.copyto!(dest, parent(src))
end
1 change: 1 addition & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ include("abstractarray/set_types.jl")
include("abstractarray/to_shape.jl")
include("abstractarray/similar.jl")
include("abstractarray/ndims.jl")
include("abstractarray/copyto.jl")
include("abstractarray/permutedims.jl")
include("abstractarray/fill.jl")
include("abstractarray/mul.jl")
Expand Down
13 changes: 13 additions & 0 deletions NDTensors/src/abstractarray/copyto.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# NDTensors.copyto!
function copyto!(R::AbstractArray, T::AbstractArray)
copyto!(leaf_parenttype(R), R, leaf_parenttype(T), T)
return R
end

# NDTensors.copyto!
function copyto!(
::Type{<:AbstractArray}, R::AbstractArray, ::Type{<:AbstractArray}, T::AbstractArray
)
Base.copyto!(R, T)
return R
end
8 changes: 4 additions & 4 deletions NDTensors/src/blocksparse/blocksparsetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ function BlockSparseTensor(
end

function zeros(
::BlockSparseTensor{ElT,N}, blockoffsets::BlockOffsets{N}, inds
tensor::BlockSparseTensor{ElT,N}, blockoffsets::BlockOffsets{N}, inds
) where {ElT,N}
return BlockSparseTensor(ElT, blockoffsets, inds)
return BlockSparseTensor(datatype(tensor), blockoffsets, inds)
end

function zeros(
::Type{<:BlockSparseTensor{ElT,N}}, blockoffsets::BlockOffsets{N}, inds
tensortype::Type{<:BlockSparseTensor{ElT,N}}, blockoffsets::BlockOffsets{N}, inds
) where {ElT,N}
return BlockSparseTensor(ElT, blockoffsets, inds)
return BlockSparseTensor(datatype(tensortype), blockoffsets, inds)
end

function zeros(tensortype::Type{<:BlockSparseTensor}, inds)
Expand Down
11 changes: 4 additions & 7 deletions NDTensors/src/blocksparse/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,9 @@ function svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT}
sU = right_arrow_sign(uind, blockU[2])

if sU == -1
blockview(U, blockU) .= -Ub
else
blockview(U, blockU) .= Ub
Ub *= -1
end
copyto!(blockview(U, blockU), Ub)

blockviewS = blockview(S, blockS)
for i in 1:diaglength(Sb)
Expand All @@ -193,12 +192,10 @@ function svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT}
end

if (sV * sVP) == -1
blockview(V, blockV) .= -Vb
else
blockview(V, blockV) .= Vb
Vb *= -1
end
copyto!(blockview(V, blockV), Vb)
end

return U, S, V, Spectrum(d, truncerr)
#end # @timeit_debug
end
Expand Down
7 changes: 3 additions & 4 deletions NDTensors/src/dense/densetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,9 @@ function permutedims!(
return R
end

function copyto!(R::DenseTensor{<:Number,N}, T::DenseTensor{<:Number,N}) where {N}
RA = array(R)
TA = array(T)
RA .= TA
# NDTensors.copyto!
function copyto!(R::DenseTensor, T::DenseTensor)
copyto!(array(R), array(T))
return R
end

Expand Down
1 change: 0 additions & 1 deletion NDTensors/src/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import Base:
convert,
conj,
copy,
copyto!,
eachindex,
eltype,
empty,
Expand Down
1 change: 0 additions & 1 deletion NDTensors/src/linearalgebra/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ function eigen(
#DM = DM[p]
#VM = VM[:,p]


if truncate
DM_cpu = NDTensors.cpu(DM)
truncerr, _ = truncate!(
Expand Down
3 changes: 1 addition & 2 deletions NDTensors/src/linearalgebra/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ function svd_recursive(M::AbstractMatrix; thresh::Float64=1E-3, north_pass::Int=
V = M' * U

V, R = qr_positive(V)
diagR = diag(R)
D[1:Nd] = diagR[1:Nd]
D[1:Nd] = diag(R)[1:Nd]

(done, start) = checkSVDDone(D, thresh)

Expand Down
2 changes: 0 additions & 2 deletions src/mps/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,9 @@ function replacebond!(M::MPS, b::Int, phi::ITensor; kwargs...)
sbp1 = siteind(M, b + 1)
indsMb = replaceind(indsMb, sb, sbp1)
end

L, R, spec = factorize(
phi, indsMb; which_decomp=which_decomp, tags=tags(linkind(M, b)), kwargs...
)

M[b] = L
M[b + 1] = R
if ortho == "left"
Expand Down

0 comments on commit 28b9e5a

Please sign in to comment.