From 8d5f1766c9e5725f1c5d162fa73f081fd2fb5ce5 Mon Sep 17 00:00:00 2001 From: afourney Date: Wed, 29 Nov 2023 12:43:57 -0800 Subject: [PATCH] GroupChat handle is_termination_msg (#804) * Have GroupChatManager check is_termination_msg * Added test cases. --- autogen/agentchat/groupchat.py | 10 ++++++ test/agentchat/test_groupchat.py | 59 ++++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index 0b52716f1e55..7689246a0b25 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -257,6 +257,11 @@ def run_chat( if message["role"] != "function": message["name"] = speaker.name groupchat.messages.append(message) + + if self._is_termination_msg(message): + # The conversation is over + break + # broadcast the message to all agents except the speaker for agent in groupchat.agents: if agent != speaker: @@ -302,6 +307,11 @@ async def a_run_chat( if message["role"] != "function": message["name"] = speaker.name groupchat.messages.append(message) + + if self._is_termination_msg(message): + # The conversation is over + break + # broadcast the message to all agents except the speaker for agent in groupchat.agents: if agent != speaker: diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py index 441dcb7c6251..9cbbe814cef0 100644 --- a/test/agentchat/test_groupchat.py +++ b/test/agentchat/test_groupchat.py @@ -265,7 +265,7 @@ def test_agent_mentions(): max_consecutive_auto_reply=2, human_input_mode="NEVER", llm_config=False, - default_auto_reply="This is alice sepaking.", + default_auto_reply="This is alice speaking.", ) agent2 = autogen.ConversableAgent( "bob", @@ -330,11 +330,64 @@ def test_agent_mentions(): ) +def test_termination(): + agent1 = autogen.ConversableAgent( + "alice", + max_consecutive_auto_reply=10, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is alice speaking.", + ) + agent2 = autogen.ConversableAgent( + "bob", + max_consecutive_auto_reply=10, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is bob speaking.", + ) + agent3 = autogen.ConversableAgent( + "sam", + max_consecutive_auto_reply=10, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is sam speaking. TERMINATE", + ) + + # Test empty is_termination_msg function + groupchat = autogen.GroupChat( + agents=[agent1, agent2, agent3], messages=[], speaker_selection_method="round_robin", max_round=10 + ) + + group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False, is_termination_msg=None) + + agent1.initiate_chat(group_chat_manager, message="'None' is_termination_msg function.") + assert len(groupchat.messages) == 10 + + # Test user-provided is_termination_msg function + agent1.reset() + agent2.reset() + agent3.reset() + + groupchat = autogen.GroupChat( + agents=[agent1, agent2, agent3], messages=[], speaker_selection_method="round_robin", max_round=10 + ) + + group_chat_manager = autogen.GroupChatManager( + groupchat=groupchat, + llm_config=False, + is_termination_msg=lambda x: x.get("content", "").rstrip().find("TERMINATE") >= 0, + ) + + agent1.initiate_chat(group_chat_manager, message="User-provided is_termination_msg function.") + assert len(groupchat.messages) == 3 + + if __name__ == "__main__": # test_func_call_groupchat() # test_broadcast() # test_chat_manager() # test_plugin() - test_speaker_selection_method() - test_n_agents_less_than_3() + # test_speaker_selection_method() + # test_n_agents_less_than_3() # test_agent_mentions() + test_termination()