diff --git a/NDTensors/src/abstractarray/fill.jl b/NDTensors/src/abstractarray/fill.jl index 1739fe09cb..fa50d7e153 100644 --- a/NDTensors/src/abstractarray/fill.jl +++ b/NDTensors/src/abstractarray/fill.jl @@ -2,16 +2,18 @@ using .TypeParameterAccessors: set_ndims, specify_default_parameters, unwrap_arr ## TODO I modified this to accept any type and just match the output to the number of dims. # for example generic_zeros(Vector, 2,3) = Matrix{Float64}[0,0,0;0,0,0;] -function generic_randn( - arraytype::Type{<:AbstractArray}, dims...; rng=Random.default_rng() -) - arraytype_specified = set_ndims(specify_default_parameters(unwrap_array_type(arraytype)), length(dims)); +function generic_randn(arraytype::Type{<:AbstractArray}, dims...; rng=Random.default_rng()) + arraytype_specified = set_ndims( + specify_default_parameters(unwrap_array_type(arraytype)), length(dims) + ) data = similar(arraytype_specified, dims) return randn!(rng, data) end function generic_zeros(arraytype::Type{<:AbstractArray}, dims...) - arraytype_specified = set_ndims(specify_default_parameters(unwrap_array_type(arraytype)), length(dims)); + arraytype_specified = set_ndims( + specify_default_parameters(unwrap_array_type(arraytype)), length(dims) + ) ElT = eltype(arraytype_specified) return fill!(similar(arraytype_specified, dims), zero(ElT)) end diff --git a/NDTensors/src/abstractarray/ndims.jl b/NDTensors/src/abstractarray/ndims.jl index ac98c5faf3..846fb3b4ca 100644 --- a/NDTensors/src/abstractarray/ndims.jl +++ b/NDTensors/src/abstractarray/ndims.jl @@ -1,5 +1,6 @@ ## NDTensors.ndims (not imported from Base) -using .TypeParameterAccessors: TypeParameterAccessors, Self, type_parameter, set_ndims, set_parameter +using .TypeParameterAccessors: + TypeParameterAccessors, Self, type_parameter, set_ndims, set_parameter ndims((array::AbstractArray)) = ndims(typeof(array)) ndims(arraytype::Type{<:AbstractArray}) = type_parameter(arraytype, Base.ndims) diff --git a/NDTensors/src/abstractarray/similar.jl b/NDTensors/src/abstractarray/similar.jl index 97ad947e88..62bbf3f7bf 100644 --- a/NDTensors/src/abstractarray/similar.jl +++ b/NDTensors/src/abstractarray/similar.jl @@ -1,5 +1,6 @@ ## todo working on this still -using .TypeParameterAccessors: IsWrappedArray, set_eltype, specify_default_parameters, unwrap_array_type +using .TypeParameterAccessors: + IsWrappedArray, set_eltype, specify_default_parameters, unwrap_array_type ## Custom `NDTensors.similar` implementation. ## More extensive than `Base.similar`. @@ -107,9 +108,7 @@ function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, dims::Tuple end ## Set eltype captures WrappedArray types as long as `position(::Type, ::typeof(parenttype))` is defined -function similartype( - arraytype::Type{<:AbstractArray}, eltype::Type -) +function similartype(arraytype::Type{<:AbstractArray}, eltype::Type) return set_eltype(arraytype, eltype) end diff --git a/NDTensors/src/dense/typeparameteraccessors.jl b/NDTensors/src/dense/typeparameteraccessors.jl index dc9a87a106..34db7710f0 100644 --- a/NDTensors/src/dense/typeparameteraccessors.jl +++ b/NDTensors/src/dense/typeparameteraccessors.jl @@ -1,5 +1,11 @@ using .TypeParameterAccessors: - TypeParameterAccessors, default_type_parameter, default_type_parameters, parenttype, position, unwrap_array_type, set_parenttype + TypeParameterAccessors, + default_type_parameter, + default_type_parameters, + parenttype, + position, + unwrap_array_type, + set_parenttype ## Dense datatype(storetype::Type{<:Dense}) = parenttype(storetype) @@ -29,4 +35,16 @@ function TypeParameterAccessors.position(::Type{<:Dense}, ::typeof(parenttype)) return TypeParameterAccessors.Position(2) end -TypeParameterAccessors.default_type_parameters(::Type{<:Dense}) = (default_type_parameter(Vector, eltype), Vector{default_type_parameter(Vector, eltype)}) \ No newline at end of file +function TypeParameterAccessors.default_type_parameters(::Type{<:Dense}) + return ( + default_type_parameter(Vector, eltype), Vector{default_type_parameter(Vector, eltype)} + ) +end + +function TypeParameterAccessors.position(::Type{<:DenseTensor}, ::typeof(Base.ndims)) + return TypeParameterAccessors.Position(2) +end + +function TypeParameterAccessors.set_ndims(type::Type{<:DenseTensor}, N) + return set_parameter(type, Base.ndims, N) +end diff --git a/NDTensors/src/diag/diagtensor.jl b/NDTensors/src/diag/diagtensor.jl index ba0a93a920..ed219ca1e0 100644 --- a/NDTensors/src/diag/diagtensor.jl +++ b/NDTensors/src/diag/diagtensor.jl @@ -1,4 +1,5 @@ using .DiagonalArrays: diaglength +using .TypeParameterAccessors: unwrap_array_type const DiagTensor{ElT,N,StoreT,IndsT} = Tensor{ElT,N,StoreT,IndsT} where {StoreT<:Diag} const NonuniformDiagTensor{ElT,N,StoreT,IndsT} = @@ -109,7 +110,7 @@ end # convert to Dense function dense(T::DiagTensor) - return dense(unwrap_type(T), T) + return dense(unwrap_array_type(T), T) end # CPU version @@ -124,7 +125,7 @@ end # GPU version function dense(::Type{<:AbstractArray}, T::DiagTensor) D_cpu = dense(Array, cpu(T)) - return adapt(unwrap_type(T), D_cpu) + return adapt(unwrap_array_type(T), D_cpu) end # UniformDiag version diff --git a/NDTensors/src/tensorstorage/default_storage.jl b/NDTensors/src/tensorstorage/default_storage.jl index fe522da7a1..5e1f7c5d55 100644 --- a/NDTensors/src/tensorstorage/default_storage.jl +++ b/NDTensors/src/tensorstorage/default_storage.jl @@ -1,3 +1,4 @@ +using .TypeParameterAccessors: specify_default_parameters ## This is a fil which specifies the default storage type provided some set of parameters ## The parameters are the element type and storage type default_datatype(eltype::Type=default_eltype()) = Vector{eltype} @@ -5,7 +6,7 @@ default_eltype() = Float64 ## TODO use multiple dispace to make this pick between dense and blocksparse function default_storagetype(datatype::Type{<:AbstractArray}, inds::Tuple) - datatype = specify_parameters(datatype) + datatype = specify_default_parameters(datatype) return Dense{eltype(datatype),datatype} end diff --git a/NDTensors/src/tensorstorage/set_types.jl b/NDTensors/src/tensorstorage/set_types.jl index 174c655e90..7b27d4d245 100644 --- a/NDTensors/src/tensorstorage/set_types.jl +++ b/NDTensors/src/tensorstorage/set_types.jl @@ -1,6 +1,5 @@ -using .SetParameters: - SetParameters, Position, get_parameters, specify_parameters, unspecify_parameters -function SetParameters.set_ndims(arraytype::Type{<:TensorStorage}, ndims::Int) +using .TypeParameterAccessors: TypeParameterAccessors +function TypeParameterAccessors.set_ndims(arraytype::Type{<:TensorStorage}, ndims::Int) # TODO: Change to this once `TensorStorage` types support wrapping # non-AbstractVector types. # return set_datatype(arraytype, set_ndims(datatype(arraytype), ndims))