diff --git a/autogen/oai/anthropic.py b/autogen/oai/anthropic.py index 9faa4e2cb808..e2448929e618 100644 --- a/autogen/oai/anthropic.py +++ b/autogen/oai/anthropic.py @@ -49,10 +49,10 @@ "claude-3-5-sonnet-20240620": (0.003, 0.015), "claude-3-sonnet-20240229": (0.003, 0.015), "claude-3-opus-20240229": (0.015, 0.075), - "claude-2.0": (0.008, 0.024), + "claude-3-haiku-20240307": (0.00025, 0.00125), "claude-2.1": (0.008, 0.024), - "claude-3.0-opus": (0.015, 0.075), - "claude-3.0-haiku": (0.00025, 0.00125), + "claude-2.0": (0.008, 0.024), + "claude-instant-1.2": (0.008, 0.024), } @@ -250,6 +250,7 @@ def oai_messages_to_anthropic_messages(params: Dict[str, Any]) -> list[dict[str, tool_use_messages = 0 tool_result_messages = 0 last_tool_use_index = -1 + last_tool_result_index = -1 for message in params["messages"]: if message["role"] == "system": params["system"] = message["content"] @@ -290,25 +291,26 @@ def oai_messages_to_anthropic_messages(params: Dict[str, Any]) -> list[dict[str, } ) elif "tool_call_id" in message: - - if expected_role == "assistant": - # Insert an extra assistant message as we will append a user message - processed_messages.append(assistant_continue_message) - if has_tools: # Map the tool usage call to tool_result for Anthropic - processed_messages.append( - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": message["tool_call_id"], - "content": message["content"], - } - ], - } - ) + tool_result = { + "type": "tool_result", + "tool_use_id": message["tool_call_id"], + "content": message["content"], + } + + # If the previous message also had a tool_result, add it to that + # Otherwise append a new message + if last_tool_result_index == len(processed_messages) - 1: + processed_messages[-1]["content"].append(tool_result) + else: + if expected_role == "assistant": + # Insert an extra assistant message as we will append a user message + processed_messages.append(assistant_continue_message) + + processed_messages.append({"role": "user", "content": [tool_result]}) + last_tool_result_index = len(processed_messages) - 1 + tool_result_messages += 1 else: # Not using tools, so put in a plain text message