From dd4a2da20495d1282b6a1efb71640bb6fee2f073 Mon Sep 17 00:00:00 2001 From: afourney Date: Sun, 17 Dec 2023 19:51:39 -0800 Subject: [PATCH] Enable allow_repeat_speaker to be a list of agents that are allowed to repeat, rather than just a global boolean. (#905) Co-authored-by: Qingyun Wu --- autogen/agentchat/groupchat.py | 23 +++++++++++++++-------- test/agentchat/test_groupchat.py | 4 ++-- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index c420d7b22044..5b12a97e6b17 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -30,16 +30,16 @@ class GroupChat: - "manual": the next speaker is selected manually by user input. - "random": the next speaker is selected randomly. - "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`. - - allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True. + - allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True, in which case all speakers are allowed to speak consecutively. If allow_repeat_speaker is a list of Agents, then only those listed agents are allowed to repeat. If set to False, then no speakers are allowed to repeat. """ agents: List[Agent] messages: List[Dict] - max_round: int = 10 - admin_name: str = "Admin" - func_call_filter: bool = True - speaker_selection_method: str = "auto" - allow_repeat_speaker: bool = True + max_round: Optional[int] = 10 + admin_name: Optional[str] = "Admin" + func_call_filter: Optional[bool] = True + speaker_selection_method: Optional[str] = "auto" + allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = True _VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"] @@ -125,6 +125,13 @@ def _prepare_and_select_agents(self, last_speaker: Agent) -> Tuple[Optional[Agen f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). " ) + # If provided a list, make sure the agent is in the list + allow_repeat_speaker = ( + self.allow_repeat_speaker + if isinstance(self.allow_repeat_speaker, bool) + else last_speaker in self.allow_repeat_speaker + ) + agents = self.agents n_agents = len(agents) # Warn if GroupChat is underpopulated @@ -133,7 +140,7 @@ def _prepare_and_select_agents(self, last_speaker: Agent) -> Tuple[Optional[Agen f"GroupChat is underpopulated with {n_agents} agents. " "Please add more agents to the GroupChat or use direct communication instead." ) - elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and self.allow_repeat_speaker: + elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker: logger.warning( f"GroupChat is underpopulated with {n_agents} agents. " "It is recommended to set speaker_selection_method to 'round_robin' or allow_repeat_speaker to False." @@ -159,7 +166,7 @@ def _prepare_and_select_agents(self, last_speaker: Agent) -> Tuple[Optional[Agen "Please check the function_map of the agents." ) # remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False - agents = agents if self.allow_repeat_speaker else [agent for agent in agents if agent != last_speaker] + agents = agents if allow_repeat_speaker else [agent for agent in agents if agent != last_speaker] if self.speaker_selection_method.lower() == "manual": selected_agent = self.manual_select_speaker(agents) diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py index 27fda1fd5249..6d592ae3fa3d 100644 --- a/test/agentchat/test_groupchat.py +++ b/test/agentchat/test_groupchat.py @@ -187,7 +187,7 @@ def _test_n_agents_less_than_3(method): messages=[], max_round=6, speaker_selection_method=method, - allow_repeat_speaker=True if method == "random" else False, + allow_repeat_speaker=[agent1, agent2] if method == "random" else False, ) group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False) agent1.initiate_chat(group_chat_manager, message="This is alice speaking.") @@ -434,7 +434,7 @@ def test_next_agent(): # test_broadcast() # test_chat_manager() # test_plugin() - # test_speaker_selection_method() + test_speaker_selection_method() # test_n_agents_less_than_3() # test_agent_mentions() # test_termination()