From 28b9e5ac4846ca286560202f5f5c49140a7e0c52 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 23 Oct 2023 10:50:36 -0400 Subject: [PATCH] Fix issues for block sparse SVD on GPU --- .../ext/NDTensorsMetalExt/NDTensorsMetalExt.jl | 1 + NDTensors/ext/NDTensorsMetalExt/copyto.jl | 13 +++++++++++++ NDTensors/src/NDTensors.jl | 1 + NDTensors/src/abstractarray/copyto.jl | 13 +++++++++++++ NDTensors/src/blocksparse/blocksparsetensor.jl | 8 ++++---- NDTensors/src/blocksparse/linearalgebra.jl | 11 ++++------- NDTensors/src/dense/densetensor.jl | 7 +++---- NDTensors/src/imports.jl | 1 - NDTensors/src/linearalgebra/linearalgebra.jl | 1 - NDTensors/src/linearalgebra/svd.jl | 3 +-- src/mps/mps.jl | 2 -- 11 files changed, 40 insertions(+), 21 deletions(-) create mode 100644 NDTensors/ext/NDTensorsMetalExt/copyto.jl create mode 100644 NDTensors/src/abstractarray/copyto.jl diff --git a/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl b/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl index ef54c8502d..64b0e4477e 100644 --- a/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl +++ b/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl @@ -17,5 +17,6 @@ include("adapt.jl") include("set_types.jl") include("indexing.jl") include("linearalgebra.jl") +include("copyto.jl") include("permutedims.jl") end diff --git a/NDTensors/ext/NDTensorsMetalExt/copyto.jl b/NDTensors/ext/NDTensorsMetalExt/copyto.jl new file mode 100644 index 0000000000..2b2022839f --- /dev/null +++ b/NDTensors/ext/NDTensorsMetalExt/copyto.jl @@ -0,0 +1,13 @@ +# Catches a bug in `copyto!` in Metal backend. +function NDTensors.copyto!( + ::Type{<:MtlArray}, dest::AbstractArray, ::Type{<:MtlArray}, src::SubArray +) + return Base.copyto!(dest, copy(src)) +end + +# Catches a bug in `copyto!` in Metal backend. +function NDTensors.copyto!( + ::Type{<:MtlArray}, dest::AbstractArray, ::Type{<:MtlArray}, src::Base.ReshapedArray +) + return NDTensors.copyto!(dest, parent(src)) +end diff --git a/NDTensors/src/NDTensors.jl b/NDTensors/src/NDTensors.jl index ec80dad1ca..20af697b8c 100644 --- a/NDTensors/src/NDTensors.jl +++ b/NDTensors/src/NDTensors.jl @@ -50,6 +50,7 @@ include("abstractarray/set_types.jl") include("abstractarray/to_shape.jl") include("abstractarray/similar.jl") include("abstractarray/ndims.jl") +include("abstractarray/copyto.jl") include("abstractarray/permutedims.jl") include("abstractarray/fill.jl") include("abstractarray/mul.jl") diff --git a/NDTensors/src/abstractarray/copyto.jl b/NDTensors/src/abstractarray/copyto.jl new file mode 100644 index 0000000000..3ed7b0bb5a --- /dev/null +++ b/NDTensors/src/abstractarray/copyto.jl @@ -0,0 +1,13 @@ +# NDTensors.copyto! +function copyto!(R::AbstractArray, T::AbstractArray) + copyto!(leaf_parenttype(R), R, leaf_parenttype(T), T) + return R +end + +# NDTensors.copyto! +function copyto!( + ::Type{<:AbstractArray}, R::AbstractArray, ::Type{<:AbstractArray}, T::AbstractArray +) + Base.copyto!(R, T) + return R +end diff --git a/NDTensors/src/blocksparse/blocksparsetensor.jl b/NDTensors/src/blocksparse/blocksparsetensor.jl index 71a97fe7d1..a591038512 100644 --- a/NDTensors/src/blocksparse/blocksparsetensor.jl +++ b/NDTensors/src/blocksparse/blocksparsetensor.jl @@ -177,15 +177,15 @@ function BlockSparseTensor( end function zeros( - ::BlockSparseTensor{ElT,N}, blockoffsets::BlockOffsets{N}, inds + tensor::BlockSparseTensor{ElT,N}, blockoffsets::BlockOffsets{N}, inds ) where {ElT,N} - return BlockSparseTensor(ElT, blockoffsets, inds) + return BlockSparseTensor(datatype(tensor), blockoffsets, inds) end function zeros( - ::Type{<:BlockSparseTensor{ElT,N}}, blockoffsets::BlockOffsets{N}, inds + tensortype::Type{<:BlockSparseTensor{ElT,N}}, blockoffsets::BlockOffsets{N}, inds ) where {ElT,N} - return BlockSparseTensor(ElT, blockoffsets, inds) + return BlockSparseTensor(datatype(tensortype), blockoffsets, inds) end function zeros(tensortype::Type{<:BlockSparseTensor}, inds) diff --git a/NDTensors/src/blocksparse/linearalgebra.jl b/NDTensors/src/blocksparse/linearalgebra.jl index 85522c311e..c7b2eb59b7 100644 --- a/NDTensors/src/blocksparse/linearalgebra.jl +++ b/NDTensors/src/blocksparse/linearalgebra.jl @@ -172,10 +172,9 @@ function svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT} sU = right_arrow_sign(uind, blockU[2]) if sU == -1 - blockview(U, blockU) .= -Ub - else - blockview(U, blockU) .= Ub + Ub *= -1 end + copyto!(blockview(U, blockU), Ub) blockviewS = blockview(S, blockS) for i in 1:diaglength(Sb) @@ -193,12 +192,10 @@ function svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT} end if (sV * sVP) == -1 - blockview(V, blockV) .= -Vb - else - blockview(V, blockV) .= Vb + Vb *= -1 end + copyto!(blockview(V, blockV), Vb) end - return U, S, V, Spectrum(d, truncerr) #end # @timeit_debug end diff --git a/NDTensors/src/dense/densetensor.jl b/NDTensors/src/dense/densetensor.jl index fafee27c73..f80dc99e8c 100644 --- a/NDTensors/src/dense/densetensor.jl +++ b/NDTensors/src/dense/densetensor.jl @@ -219,10 +219,9 @@ function permutedims!( return R end -function copyto!(R::DenseTensor{<:Number,N}, T::DenseTensor{<:Number,N}) where {N} - RA = array(R) - TA = array(T) - RA .= TA +# NDTensors.copyto! +function copyto!(R::DenseTensor, T::DenseTensor) + copyto!(array(R), array(T)) return R end diff --git a/NDTensors/src/imports.jl b/NDTensors/src/imports.jl index cfb2ad1b61..8925685297 100644 --- a/NDTensors/src/imports.jl +++ b/NDTensors/src/imports.jl @@ -17,7 +17,6 @@ import Base: convert, conj, copy, - copyto!, eachindex, eltype, empty, diff --git a/NDTensors/src/linearalgebra/linearalgebra.jl b/NDTensors/src/linearalgebra/linearalgebra.jl index 7f076f7c53..4b7c9a97e5 100644 --- a/NDTensors/src/linearalgebra/linearalgebra.jl +++ b/NDTensors/src/linearalgebra/linearalgebra.jl @@ -358,7 +358,6 @@ function eigen( #DM = DM[p] #VM = VM[:,p] - if truncate DM_cpu = NDTensors.cpu(DM) truncerr, _ = truncate!( diff --git a/NDTensors/src/linearalgebra/svd.jl b/NDTensors/src/linearalgebra/svd.jl index 47207fce6f..8dae0e9adc 100644 --- a/NDTensors/src/linearalgebra/svd.jl +++ b/NDTensors/src/linearalgebra/svd.jl @@ -35,8 +35,7 @@ function svd_recursive(M::AbstractMatrix; thresh::Float64=1E-3, north_pass::Int= V = M' * U V, R = qr_positive(V) - diagR = diag(R) - D[1:Nd] = diagR[1:Nd] + D[1:Nd] = diag(R)[1:Nd] (done, start) = checkSVDDone(D, thresh) diff --git a/src/mps/mps.jl b/src/mps/mps.jl index 9983dc9918..2a19b6c1e0 100644 --- a/src/mps/mps.jl +++ b/src/mps/mps.jl @@ -547,11 +547,9 @@ function replacebond!(M::MPS, b::Int, phi::ITensor; kwargs...) sbp1 = siteind(M, b + 1) indsMb = replaceind(indsMb, sb, sbp1) end - L, R, spec = factorize( phi, indsMb; which_decomp=which_decomp, tags=tags(linkind(M, b)), kwargs... ) - M[b] = L M[b + 1] = R if ortho == "left"