Skip to content

Commit

Permalink
Fix adapt functions for UnallocatedZeros
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Sep 20, 2023
1 parent 2f38cef commit ca75862
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
15 changes: 15 additions & 0 deletions NDTensors/ext/NDTensorsCUDAExt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,18 @@ function Adapt.adapt_storage(adaptor::NDTensorCuArrayAdaptor, xs::AbstractArray)
BufT = buffertype(adaptor)
return isbits(xs) ? xs : CuArray{ElT,1,BufT}(xs)
end

function Adapt.adapt_storage(arraytype::Type{<:CuArray}, xs::NDTensors.UnallocatedZeros)
arraytype_specified_1 = set_unspecified_parameters(
arraytype, Position(1), get_parameter(xs, Position(1))
)
arraytype_specified_2 = set_unspecified_parameters(
arraytype_specified_1, Position(2), get_parameter(xs, Position(2))
)
arraytype_specified_3 = set_unspecified_parameters(
arraytype_specified_2, Position(3), get_parameter(xs, Position(3))
)
elt = get_parameter(arraytype_specified_3, Position(1))
N = get_parameter(arraytype_specified_3, Position(2))
return NDTensors.UnallocatedZeros{elt, N, CUDA.CuArray{elt, N, default_parameter(CuArray, Position(3))}}(size(xs))
end
17 changes: 17 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,22 @@ function Adapt.adapt_storage(arraytype::Type{<:MtlArray}, xs::AbstractArray)
arraytype_specified_3 = set_unspecified_parameters(
arraytype_specified_2, Position(3), get_parameter(xs, Position(3))
)
@show convert(arraytype_specified_3, xs)
return isbitstype(typeof(xs)) ? xs : convert(arraytype_specified_3, xs)
end

function Adapt.adapt_storage(arraytype::Type{<:MtlArray}, xs::NDTensors.UnallocatedZeros)
arraytype_specified_1 = set_unspecified_parameters(
arraytype, Position(1), get_parameter(xs, Position(1))
)
arraytype_specified_2 = set_unspecified_parameters(
arraytype_specified_1, Position(2), get_parameter(xs, Position(2))
)
arraytype_specified_3 = set_unspecified_parameters(
arraytype_specified_2, Position(3), get_parameter(xs, Position(3))
)
elt = get_parameter(arraytype_specified_3, Position(1))
N = get_parameter(arraytype_specified_3, Position(2))

return NDTensors.UnallocatedZeros{elt, N, Metal.MtlArray{elt, N, default_parameter(MtlArray, Position(3))}}(size(xs))
end
11 changes: 6 additions & 5 deletions NDTensors/src/zeros/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import .SetParameters: set_parameter, nparameters, default_parameter

# `SetParameters.jl` overloads.
get_parameter(::Type{<:UnallocatedZeros{P1}}, ::Position{1}) where {P1} = P1
get_parameter(::Type{<:UnallocatedZeros{<:Any,P2}}, ::Position{2}) where {P2} = P2
get_parameter(::Type{<:UnallocatedZeros{<:Any,<:Any,P3}}, ::Position{3}) where {P3} = P3
function get_parameter(
NDTensors.SetParameters.get_parameter(::Type{<:UnallocatedZeros{P1}}, ::Position{1}) where {P1} = P1
NDTensors.SetParameters.get_parameter(::Type{<:UnallocatedZeros{<:Any,P2}}, ::Position{2}) where {P2} = P2
NDTensors.SetParameters.get_parameter(::Type{<:UnallocatedZeros{<:Any,<:Any,P3}}, ::Position{3}) where {P3} = P3
function NDTensors.SetParameters.get_parameter(
::Type{<:UnallocatedZeros{<:Any,<:Any,<:Any,P4}}, ::Position{4}
) where {P4}
return P4
Expand Down Expand Up @@ -36,12 +36,13 @@ end
function set_parameter(
::Type{<:UnallocatedZeros{P1,P2,P3,<:Any}}, ::Position{4}, P4
) where {P1,P2,P3}
@show P4
return UnallocatedZeros{P1,P2,P3,P4}
end

default_parameter(::Type{<:UnallocatedZeros}, ::Position{1}) = Float64
default_parameter(::Type{<:UnallocatedZeros}, ::Position{2}) = 1
default_parameter(::Type{<:UnallocatedZeros}, ::Position{3}) = Tuple{Base.OneTo{Int64}}
default_parameter(::Type{<:UnallocatedZeros}, ::Position{4}) = Vector{Float64}
default_parameter(::Type{<:UnallocatedZeros}, ::Position{4}) = Vector{default_parameter(UnallocatedZeros, Position(1))}

nparameters(::Type{<:UnallocatedZeros}) = Val(4)

0 comments on commit ca75862

Please sign in to comment.