Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Remove SetParameters and replace with TypeParameterAccessors #1353

Merged
merged 23 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module NDTensorsCUDAExt

using NDTensors
using NDTensors.SetParameters
using NDTensors.Expose
using Adapt
using Functors
Expand Down
1 change: 0 additions & 1 deletion NDTensors/ext/NDTensorsCUDAExt/imports.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import NDTensors: similartype
import NDTensors:
ContractionProperties, _contract!, GemmBackend, auto_select_backend, _gemm!, iscu
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter
52 changes: 6 additions & 46 deletions NDTensors/ext/NDTensorsCUDAExt/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,55 +1,15 @@
# `SetParameters.jl` overloads.
get_parameter(::Type{<:CuArray{P1}}, ::Position{1}) where {P1} = P1
get_parameter(::Type{<:CuArray{<:Any,P2}}, ::Position{2}) where {P2} = P2
get_parameter(::Type{<:CuArray{<:Any,<:Any,P3}}, ::Position{3}) where {P3} = P3

# Set parameter 1
set_parameter(::Type{<:CuArray}, ::Position{1}, P1) = CuArray{P1}
set_parameter(::Type{<:CuArray{<:Any,P2}}, ::Position{1}, P1) where {P2} = CuArray{P1,P2}
function set_parameter(::Type{<:CuArray{<:Any,<:Any,P3}}, ::Position{1}, P1) where {P3}
return CuArray{P1,<:Any,P3}
end
function set_parameter(::Type{<:CuArray{<:Any,P2,P3}}, ::Position{1}, P1) where {P2,P3}
return CuArray{P1,P2,P3}
end

# Set parameter 2
set_parameter(::Type{<:CuArray}, ::Position{2}, P2) = CuArray{<:Any,P2}
set_parameter(::Type{<:CuArray{P1}}, ::Position{2}, P2) where {P1} = CuArray{P1,P2}
function set_parameter(::Type{<:CuArray{<:Any,<:Any,P3}}, ::Position{2}, P2) where {P3}
return CuArray{<:Any,P2,P3}
end
function set_parameter(::Type{<:CuArray{P1,<:Any,P3}}, ::Position{2}, P2) where {P1,P3}
return CuArray{P1,P2,P3}
end

# Set parameter 3
set_parameter(::Type{<:CuArray}, ::Position{3}, P3) = CuArray{<:Any,<:Any,P3}
set_parameter(::Type{<:CuArray{P1}}, ::Position{3}, P3) where {P1} = CuArray{P1,<:Any,P3}
function set_parameter(::Type{<:CuArray{<:Any,P2}}, ::Position{3}, P3) where {P2}
return CuArray{<:Any,P2,P3}
end
set_parameter(::Type{<:CuArray{P1,P2}}, ::Position{3}, P3) where {P1,P2} = CuArray{P1,P2,P3}

default_parameter(::Type{<:CuArray}, ::Position{1}) = Float64
default_parameter(::Type{<:CuArray}, ::Position{2}) = 1
default_parameter(::Type{<:CuArray}, ::Position{3}) = Mem.DeviceBuffer

nparameters(::Type{<:CuArray}) = Val(3)

SetParameters.unspecify_parameters(::Type{<:CuArray}) = CuArray

using NDTensors.TypeParameterAccessors: TypeParameterAccessors
# TypeParameterAccessors definitions
using NDTensors.TypeParameterAccessors: TypeParameterAccessors, Position
using NDTensors.GPUArraysCoreExtensions: storagemode
## TODO remove TypeParameterAccessors when SetParameters is removed
function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(eltype))
return TypeParameterAccessors.Position(1)
return Position(1)
end
function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(Base.ndims))
return TypeParameterAccessors.Position(2)
function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(ndims))
return Position(2)
end
function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(storagemode))
return TypeParameterAccessors.Position(3)
return Position(3)
end

function TypeParameterAccessors.default_type_parameters(::Type{<:CuArray})
Expand Down
1 change: 0 additions & 1 deletion NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using Adapt
using Functors
using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, mul!, qr, eigen, svd
using NDTensors
using NDTensors.SetParameters
using NDTensors.Expose: qr_positive, ql_positive, ql

using Metal
Expand Down
7 changes: 4 additions & 3 deletions NDTensors/ext/NDTensorsMetalExt/adapt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using NDTensors.MetalExtensions: MetalExtensions
using NDTensors.GPUArraysCoreExtensions: GPUArraysCoreExtensions
using NDTensors.GPUArraysCoreExtensions: GPUArraysCoreExtensions, set_storagemode
using NDTensors.TypeParameterAccessors: specify_type_parameters, type_parameters

GPUArraysCoreExtensions.cpu(e::Exposed{<:MtlArray}) = adapt(Array, e)

Expand All @@ -10,7 +11,7 @@ end
# More general than the version in Metal.jl
## TODO Rewrite this using a custom `MtlArrayAdaptor` which will be written in `MetalExtensions`.
function Adapt.adapt_storage(arraytype::Type{<:MtlArray}, xs::AbstractArray)
params = get_parameters(xs)
arraytype_specified = specify_parameters(arraytype, params...)
params = type_parameters(xs)
arraytype_specified = specify_type_parameters(arraytype, params)
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
return isbitstype(typeof(xs)) ? xs : convert(arraytype_specified, xs)
end
2 changes: 0 additions & 2 deletions NDTensors/ext/NDTensorsMetalExt/imports.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter

using NDTensors.Expose: Exposed, unexpose, expose
using Metal: DefaultStorageMode
using NDTensors: adapt
17 changes: 14 additions & 3 deletions NDTensors/ext/NDTensorsMetalExt/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using NDTensors.TypeParameterAccessors: unwrap_array_type
using NDTensors.TypeParameterAccessors:
set_type_parameters, type_parameters, unwrap_array_type

function LinearAlgebra.qr(A::Exposed{<:MtlMatrix})
Q, R = qr(expose(NDTensors.cpu(A)))
Expand All @@ -21,15 +22,25 @@ end

function LinearAlgebra.eigen(A::Exposed{<:MtlMatrix})
Dcpu, Ucpu = eigen(expose(NDTensors.cpu(A)))
D = adapt(set_ndims(set_eltype(unwrap_array_type(A), eltype(Dcpu)), ndims(Dcpu)), Dcpu)
D = adapt(
set_type_parameters(
unwrap_array_type(A), (eltype, ndims), type_parameters(Dcpu, (eltype, ndims))
),
Dcpu,
)
U = adapt(unwrap_array_type(A), Ucpu)
return D, U
end

function LinearAlgebra.svd(A::Exposed{<:MtlMatrix}; kwargs...)
Ucpu, Scpu, Vcpu = svd(expose(NDTensors.cpu(A)); kwargs...)
U = adapt(unwrap_array_type(A), Ucpu)
S = adapt(set_ndims(set_eltype(unwrap_array_type(A), eltype(Scpu)), ndims(Scpu)), Scpu)
S = adapt(
set_type_parameters(
unwrap_array_type(A), (eltype, ndims), type_parameters(Scpu, (eltype, ndims))
),
Scpu,
)
V = adapt(unwrap_array_type(A), Vcpu)
return U, S, V
end
56 changes: 6 additions & 50 deletions NDTensors/ext/NDTensorsMetalExt/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,61 +1,17 @@
# `SetParameters.jl` overloads.
get_parameter(::Type{<:MtlArray{P1}}, ::Position{1}) where {P1} = P1
get_parameter(::Type{<:MtlArray{<:Any,P2}}, ::Position{2}) where {P2} = P2
get_parameter(::Type{<:MtlArray{<:Any,<:Any,P3}}, ::Position{3}) where {P3} = P3
# `TypeParameterAccessors.jl` definitions.

# Set parameter 1
set_parameter(::Type{<:MtlArray}, ::Position{1}, P1) = MtlArray{P1}
set_parameter(::Type{<:MtlArray{<:Any,P2}}, ::Position{1}, P1) where {P2} = MtlArray{P1,P2}
function set_parameter(::Type{<:MtlArray{<:Any,<:Any,P3}}, ::Position{1}, P1) where {P3}
return MtlArray{P1,<:Any,P3}
end
function set_parameter(::Type{<:MtlArray{<:Any,P2,P3}}, ::Position{1}, P1) where {P2,P3}
return MtlArray{P1,P2,P3}
end

# Set parameter 2
set_parameter(::Type{<:MtlArray}, ::Position{2}, P2) = MtlArray{<:Any,P2}
set_parameter(::Type{<:MtlArray{P1}}, ::Position{2}, P2) where {P1} = MtlArray{P1,P2}
function set_parameter(::Type{<:MtlArray{<:Any,<:Any,P3}}, ::Position{2}, P2) where {P3}
return MtlArray{<:Any,P2,P3}
end
function set_parameter(::Type{<:MtlArray{P1,<:Any,P3}}, ::Position{2}, P2) where {P1,P3}
return MtlArray{P1,P2,P3}
end

# Set parameter 3
set_parameter(::Type{<:MtlArray}, ::Position{3}, P3) = MtlArray{<:Any,<:Any,P3}
set_parameter(::Type{<:MtlArray{P1}}, ::Position{3}, P3) where {P1} = MtlArray{P1,<:Any,P3}
function set_parameter(::Type{<:MtlArray{<:Any,P2}}, ::Position{3}, P3) where {P2}
return MtlArray{<:Any,P2,P3}
end
function set_parameter(::Type{<:MtlArray{P1,P2}}, ::Position{3}, P3) where {P1,P2}
return MtlArray{P1,P2,P3}
end

default_parameter(::Type{<:MtlArray}, ::Position{1}) = Float32
default_parameter(::Type{<:MtlArray}, ::Position{2}) = 1
default_parameter(::Type{<:MtlArray}, ::Position{3}) = Metal.DefaultStorageMode

nparameters(::Type{<:MtlArray}) = Val(3)

using NDTensors.TypeParameterAccessors: TypeParameterAccessors
using NDTensors.TypeParameterAccessors: TypeParameterAccessors, Position, set_type_parameter
using NDTensors.GPUArraysCoreExtensions: storagemode
# Metal-specific type parameter setting
function set_storagemode(arraytype::Type{<:MtlArray}, param)
return TypeParameterAccessors.set_type_parameter(arraytype, storagemode, param)
end

SetParameters.unspecify_parameters(::Type{<:MtlArray}) = MtlArray
## TODO remove TypeParameterAccessors when SetParameters is removed
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(eltype))
return TypeParameterAccessors.Position(1)
return Position(1)
end
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(Base.ndims))
return TypeParameterAccessors.Position(2)
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(ndims))
return Position(2)
end
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(storagemode))
return TypeParameterAccessors.Position(3)
return Position(3)
end

function TypeParameterAccessors.default_type_parameters(::Type{<:MtlArray})
Expand Down
2 changes: 0 additions & 2 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ include("abstractarray/set_types.jl")
include("abstractarray/to_shape.jl")
include("abstractarray/iscu.jl")
include("abstractarray/similar.jl")
include("abstractarray/ndims.jl")
include("abstractarray/mul.jl")
include("abstractarray/append.jl")
include("abstractarray/permutedims.jl")
include("abstractarray/fill.jl")
include("array/set_types.jl")
include("array/permutedims.jl")
include("array/mul.jl")
include("tupletools.jl")
Expand Down
11 changes: 3 additions & 8 deletions NDTensors/src/abstractarray/fill.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
using .SetParameters: DefaultParameters, specify_parameters
using .TypeParameterAccessors: unwrap_array_type
using .TypeParameterAccessors: unwrap_array_type, specify_default_type_parameters

function generic_randn(
arraytype::Type{<:AbstractArray}, dim::Integer=0; rng=Random.default_rng()
)
arraytype_specified = specify_parameters(
unwrap_array_type(arraytype), DefaultParameters()
)
arraytype_specified = specify_default_type_parameters(unwrap_array_type(arraytype))
data = similar(arraytype_specified, dim)
return randn!(rng, data)
end

function generic_zeros(arraytype::Type{<:AbstractArray}, dims...)
arraytype_specified = specify_parameters(
unwrap_array_type(arraytype), DefaultParameters()
)
arraytype_specified = specify_default_type_parameters(unwrap_array_type(arraytype))
ElT = eltype(arraytype_specified)
return fill!(similar(arraytype_specified, dims...), zero(ElT))
end
10 changes: 0 additions & 10 deletions NDTensors/src/abstractarray/ndims.jl

This file was deleted.

4 changes: 2 additions & 2 deletions NDTensors/src/abstractarray/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using .SetParameters: set_ndims
using .TypeParameterAccessors: TypeParameterAccessors, set_ndims
"""
# Do we still want to define things like this?
TODO: Use `Accessors.jl` notation:
Expand All @@ -11,7 +11,7 @@ TODO: Use `Accessors.jl` notation:
# TODO: Delete this when we change to using a
# `FillArray` instead. This is a stand-in
# to make things work with the current design.
function SetParameters.set_ndims(numbertype::Type{<:Number}, ndims)
function TypeParameterAccessors.set_ndims(numbertype::Type{<:Number}, ndims)
return numbertype
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
end

Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/abstractarray/similar.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using .TypeParameterAccessors: IsWrappedArray, unwrap_array_type
using .TypeParameterAccessors: IsWrappedArray, unwrap_array_type, set_eltype

## Custom `NDTensors.similar` implementation.
## More extensive than `Base.similar`.
Expand Down
5 changes: 3 additions & 2 deletions NDTensors/src/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ double_precision(x) = fmap(x -> adapt(double_precision(eltype(x)), x), x)
# Used to adapt `EmptyStorage` types
#

using .TypeParameterAccessors: specify_type_parameter, specify_type_parameters
function adapt_storagetype(to::Type{<:AbstractVector}, x::Type{<:TensorStorage})
return set_datatype(x, specify_parameters(to, eltype(x)))
return set_datatype(x, specify_type_parameter(to, eltype, eltype(x)))
end

function adapt_storagetype(to::Type{<:AbstractArray}, x::Type{<:TensorStorage})
return set_datatype(x, specify_parameters(set_ndims(to, 1), eltype(x)))
return set_datatype(x, specify_type_parameter(to, (ndims, eltype), (1, eltype(x))))
end
12 changes: 0 additions & 12 deletions NDTensors/src/array/set_types.jl

This file was deleted.

6 changes: 3 additions & 3 deletions NDTensors/src/blocksparse/blockoffsets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ const BlockOffsets{N} = Dictionary{Block{N},Int}

BlockOffset(block::Block{N}, offset::Int) where {N} = BlockOffset{N}(block, offset)

ndims(::Blocks{N}) where {N} = N
ndims(::BlockOffset{N}) where {N} = N
ndims(::BlockOffsets{N}) where {N} = N
Base.ndims(::Blocks{N}) where {N} = N
Base.ndims(::BlockOffset{N}) where {N} = N
Base.ndims(::BlockOffsets{N}) where {N} = N

blocktype(bofs::BlockOffsets) = keytype(bofs)

Expand Down
6 changes: 3 additions & 3 deletions NDTensors/src/blocksparse/blocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ end
# TODO: Implement as `fieldtype(storagetype, :data)`.
datatype(::Type{<:BlockSparse{<:Any,DataT}}) where {DataT} = DataT
# TODO: Implement as `ndims(blockoffsetstype(storagetype))`.
ndims(storagetype::Type{<:BlockSparse{<:Any,<:Any,N}}) where {N} = N
Base.ndims(storagetype::Type{<:BlockSparse{<:Any,<:Any,N}}) where {N} = N
# TODO: Implement as `fieldtype(storagetype, :blockoffsets)`.
blockoffsetstype(storagetype::Type{<:BlockSparse}) = BlockOffsets{ndims(storagetype)}

function set_datatype(storagetype::Type{<:BlockSparse}, datatype::Type{<:AbstractVector})
return BlockSparse{eltype(datatype),datatype,ndims(storagetype)}
end

function SetParameters.set_ndims(storagetype::Type{<:BlockSparse}, ndims::Int)
function TypeParameterAccessors.set_ndims(storagetype::Type{<:BlockSparse}, ndims::Int)
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
return BlockSparse{eltype(storagetype),datatype(storagetype),ndims}
end

Expand Down Expand Up @@ -112,7 +112,7 @@ Base.real(::Type{BlockSparse{T}}) where {T} = BlockSparse{real(T)}

complex(::Type{BlockSparse{T}}) where {T} = BlockSparse{complex(T)}

ndims(::BlockSparse{T,V,N}) where {T,V,N} = N
Base.ndims(::BlockSparse{T,V,N}) where {T,V,N} = N

eltype(::BlockSparse{T}) where {T} = eltype(T)
# This is necessary since for some reason inference doesn't work
Expand Down
27 changes: 4 additions & 23 deletions NDTensors/src/dense/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using .SetParameters:
SetParameters, Position, get_parameters, specify_parameters, unspecify_parameters
using .TypeParameterAccessors: TypeParameterAccessors, parenttype
using .TypeParameterAccessors: TypeParameterAccessors, Position, parenttype

function set_datatype(storagetype::Type{<:Dense}, datatype::Type{<:AbstractVector})
return Dense{eltype(datatype),datatype}
Expand All @@ -12,23 +10,6 @@ function set_datatype(storagetype::Type{<:Dense}, datatype::Type{<:AbstractArray
)
end

SetParameters.unspecify_parameters(::Type{<:Dense}) = Dense

SetParameters.parenttype_position(::Type{<:Dense}) = Position(2)
SetParameters.nparameters(::Type{<:Dense}) = Val(2)
SetParameters.get_parameter(::Type{<:Dense{P1}}, ::Position{1}) where {P1} = P1
SetParameters.get_parameter(::Type{<:Dense{<:Any,P2}}, ::Position{2}) where {P2} = P2
SetParameters.default_parameter(::Type{<:Dense}, ::Position{1}) = Float64
SetParameters.default_parameter(::Type{<:Dense}, ::Position{2}) = Vector

SetParameters.set_parameter(::Type{<:Dense}, ::Position{1}, P1) = Dense{P1}
function SetParameters.set_parameter(
::Type{<:Dense{<:Any,P2}}, ::Position{1}, P1
) where {P2}
return Dense{P1,P2}
end

SetParameters.set_parameter(::Type{<:Dense}, ::Position{2}, P2) = Dense{<:Any,P2}
function SetParameters.set_parameter(::Type{<:Dense{P1}}, ::Position{2}, P2) where {P1}
return Dense{P1,P2}
end
TypeParameterAccessors.default_type_parameters(::Type{<:Dense}) = (Float64, Vector)
TypeParameterAccessors.position(::Type{<:Dense}, ::typeof(eltype)) = Position(1)
TypeParameterAccessors.position(::Type{<:Dense}, ::typeof(parenttype)) = Position(2)
6 changes: 4 additions & 2 deletions NDTensors/src/diag/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
function SetParameters.set_eltype(storagetype::Type{<:UniformDiag}, eltype::Type)
using .TypeParameterAccessors: TypeParameterAccessors

function TypeParameterAccessors.set_eltype(storagetype::Type{<:UniformDiag}, eltype::Type)
return Diag{eltype,eltype}
end

function SetParameters.set_eltype(
function TypeParameterAccessors.set_eltype(
storagetype::Type{<:NonuniformDiag}, eltype::Type{<:AbstractArray}
)
return Diag{eltype,similartype(storagetype, eltype)}
Expand Down
Loading
Loading