Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] UnallocatedArrays and UnspecifiedTypes #1213

Merged
merged 188 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 178 commits
Commits
Show all changes
188 commits
Select commit Hold shift + click to select a range
7453e04
Add FieldType module with Unallocated and Unspecified Types
kmp5VT Oct 18, 2023
af065e7
format
kmp5VT Oct 18, 2023
bd2a078
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Oct 18, 2023
5fa5fab
Merge branch 'main' into kmp5/feature/FieldTypes
mtfishman Oct 19, 2023
86872e1
Split up fieldtypes into Unallocated and Unspecified modules
kmp5VT Oct 19, 2023
071ae92
More UnallocatedFill implementation.
kmp5VT Oct 23, 2023
1a849c9
Start simplifying the UnallocatedZeros file
kmp5VT Oct 23, 2023
297aaa2
Fix typo
kmp5VT Oct 23, 2023
3fc526c
format
kmp5VT Oct 23, 2023
c9fb989
Update to the Unallocatedfill/zero code
kmp5VT Oct 23, 2023
e0e3652
format
kmp5VT Oct 23, 2023
7e58e30
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Oct 24, 2023
ae13d96
Updates to code. Still working on all of these
kmp5VT Oct 24, 2023
3839f88
Add some promote rules. Still working here
kmp5VT Oct 24, 2023
c1b834f
format
kmp5VT Oct 24, 2023
e9d8b06
Merge commit '66f5d391a6ef01834658ec9d5a9a8ba3d4b0f827' into kmp5/fea…
kmp5VT Nov 2, 2023
83bbd37
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 2, 2023
e2c59cb
Start working on Unallocated settype functions
kmp5VT Nov 2, 2023
145882e
format
kmp5VT Nov 2, 2023
8438d5d
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 8, 2023
b5ac4ce
Fix set_types
kmp5VT Nov 8, 2023
c136685
formatting
kmp5VT Nov 8, 2023
4d1b976
format
kmp5VT Nov 8, 2023
4e265ce
Update UnallocatedArrays
kmp5VT Nov 8, 2023
2b2a6b2
format
kmp5VT Nov 8, 2023
794d6ba
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 8, 2023
305c47b
CuArray name change, fix warning
kmp5VT Nov 8, 2023
1468c69
Some fixes
kmp5VT Nov 8, 2023
a25ec05
Add comment
kmp5VT Nov 8, 2023
a824fab
I don't need getindex function and `[I...]` still works
kmp5VT Nov 8, 2023
0fdb0a4
Convert is implemented wrong
kmp5VT Nov 8, 2023
fe051c5
Import setparameters
kmp5VT Nov 9, 2023
547fcd7
Working on set_parameter
kmp5VT Nov 9, 2023
cf1dff4
format
kmp5VT Nov 9, 2023
7963a54
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 9, 2023
568fe7d
Create set_alloctype constructor for UnallocatedArrays
kmp5VT Nov 9, 2023
52895c0
Merge branch 'kmp5/feature/FieldTypes' of github.com:kmp5VT/ITensors.…
kmp5VT Nov 9, 2023
8d0135d
Yes we should force Alloc to be the same dimensions as the `Fill`
kmp5VT Nov 9, 2023
9056599
Add set_alloctype and complex
kmp5VT Nov 9, 2023
c646598
format
kmp5VT Nov 9, 2023
2a9bc4f
We don't need to define norm
kmp5VT Nov 9, 2023
25b1ef4
Remove unecessary imports
kmp5VT Nov 9, 2023
c0de9dc
Remove set_eltype
kmp5VT Nov 9, 2023
2ad44e9
Remove line per matts comment
kmp5VT Nov 9, 2023
c56b3ac
Move set_alloctype constructors to their own files
kmp5VT Nov 9, 2023
fcd1997
Some updates
kmp5VT Nov 9, 2023
ab7e910
Add a unittest
kmp5VT Nov 9, 2023
e007458
format
kmp5VT Nov 9, 2023
bca8e77
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 9, 2023
df752fe
Remove @eval and fix the complex function
kmp5VT Nov 9, 2023
d254db8
Add complex testing
kmp5VT Nov 9, 2023
c4dc2af
format
kmp5VT Nov 9, 2023
4e6036b
Make a UnallocFillOrZero union variable
kmp5VT Nov 9, 2023
f87c2d3
format
kmp5VT Nov 9, 2023
aef5623
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 15, 2023
d635004
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 16, 2023
a0ceca4
Move abstractUnallocatedArray functions to file
kmp5VT Nov 20, 2023
8925888
Add UnallocatedArray.jl
kmp5VT Nov 20, 2023
2a694c0
Update set_types to use AbstractUnallocatedArrays
kmp5VT Nov 20, 2023
3027afa
Make other objects subtype of AbstractUnallocatedArray
kmp5VT Nov 20, 2023
85a9837
Add norm
kmp5VT Nov 20, 2023
3cbfdbe
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 30, 2023
b3ab1d6
Merge branch 'kmp5/feature/FieldTypes' of github.com:kmp5VT/ITensors.…
kmp5VT Nov 30, 2023
be31025
Move unallocatedArrays and UnspecifiedTypes to lib folder
kmp5VT Nov 30, 2023
7fe90f2
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 3, 2023
d945bd4
Merge branch 'kmp5/feature/FieldTypes' of github.com:kmp5VT/ITensors.…
kmp5VT Dec 3, 2023
d64f3de
Working on UnallocatedArrays
kmp5VT Dec 3, 2023
ac5cb59
Remove old file
kmp5VT Dec 4, 2023
834287a
Updates to UnallocatedArrays system
kmp5VT Dec 4, 2023
4c66224
Remove import to be more explicit
kmp5VT Dec 4, 2023
19b29df
Cleanup/fixes
kmp5VT Dec 4, 2023
ba869d3
Remove import from module
kmp5VT Dec 4, 2023
864c1be
Clean up
kmp5VT Dec 4, 2023
7fe9457
Add tests to NDTensors
kmp5VT Dec 4, 2023
01a828f
format
kmp5VT Dec 4, 2023
af20c8f
additional updates
kmp5VT Dec 4, 2023
3dc4514
Add FillArrays to NDTensors test
kmp5VT Dec 4, 2023
fdc440f
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 4, 2023
a3dca2b
Remove type piracy code
kmp5VT Dec 4, 2023
41d3f75
Merge branch 'kmp5/feature/FieldTypes' of github.com:kmp5VT/ITensors.…
kmp5VT Dec 4, 2023
eedc125
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 4, 2023
3fcbe11
Moved shared function to defaultunallocatedarray.jl
kmp5VT Dec 4, 2023
f92e4d9
format
kmp5VT Dec 4, 2023
71031a3
force itensors to update packages, using old FillArrays
kmp5VT Dec 4, 2023
9ea0ce5
change order
kmp5VT Dec 4, 2023
43933b9
Rename
kmp5VT Dec 4, 2023
9b963fd
rename
kmp5VT Dec 4, 2023
c4c4a74
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 6, 2023
e3ccfef
Define mult_zero for UnallocatedZeros and make unit tests
kmp5VT Dec 7, 2023
ec266f8
Define transpose and adjoint for unallocatedarrays
kmp5VT Dec 7, 2023
57d52e1
Define broadcasted_zeros
kmp5VT Dec 7, 2023
b6a51e7
Make kron_zero functions
kmp5VT Dec 7, 2023
0c4c92e
format
kmp5VT Dec 7, 2023
3c18596
fix typo
kmp5VT Dec 7, 2023
60ebf27
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 8, 2023
1dbf77f
Merge commit '60ebf274959a371f52588a9c500f00b873591262' into kmp5/fea…
kmp5VT Dec 8, 2023
57b9002
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 8, 2023
d47e50a
Remove update
kmp5VT Dec 8, 2023
b5d3609
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 8, 2023
82519ed
Merge branch 'kmp5/feature/FieldTypes' of github.com:kmp5VT/ITensors.…
kmp5VT Dec 8, 2023
7653d91
Add some functions to fix
kmp5VT Dec 8, 2023
0699e89
add fill functions
kmp5VT Dec 8, 2023
4a1cef2
overload kron_fill not kron_zeros
kmp5VT Dec 8, 2023
40af2cc
Update functions needed
kmp5VT Dec 8, 2023
1d107a6
Add more tests and sort them
kmp5VT Dec 8, 2023
e05d5ff
format
kmp5VT Dec 8, 2023
410a8df
Add more tests
kmp5VT Dec 8, 2023
5e3601e
format
kmp5VT Dec 8, 2023
3ce2ce8
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 11, 2023
4246757
Add more tests for UnallocatedFill
kmp5VT Dec 11, 2023
417e025
Move tests to braodcast
kmp5VT Dec 11, 2023
4e70bb3
Add more tests
kmp5VT Dec 11, 2023
52e21b0
Create a trival convert foreach UnallocatedArray
kmp5VT Dec 11, 2023
d874b1b
Remove promote_rules
kmp5VT Dec 12, 2023
611420b
Move using to each file
kmp5VT Dec 12, 2023
6011fbc
Create an unallocatedArray typedef
kmp5VT Dec 12, 2023
bcf89f1
Remove comment
kmp5VT Dec 12, 2023
a6b9ec7
Definitions to fix UnallocatedZeros addition
kmp5VT Dec 12, 2023
03655d0
Some additional fixes
kmp5VT Dec 12, 2023
e9ed6c5
Remove tensor tests
kmp5VT Dec 12, 2023
0c03fab
Test similar
kmp5VT Dec 12, 2023
2a04e0a
Remove NDTensor variables
kmp5VT Dec 12, 2023
24f185c
Remove comment
kmp5VT Dec 12, 2023
b64372d
Add comment
kmp5VT Dec 12, 2023
afebe2a
format
kmp5VT Dec 12, 2023
5bd1fd3
Update constructors and fix + for UnallocatedFill
kmp5VT Dec 12, 2023
a41a70c
Make only one constructor for now that requires an abstractarray type
kmp5VT Dec 13, 2023
4bd3028
Assue alloc defines constructor for type
kmp5VT Dec 13, 2023
cd42628
Remove comment
kmp5VT Dec 13, 2023
16daf49
remove comment
kmp5VT Dec 13, 2023
2693dfc
format
kmp5VT Dec 13, 2023
b7db76f
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 18, 2023
e7129b3
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 19, 2023
d149ef7
Remove `<:`
kmp5VT Dec 19, 2023
fab15ff
Update complex function to use `set_eltype`
kmp5VT Dec 19, 2023
baaf361
Some fixes
kmp5VT Dec 19, 2023
ea8735f
Update Base.:+
kmp5VT Dec 19, 2023
edf78cb
format
kmp5VT Dec 19, 2023
b33f82e
fix assert parentheses [no-ci]
kmp5VT Dec 19, 2023
e2c53d1
Use smallest number size and `iszero` function
kmp5VT Dec 19, 2023
1d6da8d
Merge branch 'kmp5/feature/FieldTypes' of github.com:kmp5VT/ITensors.…
kmp5VT Dec 19, 2023
5408031
Remove set_eltype function
kmp5VT Dec 22, 2023
199158b
Rename to unspecified_parameters and use in set_parameter functions.
kmp5VT Dec 22, 2023
d34f637
Add test for SetParameters of FillArray
kmp5VT Dec 22, 2023
fee954b
format
kmp5VT Dec 22, 2023
b74140c
Remove comments
kmp5VT Dec 22, 2023
5b5c057
format
kmp5VT Dec 22, 2023
8f0e361
Add SetParameter tests
kmp5VT Dec 22, 2023
85ece94
Overwrite Base.Braodcast.broadcasted for UnallocatedArrays to preserv…
kmp5VT Dec 22, 2023
aabd4e9
Clean up set types with Unspecify_parameters and UnallocatedArray
kmp5VT Dec 22, 2023
16afdb8
Make more tests for UnallocatedArrays
kmp5VT Dec 22, 2023
8f36ed4
Redefine allocate using similar
kmp5VT Dec 22, 2023
026a487
format
kmp5VT Dec 22, 2023
09fb302
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 22, 2023
68736f5
Update similar to have an allocate function based on AbstractArrays
kmp5VT Dec 22, 2023
a16027e
Set N
kmp5VT Dec 22, 2023
623af88
Move UnspecifiedTypes higher because its used in SetParameters
kmp5VT Dec 22, 2023
c18f1fb
To prepare for "SetParameters" to be a registered
kmp5VT Dec 22, 2023
e136946
format
kmp5VT Dec 22, 2023
b383f06
Add comments about set_eltype and set_ndim function
kmp5VT Jan 2, 2024
dcd06a8
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Jan 5, 2024
13c344c
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Jan 5, 2024
7de3788
Bump to jenkins to julia 1.10
kmp5VT Jan 5, 2024
2b48d3e
Remove type constraint in `set_alloctype`
kmp5VT Jan 8, 2024
108d8d9
No parenthasis around assert
kmp5VT Jan 8, 2024
ca5bb48
Define `similar(unallocated, eltype, size)` as is called by `similar(…
kmp5VT Jan 8, 2024
b63f603
Remove unecessary broadcast call
kmp5VT Jan 8, 2024
4df754c
Remove `::tuple` from `allocate`
kmp5VT Jan 8, 2024
9646e5d
loosen tollerance on a contract result
kmp5VT Jan 8, 2024
26d7699
Don't test if an element isn't 3 when constructing with similar. Test…
kmp5VT Jan 8, 2024
8ae8539
format
kmp5VT Jan 8, 2024
2f85238
Move Type based constructor out of UnallocatedX
kmp5VT Jan 9, 2024
209ce52
Add tests where parameters are unset
kmp5VT Jan 9, 2024
c538803
Fix a typo
kmp5VT Jan 9, 2024
35d01a5
Remove test and add todo
kmp5VT Jan 9, 2024
d0d27ba
format
kmp5VT Jan 9, 2024
abe185e
format
kmp5VT Jan 9, 2024
7a536ef
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Jan 9, 2024
c1f1410
Add broadcast functions for `UnallocatedZeros` to preserve type
kmp5VT Jan 10, 2024
bc62723
To ensure that FillArrays can be allocated, set ndims as well as eltype
kmp5VT Jan 10, 2024
a34337d
Update broadcast functions
kmp5VT Jan 10, 2024
777fd84
Add more tests to UnallocatedArrays
kmp5VT Jan 10, 2024
529c450
Remove inline `FillArrays.`
kmp5VT Jan 10, 2024
c62d09e
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Jan 11, 2024
a8f21fc
Use broadcast fill in `UnallocatedFill`
kmp5VT Jan 11, 2024
7bc4d0e
Replace f with z
kmp5VT Jan 11, 2024
f19e090
Comment out test which is only occasionally broken
kmp5VT Jan 11, 2024
f8ed013
format
kmp5VT Jan 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ for lib in [
:AlgorithmSelection,
:AllocateData,
:BaseExtensions,
:UnspecifiedTypes,
:SetParameters,
:BroadcastMapConversion,
:Unwrap,
Expand All @@ -37,6 +38,7 @@ for lib in [
:SmallVectors,
:SortedSets,
:TagSets,
:UnallocatedArrays,
]
include("lib/$(lib)/src/$(lib).jl")
@eval using .$lib: $lib
Expand Down
8 changes: 8 additions & 0 deletions NDTensors/src/lib/SetParameters/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[deps]
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"

[extensions]
SetParametersFillArraysExt = "FillArrays"

[compat]
PackageExtensionCompat = "1"
9 changes: 9 additions & 0 deletions NDTensors/src/lib/SetParameters/TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,12 @@ default_parameter(::Type{<:AbstractArray}, ::Position{2}) = 1

nparameters(::Type{<:AbstractArray}) = Val(2)
```

https://github.com/ITensor/ITensors.jl/pull/1213/files#r1431708585

# Create generic set_eltype and set_ndims functions which can be defined on
# Array structures and use the `set_parameter`
```julia
set_eltype(T::Type, elt::Type) = Error("set_eltype is not defined for datatype $T")
set_eltype(T::Type{<:Array}, elt::Type) = set_parameter(T, Position{1}(), elt)
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
module SetParametersFillArraysExt
include("set_types.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using FillArrays: AbstractFill, Fill, Zeros
using NDTensors.SetParameters: SetParameters, Position
using NDTensors.UnspecifiedTypes: UnspecifiedZero

# `SetParameters.jl` overloads.
SetParameters.get_parameter(::Type{<:AbstractFill{P1}}, ::Position{1}) where {P1} = P1
SetParameters.get_parameter(::Type{<:AbstractFill{<:Any,P2}}, ::Position{2}) where {P2} = P2
function SetParameters.get_parameter(
::Type{<:AbstractFill{<:Any,<:Any,P3}}, ::Position{3}
) where {P3}
return P3
end

## Setting paramaters
# Set parameter 1
function SetParameters.set_parameter(T::Type{<:AbstractFill}, ::Position{1}, P1)
return unspecify_parameters(T){P1}
end
function SetParameters.set_parameter(
T::Type{<:AbstractFill{<:Any,P2}}, ::Position{1}, P1
) where {P2}
return unspecify_parameters(T){P1,P2}
end
function SetParameters.set_parameter(
T::Type{<:AbstractFill{<:Any,P2,P3}}, ::Position{1}, P1
) where {P2,P3}
return unspecify_parameters(T){P1,P2,P3}
end

# Set parameter 2
function SetParameters.set_parameter(
T::Type{<:AbstractFill{P1}}, ::Position{2}, P2
) where {P1}
return unspecify_parameters(T){P1,P2}
end
function SetParameters.set_parameter(
T::Type{<:AbstractFill{P1,P}}, ::Position{2}, P2
) where {P1,P}
return unspecify_parameters(T){P1,P2}
end
function SetParameters.set_parameter(
T::Type{<:AbstractFill{P1,P,P3}}, ::Position{2}, P2
) where {P1,P,P3}
return unspecify_parameters(T){P1,P2,P3}
end

# Set parameter 3
function SetParameters.set_parameter(
T::Type{<:AbstractFill{P1}}, ::Position{3}, P3
) where {P1}
return unspecify_parameters(T){P1,<:Any,P3}
end
function SetParameters.set_parameter(
T::Type{<:AbstractFill{P1,P2}}, ::Position{3}, P3
) where {P1,P2}
return unspecify_parameters(T){P1,P2,P3}
end
function SetParameters.set_parameter(
T::Type{<:AbstractFill{P1,P2,P}}, ::Position{3}, P3
) where {P1,P2,P}
return unspecify_parameters(T){P1,P2,P3}
end

## default parameters
function SetParameters.default_parameter(::Type{<:AbstractFill}, ::Position{1})
return UnspecifiedZero
end
SetParameters.default_parameter(::Type{<:AbstractFill}, ::Position{2}) = 0
SetParameters.default_parameter(::Type{<:AbstractFill}, ::Position{3}) = Tuple{}

SetParameters.nparameters(::Type{<:AbstractFill}) = Val(3)

## These helper functions take a UnallocatedArray type and
## remove all the parameters, this way all parameters can be set
## at once in the `set_parameter` functions above.
unspecify_parameters(::Type{<:Fill}) = Fill
unspecify_parameters(::Type{<:Zeros}) = Zeros
9 changes: 9 additions & 0 deletions NDTensors/src/lib/SetParameters/src/SetParameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ include("Base/val.jl")
include("Base/array.jl")
include("Base/subarray.jl")

## TODO when this is a full package utilize this function to
# # enable extensions
# using PackageExtensionCompat
# function __init__()
# @require_extensions
# end

include("../ext/SetParametersFillArraysExt/SetParametersFillArraysExt.jl")

export DefaultParameter,
DefaultParameters,
Position,
Expand Down
7 changes: 4 additions & 3 deletions NDTensors/src/lib/TensorAlgebra/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,10 @@ end
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1, labels1, a2, labels2
)
@test a_dest ≈ α * a_dest_tensoroperations + β * a_dest_init rtol = default_rtol(
elt_dest
)
## Here we loosened the tolerance because of some floating point roundoff issue.
## with Float32 numbers
@test a_dest ≈ α * a_dest_tensoroperations + β * a_dest_init rtol =
10 * default_rtol(elt_dest)
end
end
@testset "qr (eltype=$elt)" for elt in elts
Expand Down
4 changes: 4 additions & 0 deletions NDTensors/src/lib/UnallocatedArrays/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# UnallocatedArrays

A module defining a set of unallocated immutable lazy arrays which will be used to quickly construct
tensors and allocating as little data as possible.
10 changes: 10 additions & 0 deletions NDTensors/src/lib/UnallocatedArrays/src/UnallocatedArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module UnallocatedArrays
include("abstractfill/abstractfill.jl")

include("unallocatedfill.jl")
include("unallocatedzeros.jl")
include("abstractunallocatedarray.jl")
include("set_types.jl")

export UnallocatedFill, UnallocatedZeros, alloctype, set_alloctype, allocate
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using FillArrays: AbstractFill
using NDTensors.SetParameters: Position, get_parameter, set_parameters
## Here are functions specifically defined for UnallocatedArrays
## not implemented by FillArrays
## TODO this might need a more generic name maybe like compute unit
function alloctype(A::AbstractFill)
return A.alloc
end

## TODO this fails if the parameter is a type
function alloctype(Atype::Type{<:AbstractFill})
return get_parameter(Atype, Position{4}())
end

set_eltype(T::Type{<:AbstractFill}, elt::Type) = set_parameters(T, Position{1}(), elt)
set_ndims(T::Type{<:AbstractFill}, n) = set_parameters(T, Position{2}(), n)
set_axestype(T::Type{<:AbstractFill}, ax::Type) = set_parameters(T, Position{3}(), ax)
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using FillArrays: FillArrays, getindex_value
using NDTensors.SetParameters: set_parameters
using Adapt: adapt

const UnallocatedArray{ElT,N,AxesT,AllocT} = Union{
UnallocatedFill{ElT,N,AxesT,AllocT},UnallocatedZeros{ElT,N,AxesT,AllocT}
}

@inline Base.axes(A::UnallocatedArray) = axes(parent(A))
Base.size(A::UnallocatedArray) = size(parent(A))
function FillArrays.getindex_value(A::UnallocatedArray)
return getindex_value(parent(A))
end

function Base.complex(A::UnallocatedArray)
return complex(eltype(A)).(A)
end

function Base.transpose(a::UnallocatedArray)
return set_alloctype(transpose(parent(a)), alloctype(a))
end

function Base.adjoint(a::UnallocatedArray)
return set_alloctype(adjoint(parent(a)), alloctype(a))
end

function set_alloctype(T::Type{<:UnallocatedArray}, alloc::Type{<:AbstractArray})
return set_parameters(T, Position{4}(), alloc)
end

## This overloads the definition defined in `FillArrays.jl`
for STYPE in (:AbstractArray, :AbstractFill)
@eval begin
@inline $STYPE{T}(F::UnallocatedArray{T}) where {T} = F
@inline $STYPE{T,N}(F::UnallocatedArray{T,N}) where {T,N} = F
end
end

function allocate(f::UnallocatedArray)
a = similar(f)
fill!(a, getindex_value(f))
return a
end

function allocate(arraytype::Type{<:AbstractArray}, elt::Type, axes)
## TODO rewrite this using set_eltype and set_ndims functions
## currently these functions are defined in `NDTensors`
## In the future they should be defined in `SetParameters`
ArrayT = set_parameters(arraytype, Position{1}(), elt)
return similar(ArrayT, axes)
end

function Base.similar(f::UnallocatedArray, elt::Type, axes::Tuple{Int64,Vararg{Int64}})
return allocate(alloctype(f), elt, axes)
end

## TODO fix this because reshape loses alloctype
#FillArrays.reshape(a::Union{<:UnallocatedFill, <:UnallocatedZeros}, dims) = set_alloctype(reshape(parent(a), dims), allocate(a))

# function Adapt.adapt_storage(to::Type{<:AbstractArray}, x::Union{<:UnallocatedFill, <:UnallocatedZeros})
# return set_alloctype(parent(x), to)
# end

# function Adapt.adapt_storage(to::Type{<:Number}, x::)
45 changes: 45 additions & 0 deletions NDTensors/src/lib/UnallocatedArrays/src/set_types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using NDTensors.SetParameters: SetParameters, Position
using NDTensors.UnspecifiedTypes: UnspecifiedArray, UnspecifiedNumber, UnspecifiedZero
# ## TODO make unit tests for all of these functions
## TODO All I need to do is overload AbstractFill functions with 4 parameters
# `SetParameters.jl` overloads.
function SetParameters.get_parameter(
::Type{<:UnallocatedArray{<:Any,<:Any,<:Any,P4}}, ::Position{4}
) where {P4}
return P4
end

# ## Setting paramaters
function SetParameters.set_parameter(
T::Type{<:UnallocatedArray{P,P2,P3,P4}}, ::Position{1}, P1
) where {P,P2,P3,P4}
return unspecify_parameters(T){P1,P2,P3,P4}
end

function SetParameters.set_parameter(
T::Type{<:UnallocatedArray{P1,P,P3,P4}}, ::Position{2}, P2
) where {P1,P,P3,P4}
return unspecify_parameters(T){P1,P2,P3,P4}
end

function SetParameters.set_parameter(
T::Type{<:UnallocatedArray{P1,P2,P,P4}}, ::Position{3}, P3
) where {P1,P2,P,P4}
return unspecify_parameters(T){P1,P2,P3,P4}
end

function SetParameters.set_parameter(
T::Type{<:UnallocatedArray{P1,P2,P3}}, ::Position{4}, P4
) where {P1,P2,P3}
return unspecify_parameters(T){P1,P2,P3,P4}
end

# ## default parameters
function SetParameters.default_parameter(::Type{<:UnallocatedArray}, ::Position{4})
return UnspecifiedArray{UnspecifiedNumber{UnspecifiedZero},0}
end

SetParameters.nparameters(::Type{<:UnallocatedArray}) = Val(4)

unspecify_parameters(::Type{<:UnallocatedFill}) = UnallocatedFill
unspecify_parameters(::Type{<:UnallocatedZeros}) = UnallocatedZeros
67 changes: 67 additions & 0 deletions NDTensors/src/lib/UnallocatedArrays/src/unallocatedfill.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using FillArrays: FillArrays, AbstractFill, Fill, broadcasted_fill, kron_fill, mult_fill
using NDTensors.SetParameters: Position, set_parameters

struct UnallocatedFill{ElT,N,Axes,Alloc} <: AbstractFill{ElT,N,Axes}
f::Fill{ElT,N,Axes}
alloc::Alloc
end

function UnallocatedFill{ElT,N,Axes}(f::Fill, alloc::Type) where {ElT,N,Axes}
return new{ElT,N,Axes,Type{alloc}}(f, alloc)
end

function UnallocatedFill{ElT,N}(f::Fill, alloc) where {ElT,N}
return UnallocatedFill{ElT,N,typeof(axes(f))}(f, alloc)
end

function UnallocatedFill{ElT}(f::Fill, alloc) where {ElT}
return UnallocatedFill{ElT,ndims(f)}(f, alloc)
end

set_alloctype(f::Fill, alloc::Type) = UnallocatedFill(f, alloc)

Base.parent(F::UnallocatedFill) = F.f

Base.convert(::Type{<:UnallocatedFill}, A::UnallocatedFill) = A

#############################################
# Arithmatic

# mult_fill(a, b, val, ax) = Fill(val, ax)
function FillArrays.mult_fill(a::UnallocatedFill, b, val, ax)
return UnallocatedFill(Fill(val, ax), alloctype(a))
end
FillArrays.mult_fill(a, b::UnallocatedFill, val, ax) = mult_fill(b, a, val, ax)
function FillArrays.mult_fill(a::UnallocatedFill, b::UnallocatedFill, val, ax)
@assert alloctype(a) == alloctype(b)
return UnallocatedFill(Fill(val, ax), alloctype(a))
end

function FillArrays.broadcasted_fill(f, a::UnallocatedFill, val, ax)
return UnallocatedFill(Fill(val, ax), alloctype(a))
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
end
function FillArrays.broadcasted_fill(f, a::UnallocatedFill, b::UnallocatedFill, val, ax)
@assert alloctype(a) == alloctype(b)
return UnallocatedFill(Fill(val, ax), alloctype(a))
end

function FillArrays.broadcasted_fill(f, a::UnallocatedFill, b, val, ax)
return UnallocatedFill(Fill(val, ax), alloctype(a))
end
function FillArrays.broadcasted_fill(f, a, b::UnallocatedFill, val, ax)
return broadcasted_fill(f, b, a, val, ax)
end

function FillArrays.kron_fill(a::UnallocatedFill, b::UnallocatedFill, val, ax)
@assert alloctype(a) == alloctype(b)
return UnallocatedFill(Fill(val, ax), alloctype(a))
end

Base.:+(A::UnallocatedFill, B::UnallocatedFill) = A .+ B

function Base.Broadcast.broadcasted(
::Base.Broadcast.DefaultArrayStyle, op, r::UnallocatedFill
)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
f = op.(parent(r))
return set_alloctype(f, set_parameters(alloctype(r), Position{1}(), eltype(f)))
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
end
Loading
Loading