Skip to content

Commit

Permalink
Makes select_speaker more robust by checking for mentions anywhere. (m…
Browse files Browse the repository at this point in the history
…icrosoft#669)

* Makes select_speaker more robust by checking for agents mentioned anywhere in the selection string. Addresses 663.

* Added test coverage for group chat mentions. Refactored mention counter to own function.

* Fixed pre-commit formatting.
  • Loading branch information
afourney authored Nov 17, 2023
1 parent d340159 commit f939dda
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
25 changes: 24 additions & 1 deletion autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

import re
from .agent import Agent
from .conversable_agent import ConversableAgent

Expand Down Expand Up @@ -101,6 +101,13 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
if not final:
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
return self.next_agent(last_speaker, agents)

# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified
mentions = self._mentioned_agents(name, agents)
if len(mentions) == 1:
name = next(iter(mentions))

# Return the result
try:
return self.agent_by_name(name)
except ValueError:
Expand All @@ -119,6 +126,22 @@ def _participant_roles(self):
roles.append(f"{agent.name}: {agent.system_message}")
return "\n".join(roles)

def _mentioned_agents(self, message_content: str, agents: List[Agent]) -> Dict:
"""
Finds and counts agent mentions in the string message_content, taking word boundaries into account.
Returns: A dictionary mapping agent names to mention counts (to be included, at least one mention must occur)
"""
mentions = dict()
for agent in agents:
regex = (
r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)"
) # Finds agent mentions, taking word boundaries into account
count = len(re.findall(regex, " " + message_content + " ")) # Pad the message to help with matching
if count > 0:
mentions[agent.name] = count
return mentions


class GroupChatManager(ConversableAgent):
"""(In preview) A chat manager agent that can manage a group chat of multiple agents."""
Expand Down
73 changes: 73 additions & 0 deletions test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import autogen
import json


def test_func_call_groupchat():
Expand Down Expand Up @@ -112,8 +113,80 @@ def test_plugin():
assert len(groupchat.messages) == 2


def test_agent_mentions():
agent1 = autogen.ConversableAgent(
"alice",
max_consecutive_auto_reply=2,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is alice sepaking.",
)
agent2 = autogen.ConversableAgent(
"bob",
max_consecutive_auto_reply=2,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is bob speaking.",
)
agent3 = autogen.ConversableAgent(
"sam",
max_consecutive_auto_reply=2,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is sam speaking.",
)
groupchat = autogen.GroupChat(agents=[agent1, agent2, agent3], messages=[], max_round=2)

# Basic counting
assert json.dumps(groupchat._mentioned_agents("", [agent1, agent2, agent3]), sort_keys=True) == "{}"
assert json.dumps(groupchat._mentioned_agents("alice", [agent1, agent2, agent3]), sort_keys=True) == '{"alice": 1}'
assert (
json.dumps(groupchat._mentioned_agents("alice bob alice", [agent1, agent2, agent3]), sort_keys=True)
== '{"alice": 2, "bob": 1}'
)
assert (
json.dumps(groupchat._mentioned_agents("alice bob alice sam", [agent1, agent2, agent3]), sort_keys=True)
== '{"alice": 2, "bob": 1, "sam": 1}'
)
assert (
json.dumps(groupchat._mentioned_agents("alice bob alice sam robert", [agent1, agent2, agent3]), sort_keys=True)
== '{"alice": 2, "bob": 1, "sam": 1}'
)

# Substring
assert (
json.dumps(groupchat._mentioned_agents("sam samantha basam asami", [agent1, agent2, agent3]), sort_keys=True)
== '{"sam": 1}'
)

# Word boundaries
assert (
json.dumps(groupchat._mentioned_agents("alice! .alice. .alice", [agent1, agent2, agent3]), sort_keys=True)
== '{"alice": 3}'
)

# Special characters in agent names
agent4 = autogen.ConversableAgent(
".*",
max_consecutive_auto_reply=2,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="Match everything.",
)

groupchat = autogen.GroupChat(agents=[agent1, agent2, agent3, agent4], messages=[], max_round=2)
assert (
json.dumps(
groupchat._mentioned_agents("alice bob alice sam robert .*", [agent1, agent2, agent3, agent4]),
sort_keys=True,
)
== '{".*": 1, "alice": 2, "bob": 1, "sam": 1}'
)


if __name__ == "__main__":
test_func_call_groupchat()
# test_broadcast()
test_chat_manager()
# test_plugin()
# test_agent_mentions()

0 comments on commit f939dda

Please sign in to comment.