Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to retrieve actions given context for general agent #616

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
from datetime import timedelta

from langchain.vectorstores.chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from microchain import Function
from prediction_market_agent_tooling.tools.utils import utcnow
from prediction_market_agent_tooling.tools.utils import check_not_none, utcnow

from prediction_market_agent.agents.microchain_agent.memory import DatedChatMessage
from prediction_market_agent.agents.microchain_agent.microchain_agent_keys import (
MicrochainAgentKeys,
)
from prediction_market_agent.agents.utils import memories_to_learnings
from prediction_market_agent.db.long_term_memory_table_handler import (
LongTermMemories,
LongTermMemoryTableHandler,
)


class LookAtPastActions(Function):
class LongTermMemoryBasedFunction(Function):
def __init__(
self, long_term_memory: LongTermMemoryTableHandler, model: str
) -> None:
self.long_term_memory = long_term_memory
self.model = model
super().__init__()


class LookAtPastActionsFromLastDay(LongTermMemoryBasedFunction):
@property
def description(self) -> str:
return (
Expand All @@ -42,3 +50,49 @@ def __call__(self) -> str:
DatedChatMessage.from_long_term_memory(ltm) for ltm in memories
]
return memories_to_learnings(memories=simple_memories, model=self.model)


class CheckAllPastActionsGivenContext(LongTermMemoryBasedFunction):
@property
def description(self) -> str:
return (
"Use this function to fetch information about the actions you executed with respect to a specific context. "
"For example, you can use this function to look into all your past actions if you ever did form a coalition with another agent."
)

@property
def example_args(self) -> list[str]:
return ["What coalitions did I form?"]

def __call__(self, context: str) -> str:
keys = MicrochainAgentKeys()
all_memories = self.long_term_memory.search()

collection = Chroma(
embedding_function=OpenAIEmbeddings(
api_key=keys.openai_api_key_secretstr_v1
)
)
collection.add_texts(
texts=[
f"From: {check_not_none(x.metadata_dict)['role']} Content: {check_not_none(x.metadata_dict)['content']}"
for x in all_memories
],
metadatas=[{"json": x.model_dump_json()} for x in all_memories],
)

top_k_per_query_results = collection.similarity_search(context, k=50)
results = [
DatedChatMessage.from_long_term_memory(
LongTermMemories.model_validate_json(x.metadata["json"])
)
for x in top_k_per_query_results
]

return memories_to_learnings(memories=results, model=self.model)

Comment on lines +55 to +93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for potential missing metadata keys

In the __call__ method of CheckAllPastActionsGivenContext, accessing x.metadata_dict['role'], x.metadata_dict['content'], and x.metadata['json'] may raise exceptions if these keys are missing or if the metadata is not as expected. To prevent runtime errors, consider adding checks or exception handling to ensure these keys exist before accessing them.


MEMORY_FUNCTIONS: list[type[LongTermMemoryBasedFunction]] = [
LookAtPastActionsFromLastDay,
CheckAllPastActionsGivenContext,
]
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
MARKET_FUNCTIONS,
)
from prediction_market_agent.agents.microchain_agent.memory_functions import (
LookAtPastActions,
MEMORY_FUNCTIONS,
)
from prediction_market_agent.agents.microchain_agent.nft_functions import NFT_FUNCTIONS
from prediction_market_agent.agents.microchain_agent.nft_treasury_game.messages_functions import (
Expand Down Expand Up @@ -171,8 +171,8 @@ def build_agent_functions(
functions.extend(f() for f in BALANCE_FUNCTIONS)

if long_term_memory:
functions.append(
LookAtPastActions(long_term_memory=long_term_memory, model=model)
functions.extend(
f(long_term_memory=long_term_memory, model=model) for f in MEMORY_FUNCTIONS
)

return functions
Expand Down
2 changes: 1 addition & 1 deletion prediction_market_agent/agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

Each memory comes with a timestamp. If the memories are clustered into
different times, then make a separate list for each cluster. Refer to each
cluster as a 'Trading Session', and display the range of timestamps for each.
cluster as a 'Session', and display the range of timestamps for each.

MEMORIES:
{memories}
Expand Down
33 changes: 30 additions & 3 deletions tests/agents/microchain/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
SellYes,
)
from prediction_market_agent.agents.microchain_agent.memory_functions import (
LookAtPastActions,
CheckAllPastActionsGivenContext,
LookAtPastActionsFromLastDay,
)
from prediction_market_agent.agents.microchain_agent.utils import (
get_balance,
Expand Down Expand Up @@ -163,7 +164,7 @@ def test_predict_probability(market_type: MarketType) -> None:


@pytest.mark.skipif(not RUN_PAID_TESTS, reason="This test costs money to run.")
def test_remember_past_learnings(
def test_look_at_past_actions(
long_term_memory_table_handler: LongTermMemoryTableHandler,
) -> None:
long_term_memory_table_handler.save_history(
Expand All @@ -175,13 +176,39 @@ def test_remember_past_learnings(
)
## Uncomment below to test with the memories accrued from use of https://autonomous-trader-agent.streamlit.app/
# long_term_memory = LongTermMemoryTableHandler(task_description="microchain-streamlit-app")
past_actions = LookAtPastActions(
past_actions = LookAtPastActionsFromLastDay(
long_term_memory=long_term_memory_table_handler,
model=DEFAULT_OPENAI_MODEL,
)
print(past_actions())


@pytest.mark.skipif(not RUN_PAID_TESTS, reason="This test costs money to run.")
def test_check_past_actions_given_context(
long_term_memory_table_handler: LongTermMemoryTableHandler,
) -> None:
long_term_memory_table_handler.save_history(
history=[
{
"role": "user",
"content": "Agent X sent me a message asking for a coalition.",
},
{
"role": "user",
"content": "I agreed with agent X to form a coalition, I'll send him my NFT key if he sends me 5 xDai",
},
{"role": "user", "content": "I went to the park and saw a bird."},
]
)
## Uncomment below to test with the memories accrued from use of https://autonomous-trader-agent.streamlit.app/
# long_term_memory = LongTermMemoryTableHandler(task_description="microchain-streamlit-app")
past_actions = CheckAllPastActionsGivenContext(
long_term_memory=long_term_memory_table_handler,
model=DEFAULT_OPENAI_MODEL,
)
print(past_actions(context="What coalitions did I form?"))


@pytest.mark.parametrize("market_type", [MarketType.OMEN])
def test_kelly_bet(market_type: MarketType) -> None:
get_kelly_bet = GetKellyBet(market_type=market_type, keys=APIKeys())
Expand Down
Loading