diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 29167ac031..2c4703cdd3 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -16,7 +16,7 @@ from flytekit.core.tracker import TrackedInstance, extract_task_module from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit from flytekit.extras.accelerators import BaseAccelerator -from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, _calculate_deduped_hash_from_image_spec from flytekit.loggers import logger from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext @@ -276,8 +276,12 @@ def get_registerable_container_image(img: Optional[Union[str, ImageSpec]], cfg: :return: """ if isinstance(img, ImageSpec): - ImageBuildEngine.build(img) - return img.image_name() + image = cfg.find_image(_calculate_deduped_hash_from_image_spec(img)) + image_name = image.full if image else None + if not image_name: + ImageBuildEngine.build(img) + image_name = img.image_name() + return image_name if img is not None and img != "": matches = _IMAGE_REPLACE_REGEX.findall(img) diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index 98f6c05cdc..37f87549d0 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -280,6 +280,21 @@ def _build_image(cls, builder, image_spec, img_name): cls._IMAGE_NAME_TO_REAL_NAME[img_name] = fully_qualified_image_name +@lru_cache +def _calculate_deduped_hash_from_image_spec(image_spec: ImageSpec): + """ + Calculate this special hash from the image spec, + and it used to identify the imageSpec in the ImageConfig in the serialization context. + + ImageConfig: + - deduced hash 1: flyteorg/flytekit: 123 + - deduced hash 2: flyteorg/flytekit: 456 + """ + image_spec_bytes = asdict(image_spec).__str__().encode("utf-8") + # copy the image spec to avoid modifying the original image spec. otherwise, the hash will be different. + return base64.urlsafe_b64encode(hashlib.md5(image_spec_bytes).digest()).decode("ascii").rstrip("=") + + @lru_cache def calculate_hash_from_image_spec(image_spec: ImageSpec): """ diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index b49639d23a..a77e0a0bf5 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -6,9 +6,10 @@ from flyteidl.admin import schedule_pb2 -from flytekit import PythonFunctionTask, SourceCode -from flytekit.configuration import SerializationSettings +from flytekit import ImageSpec, PythonFunctionTask, SourceCode +from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import constants as _common_constants +from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask from flytekit.core.base_task import PythonTask from flytekit.core.condition import BranchNode @@ -22,6 +23,7 @@ from flytekit.core.task import ReferenceTask from flytekit.core.utils import ClassDecorator, _dnsify from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase +from flytekit.image_spec.image_spec import _calculate_deduped_hash_from_image_spec from flytekit.models import common as _common_models from flytekit.models import common as common_models from flytekit.models import interface as interface_models @@ -176,6 +178,19 @@ def get_serializable_task( ) if isinstance(entity, PythonFunctionTask) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC: + for e in context_manager.FlyteEntities.entities: + if isinstance(e, PythonAutoContainerTask): + # 1. Build the ImageSpec for all the entities that are inside the current context, + # 2. Add images to the serialization context, so the dynamic task can look it up at runtime. + if isinstance(e.container_image, ImageSpec): + if settings.image_config.images is None: + settings.image_config = ImageConfig.create_from(settings.image_config.default_image) + settings.image_config.images.append( + Image.look_up_image_info( + _calculate_deduped_hash_from_image_spec(e.container_image), e.get_image(settings) + ) + ) + # In case of Dynamic tasks, we want to pass the serialization context, so that they can reconstruct the state # from the serialization context. This is passed through an environment variable, that is read from # during dynamic serialization diff --git a/plugins/flytekit-envd/tests/test_image_spec.py b/plugins/flytekit-envd/tests/test_image_spec.py index 7fd3cd1be0..5b7b73f755 100644 --- a/plugins/flytekit-envd/tests/test_image_spec.py +++ b/plugins/flytekit-envd/tests/test_image_spec.py @@ -57,7 +57,7 @@ def build(): run(commands=["echo hello"]) install.python_packages(name=["pandas"]) install.apt_packages(name=["git"]) - runtime.environ(env={{'PYTHONPATH': '/root', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) + runtime.environ(env={{'PYTHONPATH': '/root:', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) config.pip_index(url="https://private-pip-index/simple") install.python(version="3.8") io.copy(source="./", target="/root") @@ -88,7 +88,7 @@ def build(): run(commands=[]) install.python_packages(name=["flytekit"]) install.apt_packages(name=[]) - runtime.environ(env={{'PYTHONPATH': '/root', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) + runtime.environ(env={{'PYTHONPATH': '/root:', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) config.pip_index(url="https://pypi.org/simple") install.conda(use_mamba=True) install.conda_packages(name=["pytorch", "cpuonly"], channel=["pytorch"]) @@ -122,7 +122,7 @@ def build(): run(commands=[]) install.python_packages(name=["-U --pre pandas", "torch", "torchvision"]) install.apt_packages(name=[]) - runtime.environ(env={{'PYTHONPATH': '/root', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) + runtime.environ(env={{'PYTHONPATH': '/root:', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) config.pip_index(url="https://pypi.org/simple", extra_url="https://download.pytorch.org/whl/cpu https://pypi.anaconda.org/scientific-python-nightly-wheels/simple") """ ) diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index fc3284ca10..684f49031b 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -14,12 +14,15 @@ from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion from flytekit.extras.accelerators import A100, T4 +from flytekit.image_spec.image_spec import ImageBuildEngine from flytekit.models import literals as _literal_models from flytekit.models.task import Resources as _resources_models from flytekit.tools.translator import get_serializable -def test_normal_task(): +def test_normal_task(mock_image_spec_builder): + ImageBuildEngine.register("test", mock_image_spec_builder) + @task def t1(a: str) -> str: return a + " world" diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py index 58492fca06..5068da53de 100644 --- a/tests/flytekit/unit/core/test_python_auto_container.py +++ b/tests/flytekit/unit/core/test_python_auto_container.py @@ -9,7 +9,7 @@ from flytekit.core.pod_template import PodTemplate from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image from flytekit.core.resources import Resources -from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, _calculate_deduped_hash_from_image_spec from flytekit.tools.translator import get_serializable_task @@ -55,9 +55,17 @@ def serialization_settings(request): def test_image_name_interpolation(default_image_config): + image_spec = ImageSpec(name="image-1", registry="localhost:30000", builder="test") + + new_img_cfg = ImageConfig.create_from( + default_image_config.default_image, + other_images=[Image.look_up_image_info(_calculate_deduped_hash_from_image_spec(image_spec), "flyte/test:d1")], + ) img_to_interpolate = "{{.image.default.fqn}}:{{.image.default.version}}-special" - img = get_registerable_container_image(img=img_to_interpolate, cfg=default_image_config) + img = get_registerable_container_image(img=img_to_interpolate, cfg=new_img_cfg) assert img == "docker.io/xyz:some-git-hash-special" + img = get_registerable_container_image(img=image_spec, cfg=new_img_cfg) + assert img == "flyte/test:d1" class DummyAutoContainerTask(PythonAutoContainerTask): diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 9b11a2a16a..88297f43f4 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -6,12 +6,13 @@ import pytest import flytekit.configuration -from flytekit import ContainerTask, kwtypes +from flytekit import ContainerTask, ImageSpec, kwtypes from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.condition import conditional from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.core.task import task from flytekit.core.workflow import workflow +from flytekit.image_spec.image_spec import ImageBuildEngine, _calculate_deduped_hash_from_image_spec from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.types import SimpleType from flytekit.tools.translator import get_serializable @@ -250,7 +251,9 @@ def test_bad_configuration(): get_registerable_container_image(container_image, image_config) -def test_serialization_images(): +def test_serialization_images(mock_image_spec_builder): + ImageBuildEngine.register("test", mock_image_spec_builder) + @task(container_image="{{.image.xyz.fqn}}:{{.image.xyz.version}}") def t1(a: int) -> int: return a @@ -271,10 +274,24 @@ def t5(a: int) -> int: def t6(a: int) -> int: return a + image_spec = ImageSpec( + packages=["mypy"], + apt_packages=["git"], + registry="ghcr.io/flyteorg", + builder="test", + ) + + @task(container_image=image_spec) + def t7(a: int) -> int: + return a + with mock.patch.dict(os.environ, {"FLYTE_INTERNAL_IMAGE": "docker.io/default:version"}): imgs = ImageConfig.auto( config_file=os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config") ) + imgs.images.append( + Image(name=_calculate_deduped_hash_from_image_spec(image_spec), fqn="docker.io/t7", tag="latest") + ) rs = flytekit.configuration.SerializationSettings( project="project", domain="domain", @@ -295,8 +312,11 @@ def t6(a: int) -> int: t5_spec = get_serializable(OrderedDict(), rs, t5) assert t5_spec.template.container.image == "docker.io/org/myimage:latest" - t5_spec = get_serializable(OrderedDict(), rs, t6) - assert t5_spec.template.container.image == "docker.io/xyz_123:v1" + t6_spec = get_serializable(OrderedDict(), rs, t6) + assert t6_spec.template.container.image == "docker.io/xyz_123:v1" + + t7_spec = get_serializable(OrderedDict(), rs, t7) + assert t7_spec.template.container.image == "docker.io/t7:latest" def test_serialization_command1():