-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
526 additions
and
127 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
from llama_index.core.agent.react import ReActChatFormatter, ReActOutputParser | ||
from llama_index.core.agent.react.types import ( | ||
ActionReasoningStep, | ||
ObservationReasoningStep, | ||
) | ||
from llama_index.core.llms.llm import LLM | ||
from llama_index.core.memory import ChatMemoryBuffer | ||
from llama_index.core.tools.types import BaseTool | ||
from llama_index.core.workflow import ( | ||
Context, | ||
Workflow, | ||
StartEvent, | ||
StopEvent, | ||
step, | ||
) | ||
from llama_index.llms.openai import OpenAI | ||
from llama_index.core.llms import ChatMessage | ||
from llama_index.core.tools import ToolSelection, ToolOutput | ||
from llama_index.core.workflow import Event | ||
from typing import Any | ||
|
||
|
||
class PrepEvent(Event): | ||
pass | ||
|
||
|
||
class InputEvent(Event): | ||
input: list[ChatMessage] | ||
|
||
|
||
class ToolCallEvent(Event): | ||
tool_calls: list[ToolSelection] | ||
|
||
|
||
class FunctionOutputEvent(Event): | ||
output: ToolOutput | ||
|
||
class ReActAgent(Workflow): | ||
def __init__( | ||
self, | ||
*args: Any, | ||
llm: LLM | None = None, | ||
tools: list[BaseTool] | None = None, | ||
extra_context: str | None = None, | ||
**kwargs: Any, | ||
) -> None: | ||
with open("react_system_header.txt", "r", encoding="utf-8") as f: | ||
react_system_header_str = f.read() | ||
super().__init__(*args, **kwargs) | ||
self.tools = tools or [] | ||
|
||
self.llm = llm or OpenAI() | ||
|
||
self.memory = ChatMemoryBuffer.from_defaults(llm=llm) | ||
self.formatter = ReActChatFormatter(context=extra_context or "", system_header=react_system_header_str) | ||
self.output_parser = ReActOutputParser() | ||
self.sources = [] | ||
|
||
@step | ||
async def new_user_msg(self, ctx: Context, ev: StartEvent) -> PrepEvent: | ||
# clear sources | ||
self.sources = [] | ||
|
||
# get user input | ||
user_input = ev.input | ||
user_msg = ChatMessage(role="user", content=user_input) | ||
self.memory.put(user_msg) | ||
# clear current reasoning | ||
await ctx.set("current_reasoning", []) | ||
|
||
return PrepEvent() | ||
|
||
@step | ||
async def prepare_chat_history( | ||
self, ctx: Context, ev: PrepEvent | ||
) -> InputEvent: | ||
# get chat history | ||
chat_history = self.memory.get() | ||
current_reasoning = await ctx.get("current_reasoning", default=[]) | ||
llm_input = self.formatter.format( | ||
self.tools, chat_history, current_reasoning=current_reasoning | ||
) | ||
return InputEvent(input=llm_input) | ||
|
||
@step | ||
async def handle_llm_input( | ||
self, ctx: Context, ev: InputEvent | ||
) -> ToolCallEvent | StopEvent: | ||
chat_history = ev.input | ||
|
||
response = await self.llm.achat(chat_history) | ||
try: | ||
reasoning_step = self.output_parser.parse(response.message.content) | ||
print(reasoning_step) | ||
(await ctx.get("current_reasoning", default=[])).append( | ||
reasoning_step | ||
) | ||
if reasoning_step.is_done: | ||
self.memory.put( | ||
ChatMessage( | ||
role="assistant", content=reasoning_step.response | ||
) | ||
) | ||
return StopEvent( | ||
result={ | ||
"response": reasoning_step.response, | ||
"sources": [*self.sources], | ||
"reasoning": await ctx.get( | ||
"current_reasoning", default=[] | ||
), | ||
} | ||
) | ||
elif isinstance(reasoning_step, ActionReasoningStep): | ||
tool_name = reasoning_step.action | ||
tool_args = reasoning_step.action_input | ||
return ToolCallEvent( | ||
tool_calls=[ | ||
ToolSelection( | ||
tool_id="fake", | ||
tool_name=tool_name, | ||
tool_kwargs=tool_args, | ||
) | ||
] | ||
) | ||
except Exception as e: | ||
(await ctx.get("current_reasoning", default=[])).append( | ||
ObservationReasoningStep( | ||
observation=f"There was an error in parsing my reasoning: {e}" | ||
) | ||
) | ||
|
||
# if no tool calls or final response, iterate again | ||
return PrepEvent() | ||
|
||
@step | ||
async def handle_tool_calls( | ||
self, ctx: Context, ev: ToolCallEvent | ||
) -> PrepEvent: | ||
tool_calls = ev.tool_calls | ||
tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools} | ||
|
||
# call tools -- safely! | ||
for tool_call in tool_calls: | ||
tool = tools_by_name.get(tool_call.tool_name) | ||
if not tool: | ||
(await ctx.get("current_reasoning", default=[])).append( | ||
ObservationReasoningStep( | ||
observation=f"Tool {tool_call.tool_name} does not exist" | ||
) | ||
) | ||
continue | ||
|
||
try: | ||
tool_output = tool(**tool_call.tool_kwargs) | ||
print(tool_output) | ||
self.sources.append(tool_output) | ||
(await ctx.get("current_reasoning", default=[])).append( | ||
ObservationReasoningStep(observation=tool_output.content) | ||
) | ||
except Exception as e: | ||
(await ctx.get("current_reasoning", default=[])).append( | ||
ObservationReasoningStep( | ||
observation=f"Error calling tool {tool.metadata.get_name()}: {e}" | ||
) | ||
) | ||
|
||
# prep the next iteraiton | ||
return PrepEvent() |
Oops, something went wrong.