Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Sep 20, 2023
1 parent ca75862 commit 5597fe7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
8 changes: 6 additions & 2 deletions NDTensors/ext/NDTensorsCUDAExt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,9 @@ function Adapt.adapt_storage(arraytype::Type{<:CuArray}, xs::NDTensors.Unallocat
)
elt = get_parameter(arraytype_specified_3, Position(1))
N = get_parameter(arraytype_specified_3, Position(2))
return NDTensors.UnallocatedZeros{elt, N, CUDA.CuArray{elt, N, default_parameter(CuArray, Position(3))}}(size(xs))
end
return NDTensors.UnallocatedZeros{
elt,N,CUDA.CuArray{elt,N,default_parameter(CuArray, Position(3))}
}(
size(xs)
)
end
10 changes: 7 additions & 3 deletions NDTensors/ext/NDTensorsMetalExt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ function Adapt.adapt_storage(arraytype::Type{<:MtlArray}, xs::NDTensors.Unalloca
)
elt = get_parameter(arraytype_specified_3, Position(1))
N = get_parameter(arraytype_specified_3, Position(2))

return NDTensors.UnallocatedZeros{elt, N, Metal.MtlArray{elt, N, default_parameter(MtlArray, Position(3))}}(size(xs))
end

return NDTensors.UnallocatedZeros{
elt,N,Metal.MtlArray{elt,N,default_parameter(MtlArray, Position(3))}
}(
size(xs)
)
end
24 changes: 19 additions & 5 deletions NDTensors/src/zeros/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import .SetParameters: set_parameter, nparameters, default_parameter

# `SetParameters.jl` overloads.
NDTensors.SetParameters.get_parameter(::Type{<:UnallocatedZeros{P1}}, ::Position{1}) where {P1} = P1
NDTensors.SetParameters.get_parameter(::Type{<:UnallocatedZeros{<:Any,P2}}, ::Position{2}) where {P2} = P2
NDTensors.SetParameters.get_parameter(::Type{<:UnallocatedZeros{<:Any,<:Any,P3}}, ::Position{3}) where {P3} = P3
function NDTensors.SetParameters.get_parameter(
::Type{<:UnallocatedZeros{P1}}, ::Position{1}
) where {P1}
return P1
end
function NDTensors.SetParameters.get_parameter(
::Type{<:UnallocatedZeros{<:Any,P2}}, ::Position{2}
) where {P2}
return P2
end
function NDTensors.SetParameters.get_parameter(
::Type{<:UnallocatedZeros{<:Any,<:Any,P3}}, ::Position{3}
) where {P3}
return P3
end
function NDTensors.SetParameters.get_parameter(
::Type{<:UnallocatedZeros{<:Any,<:Any,<:Any,P4}}, ::Position{4}
) where {P4}
Expand Down Expand Up @@ -36,13 +48,15 @@ end
function set_parameter(
::Type{<:UnallocatedZeros{P1,P2,P3,<:Any}}, ::Position{4}, P4
) where {P1,P2,P3}
@show P4
@show P4
return UnallocatedZeros{P1,P2,P3,P4}
end

default_parameter(::Type{<:UnallocatedZeros}, ::Position{1}) = Float64
default_parameter(::Type{<:UnallocatedZeros}, ::Position{2}) = 1
default_parameter(::Type{<:UnallocatedZeros}, ::Position{3}) = Tuple{Base.OneTo{Int64}}
default_parameter(::Type{<:UnallocatedZeros}, ::Position{4}) = Vector{default_parameter(UnallocatedZeros, Position(1))}
function default_parameter(::Type{<:UnallocatedZeros}, ::Position{4})
return Vector{default_parameter(UnallocatedZeros, Position(1))}
end

nparameters(::Type{<:UnallocatedZeros}) = Val(4)

0 comments on commit 5597fe7

Please sign in to comment.