diff --git a/llama_stack/apis/memory_banks/client.py b/llama_stack/apis/memory_banks/client.py index 588a93fe29..69be35d028 100644 --- a/llama_stack/apis/memory_banks/client.py +++ b/llama_stack/apis/memory_banks/client.py @@ -92,6 +92,21 @@ async def run_main(host: str, port: int, stream: bool): response = await client.list_memory_banks() cprint(f"list_memory_banks response={response}", "green") + # register memory bank for the first time + response = await client.register_memory_bank( + VectorMemoryBankDef( + identifier="test_bank2", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) + ) + cprint(f"register_memory_bank response={response}", "blue") + + # list again after registering + response = await client.list_memory_banks() + cprint(f"list_memory_banks response={response}", "green") + def main(host: str, port: int, stream: bool = True): asyncio.run(run_main(host, port, stream)) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 17755f0e42..ede30aea13 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -110,10 +110,16 @@ def get_object_by_identifier( async def register_object(self, obj: RoutableObjectWithProvider): entries = self.registry.get(obj.identifier, []) for entry in entries: - if entry.provider_id == obj.provider_id: - print(f"`{obj.identifier}` already registered with `{obj.provider_id}`") + if entry.provider_id == obj.provider_id or not obj.provider_id: + print( + f"`{obj.identifier}` already registered with `{entry.provider_id}`" + ) return + # if provider_id is not specified, we'll pick an arbitrary one from existing entries + if not obj.provider_id and len(self.impls_by_provider_id) > 0: + obj.provider_id = list(self.impls_by_provider_id.keys())[0] + if obj.provider_id not in self.impls_by_provider_id: raise ValueError(f"Provider `{obj.provider_id}` not found") diff --git a/llama_stack/providers/tests/agents/provider_config_example.yaml b/llama_stack/providers/tests/agents/provider_config_example.yaml index 5b643590c8..58f05e29ae 100644 --- a/llama_stack/providers/tests/agents/provider_config_example.yaml +++ b/llama_stack/providers/tests/agents/provider_config_example.yaml @@ -31,4 +31,4 @@ providers: persistence_store: namespace: null type: sqlite - db_path: /Users/ashwin/.llama/runtime/kvstore.db + db_path: ~/.llama/runtime/kvstore.db diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index edcc6adeab..6774d3f1fc 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -64,6 +64,24 @@ def search_query_messages(): ] +@pytest.fixture +def attachment_message(): + return [ + UserMessage( + content="I am attaching some documentation for Torchtune. Help me answer questions I will ask next.", + ), + ] + + +@pytest.fixture +def query_attachment_messages(): + return [ + UserMessage( + content="What are the top 5 topics that were explained? Only list succinct bullet points." + ), + ] + + @pytest.mark.asyncio async def test_create_agent_turn(agents_settings, sample_messages): agents_impl = agents_settings["impl"] @@ -123,6 +141,89 @@ async def test_create_agent_turn(agents_settings, sample_messages): assert len(final_event.turn.output_message.content) > 0 +@pytest.mark.asyncio +async def test_rag_agent_as_attachments( + agents_settings, attachment_message, query_attachment_messages +): + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + "datasets.rst", + "qat_finetune.rst", + "lora_finetune.rst", + ] + + attachments = [ + Attachment( + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + ) + for i, url in enumerate(urls) + ] + + agents_impl = agents_settings["impl"] + + agent_config = AgentConfig( + model=agents_settings["common_params"]["model"], + instructions=agents_settings["common_params"]["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[ + MemoryToolDefinition( + memory_bank_configs=[], + query_generator_config={ + "type": "default", + "sep": " ", + }, + max_tokens_in_context=4096, + max_chunks=10, + ), + ], + max_infer_iters=5, + ) + + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id + + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=attachment_message, + attachments=attachments, + stream=True, + ) + + turn_response = [ + chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + ] + + assert len(turn_response) > 0 + + # Create a second turn querying the agent + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=query_attachment_messages, + stream=True, + ) + + turn_response = [ + chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + ] + + assert len(turn_response) > 0 + + @pytest.mark.asyncio async def test_create_agent_turn_with_brave_search( agents_settings, search_query_messages diff --git a/llama_stack/providers/tests/memory/provider_config_example.yaml b/llama_stack/providers/tests/memory/provider_config_example.yaml index cac1adde5a..5b5440f8da 100644 --- a/llama_stack/providers/tests/memory/provider_config_example.yaml +++ b/llama_stack/providers/tests/memory/provider_config_example.yaml @@ -2,8 +2,8 @@ providers: - provider_id: test-faiss provider_type: meta-reference config: {} - - provider_id: test-chroma - provider_type: remote::chroma + - provider_id: test-chromadb + provider_type: remote::chromadb config: host: localhost port: 6001 diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index c5ebdf9c77..d92feaba89 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -89,6 +89,30 @@ async def test_banks_list(memory_settings): assert len(response) == 0 +@pytest.mark.asyncio +async def test_banks_register(memory_settings): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + banks_impl = memory_settings["memory_banks_impl"] + bank = VectorMemoryBankDef( + identifier="test_bank_no_provider", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) + + await banks_impl.register_memory_bank(bank) + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 1 + + # register same memory bank with same id again will fail + await banks_impl.register_memory_bank(bank) + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 1 + + @pytest.mark.asyncio async def test_query_documents(memory_settings, sample_documents): memory_impl = memory_settings["memory_impl"]