Skip to content

Commit

Permalink
Remove ndims from NDTensors and replace in TypeParameterAcessesors
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Mar 15, 2024
1 parent 8d6b806 commit 78233a0
Show file tree
Hide file tree
Showing 12 changed files with 28 additions and 30 deletions.
4 changes: 0 additions & 4 deletions NDTensors/ext/NDTensorsCUDAExt/set_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,3 @@ end
function TypeParameterAccessors.default_type_parameters(::Type{<:CuArray})
return (Float64, 1, CUDA.Mem.DeviceBuffer)
end

function TypeParameterAccessors.set_ndims(type::Type{<:CuArray}, param)
return set_type_parameter(type, ndims, param)
end
4 changes: 0 additions & 4 deletions NDTensors/ext/NDTensorsMetalExt/set_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,3 @@ end
function TypeParameterAccessors.default_type_parameters(::Type{<:MtlArray})
return (Float32, 1, Metal.DefaultStorageMode)
end

function TypeParameterAccessors.set_ndims(type::Type{<:MtlArray}, param)
return set_type_parameter(type, ndims, param)
end
1 change: 0 additions & 1 deletion NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ include("abstractarray/set_types.jl")
include("abstractarray/to_shape.jl")
include("abstractarray/iscu.jl")
include("abstractarray/similar.jl")
include("abstractarray/ndims.jl")
include("abstractarray/mul.jl")
include("abstractarray/append.jl")
include("abstractarray/permutedims.jl")
Expand Down
10 changes: 0 additions & 10 deletions NDTensors/src/abstractarray/ndims.jl

This file was deleted.

6 changes: 3 additions & 3 deletions NDTensors/src/blocksparse/blockoffsets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ const BlockOffsets{N} = Dictionary{Block{N},Int}

BlockOffset(block::Block{N}, offset::Int) where {N} = BlockOffset{N}(block, offset)

ndims(::Blocks{N}) where {N} = N
ndims(::BlockOffset{N}) where {N} = N
ndims(::BlockOffsets{N}) where {N} = N
TypeParameterAccessors.ndims(::Blocks{N}) where {N} = N
TypeParameterAccessors.ndims(::BlockOffset{N}) where {N} = N
TypeParameterAccessors.ndims(::BlockOffsets{N}) where {N} = N

blocktype(bofs::BlockOffsets) = keytype(bofs)

Expand Down
4 changes: 2 additions & 2 deletions NDTensors/src/blocksparse/blocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
# TODO: Implement as `fieldtype(storagetype, :data)`.
datatype(::Type{<:BlockSparse{<:Any,DataT}}) where {DataT} = DataT
# TODO: Implement as `ndims(blockoffsetstype(storagetype))`.
ndims(storagetype::Type{<:BlockSparse{<:Any,<:Any,N}}) where {N} = N
TypeParameterAccessors.ndims(storagetype::Type{<:BlockSparse{<:Any,<:Any,N}}) where {N} = N
# TODO: Implement as `fieldtype(storagetype, :blockoffsets)`.
blockoffsetstype(storagetype::Type{<:BlockSparse}) = BlockOffsets{ndims(storagetype)}

Expand Down Expand Up @@ -112,7 +112,7 @@ Base.real(::Type{BlockSparse{T}}) where {T} = BlockSparse{real(T)}

complex(::Type{BlockSparse{T}}) where {T} = BlockSparse{complex(T)}

ndims(::BlockSparse{T,V,N}) where {T,V,N} = N
TypeParameterAccessors.ndims(::BlockSparse{T,V,N}) where {T,V,N} = N

eltype(::BlockSparse{T}) where {T} = eltype(T)
# This is necessary since for some reason inference doesn't work
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/empty/EmptyTensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fulltype(T::Tensor) = fulltype(typeof(T))
# Needed for correct `NDTensors.ndims` definitions, for
# example `EmptyStorage` that wraps a `BlockSparse` which
# can have non-unity dimensions.
function ndims(storagetype::Type{<:EmptyStorage})
function TypeParameterAccessors.ndims(storagetype::Type{<:EmptyStorage})
return ndims(fulltype(storagetype))
end

Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ using Base.Threads: @spawn
using .CUDAExtensions: cu
using .MetalExtensions: mtl
using .GPUArraysCoreExtensions: cpu
## Adding this here for now so that ndims in `NDTensors`` uses `TypeParameterAccessors`
using .TypeParameterAccessors: ndims

import Base:
# Types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include("specify_parameters.jl")
include("default_parameters.jl")
include("base/abstractarray.jl")
include("base/array.jl")
include("base/ndims.jl")
include("base/linearalgebra.jl")
include("base/stridedviews.jl")
end
4 changes: 0 additions & 4 deletions NDTensors/src/lib/TypeParameterAccessors/src/base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,3 @@ position(::Type{<:Array}, ::typeof(eltype)) = Position(1)
position(::Type{<:Array}, ::typeof(ndims)) = Position(2)

default_type_parameters(::Type{<:Array}) = (Float64, 1)

function set_ndims(type::Type{<:Array}, param)
return set_type_parameter(type, ndims, param)
end
17 changes: 17 additions & 0 deletions NDTensors/src/lib/TypeParameterAccessors/src/base/ndims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## NDTensors.ndims (not imported from Base)

## TODO So here I try to use the new type_parameters system for `NDTensors.ndims`
## But if `ndims` is not defined for a type, I revert to using Base.ndims
function TypeParameterAccessors.ndims(array)
try
type_parameter(array, Base.ndims)
catch
Base.ndims(array)
end
end

# ## In house patch to deal issue of calling ndims with an Array of unspecified eltype
# ## https://github.com/JuliaLang/julia/pull/40682
# if VERSION < v"1.7"
# TypeParameterAccessors.ndims(::Type{<:AbstractArray{<:Any,N}}) where {N} = N
# end
3 changes: 2 additions & 1 deletion NDTensors/src/tensor/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ function randomTensor(StoreT::Type{<:TensorStorage}, inds::Tuple)
end
## End Random Tensor

ndims(::Type{<:Tensor{<:Any,N}}) where {N} = N
## Potentially it would be better to define `TypeParameterAccessors.positoin(::Type{<:Tensor}, ::typeof(ndims)) = Position(2)` ?
TypeParameterAccessors.ndims(::Type{<:Tensor{<:Any,N}}) where {N} = N

# Like `Base.to_shape` but more general, can return
# `Index`, etc. Customize for an array/tensor
Expand Down

0 comments on commit 78233a0

Please sign in to comment.