From b7119d6fd305ffbccad02d49389bc83343c1e579 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 9 Nov 2023 11:58:17 -0500 Subject: [PATCH] [NDTensorsMetalExt] Update for latest `Unwrap` (#1243) --- .../NDTensorsMetalExt/NDTensorsMetalExt.jl | 3 +- NDTensors/ext/NDTensorsMetalExt/adapt.jl | 2 + NDTensors/ext/NDTensorsMetalExt/copyto.jl | 24 ++++++--- NDTensors/ext/NDTensorsMetalExt/indexing.jl | 5 ++ .../ext/NDTensorsMetalExt/linearalgebra.jl | 9 ++-- .../ext/NDTensorsMetalExt/permutedims.jl | 2 +- NDTensors/src/Unwrap/src/Unwrap.jl | 2 + NDTensors/src/Unwrap/src/functions/adapt.jl | 8 +++ NDTensors/src/Unwrap/test/runtests.jl | 54 +++++++++++++++---- NDTensors/src/abstractarray/append.jl | 1 + NDTensors/src/blocksparse/linearalgebra.jl | 4 +- 11 files changed, 90 insertions(+), 24 deletions(-) create mode 100644 NDTensors/src/Unwrap/src/functions/adapt.jl diff --git a/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl b/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl index 2a1d3ee5e5..36d72d7209 100644 --- a/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl +++ b/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl @@ -2,7 +2,7 @@ module NDTensorsMetalExt using Adapt using Functors -using LinearAlgebra: LinearAlgebra, Transpose, mul!, qr, eigen, svd +using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, mul!, qr, eigen, svd using NDTensors using NDTensors.SetParameters using NDTensors.Unwrap: qr_positive, ql_positive, ql @@ -22,4 +22,5 @@ include("copyto.jl") include("append.jl") include("permutedims.jl") include("mul.jl") + end diff --git a/NDTensors/ext/NDTensorsMetalExt/adapt.jl b/NDTensors/ext/NDTensorsMetalExt/adapt.jl index ade6758c52..1a8df5bd95 100644 --- a/NDTensors/ext/NDTensorsMetalExt/adapt.jl +++ b/NDTensors/ext/NDTensorsMetalExt/adapt.jl @@ -1,3 +1,5 @@ +NDTensors.cpu(e::Exposed{<:MtlArray}) = adapt(Array, e) + function mtl(xs; storage=DefaultStorageMode) return adapt(set_storagemode(MtlArray, storage), xs) end diff --git a/NDTensors/ext/NDTensorsMetalExt/copyto.jl b/NDTensors/ext/NDTensorsMetalExt/copyto.jl index 2b2022839f..ba111fb099 100644 --- a/NDTensors/ext/NDTensorsMetalExt/copyto.jl +++ b/NDTensors/ext/NDTensorsMetalExt/copyto.jl @@ -1,13 +1,23 @@ -# Catches a bug in `copyto!` in Metal backend. -function NDTensors.copyto!( - ::Type{<:MtlArray}, dest::AbstractArray, ::Type{<:MtlArray}, src::SubArray +function Base.copy(src::Exposed{<:MtlArray,<:Base.ReshapedArray}) + return reshape(copy(parent(src)), size(unexpose(src))) +end + +function Base.copy( + src::Exposed{ + <:MtlArray,<:SubArray{<:Any,<:Any,<:Base.ReshapedArray{<:Any,<:Any,<:Adjoint}} + }, ) - return Base.copyto!(dest, copy(src)) + return copy(@view copy(expose(parent(src)))[parentindices(unexpose(src))...]) +end + +# Catches a bug in `copyto!` in Metal backend. +function Base.copyto!(dest::Exposed{<:MtlArray}, src::Exposed{<:MtlArray,<:SubArray}) + return copyto!(dest, expose(copy(src))) end # Catches a bug in `copyto!` in Metal backend. -function NDTensors.copyto!( - ::Type{<:MtlArray}, dest::AbstractArray, ::Type{<:MtlArray}, src::Base.ReshapedArray +function Base.copyto!( + dest::Exposed{<:MtlArray}, src::Exposed{<:MtlArray,<:Base.ReshapedArray} ) - return NDTensors.copyto!(dest, parent(src)) + return copyto!(dest, expose(parent(src))) end diff --git a/NDTensors/ext/NDTensorsMetalExt/indexing.jl b/NDTensors/ext/NDTensorsMetalExt/indexing.jl index 71bead3e72..a682088c8a 100644 --- a/NDTensors/ext/NDTensorsMetalExt/indexing.jl +++ b/NDTensors/ext/NDTensorsMetalExt/indexing.jl @@ -6,3 +6,8 @@ function Base.setindex!(E::Exposed{<:MtlArray}, x::Number) Metal.@allowscalar unexpose(E)[] = x return unexpose(E) end + +# Shared with `CuArray`. Move to `NDTensorsGPUArraysCoreExt`? +function Base.getindex(E::Exposed{<:MtlArray,<:Adjoint}, i, j) + return (expose(parent(E))[j, i])' +end diff --git a/NDTensors/ext/NDTensorsMetalExt/linearalgebra.jl b/NDTensors/ext/NDTensorsMetalExt/linearalgebra.jl index 5297c69063..ce52993c42 100644 --- a/NDTensors/ext/NDTensorsMetalExt/linearalgebra.jl +++ b/NDTensors/ext/NDTensorsMetalExt/linearalgebra.jl @@ -23,8 +23,9 @@ function LinearAlgebra.eigen(A::Exposed{<:MtlMatrix}) end function LinearAlgebra.svd(A::Exposed{<:MtlMatrix}; kwargs...) - U, S, V = svd(expose(NDTensors.cpu(A)); kwargs...) - return adapt(unwrap_type(A), U), - adapt(set_ndims(unwrap_type(A), ndims(S)), S), - adapt(unwrap_type(A), V) + Ucpu, Scpu, Vcpu = svd(expose(NDTensors.cpu(A)); kwargs...) + U = adapt(unwrap_type(A), Ucpu) + S = adapt(set_ndims(unwrap_type(A), ndims(Scpu)), Scpu) + V = adapt(unwrap_type(A), Vcpu) + return U, S, V end diff --git a/NDTensors/ext/NDTensorsMetalExt/permutedims.jl b/NDTensors/ext/NDTensorsMetalExt/permutedims.jl index ad830d0e88..ad29610527 100644 --- a/NDTensors/ext/NDTensorsMetalExt/permutedims.jl +++ b/NDTensors/ext/NDTensorsMetalExt/permutedims.jl @@ -1,4 +1,4 @@ -function permutedims!( +function Base.permutedims!( Edest::Exposed{<:MtlArray,<:Base.ReshapedArray}, Esrc::Exposed{<:MtlArray}, perm ) Aperm = permutedims(Esrc, perm) diff --git a/NDTensors/src/Unwrap/src/Unwrap.jl b/NDTensors/src/Unwrap/src/Unwrap.jl index 00b6448852..9ecdf8d578 100644 --- a/NDTensors/src/Unwrap/src/Unwrap.jl +++ b/NDTensors/src/Unwrap/src/Unwrap.jl @@ -3,6 +3,7 @@ using SimpleTraits using LinearAlgebra using Base: ReshapedArray using StridedViews +using Adapt: Adapt, adapt, adapt_structure include("expose.jl") include("iswrappedarray.jl") @@ -16,6 +17,7 @@ include("functions/copyto.jl") include("functions/linearalgebra.jl") include("functions/mul.jl") include("functions/permutedims.jl") +include("functions/adapt.jl") export IsWrappedArray, is_wrapped_array, parenttype, unwrap_type, expose, Exposed, unexpose, cpu diff --git a/NDTensors/src/Unwrap/src/functions/adapt.jl b/NDTensors/src/Unwrap/src/functions/adapt.jl new file mode 100644 index 0000000000..6ebc8bf7d6 --- /dev/null +++ b/NDTensors/src/Unwrap/src/functions/adapt.jl @@ -0,0 +1,8 @@ +Adapt.adapt(to, x::Exposed) = adapt_structure(to, x) +Adapt.adapt_structure(to, x::Exposed) = adapt_structure(to, unexpose(x)) + +# https://github.com/JuliaGPU/Adapt.jl/pull/51 +# TODO: Remove once https://github.com/JuliaGPU/Adapt.jl/issues/71 is addressed. +function Adapt.adapt_structure(to, A::Exposed{<:Any,<:Hermitian}) + return Hermitian(adapt(to, parent(unexpose(A))), Symbol(unexpose(A).uplo)) +end diff --git a/NDTensors/src/Unwrap/test/runtests.jl b/NDTensors/src/Unwrap/test/runtests.jl index 0cbf4006c7..c4ff4fd825 100644 --- a/NDTensors/src/Unwrap/test/runtests.jl +++ b/NDTensors/src/Unwrap/test/runtests.jl @@ -5,7 +5,9 @@ using LinearAlgebra include("../../../test/device_list.jl") @testset "Testing Unwrap" for dev in devices_list(ARGS) - v = dev(Vector{Float64}(undef, 10)) + elt = Float32 + + v = dev(Vector{elt}(undef, 10)) vt = transpose(v) va = v' @@ -37,7 +39,7 @@ include("../../../test/device_list.jl") @test typeof(Et) == Exposed{m_type,LinearAlgebra.Transpose{e_type,m_type}} @test typeof(Ea) == Exposed{m_type,LinearAlgebra.Adjoint{e_type,m_type}} - o = dev(Vector{Float32})(undef, 1) + o = dev(Vector{elt})(undef, 1) expose(o)[] = 2 @test expose(o)[] == 2 @@ -56,8 +58,8 @@ include("../../../test/device_list.jl") q, r = Unwrap.qr_positive(expose(mp)) @test q * r ≈ mp - square = dev(rand(Float64, (10, 10))) - square = (square + transpose(square)) ./ 2.0 + square = dev(rand(elt, (10, 10))) + square = (square + transpose(square)) / 2 ## CUDA only supports Hermitian or Symmetric eigen decompositions ## So I symmetrize square and call symetric here l, U = eigen(expose(Symmetric(square))) @@ -66,25 +68,59 @@ include("../../../test/device_list.jl") U, S, V, = svd(expose(mp)) @test U * Diagonal(S) * V' ≈ mp - cm = dev(fill!(Matrix{Float64}(undef, (2, 2)), 0.0)) + cm = dev(fill!(Matrix{elt}(undef, (2, 2)), 0.0)) mul!(expose(cm), expose(mp), expose(mp'), 1.0, 0.0) @test cm ≈ mp * mp' @test permutedims(expose(mp), (2, 1)) == transpose(mp) - fill!(mt, 3.0) + fill!(mt, 3) permutedims!(expose(m), expose(mt), (2, 1)) - @test norm(m) == sqrt(3^2 * 10) + @test norm(m) ≈ sqrt(3^2 * 10) @test size(m) == (5, 2) permutedims!(expose(m), expose(mt), (2, 1), +) @test size(m) == (5, 2) - @test norm(m) == sqrt(6^2 * 10) + @test norm(m) ≈ sqrt(6^2 * 10) m = reshape(m, (5, 2, 1)) mt = fill!(similar(m), 3.0) m = permutedims(expose(m), (2, 1, 3)) @test size(m) == (2, 5, 1) permutedims!(expose(m), expose(mt), (2, 1, 3)) - @test norm(m) == sqrt(3^2 * 10) + @test norm(m) ≈ sqrt(3^2 * 10) permutedims!(expose(m), expose(mt), (2, 1, 3), -) @test norm(m) == 0 + + x = dev(rand(elt, 4, 4)) + y = dev(rand(elt, 4, 4)) + copyto!(expose(y), expose(x)) + @test y == x + + y = dev(rand(elt, 4, 4)) + x = Base.ReshapedArray(dev(rand(elt, 16)), (4, 4), ()) + copyto!(expose(y), expose(x)) + @test NDTensors.cpu(y) == NDTensors.cpu(x) + @test NDTensors.cpu(copy(expose(x))) == NDTensors.cpu(x) + + y = dev(rand(elt, 4, 4)) + x = @view dev(rand(elt, 8, 8))[1:4, 1:4] + copyto!(expose(y), expose(x)) + @test y == x + @test copy(x) == x + + y = dev(randn(elt, 16)) + x = reshape(dev(randn(elt, 4, 4))', 16) + copyto!(expose(y), expose(x)) + @test y == x + @test copy(x) == x + + y = dev(randn(elt, 8)) + x = @view reshape(dev(randn(elt, 8, 8))', 64)[1:8] + copyto!(expose(y), expose(x)) + @test y == x + @test copy(x) == x + + y = Base.ReshapedArray(dev(randn(elt, 16)), (4, 4), ()) + x = dev(randn(elt, 4, 4)) + permutedims!(expose(y), expose(x), (2, 1)) + @test NDTensors.cpu(y) == transpose(NDTensors.cpu(x)) end diff --git a/NDTensors/src/abstractarray/append.jl b/NDTensors/src/abstractarray/append.jl index bdfe0e38fc..830d72ab58 100644 --- a/NDTensors/src/abstractarray/append.jl +++ b/NDTensors/src/abstractarray/append.jl @@ -1,6 +1,7 @@ # NDTensors.append! # Used to circumvent issues with some GPU backends like Metal # not supporting `resize!`. +# TODO: Change this over to use `expose`. function append!!(collection, collections...) return append!!(unwrap_type(collection), collection, collections...) end diff --git a/NDTensors/src/blocksparse/linearalgebra.jl b/NDTensors/src/blocksparse/linearalgebra.jl index 25917b1992..f3a830ccf7 100644 --- a/NDTensors/src/blocksparse/linearalgebra.jl +++ b/NDTensors/src/blocksparse/linearalgebra.jl @@ -181,7 +181,7 @@ function svd( if sU == -1 Ub *= -1 end - copyto!(blockview(U, blockU), Ub) + copyto!(expose(blockview(U, blockU)), expose(Ub)) blockviewS = blockview(S, blockS) # TODO: Replace `data` with `diagview`. @@ -200,7 +200,7 @@ function svd( if (sV * sVP) == -1 Vb *= -1 end - copyto!(blockview(V, blockV), Vb) + copyto!(expose(blockview(V, blockV)), expose(Vb)) end return U, S, V, Spectrum(d, truncerr) end