From d66e67b42605e53c8546f94da4aa0472b1cf443e Mon Sep 17 00:00:00 2001 From: Maksim Ivanov Date: Mon, 16 Dec 2024 14:58:36 +0100 Subject: [PATCH] Add with_skyvern_context decorator --- skyvern/forge/sdk/log_artifacts.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/skyvern/forge/sdk/log_artifacts.py b/skyvern/forge/sdk/log_artifacts.py index 660e13ae4..ccbd1d744 100644 --- a/skyvern/forge/sdk/log_artifacts.py +++ b/skyvern/forge/sdk/log_artifacts.py @@ -1,6 +1,9 @@ import json import structlog +import functools +from typing import Callable + from skyvern.forge import app from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.skyvern_json_encoder import SkyvernJSONLogEncoder @@ -9,6 +12,21 @@ LOG = structlog.get_logger() +def with_skyvern_context(func: Callable): + """ + Decorator to ensure the presence of a Skyvern context for a function. + If no context is available, the function will not execute. + """ + @functools.wraps(func) + async def wrapper(*args, **kwargs): + context = skyvern_context.current() + if not context: + LOG.warning("No Skyvern context found, skipping function execution", func=func.__name__) + return + return await func(*args, **kwargs) + + return wrapper + def primary_key_from_log_entity_type(log_entity_type: LogEntityType) -> str: if log_entity_type == LogEntityType.STEP: return "step_id" @@ -21,6 +39,7 @@ def primary_key_from_log_entity_type(log_entity_type: LogEntityType) -> str: else: raise ValueError(f"Invalid log entity type: {log_entity_type}") +@with_skyvern_context async def save_step_logs(step_id: str) -> None: log = skyvern_context.current().log organization_id = skyvern_context.current().organization_id @@ -36,6 +55,7 @@ async def save_step_logs(step_id: str) -> None: ) +@with_skyvern_context async def save_task_logs(task_id: str) -> None: log = skyvern_context.current().log organization_id = skyvern_context.current().organization_id @@ -51,6 +71,7 @@ async def save_task_logs(task_id: str) -> None: ) +@with_skyvern_context async def save_workflow_run_logs(workflow_run_id: str) -> None: log = skyvern_context.current().log organization_id = skyvern_context.current().organization_id @@ -66,6 +87,7 @@ async def save_workflow_run_logs(workflow_run_id: str) -> None: ) +@with_skyvern_context async def save_workflow_run_block_logs(workflow_run_block_id: str) -> None: log = skyvern_context.current().log organization_id = skyvern_context.current().organization_id