Skip to content

Commit

Permalink
Use Optuna and KFold to find best strategy for agent (#561)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Nov 29, 2024
1 parent 58f2b8e commit 26dad12
Show file tree
Hide file tree
Showing 15 changed files with 660 additions and 243 deletions.
535 changes: 340 additions & 195 deletions examples/monitor/match_bets_with_langfuse_traces.py

Large diffs are not rendered by default.

175 changes: 148 additions & 27 deletions poetry.lock

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@


class BaseSubgraphHandler(metaclass=SingletonMeta):
def __init__(self) -> None:
self.sg = Subgrounds()
def __init__(self, timeout: int = 30) -> None:
self.sg = Subgrounds(timeout=timeout)
# Patch methods to retry on failure.
self.sg.query_json = tenacity.retry(
stop=tenacity.stop_after_attempt(3),
Expand Down
15 changes: 14 additions & 1 deletion prediction_market_agent_tooling/markets/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def from_trade(trade: Trade, id: str) -> "PlacedTrade":
)


class SimulationDetail(BaseModel):
class SimulatedBetDetail(BaseModel):
strategy: str
url: str
market_p_yes: float
Expand All @@ -161,3 +161,16 @@ class SharpeOutput(BaseModel):
annualized_volatility: float
mean_daily_return: float
annualized_sharpe_ratio: float


class SimulatedLifetimeDetail(BaseModel):
p_yes_mse: float
total_bet_amount: float
total_bet_profit: float
total_simulated_amount: float
total_simulated_profit: float
roi: float
simulated_roi: float
sharpe_output_original: SharpeOutput
sharpe_output_simulation: SharpeOutput
maximize: float
10 changes: 7 additions & 3 deletions prediction_market_agent_tooling/markets/omen/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,13 +503,17 @@ class OmenBet(BaseModel):
creator: OmenBetCreator
creationTimestamp: int
collateralAmount: Wei
collateralAmountUSD: USD
feeAmount: Wei
outcomeIndex: int
outcomeTokensTraded: Wei
transactionHash: HexBytes
fpmm: OmenMarket

@property
def collateral_amount_usd(self) -> USD:
# Convert manually instad of using the field `collateralAmountUSD` available on the graph, because it's bugged, it's 0 for non-xDai markets.
return USD(wei_to_xdai(self.collateralAmount))

@property
def creation_datetime(self) -> DatetimeUTC:
return DatetimeUTC.to_datetime_utc(self.creationTimestamp)
Expand Down Expand Up @@ -544,7 +548,7 @@ def to_bet(self) -> Bet:
return Bet(
id=str(self.transactionHash),
# Use the transaction hash instead of the bet id - both are valid, but we return the transaction hash from the trade functions, so be consistent here.
amount=BetAmount(amount=self.collateralAmountUSD, currency=Currency.xDai),
amount=BetAmount(amount=self.collateral_amount_usd, currency=Currency.xDai),
outcome=self.boolean_outcome,
created_time=self.creation_datetime,
market_question=self.title,
Expand All @@ -560,7 +564,7 @@ def to_generic_resolved_bet(self) -> ResolvedBet:
return ResolvedBet(
id=self.transactionHash.hex(),
# Use the transaction hash instead of the bet id - both are valid, but we return the transaction hash from the trade functions, so be consistent here.
amount=BetAmount(amount=self.collateralAmountUSD, currency=Currency.xDai),
amount=BetAmount(amount=self.collateral_amount_usd, currency=Currency.xDai),
outcome=self.boolean_outcome,
created_time=self.creation_datetime,
market_question=self.title,
Expand Down
4 changes: 4 additions & 0 deletions prediction_market_agent_tooling/markets/omen/omen.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,13 +480,17 @@ def get_resolved_bets_made_since(
better_address: ChecksumAddress,
start_time: DatetimeUTC,
end_time: DatetimeUTC | None,
market_resolved_before: DatetimeUTC | None = None,
market_resolved_after: DatetimeUTC | None = None,
) -> list[ResolvedBet]:
subgraph_handler = OmenSubgraphHandler()
bets = subgraph_handler.get_resolved_bets_with_valid_answer(
better_address=better_address,
start_time=start_time,
end_time=end_time,
market_id=None,
market_resolved_before=market_resolved_before,
market_resolved_after=market_resolved_after,
)
generic_bets = [b.to_generic_resolved_bet() for b in bets]
return generic_bets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
WrappedxDaiContract,
sDaiContract,
)
from prediction_market_agent_tooling.tools.caches.inmemory_cache import (
persistent_inmemory_cache,
)
from prediction_market_agent_tooling.tools.utils import (
DatetimeUTC,
to_int_timestamp,
Expand Down Expand Up @@ -112,7 +115,6 @@ def _get_fields_for_bets(self, bets_field: FieldPath) -> list[FieldPath]:
bets_field.creator.id,
bets_field.creationTimestamp,
bets_field.collateralAmount,
bets_field.collateralAmountUSD,
bets_field.feeAmount,
bets_field.outcomeIndex,
bets_field.outcomeTokensTraded,
Expand Down Expand Up @@ -524,6 +526,8 @@ def get_trades(
filter_by_answer_finalized_not_null: bool = False,
type_: t.Literal["Buy", "Sell"] | None = None,
market_opening_after: DatetimeUTC | None = None,
market_resolved_before: DatetimeUTC | None = None,
market_resolved_after: DatetimeUTC | None = None,
collateral_amount_more_than: Wei | None = None,
sort_by_field: FieldPath | None = None,
sort_direction: str | None = None,
Expand All @@ -549,6 +553,15 @@ def get_trades(
where_stms.append(
trade.fpmm.openingTimestamp > to_int_timestamp(market_opening_after)
)
if market_resolved_after is not None:
where_stms.append(
trade.fpmm.resolutionTimestamp > to_int_timestamp(market_resolved_after)
)
if market_resolved_before is not None:
where_stms.append(
trade.fpmm.resolutionTimestamp
< to_int_timestamp(market_resolved_before)
)
if collateral_amount_more_than is not None:
where_stms.append(trade.collateralAmount > collateral_amount_more_than)

Expand Down Expand Up @@ -577,6 +590,8 @@ def get_bets(
market_id: t.Optional[ChecksumAddress] = None,
filter_by_answer_finalized_not_null: bool = False,
market_opening_after: DatetimeUTC | None = None,
market_resolved_before: DatetimeUTC | None = None,
market_resolved_after: DatetimeUTC | None = None,
collateral_amount_more_than: Wei | None = None,
) -> list[OmenBet]:
return self.get_trades(
Expand All @@ -587,37 +602,47 @@ def get_bets(
filter_by_answer_finalized_not_null=filter_by_answer_finalized_not_null,
type_="Buy", # We consider `bet` to be only the `Buy` trade types.
market_opening_after=market_opening_after,
market_resolved_before=market_resolved_before,
market_resolved_after=market_resolved_after,
collateral_amount_more_than=collateral_amount_more_than,
)

def get_resolved_bets(
self,
better_address: ChecksumAddress,
start_time: DatetimeUTC,
start_time: DatetimeUTC | None = None,
end_time: t.Optional[DatetimeUTC] = None,
market_id: t.Optional[ChecksumAddress] = None,
market_resolved_before: DatetimeUTC | None = None,
market_resolved_after: DatetimeUTC | None = None,
) -> list[OmenBet]:
omen_bets = self.get_bets(
better_address=better_address,
start_time=start_time,
end_time=end_time,
market_id=market_id,
filter_by_answer_finalized_not_null=True,
market_resolved_before=market_resolved_before,
market_resolved_after=market_resolved_after,
)
return [b for b in omen_bets if b.fpmm.is_resolved]

def get_resolved_bets_with_valid_answer(
self,
better_address: ChecksumAddress,
start_time: DatetimeUTC,
start_time: DatetimeUTC | None = None,
end_time: t.Optional[DatetimeUTC] = None,
market_resolved_before: DatetimeUTC | None = None,
market_resolved_after: DatetimeUTC | None = None,
market_id: t.Optional[ChecksumAddress] = None,
) -> list[OmenBet]:
bets = self.get_resolved_bets(
better_address=better_address,
start_time=start_time,
end_time=end_time,
market_id=market_id,
market_resolved_before=market_resolved_before,
market_resolved_after=market_resolved_after,
)
return [b for b in bets if b.fpmm.is_resolved_with_valid_answer]

Expand Down Expand Up @@ -926,3 +951,13 @@ def get_agent_results_for_bet(self, bet: OmenBet) -> ContractPrediction | None:
raise RuntimeError("Multiple results found for a single bet.")

return results[0]


@persistent_inmemory_cache
def get_omen_market_by_market_id_cached(
market_id: HexAddress,
block_number: int, # Force `block_number` to be provided, because `latest` block constantly updates.
) -> OmenMarket:
return OmenSubgraphHandler().get_omen_market_by_market_id(
market_id, block_number=block_number
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from prediction_market_agent_tooling.markets.data_models import (
SharpeOutput,
SimulationDetail,
SimulatedBetDetail,
)


class SharpeRatioCalculator:
def __init__(
self, details: list[SimulationDetail], risk_free_rate: float = 0.0
self, details: list[SimulatedBetDetail], risk_free_rate: float = 0.0
) -> None:
self.details = details
self.df = pd.DataFrame([d.model_dump() for d in self.details])
Expand All @@ -19,7 +19,9 @@ def __has_df_valid_columns_else_exception(
self, required_columns: list[str]
) -> None:
if not set(required_columns).issubset(self.df.columns):
raise ValueError("Dataframe doesn't contain all the required columns.")
raise ValueError(
f"Dataframe doesn't contain all the required columns. {required_columns=} {self.df.columns=}"
)

def prepare_wallet_daily_balance_df(
self, timestamp_col_name: str, profit_col_name: str
Expand Down
48 changes: 44 additions & 4 deletions prediction_market_agent_tooling/tools/caches/inmemory_cache.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,57 @@
from functools import cache
from typing import Any, Callable, TypeVar, cast
from typing import Any, Callable, TypeVar, cast, overload

from joblib import Memory

from prediction_market_agent_tooling.config import APIKeys

MEMORY = Memory(APIKeys().CACHE_DIR, verbose=0)


T = TypeVar("T", bound=Callable[..., Any])


def persistent_inmemory_cache(func: T) -> T:
@overload
def persistent_inmemory_cache(
func: None = None,
*,
in_memory_cache: bool = True,
) -> Callable[[T], T]:
...


@overload
def persistent_inmemory_cache(
func: T,
*,
in_memory_cache: bool = True,
) -> T:
...


def persistent_inmemory_cache(
func: T | None = None,
*,
in_memory_cache: bool = True,
) -> T | Callable[[T], T]:
"""
Wraps a function with both file cache (for persistent cache) and in-memory cache (for speed).
Wraps a function with both file cache (for persistent cache) and optional in-memory cache (for speed).
Can be used as @persistent_inmemory_cache or @persistent_inmemory_cache(in_memory_cache=False)
"""
return cast(T, cache(MEMORY.cache(func)) if APIKeys().ENABLE_CACHE else func)
if func is None:
# Ugly Pythonic way to support this decorator as `@persistent_inmemory_cache` but also `@persistent_inmemory_cache(in_memory_cache=False)`
def decorator(func: T) -> T:
return persistent_inmemory_cache(
func,
in_memory_cache=in_memory_cache,
)

return decorator
else:
# The decorator is called without arguments.
if not APIKeys().ENABLE_CACHE:
return func
cached_func = MEMORY.cache(func)
if in_memory_cache:
cached_func = cache(cached_func)
return cast(T, cached_func)
5 changes: 4 additions & 1 deletion prediction_market_agent_tooling/tools/httpx_cached_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

class HttpxCachedClient:
def __init__(self) -> None:
storage = hishel.FileStorage(ttl=3600, check_ttl_every=600)
storage = hishel.FileStorage(
ttl=24 * 60 * 60,
check_ttl_every=1 * 60 * 60,
)
controller = hishel.Controller(force_cache=True)
self.client = hishel.CacheClient(storage=storage, controller=controller)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def get_traces_for_agent(
from_timestamp: DatetimeUTC,
has_output: bool,
client: Langfuse,
to_timestamp: DatetimeUTC | None = None,
) -> list[TraceWithDetails]:
"""
Fetch agent traces using pagination
Expand All @@ -76,6 +77,7 @@ def get_traces_for_agent(
limit=100,
page=page,
from_timestamp=from_timestamp,
to_timestamp=to_timestamp,
)
if not traces.data:
break
Expand Down
4 changes: 2 additions & 2 deletions prediction_market_agent_tooling/tools/transaction_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

class TransactionBlockCache:
def __init__(self, web3: Web3):
self.block_number_cache = dc.Cache("block_cache_dir")
self.block_timestamp_cache = dc.Cache("timestamp_cache_dir")
self.block_number_cache = dc.Cache(".cache/block_cache_dir")
self.block_timestamp_cache = dc.Cache(".cache/timestamp_cache_dir")
self.web3 = web3

@tenacity.retry(
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "prediction-market-agent-tooling"
version = "0.57.1"
version = "0.57.2"
description = "Tools to benchmark, deploy and monitor prediction market agents."
authors = ["Gnosis"]
readme = "README.md"
Expand Down Expand Up @@ -53,11 +53,13 @@ types-python-dateutil = "^2.9.0.20240906"
pinatapy-vourhey = "^0.2.0"
hishel = "^0.0.31"
pytest-postgresql = "^6.1.1"
optuna = { version = "^4.1.0", optional = true}

[tool.poetry.extras]
openai = ["openai"]
langchain = ["langchain", "langchain-openai"]
google = ["google-api-python-client"]
optuna = ["optuna"]

[tool.poetry.group.dev.dependencies]
pytest = "*"
Expand Down
2 changes: 1 addition & 1 deletion scripts/sell_all_omen.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def sell_all(
better_address=better_address,
market_opening_after=utcnow() + timedelta(days=closing_later_than_days),
)
bets_total_usd = sum(b.collateralAmountUSD for b in bets)
bets_total_usd = sum(b.collateral_amount_usd for b in bets)
unique_market_urls = set(b.fpmm.url for b in bets)
starting_balance = get_balances(better_address)
new_balance = starting_balance # initialisation
Expand Down
Loading

0 comments on commit 26dad12

Please sign in to comment.