Skip to content

Commit

Permalink
Refactored blockchain_message and blockchain_message_handler (#597)
Browse files Browse the repository at this point in the history
* Refactored blockchain_message and blockchain_message_handler

* Returning formatted message on ReceiveMessage

* Fixed tests

* Fixed test (2)

* Update prediction_market_agent/db/models.py

Co-authored-by: Peter Jung <[email protected]>

* Implemented PR comments

* Fixing tests

* typo

* Renamed test

* Trying to fix test again

* Tests now passing

* Fixed additional tests

---------

Co-authored-by: Peter Jung <[email protected]>
  • Loading branch information
gabrielfior and kongzii authored Dec 12, 2024
1 parent 7cdf762 commit 2ad0275
Show file tree
Hide file tree
Showing 17 changed files with 245 additions and 186 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
from prediction_market_agent.agents.microchain_agent.microchain_agent_keys import (
MicrochainAgentKeys,
)
from prediction_market_agent.agents.microchain_agent.utils import compress_message
from prediction_market_agent.db.blockchain_transaction_fetcher import (
BlockchainTransactionFetcher,
)
from prediction_market_agent.db.models import BlockchainMessage
from prediction_market_agent.tools.message_utils import compress_message


class BroadcastPublicMessageToHumans(Function):
Expand Down Expand Up @@ -77,7 +76,7 @@ def description(self) -> str:
def example_args(self) -> list[str]:
return []

def __call__(self) -> BlockchainMessage | None:
def __call__(self) -> str:
keys = MicrochainAgentKeys()
fetcher = BlockchainTransactionFetcher()
message_to_process = (
Expand All @@ -101,7 +100,7 @@ def __call__(self) -> BlockchainMessage | None:
logger.info(
f"Funded the treasury with xDai, tx_hash: {HexBytes(tx_receipt['transactionHash']).hex()}"
)
return message_to_process
return str(message_to_process) if message_to_process else "No new messages"


MESSAGES_FUNCTIONS: list[type[Function]] = [
Expand Down
10 changes: 0 additions & 10 deletions prediction_market_agent/agents/microchain_agent/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import typing as t
import zlib

import pandas as pd
from microchain import Agent
Expand Down Expand Up @@ -141,12 +140,3 @@ def get_function_useage_from_history(
data={"Usage Count": list(function_useage.values())},
index=function_names,
)


def compress_message(message: str) -> bytes:
"""Used to reduce size of the message before sending it to reduce gas costs."""
return zlib.compress(message.encode(), level=zlib.Z_BEST_COMPRESSION)


def decompress_message(message: bytes) -> str:
return zlib.decompress(message).decode()
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from prediction_market_agent.agents.microchain_agent.microchain_agent_keys import (
MicrochainAgentKeys,
)
from prediction_market_agent.agents.microchain_agent.utils import decompress_message
from prediction_market_agent.db.blockchain_message_table_handler import (
BlockchainMessageTableHandler,
)
from prediction_market_agent.db.models import BlockchainMessage
from prediction_market_agent.tools.message_utils import decompress_message
from prediction_market_agent.utils import APIKeys


Expand Down
12 changes: 9 additions & 3 deletions prediction_market_agent/db/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
from typing import Any, Optional

from prediction_market_agent_tooling.gtypes import wei_type
from prediction_market_agent_tooling.loggers import logger
from prediction_market_agent_tooling.tools.utils import DatetimeUTC
from sqlalchemy import BigInteger, Column
from prediction_market_agent_tooling.tools.web3_utils import wei_to_xdai
from sqlalchemy import Column, Numeric
from sqlmodel import Field, SQLModel


Expand Down Expand Up @@ -74,6 +76,10 @@ class BlockchainMessage(SQLModel, table=True):
consumer_address: str
sender_address: str
transaction_hash: str = Field(unique=True)
block: int = Field(sa_column=Column(BigInteger, nullable=False))
value_wei: int = Field(sa_column=Column(BigInteger, nullable=False))
block: int = Field(sa_column=Column(Numeric, nullable=False))
value_wei: int = Field(sa_column=Column(Numeric, nullable=False))
data_field: Optional[str]

def __str__(self) -> str:
return f"""Sender: {self.sender_address} \n Value: {wei_to_xdai(wei_type(self.value_wei))} \n Message: {self.data_field}
"""
20 changes: 20 additions & 0 deletions prediction_market_agent/tools/message_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import zlib

from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes


def compress_message(message: str) -> bytes:
"""Used to reduce size of the message before sending it to reduce gas costs."""
return zlib.compress(message.encode(), level=zlib.Z_BEST_COMPRESSION)


def decompress_message(message: bytes) -> str:
return zlib.decompress(message).decode()


def unzip_message_else_do_nothing(data_field: str) -> str:
"""We try decompressing the message, else we return the original data field."""
try:
return decompress_message(HexBytes(data_field))
except Exception:
return data_field
7 changes: 0 additions & 7 deletions tests/agents/microchain/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from prediction_market_agent_tooling.markets.omen.omen_contracts import (
WrappedxDaiContract,
)
from pydantic import SecretStr
from web3 import Web3

from prediction_market_agent.agents.microchain_agent.blockchain.code_interpreter import (
Expand All @@ -15,7 +14,6 @@
from prediction_market_agent.agents.microchain_agent.blockchain.contract_class_converter import (
ContractClassConverter,
)
from prediction_market_agent.utils import DBKeys


def mock_summaries(function_names: list[str]) -> Summaries:
Expand Down Expand Up @@ -85,8 +83,3 @@ def wxdai_contract_mocked_rag(
yield ContractClassConverter(
contract_address=contract_address, contract_name=wxdai.__class__.__name__
)


@pytest.fixture(scope="session")
def session_keys_with_mocked_db() -> Generator[DBKeys, None, None]:
yield DBKeys(SQLALCHEMY_DB_URL=SecretStr("sqlite://"))
19 changes: 5 additions & 14 deletions tests/agents/microchain/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Generator

import numpy as np
import pytest
from microchain import Engine
Expand Down Expand Up @@ -35,15 +33,6 @@
from tests.utils import RUN_PAID_TESTS


@pytest.fixture(scope="session")
def long_term_memory() -> Generator[LongTermMemoryTableHandler, None, None]:
"""Creates a in-memory SQLite DB for testing"""
long_term_memory = LongTermMemoryTableHandler(
task_description="test", sqlalchemy_db_url="sqlite://"
)
yield long_term_memory


# TODO investigate why this fails for polymarket https://github.com/gnosis/prediction-market-agent/issues/62
@pytest.mark.parametrize("market_type", [MarketType.OMEN, MarketType.MANIFOLD])
def test_get_markets(market_type: MarketType) -> None:
Expand Down Expand Up @@ -174,8 +163,10 @@ def test_predict_probability(market_type: MarketType) -> None:


@pytest.mark.skipif(not RUN_PAID_TESTS, reason="This test costs money to run.")
def test_remember_past_learnings(long_term_memory: LongTermMemoryTableHandler) -> None:
long_term_memory.save_history(
def test_remember_past_learnings(
long_term_memory_table_handler: LongTermMemoryTableHandler,
) -> None:
long_term_memory_table_handler.save_history(
history=[
{"role": "user", "content": "I went to the park and saw a dog."},
{"role": "user", "content": "I went to the park and saw a cat."},
Expand All @@ -185,7 +176,7 @@ def test_remember_past_learnings(long_term_memory: LongTermMemoryTableHandler) -
## Uncomment below to test with the memories accrued from use of https://autonomous-trader-agent.streamlit.app/
# long_term_memory = LongTermMemoryTableHandler(task_description="microchain-streamlit-app")
remember_past_learnings = RememberPastActions(
long_term_memory=long_term_memory,
long_term_memory=long_term_memory_table_handler,
model=DEFAULT_OPENAI_MODEL,
)
print(remember_past_learnings())
Expand Down
67 changes: 46 additions & 21 deletions tests/agents/microchain/test_messages_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,51 @@
import polars as pl
import pytest
from eth_typing import ChecksumAddress
from prediction_market_agent_tooling.gtypes import xdai_type
from prediction_market_agent_tooling.tools.web3_utils import xdai_to_wei
from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes
from pydantic import SecretStr
from web3 import Web3

from prediction_market_agent.agents.microchain_agent.messages_functions import (
ReceiveMessage,
)
from prediction_market_agent.agents.microchain_agent.utils import compress_message
from prediction_market_agent.db.blockchain_message_table_handler import (
BlockchainMessageTableHandler,
)
from prediction_market_agent.db.blockchain_transaction_fetcher import (
BlockchainTransactionFetcher,
)
from prediction_market_agent.utils import DBKeys
from prediction_market_agent.tools.message_utils import compress_message


@pytest.fixture(scope="module")
def agent2_address() -> ChecksumAddress:
return Web3.to_checksum_address("0xb4D8C8BedE2E49b08d2A22485f72fA516116FE7F")
@pytest.fixture(scope="session")
def account2_address() -> ChecksumAddress:
# anvil account # 2
return Web3.to_checksum_address("0x70997970C51812dc3A010C7d01b50e0d17dc79C8")


@pytest.fixture(scope="session")
def account2_private_key() -> SecretStr:
"Anvil test account private key. It's public already."
return SecretStr(
"0x59c6995e998f97a5a0044966f0945389dc9e86dae88c7a8412f4603b6b78690d"
)


# Random transactions found on Gnosisscan.
MOCK_HASH_1 = "0x5ba6dd51d3660f98f02683e032daa35644d3f7f975975da3c2628a5b4b1f5cb6"
MOCK_HASH_2 = "0x429f61ea3e1afdd104fdd0a6f3b88432ec4c7b298fd126378e53a63bc60fed6a"
MOCK_SENDER_SPICE_QUERY = Web3.to_checksum_address(
"0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266"
) # anvil account 1


def mock_spice_query(query: str, api_key: str) -> pl.DataFrame:
anvil_account_1 = Web3.to_checksum_address(
"0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266"
)
return pl.DataFrame(
{
"hash": [MOCK_HASH_1, MOCK_HASH_2],
"value": [xdai_to_wei(xdai_type(1)), xdai_to_wei(xdai_type(2))],
"value": [Web3.to_wei(1, "ether"), Web3.to_wei(2, "ether")],
"block_number": [1, 2],
"from": [anvil_account_1, anvil_account_1],
"from": [MOCK_SENDER_SPICE_QUERY, MOCK_SENDER_SPICE_QUERY],
"data": ["test", Web3.to_hex(compress_message("test"))],
}
)
Expand All @@ -64,29 +74,41 @@ def patch_spice() -> Generator[PropertyMock, None, None]:
yield mock_spice


@pytest.fixture(scope="module")
def patch_send_xdai() -> Generator[PropertyMock, None, None]:
# Note that we patch where the function is called (see https://docs.python.org/3/library/unittest.mock.html#where-to-patch).
with patch(
"prediction_market_agent.agents.microchain_agent.messages_functions.send_xdai_to",
return_value={"transactionHash": HexBytes(MOCK_HASH_1)},
) as mock_send_xdai:
yield mock_send_xdai


@pytest.fixture
def patch_public_key(
agent2_address: ChecksumAddress,
account2_address: ChecksumAddress, account2_private_key: SecretStr
) -> Generator[PropertyMock, None, None]:
with patch(
"prediction_market_agent.agents.microchain_agent.microchain_agent_keys.MicrochainAgentKeys.public_key",
new_callable=PropertyMock,
) as mock_public_key:
mock_public_key.return_value = agent2_address
) as mock_public_key, patch(
"prediction_market_agent.agents.microchain_agent.microchain_agent_keys.MicrochainAgentKeys.bet_from_private_key",
new_callable=PropertyMock,
) as mock_private_key:
mock_public_key.return_value = account2_address
mock_private_key.return_value = account2_private_key
yield mock_public_key


@pytest.fixture
@pytest.fixture(scope="function")
def patch_pytest_db(
session_keys_with_mocked_db: DBKeys,
memory_blockchain_handler: BlockchainMessageTableHandler,
) -> Generator[PropertyMock, None, None]:
with patch(
"prediction_market_agent_tooling.config.APIKeys.sqlalchemy_db_url",
new_callable=PropertyMock,
) as mock_sqlalchemy_db_url:
mock_sqlalchemy_db_url.return_value = (
session_keys_with_mocked_db.SQLALCHEMY_DB_URL
)
mock_sqlalchemy_db_url.return_value = SecretStr("sqlite://")
yield mock_sqlalchemy_db_url


Expand All @@ -95,6 +117,7 @@ def test_receive_message_description(
patch_public_key: PropertyMock,
patch_spice: PropertyMock,
patch_dune_api_key: PropertyMock,
patch_send_xdai: PropertyMock,
) -> None:
r = ReceiveMessage()
description = r.description
Expand All @@ -107,6 +130,7 @@ def test_receive_message_description(


def test_receive_message_call(
patch_send_xdai: PropertyMock,
patch_pytest_db: PropertyMock,
patch_public_key: PropertyMock,
patch_spice: PropertyMock,
Expand All @@ -116,14 +140,15 @@ def test_receive_message_call(

blockchain_message = r()
assert blockchain_message is not None
assert blockchain_message.transaction_hash == MOCK_HASH_1
assert MOCK_SENDER_SPICE_QUERY in blockchain_message


def test_receive_message_then_check_count_unseen_messages(
patch_pytest_db: PropertyMock,
patch_public_key: PropertyMock,
patch_spice: typing.Any,
patch_dune_api_key: PropertyMock,
patch_send_xdai: PropertyMock,
) -> None:
# Idea here is to fetch the next message, and then fetch the count of unseen messages, asserting that
# this number decreased by 1.
Expand Down
10 changes: 0 additions & 10 deletions tests/agents/microchain/test_utils.py

This file was deleted.

Loading

0 comments on commit 2ad0275

Please sign in to comment.