Skip to content

Commit

Permalink
Add SubManagedIndex and view
Browse files Browse the repository at this point in the history
  • Loading branch information
iuliadmtru committed Sep 1, 2024
1 parent 1c4e0f5 commit 611716e
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 14 deletions.
14 changes: 7 additions & 7 deletions src/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -742,12 +742,12 @@ Builds a `PineconeIndex` containing a Pinecone context (API key, index and names
function build_index(
indexer::PineconeIndexer, files_or_docs::Vector{<:AbstractString};
metadata::Vector{Dict{String, Any}} = Vector{Dict{String, Any}}(),
context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""),
index::Pinecone.PineconeIndexv3 = nothing,
namespace::AbstractString = "",
pinecone_context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""),
pinecone_index::Pinecone.PineconeIndexv3 = nothing,
pinecone_namespace::AbstractString = "",
upsert::Bool = false,
verbose::Integer = 1,
index_id = gensym(namespace),
index_id = gensym(pinecone_namespace),
chunker::AbstractChunker = indexer.chunker,
chunker_kwargs::NamedTuple = NamedTuple(),
embedder::AbstractEmbedder = indexer.embedder,
Expand All @@ -756,7 +756,7 @@ function build_index(
tagger_kwargs::NamedTuple = NamedTuple(),
api_kwargs::NamedTuple = NamedTuple(),
cost_tracker = Threads.Atomic{Float64}(0.0))
@assert !isempty(context.apikey) && !isnothing(index) "Pinecone context and index not set"
@assert !isempty(pinecone_context.apikey) && !isnothing(pinecone_index) "Pinecone context and index not set"

## Split into chunks
chunks, sources = get_chunks(chunker, files_or_docs;
Expand Down Expand Up @@ -788,12 +788,12 @@ function build_index(
embeddings_arr = [embeddings[:,i] for i in axes(embeddings,2)]
for (idx, emb) in enumerate(embeddings_arr)
pinevector = Pinecone.PineconeVector(string(UUIDs.uuid4()), emb, metadata[idx])
Pinecone.upsert(context, index, [pinevector], namespace)
Pinecone.upsert(pinecone_context, pinecone_index, [pinevector], pinecone_namespace)
@info "Upsert #$idx complete"
end
end

index = PineconeIndex(; id = index_id, context, index, namespace, chunks, embeddings, tags, tags_vocab, metadata, sources)
index = PineconeIndex(; id = index_id, pinecone_context, pinecone_index, pinecone_namespace, chunks, embeddings, tags, tags_vocab, metadata, sources)

(verbose > 0) && @info "Index built! (cost: \$$(round(cost_tracker[], digits=3)))"

Expand Down
6 changes: 3 additions & 3 deletions src/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ function find_closest(
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
top_n::Int = 10, kwargs...)
# get Pinecone info
pinecone_context = index.context
pinecone_index = index.index
pinecone_namespace = index.namespace
pinecone_context = index.pinecone_context
pinecone_index = index.pinecone_index
pinecone_namespace = index.pinecone_namespace

# query candidates
pinecone_results = Pinecone.query(pinecone_context, pinecone_index,
Expand Down
111 changes: 107 additions & 4 deletions src/Experimental/RAGTools/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ chunkdata(index::ChunkEmbeddingsIndex) = embeddings(index)
const ChunkIndex = ChunkEmbeddingsIndex

indexid(index::AbstractManagedIndex) = index.id
chunks(index::AbstractManagedIndex) = index.chunks
sources(index::AbstractManagedIndex) = index.sources

using Pinecone: Pinecone, PineconeContextv3, PineconeIndexv3
@kwdef struct PineconeIndex{
Expand All @@ -145,9 +147,9 @@ using Pinecone: Pinecone, PineconeContextv3, PineconeIndexv3
T3 <: Union{Nothing, AbstractMatrix{<:Bool}}
} <: AbstractManagedIndex
id::Symbol # namespace
context::Pinecone.PineconeContextv3
index::Pinecone.PineconeIndexv3
namespace::String
pinecone_context::Pinecone.PineconeContextv3
pinecone_index::Pinecone.PineconeIndexv3
pinecone_namespace::String
# underlying document chunks / snippets
chunks::Vector{T1} = nothing
# for semantic search
Expand Down Expand Up @@ -546,6 +548,84 @@ Base.@propagate_inbounds function translate_positions_to_parent(
return sub_positions[pos]
end


@kwdef struct SubManagedIndex{T <: AbstractManagedIndex} <: AbstractManagedIndex
parent::T
positions::Vector{Int}
end

indexid(index::SubManagedIndex) = parent(index) |> indexid
positions(index::SubManagedIndex) = index.positions
Base.parent(index::SubManagedIndex) = index.parent
HasEmbeddings(index::SubManagedIndex) = HasEmbeddings(parent(index))
HasKeywords(index::SubManagedIndex) = HasKeywords(parent(index))

Base.@propagate_inbounds function chunks(index::SubManagedIndex)
view(chunks(parent(index)), positions(index))
end
Base.@propagate_inbounds function sources(index::SubManagedIndex)
view(sources(parent(index)), positions(index))
end
Base.@propagate_inbounds function chunkdata(index::SubManagedIndex)
chunkdata(parent(index), positions(index))
end
"Access chunkdata for a subset of chunks, `chunk_idx` is a vector of chunk indices in the index"
Base.@propagate_inbounds function chunkdata(
index::SubManagedIndex, chunk_idx::AbstractVector{<:Integer})
## We need this accessor because different chunk indices can have chunks in different dimensions!!
index_chunk_idx = translate_positions_to_parent(index, chunk_idx)
pos = intersect(positions(index), index_chunk_idx)
chkdata = chunkdata(parent(index), pos)
end
function embeddings(index::SubManagedIndex)
if HasEmbeddings(index)
view(embeddings(parent(index)), :, positions(index))
else
throw(ArgumentError("`embeddings` not implemented for $(typeof(index))"))
end
end
function tags(index::SubManagedIndex)
tagsdata = tags(parent(index))
isnothing(tagsdata) && return nothing
view(tagsdata, positions(index), :)
end
function tags_vocab(index::SubManagedIndex)
tags_vocab(parent(index))
end
function extras(index::SubManagedIndex)
extrasdata = extras(parent(index))
isnothing(extrasdata) && return nothing
view(extrasdata, positions(index))
end
function Base.vcat(i1::SubManagedIndex, i2::SubManagedIndex)
throw(ArgumentError("vcat not implemented for type $(typeof(i1)) and $(typeof(i2))"))
end
function Base.vcat(i1::T, i2::T) where {T <: SubManagedIndex}
## Check if can be merged
if indexid(parent(i1)) != indexid(parent(i2))
throw(ArgumentError("Parent indices must be the same (provided: $(indexid(parent(i1))) and $(indexid(parent(i2))))"))
end
return SubChunkIndex(parent(i1), vcat(positions(i1), positions(i2)))
end
function Base.unique(index::SubManagedIndex)
return SubChunkIndex(parent(index), unique(positions(index)))
end
function Base.length(index::SubManagedIndex)
return length(positions(index))
end
function Base.isempty(index::SubManagedIndex)
return isempty(positions(index))
end
function Base.show(io::IO, index::SubManagedIndex)
print(io,
"A view of $(typeof(parent(index))|>nameof) (id: $(indexid(parent(index)))) with $(length(index)) chunks")
end
Base.@propagate_inbounds function translate_positions_to_parent(
index::SubManagedIndex, pos::AbstractVector{<:Integer})
sub_positions = positions(index)
return sub_positions[pos]
end

# # CandidateChunks for Retrieval

"""
Expand Down Expand Up @@ -864,7 +944,18 @@ Base.@propagate_inbounds function Base.view(index::SubChunkIndex, cc::MultiCandi
end
# TODO: proper `view` -- `SubManagedIndex`?
Base.@propagate_inbounds function Base.view(index::AbstractManagedIndex, cc::CandidateWithChunks)
return cc
@boundscheck let chk_vector = chunks(parent(index))
if !checkbounds(Bool, axes(chk_vector, 1), positions(cc))
## Avoid printing huge position arrays, show the extremas of the attempted range
max_pos = extrema(positions(cc))
throw(BoundsError(chk_vector, max_pos))
end
end
pos = indexid(index) == indexid(cc) ? positions(cc) : Int[]
return SubManagedIndex(parent(index), pos)
end
Base.@propagate_inbounds function Base.view(index::SubManagedIndex, cc::CandidateWithChunks)
SubManagedIndex(index, cc)
end
Base.@propagate_inbounds function SubChunkIndex(index::SubChunkIndex, cc::CandidateChunks)
pos = indexid(index) == indexid(cc) ? positions(cc) : Int[]
Expand Down Expand Up @@ -892,6 +983,18 @@ Base.@propagate_inbounds function SubChunkIndex(
end
return SubChunkIndex(parent(index), intersect_pos)
end
Base.@propagate_inbounds function SubManagedIndex(index::SubManagedIndex, cc::CandidateWithChunks)
pos = indexid(index) == indexid(cc) ? positions(cc) : Int[]
intersect_pos = intersect(pos, positions(index))
@boundscheck let chk_vector = chunks(parent(index))
if !checkbounds(Bool, axes(chk_vector, 1), intersect_pos)
## Avoid printing huge position arrays, show the extremas of the attempted range
max_pos = extrema(intersect_pos)
throw(BoundsError(chk_vector, max_pos))
end
end
return SubManagedIndex(parent(index), intersect_pos)
end

## Getindex

Expand Down

0 comments on commit 611716e

Please sign in to comment.