Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Sep 25, 2023
1 parent 740cd02 commit 81216b9
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ function Base.setindex!(vec::AbstractSmallVector, item, index::Integer)
return throw(NotImplemented())
end
Base.IndexStyle(::Type{<:AbstractSmallVector}) = IndexLinear()

Base.convert(::Type{T}, a::AbstractArray) where {T<:AbstractSmallVector} = a isa T ? a : T(a)::T
11 changes: 11 additions & 0 deletions NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
# symdiff[!]
# unique[!]

# unionsorted[!]
# setdiffsorted[!]
# deletesorted[!] (delete all or one?)
# deletesortedfirst[!] (delete all or one?)

Base.resize!(vec::AbstractSmallVector, len) = throw(NotImplemented())

@inline function resize(vec::AbstractSmallVector, len)
Expand Down Expand Up @@ -80,6 +85,12 @@ function StaticArrays.popfirst(vec::AbstractSmallVector)
return convert(similar_type(vec), mvec)
end

# This implementation of `midpoint` is performance-optimized but safe
# only if `lo <= hi`.
# TODO: Replace with `Base.midpoint`.
midpoint(lo::T, hi::T) where T<:Integer = lo + ((hi - lo) >>> 0x01)
midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...)

@inline function Base.reverse!(vec::AbstractSmallVector)
start, stop = firstindex(vec), lastindex(vec)
r = stop
Expand Down
4 changes: 1 addition & 3 deletions NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@ MSmallVector{10}(SA[1, 2, 3])
```
"""
function MSmallVector{S,T}(vec::AbstractVector) where {S,T}
buffer = zeros(MVector{S,T})
buffer = MVector{S,T}(undef)
copyto!(buffer, vec)
return MSmallVector(buffer, length(vec))
end

# Derive the buffer length.
MSmallVector(vec::AbstractSmallVector) = MSmallVector{length(buffer(vec))}(vec)

Base.convert(::Type{T}, a::AbstractArray) where {T<:MSmallVector} = a isa T ? a : T(a)::T

function MSmallVector{S}(vec::AbstractVector) where {S}
return MSmallVector{S,eltype(vec)}(vec)
end
Expand Down
7 changes: 4 additions & 3 deletions NDTensors/src/SmallVectors/src/smallvector/smallvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ SmallVector{10}(SA[1, 2, 3])
```
"""
function SmallVector{S,T}(vec::AbstractVector) where {S,T}
mvec = MSmallVector{S,T}(vec)
return SmallVector{S,T}(buffer(mvec), length(mvec))
return SmallVector{S,T}(MSmallVector{S,T}(vec))
end
# Special optimization codepath for `MSmallVector`
# to avoid a copy.
Expand All @@ -48,7 +47,9 @@ end
# Derive the buffer length.
SmallVector(vec::AbstractSmallVector) = SmallVector{length(buffer(vec))}(vec)

Base.convert(::Type{T}, a::AbstractArray) where {T<:SmallVector} = a isa T ? a : T(a)::T
# Empty constructor
(smallvector_type::Type{SmallVector{S,T}} where {S,T})() = smallvector_type(undef, 0)
SmallVector{S,T}(::UndefInitializer, length::Integer) where {S,T} = SmallVector{S,T}(SVector{S,T}(MVector{S,T}(undef)), length)

# Buffer interface
buffer(vec::SmallVector) = vec.buffer
Expand Down

0 comments on commit 81216b9

Please sign in to comment.