Skip to content

Commit

Permalink
DB cache in form of decorator (#534)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Nov 6, 2024
1 parent 73f0e21 commit c6ae204
Show file tree
Hide file tree
Showing 15 changed files with 1,215 additions and 597 deletions.
855 changes: 494 additions & 361 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion prediction_market_agent_tooling/jobs/jobs_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def get_job(id: str) -> "JobAgentMarket":
"""Get a single job by its id."""

@abstractmethod
def submit_job_result(self, max_bond: float, result: str) -> ProcessedTradedMarket:
def submit_job_result(
self, agent_name: str, max_bond: float, result: str
) -> ProcessedTradedMarket:
"""Submit the completed result for this job."""

def to_simple_job(self, max_bond: float) -> SimpleJob:
Expand Down
6 changes: 4 additions & 2 deletions prediction_market_agent_tooling/jobs/omen/omen_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def get_job(id: str) -> "OmenJobAgentMarket":
OmenJobAgentMarket.get_binary_market(id=id)
)

def submit_job_result(self, max_bond: float, result: str) -> ProcessedTradedMarket:
def submit_job_result(
self, agent_name: str, max_bond: float, result: str
) -> ProcessedTradedMarket:
if not APIKeys().enable_ipfs_upload:
raise RuntimeError(
f"ENABLE_IPFS_UPLOAD must be set to True to upload job results."
Expand All @@ -86,7 +88,7 @@ def submit_job_result(self, max_bond: float, result: str) -> ProcessedTradedMark
)

keys = APIKeys()
self.store_trades(processed_traded_market, keys)
self.store_trades(processed_traded_market, keys, agent_name)

return processed_traded_market

Expand Down
351 changes: 351 additions & 0 deletions prediction_market_agent_tooling/tools/caches/db_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
import hashlib
import inspect
import json
from datetime import date, timedelta
from functools import wraps
from typing import (
Any,
Callable,
Sequence,
TypeVar,
cast,
get_args,
get_origin,
overload,
)

from pydantic import BaseModel
from sqlalchemy import Column
from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import Field, Session, SQLModel, create_engine, desc, select

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.loggers import logger
from prediction_market_agent_tooling.tools.datetime_utc import DatetimeUTC
from prediction_market_agent_tooling.tools.utils import utcnow

FunctionT = TypeVar("FunctionT", bound=Callable[..., Any])


class FunctionCache(SQLModel, table=True):
__tablename__ = "function_cache"
id: int | None = Field(default=None, primary_key=True)
function_name: str = Field(index=True)
full_function_name: str = Field(index=True)
# Args are stored to see what was the function called with.
args: Any = Field(sa_column=Column(JSONB, nullable=False))
# Args hash is stored as a fast look-up option when looking for cache hits.
args_hash: str = Field(index=True)
result: Any = Field(sa_column=Column(JSONB, nullable=False))
created_at: DatetimeUTC = Field(default_factory=utcnow, index=True)


@overload
def db_cache(
func: None = None,
*,
max_age: timedelta | None = None,
cache_none: bool = True,
api_keys: APIKeys | None = None,
ignore_args: Sequence[str] | None = None,
ignore_arg_types: Sequence[type] | None = None,
) -> Callable[[FunctionT], FunctionT]:
...


@overload
def db_cache(
func: FunctionT,
*,
max_age: timedelta | None = None,
cache_none: bool = True,
api_keys: APIKeys | None = None,
ignore_args: Sequence[str] | None = None,
ignore_arg_types: Sequence[type] | None = None,
) -> FunctionT:
...


def db_cache(
func: FunctionT | None = None,
*,
max_age: timedelta | None = None,
cache_none: bool = True,
api_keys: APIKeys | None = None,
ignore_args: Sequence[str] | None = None,
ignore_arg_types: Sequence[type] | None = None,
) -> FunctionT | Callable[[FunctionT], FunctionT]:
if func is None:
# Ugly Pythonic way to support this decorator as `@postgres_cache` but also `@postgres_cache(max_age=timedelta(days=3))`
def decorator(func: FunctionT) -> FunctionT:
return db_cache(
func,
max_age=max_age,
cache_none=cache_none,
api_keys=api_keys,
ignore_args=ignore_args,
ignore_arg_types=ignore_arg_types,
)

return decorator

api_keys = api_keys if api_keys is not None else APIKeys()

sqlalchemy_db_url = api_keys.SQLALCHEMY_DB_URL
if sqlalchemy_db_url is None:
logger.warning(
f"SQLALCHEMY_DB_URL not provided in the environment, skipping function caching."
)

engine = (
create_engine(
sqlalchemy_db_url.get_secret_value(),
# Use custom json serializer and deserializer, because otherwise, for example `datetime` serialization would fail.
json_serializer=json_serializer,
json_deserializer=json_deserializer,
)
if sqlalchemy_db_url is not None
else None
)

# Create table if it doesn't exist
if engine is not None:
SQLModel.metadata.create_all(engine)

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# If caching is disabled, just call the function and return it
if not api_keys.ENABLE_CACHE:
return func(*args, **kwargs)

# Convert *args and **kwargs to a single dictionary, where we have names for arguments passed as args as well.
signature = inspect.signature(func)
bound_arguments = signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()

# Convert any argument that is Pydantic model into classic dictionary, otherwise it won't be json-serializable.
args_dict: dict[str, Any] = bound_arguments.arguments

# Remove `self` or `cls` if present (in case of class' methods)
if "self" in args_dict:
del args_dict["self"]
if "cls" in args_dict:
del args_dict["cls"]

# Remove ignored arguments
if ignore_args:
for arg in ignore_args:
if arg in args_dict:
del args_dict[arg]

# Remove arguments of ignored types
if ignore_arg_types:
args_dict = {
k: v
for k, v in args_dict.items()
if not isinstance(v, tuple(ignore_arg_types))
}

# Compute a hash of the function arguments used for lookup of cached results
arg_string = json.dumps(args_dict, sort_keys=True, default=str)
args_hash = hashlib.md5(arg_string.encode()).hexdigest()

# Get the full function name as concat of module and qualname, to not accidentally clash
full_function_name = func.__module__ + "." + func.__qualname__
# But also get the standard function name to easily search for it in database
function_name = func.__name__

# Determine if the function returns or contains Pydantic BaseModel(s)
return_type = func.__annotations__.get("return", None)
is_pydantic_model = False

if return_type is not None and contains_pydantic_model(return_type):
is_pydantic_model = True

# If postgres access was specified, try to find a hit
if engine is not None:
with Session(engine) as session:
# Try to get cached result
statement = (
select(FunctionCache)
.where(
FunctionCache.function_name == function_name,
FunctionCache.full_function_name == full_function_name,
FunctionCache.args_hash == args_hash,
)
.order_by(desc(FunctionCache.created_at))
)
if max_age is not None:
cutoff_time = utcnow() - max_age
statement = statement.where(FunctionCache.created_at >= cutoff_time)
cached_result = session.exec(statement).first()
else:
cached_result = None

if cached_result:
logger.info(
# Keep the special [case-hit] identifier so we can easily track it in GCP.
f"[cache-hit] Cache hit for {full_function_name} with args {args_dict} and output {cached_result.result}"
)
if is_pydantic_model:
# If the output contains any Pydantic models, we need to initialise them.
try:
return convert_cached_output_to_pydantic(
return_type, cached_result.result
)
except ValueError as e:
# In case of backward-incompatible pydantic model, just treat it as cache miss, to not error out.
logger.warning(
f"Can not validate {cached_result=} into {return_type=} because {e=}, treating as cache miss."
)
cached_result = None
else:
return cached_result.result

# On cache miss, compute the result
computed_result = func(*args, **kwargs)
# Keep the special [case-miss] identifier so we can easily track it in GCP.
logger.info(
f"[cache-miss] Cache miss for {full_function_name} with args {args_dict}, computed the output {computed_result}"
)

# If postgres access was specified, save it.
if engine is not None and (cache_none or computed_result is not None):
cache_entry = FunctionCache(
function_name=function_name,
full_function_name=full_function_name,
args_hash=args_hash,
args=args_dict,
result=computed_result,
created_at=utcnow(),
)
with Session(engine) as session:
logger.info(f"Saving {cache_entry} into database.")
session.add(cache_entry)
session.commit()

return computed_result

return cast(FunctionT, wrapper)


def contains_pydantic_model(return_type: Any) -> bool:
"""
Check if the return type contains anything that's a Pydantic model (including nested structures, like `list[BaseModel]`, `dict[str, list[BaseModel]]`, etc.)
"""
if return_type is None:
return False
origin = get_origin(return_type)
if origin is not None:
return any(contains_pydantic_model(arg) for arg in get_args(return_type))
if inspect.isclass(return_type):
return issubclass(return_type, BaseModel)
return False


def json_serializer_default_fn(
y: DatetimeUTC | timedelta | date | BaseModel,
) -> str | dict[str, Any]:
"""
Used to serialize objects that don't support it by default into a specific string that can be deserialized out later.
If this function returns a dictionary, it will be called recursivelly.
If you add something here, also add it to `replace_custom_stringified_objects` below.
"""
if isinstance(y, DatetimeUTC):
return f"DatetimeUTC::{y.isoformat()}"
elif isinstance(y, timedelta):
return f"timedelta::{y.total_seconds()}"
elif isinstance(y, date):
return f"date::{y.isoformat()}"
elif isinstance(y, BaseModel):
return y.model_dump()
raise TypeError(
f"Unsuported type for the default json serialize function, value is {y}."
)


def json_serializer(x: Any) -> str:
return json.dumps(x, default=json_serializer_default_fn)


def replace_custom_stringified_objects(obj: Any) -> Any:
"""
Used to deserialize objects from `json_serializer_default_fn` into their proper form.
"""
if isinstance(obj, str):
if obj.startswith("DatetimeUTC::"):
iso_str = obj[len("DatetimeUTC::") :]
return DatetimeUTC.to_datetime_utc(iso_str)
elif obj.startswith("timedelta::"):
total_seconds_str = obj[len("timedelta::") :]
return timedelta(seconds=float(total_seconds_str))
elif obj.startswith("date::"):
iso_str = obj[len("date::") :]
return date.fromisoformat(iso_str)
else:
return obj
elif isinstance(obj, dict):
return {k: replace_custom_stringified_objects(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [replace_custom_stringified_objects(item) for item in obj]
else:
return obj


def json_deserializer(s: str) -> Any:
data = json.loads(s)
return replace_custom_stringified_objects(data)


def convert_cached_output_to_pydantic(return_type: Any, data: Any) -> Any:
"""
Used to initialize Pydantic models from anything cached that was originally a Pydantic model in the output. Including models in nested structures.
"""
# Get the origin and arguments of the model type
origin = get_origin(return_type)
args = get_args(return_type)

# Check if the data is a dictionary
if isinstance(data, dict):
# If the model has no origin, check if it is a subclass of BaseModel
if origin is None:
if inspect.isclass(return_type) and issubclass(return_type, BaseModel):
# Convert the dictionary to a Pydantic model
return return_type(
**{
k: convert_cached_output_to_pydantic(
getattr(return_type, k, None), v
)
for k, v in data.items()
}
)
else:
# If not a Pydantic model, return the data as is
return data
# If the origin is a dictionary, convert keys and values
elif origin is dict:
key_type, value_type = args
return {
convert_cached_output_to_pydantic(
key_type, k
): convert_cached_output_to_pydantic(value_type, v)
for k, v in data.items()
}
else:
# If the origin is not a dictionary, return the data as is
return data
# Check if the data is a list
elif isinstance(data, (list, tuple)):
# If the origin is a list or tuple, convert each item
if origin in {list, tuple}:
item_type = args[0]
converted_items = [
convert_cached_output_to_pydantic(item_type, item) for item in data
]
return type(data)(converted_items)
else:
# If the origin is not a list or tuple, return the data as is
return data
else:
# If the data is neither a dictionary nor a list, return it as is
return data
Loading

0 comments on commit c6ae204

Please sign in to comment.