diff --git a/requirements.txt b/requirements.txt index 900837d..49fa7c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,3 +48,5 @@ uvicorn==0.27.0 uvloop==0.19.0 watchfiles==0.21.0 websockets==12.0 +langchain-astradb==0.3.3 +astrapy==1.5.2 diff --git a/wikidatachat/rag.py b/wikidatachat/rag.py index 67eeb1f..00f0c25 100644 --- a/wikidatachat/rag.py +++ b/wikidatachat/rag.py @@ -24,6 +24,10 @@ setup_document_stream_from_list ) +from langchain_astradb import AstraDBVectorStore +from astrapy.info import CollectionVectorServiceOptions +import json + # Retrieve the SERAPI API key from environment variables. SERAPI_API_KEY = os.environ.get("SERAPI_API_KEY") EMBEDDING_MODEL = os.environ.get( @@ -31,6 +35,11 @@ 'svalabs/german-gpl-adapted-covid' ) +# Retrieve the DataStax API keys from environment variables. +COLLECTION_NAME = os.environ.get('COLLECTION_NAME') +ASTRA_DB_APPLICATION_TOKEN = os.environ.get('ASTRA_DB_APPLICATION_TOKEN') +ASTRA_DB_API_ENDPOINT = os.environ.get('ASTRA_DB_API_ENDPOINT') +ASTRA_DB_KEYSPACE = os.environ.get('ASTRA_DB_KEYSPACE') class RetreivalAugmentedGenerationPipeline: def __init__(self, embedding_model=EMBEDDING_MODEL, device='cpu'): @@ -51,6 +60,20 @@ def __init__(self, embedding_model=EMBEDDING_MODEL, device='cpu'): device=self.device ) + collection_vector_service_options = CollectionVectorServiceOptions( + provider="nvidia", + model_name="NV-Embed-QA" + ) + + # Initialize the graph store + self.graph_store = AstraDBVectorStore( + collection_name="wikidata", + collection_vector_service_options=collection_vector_service_options, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + ) + def process_query( self, query: str, top_k: int = 10, lang: str = 'de', content_key: str = None, meta_keys: list = [], @@ -71,63 +94,81 @@ def process_query( Returns: The first answer from the generated answers. """ - if wikidata_kwargs is None: - # Default Wikidata query parameters. - wikidata_kwargs = { - 'timeout': 10, - 'n_cores': cpu_count(), - 'verbose': False, - 'api_url': 'https://www.wikidata.org/w', - 'wikidata_base': '"wikidata.org"', - 'return_list': True - } - - # Create a Document object from the query. - query_document = Document(content=query) - - # Embed the query document. - query_embedded = self.embedder.run([query_document]) - - # Extract the embedding of the query document. - query_embedding = query_embedded['documents'][0].embedding - - # Retrieve Wikidata statements related to the query. - wikidata_statements = get_wikidata_statements_from_query( - query, - lang=lang, - serapi_api_key=SERAPI_API_KEY, - **wikidata_kwargs - ) - # Log the retrieved Wikidata statements for debugging. - self.logger.debug(f'{wikidata_statements=}') - for wds_ in wikidata_statements: - # Log each Wikidata statement for debugging. - self.logger.debug(f'{wds_=}') - - # Setup the document stream from the list of Wikidata statements. - _, retriever = setup_document_stream_from_list( - dict_list=wikidata_statements, - content_key=content_key, - meta_keys=meta_keys, - embedder=self.embedder, - embedding_similarity_function=embedding_similarity_function, - device=self.device - ) - - # Run the retriever to find relevant documents - # based on the query embedding. - retriever_results = retriever.run( - query_embedding=list(query_embedding), - filters=None, - top_k=top_k, - scale_score=None, - return_embedding=None - ) + retriever_results = [] + try: + results = self.graph_store.similarity_search_with_relevance_scores(query, k=top_k) + + retriever_results = [ + Document( + content=r[0].page_content, + score=r[1], + meta={'qid': r[0].metadata['QID']} + ) + for r in results] + except Exception as e: + print(e) + + # If DataStax fails, use SERAPI instead + if len(retriever_results) == 0: + if wikidata_kwargs is None: + # Default Wikidata query parameters. + wikidata_kwargs = { + 'timeout': 10, + 'n_cores': cpu_count(), + 'verbose': False, + 'api_url': 'https://www.wikidata.org/w', + 'wikidata_base': '"wikidata.org"', + 'return_list': True + } + + # Create a Document object from the query. + query_document = Document(content=query) + + # Embed the query document. + query_embedded = self.embedder.run([query_document]) + + # Extract the embedding of the query document. + query_embedding = query_embedded['documents'][0].embedding + + # Retrieve Wikidata statements related to the query. + wikidata_statements = get_wikidata_statements_from_query( + query, + lang=lang, + serapi_api_key=SERAPI_API_KEY, + **wikidata_kwargs + ) + + # Log the retrieved Wikidata statements for debugging. + self.logger.debug(f'{wikidata_statements=}') + for wds_ in wikidata_statements: + # Log each Wikidata statement for debugging. + self.logger.debug(f'{wds_=}') + + # Setup the document stream from the list of Wikidata statements. + _, retriever = setup_document_stream_from_list( + dict_list=wikidata_statements, + content_key=content_key, + meta_keys=meta_keys, + embedder=self.embedder, + embedding_similarity_function=embedding_similarity_function, + device=self.device + ) + + # Run the retriever to find relevant documents + # based on the query embedding. + retriever_results = retriever.run( + query_embedding=list(query_embedding), + filters=None, + top_k=top_k, + scale_score=None, + return_embedding=None + ) + retriever_results = retriever_results['documents'] # Log the start of retriever results for debugging. self.logger.debug('retriever results:') - for retriever_result_ in retriever_results['documents']: + for retriever_result_ in retriever_results: # Log each retriever result for debugging. self.logger.debug(retriever_result_) @@ -140,8 +181,8 @@ def process_query( # Build the user prompt based on the retrieved documents # and the original query. user_prompt_build = user_prompt_builder.run( - question=query_document.content, - documents=retriever_results['documents'] + question=query, + documents=retriever_results ) # Extract the constructed prompt. @@ -167,10 +208,10 @@ def process_query( # Build the answer based on the language model's response # and the retrieved documents. answer_build = answer_builder.run( - query=query_document.content, + query=query, replies=response['replies'], meta=[r.meta for r in response['replies']], - documents=retriever_results['documents'] + documents=retriever_results ) # Log the constructed answer for debugging.