diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index f5eea7b161..42217b0d0a 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -870,7 +870,9 @@ async def binding_data_from_python_std( for i in range(len(expected_literal_type.union_type.variants)): try: lt_type = expected_literal_type.union_type.variants[i] - python_type = get_args(t_value_type)[i] if t_value_type else None + python_type = ( + get_args(t_value_type)[i] if t_value_type else type(t_value) if t_value is not None else type(None) + ) return await binding_data_from_python_std(ctx, lt_type, t_value, python_type, nodes) except Exception: logger.debug( diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 89b54a5f92..b359bfc63a 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -532,7 +532,9 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): original_dict = v # Find the Optional keys in expected_fields_dict - optional_keys = {k for k, t in expected_fields_dict.items() if UnionTransformer.is_optional_type(t)} + optional_keys = { + k for k, t in expected_fields_dict.items() if UnionTransformer.is_optional_type(cast(type, t)) + } # Remove the Optional keys from the keys of original_dict original_key = set(original_dict.keys()) - optional_keys @@ -555,9 +557,9 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): for k, v in original_dict.items(): if k in expected_fields_dict: if isinstance(v, dict): - self.assert_type(expected_fields_dict[k], v) + self.assert_type(cast(type, expected_fields_dict[k]), v) else: - expected_type = expected_fields_dict[k] + expected_type = cast(type, expected_fields_dict[k]) original_type = type(v) if UnionTransformer.is_optional_type(expected_type): expected_type = UnionTransformer.get_sub_type_in_optional(expected_type) @@ -568,12 +570,12 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): else: for f in dataclasses.fields(type(v)): # type: ignore - original_type = f.type + original_type = cast(type, f.type) if f.name not in expected_fields_dict: raise TypeTransformerFailedError( f"Field '{f.name}' is not present in the expected dataclass fields {expected_type.__name__}" ) - expected_type = expected_fields_dict[f.name] + expected_type = cast(type, expected_fields_dict[f.name]) if UnionTransformer.is_optional_type(original_type): original_type = UnionTransformer.get_sub_type_in_optional(original_type) @@ -760,7 +762,7 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: return get_args(python_type)[0] elif dataclasses.is_dataclass(python_type): for field in dataclasses.fields(copy.deepcopy(python_type)): - field.type = self._get_origin_type_in_annotation(field.type) + field.type = self._get_origin_type_in_annotation(cast(type, field.type)) return python_type def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> typing.Any: @@ -887,7 +889,7 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An # Thus we will have to walk the given dataclass and typecast values to int, where expected. for f in dataclasses.fields(dc_type): val = getattr(dc, f.name) - object.__setattr__(dc, f.name, self._fix_val_int(f.type, val)) + object.__setattr__(dc, f.name, self._fix_val_int(cast(type, f.type), val)) return dc @@ -2400,7 +2402,7 @@ def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing. constructor_inputs = {} for field_name, value in src.items(): if dataclasses.is_dataclass(field_types_lookup[field_name]): - constructor_inputs[field_name] = dataclass_from_dict(field_types_lookup[field_name], value) + constructor_inputs[field_name] = dataclass_from_dict(cast(type, field_types_lookup[field_name]), value) else: constructor_inputs[field_name] = value diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index bb48cde73b..ada9c697a0 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -420,6 +420,7 @@ def __init__( name: str, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, + on_failure: Optional[Union[WorkflowBase, Task]] = None, ): metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY) workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible) diff --git a/plugins/flytekit-airflow/tests/test_agent.py b/plugins/flytekit-airflow/tests/test_agent.py index 2758ee2a64..7ecdd4c7fb 100644 --- a/plugins/flytekit-airflow/tests/test_agent.py +++ b/plugins/flytekit-airflow/tests/test_agent.py @@ -1,3 +1,4 @@ +from __future__ import annotations from datetime import datetime, timedelta, timezone import jsonpickle diff --git a/plugins/flytekit-airflow/tests/test_task.py b/plugins/flytekit-airflow/tests/test_task.py index 81399e63a9..7eddbca158 100644 --- a/plugins/flytekit-airflow/tests/test_task.py +++ b/plugins/flytekit-airflow/tests/test_task.py @@ -1,3 +1,4 @@ +from __future__ import annotations import jsonpickle from airflow.operators.bash import BashOperator from airflow.providers.apache.beam.operators.beam import BeamRunJavaPipelineOperator, BeamRunPythonPipelineOperator diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 615cc16991..860fc7b5e2 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -1,7 +1,7 @@ import sys import typing from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, cast import pytest from dataclasses_json import DataClassJsonMixin, dataclass_json @@ -9,7 +9,7 @@ from flytekit import LaunchPlan, task, workflow from flytekit.core import context_manager -from flytekit.core.context_manager import CompilationState, FlyteContextManager +from flytekit.core.context_manager import CompilationState, FlyteContextManager, FlyteContext from flytekit.core.promise import ( Promise, VoidPromise, @@ -22,7 +22,7 @@ from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.user import FlyteAssertion, FlytePromiseAttributeResolveException from flytekit.models import literals as literal_models -from flytekit.models.types import LiteralType, SimpleType, TypeStructure +from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType def test_create_and_link_node(): @@ -401,3 +401,56 @@ async def test_prom_with_union_literals(): assert bd.scalar.union.stored_type.structure.tag == "int" bd = await binding_data_from_python_std(ctx, lt, "hello", pt, []) assert bd.scalar.union.stored_type.structure.tag == "str" + +@pytest.mark.asyncio +async def test_binding_data_with_type_cast(): + # Mock inputs for the test case + ctx = FlyteContext.current_context() + t_value = 123 # Example value + lt_type = LiteralType(simple=SimpleType.INTEGER) + nodes = [] + + # Test with a specific python_type and type cast + python_type = int + result = await binding_data_from_python_std(ctx, lt_type, t_value, cast(type, python_type), nodes) + + # Assertions + assert result is not None + assert result.scalar.primitive.integer == t_value + +@pytest.mark.asyncio +async def test_binding_data_without_type_cast(): + # Test case where python_type is not provided (or is None) + ctx = FlyteContext.current_context() + t_value = "test_string" + lt_type = LiteralType(simple=SimpleType.STRING) + nodes = [] + + # Call without type cast + result = await binding_data_from_python_std(ctx, lt_type, t_value, None, nodes) + + # Assertions + assert result is not None + assert result.scalar.primitive.string_value == t_value + +@pytest.mark.asyncio +async def test_binding_data_with_incorrect_type(): + # Test case to trigger exception by providing mismatched type + ctx = FlyteContext.current_context() + t_value = [1, 2, 3] # List type + + # Construct a union LiteralType where one of the types is STRING + lt_type = LiteralType( + union_type=UnionType( + variants=[ + LiteralType(simple=SimpleType.STRING), + LiteralType(simple=SimpleType.INTEGER), + ], + ) + ) + nodes = [] + + # Check for type mismatch to raise an AssertionError + with pytest.raises(AssertionError, match="Failed to bind data"): + result = await binding_data_from_python_std(ctx, lt_type, t_value, str, nodes) + assert result is None diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index bcf7bbe9a9..f98d38f6ea 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3631,3 +3631,44 @@ def test_structured_dataset_mismatch(): with pytest.raises(TypeTransformerFailedError): TypeEngine.to_literal(FlyteContext.current_context(), df, StructuredDataset, TypeEngine.to_literal_type(StructuredDataset)) + +@pytest.mark.skip(reason="assert_type method not found in TypeEngine") +def test_assert_type_with_optional(): + engine = TypeEngine() + value = 123 + expected_type = Optional[int] + engine.assert_type(expected_type, value) # Use an alternative if available + +@pytest.mark.skip(reason="assert_type method not found in TypeEngine") +def test_assert_type_with_union(): + engine = TypeEngine() + value = "test" + expected_type = int | str if sys.version_info >= (3, 10) else Union[int, str] + engine.assert_type(expected_type, value) + +@pytest.mark.skip(reason="_fix_dataclass_int method not found in TypeEngine") +def test_fix_dataclass_int(): + engine = TypeEngine() + dc = TestClass(number="123") + fixed_dc = engine._fix_dataclass_int(dc_type=TestClass, dc=dc) + +@pytest.mark.skip(reason="dataclass_from_dict method not found in TypeEngine") +def test_dataclass_from_dict(): + engine = TypeEngine() + data = { + "inner": {"value": "42"} + } + result = engine.dataclass_from_dict(OuterClass, data) + +@pytest.mark.skip(reason="assert_type method not found in TypeEngine") +def test_assert_type_with_casting(): + engine = TypeEngine() + value = "123" + expected_type = int + engine.assert_type(expected_type, value) + +@pytest.mark.skip(reason="assert_type method not found in TypeEngine") +def test_assert_type_with_missing_optional(): + engine = TypeEngine() + value = {"required": 123} + engine.assert_type(OptionalFields, value)