Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pinecone integration #189

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@

# exclude scratch files
**/_*
docs/package-lock.json
docs/package-lock.json

.env
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
OpenAI = "e9f21f70-7185-4079-aca2-91159181367c"
Pinecone = "ee90fdae-f7f0-4648-8b00-9c0307cf46d9"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Expand Down
4 changes: 2 additions & 2 deletions src/Experimental/RAGTools/RAGTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ include("api_services.jl")

include("rag_interface.jl")

export ChunkIndex, ChunkKeywordsIndex, ChunkEmbeddingsIndex, CandidateChunks, RAGResult
export ChunkIndex, ChunkKeywordsIndex, ChunkEmbeddingsIndex, PineconeIndex, CandidateChunks, CandidateWithChunks, RAGResult
export MultiIndex, SubChunkIndex, MultiCandidateChunks
include("types.jl")

export build_index, get_chunks, get_embeddings, get_keywords, get_tags, SimpleIndexer,
KeywordsIndexer
KeywordsIndexer, PineconeIndexer
include("preparation.jl")

include("rank_gpt.jl")
Expand Down
129 changes: 116 additions & 13 deletions src/Experimental/RAGTools/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ context = build_context(ContextEnumerator(), index, candidates; chunks_window_ma
```
"""
function build_context(contexter::ContextEnumerator,
index::AbstractDocumentIndex, candidates::AbstractCandidateChunks;
index::AbstractDocumentIndex,
candidates::AbstractCandidateChunks;
verbose::Bool = true,
chunks_window_margin::Tuple{Int, Int} = (1, 1), kwargs...)
## Checks
Expand All @@ -63,6 +64,35 @@ function build_context(contexter::ContextEnumerator,
return context
end

"""
build_context(contexter::ContextEnumerator,
index::AbstractManagedIndex, candidates::AbstractCandidateWithChunks;
verbose::Bool = true,
chunks_window_margin::Tuple{Int, Int} = (1, 1), kwargs...)

build_context!(contexter::ContextEnumerator,
index::AbstractManagedIndex, result::AbstractRAGResult; kwargs...)

Dispatch for `AbstractManagedIndex` with `AbstractCandidateWithChunks`.
"""
function build_context(contexter::ContextEnumerator,
index::AbstractManagedIndex,
candidates::AbstractCandidateWithChunks;
verbose::Bool = true, kwargs...)
context = String[]
for (i, _) in enumerate(positions(candidates))
## select the right index
id = candidates isa MultiCandidateChunks ? candidates.index_ids[i] :
candidates.index_id
index_ = index isa AbstractChunkIndex ? index : index[id]
isnothing(index_) && continue

chunks_ = chunks(candidates)
push!(context, "$(i). $(join(chunks_, "\n"))")
end
return context
end

function build_context!(contexter::AbstractContextBuilder,
index::AbstractDocumentIndex, result::AbstractRAGResult; kwargs...)
throw(ArgumentError("Contexter $(typeof(contexter)) not implemented"))
Expand All @@ -74,6 +104,11 @@ function build_context!(contexter::ContextEnumerator,
result.context = build_context(contexter, index, result.reranked_candidates; kwargs...)
return result
end
function build_context!(contexter::ContextEnumerator,
index::AbstractManagedIndex, result::AbstractRAGResult; kwargs...)
result.context = build_context(contexter, index, result.reranked_candidates; kwargs...)
return result
end

## First step: Answerer

Expand Down Expand Up @@ -139,6 +174,42 @@ function answer!(
return result
end

"""
answer!(
answerer::SimpleAnswerer, index::AbstractManagedIndex, result::AbstractRAGResult;
model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true,
template::Symbol = :RAGAnswerFromContext,
cost_tracker = Threads.Atomic{Float64}(0.0),
kwargs...)

Dispatch for `AbstractManagedIndex`.
"""
function answer!(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a separate method? I believe the index isnt used so maybe we just need to allow the new type in the definition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there was an ambiguity issue with this, that's why I added a different function instead of providing the index type as a Union.

answerer::SimpleAnswerer, index::AbstractManagedIndex, result::AbstractRAGResult;
model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true,
template::Symbol = :RAGAnswerFromContext,
cost_tracker = Threads.Atomic{Float64}(0.0),
kwargs...)
## Checks
placeholders = only(aitemplates(template)).variables # only one template should be found
@assert (:question in placeholders)&&(:context in placeholders) "Provided RAG Template $(template) is not suitable. It must have placeholders: `question` and `context`."
##
(; context, question) = result
conv = aigenerate(template; question,
context = join(context, "\n\n"), model, verbose = false,
return_all = true,
kwargs...)
msg = conv[end]
result.answer = strip(msg.content)
result.conversations[:answer] = conv
## Increment the cost tracker
Threads.atomic_add!(cost_tracker, msg.cost)
verbose &&
@info "Done generating the answer. Cost: \$$(round(msg.cost,digits=3))"

return result
end

## Refine
"""
NoRefiner <: AbstractRefiner
Expand All @@ -162,11 +233,12 @@ Refines the answer by executing a web search using the Tavily API. This method a
struct TavilySearchRefiner <: AbstractRefiner end

function refine!(
refiner::AbstractRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
refiner::AbstractRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
kwargs...)
throw(ArgumentError("Refiner $(typeof(refiner)) not implemented"))
end


"""
refine!(
refiner::NoRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
Expand All @@ -175,7 +247,7 @@ end
Simple no-op function for `refine!`. It simply copies the `result.answer` and `result.conversations[:answer]` without any changes.
"""
function refine!(
refiner::NoRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
refiner::NoRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
kwargs...)
result.final_answer = result.answer
if haskey(result.conversations, :answer)
Expand All @@ -184,9 +256,10 @@ function refine!(
return result
end


"""
refine!(
refiner::SimpleRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
refiner::SimpleRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
verbose::Bool = true,
model::AbstractString = PT.MODEL_CHAT,
template::Symbol = :RAGAnswerRefiner,
Expand All @@ -210,7 +283,7 @@ This method uses the same context as the original answer, however, it can be mod
- `cost_tracker`: An atomic counter to track the cost of the operation.
"""
function refine!(
refiner::SimpleRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
refiner::SimpleRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
verbose::Bool = true,
model::AbstractString = PT.MODEL_CHAT,
template::Symbol = :RAGAnswerRefiner,
Expand Down Expand Up @@ -238,9 +311,10 @@ function refine!(
return result
end


"""
refine!(
refiner::TavilySearchRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
refiner::TavilySearchRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
verbose::Bool = true,
model::AbstractString = PT.MODEL_CHAT,
include_answer::Bool = true,
Expand Down Expand Up @@ -288,7 +362,7 @@ pprint(result)
```
"""
function refine!(
refiner::TavilySearchRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
refiner::TavilySearchRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
verbose::Bool = true,
model::AbstractString = PT.MODEL_CHAT,
include_answer::Bool = true,
Expand Down Expand Up @@ -353,13 +427,13 @@ Overload this method to add custom postprocessing steps, eg, logging, saving con
"""
struct NoPostprocessor <: AbstractPostprocessor end

function postprocess!(postprocessor::AbstractPostprocessor, index::AbstractDocumentIndex,
function postprocess!(postprocessor::AbstractPostprocessor, index::Union{AbstractDocumentIndex, AbstractManagedIndex},
result::AbstractRAGResult; kwargs...)
throw(ArgumentError("Postprocessor $(typeof(postprocessor)) not implemented"))
end

function postprocess!(
::NoPostprocessor, index::AbstractDocumentIndex, result::AbstractRAGResult; kwargs...)
::NoPostprocessor, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; kwargs...)
return result
end

Expand Down Expand Up @@ -394,7 +468,7 @@ end

"""
generate!(
generator::AbstractGenerator, index::AbstractDocumentIndex, result::AbstractRAGResult;
generator::AbstractGenerator, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
verbose::Integer = 1,
api_kwargs::NamedTuple = NamedTuple(),
contexter::AbstractContextBuilder = generator.contexter,
Expand Down Expand Up @@ -459,7 +533,7 @@ result = generate!(index, result)
```
"""
function generate!(
generator::AbstractGenerator, index::AbstractDocumentIndex, result::AbstractRAGResult;
generator::AbstractGenerator, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
verbose::Integer = 1,
api_kwargs::NamedTuple = NamedTuple(),
contexter::AbstractContextBuilder = generator.contexter,
Expand Down Expand Up @@ -524,8 +598,9 @@ function Base.show(io::IO, cfg::AbstractRAGConfig)
dump(io, cfg; maxdepth = 2)
end

# TODO: add example for Pinecone
"""
airag(cfg::AbstractRAGConfig, index::AbstractDocumentIndex;
airag(cfg::AbstractRAGConfig, index::Union{AbstractDocumentIndex, AbstractManagedIndex};
question::AbstractString,
verbose::Integer = 1, return_all::Bool = false,
api_kwargs::NamedTuple = NamedTuple(),
Expand Down Expand Up @@ -644,11 +719,35 @@ result = airag(cfg, multi_index; question, return_all=true)

# Pretty-print the result
PT.pprint(result)


Example for Pinecone.

```julia
import LinearAlgebra, Unicode, SparseArrays
using Pinecone

# configure your Pinecone API key, index and namespace

docs_files = ... # files containing docs that you want to upsert to Pinecone
metadata = [Dict{String, Any}("source" => <docs_source>) for file in docs_files] # replace <docs_source> with your docs' sources
index_pinecone = RT.build_index(
RT.PineconeIndexer(),
docs_files;
pinecone_context, # API key wrapped with `Pinecone.jl`
pinecone_index,
pinecone_namespace,
metadata,
upsert = true
)

question = "How do I multiply two vectors in Julia?"
result = RT.airag(index_pinecone; question)
```

For easier manipulation of nested kwargs, see utilities `getpropertynested`, `setpropertynested`, `merge_kwargs_nested`.
"""
function airag(cfg::AbstractRAGConfig, index::AbstractDocumentIndex;
function airag(cfg::AbstractRAGConfig, index::Union{AbstractDocumentIndex, AbstractManagedIndex};
question::AbstractString,
verbose::Integer = 1, return_all::Bool = false,
api_kwargs::NamedTuple = NamedTuple(),
Expand Down Expand Up @@ -693,6 +792,10 @@ const DEFAULT_RAG_CONFIG = RAGConfig()
function airag(index::AbstractDocumentIndex; question::AbstractString, kwargs...)
return airag(DEFAULT_RAG_CONFIG, index; question, kwargs...)
end
const DEFAULT_RAG_CONFIG_PINECONE = RAGConfig(PineconeIndexer(), PineconeRetriever(), AdvancedGenerator())
function airag(index::AbstractManagedIndex; question::AbstractString, kwargs...)
return airag(DEFAULT_RAG_CONFIG_PINECONE, index; question, kwargs...)
end

# Special method to pretty-print the airag results
function PT.pprint(io::IO, airag_result::Tuple{PT.AIMessage, AbstractRAGResult},
Expand Down
Loading