diff --git a/src/Experimental/RAGTools/preparation.jl b/src/Experimental/RAGTools/preparation.jl index f2c8a4029..187f9495f 100644 --- a/src/Experimental/RAGTools/preparation.jl +++ b/src/Experimental/RAGTools/preparation.jl @@ -742,12 +742,12 @@ Builds a `PineconeIndex` containing a Pinecone context (API key, index and names function build_index( indexer::PineconeIndexer, files_or_docs::Vector{<:AbstractString}; metadata::Vector{Dict{String, Any}} = Vector{Dict{String, Any}}(), - context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""), - index::Pinecone.PineconeIndexv3 = nothing, - namespace::AbstractString = "", + pinecone_context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""), + pinecone_index::Pinecone.PineconeIndexv3 = nothing, + pinecone_namespace::AbstractString = "", upsert::Bool = false, verbose::Integer = 1, - index_id = gensym(namespace), + index_id = gensym(pinecone_namespace), chunker::AbstractChunker = indexer.chunker, chunker_kwargs::NamedTuple = NamedTuple(), embedder::AbstractEmbedder = indexer.embedder, @@ -756,7 +756,7 @@ function build_index( tagger_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), cost_tracker = Threads.Atomic{Float64}(0.0)) - @assert !isempty(context.apikey) && !isnothing(index) "Pinecone context and index not set" + @assert !isempty(pinecone_context.apikey) && !isnothing(pinecone_index) "Pinecone context and index not set" ## Split into chunks chunks, sources = get_chunks(chunker, files_or_docs; @@ -788,12 +788,12 @@ function build_index( embeddings_arr = [embeddings[:,i] for i in axes(embeddings,2)] for (idx, emb) in enumerate(embeddings_arr) pinevector = Pinecone.PineconeVector(string(UUIDs.uuid4()), emb, metadata[idx]) - Pinecone.upsert(context, index, [pinevector], namespace) + Pinecone.upsert(pinecone_context, pinecone_index, [pinevector], pinecone_namespace) @info "Upsert #$idx complete" end end - index = PineconeIndex(; id = index_id, context, index, namespace, chunks, embeddings, tags, tags_vocab, metadata, sources) + index = PineconeIndex(; id = index_id, pinecone_context, pinecone_index, pinecone_namespace, chunks, embeddings, tags, tags_vocab, metadata, sources) (verbose > 0) && @info "Index built! (cost: \$$(round(cost_tracker[], digits=3)))" diff --git a/src/Experimental/RAGTools/retrieval.jl b/src/Experimental/RAGTools/retrieval.jl index 664210f77..673fdafba 100644 --- a/src/Experimental/RAGTools/retrieval.jl +++ b/src/Experimental/RAGTools/retrieval.jl @@ -246,9 +246,9 @@ function find_closest( query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[]; top_n::Int = 10, kwargs...) # get Pinecone info - pinecone_context = index.context - pinecone_index = index.index - pinecone_namespace = index.namespace + pinecone_context = index.pinecone_context + pinecone_index = index.pinecone_index + pinecone_namespace = index.pinecone_namespace # query candidates pinecone_results = Pinecone.query(pinecone_context, pinecone_index, diff --git a/src/Experimental/RAGTools/types.jl b/src/Experimental/RAGTools/types.jl index 37854c031..3301c1b67 100644 --- a/src/Experimental/RAGTools/types.jl +++ b/src/Experimental/RAGTools/types.jl @@ -137,6 +137,8 @@ chunkdata(index::ChunkEmbeddingsIndex) = embeddings(index) const ChunkIndex = ChunkEmbeddingsIndex indexid(index::AbstractManagedIndex) = index.id +chunks(index::AbstractManagedIndex) = index.chunks +sources(index::AbstractManagedIndex) = index.sources using Pinecone: Pinecone, PineconeContextv3, PineconeIndexv3 @kwdef struct PineconeIndex{ @@ -145,9 +147,9 @@ using Pinecone: Pinecone, PineconeContextv3, PineconeIndexv3 T3 <: Union{Nothing, AbstractMatrix{<:Bool}} } <: AbstractManagedIndex id::Symbol # namespace - context::Pinecone.PineconeContextv3 - index::Pinecone.PineconeIndexv3 - namespace::String + pinecone_context::Pinecone.PineconeContextv3 + pinecone_index::Pinecone.PineconeIndexv3 + pinecone_namespace::String # underlying document chunks / snippets chunks::Vector{T1} = nothing # for semantic search @@ -546,6 +548,84 @@ Base.@propagate_inbounds function translate_positions_to_parent( return sub_positions[pos] end + +@kwdef struct SubManagedIndex{T <: AbstractManagedIndex} <: AbstractManagedIndex + parent::T + positions::Vector{Int} +end + +indexid(index::SubManagedIndex) = parent(index) |> indexid +positions(index::SubManagedIndex) = index.positions +Base.parent(index::SubManagedIndex) = index.parent +HasEmbeddings(index::SubManagedIndex) = HasEmbeddings(parent(index)) +HasKeywords(index::SubManagedIndex) = HasKeywords(parent(index)) + +Base.@propagate_inbounds function chunks(index::SubManagedIndex) + view(chunks(parent(index)), positions(index)) +end +Base.@propagate_inbounds function sources(index::SubManagedIndex) + view(sources(parent(index)), positions(index)) +end +Base.@propagate_inbounds function chunkdata(index::SubManagedIndex) + chunkdata(parent(index), positions(index)) +end +"Access chunkdata for a subset of chunks, `chunk_idx` is a vector of chunk indices in the index" +Base.@propagate_inbounds function chunkdata( + index::SubManagedIndex, chunk_idx::AbstractVector{<:Integer}) + ## We need this accessor because different chunk indices can have chunks in different dimensions!! + index_chunk_idx = translate_positions_to_parent(index, chunk_idx) + pos = intersect(positions(index), index_chunk_idx) + chkdata = chunkdata(parent(index), pos) +end +function embeddings(index::SubManagedIndex) + if HasEmbeddings(index) + view(embeddings(parent(index)), :, positions(index)) + else + throw(ArgumentError("`embeddings` not implemented for $(typeof(index))")) + end +end +function tags(index::SubManagedIndex) + tagsdata = tags(parent(index)) + isnothing(tagsdata) && return nothing + view(tagsdata, positions(index), :) +end +function tags_vocab(index::SubManagedIndex) + tags_vocab(parent(index)) +end +function extras(index::SubManagedIndex) + extrasdata = extras(parent(index)) + isnothing(extrasdata) && return nothing + view(extrasdata, positions(index)) +end +function Base.vcat(i1::SubManagedIndex, i2::SubManagedIndex) + throw(ArgumentError("vcat not implemented for type $(typeof(i1)) and $(typeof(i2))")) +end +function Base.vcat(i1::T, i2::T) where {T <: SubManagedIndex} + ## Check if can be merged + if indexid(parent(i1)) != indexid(parent(i2)) + throw(ArgumentError("Parent indices must be the same (provided: $(indexid(parent(i1))) and $(indexid(parent(i2))))")) + end + return SubChunkIndex(parent(i1), vcat(positions(i1), positions(i2))) +end +function Base.unique(index::SubManagedIndex) + return SubChunkIndex(parent(index), unique(positions(index))) +end +function Base.length(index::SubManagedIndex) + return length(positions(index)) +end +function Base.isempty(index::SubManagedIndex) + return isempty(positions(index)) +end +function Base.show(io::IO, index::SubManagedIndex) + print(io, + "A view of $(typeof(parent(index))|>nameof) (id: $(indexid(parent(index)))) with $(length(index)) chunks") +end +Base.@propagate_inbounds function translate_positions_to_parent( + index::SubManagedIndex, pos::AbstractVector{<:Integer}) + sub_positions = positions(index) + return sub_positions[pos] +end + # # CandidateChunks for Retrieval """ @@ -864,7 +944,18 @@ Base.@propagate_inbounds function Base.view(index::SubChunkIndex, cc::MultiCandi end # TODO: proper `view` -- `SubManagedIndex`? Base.@propagate_inbounds function Base.view(index::AbstractManagedIndex, cc::CandidateWithChunks) - return cc + @boundscheck let chk_vector = chunks(parent(index)) + if !checkbounds(Bool, axes(chk_vector, 1), positions(cc)) + ## Avoid printing huge position arrays, show the extremas of the attempted range + max_pos = extrema(positions(cc)) + throw(BoundsError(chk_vector, max_pos)) + end + end + pos = indexid(index) == indexid(cc) ? positions(cc) : Int[] + return SubManagedIndex(parent(index), pos) +end +Base.@propagate_inbounds function Base.view(index::SubManagedIndex, cc::CandidateWithChunks) + SubManagedIndex(index, cc) end Base.@propagate_inbounds function SubChunkIndex(index::SubChunkIndex, cc::CandidateChunks) pos = indexid(index) == indexid(cc) ? positions(cc) : Int[] @@ -892,6 +983,18 @@ Base.@propagate_inbounds function SubChunkIndex( end return SubChunkIndex(parent(index), intersect_pos) end +Base.@propagate_inbounds function SubManagedIndex(index::SubManagedIndex, cc::CandidateWithChunks) + pos = indexid(index) == indexid(cc) ? positions(cc) : Int[] + intersect_pos = intersect(pos, positions(index)) + @boundscheck let chk_vector = chunks(parent(index)) + if !checkbounds(Bool, axes(chk_vector, 1), intersect_pos) + ## Avoid printing huge position arrays, show the extremas of the attempted range + max_pos = extrema(intersect_pos) + throw(BoundsError(chk_vector, max_pos)) + end + end + return SubManagedIndex(parent(index), intersect_pos) +end ## Getindex