Skip to content

Commit

Permalink
Allow initiate_chat without passing message (microsoft#1244)
Browse files Browse the repository at this point in the history
* allow initiate_chat without passing message

* test human input

* assert called

* Add missing method a_generate_init_message

* fix tests

* add back skipif

* Update test/agentchat/test_async_get_human_input.py

---------

Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
bitnom and sonichi authored Jan 19, 2024
1 parent 9729610 commit e97b639
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 12 deletions.
21 changes: 20 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ def initiate_chat(
silent (bool or None): (Experimental) whether to print the messages for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
Raises:
RuntimeError: if any async reply functions are registered and not ignored in sync chat.
Expand Down Expand Up @@ -707,9 +708,10 @@ async def a_initiate_chat(
silent (bool or None): (Experimental) whether to print the messages for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
"""
self._prepare_chat(recipient, clear_history)
await self.a_send(self.generate_init_message(**context), recipient, silent=silent)
await self.a_send(await self.a_generate_init_message(**context), recipient, silent=silent)

def reset(self):
"""Reset the agent."""
Expand Down Expand Up @@ -1583,7 +1585,24 @@ def generate_init_message(self, **context) -> Union[str, Dict]:
Args:
**context: any context information, and "message" parameter needs to be provided.
If message is not given, prompt for it via input()
"""
if "message" not in context:
context["message"] = self.get_human_input(">")
return context["message"]

async def a_generate_init_message(self, **context) -> Union[str, Dict]:
"""Generate the initial message for the agent.
Override this function to customize the initial message based on user's request.
If not overridden, "message" needs to be provided in the context.
Args:
**context: any context information, and "message" parameter needs to be provided.
If message is not given, prompt for it via input()
"""
if "message" not in context:
context["message"] = await self.a_get_human_input(">")
return context["message"]

def register_function(self, function_map: Dict[str, Callable]):
Expand Down
21 changes: 10 additions & 11 deletions test/agentchat/test_async_get_human_input.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import os
import sys
from unittest.mock import AsyncMock

import autogen
import pytest
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
import sys
import os

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import skip_openai # noqa: E402
Expand All @@ -25,20 +27,17 @@ async def test_async_get_human_input():
assistant = autogen.AssistantAgent(
name="assistant",
max_consecutive_auto_reply=2,
llm_config={"timeout": 600, "cache_seed": 41, "config_list": config_list, "temperature": 0},
llm_config={"seed": 41, "config_list": config_list, "temperature": 0},
)

user_proxy = autogen.UserProxyAgent(name="user", human_input_mode="ALWAYS", code_execution_config=False)

async def custom_a_get_human_input(prompt):
return "This is a test"

user_proxy.a_get_human_input = custom_a_get_human_input
user_proxy.a_get_human_input = AsyncMock(return_value="This is a test")

user_proxy.register_reply([autogen.Agent, None], autogen.ConversableAgent.a_check_termination_and_human_reply)

await user_proxy.a_initiate_chat(assistant, clear_history=True, message="Hello.")


if __name__ == "__main__":
test_async_get_human_input()
# Test without message
await user_proxy.a_initiate_chat(assistant, clear_history=True)
# Assert that custom a_get_human_input was called at least once
user_proxy.a_get_human_input.assert_called()
46 changes: 46 additions & 0 deletions test/agentchat/test_human_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import autogen
import pytest
from unittest.mock import MagicMock
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
import sys
import os

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import skip_openai # noqa: E402

try:
from openai import OpenAI
except ImportError:
skip = True
else:
skip = False or skip_openai


@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")
def test_get_human_input():
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, KEY_LOC)

# create an AssistantAgent instance named "assistant"
assistant = autogen.AssistantAgent(
name="assistant",
max_consecutive_auto_reply=2,
llm_config={"timeout": 600, "cache_seed": 41, "config_list": config_list, "temperature": 0},
)

user_proxy = autogen.UserProxyAgent(name="user", human_input_mode="ALWAYS", code_execution_config=False)

# Use MagicMock to create a mock get_human_input function
user_proxy.get_human_input = MagicMock(return_value="This is a test")

user_proxy.register_reply([autogen.Agent, None], autogen.ConversableAgent.a_check_termination_and_human_reply)

user_proxy.initiate_chat(assistant, clear_history=True, message="Hello.")
# Test without supplying messages parameter
user_proxy.initiate_chat(assistant, clear_history=True)

# Assert that custom_a_get_human_input was called at least once
user_proxy.get_human_input.assert_called()


if __name__ == "__main__":
test_get_human_input()

0 comments on commit e97b639

Please sign in to comment.