From f028b4889294c0796f3e7cc7e553ccffac8e1d11 Mon Sep 17 00:00:00 2001 From: LawyZheng Date: Fri, 13 Dec 2024 02:37:37 +0800 Subject: [PATCH] refactor loopblock value (#1381) --- skyvern/forge/sdk/workflow/exceptions.py | 14 +++- skyvern/forge/sdk/workflow/models/block.py | 79 +++++++++++++++------- skyvern/forge/sdk/workflow/models/yaml.py | 1 + skyvern/forge/sdk/workflow/service.py | 19 +++++- 4 files changed, 87 insertions(+), 26 deletions(-) diff --git a/skyvern/forge/sdk/workflow/exceptions.py b/skyvern/forge/sdk/workflow/exceptions.py index 7d8b38d52..1ba5ca61c 100644 --- a/skyvern/forge/sdk/workflow/exceptions.py +++ b/skyvern/forge/sdk/workflow/exceptions.py @@ -110,5 +110,17 @@ def __init__(self, workflow_parameter_type: str, workflow_parameter_key: str, re class InvalidWaitBlockTime(SkyvernException): - def __init__(self, max_sec: int): + def __init__(self, max_sec: int) -> None: super().__init__(f"Invalid wait time for wait block, it should be a number between 0 and {max_sec}.") + + +class FailedToFormatJinjaStyleParameter(SkyvernException): + def __init__(self, template: str, msg: str) -> None: + super().__init__( + f"Failed to format Jinja style parameter {template}. Please make sure the variable reference is correct. reason: {msg}" + ) + + +class NoIterableValueFound(SkyvernException): + def __init__(self) -> None: + super().__init__("No iterable value found for the loop block") diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 4b011fb6a..33cfbc8db 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -47,8 +47,10 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus from skyvern.forge.sdk.workflow.context_manager import BlockMetadata, WorkflowRunContext from skyvern.forge.sdk.workflow.exceptions import ( + FailedToFormatJinjaStyleParameter, InvalidEmailClientConfiguration, InvalidFileType, + NoIterableValueFound, NoValidEmailRecipient, ) from skyvern.forge.sdk.workflow.models.parameter import ( @@ -576,14 +578,17 @@ def get_failure_reason(self) -> str | None: class ForLoopBlock(Block): block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP - loop_over: PARAMETER_TYPE loop_blocks: list[BlockTypeVar] + loop_over: PARAMETER_TYPE | None = None + loop_variable_reference: str | None = None def get_all_parameters( self, workflow_run_id: str, ) -> list[PARAMETER_TYPE]: - parameters = {self.loop_over} + parameters = set() + if self.loop_over is not None: + parameters.add(self.loop_over) for loop_block in self.loop_blocks: for parameter in loop_block.get_all_parameters(workflow_run_id): @@ -600,6 +605,9 @@ def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any if isinstance(parameter, ContextParameter): context_parameters.append(parameter) + if self.loop_over is None: + return context_parameters + for context_parameter in context_parameters: if context_parameter.source.key != self.loop_over.key: continue @@ -620,29 +628,44 @@ def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any return context_parameters def get_loop_over_parameter_values(self, workflow_run_context: WorkflowRunContext) -> list[Any]: - if isinstance(self.loop_over, WorkflowParameter): - parameter_value = workflow_run_context.get_value(self.loop_over.key) - elif isinstance(self.loop_over, OutputParameter): - # If the output parameter is for a TaskBlock, it will be a TaskOutput object. We need to extract the - # value from the TaskOutput object's extracted_information field. - output_parameter_value = workflow_run_context.get_value(self.loop_over.key) - if isinstance(output_parameter_value, dict) and "extracted_information" in output_parameter_value: - parameter_value = output_parameter_value["extracted_information"] - else: - parameter_value = output_parameter_value - elif isinstance(self.loop_over, ContextParameter): - parameter_value = self.loop_over.value - if not parameter_value: - source_parameter_value = workflow_run_context.get_value(self.loop_over.source.key) - if isinstance(source_parameter_value, dict): - if "extracted_information" in source_parameter_value: - parameter_value = source_parameter_value["extracted_information"].get(self.loop_over.key) - else: - parameter_value = source_parameter_value.get(self.loop_over.key) + # parse the value from self.loop_variable_reference and then from self.loop_over + if self.loop_variable_reference: + value_template = f'{{{{ {self.loop_variable_reference.strip(" {}")} | tojson }}}}' + try: + value_json = self.format_block_parameter_template_from_workflow_run_context( + value_template, workflow_run_context + ) + except Exception as e: + raise FailedToFormatJinjaStyleParameter(value_template, str(e)) + parameter_value = json.loads(value_json) + + elif self.loop_over is not None: + if isinstance(self.loop_over, WorkflowParameter): + parameter_value = workflow_run_context.get_value(self.loop_over.key) + elif isinstance(self.loop_over, OutputParameter): + # If the output parameter is for a TaskBlock, it will be a TaskOutput object. We need to extract the + # value from the TaskOutput object's extracted_information field. + output_parameter_value = workflow_run_context.get_value(self.loop_over.key) + if isinstance(output_parameter_value, dict) and "extracted_information" in output_parameter_value: + parameter_value = output_parameter_value["extracted_information"] else: - raise ValueError("ContextParameter source value should be a dict") + parameter_value = output_parameter_value + elif isinstance(self.loop_over, ContextParameter): + parameter_value = self.loop_over.value + if not parameter_value: + source_parameter_value = workflow_run_context.get_value(self.loop_over.source.key) + if isinstance(source_parameter_value, dict): + if "extracted_information" in source_parameter_value: + parameter_value = source_parameter_value["extracted_information"].get(self.loop_over.key) + else: + parameter_value = source_parameter_value.get(self.loop_over.key) + else: + raise ValueError("ContextParameter source value should be a dict") + else: + raise NotImplementedError() + else: - raise NotImplementedError + raise NoIterableValueFound() if isinstance(parameter_value, list): return parameter_value @@ -725,7 +748,15 @@ async def execute_loop_helper( async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult: workflow_run_context = self.get_workflow_run_context(workflow_run_id) - loop_over_values = self.get_loop_over_parameter_values(workflow_run_context) + try: + loop_over_values = self.get_loop_over_parameter_values(workflow_run_context) + except Exception as e: + return self.build_block_result( + success=False, + failure_reason=f"failed to get loop values: {str(e)}", + status=BlockStatus.failed, + ) + LOG.info( f"Number of loop_over values: {len(loop_over_values)}", block_type=self.block_type, diff --git a/skyvern/forge/sdk/workflow/models/yaml.py b/skyvern/forge/sdk/workflow/models/yaml.py index 8c6b847e2..453e950bb 100644 --- a/skyvern/forge/sdk/workflow/models/yaml.py +++ b/skyvern/forge/sdk/workflow/models/yaml.py @@ -142,6 +142,7 @@ class ForLoopBlockYAML(BlockYAML): loop_over_parameter_key: str loop_blocks: list["BLOCK_YAML_SUBCLASSES"] + loop_variable_reference: str | None = None class CodeBlockYAML(BlockYAML): diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 838f7b2f7..378d4beab 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -1319,10 +1319,27 @@ async def block_yaml_to_block( await WorkflowService.block_yaml_to_block(workflow, loop_block, parameters) for loop_block in block_yaml.loop_blocks ] - loop_over_parameter = parameters[block_yaml.loop_over_parameter_key] + + loop_over_parameter: Parameter | None = None + if block_yaml.loop_over_parameter_key: + loop_over_parameter = parameters[block_yaml.loop_over_parameter_key] + + if block_yaml.loop_variable_reference: + # it's backaward compatible with jinja style parameter and context paramter + # we trim the format like {{ loop_key }} into loop_key to initialize the context parater, + # otherwise it might break the context parameter initialization chain, blow up the worklofw parameters + # TODO: consider remove this if we totally give up context parameter + trimmed_key = block_yaml.loop_variable_reference.strip(" {}") + if trimmed_key in parameters: + loop_over_parameter = parameters[trimmed_key] + + if loop_over_parameter is None and not block_yaml.loop_variable_reference: + raise Exception("empty loop value parameter") + return ForLoopBlock( label=block_yaml.label, loop_over=loop_over_parameter, + loop_variable_reference=block_yaml.loop_variable_reference, loop_blocks=loop_blocks, output_parameter=output_parameter, continue_on_failure=block_yaml.continue_on_failure,