Skip to content

Commit

Permalink
[NDTensors] Add SortedSets and new TagSet prototype (#1204)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Oct 4, 2023
1 parent ea0f602 commit a9eb3cf
Show file tree
Hide file tree
Showing 32 changed files with 1,127 additions and 6 deletions.
1 change: 1 addition & 0 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand Down
5 changes: 5 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Compat
using Dictionaries
using FLoops
using Folds
using InlineStrings
using Random
using LinearAlgebra
using StaticArrays
Expand All @@ -21,6 +22,10 @@ include("SetParameters/src/SetParameters.jl")
using .SetParameters
include("SmallVectors/src/SmallVectors.jl")
using .SmallVectors
include("SortedSets/src/SortedSets.jl")
using .SortedSets
include("TagSets/src/TagSets.jl")
using .TagSets

using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

Expand Down
9 changes: 9 additions & 0 deletions NDTensors/src/SmallVectors/src/BaseExt/insertstyle.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Trait determining the style of inserting into a structure
abstract type InsertStyle end
struct IsInsertable <: InsertStyle end
struct NotInsertable <: InsertStyle end
struct FastCopy <: InsertStyle end

# Assume is insertable
@inline InsertStyle(::Type) = IsInsertable()
@inline InsertStyle(x) = InsertStyle(typeof(x))
242 changes: 242 additions & 0 deletions NDTensors/src/SmallVectors/src/BaseExt/sortedunique.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# Union two unique sorted collections into an
# output buffer, returning a unique sorted collection.

using Base: Ordering, ord, lt

function unionsortedunique!(
itr1,
itr2;
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
)
return unionsortedunique!(itr1, itr2, ord(lt, by, rev, order))
end

function unionsortedunique!(itr1, itr2, order::Ordering)
i1 = firstindex(itr1)
i2 = firstindex(itr2)
stop1 = lastindex(itr1)
stop2 = lastindex(itr2)
@inbounds while i1 stop1 && i2 stop2
item1 = itr1[i1]
item2 = itr2[i2]
if lt(order, item1, item2)
i1 += 1
elseif lt(order, item2, item1)
# TODO: Use `insertat!`?
resize!(itr1, length(itr1) + 1)
for j in length(itr1):-1:(i1 + 1)
itr1[j] = itr1[j - 1]
end
itr1[i1] = item2
i1 += 1
i2 += 1
stop1 += 1
else # They are equal
i1 += 1
i2 += 1
end
end
# TODO: Use `insertat!`?
resize!(itr1, length(itr1) + (stop2 - i2 + 1))
@inbounds for j2 in i2:stop2
itr1[i1] = itr2[j2]
i1 += 1
end
return itr1
end

function unionsortedunique(
itr1,
itr2;
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
)
return unionsortedunique(itr1, itr2, ord(lt, by, rev, order))
end

# Union two unique sorted collections into an
# output buffer, returning a unique sorted collection.
function unionsortedunique(itr1, itr2, order::Ordering)
out = thaw_type(itr1)(undef, length(itr1))
i1 = firstindex(itr1)
i2 = firstindex(itr2)
iout = firstindex(out)
stop1 = lastindex(itr1)
stop2 = lastindex(itr2)
stopout = lastindex(out)
@inbounds while i1 stop1 && i2 stop2
iout > stopout && resize!(out, iout)
item1 = itr1[i1]
item2 = itr2[i2]
if lt(order, item1, item2)
out[iout] = item1
iout += 1
i1 += 1
elseif lt(order, item2, item1)
out[iout] = item2
iout += 1
i2 += 1
else # They are equal
out[iout] = item1
iout += 1
i1 += 1
i2 += 1
end
end
# In case `out` was too long to begin with.
## resize!(out, iout - 1)
# TODO: Use `insertat!`?
r1 = i1:stop1
resize!(out, length(out) + length(r1))
@inbounds for j1 in r1
out[iout] = itr1[j1]
iout += 1
end
# TODO: Use `insertat!`?
r2 = i2:stop2
resize!(out, length(out) + length(r2))
@inbounds for j2 in r2
out[iout] = itr2[j2]
iout += 1
end
return freeze(out)
end

function setdiffsortedunique!(
itr1,
itr2;
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
)
return setdiffsortedunique!(itr1, itr2, ord(lt, by, rev, order))
end

function setdiffsortedunique!(itr1, itr2, order::Ordering)
i1 = firstindex(itr1)
i2 = firstindex(itr2)
stop1 = lastindex(itr1)
stop2 = lastindex(itr2)
@inbounds while i1 stop1 && i2 stop2
item1 = itr1[i1]
item2 = itr2[i2]
if lt(order, item1, item2)
i1 += 1
elseif lt(order, item2, item1)
i2 += 1
else # They are equal
# TODO: Use `deletate!`?
for j1 in i1:(length(itr1) - 1)
itr1[j1] = itr1[j1 + 1]
end
resize!(itr1, length(itr1) - 1)
stop1 = lastindex(itr1)
i2 += 1
end
end
return itr1
end

function setdiffsortedunique(
itr1,
itr2;
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
)
return setdiffsortedunique(itr1, itr2, ord(lt, by, rev, order))
end

function setdiffsortedunique(itr1, itr2, order::Ordering)
out = thaw_type(itr1)()
i1 = firstindex(itr1)
i2 = firstindex(itr2)
iout = firstindex(out)
stop1 = lastindex(itr1)
stop2 = lastindex(itr2)
stopout = lastindex(out)
@inbounds while i1 stop1 && i2 stop2
item1 = itr1[i1]
item2 = itr2[i2]
if lt(order, item1, item2)
iout > stopout && resize!(out, iout)
out[iout] = item1
iout += 1
i1 += 1
elseif lt(order, item2, item1)
i2 += 1
else # They are equal
i1 += 1
i2 += 1
end
end
resize!(out, iout - 1)
return freeze(out)
end

function intersectsortedunique!(
itr1,
itr2;
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
)
return intersectsortedunique!(itr1, itr2, ord(lt, by, rev, order))
end

function intersectsortedunique!(itr1, itr2, order::Ordering)
return error("Not implemented")
end

function intersectsortedunique(
itr1,
itr2;
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
)
return intersectsortedunique(itr1, itr2, ord(lt, by, rev, order))
end

function intersectsortedunique(itr1, itr2, order::Ordering)
return error("Not implemented")
end

function symdiffsortedunique!(
itr1,
itr2;
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
)
return symdiffsortedunique!(itr1, itr2, ord(lt, by, rev, order))
end

function symdiffsortedunique!(itr1, itr2, order::Ordering)
return error("Not implemented")
end

function symdiffsortedunique(
itr1,
itr2;
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
)
return symdiffsortedunique(itr1, itr2, ord(lt, by, rev, order))
end

function symdiffsortedunique(itr1, itr2, order::Ordering)
return error("Not implemented")
end
6 changes: 6 additions & 0 deletions NDTensors/src/SmallVectors/src/BaseExt/thawfreeze.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
thaw(x) = copy(x)
freeze(x) = x

thaw_type(::Type{<:AbstractArray{<:Any,N}}, ::Type{T}) where {T,N} = Array{T,N}
thaw_type(x::AbstractArray, ::Type{T}) where {T} = thaw_type(typeof(x), T)
thaw_type(x::AbstractArray{T}) where {T} = thaw_type(typeof(x), T)
30 changes: 29 additions & 1 deletion NDTensors/src/SmallVectors/src/SmallVectors.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,44 @@
module SmallVectors
using StaticArrays

export SmallVector, MSmallVector, SubSmallVector
export AbstractSmallVector,
SmallVector,
MSmallVector,
SubSmallVector,
FastCopy,
InsertStyle,
IsInsertable,
NotInsertable,
insert,
delete,
thaw,
freeze,
maxlength,
unionsortedunique,
unionsortedunique!,
setdiffsortedunique,
setdiffsortedunique!,
intersectsortedunique,
intersectsortedunique!,
symdiffsortedunique,
symdiffsortedunique!,
thaw_type

struct NotImplemented <: Exception
msg::String
end
NotImplemented() = NotImplemented("Not implemented.")

include("BaseExt/insertstyle.jl")
include("BaseExt/thawfreeze.jl")
include("BaseExt/sortedunique.jl")
include("abstractarray/insert.jl")
include("abstractsmallvector/abstractsmallvector.jl")
include("abstractsmallvector/deque.jl")
include("msmallvector/msmallvector.jl")
include("smallvector/smallvector.jl")
include("smallvector/insertstyle.jl")
include("msmallvector/thawfreeze.jl")
include("smallvector/thawfreeze.jl")
include("subsmallvector/subsmallvector.jl")
end
2 changes: 2 additions & 0 deletions NDTensors/src/SmallVectors/src/abstractarray/insert.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SmallVectors.insert(a::Vector, index::Integer, item) = insert!(copy(a), index, item)
delete(d::AbstractDict, key) = delete!(copy(d), key)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ similar_type(vec::AbstractSmallVector) = typeof(vec)

# Required buffer interface
maxlength(vec::AbstractSmallVector) = length(buffer(vec))
maxlength(vectype::Type{<:AbstractSmallVector}) = error("Not implemented")

function thaw_type(vectype::Type{<:AbstractSmallVector}, ::Type{T}) where {T}
return MSmallVector{maxlength(vectype),T}
end
thaw_type(vectype::Type{<:AbstractSmallVector{T}}) where {T} = thaw_type(vectype, T)

# Required AbstractArray interface
Base.size(vec::AbstractSmallVector) = throw(NotImplemented())
Expand Down
14 changes: 13 additions & 1 deletion NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,20 @@ end
return vec
end

@inline function Base.deleteat!(
vec::AbstractSmallVector, indices::AbstractUnitRange{<:Integer}
)
f = first(indices)
n = length(indices)
circshift!(smallview(vec, f, lastindex(vec)), -n)
resize!(vec, length(vec) - n)
return vec
end

# Don't @inline, makes it slower.
function StaticArrays.deleteat(vec::AbstractSmallVector, index::Integer)
function StaticArrays.deleteat(
vec::AbstractSmallVector, index::Union{Integer,AbstractUnitRange{<:Integer}}
)
mvec = Base.copymutable(vec)
deleteat!(mvec, index)
return convert(similar_type(vec), mvec)
Expand Down
5 changes: 5 additions & 0 deletions NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ function MSmallVector(buffer::AbstractVector, len::Int)
return MSmallVector{length(buffer),eltype(buffer)}(buffer, len)
end

maxlength(::Type{<:MSmallVector{S}}) where {S} = S

# Empty constructor
(msmallvector_type::Type{MSmallVector{S,T}} where {S,T})() = msmallvector_type(undef, 0)

"""
`MSmallVector` constructor, uses `MVector` as a buffer.
```julia
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/SmallVectors/src/msmallvector/thawfreeze.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
thaw(vec::MSmallVector) = copy(vec)
freeze(vec::MSmallVector) = SmallVector(vec)
1 change: 1 addition & 0 deletions NDTensors/src/SmallVectors/src/smallvector/insertstyle.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
InsertStyle(::Type{<:SmallVector}) = FastCopy()
Loading

0 comments on commit a9eb3cf

Please sign in to comment.