From 759180537788e01a7906432255ad663c296e9518 Mon Sep 17 00:00:00 2001 From: Jiaxiao Zheng Date: Tue, 15 Dec 2020 13:18:20 -0800 Subject: [PATCH] feat(SDK): adds Artifact base class. (#4895) * Add base Artifact class * add __repr__ test * add __repr__ test * add serialization * add serialization tests * add deserialization tests * add absl-py * update requirements * Resolve comments --- sdk/python/kfp/v2/dsl/artifact.py | 336 +++++++++++++++++++++++++ sdk/python/kfp/v2/dsl/artifact_test.py | 191 ++++++++++++++ sdk/python/requirements.in | 3 + sdk/python/requirements.txt | 3 +- 4 files changed, 532 insertions(+), 1 deletion(-) create mode 100644 sdk/python/kfp/v2/dsl/artifact.py create mode 100644 sdk/python/kfp/v2/dsl/artifact_test.py diff --git a/sdk/python/kfp/v2/dsl/artifact.py b/sdk/python/kfp/v2/dsl/artifact.py new file mode 100644 index 00000000000..3b1455344d5 --- /dev/null +++ b/sdk/python/kfp/v2/dsl/artifact.py @@ -0,0 +1,336 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base class for MLMD artifact ontology in KFP SDK.""" +from typing import Any, Dict, Optional + +from absl import logging +import enum +import importlib +from google.protobuf import json_format +import yaml + +from kfp.pipeline_spec import pipeline_spec_pb2 + +_KFP_ARTIFACT_TITLE_PATTERN = 'kfp.{}' +_KFP_ARTIFACT_ONTOLOGY_MODULE = 'kfp.v2.dsl.artifacts' + + +# Enum for property types. +# This is introduced to decouple the MLMD ontology with Python built-in types. +class PropertyType(enum.Enum): + INT = 1 + DOUBLE = 2 + STRING = 3 + + +class Property(object): + """Property specified for an Artifact.""" + + # Mapping from Python enum to primitive type in the IR proto. + _ALLOWED_PROPERTY_TYPES = { + PropertyType.INT: pipeline_spec_pb2.PrimitiveType.INT, + PropertyType.DOUBLE: pipeline_spec_pb2.PrimitiveType.DOUBLE, + PropertyType.STRING: pipeline_spec_pb2.PrimitiveType.STRING, + } + + def __init__(self, type: PropertyType, description: Optional[str] = None): + if type not in Property._ALLOWED_PROPERTY_TYPES: + raise ValueError('Property type must be one of %s.' % + list(Property._ALLOWED_PROPERTY_TYPES.keys())) + self.type = type + self.description = description + + @classmethod + def from_dict(cls, dict_data: Dict[str, Any]) -> 'Property': + """Deserializes the Property object from YAML dict.""" + if not dict_data.get('type'): + raise TypeError('Missing type keyword in property dict.') + if dict_data['type'] == 'string': + kind = PropertyType.STRING + elif dict_data['type'] == 'int': + kind = PropertyType.INT + elif dict_data['type'] == 'double': + kind = PropertyType.DOUBLE + else: + raise TypeError('Got unknown type: %s' % dict_data['type']) + + return Property( + type=kind, + description=dict_data['description'] + ) + + def get_ir_type(self): + """Gets the IR primitive type.""" + return Property._ALLOWED_PROPERTY_TYPES[self.type] + + def get_type_name(self): + """Gets the type name used in YAML instance.""" + if self.type == PropertyType.INT: + return 'int' + elif self.type == PropertyType.DOUBLE: + return 'double' + elif self.type == PropertyType.STRING: + return 'string' + else: + raise TypeError('Unexpected property type: %s' % self.type) + + +class Artifact(object): + """KFP Artifact Python class. + + Artifact Python class/object mainly serves following purposes in different + period of its lifecycle. + + 1. During compile time, users can use Artifact class to annotate I/O types of + their components. + 2. At runtime, Artifact objects provide helper function/utilities to access + the underlying RuntimeArtifact pb message, and provide additional layers + of validation to ensure type compatibility. + """ + + # Name of the Artifact type. + TYPE_NAME = None + # Property schema. + # Example usage: + # + # PROPERTIES = { + # 'span': Property(type=PropertyType.INT), + # # Comma separated of splits for an artifact. Empty string means artifact + # # has no split. + # 'split_names': Property(type=PropertyType.STRING), + # } + PROPERTIES = None + + # Initialization flag to support setattr / getattr behavior. + _initialized = False + + def __init__(self, instance_schema: Optional[str] = None): + """Constructs an instance of Artifact""" + if self.__class__ == Artifact: + if not instance_schema: + raise ValueError( + 'The "instance_schema" argument must be passed to specify a ' + 'type for this Artifact.') + schema_yaml = yaml.safe_load(instance_schema) + if 'properties' not in schema_yaml: + raise ValueError('Invalid instance_schema, properties must be present. ' + 'Got %s' % instance_schema) + schema = schema_yaml['properties'] + self.TYPE_NAME = yaml.safe_load(instance_schema)['title'] + self.PROPERTIES = {} + for k, v in schema.items(): + self.PROPERTIES[k] = Property.from_dict(v) + else: + if instance_schema: + raise ValueError( + 'The "mlmd_artifact_type" argument must not be passed for ' + 'Artifact subclass %s.' % self.__class__) + instance_schema = self._get_artifact_type() + + # MLMD artifact type schema string. + self._type_schema = instance_schema + # Instantiate a RuntimeArtifact pb message as the POD data structure. + self._artifact = pipeline_spec_pb2.RuntimeArtifact() + self._artifact.type.CopyFrom(pipeline_spec_pb2.ArtifactTypeSchema( + instance_schema=instance_schema + )) + # Initialization flag to prevent recursive getattr / setattr errors. + self._initialized = True + + def _get_artifact_type(self) -> str: + """Gets the instance_schema according to the Python schema spec.""" + title = _KFP_ARTIFACT_TITLE_PATTERN.format(self.TYPE_NAME) + schema_map = {} + for k, v in self.PROPERTIES.items(): + schema_map[k] = { + 'type': v.get_type_name(), + 'description': v.description + } + result_map = { + 'title': title, + 'type': 'object', + 'properties': schema_map + } + return yaml.safe_dump(result_map) + + @property + def type_schema(self) -> str: + return self._type_schema + + def __repr__(self) -> str: + return 'Artifact(artifact: {}, type_schema: {})'.format( + str(self._artifact), str(self.type_schema)) + + def __getattr__(self, name: str) -> Any: + """Custom __getattr__ to allow access to artifact properties.""" + if name == '_artifact_type': + # Prevent infinite recursion when used with copy.deepcopy(). + raise AttributeError() + if name not in self.PROPERTIES: + raise AttributeError( + '%s artifact has no property %r.' % (self.TYPE_NAME, name)) + property_type = self.PROPERTIES[name].type + if property_type == PropertyType.STRING: + if name not in self._artifact.properties: + # Avoid populating empty property protobuf with the [] operator. + return '' + return self._artifact.properties[name].string_value + elif property_type == PropertyType.INT: + if name not in self._artifact.properties: + # Avoid populating empty property protobuf with the [] operator. + return 0 + return self._artifact.properties[name].int_value + elif property_type == PropertyType.DOUBLE: + if name not in self._artifact.properties: + # Avoid populating empty property protobuf with the [] operator. + return 0.0 + return self._artifact.properties[name].double_value + else: + raise Exception('Unknown MLMD type %r for property %r.' % + (property_type, name)) + + def __setattr__(self, name: str, value: Any): + """Custom __setattr__ to allow access to artifact properties.""" + if not self._initialized: + object.__setattr__(self, name, value) + return + if name not in self.PROPERTIES: + if (name in self.__dict__ or + any(name in c.__dict__ for c in self.__class__.mro())): + # Use any provided getter / setter if available. + object.__setattr__(self, name, value) + return + # In the case where we do not handle this via an explicit getter / + # setter, we assume that the user implied an artifact attribute store, + # and we raise an exception since such an attribute was not explicitly + # defined in the Artifact PROPERTIES dictionary. + raise AttributeError('Cannot set unknown property %r on artifact %r.' % + (name, self)) + property_type = self.PROPERTIES[name].type + if property_type == PropertyType.STRING: + if not isinstance(value, str): + raise Exception( + 'Expected string value for property %r; got %r instead.' % + (name, value)) + self._artifact.properties[name].string_value = value + elif property_type == PropertyType.INT: + if not isinstance(value, int): + raise Exception( + 'Expected integer value for property %r; got %r instead.' % + (name, value)) + self._artifact.properties[name].int_value = value + elif property_type == PropertyType.DOUBLE: + if not isinstance(value, float): + raise Exception( + 'Expected integer value for property %r; got %r instead.' % + (name, value)) + self._artifact.properties[name].double_value = value + else: + raise Exception('Unknown property type %r for property %r.' % + (property_type, name)) + + @property + def type(self): + return self.__class__ + + @property + def type_name(self): + return self.TYPE_NAME + + @property + def runtime_artifact(self) -> pipeline_spec_pb2.RuntimeArtifact: + return self._artifact + + @runtime_artifact.setter + def runtime_artifact(self, artifact: pipeline_spec_pb2.RuntimeArtifact): + self._artifact = artifact + + @property + def uri(self) -> str: + return self._artifact.uri + + @uri.setter + def uri(self, uri: str) -> None: + self._artifact.uri = uri + + @property + def name(self) -> str: + return self._artifact.name + + @name.setter + def name(self, name: str) -> None: + self._artifact.name = name + + # Custom property accessors. + def set_string_custom_property(self, key: str, value: str): + """Sets a custom property of string type.""" + self._artifact.custom_properties[key].string_value = value + + def set_int_custom_property(self, key: str, value: int): + """Sets a custom property of int type.""" + self._artifact.custom_properties[key].int_value = value + + def set_float_custom_property(self, key: str, value: float): + """Sets a custom property of float type.""" + self._artifact.custom_properties[key].double_value = value + + def has_custom_property(self, key: str) -> bool: + return key in self._artifact.custom_properties + + def get_string_custom_property(self, key: str) -> str: + """Gets a custom property of string type.""" + if key not in self._artifact.custom_properties: + return '' + return self._artifact.custom_properties[key].string_value + + def get_int_custom_property(self, key: str) -> int: + """Gets a custom property of int type.""" + if key not in self._artifact.custom_properties: + return 0 + return self._artifact.custom_properties[key].int_value + + def get_float_custom_property(self, key: str) -> float: + """Gets a custom property of float type.""" + if key not in self._artifact.custom_properties: + return 0.0 + return self._artifact.custom_properties[key].double_value + + @classmethod + def deserialize(cls, data: str) -> Any: + """Deserializes an Artifact object from JSON dict.""" + artifact = pipeline_spec_pb2.RuntimeArtifact() + json_format.Parse(data, artifact, ignore_unknown_fields=True) + instance_schema = yaml.safe_load(artifact.type.instance_schema) + type_name = instance_schema['title'][len('kfp.')] + result = None + try: + artifact_cls = getattr( + importlib.import_module(_KFP_ARTIFACT_ONTOLOGY_MODULE), type_name) + # TODO(numerology): Add deserialization tests for first party classes. + result = artifact_cls() + except (AttributeError, ImportError, ValueError): + logging.warning(( + 'Could not load artifact class %s.%s; using fallback deserialization ' + 'for the relevant artifact. Please make sure that any artifact ' + 'classes can be imported within your container or environment.'), + _KFP_ARTIFACT_ONTOLOGY_MODULE, type_name) + if not result: + # Otherwise generate a generic Artifact object. + result = Artifact(instance_schema=artifact.type.instance_schema) + result.runtime_artifact = artifact + return result + + def serialize(self) -> str: + """Serializes an Artifact to JSON dict format.""" + return json_format.MessageToJson(self._artifact, sort_keys=True) diff --git a/sdk/python/kfp/v2/dsl/artifact_test.py b/sdk/python/kfp/v2/dsl/artifact_test.py new file mode 100644 index 00000000000..1ca47c23b5c --- /dev/null +++ b/sdk/python/kfp/v2/dsl/artifact_test.py @@ -0,0 +1,191 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for kfp.v2.ds.artifact module.""" +import importlib +import mock +import unittest +import textwrap + +from kfp.v2.dsl import artifact + + +class _MyArtifact(artifact.Artifact): + TYPE_NAME = 'MyTypeName' + PROPERTIES = { + 'int1': artifact.Property( + type=artifact.PropertyType.INT, + description='An integer-typed property'), + 'int2': artifact.Property(type=artifact.PropertyType.INT), + 'float1': artifact.Property( + type=artifact.PropertyType.DOUBLE, + description='A float-typed property'), + 'float2': artifact.Property(type=artifact.PropertyType.DOUBLE), + 'string1': artifact.Property( + type=artifact.PropertyType.STRING, + description='A string-typed property'), + 'string2': artifact.Property(type=artifact.PropertyType.STRING), + } + + +_SERIALIZED_INSTANCE = """\ +{ + "properties": { + "float1": { + "doubleValue": 1.11 + }, + "int1": { + "intValue": "1" + }, + "string1": { + "stringValue": "111" + } + }, + "type": { + "instanceSchema": "properties:\\n float1:\\n description: A float-typed property\\n type: double\\n float2:\\n description: null\\n type: double\\n int1:\\n description: An integer-typed property\\n type: int\\n int2:\\n description: null\\n type: int\\n string1:\\n description: A string-typed property\\n type: string\\n string2:\\n description: null\\n type: string\\ntitle: kfp.MyTypeName\\ntype: object\\n" + } +}""" + + +class ArtifactTest(unittest.TestCase): + + def testArtifact(self): + instance = _MyArtifact() + + # Test property getters. + self.assertEqual('', instance.uri) + self.assertEqual('', instance.name) + + # Default property does not have span or split_names. + with self.assertRaisesRegex(AttributeError, "has no property 'span'"): + _ = instance.span + with self.assertRaisesRegex(AttributeError, + "has no property 'split_names'"): + _ = instance.split_names + + # Test property setters. + instance.uri = '/tmp/uri2' + self.assertEqual('/tmp/uri2', instance.uri) + + instance.name = '1' + self.assertEqual('1', instance.name) + + # Testing artifact does not have span. + with self.assertRaisesRegex(AttributeError, "unknown property 'span'"): + instance.span = 20190101 + # Testing artifact does not have span. + with self.assertRaisesRegex(AttributeError, + "unknown property 'split_names'"): + instance.split_names = '' + + instance.set_int_custom_property('int_key', 20) + self.assertEqual( + 20, instance.runtime_artifact.custom_properties['int_key'].int_value) + + instance.set_string_custom_property('string_key', 'string_value') + self.assertEqual( + 'string_value', + instance.runtime_artifact.custom_properties['string_key'].string_value) + + self.assertEqual(textwrap.dedent("""\ + Artifact(artifact: name: "1" + type { + instance_schema: "properties:\\n float1:\\n description: A float-typed property\\n type: double\\n float2:\\n description: null\\n type: double\\n int1:\\n description: An integer-typed property\\n type: int\\n int2:\\n description: null\\n type: int\\n string1:\\n description: A string-typed property\\n type: string\\n string2:\\n description: null\\n type: string\\ntitle: kfp.MyTypeName\\ntype: object\\n" + } + uri: "/tmp/uri2" + custom_properties { + key: "int_key" + value { + int_value: 20 + } + } + custom_properties { + key: "string_key" + value { + string_value: "string_value" + } + } + , type_schema: properties: + float1: + description: A float-typed property + type: double + float2: + description: null + type: double + int1: + description: An integer-typed property + type: int + int2: + description: null + type: int + string1: + description: A string-typed property + type: string + string2: + description: null + type: string + title: kfp.MyTypeName + type: object + )"""), str(instance)) + + def testArtifactProperties(self): + my_artifact = _MyArtifact() + + self.assertEqual(0, my_artifact.int1) + self.assertEqual(0, my_artifact.int2) + my_artifact.int1 = 111 + my_artifact.int2 = 222 + self.assertEqual('', my_artifact.string1) + self.assertEqual('', my_artifact.string2) + my_artifact.string1 = '111' + my_artifact.string2 = '222' + self.assertEqual(0.0, my_artifact.float1) + self.assertEqual(0.0, my_artifact.float2) + my_artifact.float1 = 1.11 + my_artifact.float2 = 2.22 + self.assertEqual(my_artifact.int1, 111) + self.assertEqual(my_artifact.int2, 222) + self.assertEqual(my_artifact.string1, '111') + self.assertEqual(my_artifact.string2, '222') + self.assertEqual(1.11, my_artifact.float1) + self.assertEqual(2.22, my_artifact.float2) + self.assertEqual(my_artifact.get_string_custom_property('invalid'), '') + self.assertEqual(my_artifact.get_int_custom_property('invalid'), 0) + self.assertNotIn('invalid', my_artifact._artifact.custom_properties) + + with self.assertRaisesRegex( + AttributeError, "Cannot set unknown property 'invalid' on artifact"): + my_artifact.invalid = 1 + + with self.assertRaisesRegex( + AttributeError, "Cannot set unknown property 'invalid' on artifact"): + my_artifact.invalid = 'x' + + with self.assertRaisesRegex(AttributeError, + "\D+ artifact has no property 'invalid'"): + _ = my_artifact.invalid + + def testSerialize(self): + instance = _MyArtifact() + instance.int1 = 1 + instance.string1 = '111' + instance.float1 = 1.11 + + self.assertEqual(_SERIALIZED_INSTANCE, instance.serialize()) + + def testDeserialize(self): + instance = artifact.Artifact.deserialize(_SERIALIZED_INSTANCE) + self.assertEqual(1, instance.int1) + self.assertEqual('111', instance.string1) + self.assertEqual(1.11, instance.float1) + self.assertEqual('kfp.MyTypeName', instance.type_name) diff --git a/sdk/python/requirements.in b/sdk/python/requirements.in index 3b8b5dca81b..a08facdd933 100644 --- a/sdk/python/requirements.in +++ b/sdk/python/requirements.in @@ -21,3 +21,6 @@ requests_toolbelt>=0.8.0 # CLI tabulate click + +# kfp.v2 +absl-py>=0.11.0,<1 diff --git a/sdk/python/requirements.txt b/sdk/python/requirements.txt index 93b7cdc6381..bf9d132c7a3 100644 --- a/sdk/python/requirements.txt +++ b/sdk/python/requirements.txt @@ -4,6 +4,7 @@ # # pip-compile --output-file=requirements.txt requirements.in # +absl-py==0.11.0 # via -r requirements.in attrs==19.3.0 # via jsonschema cachetools==4.0.0 # via google-auth certifi==2019.11.28 # via kfp-server-api, kubernetes, requests @@ -35,7 +36,7 @@ requests-oauthlib==1.3.0 # via kubernetes requests-toolbelt==0.9.1 # via -r requirements.in requests==2.23.0 # via google-api-core, kubernetes, requests-oauthlib, requests-toolbelt rsa==4.0 # via google-auth -six==1.14.0 # via google-api-core, google-auth, google-resumable-media, jsonschema, kfp-server-api, kubernetes, protobuf, pyrsistent, python-dateutil, websocket-client +six==1.14.0 # via absl-py, google-api-core, google-auth, google-resumable-media, jsonschema, kfp-server-api, kubernetes, protobuf, pyrsistent, python-dateutil, websocket-client strip-hints==0.1.8 # via -r requirements.in tabulate==0.8.6 # via -r requirements.in urllib3==1.25.8 # via kfp-server-api, kubernetes, requests