Skip to content

Commit

Permalink
Ability to generate handoff message from AssistantAgent (microsoft#3968)
Browse files Browse the repository at this point in the history
* Ability to generate handoff message from AssistantAgent

* Fix mypy

* Validation

---------

Co-authored-by: Victor Dibia <[email protected]>
  • Loading branch information
ekzhu and victordibia authored Oct 29, 2024
1 parent 14846a3 commit eb4b1f8
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._assistant_agent import AssistantAgent
from ._assistant_agent import AssistantAgent, Handoff
from ._base_chat_agent import BaseChatAgent
from ._code_executor_agent import CodeExecutorAgent
from ._coding_assistant_agent import CodingAssistantAgent
Expand All @@ -7,6 +7,7 @@
__all__ = [
"BaseChatAgent",
"AssistantAgent",
"Handoff",
"CodeExecutorAgent",
"CodingAssistantAgent",
"ToolUseAssistantAgent",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import logging
from typing import Any, Awaitable, Callable, List, Sequence
from typing import Any, Awaitable, Callable, Dict, List, Sequence

from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall
Expand All @@ -15,11 +15,12 @@
UserMessage,
)
from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field, model_validator

from .. import EVENT_LOGGER_NAME
from ..messages import (
ChatMessage,
HandoffMessage,
StopMessage,
TextMessage,
)
Expand All @@ -31,6 +32,9 @@
class ToolCallEvent(BaseModel):
"""A tool call event."""

source: str
"""The source of the event."""

tool_calls: List[FunctionCall]
"""The tool call message."""

Expand All @@ -40,12 +44,58 @@ class ToolCallEvent(BaseModel):
class ToolCallResultEvent(BaseModel):
"""A tool call result event."""

source: str
"""The source of the event."""

tool_call_results: List[FunctionExecutionResult]
"""The tool call result message."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class Handoff(BaseModel):
"""Handoff configuration for :class:`AssistantAgent`."""

target: str
"""The name of the target agent to handoff to."""

description: str = Field(default=None)
"""The description of the handoff such as the condition under which it should happen and the target agent's ability.
If not provided, it is generated from the target agent's name."""

name: str = Field(default=None)
"""The name of this handoff configuration. If not provided, it is generated from the target agent's name."""

message: str = Field(default=None)
"""The message to the target agent.
If not provided, it is generated from the target agent's name."""

@model_validator(mode="before")
@classmethod
def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("description") is None:
values["description"] = f"Handoff to {values['target']}."
if values.get("name") is None:
values["name"] = f"transfer_to_{values['target']}".lower()
else:
name = values["name"]
if not isinstance(name, str):
raise ValueError(f"Handoff name must be a string: {values['name']}")
# Check if name is a valid identifier.
if not name.isidentifier():
raise ValueError(f"Handoff name must be a valid identifier: {values['name']}")
if values.get("message") is None:
values["message"] = (
f"Transferred to {values['target']}, adopting the role of {values['target']} immediately."
)
return values

@property
def handoff_tool(self) -> Tool:
"""Create a handoff tool from this handoff configuration."""
return FunctionTool(lambda: self.message, name=self.name, description=self.description)


class AssistantAgent(BaseChatAgent):
"""An agent that provides assistance with tool use.
Expand All @@ -55,8 +105,52 @@ class AssistantAgent(BaseChatAgent):
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client to use for inference.
tools (List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
handoffs (List[Handoff | str] | None, optional): The handoff configurations for the agent, allowing it to transfer to other agents by responding with a HandoffMessage.
If a handoff is a string, it should represent the target agent's name.
description (str, optional): The description of the agent.
system_message (str, optional): The system message for the model.
Raises:
ValueError: If tool names are not unique.
ValueError: If handoff names are not unique.
ValueError: If handoff names are not unique from tool names.
Examples:
The following example demonstrates how to create an assistant agent with
a model client and generate a response to a simple task.
.. code-block:: python
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.task import MaxMessageTermination
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client)
await agent.run("What is the capital of France?", termination_condition=MaxMessageTermination(2))
The following example demonstrates how to create an assistant agent with
a model client and a tool, and generate a response to a simple task using the tool.
.. code-block:: python
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.task import MaxMessageTermination
async def get_current_time() -> str:
return "The current time is 12:00 PM."
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
await agent.run("What is the current time?", termination_condition=MaxMessageTermination(3))
"""

def __init__(
Expand All @@ -65,6 +159,7 @@ def __init__(
model_client: ChatCompletionClient,
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[Handoff | str] | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
):
Expand All @@ -84,33 +179,71 @@ def __init__(
self._tools.append(FunctionTool(tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
tool_names = [tool.name for tool in self._tools]
if len(tool_names) != len(set(tool_names)):
raise ValueError(f"Tool names must be unique: {tool_names}")
# Handoff tools.
self._handoff_tools: List[Tool] = []
self._handoffs: Dict[str, Handoff] = {}
if handoffs is not None:
for handoff in handoffs:
if isinstance(handoff, str):
handoff = Handoff(target=handoff)
if isinstance(handoff, Handoff):
self._handoff_tools.append(handoff.handoff_tool)
self._handoffs[handoff.name] = handoff
else:
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
# Check if handoff tool names are unique.
handoff_tool_names = [tool.name for tool in self._handoff_tools]
if len(handoff_tool_names) != len(set(handoff_tool_names)):
raise ValueError(f"Handoff names must be unique: {handoff_tool_names}")
# Check if handoff tool names not in tool names.
if any(name in tool_names for name in handoff_tool_names):
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context: List[LLMMessage] = []

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
# Add messages to the model context.
for msg in messages:
# TODO: add special handling for handoff messages
self._model_context.append(UserMessage(content=msg.content, source=msg.source))

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content))
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
self._model_context.append(FunctionExecutionResultMessage(content=results))

# Detect handoff requests.
handoffs: List[Handoff] = []
for call in result.content:
if call.name in self._handoffs:
handoffs.append(self._handoffs[call.name])
if len(handoffs) > 0:
if len(handoffs) > 1:
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
# Respond with a handoff message.
return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name)

# Generate an inference result based on the current model context.
result = await self._model_client.create(
self._model_context, tools=self._tools, cancellation_token=cancellation_token
self._model_context, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

Expand All @@ -127,9 +260,9 @@ async def _execute_tool_call(
) -> FunctionExecutionResult:
"""Execute a tool call and return the result."""
try:
if not self._tools:
if not self._tools + self._handoff_tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._tools if t.name == tool_call.name), None)
tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ class StopMessage(BaseMessage):
class HandoffMessage(BaseMessage):
"""A message requesting handoff of a conversation to another agent."""

target: str
"""The name of the target agent to handoff to."""

content: str
"""The agent name to handoff the conversation to."""
"""The handoff message to the target agent."""


ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
"""Select a speaker from the participants based on handoff message."""
if len(thread) > 0 and isinstance(thread[-1].agent_message, HandoffMessage):
self._current_speaker = thread[-1].agent_message.content
self._current_speaker = thread[-1].agent_message.target
if self._current_speaker not in self._participant_topic_types:
raise ValueError("The selected speaker in the handoff message is not a participant.")
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=self._current_speaker, source=self.id))
Expand All @@ -47,7 +47,40 @@ async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:


class Swarm(BaseGroupChat):
"""(Experimental) A group chat that selects the next speaker based on handoff message only."""
"""A group chat team that selects the next speaker based on handoff message only.
The first participant in the list of participants is the initial speaker.
The next speaker is selected based on the :class:`~autogen_agentchat.messages.HandoffMessage` message
sent by the current speaker. If no handoff message is sent, the current speaker
continues to be the speaker.
Args:
participants (List[ChatAgent]): The agents participating in the group chat. The first agent in the list is the initial speaker.
Examples:
.. code-block:: python
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import Swarm
from autogen_agentchat.task import MaxMessageTermination
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent1 = AssistantAgent(
"Alice",
model_client=model_client,
handoffs=["Bob"],
system_message="You are Alice and you only answer questions about yourself.",
)
agent2 = AssistantAgent(
"Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January."
)
team = Swarm([agent1, agent2])
await team.run("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
"""

def __init__(self, participants: List[ChatAgent]):
super().__init__(participants, group_chat_manager_class=SwarmGroupChatManager)
Expand Down
60 changes: 58 additions & 2 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import asyncio
import json
import logging
from typing import Any, AsyncGenerator, List

import pytest
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import StopMessage, TextMessage
from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import AssistantAgent, Handoff
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
from autogen_core.base import CancellationToken
from autogen_core.components.tools import FunctionTool
from autogen_ext.models import OpenAIChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
Expand All @@ -14,6 +18,10 @@
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
from openai.types.completion_usage import CompletionUsage

logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.setLevel(logging.DEBUG)
logger.addHandler(FileLogHandler("test_assistant_agent.log"))


class _MockChatCompletion:
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
Expand Down Expand Up @@ -107,3 +115,51 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], TextMessage)
assert isinstance(result.messages[2], StopMessage)


@pytest.mark.asyncio
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
handoff = Handoff(target="agent2")
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name=handoff.name,
arguments=json.dumps({}),
),
)
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
tool_use_agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
handoffs=[handoff],
)
response = await tool_use_agent.on_messages(
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
)
assert isinstance(response, HandoffMessage)
assert response.target == "agent2"
Loading

0 comments on commit eb4b1f8

Please sign in to comment.