From cdaa11e0e08f51745cb3c9a65ec4f9e4ff03f317 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Fri, 13 Dec 2024 13:38:23 +0100 Subject: [PATCH 1/2] Make nft web app nicer --- .../app_nft_treasury_game.py | 86 +++++++++++-------- .../db/long_term_memory_table_handler.py | 2 +- prediction_market_agent/db/sql_handler.py | 2 +- 3 files changed, 52 insertions(+), 38 deletions(-) diff --git a/prediction_market_agent/agents/microchain_agent/nft_treasury_game/app_nft_treasury_game.py b/prediction_market_agent/agents/microchain_agent/nft_treasury_game/app_nft_treasury_game.py index 21296f0b..5bf03757 100644 --- a/prediction_market_agent/agents/microchain_agent/nft_treasury_game/app_nft_treasury_game.py +++ b/prediction_market_agent/agents/microchain_agent/nft_treasury_game/app_nft_treasury_game.py @@ -11,7 +11,8 @@ import streamlit as st from microchain.functions import Reasoning from prediction_market_agent_tooling.tools.balances import get_balances -from prediction_market_agent_tooling.tools.datetime_utc import DatetimeUTC +from prediction_market_agent_tooling.tools.utils import check_not_none +from streamlit_extras.stylable_container import stylable_container from prediction_market_agent.agents.identifiers import AgentIdentifier from prediction_market_agent.agents.microchain_agent.messages_functions import ( @@ -27,6 +28,7 @@ DeployableAgentNFTGameAbstract, ) from prediction_market_agent.db.long_term_memory_table_handler import ( + LongTermMemories, LongTermMemoryTableHandler, ) from prediction_market_agent.db.prompt_table_handler import PromptTableHandler @@ -63,8 +65,8 @@ def send_message_part(nft_agent: type[DeployableAgentNFTGameAbstract]) -> None: def parse_function_and_body( - role: t.Literal["user", "assistant", "system"], message: str -) -> t.Tuple[str | None, str | None]: + role: t.Literal["user", "assistant"], message: str +) -> t.Tuple[str, str]: message = message.strip() if role == "assistant": @@ -75,10 +77,6 @@ def parse_function_and_body( # Responses from the individual functions are stored under `user` role. parsed_function = DummyFunctionName.RESPONSE_FUNCTION_NAME parsed_body = message - elif role == "system": - # System message isn't shown in the chat history, so ignore. - parsed_function = None - parsed_body = None else: raise ValueError(f"Unknown role: {role}") @@ -86,26 +84,23 @@ def parse_function_and_body( def customized_chat_message( - role: t.Literal["user", "assistant", "system"], - message: str, - created_at: DatetimeUTC, + function_call: LongTermMemories, + function_output: LongTermMemories, ) -> None: - parsed_function, parsed_body = parse_function_and_body(role, message) - if parsed_function is None: - return - # If the message is output from one of these functions, skip it, because it's not interesting to read `The reasoning has been recorded` and similar over and over again. - if parsed_body in ( - Reasoning()(""), - BroadcastPublicMessageToHumans.OUTPUT_TEXT, - SendPaidMessageToAnotherAgent.OUTPUT_TEXT, - ): - return + created_at = function_output.datetime_ + + parsed_function_call_name, parsed_function_call_body = parse_function_and_body( + check_not_none(function_call.metadata_dict)["role"], + check_not_none(function_call.metadata_dict)["content"], + ) + parsed_function_output_name, parsed_function_output_body = parse_function_and_body( + check_not_none(function_output.metadata_dict)["role"], + check_not_none(function_output.metadata_dict)["content"], + ) - match parsed_function: + match parsed_function_call_name: case Reasoning.__name__: icon = "🧠" - case DummyFunctionName.RESPONSE_FUNCTION_NAME: - icon = "✔️" case ReceiveMessage.__name__: icon = "👤" case BroadcastPublicMessageToHumans.__name__: @@ -116,11 +111,24 @@ def customized_chat_message( icon = "🤖" with st.chat_message(icon): - if parsed_function: - st.markdown(f"**{parsed_function}**") - st.write(created_at.strftime("%Y-%m-%d %H:%M:%S")) - if message: - st.markdown(parsed_body) + if parsed_function_call_name == Reasoning.__name__: + # Don't show reasoning as function call, to make it a bit nicer. + st.markdown( + parsed_function_call_body.replace("reasoning='", "").replace("')", "") + ) + else: + # Otherwise, show it as a normal function-response call, e.g. `ReceiveMessages() -> ...`. + st.markdown( + f"**{parsed_function_call_name}**({parsed_function_call_body}) *{created_at.strftime('%Y-%m-%d %H:%M:%S')}*" + ) + + # Only show the output if it's supposed to be interesting. + if parsed_function_call_name not in ( + Reasoning.__name__, + BroadcastPublicMessageToHumans.__name__, + SendPaidMessageToAnotherAgent.__name__, + ): + st.markdown(parsed_function_output_body) @st.fragment(run_every=timedelta(seconds=5)) @@ -134,14 +142,20 @@ def show_function_calls_part(nft_agent: type[DeployableAgentNFTGameAbstract]) -> st.markdown("No actions yet.") return - for item in calls: - if item.metadata_dict is None: - continue - customized_chat_message( - item.metadata_dict["role"], - item.metadata_dict["content"], - item.datetime_, - ) + # Filter out system calls, because they aren't supposed to be shown in the chat history itself. + calls = [ + call for call in calls if check_not_none(call.metadata_dict)["role"] != "system" + ] + + # Microchain works on `function call` - `funtion response` pairs, so we will process them together. + for index, (function_output, function_call) in enumerate( + zip(calls[::2], calls[1::2]) + ): + with stylable_container( + key=f"function_call_{index}", + css_styles=f"{{background-color: {'#f0f0f0' if (index % 2 == 0) else 'white'}; border-radius: 5px;}}", + ): + customized_chat_message(function_call, function_output) @st.fragment(run_every=timedelta(seconds=5)) diff --git a/prediction_market_agent/db/long_term_memory_table_handler.py b/prediction_market_agent/db/long_term_memory_table_handler.py index d9a7c3a7..463f2054 100644 --- a/prediction_market_agent/db/long_term_memory_table_handler.py +++ b/prediction_market_agent/db/long_term_memory_table_handler.py @@ -49,7 +49,7 @@ def search( from_: DatetimeUTC | None = None, to_: DatetimeUTC | None = None, limit: int | None = None, - ) -> t.Sequence[LongTermMemories]: + ) -> list[LongTermMemories]: """Searches the LongTermMemoryTableHandler for entries within a specified datetime range that match self.task_description.""" query_filters = [ diff --git a/prediction_market_agent/db/sql_handler.py b/prediction_market_agent/db/sql_handler.py index 35aee573..48094d14 100644 --- a/prediction_market_agent/db/sql_handler.py +++ b/prediction_market_agent/db/sql_handler.py @@ -35,7 +35,7 @@ def get_with_filter_and_order( order_by_column_name: str | None = None, order_desc: bool = True, limit: int | None = None, - ) -> t.Sequence[SQLModelType]: + ) -> list[SQLModelType]: with self.db_manager.get_session() as session: query = session.query(self.table) for exp in query_filters: From f58818d3355482cd00a7abd9d90180fca3f88644 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Fri, 13 Dec 2024 14:37:47 +0100 Subject: [PATCH 2/2] little updates --- .../nft_treasury_game/app_nft_treasury_game.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/prediction_market_agent/agents/microchain_agent/nft_treasury_game/app_nft_treasury_game.py b/prediction_market_agent/agents/microchain_agent/nft_treasury_game/app_nft_treasury_game.py index 5bf03757..001d5d9c 100644 --- a/prediction_market_agent/agents/microchain_agent/nft_treasury_game/app_nft_treasury_game.py +++ b/prediction_market_agent/agents/microchain_agent/nft_treasury_game/app_nft_treasury_game.py @@ -9,7 +9,7 @@ from enum import Enum import streamlit as st -from microchain.functions import Reasoning +from microchain.functions import Reasoning, Stop from prediction_market_agent_tooling.tools.balances import get_balances from prediction_market_agent_tooling.tools.utils import check_not_none from streamlit_extras.stylable_container import stylable_container @@ -101,6 +101,8 @@ def customized_chat_message( match parsed_function_call_name: case Reasoning.__name__: icon = "🧠" + case Stop.__name__: + icon = "😴" case ReceiveMessage.__name__: icon = "👤" case BroadcastPublicMessageToHumans.__name__: @@ -116,6 +118,9 @@ def customized_chat_message( st.markdown( parsed_function_call_body.replace("reasoning='", "").replace("')", "") ) + elif parsed_function_call_name == Stop.__name__: + # If the agent decided to stop, show it as a break, as it will be started soon again. + st.markdown("Taking a break.") else: # Otherwise, show it as a normal function-response call, e.g. `ReceiveMessages() -> ...`. st.markdown( @@ -125,6 +130,7 @@ def customized_chat_message( # Only show the output if it's supposed to be interesting. if parsed_function_call_name not in ( Reasoning.__name__, + Stop.__name__, BroadcastPublicMessageToHumans.__name__, SendPaidMessageToAnotherAgent.__name__, ):