diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index c5d46e7592..252966c7a6 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -11,6 +11,7 @@ FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" Folds = "41a02a25-b8f0-4f67-bc48-60067656b558" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/NDTensors/src/NDTensors.jl b/NDTensors/src/NDTensors.jl index 38d5f53a05..62cd081654 100644 --- a/NDTensors/src/NDTensors.jl +++ b/NDTensors/src/NDTensors.jl @@ -6,6 +6,7 @@ using Compat using Dictionaries using FLoops using Folds +using InlineStrings using Random using LinearAlgebra using StaticArrays @@ -21,6 +22,10 @@ include("SetParameters/src/SetParameters.jl") using .SetParameters include("SmallVectors/src/SmallVectors.jl") using .SmallVectors +include("SortedSets/src/SortedSets.jl") +using .SortedSets +include("TagSets/src/TagSets.jl") +using .TagSets using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo diff --git a/NDTensors/src/SmallVectors/src/BaseExt/insertstyle.jl b/NDTensors/src/SmallVectors/src/BaseExt/insertstyle.jl new file mode 100644 index 0000000000..c5008fef87 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/BaseExt/insertstyle.jl @@ -0,0 +1,9 @@ +# Trait determining the style of inserting into a structure +abstract type InsertStyle end +struct IsInsertable <: InsertStyle end +struct NotInsertable <: InsertStyle end +struct FastCopy <: InsertStyle end + +# Assume is insertable +@inline InsertStyle(::Type) = IsInsertable() +@inline InsertStyle(x) = InsertStyle(typeof(x)) diff --git a/NDTensors/src/SmallVectors/src/BaseExt/sortedunique.jl b/NDTensors/src/SmallVectors/src/BaseExt/sortedunique.jl new file mode 100644 index 0000000000..34fa0ff4c9 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/BaseExt/sortedunique.jl @@ -0,0 +1,242 @@ +# Union two unique sorted collections into an +# output buffer, returning a unique sorted collection. + +using Base: Ordering, ord, lt + +function unionsortedunique!( + itr1, + itr2; + lt=isless, + by=identity, + rev::Union{Bool,Nothing}=nothing, + order::Ordering=Forward, +) + return unionsortedunique!(itr1, itr2, ord(lt, by, rev, order)) +end + +function unionsortedunique!(itr1, itr2, order::Ordering) + i1 = firstindex(itr1) + i2 = firstindex(itr2) + stop1 = lastindex(itr1) + stop2 = lastindex(itr2) + @inbounds while i1 ≤ stop1 && i2 ≤ stop2 + item1 = itr1[i1] + item2 = itr2[i2] + if lt(order, item1, item2) + i1 += 1 + elseif lt(order, item2, item1) + # TODO: Use `insertat!`? + resize!(itr1, length(itr1) + 1) + for j in length(itr1):-1:(i1 + 1) + itr1[j] = itr1[j - 1] + end + itr1[i1] = item2 + i1 += 1 + i2 += 1 + stop1 += 1 + else # They are equal + i1 += 1 + i2 += 1 + end + end + # TODO: Use `insertat!`? + resize!(itr1, length(itr1) + (stop2 - i2 + 1)) + @inbounds for j2 in i2:stop2 + itr1[i1] = itr2[j2] + i1 += 1 + end + return itr1 +end + +function unionsortedunique( + itr1, + itr2; + lt=isless, + by=identity, + rev::Union{Bool,Nothing}=nothing, + order::Ordering=Forward, +) + return unionsortedunique(itr1, itr2, ord(lt, by, rev, order)) +end + +# Union two unique sorted collections into an +# output buffer, returning a unique sorted collection. +function unionsortedunique(itr1, itr2, order::Ordering) + out = thaw_type(itr1)(undef, length(itr1)) + i1 = firstindex(itr1) + i2 = firstindex(itr2) + iout = firstindex(out) + stop1 = lastindex(itr1) + stop2 = lastindex(itr2) + stopout = lastindex(out) + @inbounds while i1 ≤ stop1 && i2 ≤ stop2 + iout > stopout && resize!(out, iout) + item1 = itr1[i1] + item2 = itr2[i2] + if lt(order, item1, item2) + out[iout] = item1 + iout += 1 + i1 += 1 + elseif lt(order, item2, item1) + out[iout] = item2 + iout += 1 + i2 += 1 + else # They are equal + out[iout] = item1 + iout += 1 + i1 += 1 + i2 += 1 + end + end + # In case `out` was too long to begin with. + ## resize!(out, iout - 1) + # TODO: Use `insertat!`? + r1 = i1:stop1 + resize!(out, length(out) + length(r1)) + @inbounds for j1 in r1 + out[iout] = itr1[j1] + iout += 1 + end + # TODO: Use `insertat!`? + r2 = i2:stop2 + resize!(out, length(out) + length(r2)) + @inbounds for j2 in r2 + out[iout] = itr2[j2] + iout += 1 + end + return freeze(out) +end + +function setdiffsortedunique!( + itr1, + itr2; + lt=isless, + by=identity, + rev::Union{Bool,Nothing}=nothing, + order::Ordering=Forward, +) + return setdiffsortedunique!(itr1, itr2, ord(lt, by, rev, order)) +end + +function setdiffsortedunique!(itr1, itr2, order::Ordering) + i1 = firstindex(itr1) + i2 = firstindex(itr2) + stop1 = lastindex(itr1) + stop2 = lastindex(itr2) + @inbounds while i1 ≤ stop1 && i2 ≤ stop2 + item1 = itr1[i1] + item2 = itr2[i2] + if lt(order, item1, item2) + i1 += 1 + elseif lt(order, item2, item1) + i2 += 1 + else # They are equal + # TODO: Use `deletate!`? + for j1 in i1:(length(itr1) - 1) + itr1[j1] = itr1[j1 + 1] + end + resize!(itr1, length(itr1) - 1) + stop1 = lastindex(itr1) + i2 += 1 + end + end + return itr1 +end + +function setdiffsortedunique( + itr1, + itr2; + lt=isless, + by=identity, + rev::Union{Bool,Nothing}=nothing, + order::Ordering=Forward, +) + return setdiffsortedunique(itr1, itr2, ord(lt, by, rev, order)) +end + +function setdiffsortedunique(itr1, itr2, order::Ordering) + out = thaw_type(itr1)() + i1 = firstindex(itr1) + i2 = firstindex(itr2) + iout = firstindex(out) + stop1 = lastindex(itr1) + stop2 = lastindex(itr2) + stopout = lastindex(out) + @inbounds while i1 ≤ stop1 && i2 ≤ stop2 + item1 = itr1[i1] + item2 = itr2[i2] + if lt(order, item1, item2) + iout > stopout && resize!(out, iout) + out[iout] = item1 + iout += 1 + i1 += 1 + elseif lt(order, item2, item1) + i2 += 1 + else # They are equal + i1 += 1 + i2 += 1 + end + end + resize!(out, iout - 1) + return freeze(out) +end + +function intersectsortedunique!( + itr1, + itr2; + lt=isless, + by=identity, + rev::Union{Bool,Nothing}=nothing, + order::Ordering=Forward, +) + return intersectsortedunique!(itr1, itr2, ord(lt, by, rev, order)) +end + +function intersectsortedunique!(itr1, itr2, order::Ordering) + return error("Not implemented") +end + +function intersectsortedunique( + itr1, + itr2; + lt=isless, + by=identity, + rev::Union{Bool,Nothing}=nothing, + order::Ordering=Forward, +) + return intersectsortedunique(itr1, itr2, ord(lt, by, rev, order)) +end + +function intersectsortedunique(itr1, itr2, order::Ordering) + return error("Not implemented") +end + +function symdiffsortedunique!( + itr1, + itr2; + lt=isless, + by=identity, + rev::Union{Bool,Nothing}=nothing, + order::Ordering=Forward, +) + return symdiffsortedunique!(itr1, itr2, ord(lt, by, rev, order)) +end + +function symdiffsortedunique!(itr1, itr2, order::Ordering) + return error("Not implemented") +end + +function symdiffsortedunique( + itr1, + itr2; + lt=isless, + by=identity, + rev::Union{Bool,Nothing}=nothing, + order::Ordering=Forward, +) + return symdiffsortedunique(itr1, itr2, ord(lt, by, rev, order)) +end + +function symdiffsortedunique(itr1, itr2, order::Ordering) + return error("Not implemented") +end diff --git a/NDTensors/src/SmallVectors/src/BaseExt/thawfreeze.jl b/NDTensors/src/SmallVectors/src/BaseExt/thawfreeze.jl new file mode 100644 index 0000000000..7f6cd037a5 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/BaseExt/thawfreeze.jl @@ -0,0 +1,6 @@ +thaw(x) = copy(x) +freeze(x) = x + +thaw_type(::Type{<:AbstractArray{<:Any,N}}, ::Type{T}) where {T,N} = Array{T,N} +thaw_type(x::AbstractArray, ::Type{T}) where {T} = thaw_type(typeof(x), T) +thaw_type(x::AbstractArray{T}) where {T} = thaw_type(typeof(x), T) diff --git a/NDTensors/src/SmallVectors/src/SmallVectors.jl b/NDTensors/src/SmallVectors/src/SmallVectors.jl index e3f7083795..75b5ec200c 100644 --- a/NDTensors/src/SmallVectors/src/SmallVectors.jl +++ b/NDTensors/src/SmallVectors/src/SmallVectors.jl @@ -1,16 +1,44 @@ module SmallVectors using StaticArrays -export SmallVector, MSmallVector, SubSmallVector +export AbstractSmallVector, + SmallVector, + MSmallVector, + SubSmallVector, + FastCopy, + InsertStyle, + IsInsertable, + NotInsertable, + insert, + delete, + thaw, + freeze, + maxlength, + unionsortedunique, + unionsortedunique!, + setdiffsortedunique, + setdiffsortedunique!, + intersectsortedunique, + intersectsortedunique!, + symdiffsortedunique, + symdiffsortedunique!, + thaw_type struct NotImplemented <: Exception msg::String end NotImplemented() = NotImplemented("Not implemented.") +include("BaseExt/insertstyle.jl") +include("BaseExt/thawfreeze.jl") +include("BaseExt/sortedunique.jl") +include("abstractarray/insert.jl") include("abstractsmallvector/abstractsmallvector.jl") include("abstractsmallvector/deque.jl") include("msmallvector/msmallvector.jl") include("smallvector/smallvector.jl") +include("smallvector/insertstyle.jl") +include("msmallvector/thawfreeze.jl") +include("smallvector/thawfreeze.jl") include("subsmallvector/subsmallvector.jl") end diff --git a/NDTensors/src/SmallVectors/src/abstractarray/insert.jl b/NDTensors/src/SmallVectors/src/abstractarray/insert.jl new file mode 100644 index 0000000000..3a864aabe6 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/abstractarray/insert.jl @@ -0,0 +1,2 @@ +SmallVectors.insert(a::Vector, index::Integer, item) = insert!(copy(a), index, item) +delete(d::AbstractDict, key) = delete!(copy(d), key) diff --git a/NDTensors/src/SmallVectors/src/abstractsmallvector/abstractsmallvector.jl b/NDTensors/src/SmallVectors/src/abstractsmallvector/abstractsmallvector.jl index 014f0e0d9b..382928489f 100644 --- a/NDTensors/src/SmallVectors/src/abstractsmallvector/abstractsmallvector.jl +++ b/NDTensors/src/SmallVectors/src/abstractsmallvector/abstractsmallvector.jl @@ -10,6 +10,12 @@ similar_type(vec::AbstractSmallVector) = typeof(vec) # Required buffer interface maxlength(vec::AbstractSmallVector) = length(buffer(vec)) +maxlength(vectype::Type{<:AbstractSmallVector}) = error("Not implemented") + +function thaw_type(vectype::Type{<:AbstractSmallVector}, ::Type{T}) where {T} + return MSmallVector{maxlength(vectype),T} +end +thaw_type(vectype::Type{<:AbstractSmallVector{T}}) where {T} = thaw_type(vectype, T) # Required AbstractArray interface Base.size(vec::AbstractSmallVector) = throw(NotImplemented()) diff --git a/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl b/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl index 96c1e80640..79fdd5cdda 100644 --- a/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl +++ b/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl @@ -140,8 +140,20 @@ end return vec end +@inline function Base.deleteat!( + vec::AbstractSmallVector, indices::AbstractUnitRange{<:Integer} +) + f = first(indices) + n = length(indices) + circshift!(smallview(vec, f, lastindex(vec)), -n) + resize!(vec, length(vec) - n) + return vec +end + # Don't @inline, makes it slower. -function StaticArrays.deleteat(vec::AbstractSmallVector, index::Integer) +function StaticArrays.deleteat( + vec::AbstractSmallVector, index::Union{Integer,AbstractUnitRange{<:Integer}} +) mvec = Base.copymutable(vec) deleteat!(mvec, index) return convert(similar_type(vec), mvec) diff --git a/NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl b/NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl index 434d1e4d48..6d45801f93 100644 --- a/NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl +++ b/NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl @@ -16,6 +16,11 @@ function MSmallVector(buffer::AbstractVector, len::Int) return MSmallVector{length(buffer),eltype(buffer)}(buffer, len) end +maxlength(::Type{<:MSmallVector{S}}) where {S} = S + +# Empty constructor +(msmallvector_type::Type{MSmallVector{S,T}} where {S,T})() = msmallvector_type(undef, 0) + """ `MSmallVector` constructor, uses `MVector` as a buffer. ```julia diff --git a/NDTensors/src/SmallVectors/src/msmallvector/thawfreeze.jl b/NDTensors/src/SmallVectors/src/msmallvector/thawfreeze.jl new file mode 100644 index 0000000000..563f20dd18 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/msmallvector/thawfreeze.jl @@ -0,0 +1,2 @@ +thaw(vec::MSmallVector) = copy(vec) +freeze(vec::MSmallVector) = SmallVector(vec) diff --git a/NDTensors/src/SmallVectors/src/smallvector/insertstyle.jl b/NDTensors/src/SmallVectors/src/smallvector/insertstyle.jl new file mode 100644 index 0000000000..027d1eb356 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/smallvector/insertstyle.jl @@ -0,0 +1 @@ +InsertStyle(::Type{<:SmallVector}) = FastCopy() diff --git a/NDTensors/src/SmallVectors/src/smallvector/smallvector.jl b/NDTensors/src/SmallVectors/src/smallvector/smallvector.jl index 3976480f47..4b99e76cd6 100644 --- a/NDTensors/src/SmallVectors/src/smallvector/smallvector.jl +++ b/NDTensors/src/SmallVectors/src/smallvector/smallvector.jl @@ -11,6 +11,8 @@ end @inline setbuffer(vec::SmallVector, buffer) = SmallVector(buffer, vec.length) @inline setlength(vec::SmallVector, length) = SmallVector(vec.buffer, length) +maxlength(::Type{<:SmallVector{S}}) where {S} = S + # Constructors function SmallVector{S}(buffer::AbstractVector, len::Int) where {S} return SmallVector{S,eltype(buffer)}(buffer, len) @@ -27,7 +29,16 @@ SmallVector{10}(SA[1, 2, 3]) ``` """ function SmallVector{S,T}(vec::AbstractVector) where {S,T} - return SmallVector{S,T}(MSmallVector{S,T}(vec)) + # TODO: This is a bit slower, but simpler. Check if this + # gets faster in newer Julia versions. + # return SmallVector{S,T}(MSmallVector{S,T}(vec)) + length(vec) > S && error("Data is too long for `SmallVector`.") + msvec = MVector{S,T}(undef) + @inbounds for i in eachindex(vec) + msvec[i] = vec[i] + end + svec = SVector(msvec) + return SmallVector{S,T}(svec, length(vec)) end # Special optimization codepath for `MSmallVector` # to avoid a copy. diff --git a/NDTensors/src/SmallVectors/src/smallvector/thawfreeze.jl b/NDTensors/src/SmallVectors/src/smallvector/thawfreeze.jl new file mode 100644 index 0000000000..077e7d539e --- /dev/null +++ b/NDTensors/src/SmallVectors/src/smallvector/thawfreeze.jl @@ -0,0 +1,2 @@ +thaw(vec::SmallVector) = MSmallVector(vec) +freeze(vec::SmallVector) = vec diff --git a/NDTensors/src/SortedSets/src/BaseExt/sorted.jl b/NDTensors/src/SortedSets/src/BaseExt/sorted.jl new file mode 100644 index 0000000000..939eb12274 --- /dev/null +++ b/NDTensors/src/SortedSets/src/BaseExt/sorted.jl @@ -0,0 +1,54 @@ +# TODO: +# Add ` +# Version that uses an `Ordering`. +function _insorted( + x, + v::AbstractVector; + lt=isless, + by=identity, + rev::Union{Bool,Nothing}=nothing, + order::Ordering=Forward, +) + return _insorted(x, v, ord(lt, by, rev, order)) +end +_insorted(x, v::AbstractVector, o::Ordering) = !isempty(searchsorted(v, x, o)) + +function alluniquesorted( + vec; lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward +) + return alluniquesorted(vec, ord(lt, by, rev, order)) +end + +function alluniquesorted(vec, order::Ordering) + length(vec) < 2 && return true + iter = eachindex(vec) + I = iterate(iter) + while I !== nothing + i, s = I + J = iterate(iter, s) + isnothing(J) && return true + j, _ = J + !lt(order, @inbounds(vec[i]), @inbounds(vec[j])) && return false + I = J + end + return true +end + +function uniquesorted(vec; lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward) + return uniquesorted(vec, ord(lt, by, rev, order)) +end + +function uniquesorted(vec::AbstractVector, order::Ordering) + vec = copy(vec) + i = firstindex(vec) + stopi = lastindex(vec) + while i < stopi + if !lt(order, @inbounds(vec[i]), @inbounds(vec[i + 1])) + deleteat!(vec, i) + stopi -= 1 + else + i += 1 + end + end + return vec +end diff --git a/NDTensors/src/SortedSets/src/DictionariesExt/insert.jl b/NDTensors/src/SortedSets/src/DictionariesExt/insert.jl new file mode 100644 index 0000000000..1721487b9b --- /dev/null +++ b/NDTensors/src/SortedSets/src/DictionariesExt/insert.jl @@ -0,0 +1,2 @@ +SmallVectors.insert(inds::AbstractIndices, i) = insert!(copy(inds), i) +SmallVectors.delete(inds::AbstractIndices, i) = delete!(copy(inds), i) diff --git a/NDTensors/src/SortedSets/src/DictionariesExt/isinsertable.jl b/NDTensors/src/SortedSets/src/DictionariesExt/isinsertable.jl new file mode 100644 index 0000000000..6c9599ce39 --- /dev/null +++ b/NDTensors/src/SortedSets/src/DictionariesExt/isinsertable.jl @@ -0,0 +1 @@ +Dictionaries.isinsertable(::AbstractArray) = true diff --git a/NDTensors/src/SortedSets/src/SmallVectorsDictionariesExt/interface.jl b/NDTensors/src/SortedSets/src/SmallVectorsDictionariesExt/interface.jl new file mode 100644 index 0000000000..20dca8d56f --- /dev/null +++ b/NDTensors/src/SortedSets/src/SmallVectorsDictionariesExt/interface.jl @@ -0,0 +1,3 @@ +Dictionaries.isinsertable(::AbstractSmallVector) = true +Dictionaries.isinsertable(::SmallVector) = false +Dictionaries.empty_type(::Type{SmallVector{S,T}}, ::Type{T}) where {S,T} = MSmallVector{S,T} diff --git a/NDTensors/src/SortedSets/src/SortedSets.jl b/NDTensors/src/SortedSets/src/SortedSets.jl new file mode 100644 index 0000000000..09343deb3b --- /dev/null +++ b/NDTensors/src/SortedSets/src/SortedSets.jl @@ -0,0 +1,21 @@ +module SortedSets +using Compat +using Dictionaries +using Random +using ..SmallVectors + +using Base: @propagate_inbounds +using Base.Order: Ordering, Forward, ord, lt + +export AbstractWrappedIndices, SortedSet, SmallSet, MSmallSet + +include("BaseExt/sorted.jl") +include("DictionariesExt/insert.jl") +include("DictionariesExt/isinsertable.jl") +include("abstractset.jl") +include("abstractwrappedset.jl") +include("SmallVectorsDictionariesExt/interface.jl") +include("sortedset.jl") +include("SortedSetsSmallVectorsExt/smallset.jl") + +end diff --git a/NDTensors/src/SortedSets/src/SortedSetsSmallVectorsExt/smallset.jl b/NDTensors/src/SortedSets/src/SortedSetsSmallVectorsExt/smallset.jl new file mode 100644 index 0000000000..8c70f8fe60 --- /dev/null +++ b/NDTensors/src/SortedSets/src/SortedSetsSmallVectorsExt/smallset.jl @@ -0,0 +1,16 @@ +const AbstractSmallSet{T} = SortedSet{T,<:AbstractSmallVector{T}} +const SmallSet{S,T} = SortedSet{T,SmallVector{S,T}} +const MSmallSet{S,T} = SortedSet{T,MSmallVector{S,T}} + +# Specialized constructors +@propagate_inbounds SmallSet{S}(; kwargs...) where {S} = SmallSet{S}([]; kwargs...) +@propagate_inbounds SmallSet{S}(iter; kwargs...) where {S} = + SmallSet{S}(collect(iter); kwargs...) +@propagate_inbounds SmallSet{S}(a::AbstractArray{I}; kwargs...) where {S,I} = + SmallSet{S,I}(a; kwargs...) + +@propagate_inbounds MSmallSet{S}(; kwargs...) where {S} = MSmallSet{S}([]; kwargs...) +@propagate_inbounds MSmallSet{S}(iter; kwargs...) where {S} = + MSmallSet{S}(collect(iter); kwargs...) +@propagate_inbounds MSmallSet{S}(a::AbstractArray{I}; kwargs...) where {S,I} = + MSmallSet{S,I}(a; kwargs...) diff --git a/NDTensors/src/SortedSets/src/abstractset.jl b/NDTensors/src/SortedSets/src/abstractset.jl new file mode 100644 index 0000000000..57ccb90619 --- /dev/null +++ b/NDTensors/src/SortedSets/src/abstractset.jl @@ -0,0 +1,88 @@ +abstract type AbstractSet{T} <: AbstractIndices{T} end + +# Specialized versions of set operations for `AbstractSet` +# that allow more specialization. + +function Base.union(i::AbstractSet, itr) + return union(InsertStyle(i), i, itr) +end + +function Base.union(::InsertStyle, i::AbstractSet, itr) + return error("Not implemented") +end + +function Base.union(::IsInsertable, i::AbstractSet, itr) + out = copy(i) + union!(out, itr) + return out +end + +function Base.union(::NotInsertable, i::AbstractSet, itr) + out = empty(i) + union!(out, i) + union!(out, itr) + return out +end + +function Base.intersect(i::AbstractSet, itr) + return intersect(InsertStyle(i), i, itr) +end + +function Base.intersect(::InsertStyle, i::AbstractSet, itr) + return error("Not implemented") +end + +function Base.intersect(::IsInsertable, i::AbstractSet, itr) + out = copy(i) + intersect!(out, itr) + return out +end + +function Base.intersect(::NotInsertable, i::AbstractSet, itr) + out = empty(i) + union!(out, i) + intersect!(out, itr) + return out +end + +function Base.setdiff(i::AbstractSet, itr) + return setdiff(InsertStyle(i), i, itr) +end + +function Base.setdiff(::InsertStyle, i::AbstractSet, itr) + return error("Not implemented") +end + +function Base.setdiff(::IsInsertable, i::AbstractSet, itr) + out = copy(i) + setdiff!(out, itr) + return out +end + +function Base.setdiff(::NotInsertable, i::AbstractSet, itr) + out = empty(i) + union!(out, i) + setdiff!(out, itr) + return out +end + +function Base.symdiff(i::AbstractSet, itr) + return symdiff(InsertStyle(i), i, itr) +end + +function Base.symdiff(::InsertStyle, i::AbstractSet, itr) + return error("Not implemented") +end + +function Base.symdiff(::IsInsertable, i::AbstractSet, itr) + out = copy(i) + symdiff!(out, itr) + return out +end + +function Base.symdiff(::NotInsertable, i::AbstractSet, itr) + out = empty(i) + union!(out, i) + symdiff!(out, itr) + return out +end diff --git a/NDTensors/src/SortedSets/src/abstractwrappedset.jl b/NDTensors/src/SortedSets/src/abstractwrappedset.jl new file mode 100644 index 0000000000..e1bbf50d24 --- /dev/null +++ b/NDTensors/src/SortedSets/src/abstractwrappedset.jl @@ -0,0 +1,111 @@ +# AbstractWrappedIndices: a wrapper around an `AbstractIndices` +# with methods automatically forwarded via `parent` +# and rewrapped via `rewrap`. +abstract type AbstractWrappedIndices{T,D} <: AbstractIndices{T} end + +# Required interface +Base.parent(inds::AbstractWrappedIndices) = error("Not implemented") +function Dictionaries.empty_type(::Type{AbstractWrappedIndices{I}}, ::Type{I}) where {I} + return error("Not implemented") +end +SmallVectors.thaw(::AbstractWrappedIndices) = error("Not implemented") +SmallVectors.freeze(::AbstractWrappedIndices) = error("Not implemented") +rewrap(::AbstractWrappedIndices, data) = error("Not implemented") + +# Traits +SmallVectors.InsertStyle(::Type{<:AbstractWrappedIndices{T,D}}) where {T,D} = InsertStyle(D) + +# AbstractIndices interface +@propagate_inbounds function Base.iterate(inds::AbstractWrappedIndices, state...) + return iterate(parent(inds), state...) +end + +# `I` is needed to avoid ambiguity error. +@inline Base.in(tag::I, inds::AbstractWrappedIndices{I}) where {I} = in(tag, parent(inds)) +@inline Base.IteratorSize(inds::AbstractWrappedIndices) = Base.IteratorSize(parent(inds)) +@inline Base.length(inds::AbstractWrappedIndices) = length(parent(inds)) + +@inline Dictionaries.istokenizable(inds::AbstractWrappedIndices) = + istokenizable(parent(inds)) +@inline Dictionaries.tokentype(inds::AbstractWrappedIndices) = tokentype(parent(inds)) +@inline Dictionaries.iteratetoken(inds::AbstractWrappedIndices, s...) = + iterate(parent(inds), s...) +@inline function Dictionaries.iteratetoken_reverse(inds::AbstractWrappedIndices) + return iteratetoken_reverse(parent(inds)) +end +@inline function Dictionaries.iteratetoken_reverse(inds::AbstractWrappedIndices, t) + return iteratetoken_reverse(parent(inds), t) +end + +@inline function Dictionaries.gettoken(inds::AbstractWrappedIndices, i) + return gettoken(parent(inds), i) +end +@propagate_inbounds Dictionaries.gettokenvalue(inds::AbstractWrappedIndices, x) = + gettokenvalue(parent(inds), x) + +@inline Dictionaries.isinsertable(inds::AbstractWrappedIndices) = isinsertable(parent(inds)) + +# Specify `I` to fix ambiguity error. +@inline function Dictionaries.gettoken!( + inds::AbstractWrappedIndices{I}, i::I, values=() +) where {I} + return gettoken!(parent(inds), i, values) +end + +@inline function Dictionaries.deletetoken!(inds::AbstractWrappedIndices, x, values=()) + deletetoken!(parent(inds), x, values) + return inds +end + +@inline function Base.empty!(inds::AbstractWrappedIndices, values=()) + empty!(parent(inds)) + return inds +end + +# Not defined to be part of the `AbstractIndices` interface, +# but seems to be needed. +@inline function Base.filter!(pred, inds::AbstractWrappedIndices) + filter!(pred, parent(inds)) + return inds +end + +# TODO: Maybe require an implementation? +@inline function Base.copy(inds::AbstractWrappedIndices, eltype::Type) + return typeof(inds)(copy(parent(inds), eltype)) +end + +# Not required for AbstractIndices interface but +# helps with faster code paths +SmallVectors.insert(inds::AbstractWrappedIndices, tag) = insert(parent(inds), tag) +Base.insert!(inds::AbstractWrappedIndices, tag) = insert!(parent(inds), tag) + +SmallVectors.delete(inds::AbstractWrappedIndices, tag) = delete(parent(inds), tag) +Base.delete!(inds::AbstractWrappedIndices, tag) = delete!(parent(inds), tag) + +function Base.union(inds1::AbstractWrappedIndices, inds2::AbstractWrappedIndices) + return rewrap(inds1, union(parent(inds1), parent(inds2))) +end +function Base.union(inds1::AbstractWrappedIndices, inds2) + return rewrap(inds1, union(parent(inds1), inds2)) +end + +function Base.intersect(inds1::AbstractWrappedIndices, inds2::AbstractWrappedIndices) + return rewrap(inds1, intersect(parent(inds1), parent(inds2))) +end +function Base.intersect(inds1::AbstractWrappedIndices, inds2) + return rewrap(inds1, intersect(parent(inds1), inds2)) +end + +function Base.setdiff(inds1::AbstractWrappedIndices, inds2::AbstractWrappedIndices) + return rewrap(inds1, setdiff(parent(inds1), parent(inds2))) +end +function Base.setdiff(inds1::AbstractWrappedIndices, inds2) + return rewrap(inds1, setdiff(parent(inds1), inds2)) +end + +function Base.symdiff(inds1::AbstractWrappedIndices, inds2::AbstractWrappedIndices) + return rewrap(inds1, symdiff(parent(inds1), parent(inds2))) +end +function Base.symdiff(inds1::AbstractWrappedIndices, inds2) + return rewrap(inds1, symdiff(parent(inds1), inds2)) +end diff --git a/NDTensors/src/SortedSets/src/sortedset.jl b/NDTensors/src/SortedSets/src/sortedset.jl new file mode 100644 index 0000000000..43e2db1c8c --- /dev/null +++ b/NDTensors/src/SortedSets/src/sortedset.jl @@ -0,0 +1,287 @@ +""" + SortedIndices(iter) + +Construct an `SortedIndices <: AbstractIndices` from an arbitrary Julia iterable with unique +elements. Lookup uses that they are sorted. + +SortedIndices can be faster than ArrayIndices which use naive search that may be optimal for +small collections. Larger collections are better handled by containers like `Indices`. +""" +struct SortedIndices{I,Inds<:AbstractArray{I},Order<:Ordering} <: AbstractSet{I} + inds::Inds + order::Order + global @inline _SortedIndices( + inds::Inds, order::Order + ) where {I,Inds<:AbstractArray{I},Order<:Ordering} = new{I,Inds,Order}(inds, order) +end + +# Inner constructor +function SortedIndices{I,Inds,Order}( + a::Inds, order::Order; issorted=issorted, allunique=allunique +) where {I,Inds<:AbstractArray{I},Order<:Ordering} + if !issorted(a, order) + a = sort(a, order) + end + if !alluniquesorted(a, order) + a = uniquesorted(a, order) + end + return _SortedIndices(a, order) +end + +@inline function SortedIndices{I,Inds,Order}( + a::AbstractArray, order::Ordering; issorted=issorted, allunique=allunique +) where {I,Inds<:AbstractArray{I},Order<:Ordering} + return SortedIndices{I,Inds,Order}( + convert(Inds, a), convert(Order, order); issorted, allunique + ) +end + +@inline function SortedIndices{I,Inds}( + a::AbstractArray, order::Order; issorted=issorted, allunique=allunique +) where {I,Inds<:AbstractArray{I},Order<:Ordering} + return SortedIndices{I,Inds,Order}(a, order; issorted, allunique) +end + +@inline function SortedIndices( + a::Inds, order::Ordering; issorted=issorted, allunique=allunique +) where {I,Inds<:AbstractArray{I}} + return SortedIndices{I,Inds}(a, order; issorted, allunique) +end + +@inline function SortedIndices{I,Inds}( + a::Inds; + lt=isless, + by=identity, + rev::Bool=false, + order::Ordering=Forward, + issorted=issorted, + allunique=allunique, +) where {I,Inds<:AbstractArray{I}} + order = ord(lt, by, rev, order) + return SortedIndices{I,Inds}(a, order; issorted, allunique) +end + +const SortedSet = SortedIndices + +# Traits +@inline SmallVectors.InsertStyle(::Type{<:SortedIndices{I,Inds}}) where {I,Inds} = + InsertStyle(Inds) +@inline SmallVectors.thaw(i::SortedIndices) = SortedIndices(thaw(i.inds), i.order) +@inline SmallVectors.freeze(i::SortedIndices) = SortedIndices(freeze(i.inds), i.order) + +@propagate_inbounds SortedIndices(; kwargs...) = SortedIndices{Any}([]; kwargs...) +@propagate_inbounds SortedIndices{I}(; kwargs...) where {I} = + SortedIndices{I,Vector{I}}(I[]; kwargs...) +@propagate_inbounds SortedIndices{I,Inds}(; kwargs...) where {I,Inds} = + SortedIndices{I}(Inds(); kwargs...) + +@propagate_inbounds SortedIndices(iter; kwargs...) = SortedIndices(collect(iter); kwargs...) +@propagate_inbounds SortedIndices{I}(iter; kwargs...) where {I} = + SortedIndices{I}(collect(I, iter); kwargs...) + +@propagate_inbounds SortedIndices(a::AbstractArray{I}; kwargs...) where {I} = + SortedIndices{I}(a; kwargs...) +@propagate_inbounds SortedIndices{I}(a::AbstractArray{I}; kwargs...) where {I} = + SortedIndices{I,typeof(a)}(a; kwargs...) + +@propagate_inbounds SortedIndices{I,Inds}( + a::AbstractArray; kwargs... +) where {I,Inds<:AbstractArray{I}} = SortedIndices{I,Inds}(Inds(a); kwargs...) + +function Base.convert(::Type{AbstractIndices{I}}, inds::SortedIndices) where {I} + return convert(SortedIndices{I}, inds) +end +function Base.convert(::Type{SortedIndices}, inds::AbstractIndices{I}) where {I} + return convert(SortedIndices{I}, inds) +end +function Base.convert(::Type{SortedIndices{I}}, inds::AbstractIndices) where {I} + return convert(SortedIndices{I,Vector{I}}, inds) +end +function Base.convert( + ::Type{SortedIndices{I,Inds}}, inds::AbstractIndices +) where {I,Inds<:AbstractArray{I}} + a = convert(Inds, collect(I, inds)) + return @inbounds SortedIndices{I,typeof(a)}(a) +end + +Base.convert(::Type{SortedIndices{I}}, inds::SortedIndices{I}) where {I} = inds +function Base.convert( + ::Type{SortedIndices{I}}, inds::SortedIndices{<:Any,Inds} +) where {I,Inds<:AbstractArray{I}} + return convert(SortedIndices{I,Inds}, inds) +end +function Base.convert( + ::Type{SortedIndices{I,Inds}}, inds::SortedIndices{I,Inds} +) where {I,Inds<:AbstractArray{I}} + return inds +end +function Base.convert( + ::Type{SortedIndices{I,Inds}}, inds::SortedIndices +) where {I,Inds<:AbstractArray{I}} + a = convert(Inds, parent(inds)) + return @inbounds SortedIndices{I,Inds}(a) +end + +@inline Base.parent(inds::SortedIndices) = getfield(inds, :inds) + +# Basic interface +@propagate_inbounds function Base.iterate(i::SortedIndices{I}, state...) where {I} + return iterate(parent(i), state...) +end + +@inline function Base.in(i::I, inds::SortedIndices{I}) where {I} + return _insorted(i, parent(inds), inds.order) +end +@inline Base.IteratorSize(::SortedIndices) = Base.HasLength() +@inline Base.length(inds::SortedIndices) = length(parent(inds)) + +@inline Dictionaries.istokenizable(i::SortedIndices) = true +@inline Dictionaries.tokentype(::SortedIndices) = Int +@inline Dictionaries.iteratetoken(inds::SortedIndices, s...) = + iterate(LinearIndices(parent(inds)), s...) +@inline function Dictionaries.iteratetoken_reverse(inds::SortedIndices) + li = LinearIndices(parent(inds)) + if isempty(li) + return nothing + else + t = last(li) + return (t, t) + end +end +@inline function Dictionaries.iteratetoken_reverse(inds::SortedIndices, t) + li = LinearIndices(parent(inds)) + t -= 1 + if t < first(li) + return nothing + else + return (t, t) + end +end + +@inline function Dictionaries.gettoken(inds::SortedIndices, i) + a = parent(inds) + r = searchsorted(a, i, inds.order) + @assert 0 ≤ length(r) ≤ 1 # If > 1, means the elements are not unique + length(r) == 0 && return (false, 0) + return (true, convert(Int, only(r))) +end +@propagate_inbounds Dictionaries.gettokenvalue(inds::SortedIndices, x::Int) = + parent(inds)[x] + +@inline Dictionaries.isinsertable(i::SortedIndices) = isinsertable(parent(inds)) + +@inline function Dictionaries.gettoken!(inds::SortedIndices{I}, i::I, values=()) where {I} + a = parent(inds) + r = searchsorted(a, i, inds.order) + @assert 0 ≤ length(r) ≤ 1 # If > 1, means the elements are not unique + if length(r) == 0 + insert!(a, first(r), i) + foreach(v -> resize!(v, length(v) + 1), values) + return (false, last(LinearIndices(a))) + end + return (true, convert(Int, only(r))) +end + +@inline function Dictionaries.deletetoken!(inds::SortedIndices, x::Int, values=()) + deleteat!(parent(inds), x) + foreach(v -> deleteat!(v, x), values) + return inds +end + +@inline function Base.empty!(inds::SortedIndices, values=()) + empty!(parent(inds)) + foreach(empty!, values) + return inds +end + +# TODO: Make into `MSmallVector`? +# More generally, make a `thaw(::AbstractArray)` function to return +# a mutable version of an AbstractArray. +@inline Dictionaries.empty_type( + ::Type{SortedIndices{I,D,Order}}, ::Type{I} +) where {I,D,Order} = SortedIndices{I,Dictionaries.empty_type(D, I),Order} + +@inline Dictionaries.empty_type(::Type{<:AbstractVector}, ::Type{I}) where {I} = Vector{I} + +function Base.empty(inds::SortedIndices{I,D}, ::Type{I}) where {I,D} + return Dictionaries.empty_type(typeof(inds), I)(D(), inds.order) +end + +@inline function Base.copy(inds::SortedIndices, ::Type{I}) where {I} + if I === eltype(inds) + SortedIndices( + copy(parent(inds)), inds.order; issorted=Returns(true), allunique=Returns(true) + ) + else + SortedIndices( + convert(AbstractArray{I}, parent(inds)), + inds.order; + issorted=Returns(true), + allunique=Returns(true), + ) + end +end + +# TODO: Can this take advantage of sorting? +@inline function Base.filter!(pred, inds::SortedIndices) + filter!(pred, parent(inds)) + return inds +end + +function Dictionaries.randtoken(rng::Random.AbstractRNG, inds::SortedIndices) + return rand(rng, keys(parent(inds))) +end + +@inline function Base.sort!( + inds::SortedIndices; lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward +) + # No-op, should be sorted already. + # TODO: Check `ord(lt, by, rev, order) == inds.ord`. + return inds +end + +# Custom faster operations (not required for interface) +function Base.union!(inds::SortedIndices, items::SortedIndices) + if inds.order ≠ items.order + # Reorder if the orderings are different. + items = SortedIndices(parent(inds), inds.order) + end + unionsortedunique!(parent(inds), parent(items), inds.order) + return inds +end + +function Base.union(inds::SortedIndices, items::SortedIndices) + if inds.order ≠ items.order + # Reorder if the orderings are different. + items = SortedIndices(parent(inds), inds.order) + end + out = unionsortedunique(parent(inds), parent(items), inds.order) + return SortedIndices(out, inds.order; issorted=Returns(true), allunique=Returns(true)) +end + +function Base.union(inds::SortedIndices, items) + return union(inds, SortedIndices(items, inds.order)) +end + +function Base.intersect(inds::SortedIndices, items::SortedIndices) + # TODO: Make an `intersectsortedunique`. + return intersect(NotInsertable(), inds, items) +end + +function Base.setdiff(inds::SortedIndices, items) + return setdiff(inds, SortedIndices(items, inds.order)) +end + +function Base.setdiff(inds::SortedIndices, items::SortedIndices) + # TODO: Make an `setdiffsortedunique`. + return setdiff(NotInsertable(), inds, items) +end + +function Base.symdiff(inds::SortedIndices, items) + return symdiff(inds, SortedIndices(items, inds.order)) +end + +function Base.symdiff(inds::SortedIndices, items::SortedIndices) + # TODO: Make an `symdiffsortedunique`. + return symdiff(NotInsertable(), inds, items) +end diff --git a/NDTensors/src/SortedSets/test/runtests.jl b/NDTensors/src/SortedSets/test/runtests.jl new file mode 100644 index 0000000000..24500f474d --- /dev/null +++ b/NDTensors/src/SortedSets/test/runtests.jl @@ -0,0 +1,20 @@ +using Test +using NDTensors.SortedSets +using NDTensors.SmallVectors + +@testset "Test NDTensors.SortedSets" begin + for V in (Vector, MSmallVector{10}, SmallVector{10}) + s1 = SortedSet(V([1, 3, 5])) + s2 = SortedSet(V([2, 3, 6])) + + # Set interface + @test union(s1, s2) == SortedSet([1, 2, 3, 5, 6]) + @test setdiff(s1, s2) == SortedSet([1, 5]) + @test symdiff(s1, s2) == SortedSet([1, 2, 5, 6]) + @test intersect(s1, s2) == SortedSet([3]) + if SmallVectors.InsertStyle(V) isa IsInsertable + @test insert!(copy(s1), 4) == SortedSet([1, 3, 4, 5]) + @test delete!(copy(s1), 3) == SortedSet([1, 5]) + end + end +end diff --git a/NDTensors/src/TagSets/README.md b/NDTensors/src/TagSets/README.md new file mode 100644 index 0000000000..346026463b --- /dev/null +++ b/NDTensors/src/TagSets/README.md @@ -0,0 +1,20 @@ +# TagSets.jl + +A sorted collection of unique tags of type `T`. + +# TODO + +- Add `skipchars` (see `skipmissing`) and `delim` for delimiter. +- https://docs.julialang.org/en/v1/base/strings/#Base.strip +- https://docs.julialang.org/en/v1/stdlib/DelimitedFiles/#Delimited-Files +- Add a `Bool` param for bounds checking/ignoring overflow/spillover? +- Make `S` a first argument, hardcode `SmallVector` storage? +- https://juliacollections.github.io/DataStructures.jl/v0.9/sorted_containers.html +- https://github.com/JeffreySarnoff/SortingNetworks.jl +- https://github.com/vvjn/MergeSorted.jl +- https://bkamins.github.io/julialang/2023/08/25/infiltrate.html +- https://github.com/Jutho/TensorKit.jl/blob/master/src/auxiliary/dicts.jl +- https://github.com/tpapp/SortedVectors.jl +- https://discourse.julialang.org/t/special-purpose-subtypes-of-arrays/20327 +- https://discourse.julialang.org/t/all-the-ways-to-group-reduce-sorted-vectors-ideas/45239 +- https://discourse.julialang.org/t/sorting-a-vector-of-fixed-size/71766 diff --git a/NDTensors/src/TagSets/examples/benchmark.jl b/NDTensors/src/TagSets/examples/benchmark.jl new file mode 100644 index 0000000000..98a40e5f46 --- /dev/null +++ b/NDTensors/src/TagSets/examples/benchmark.jl @@ -0,0 +1,47 @@ +using NDTensors.TagSets +using NDTensors.InlineStrings +using NDTensors.SmallVectors +using NDTensors.SortedSets +using NDTensors.TagSets + +using BenchmarkTools +using Cthulhu +using Profile +using PProf + +function main(; profile=false) + TS = SmallTagSet{10,String31} + ts1 = TS(["a", "b"]) + ts2 = TS(["b", "c", "d"]) + + @btime $TS($("x,y")) + + @show union(ts1, ts2) + @show intersect(ts1, ts2) + @show setdiff(ts1, ts2) + @show symdiff(ts1, ts2) + + @btime union($ts1, $ts2) + @btime intersect($ts1, $ts2) + @btime setdiff($ts1, $ts2) + @btime symdiff($ts1, $ts2) + + @show addtags(ts1, ts2) + @show commontags(ts1, ts2) + @show removetags(ts1, ts2) + @show noncommontags(ts1, ts2) + @show replacetags(ts1, ["b"], ["c", "d"]) + + @btime addtags($ts1, $ts2) + @btime commontags($ts1, $ts2) + @btime removetags($ts1, $ts2) + @btime noncommontags($ts1, $ts2) + @btime replacetags($ts1, $(["b"]), $(["c", "d"])) + + if profile + Profile.clear() + @profile foreach(_ -> TagSet("x,y"; data_type=set_type), 1:1_000_000) + return pprof() + end + return nothing +end diff --git a/NDTensors/src/TagSets/src/TagSets.jl b/NDTensors/src/TagSets/src/TagSets.jl new file mode 100644 index 0000000000..53df13d0ee --- /dev/null +++ b/NDTensors/src/TagSets/src/TagSets.jl @@ -0,0 +1,76 @@ +module TagSets +using Dictionaries +using ..SmallVectors +using ..SortedSets + +using Base: @propagate_inbounds + +export TagSet, SmallTagSet, addtags, removetags, replacetags, commontags, noncommontags + +# A sorted collection of unique tags of type `T`. +struct TagSet{T,D<:AbstractIndices{T}} <: AbstractWrappedIndices{T,D} + data::D +end + +TagSet{T}(data::D) where {T,D<:AbstractIndices{T}} = TagSet{T,D}(data) + +TagSet{T,D}(vec::AbstractVector) where {T,D<:AbstractIndices{T}} = TagSet{T,D}(D(vec)) +TagSet{T,D}() where {T,D<:AbstractIndices{T}} = TagSet{T,D}(D()) + +# Defaults to Indices if unspecified. +default_data_type() = Indices{String} +TagSet(vec::AbstractVector) = TagSet(default_data_type()(vec)) + +# Constructor from string +default_delim() = ',' +@inline function TagSet(str::AbstractString; delim=default_delim()) + return TagSet(default_data_type(), str) +end +@inline function TagSet( + ::Type{D}, str::AbstractString; delim=default_delim() +) where {T,D<:AbstractIndices{T}} + return TagSet{T,D}(str) +end +@inline function TagSet{T,D}( + str::AbstractString; delim=default_delim() +) where {T,D<:AbstractIndices{T}} + return TagSet{T,D}(split(str, delim)) +end + +const SmallTagSet{S,T} = TagSet{T,SmallSet{S,T}} +@propagate_inbounds SmallTagSet{S}(; kwargs...) where {S} = SmallTagSet{S}([]; kwargs...) +@propagate_inbounds SmallTagSet{S}(iter; kwargs...) where {S} = + SmallTagSet{S}(collect(iter); kwargs...) +@propagate_inbounds SmallTagSet{S}(a::AbstractArray{I}; kwargs...) where {S,I} = + SmallTagSet{S,I}(a; kwargs...) +# Specialized `SmallSet{S,T} = SortedSet{T,SmallVector{S,T}}` constructor +function SmallTagSet{S,T}(str::AbstractString; delim=default_delim()) where {S,T} + # TODO: Optimize for `SmallSet`. + return SmallTagSet{S,T}(split(str, delim)) +end + +# Field accessors +Base.parent(tags::TagSet) = getfield(tags, :data) + +# AbstractWrappedSet interface. +# Specialized version when they are the same data type is faster. +@inline SortedSets.rewrap(vec::TagSet{T,D}, data::D) where {T,D<:AbstractIndices{T}} = + TagSet{T,D}(data) +@inline SortedSets.rewrap(vec::TagSet{T,D}, data) where {T,D<:AbstractIndices{T}} = + TagSet{T,D}(data) + +# TagSet interface +addtags(tags::TagSet, items) = union(tags, items) +removetags(tags::TagSet, items) = setdiff(tags, items) +commontags(tags::TagSet, items) = intersect(tags, items) +noncommontags(tags::TagSet, items) = symdiff(tags, items) +function replacetags(tags::TagSet, rem, add) + remtags = setdiff(tags, rem) + if length(tags) ≠ length(remtags) + length(rem) + # Not all are removed, no replacement + return tags + end + return union(remtags, add) +end + +end diff --git a/NDTensors/src/TagSets/test/runtests.jl b/NDTensors/src/TagSets/test/runtests.jl new file mode 100644 index 0000000000..cf7336bbc1 --- /dev/null +++ b/NDTensors/src/TagSets/test/runtests.jl @@ -0,0 +1,33 @@ +using Test +using NDTensors.TagSets +using NDTensors.SortedSets +using NDTensors.SmallVectors +using NDTensors.InlineStrings +using NDTensors.Dictionaries + +@testset "Test NDTensors.TagSets" begin + for data_type in (Vector,) # SmallVector{10}) + d1 = data_type{String31}(["1", "3", "5"]) + d2 = data_type{String31}(["2", "3", "6"]) + for set_type in (Indices, SortedSet) + s1 = TagSet(set_type(d1)) + s2 = TagSet(set_type(d2)) + + @test issetequal(union(s1, s2), ["1", "2", "3", "5", "6"]) + @test issetequal(setdiff(s1, s2), ["1", "5"]) + @test issetequal(symdiff(s1, s2), ["1", "2", "5", "6"]) + @test issetequal(intersect(s1, s2), ["3"]) + + # TagSet interface + @test issetequal(addtags(s1, ["4"]), ["1", "3", "4", "5"]) + @test issetequal(removetags(s1, ["3"]), ["1", "5"]) + @test issetequal(replacetags(s1, ["3"], ["6", "7"]), ["1", "5", "6", "7"]) + @test issetequal(replacetags(s1, ["3", "4"], ["6, 7"]), ["1", "3", "5"]) + + # Only test if `isinsertable`. Make sure that is false + # for `SmallVector`. + ## @test issetequal(insert!(copy(s1), "4"), ["1", "3", "4", "5"]) + ## @test issetequal(delete!(copy(s1), "3"), ["1", "5"]) + end + end +end diff --git a/NDTensors/src/empty/empty.jl b/NDTensors/src/empty/empty.jl index 437d0d62ad..3d4c7e8188 100644 --- a/NDTensors/src/empty/empty.jl +++ b/NDTensors/src/empty/empty.jl @@ -40,11 +40,11 @@ storagetype(::Type{EmptyStorage{ElT,StoreT}}) where {ElT,StoreT} = StoreT storagetype(::EmptyStorage{ElT,StoreT}) where {ElT,StoreT} = StoreT # Get the EmptyStorage version of the TensorStorage -function emptytype(::Type{StoreT}) where {StoreT} - return EmptyStorage{eltype(StoreT),StoreT} +function emptytype(storagetype::Type{<:TensorStorage}) + return EmptyStorage{eltype(storagetype),storagetype} end -empty(::Type{StoreT}) where {StoreT} = emptytype(StoreT)() +empty(storagetype::Type{<:TensorStorage}) = emptytype(storagetype)() data(S::EmptyStorage) = NoData() diff --git a/NDTensors/test/SortedSets.jl b/NDTensors/test/SortedSets.jl new file mode 100644 index 0000000000..e5a885737d --- /dev/null +++ b/NDTensors/test/SortedSets.jl @@ -0,0 +1,4 @@ +using Test +using NDTensors + +include(joinpath(pkgdir(NDTensors), "src", "SortedSets", "test", "runtests.jl")) diff --git a/NDTensors/test/TagSets.jl b/NDTensors/test/TagSets.jl new file mode 100644 index 0000000000..3ce0fbfd98 --- /dev/null +++ b/NDTensors/test/TagSets.jl @@ -0,0 +1,4 @@ +using Test +using NDTensors + +include(joinpath(pkgdir(NDTensors), "src", "TagSets", "test", "runtests.jl")) diff --git a/NDTensors/test/runtests.jl b/NDTensors/test/runtests.jl index 8e4f1733f8..9c00f31e1f 100644 --- a/NDTensors/test/runtests.jl +++ b/NDTensors/test/runtests.jl @@ -21,6 +21,8 @@ end @testset "$filename" for filename in [ "SetParameters.jl", "SmallVectors.jl", + "SortedSets.jl", + "TagSets.jl", "linearalgebra.jl", "dense.jl", "blocksparse.jl",