From dfa90caad04756b6956ecb462156e80149f24098 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Wed, 6 Nov 2024 16:32:59 +0100 Subject: [PATCH] Initialize db_cache engine in wrapper (#546) --- .../tools/caches/db_cache.py | 59 +++++++------------ pyproject.toml | 2 +- 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/prediction_market_agent_tooling/tools/caches/db_cache.py b/prediction_market_agent_tooling/tools/caches/db_cache.py index 664309ba..500b686c 100644 --- a/prediction_market_agent_tooling/tools/caches/db_cache.py +++ b/prediction_market_agent_tooling/tools/caches/db_cache.py @@ -91,33 +91,22 @@ def decorator(func: FunctionT) -> FunctionT: api_keys = api_keys if api_keys is not None else APIKeys() - sqlalchemy_db_url = api_keys.SQLALCHEMY_DB_URL - if sqlalchemy_db_url is None: - logger.warning( - f"SQLALCHEMY_DB_URL not provided in the environment, skipping function caching." - ) + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + # If caching is disabled, just call the function and return it + if not api_keys.ENABLE_CACHE: + return func(*args, **kwargs) - engine = ( - create_engine( - sqlalchemy_db_url.get_secret_value(), + engine = create_engine( + api_keys.sqlalchemy_db_url.get_secret_value(), # Use custom json serializer and deserializer, because otherwise, for example `datetime` serialization would fail. json_serializer=json_serializer, json_deserializer=json_deserializer, ) - if sqlalchemy_db_url is not None - else None - ) - # Create table if it doesn't exist - if engine is not None: + # Create table if it doesn't exist SQLModel.metadata.create_all(engine) - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - # If caching is disabled, just call the function and return it - if not api_keys.ENABLE_CACHE: - return func(*args, **kwargs) - # Convert *args and **kwargs to a single dictionary, where we have names for arguments passed as args as well. signature = inspect.signature(func) bound_arguments = signature.bind(*args, **kwargs) @@ -162,25 +151,21 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: if return_type is not None and contains_pydantic_model(return_type): is_pydantic_model = True - # If postgres access was specified, try to find a hit - if engine is not None: - with Session(engine) as session: - # Try to get cached result - statement = ( - select(FunctionCache) - .where( - FunctionCache.function_name == function_name, - FunctionCache.full_function_name == full_function_name, - FunctionCache.args_hash == args_hash, - ) - .order_by(desc(FunctionCache.created_at)) + with Session(engine) as session: + # Try to get cached result + statement = ( + select(FunctionCache) + .where( + FunctionCache.function_name == function_name, + FunctionCache.full_function_name == full_function_name, + FunctionCache.args_hash == args_hash, ) - if max_age is not None: - cutoff_time = utcnow() - max_age - statement = statement.where(FunctionCache.created_at >= cutoff_time) - cached_result = session.exec(statement).first() - else: - cached_result = None + .order_by(desc(FunctionCache.created_at)) + ) + if max_age is not None: + cutoff_time = utcnow() - max_age + statement = statement.where(FunctionCache.created_at >= cutoff_time) + cached_result = session.exec(statement).first() if cached_result: logger.info( diff --git a/pyproject.toml b/pyproject.toml index 36a87f08..83197c3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "prediction-market-agent-tooling" -version = "0.56.0" +version = "0.56.1" description = "Tools to benchmark, deploy and monitor prediction market agents." authors = ["Gnosis"] readme = "README.md"