From e83578eee055aa5297e092bc1395a5db0c1a16fe Mon Sep 17 00:00:00 2001 From: Benjoyo Date: Fri, 22 Sep 2023 16:40:11 +0200 Subject: [PATCH] fix test --- python/src/gpt/agents/retrieval_agent/retrieval_agent.py | 2 +- python/src/gpt/server/server.py | 1 + python/tests/manual_integration/test.py | 1 + python/tests/server/test_server.py | 1 + python/tests/test_ui/agent_retrieval.py | 3 ++- 5 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/src/gpt/agents/retrieval_agent/retrieval_agent.py b/python/src/gpt/agents/retrieval_agent/retrieval_agent.py index 1811117..cb33728 100644 --- a/python/src/gpt/agents/retrieval_agent/retrieval_agent.py +++ b/python/src/gpt/agents/retrieval_agent/retrieval_agent.py @@ -56,7 +56,7 @@ def json_schema_to_pydantic_model(name: str, schema: Dict[str, Any]) -> Any: def create_retrieval_agent( llm: BaseChatModel, vector_store: VectorStore, - filter_llm: BaseChatModel = ChatOpenAI(temperature=0), + filter_llm: BaseChatModel, summary_store: Optional[VectorStore] = None, output_schema: Optional[Dict[str, Union[str, dict]]] = None, multi_query_expansion: bool = False, diff --git a/python/src/gpt/server/server.py b/python/src/gpt/server/server.py index b10f89d..a9874e7 100644 --- a/python/src/gpt/server/server.py +++ b/python/src/gpt/server/server.py @@ -159,6 +159,7 @@ async def post(task: RetrievalTask): summary_store = None agent = create_retrieval_agent( llm=model_id_to_llm(task.model), + filter_llm=model_id_to_llm(task.model), vector_store=vector_store, output_schema=task.output_schema, reranker=task.reranker, diff --git a/python/tests/manual_integration/test.py b/python/tests/manual_integration/test.py index c2d0160..6af5bed 100644 --- a/python/tests/manual_integration/test.py +++ b/python/tests/manual_integration/test.py @@ -362,6 +362,7 @@ def test_retrieve(): qa = create_retrieval_agent( llm=ChatOpenAI(model_name="gpt-4", temperature=0), + filter_llm=ChatOpenAI(model_name="gpt-4", temperature=0), vector_store=vector_store, multi_query_expansion=False, # metadata_field_info=[ diff --git a/python/tests/server/test_server.py b/python/tests/server/test_server.py index 974813b..eb526e3 100644 --- a/python/tests/server/test_server.py +++ b/python/tests/server/test_server.py @@ -230,6 +230,7 @@ def test_retrieval(agent_function_mock): agent_function_mock.assert_called_with( llm=None, + filter_llm=None, vector_store=None, output_schema={'result': 'the result'}, reranker='test_reranker', diff --git a/python/tests/test_ui/agent_retrieval.py b/python/tests/test_ui/agent_retrieval.py index 106a750..08725b4 100644 --- a/python/tests/test_ui/agent_retrieval.py +++ b/python/tests/test_ui/agent_retrieval.py @@ -50,7 +50,8 @@ ) agent = create_retrieval_agent( - llm=ChatOpenAI(model="gpt-4", streaming=True), + llm=ChatOpenAI(model="gpt-4", streaming=True, temperature=0), + filter_llm=ChatOpenAI(model="gpt-4", temperature=0), vector_store=vector_store, parent_document_store=parent_document_store, summary_store=summary_store,