Skip to content

Commit

Permalink
Add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
iuliadmtru committed Sep 3, 2024
1 parent 611716e commit b027d05
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 48 deletions.
36 changes: 27 additions & 9 deletions src/Experimental/RAGTools/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ function build_context(contexter::ContextEnumerator,
return context
end

"""
build_context(contexter::ContextEnumerator,
index::AbstractManagedIndex, candidates::AbstractCandidateWithChunks;
verbose::Bool = true,
chunks_window_margin::Tuple{Int, Int} = (1, 1), kwargs...)
build_context!(contexter::ContextEnumerator,
index::AbstractManagedIndex, result::AbstractRAGResult; kwargs...)
Dispatch for `AbstractManagedIndex` with `AbstractCandidateWithChunks`.
"""
function build_context(contexter::ContextEnumerator,
index::AbstractManagedIndex,
candidates::AbstractCandidateWithChunks;
Expand Down Expand Up @@ -124,7 +135,6 @@ function answer!(
throw(ArgumentError("Answerer $(typeof(answerer)) not implemented"))
end

# TODO: update docs signature
"""
answer!(
answerer::SimpleAnswerer, index::AbstractDocumentIndex, result::AbstractRAGResult;
Expand Down Expand Up @@ -173,6 +183,17 @@ function answer!(

return result
end

"""
answer!(
answerer::SimpleAnswerer, index::AbstractManagedIndex, result::AbstractRAGResult;
model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true,
template::Symbol = :RAGAnswerFromContext,
cost_tracker = Threads.Atomic{Float64}(0.0),
kwargs...)
Dispatch for `AbstractManagedIndex`.
"""
function answer!(
answerer::SimpleAnswerer, index::AbstractManagedIndex, result::AbstractRAGResult;
model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true,
Expand Down Expand Up @@ -228,7 +249,6 @@ function refine!(
end


# TODO: update docs signature
"""
refine!(
refiner::NoRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
Expand All @@ -247,10 +267,9 @@ function refine!(
end


# TODO: update docs signature
"""
refine!(
refiner::SimpleRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
refiner::SimpleRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
verbose::Bool = true,
model::AbstractString = PT.MODEL_CHAT,
template::Symbol = :RAGAnswerRefiner,
Expand Down Expand Up @@ -303,10 +322,9 @@ function refine!(
end


# TODO: update docs signature
"""
refine!(
refiner::TavilySearchRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
refiner::TavilySearchRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
verbose::Bool = true,
model::AbstractString = PT.MODEL_CHAT,
include_answer::Bool = true,
Expand Down Expand Up @@ -458,10 +476,9 @@ It uses `ContextEnumerator`, `SimpleAnswerer`, `SimpleRefiner`, and `NoPostproce
postprocessor::AbstractPostprocessor = NoPostprocessor()
end

# TODO: update docs signature
"""
generate!(
generator::AbstractGenerator, index::AbstractDocumentIndex, result::AbstractRAGResult;
generator::AbstractGenerator, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
verbose::Integer = 1,
api_kwargs::NamedTuple = NamedTuple(),
contexter::AbstractContextBuilder = generator.contexter,
Expand Down Expand Up @@ -591,8 +608,9 @@ function Base.show(io::IO, cfg::AbstractRAGConfig)
dump(io, cfg; maxdepth = 2)
end

# TODO: add example for Pinecone
"""
airag(cfg::AbstractRAGConfig, index::AbstractDocumentIndex;
airag(cfg::AbstractRAGConfig, index::Union{AbstractDocumentIndex, AbstractManagedIndex};
question::AbstractString,
verbose::Integer = 1, return_all::Bool = false,
api_kwargs::NamedTuple = NamedTuple(),
Expand Down
89 changes: 84 additions & 5 deletions src/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,12 @@ end
PineconeIndexer <: AbstractIndexBuilder
Pinecone index to be returned by `build_index`.
It uses `FileChunker`, `SimpleEmbedder` and `NoTagger` as default chunker, embedder and tagger.
"""
@kwdef mutable struct PineconeIndexer <: AbstractIndexBuilder
chunker::AbstractChunker = FileChunker()
# TODO: BatchEmbedder?
embedder::AbstractEmbedder = SimpleEmbedder()
tagger::AbstractTagger = NoTagger()
end
Expand Down Expand Up @@ -726,26 +729,102 @@ function build_index(
return index
end

# TODO: where to put these?
using Pinecone: Pinecone, PineconeContextv3, PineconeIndexv3, init_v3, Index, PineconeVector, upsert
using UUIDs: UUIDs, uuid4
# TODO: change docs
"""
build_index(
indexer::PineconeIndexer;
namespace::AbstractString,
indexer::PineconeIndexer, files_or_docs::Vector{<:AbstractString};
metadata::Vector{Dict{String, Any}} = Vector{Dict{String, Any}}(),
pinecone_context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""),
pinecone_index::Pinecone.PineconeIndexv3 = nothing,
pinecone_namespace::AbstractString = "",
upsert::Bool = true,
verbose::Integer = 1,
index_id = gensym("PTPineconeIndex"),
index_id = gensym(pinecone_namespace),
chunker::AbstractChunker = indexer.chunker,
chunker_kwargs::NamedTuple = NamedTuple(),
embedder::AbstractEmbedder = indexer.embedder,
embedder_kwargs::NamedTuple = NamedTuple(),
tagger::AbstractTagger = indexer.tagger,
tagger_kwargs::NamedTuple = NamedTuple(),
api_kwargs::NamedTuple = NamedTuple(),
cost_tracker = Threads.Atomic{Float64}(0.0))
Builds a `PineconeIndex` containing a Pinecone context (API key, index and namespace).
The index stores the document chunks and their embeddings (and potentially other information).
The function processes each file or document (depending on `chunker`), splits its content into chunks, embeds these chunks
and then combines this information into a retrievable index. The chunks and embeddings are upsert to Pinecone using
the provided Pinecone context (unless the `upsert` flag is set to `false`).
# Arguments
- `indexer::PineconeIndexer`: The indexing logic for Pinecone operations.
- `files_or_docs`: A vector of valid file paths to be indexed (chunked and embedded).
- `metadata::Vector{Dict{String, Any}}`: A vector of metadata attributed to each docs file, given as dictionaries with `String` keys. Default is empty vector.
- `pinecone_context::Pinecone.PineconeContextv3`: The Pinecone API key generated using Pinecone.jl. Must be specified.
- `pinecone_index::Pinecone.PineconeIndexv3`: The Pinecone index generated using Pinecone.jl. Must be specified.
- `pinecone_namespace::AbstractString`: The Pinecone namespace associated to `pinecone_index`.
- `upsert::Bool = true`: A flag specifying whether to upsert the chunks and embeddings to Pinecone. Defaults to `true`.
- `verbose`: An Integer specifying the verbosity of the logs. Default is `1` (high-level logging). `0` is disabled.
- `index_id`: A unique identifier for the index. Default is a generated symbol.
- `chunker`: The chunker logic to use for splitting the documents. Default is `TextChunker()`.
- `chunker_kwargs`: Parameters to be provided to the `get_chunks` function. Useful to change the `separators` or `max_length`.
- `sources`: A vector of strings indicating the source of each chunk. Default is equal to `files_or_docs`.
- `embedder`: The embedder logic to use for embedding the chunks. Default is `BatchEmbedder()`.
- `embedder_kwargs`: Parameters to be provided to the `get_embeddings` function. Useful to change the `target_batch_size_length` or reduce asyncmap tasks `ntasks`.
- `model`: The model to use for embedding. Default is `PT.MODEL_EMBEDDING`.
- `tagger`: The tagger logic to use for extracting tags from the chunks. Default is `NoTagger()`, ie, skip tag extraction. There are also `PassthroughTagger` and `OpenTagger`.
- `tagger_kwargs`: Parameters to be provided to the `get_tags` function.
- `model`: The model to use for tags extraction. Default is `PT.MODEL_CHAT`.
- `template`: A template to be used for tags extraction. Default is `:RAGExtractMetadataShort`.
- `tags`: A vector of vectors of strings directly providing the tags for each chunk. Applicable for `tagger::PasstroughTagger`.
- `api_kwargs`: Parameters to be provided to the API endpoint. Shared across all API calls if provided.
- `cost_tracker`: A `Threads.Atomic{Float64}` object to track the total cost of the API calls. Useful to pass the total cost to the parent call.
# Returns
- `PineconeIndex`: An object containing the compiled index of chunks, embeddings, tags, vocabulary, sources and metadata, together with the Pinecone connection data.
See also: `PineconeIndex`, `get_chunks`, `get_embeddings`, `get_tags`, `CandidateWithChunks`, `find_closest`, `find_tags`, `rerank`, `retrieve`, `generate!`, `airag`
# Examples
```julia
using Pinecone
# Prepare the Pinecone connection data
pinecone_context = Pinecone.init_v3(ENV["PINECONE_API_KEY"])
pindex = ENV["PINECONE_INDEX"]
pinecone_index = !isempty(pindex) ? Pinecone.Index(pinecone_context, pindex) : nothing
namespace = "my-namespace"
# Add metadata about the sources in Pinecone
metadata = [Dict{String, Any}("source" => doc_file) for doc_file in docs_files]
# Build the index. By default, the chunks and embeddings get upserted to Pinecone.
const RT = PromptingTools.Experimental.RAGTools
index_pinecone = RT.build_index(
RT.PineconeIndexer(),
docs_files;
pinecone_context = pinecone_context,
pinecone_index = pinecone_index,
pinecone_namespace = namespace,
metadata = metadata
)
# Notes
- If you get errors about exceeding embedding input sizes, first check the `max_length` in your chunks.
If that does NOT resolve the issue, try changing the `embedding_kwargs`.
In particular, reducing the `target_batch_size_length` parameter (eg, 10_000) and number of tasks `ntasks=1`.
Some providers cannot handle large batch sizes (eg, Databricks).
"""
function build_index(
indexer::PineconeIndexer, files_or_docs::Vector{<:AbstractString};
metadata::Vector{Dict{String, Any}} = Vector{Dict{String, Any}}(),
pinecone_context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""),
pinecone_index::Pinecone.PineconeIndexv3 = nothing,
pinecone_namespace::AbstractString = "",
upsert::Bool = false,
upsert::Bool = true,
verbose::Integer = 1,
index_id = gensym(pinecone_namespace),
chunker::AbstractChunker = indexer.chunker,
Expand Down
102 changes: 72 additions & 30 deletions src/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,37 @@ function find_closest(
return CandidateChunks(indexid(index), positions, Float32.(scores))
end

# Dispatch to find scores for multiple embeddings
function find_closest(
finder::AbstractSimilarityFinder, index::AbstractChunkIndex,
query_emb::AbstractMatrix{<:Real}, query_tokens::AbstractVector{<:AbstractVector{<:AbstractString}} = Vector{Vector{String}}();
top_k::Int = 100, kwargs...)
if isnothing(chunkdata(parent(index)))
return CandidateChunks(; index_id = indexid(index))
end
## reduce top_k since we have more than one query
top_k_ = top_k ÷ size(query_emb, 2)
## simply vcat together (gets sorted from the highest similarity to the lowest)
if isempty(query_tokens)
mapreduce(
c -> find_closest(finder, index, c; top_k = top_k_, kwargs...), vcat, eachcol(query_emb))
else
@assert length(query_tokens)==size(query_emb, 2) "Length of `query_tokens` must be equal to the number of columns in `query_emb`."
mapreduce(
(emb, tok) -> find_closest(finder, index, emb, tok; top_k = top_k_, kwargs...), vcat, eachcol(query_emb), query_tokens)
end
end

"""
find_closest(
finder::AbstractSimilarityFinder, index::PineconeIndex,
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
top_k::Int = 10, kwargs...)
Finds the indices of chunks that are closest to query embedding (`query_emb`) by querying Pinecone.
Returns only `top_k` closest indices.
"""
function find_closest(
finder::AbstractSimilarityFinder, index::PineconeIndex,
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
Expand All @@ -261,6 +292,7 @@ function find_closest(
scores = [m.score for m in matches]
chunks = [m.metadata.content for m in matches]
metadata = [JSON3.read(JSON3.write(m.metadata), Dict{String, Any}) for m in matches]
# TODO: metadata might not have `source`, change this
sources = [m.metadata.source for m in matches]

return CandidateWithChunks(
Expand All @@ -272,6 +304,7 @@ function find_closest(
sources = Vector{String}(sources))
end

# Dispatch to find scores for multiple embeddings
function find_closest(
finder::AbstractSimilarityFinder, index::PineconeIndex,
query_emb::AbstractMatrix{<:Real}, query_tokens::AbstractVector{<:AbstractVector{<:AbstractString}} = Vector{Vector{String}}();
Expand All @@ -290,27 +323,6 @@ function find_closest(
end
end

# Dispatch to find scores for multiple embeddings
function find_closest(
finder::AbstractSimilarityFinder, index::AbstractChunkIndex,
query_emb::AbstractMatrix{<:Real}, query_tokens::AbstractVector{<:AbstractVector{<:AbstractString}} = Vector{Vector{String}}();
top_k::Int = 100, kwargs...)
if isnothing(chunkdata(parent(index)))
return CandidateChunks(; index_id = indexid(index))
end
## reduce top_k since we have more than one query
top_k_ = top_k ÷ size(query_emb, 2)
## simply vcat together (gets sorted from the highest similarity to the lowest)
if isempty(query_tokens)
mapreduce(
c -> find_closest(finder, index, c; top_k = top_k_, kwargs...), vcat, eachcol(query_emb))
else
@assert length(query_tokens)==size(query_emb, 2) "Length of `query_tokens` must be equal to the number of columns in `query_emb`."
mapreduce(
(emb, tok) -> find_closest(finder, index, emb, tok; top_k = top_k_, kwargs...), vcat, eachcol(query_emb), query_tokens)
end
end

### For MultiIndex
function find_closest(
finder::MultiFinder, index::AbstractMultiIndex,
Expand Down Expand Up @@ -612,20 +624,14 @@ function find_tags(method::AllTagFilter, index::AbstractChunkIndex,
end

"""
find_tags(method::NoTagFilter, index::AbstractChunkIndex,
find_tags(method::NoTagFilter, index::Union{AbstractChunkIndex, AbstractManagedIndex},
tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Union{
AbstractString, Regex, Nothing}}
tags; kwargs...)
Returns all chunks in the index, ie, no filtering, so we simply return `nothing` (easier for dispatch).
"""
# function find_tags(method::NoTagFilter, index::AbstractChunkIndex,
# tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
# Union{
# AbstractString, Regex, Nothing}}
# return nothing
# end
function find_tags(
method::NoTagFilter, index::Union{AbstractChunkIndex,
AbstractManagedIndex},
Expand Down Expand Up @@ -748,8 +754,6 @@ function rerank(reranker::NoReranker,
candidates::AbstractCandidateWithChunks;
top_n::Integer = length(candidates),
kwargs...)
# Since this is almost a passthrough strategy, it returns the candidate_chunks unchanged
# but it truncates to `top_n` if necessary
return first(candidates, top_n)
end

Expand Down Expand Up @@ -1017,11 +1021,22 @@ end
PineconeRetriever <: AbstractRetriever
Dispatch for `retrieve` for Pinecone.
# Fields
- `rephraser::AbstractRephraser`: the rephrasing method, dispatching `rephrase` - uses `NoRephraser`
- `embedder::AbstractEmbedder`: the embedding method, dispatching `get_embeddings` (see Preparation Stage for more details) - uses `SimpleEmbedder`
- `processor::AbstractProcessor`: the processor method, dispatching `get_keywords` (see Preparation Stage for more details) - uses `NoProcessor`
- `finder::AbstractSimilarityFinder`: the similarity search method, dispatching `find_closest` - uses `CosineSimilarity`
- `tagger::AbstractTagger`: the tag generating method, dispatching `get_tags` (see Preparation Stage for more details) - uses `NoTagger`
- `filter::AbstractTagFilter`: the tag matching method, dispatching `find_tags` - uses `NoTagFilter`
- `reranker::AbstractReranker`: the reranking method, dispatching `rerank` - uses `NoReranker`
"""
@kwdef mutable struct PineconeRetriever <: AbstractRetriever
rephraser::AbstractRephraser = NoRephraser()
# TODO: BatchEmbedder?
embedder::AbstractEmbedder = SimpleEmbedder()
processor::AbstractProcessor = NoProcessor()
# TODO: actually do something with this; Pinecone allows choosing finder
finder::AbstractSimilarityFinder = CosineSimilarity()
tagger::AbstractTagger = NoTagger()
filter::AbstractTagFilter = NoTagFilter()
Expand Down Expand Up @@ -1242,6 +1257,33 @@ function retrieve(retriever::AbstractRetriever,
return result
end

"""
retrieve(retriever::PineconeRetriever,
index::PineconeIndex,
question::AbstractString;
verbose::Integer = 1,
top_k::Integer = 100,
top_n::Integer = 10,
api_kwargs::NamedTuple = NamedTuple(),
rephraser::AbstractRephraser = retriever.rephraser,
rephraser_kwargs::NamedTuple = NamedTuple(),
embedder::AbstractEmbedder = retriever.embedder,
embedder_kwargs::NamedTuple = NamedTuple(),
processor::AbstractProcessor = retriever.processor,
processor_kwargs::NamedTuple = NamedTuple(),
finder::AbstractSimilarityFinder = retriever.finder,
finder_kwargs::NamedTuple = NamedTuple(),
tagger::AbstractTagger = retriever.tagger,
tagger_kwargs::NamedTuple = NamedTuple(),
filter::AbstractTagFilter = retriever.filter,
filter_kwargs::NamedTuple = NamedTuple(),
reranker::AbstractReranker = retriever.reranker,
reranker_kwargs::NamedTuple = NamedTuple(),
cost_tracker = Threads.Atomic{Float64}(0.0),
kwargs...)
Dispatch method for `PineconeIndex`.
"""
function retrieve(retriever::PineconeRetriever,
index::PineconeIndex,
question::AbstractString;
Expand Down
Loading

0 comments on commit b027d05

Please sign in to comment.