Skip to content

Commit

Permalink
Merge pull request #17 from philippesaade-wmde/main
Browse files Browse the repository at this point in the history
Include DataStax integration
  • Loading branch information
exowanderer authored Nov 7, 2024
2 parents cf0a8b0 + d8369b8 commit 4ee05e7
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 57 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,5 @@ uvicorn==0.27.0
uvloop==0.19.0
watchfiles==0.21.0
websockets==12.0
langchain-astradb==0.3.3
astrapy==1.5.2
155 changes: 98 additions & 57 deletions wikidatachat/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,22 @@
setup_document_stream_from_list
)

from langchain_astradb import AstraDBVectorStore
from astrapy.info import CollectionVectorServiceOptions
import json

# Retrieve the SERAPI API key from environment variables.
SERAPI_API_KEY = os.environ.get("SERAPI_API_KEY")
EMBEDDING_MODEL = os.environ.get(
'EMBEDDING_MODEL',
'svalabs/german-gpl-adapted-covid'
)

# Retrieve the DataStax API keys from environment variables.
COLLECTION_NAME = os.environ.get('COLLECTION_NAME')
ASTRA_DB_APPLICATION_TOKEN = os.environ.get('ASTRA_DB_APPLICATION_TOKEN')
ASTRA_DB_API_ENDPOINT = os.environ.get('ASTRA_DB_API_ENDPOINT')
ASTRA_DB_KEYSPACE = os.environ.get('ASTRA_DB_KEYSPACE')

class RetreivalAugmentedGenerationPipeline:
def __init__(self, embedding_model=EMBEDDING_MODEL, device='cpu'):
Expand All @@ -51,6 +60,20 @@ def __init__(self, embedding_model=EMBEDDING_MODEL, device='cpu'):
device=self.device
)

collection_vector_service_options = CollectionVectorServiceOptions(
provider="nvidia",
model_name="NV-Embed-QA"
)

# Initialize the graph store
self.graph_store = AstraDBVectorStore(
collection_name="wikidata",
collection_vector_service_options=collection_vector_service_options,
token=ASTRA_DB_APPLICATION_TOKEN,
api_endpoint=ASTRA_DB_API_ENDPOINT,
namespace=ASTRA_DB_KEYSPACE,
)

def process_query(
self, query: str, top_k: int = 10, lang: str = 'de',
content_key: str = None, meta_keys: list = [],
Expand All @@ -71,63 +94,81 @@ def process_query(
Returns:
The first answer from the generated answers.
"""
if wikidata_kwargs is None:
# Default Wikidata query parameters.
wikidata_kwargs = {
'timeout': 10,
'n_cores': cpu_count(),
'verbose': False,
'api_url': 'https://www.wikidata.org/w',
'wikidata_base': '"wikidata.org"',
'return_list': True
}

# Create a Document object from the query.
query_document = Document(content=query)

# Embed the query document.
query_embedded = self.embedder.run([query_document])

# Extract the embedding of the query document.
query_embedding = query_embedded['documents'][0].embedding

# Retrieve Wikidata statements related to the query.
wikidata_statements = get_wikidata_statements_from_query(
query,
lang=lang,
serapi_api_key=SERAPI_API_KEY,
**wikidata_kwargs
)

# Log the retrieved Wikidata statements for debugging.
self.logger.debug(f'{wikidata_statements=}')
for wds_ in wikidata_statements:
# Log each Wikidata statement for debugging.
self.logger.debug(f'{wds_=}')

# Setup the document stream from the list of Wikidata statements.
_, retriever = setup_document_stream_from_list(
dict_list=wikidata_statements,
content_key=content_key,
meta_keys=meta_keys,
embedder=self.embedder,
embedding_similarity_function=embedding_similarity_function,
device=self.device
)

# Run the retriever to find relevant documents
# based on the query embedding.
retriever_results = retriever.run(
query_embedding=list(query_embedding),
filters=None,
top_k=top_k,
scale_score=None,
return_embedding=None
)
retriever_results = []
try:
results = self.graph_store.similarity_search_with_relevance_scores(query, k=top_k)

retriever_results = [
Document(
content=r[0].page_content,
score=r[1],
meta={'qid': r[0].metadata['QID']}
)
for r in results]
except Exception as e:
print(e)

# If DataStax fails, use SERAPI instead
if len(retriever_results) == 0:
if wikidata_kwargs is None:
# Default Wikidata query parameters.
wikidata_kwargs = {
'timeout': 10,
'n_cores': cpu_count(),
'verbose': False,
'api_url': 'https://www.wikidata.org/w',
'wikidata_base': '"wikidata.org"',
'return_list': True
}

# Create a Document object from the query.
query_document = Document(content=query)

# Embed the query document.
query_embedded = self.embedder.run([query_document])

# Extract the embedding of the query document.
query_embedding = query_embedded['documents'][0].embedding

# Retrieve Wikidata statements related to the query.
wikidata_statements = get_wikidata_statements_from_query(
query,
lang=lang,
serapi_api_key=SERAPI_API_KEY,
**wikidata_kwargs
)

# Log the retrieved Wikidata statements for debugging.
self.logger.debug(f'{wikidata_statements=}')
for wds_ in wikidata_statements:
# Log each Wikidata statement for debugging.
self.logger.debug(f'{wds_=}')

# Setup the document stream from the list of Wikidata statements.
_, retriever = setup_document_stream_from_list(
dict_list=wikidata_statements,
content_key=content_key,
meta_keys=meta_keys,
embedder=self.embedder,
embedding_similarity_function=embedding_similarity_function,
device=self.device
)

# Run the retriever to find relevant documents
# based on the query embedding.
retriever_results = retriever.run(
query_embedding=list(query_embedding),
filters=None,
top_k=top_k,
scale_score=None,
return_embedding=None
)
retriever_results = retriever_results['documents']

# Log the start of retriever results for debugging.
self.logger.debug('retriever results:')
for retriever_result_ in retriever_results['documents']:
for retriever_result_ in retriever_results:
# Log each retriever result for debugging.
self.logger.debug(retriever_result_)

Expand All @@ -140,8 +181,8 @@ def process_query(
# Build the user prompt based on the retrieved documents
# and the original query.
user_prompt_build = user_prompt_builder.run(
question=query_document.content,
documents=retriever_results['documents']
question=query,
documents=retriever_results
)

# Extract the constructed prompt.
Expand All @@ -167,10 +208,10 @@ def process_query(
# Build the answer based on the language model's response
# and the retrieved documents.
answer_build = answer_builder.run(
query=query_document.content,
query=query,
replies=response['replies'],
meta=[r.meta for r in response['replies']],
documents=retriever_results['documents']
documents=retriever_results
)

# Log the constructed answer for debugging.
Expand Down

0 comments on commit 4ee05e7

Please sign in to comment.