Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Remove SetParameters and replace with TypeParameterAccessors #1353

Merged
merged 23 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module NDTensorsCUDAExt

using NDTensors
using NDTensors.SetParameters
using NDTensors.Expose
using Adapt
using Functors
Expand Down
1 change: 0 additions & 1 deletion NDTensors/ext/NDTensorsCUDAExt/imports.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import NDTensors: similartype
import NDTensors:
ContractionProperties, _contract!, GemmBackend, auto_select_backend, _gemm!, iscu
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter
54 changes: 9 additions & 45 deletions NDTensors/ext/NDTensorsCUDAExt/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,21 @@
# `SetParameters.jl` overloads.
get_parameter(::Type{<:CuArray{P1}}, ::Position{1}) where {P1} = P1
get_parameter(::Type{<:CuArray{<:Any,P2}}, ::Position{2}) where {P2} = P2
get_parameter(::Type{<:CuArray{<:Any,<:Any,P3}}, ::Position{3}) where {P3} = P3

# Set parameter 1
set_parameter(::Type{<:CuArray}, ::Position{1}, P1) = CuArray{P1}
set_parameter(::Type{<:CuArray{<:Any,P2}}, ::Position{1}, P1) where {P2} = CuArray{P1,P2}
function set_parameter(::Type{<:CuArray{<:Any,<:Any,P3}}, ::Position{1}, P1) where {P3}
return CuArray{P1,<:Any,P3}
end
function set_parameter(::Type{<:CuArray{<:Any,P2,P3}}, ::Position{1}, P1) where {P2,P3}
return CuArray{P1,P2,P3}
end

# Set parameter 2
set_parameter(::Type{<:CuArray}, ::Position{2}, P2) = CuArray{<:Any,P2}
set_parameter(::Type{<:CuArray{P1}}, ::Position{2}, P2) where {P1} = CuArray{P1,P2}
function set_parameter(::Type{<:CuArray{<:Any,<:Any,P3}}, ::Position{2}, P2) where {P3}
return CuArray{<:Any,P2,P3}
end
function set_parameter(::Type{<:CuArray{P1,<:Any,P3}}, ::Position{2}, P2) where {P1,P3}
return CuArray{P1,P2,P3}
end

# Set parameter 3
set_parameter(::Type{<:CuArray}, ::Position{3}, P3) = CuArray{<:Any,<:Any,P3}
set_parameter(::Type{<:CuArray{P1}}, ::Position{3}, P3) where {P1} = CuArray{P1,<:Any,P3}
function set_parameter(::Type{<:CuArray{<:Any,P2}}, ::Position{3}, P3) where {P2}
return CuArray{<:Any,P2,P3}
end
set_parameter(::Type{<:CuArray{P1,P2}}, ::Position{3}, P3) where {P1,P2} = CuArray{P1,P2,P3}

default_parameter(::Type{<:CuArray}, ::Position{1}) = Float64
default_parameter(::Type{<:CuArray}, ::Position{2}) = 1
default_parameter(::Type{<:CuArray}, ::Position{3}) = Mem.DeviceBuffer

nparameters(::Type{<:CuArray}) = Val(3)

SetParameters.unspecify_parameters(::Type{<:CuArray}) = CuArray

using NDTensors.TypeParameterAccessors: TypeParameterAccessors
# TypeParameterAccessors definitions
using NDTensors.TypeParameterAccessors: TypeParameterAccessors, Position, set_type_parameter
using NDTensors.GPUArraysCoreExtensions: storagemode
## TODO remove TypeParameterAccessors when SetParameters is removed
function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(eltype))
return TypeParameterAccessors.Position(1)
return Position(1)
end
function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(Base.ndims))
return TypeParameterAccessors.Position(2)
return Position(2)
end
function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(storagemode))
return TypeParameterAccessors.Position(3)
return Position(3)
end

function TypeParameterAccessors.default_type_parameters(::Type{<:CuArray})
return (Float64, 1, CUDA.Mem.DeviceBuffer)
end

function TypeParameterAccessors.set_ndims(type::Type{<:CuArray}, param)
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
return set_type_parameter(type, ndims, param)
end
5 changes: 3 additions & 2 deletions NDTensors/ext/NDTensorsMetalExt/adapt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using NDTensors.MetalExtensions: MetalExtensions
using NDTensors.GPUArraysCoreExtensions: GPUArraysCoreExtensions
using NDTensors.TypeParameterAccessors: specify_type_parameters, type_parameters

GPUArraysCoreExtensions.cpu(e::Exposed{<:MtlArray}) = adapt(Array, e)

Expand All @@ -10,7 +11,7 @@ end
# More general than the version in Metal.jl
## TODO Rewrite this using a custom `MtlArrayAdaptor` which will be written in `MetalExtensions`.
function Adapt.adapt_storage(arraytype::Type{<:MtlArray}, xs::AbstractArray)
params = get_parameters(xs)
arraytype_specified = specify_parameters(arraytype, params...)
params = type_parameters(xs)
arraytype_specified = specify_type_parameters(arraytype, params)
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
return isbitstype(typeof(xs)) ? xs : convert(arraytype_specified, xs)
end
2 changes: 0 additions & 2 deletions NDTensors/ext/NDTensorsMetalExt/imports.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter

using NDTensors.Expose: Exposed, unexpose, expose
using Metal: DefaultStorageMode
using NDTensors: adapt
53 changes: 9 additions & 44 deletions NDTensors/ext/NDTensorsMetalExt/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,6 @@
# `SetParameters.jl` overloads.
get_parameter(::Type{<:MtlArray{P1}}, ::Position{1}) where {P1} = P1
get_parameter(::Type{<:MtlArray{<:Any,P2}}, ::Position{2}) where {P2} = P2
get_parameter(::Type{<:MtlArray{<:Any,<:Any,P3}}, ::Position{3}) where {P3} = P3
# `TypeParameterAccessors.jl` definitions.

# Set parameter 1
set_parameter(::Type{<:MtlArray}, ::Position{1}, P1) = MtlArray{P1}
set_parameter(::Type{<:MtlArray{<:Any,P2}}, ::Position{1}, P1) where {P2} = MtlArray{P1,P2}
function set_parameter(::Type{<:MtlArray{<:Any,<:Any,P3}}, ::Position{1}, P1) where {P3}
return MtlArray{P1,<:Any,P3}
end
function set_parameter(::Type{<:MtlArray{<:Any,P2,P3}}, ::Position{1}, P1) where {P2,P3}
return MtlArray{P1,P2,P3}
end

# Set parameter 2
set_parameter(::Type{<:MtlArray}, ::Position{2}, P2) = MtlArray{<:Any,P2}
set_parameter(::Type{<:MtlArray{P1}}, ::Position{2}, P2) where {P1} = MtlArray{P1,P2}
function set_parameter(::Type{<:MtlArray{<:Any,<:Any,P3}}, ::Position{2}, P2) where {P3}
return MtlArray{<:Any,P2,P3}
end
function set_parameter(::Type{<:MtlArray{P1,<:Any,P3}}, ::Position{2}, P2) where {P1,P3}
return MtlArray{P1,P2,P3}
end

# Set parameter 3
set_parameter(::Type{<:MtlArray}, ::Position{3}, P3) = MtlArray{<:Any,<:Any,P3}
set_parameter(::Type{<:MtlArray{P1}}, ::Position{3}, P3) where {P1} = MtlArray{P1,<:Any,P3}
function set_parameter(::Type{<:MtlArray{<:Any,P2}}, ::Position{3}, P3) where {P2}
return MtlArray{<:Any,P2,P3}
end
function set_parameter(::Type{<:MtlArray{P1,P2}}, ::Position{3}, P3) where {P1,P2}
return MtlArray{P1,P2,P3}
end

default_parameter(::Type{<:MtlArray}, ::Position{1}) = Float32
default_parameter(::Type{<:MtlArray}, ::Position{2}) = 1
default_parameter(::Type{<:MtlArray}, ::Position{3}) = Metal.DefaultStorageMode

nparameters(::Type{<:MtlArray}) = Val(3)

using NDTensors.TypeParameterAccessors: TypeParameterAccessors
using NDTensors.TypeParameterAccessors: TypeParameterAccessors, Position, set_type_parameter
using NDTensors.GPUArraysCoreExtensions: storagemode
# Metal-specific type parameter setting
function set_storagemode(arraytype::Type{<:MtlArray}, param)
Expand All @@ -49,15 +10,19 @@ end
SetParameters.unspecify_parameters(::Type{<:MtlArray}) = MtlArray
## TODO remove TypeParameterAccessors when SetParameters is removed
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(eltype))
return TypeParameterAccessors.Position(1)
return Position(1)
end
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(Base.ndims))
return TypeParameterAccessors.Position(2)
return Position(2)
end
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(storagemode))
return TypeParameterAccessors.Position(3)
return Position(3)
end

function TypeParameterAccessors.default_type_parameters(::Type{<:MtlArray})
return (Float32, 1, Metal.DefaultStorageMode)
end

function TypeParameterAccessors.set_ndims(type::Type{<:CuArray}, param)
return set_type_parameter(type, ndims, param)
end
1 change: 0 additions & 1 deletion NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ include("abstractarray/mul.jl")
include("abstractarray/append.jl")
include("abstractarray/permutedims.jl")
include("abstractarray/fill.jl")
include("array/set_types.jl")
include("array/permutedims.jl")
include("array/mul.jl")
include("tupletools.jl")
Expand Down
11 changes: 3 additions & 8 deletions NDTensors/src/abstractarray/fill.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
using .SetParameters: DefaultParameters, specify_parameters
using .TypeParameterAccessors: unwrap_array_type
using .TypeParameterAccessors: unwrap_array_type, specify_default_type_parameters

function generic_randn(
arraytype::Type{<:AbstractArray}, dim::Integer=0; rng=Random.default_rng()
)
arraytype_specified = specify_parameters(
unwrap_array_type(arraytype), DefaultParameters()
)
arraytype_specified = specify_default_type_parameters(unwrap_array_type(arraytype))
data = similar(arraytype_specified, dim)
return randn!(rng, data)
end

function generic_zeros(arraytype::Type{<:AbstractArray}, dims...)
arraytype_specified = specify_parameters(
unwrap_array_type(arraytype), DefaultParameters()
)
arraytype_specified = specify_default_type_parameters(unwrap_array_type(arraytype))
ElT = eltype(arraytype_specified)
return fill!(similar(arraytype_specified, dims...), zero(ElT))
end
4 changes: 2 additions & 2 deletions NDTensors/src/abstractarray/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using .SetParameters: set_ndims
using .TypeParameterAccessors: TypeParameterAccessors, set_ndims
"""
# Do we still want to define things like this?
TODO: Use `Accessors.jl` notation:
Expand All @@ -11,7 +11,7 @@ TODO: Use `Accessors.jl` notation:
# TODO: Delete this when we change to using a
# `FillArray` instead. This is a stand-in
# to make things work with the current design.
function SetParameters.set_ndims(numbertype::Type{<:Number}, ndims)
function TypeParameterAccessors.set_ndims(numbertype::Type{<:Number}, ndims)
return numbertype
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
end

Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/abstractarray/similar.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using .TypeParameterAccessors: IsWrappedArray, unwrap_array_type
using .TypeParameterAccessors: IsWrappedArray, unwrap_array_type, set_eltype

## Custom `NDTensors.similar` implementation.
## More extensive than `Base.similar`.
Expand Down
5 changes: 3 additions & 2 deletions NDTensors/src/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ double_precision(x) = fmap(x -> adapt(double_precision(eltype(x)), x), x)
# Used to adapt `EmptyStorage` types
#

using .TypeParameterAccessors: specify_type_parameter, specify_type_parameters
function adapt_storagetype(to::Type{<:AbstractVector}, x::Type{<:TensorStorage})
return set_datatype(x, specify_parameters(to, eltype(x)))
return set_datatype(x, specify_type_parameter(to, eltype, eltype(x)))
end

function adapt_storagetype(to::Type{<:AbstractArray}, x::Type{<:TensorStorage})
return set_datatype(x, specify_parameters(set_ndims(to, 1), eltype(x)))
return set_datatype(x, specify_type_parameter(to, (ndims, eltype), 1, eltype(x)))
end
12 changes: 0 additions & 12 deletions NDTensors/src/array/set_types.jl

This file was deleted.

2 changes: 1 addition & 1 deletion NDTensors/src/blocksparse/blocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function set_datatype(storagetype::Type{<:BlockSparse}, datatype::Type{<:Abstrac
return BlockSparse{eltype(datatype),datatype,ndims(storagetype)}
end

function SetParameters.set_ndims(storagetype::Type{<:BlockSparse}, ndims::Int)
function TypeParameterAccessors.set_ndims(storagetype::Type{<:BlockSparse}, ndims::Int)
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
return BlockSparse{eltype(storagetype),datatype(storagetype),ndims}
end

Expand Down
27 changes: 4 additions & 23 deletions NDTensors/src/dense/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using .SetParameters:
SetParameters, Position, get_parameters, specify_parameters, unspecify_parameters
using .TypeParameterAccessors: TypeParameterAccessors, parenttype
using .TypeParameterAccessors: TypeParameterAccessors, Position, parenttype

function set_datatype(storagetype::Type{<:Dense}, datatype::Type{<:AbstractVector})
return Dense{eltype(datatype),datatype}
Expand All @@ -12,23 +10,6 @@ function set_datatype(storagetype::Type{<:Dense}, datatype::Type{<:AbstractArray
)
end

SetParameters.unspecify_parameters(::Type{<:Dense}) = Dense

SetParameters.parenttype_position(::Type{<:Dense}) = Position(2)
SetParameters.nparameters(::Type{<:Dense}) = Val(2)
SetParameters.get_parameter(::Type{<:Dense{P1}}, ::Position{1}) where {P1} = P1
SetParameters.get_parameter(::Type{<:Dense{<:Any,P2}}, ::Position{2}) where {P2} = P2
SetParameters.default_parameter(::Type{<:Dense}, ::Position{1}) = Float64
SetParameters.default_parameter(::Type{<:Dense}, ::Position{2}) = Vector

SetParameters.set_parameter(::Type{<:Dense}, ::Position{1}, P1) = Dense{P1}
function SetParameters.set_parameter(
::Type{<:Dense{<:Any,P2}}, ::Position{1}, P1
) where {P2}
return Dense{P1,P2}
end

SetParameters.set_parameter(::Type{<:Dense}, ::Position{2}, P2) = Dense{<:Any,P2}
function SetParameters.set_parameter(::Type{<:Dense{P1}}, ::Position{2}, P2) where {P1}
return Dense{P1,P2}
end
TypeParameterAccessors.default_type_parameters(::Type{<:Dense}) = (Float64, Vector)
TypeParameterAccessors.position(::Type{<:Dense}, ::typeof(eltype)) = Position(1)
TypeParameterAccessors.position(::Type{<:Dense}, ::typeof(parenttype)) = Position(2)
6 changes: 4 additions & 2 deletions NDTensors/src/diag/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
function SetParameters.set_eltype(storagetype::Type{<:UniformDiag}, eltype::Type)
using .TypeParameterAccessors: TypeParameterAccessors

function TypeParameterAccessors.set_eltype(storagetype::Type{<:UniformDiag}, eltype::Type)
return Diag{eltype,eltype}
end

function SetParameters.set_eltype(
function TypeParameterAccessors.set_eltype(
storagetype::Type{<:NonuniformDiag}, eltype::Type{<:AbstractArray}
)
return Diag{eltype,similartype(storagetype, eltype)}
Expand Down
1 change: 0 additions & 1 deletion NDTensors/src/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ for lib in [
:CUDAExtensions,
:MetalExtensions,
:Expose,
:SetParameters,
:BroadcastMapConversion,
:RankFactorization,
:Sectors,
Expand Down
2 changes: 0 additions & 2 deletions NDTensors/src/lib/SetParameters/.JuliaFormatter.toml

This file was deleted.

8 changes: 0 additions & 8 deletions NDTensors/src/lib/SetParameters/Project.toml

This file was deleted.

3 changes: 0 additions & 3 deletions NDTensors/src/lib/SetParameters/README.md

This file was deleted.

38 changes: 0 additions & 38 deletions NDTensors/src/lib/SetParameters/TODO.md

This file was deleted.

This file was deleted.

Loading
Loading