Skip to content

Commit

Permalink
Make nft web app nicer (#602)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Dec 13, 2024
1 parent 1b45fbd commit b670cfe
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from enum import Enum

import streamlit as st
from microchain.functions import Reasoning
from microchain.functions import Reasoning, Stop
from prediction_market_agent_tooling.tools.balances import get_balances
from prediction_market_agent_tooling.tools.datetime_utc import DatetimeUTC
from prediction_market_agent_tooling.tools.utils import check_not_none
from streamlit_extras.stylable_container import stylable_container

from prediction_market_agent.agents.identifiers import AgentIdentifier
from prediction_market_agent.agents.microchain_agent.messages_functions import (
Expand All @@ -27,6 +28,7 @@
DeployableAgentNFTGameAbstract,
)
from prediction_market_agent.db.long_term_memory_table_handler import (
LongTermMemories,
LongTermMemoryTableHandler,
)
from prediction_market_agent.db.prompt_table_handler import PromptTableHandler
Expand Down Expand Up @@ -63,8 +65,8 @@ def send_message_part(nft_agent: type[DeployableAgentNFTGameAbstract]) -> None:


def parse_function_and_body(
role: t.Literal["user", "assistant", "system"], message: str
) -> t.Tuple[str | None, str | None]:
role: t.Literal["user", "assistant"], message: str
) -> t.Tuple[str, str]:
message = message.strip()

if role == "assistant":
Expand All @@ -75,37 +77,32 @@ def parse_function_and_body(
# Responses from the individual functions are stored under `user` role.
parsed_function = DummyFunctionName.RESPONSE_FUNCTION_NAME
parsed_body = message
elif role == "system":
# System message isn't shown in the chat history, so ignore.
parsed_function = None
parsed_body = None
else:
raise ValueError(f"Unknown role: {role}")

return parsed_function, parsed_body


def customized_chat_message(
role: t.Literal["user", "assistant", "system"],
message: str,
created_at: DatetimeUTC,
function_call: LongTermMemories,
function_output: LongTermMemories,
) -> None:
parsed_function, parsed_body = parse_function_and_body(role, message)
if parsed_function is None:
return
# If the message is output from one of these functions, skip it, because it's not interesting to read `The reasoning has been recorded` and similar over and over again.
if parsed_body in (
Reasoning()(""),
BroadcastPublicMessageToHumans.OUTPUT_TEXT,
SendPaidMessageToAnotherAgent.OUTPUT_TEXT,
):
return
created_at = function_output.datetime_

parsed_function_call_name, parsed_function_call_body = parse_function_and_body(
check_not_none(function_call.metadata_dict)["role"],
check_not_none(function_call.metadata_dict)["content"],
)
parsed_function_output_name, parsed_function_output_body = parse_function_and_body(
check_not_none(function_output.metadata_dict)["role"],
check_not_none(function_output.metadata_dict)["content"],
)

match parsed_function:
match parsed_function_call_name:
case Reasoning.__name__:
icon = "🧠"
case DummyFunctionName.RESPONSE_FUNCTION_NAME:
icon = "✔️"
case Stop.__name__:
icon = "😴"
case ReceiveMessage.__name__:
icon = "👤"
case BroadcastPublicMessageToHumans.__name__:
Expand All @@ -116,11 +113,28 @@ def customized_chat_message(
icon = "🤖"

with st.chat_message(icon):
if parsed_function:
st.markdown(f"**{parsed_function}**")
st.write(created_at.strftime("%Y-%m-%d %H:%M:%S"))
if message:
st.markdown(parsed_body)
if parsed_function_call_name == Reasoning.__name__:
# Don't show reasoning as function call, to make it a bit nicer.
st.markdown(
parsed_function_call_body.replace("reasoning='", "").replace("')", "")
)
elif parsed_function_call_name == Stop.__name__:
# If the agent decided to stop, show it as a break, as it will be started soon again.
st.markdown("Taking a break.")
else:
# Otherwise, show it as a normal function-response call, e.g. `ReceiveMessages() -> ...`.
st.markdown(
f"**{parsed_function_call_name}**({parsed_function_call_body}) *{created_at.strftime('%Y-%m-%d %H:%M:%S')}*"
)

# Only show the output if it's supposed to be interesting.
if parsed_function_call_name not in (
Reasoning.__name__,
Stop.__name__,
BroadcastPublicMessageToHumans.__name__,
SendPaidMessageToAnotherAgent.__name__,
):
st.markdown(parsed_function_output_body)


@st.fragment(run_every=timedelta(seconds=5))
Expand All @@ -134,14 +148,20 @@ def show_function_calls_part(nft_agent: type[DeployableAgentNFTGameAbstract]) ->
st.markdown("No actions yet.")
return

for item in calls:
if item.metadata_dict is None:
continue
customized_chat_message(
item.metadata_dict["role"],
item.metadata_dict["content"],
item.datetime_,
)
# Filter out system calls, because they aren't supposed to be shown in the chat history itself.
calls = [
call for call in calls if check_not_none(call.metadata_dict)["role"] != "system"
]

# Microchain works on `function call` - `funtion response` pairs, so we will process them together.
for index, (function_output, function_call) in enumerate(
zip(calls[::2], calls[1::2])
):
with stylable_container(
key=f"function_call_{index}",
css_styles=f"{{background-color: {'#f0f0f0' if (index % 2 == 0) else 'white'}; border-radius: 5px;}}",
):
customized_chat_message(function_call, function_output)


@st.fragment(run_every=timedelta(seconds=5))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def search(
from_: DatetimeUTC | None = None,
to_: DatetimeUTC | None = None,
limit: int | None = None,
) -> t.Sequence[LongTermMemories]:
) -> list[LongTermMemories]:
"""Searches the LongTermMemoryTableHandler for entries within a specified datetime range that match
self.task_description."""
query_filters = [
Expand Down
2 changes: 1 addition & 1 deletion prediction_market_agent/db/sql_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_with_filter_and_order(
order_by_column_name: str | None = None,
order_desc: bool = True,
limit: int | None = None,
) -> t.Sequence[SQLModelType]:
) -> list[SQLModelType]:
with self.db_manager.get_session() as session:
query = session.query(self.table)
for exp in query_filters:
Expand Down

0 comments on commit b670cfe

Please sign in to comment.