Skip to content

Commit

Permalink
Updates to get dense working
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Feb 28, 2024
1 parent 591e547 commit 8b8182d
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 18 deletions.
12 changes: 7 additions & 5 deletions NDTensors/src/abstractarray/fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ using .TypeParameterAccessors: set_ndims, specify_default_parameters, unwrap_arr

## TODO I modified this to accept any type and just match the output to the number of dims.
# for example generic_zeros(Vector, 2,3) = Matrix{Float64}[0,0,0;0,0,0;]
function generic_randn(
arraytype::Type{<:AbstractArray}, dims...; rng=Random.default_rng()
)
arraytype_specified = set_ndims(specify_default_parameters(unwrap_array_type(arraytype)), length(dims));
function generic_randn(arraytype::Type{<:AbstractArray}, dims...; rng=Random.default_rng())
arraytype_specified = set_ndims(
specify_default_parameters(unwrap_array_type(arraytype)), length(dims)
)
data = similar(arraytype_specified, dims)
return randn!(rng, data)
end

function generic_zeros(arraytype::Type{<:AbstractArray}, dims...)
arraytype_specified = set_ndims(specify_default_parameters(unwrap_array_type(arraytype)), length(dims));
arraytype_specified = set_ndims(
specify_default_parameters(unwrap_array_type(arraytype)), length(dims)
)
ElT = eltype(arraytype_specified)
return fill!(similar(arraytype_specified, dims), zero(ElT))
end
3 changes: 2 additions & 1 deletion NDTensors/src/abstractarray/ndims.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## NDTensors.ndims (not imported from Base)
using .TypeParameterAccessors: TypeParameterAccessors, Self, type_parameter, set_ndims, set_parameter
using .TypeParameterAccessors:
TypeParameterAccessors, Self, type_parameter, set_ndims, set_parameter

ndims((array::AbstractArray)) = ndims(typeof(array))
ndims(arraytype::Type{<:AbstractArray}) = type_parameter(arraytype, Base.ndims)
Expand Down
7 changes: 3 additions & 4 deletions NDTensors/src/abstractarray/similar.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## todo working on this still
using .TypeParameterAccessors: IsWrappedArray, set_eltype, specify_default_parameters, unwrap_array_type
using .TypeParameterAccessors:
IsWrappedArray, set_eltype, specify_default_parameters, unwrap_array_type

## Custom `NDTensors.similar` implementation.
## More extensive than `Base.similar`.
Expand Down Expand Up @@ -107,9 +108,7 @@ function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, dims::Tuple
end

## Set eltype captures WrappedArray types as long as `position(::Type, ::typeof(parenttype))` is defined
function similartype(
arraytype::Type{<:AbstractArray}, eltype::Type
)
function similartype(arraytype::Type{<:AbstractArray}, eltype::Type)
return set_eltype(arraytype, eltype)
end

Expand Down
22 changes: 20 additions & 2 deletions NDTensors/src/dense/typeparameteraccessors.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
using .TypeParameterAccessors:
TypeParameterAccessors, default_type_parameter, default_type_parameters, parenttype, position, unwrap_array_type, set_parenttype
TypeParameterAccessors,
default_type_parameter,
default_type_parameters,
parenttype,
position,
unwrap_array_type,
set_parenttype

## Dense
datatype(storetype::Type{<:Dense}) = parenttype(storetype)
Expand Down Expand Up @@ -29,4 +35,16 @@ function TypeParameterAccessors.position(::Type{<:Dense}, ::typeof(parenttype))
return TypeParameterAccessors.Position(2)
end

TypeParameterAccessors.default_type_parameters(::Type{<:Dense}) = (default_type_parameter(Vector, eltype), Vector{default_type_parameter(Vector, eltype)})
function TypeParameterAccessors.default_type_parameters(::Type{<:Dense})
return (
default_type_parameter(Vector, eltype), Vector{default_type_parameter(Vector, eltype)}
)
end

function TypeParameterAccessors.position(::Type{<:DenseTensor}, ::typeof(Base.ndims))
return TypeParameterAccessors.Position(2)
end

function TypeParameterAccessors.set_ndims(type::Type{<:DenseTensor}, N)
return set_parameter(type, Base.ndims, N)
end
5 changes: 3 additions & 2 deletions NDTensors/src/diag/diagtensor.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using .DiagonalArrays: diaglength
using .TypeParameterAccessors: unwrap_array_type

const DiagTensor{ElT,N,StoreT,IndsT} = Tensor{ElT,N,StoreT,IndsT} where {StoreT<:Diag}
const NonuniformDiagTensor{ElT,N,StoreT,IndsT} =
Expand Down Expand Up @@ -109,7 +110,7 @@ end

# convert to Dense
function dense(T::DiagTensor)
return dense(unwrap_type(T), T)
return dense(unwrap_array_type(T), T)
end

# CPU version
Expand All @@ -124,7 +125,7 @@ end
# GPU version
function dense(::Type{<:AbstractArray}, T::DiagTensor)
D_cpu = dense(Array, cpu(T))
return adapt(unwrap_type(T), D_cpu)
return adapt(unwrap_array_type(T), D_cpu)
end

# UniformDiag version
Expand Down
3 changes: 2 additions & 1 deletion NDTensors/src/tensorstorage/default_storage.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
using .TypeParameterAccessors: specify_default_parameters
## This is a fil which specifies the default storage type provided some set of parameters
## The parameters are the element type and storage type
default_datatype(eltype::Type=default_eltype()) = Vector{eltype}
default_eltype() = Float64

## TODO use multiple dispace to make this pick between dense and blocksparse
function default_storagetype(datatype::Type{<:AbstractArray}, inds::Tuple)
datatype = specify_parameters(datatype)
datatype = specify_default_parameters(datatype)
return Dense{eltype(datatype),datatype}
end

Expand Down
5 changes: 2 additions & 3 deletions NDTensors/src/tensorstorage/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using .SetParameters:
SetParameters, Position, get_parameters, specify_parameters, unspecify_parameters
function SetParameters.set_ndims(arraytype::Type{<:TensorStorage}, ndims::Int)
using .TypeParameterAccessors: TypeParameterAccessors
function TypeParameterAccessors.set_ndims(arraytype::Type{<:TensorStorage}, ndims::Int)
# TODO: Change to this once `TensorStorage` types support wrapping
# non-AbstractVector types.
# return set_datatype(arraytype, set_ndims(datatype(arraytype), ndims))
Expand Down

0 comments on commit 8b8182d

Please sign in to comment.