-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(SDK): adds Artifact base class. (#4895)
* Add base Artifact class * add __repr__ test * add __repr__ test * add serialization * add serialization tests * add deserialization tests * add absl-py * update requirements * Resolve comments
- Loading branch information
Jiaxiao Zheng
authored
Dec 15, 2020
1 parent
89e4210
commit 7591805
Showing
4 changed files
with
532 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,336 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Base class for MLMD artifact ontology in KFP SDK.""" | ||
from typing import Any, Dict, Optional | ||
|
||
from absl import logging | ||
import enum | ||
import importlib | ||
from google.protobuf import json_format | ||
import yaml | ||
|
||
from kfp.pipeline_spec import pipeline_spec_pb2 | ||
|
||
_KFP_ARTIFACT_TITLE_PATTERN = 'kfp.{}' | ||
_KFP_ARTIFACT_ONTOLOGY_MODULE = 'kfp.v2.dsl.artifacts' | ||
|
||
|
||
# Enum for property types. | ||
# This is introduced to decouple the MLMD ontology with Python built-in types. | ||
class PropertyType(enum.Enum): | ||
INT = 1 | ||
DOUBLE = 2 | ||
STRING = 3 | ||
|
||
|
||
class Property(object): | ||
"""Property specified for an Artifact.""" | ||
|
||
# Mapping from Python enum to primitive type in the IR proto. | ||
_ALLOWED_PROPERTY_TYPES = { | ||
PropertyType.INT: pipeline_spec_pb2.PrimitiveType.INT, | ||
PropertyType.DOUBLE: pipeline_spec_pb2.PrimitiveType.DOUBLE, | ||
PropertyType.STRING: pipeline_spec_pb2.PrimitiveType.STRING, | ||
} | ||
|
||
def __init__(self, type: PropertyType, description: Optional[str] = None): | ||
if type not in Property._ALLOWED_PROPERTY_TYPES: | ||
raise ValueError('Property type must be one of %s.' % | ||
list(Property._ALLOWED_PROPERTY_TYPES.keys())) | ||
self.type = type | ||
self.description = description | ||
|
||
@classmethod | ||
def from_dict(cls, dict_data: Dict[str, Any]) -> 'Property': | ||
"""Deserializes the Property object from YAML dict.""" | ||
if not dict_data.get('type'): | ||
raise TypeError('Missing type keyword in property dict.') | ||
if dict_data['type'] == 'string': | ||
kind = PropertyType.STRING | ||
elif dict_data['type'] == 'int': | ||
kind = PropertyType.INT | ||
elif dict_data['type'] == 'double': | ||
kind = PropertyType.DOUBLE | ||
else: | ||
raise TypeError('Got unknown type: %s' % dict_data['type']) | ||
|
||
return Property( | ||
type=kind, | ||
description=dict_data['description'] | ||
) | ||
|
||
def get_ir_type(self): | ||
"""Gets the IR primitive type.""" | ||
return Property._ALLOWED_PROPERTY_TYPES[self.type] | ||
|
||
def get_type_name(self): | ||
"""Gets the type name used in YAML instance.""" | ||
if self.type == PropertyType.INT: | ||
return 'int' | ||
elif self.type == PropertyType.DOUBLE: | ||
return 'double' | ||
elif self.type == PropertyType.STRING: | ||
return 'string' | ||
else: | ||
raise TypeError('Unexpected property type: %s' % self.type) | ||
|
||
|
||
class Artifact(object): | ||
"""KFP Artifact Python class. | ||
Artifact Python class/object mainly serves following purposes in different | ||
period of its lifecycle. | ||
1. During compile time, users can use Artifact class to annotate I/O types of | ||
their components. | ||
2. At runtime, Artifact objects provide helper function/utilities to access | ||
the underlying RuntimeArtifact pb message, and provide additional layers | ||
of validation to ensure type compatibility. | ||
""" | ||
|
||
# Name of the Artifact type. | ||
TYPE_NAME = None | ||
# Property schema. | ||
# Example usage: | ||
# | ||
# PROPERTIES = { | ||
# 'span': Property(type=PropertyType.INT), | ||
# # Comma separated of splits for an artifact. Empty string means artifact | ||
# # has no split. | ||
# 'split_names': Property(type=PropertyType.STRING), | ||
# } | ||
PROPERTIES = None | ||
|
||
# Initialization flag to support setattr / getattr behavior. | ||
_initialized = False | ||
|
||
def __init__(self, instance_schema: Optional[str] = None): | ||
"""Constructs an instance of Artifact""" | ||
if self.__class__ == Artifact: | ||
if not instance_schema: | ||
raise ValueError( | ||
'The "instance_schema" argument must be passed to specify a ' | ||
'type for this Artifact.') | ||
schema_yaml = yaml.safe_load(instance_schema) | ||
if 'properties' not in schema_yaml: | ||
raise ValueError('Invalid instance_schema, properties must be present. ' | ||
'Got %s' % instance_schema) | ||
schema = schema_yaml['properties'] | ||
self.TYPE_NAME = yaml.safe_load(instance_schema)['title'] | ||
self.PROPERTIES = {} | ||
for k, v in schema.items(): | ||
self.PROPERTIES[k] = Property.from_dict(v) | ||
else: | ||
if instance_schema: | ||
raise ValueError( | ||
'The "mlmd_artifact_type" argument must not be passed for ' | ||
'Artifact subclass %s.' % self.__class__) | ||
instance_schema = self._get_artifact_type() | ||
|
||
# MLMD artifact type schema string. | ||
self._type_schema = instance_schema | ||
# Instantiate a RuntimeArtifact pb message as the POD data structure. | ||
self._artifact = pipeline_spec_pb2.RuntimeArtifact() | ||
self._artifact.type.CopyFrom(pipeline_spec_pb2.ArtifactTypeSchema( | ||
instance_schema=instance_schema | ||
)) | ||
# Initialization flag to prevent recursive getattr / setattr errors. | ||
self._initialized = True | ||
|
||
def _get_artifact_type(self) -> str: | ||
"""Gets the instance_schema according to the Python schema spec.""" | ||
title = _KFP_ARTIFACT_TITLE_PATTERN.format(self.TYPE_NAME) | ||
schema_map = {} | ||
for k, v in self.PROPERTIES.items(): | ||
schema_map[k] = { | ||
'type': v.get_type_name(), | ||
'description': v.description | ||
} | ||
result_map = { | ||
'title': title, | ||
'type': 'object', | ||
'properties': schema_map | ||
} | ||
return yaml.safe_dump(result_map) | ||
|
||
@property | ||
def type_schema(self) -> str: | ||
return self._type_schema | ||
|
||
def __repr__(self) -> str: | ||
return 'Artifact(artifact: {}, type_schema: {})'.format( | ||
str(self._artifact), str(self.type_schema)) | ||
|
||
def __getattr__(self, name: str) -> Any: | ||
"""Custom __getattr__ to allow access to artifact properties.""" | ||
if name == '_artifact_type': | ||
# Prevent infinite recursion when used with copy.deepcopy(). | ||
raise AttributeError() | ||
if name not in self.PROPERTIES: | ||
raise AttributeError( | ||
'%s artifact has no property %r.' % (self.TYPE_NAME, name)) | ||
property_type = self.PROPERTIES[name].type | ||
if property_type == PropertyType.STRING: | ||
if name not in self._artifact.properties: | ||
# Avoid populating empty property protobuf with the [] operator. | ||
return '' | ||
return self._artifact.properties[name].string_value | ||
elif property_type == PropertyType.INT: | ||
if name not in self._artifact.properties: | ||
# Avoid populating empty property protobuf with the [] operator. | ||
return 0 | ||
return self._artifact.properties[name].int_value | ||
elif property_type == PropertyType.DOUBLE: | ||
if name not in self._artifact.properties: | ||
# Avoid populating empty property protobuf with the [] operator. | ||
return 0.0 | ||
return self._artifact.properties[name].double_value | ||
else: | ||
raise Exception('Unknown MLMD type %r for property %r.' % | ||
(property_type, name)) | ||
|
||
def __setattr__(self, name: str, value: Any): | ||
"""Custom __setattr__ to allow access to artifact properties.""" | ||
if not self._initialized: | ||
object.__setattr__(self, name, value) | ||
return | ||
if name not in self.PROPERTIES: | ||
if (name in self.__dict__ or | ||
any(name in c.__dict__ for c in self.__class__.mro())): | ||
# Use any provided getter / setter if available. | ||
object.__setattr__(self, name, value) | ||
return | ||
# In the case where we do not handle this via an explicit getter / | ||
# setter, we assume that the user implied an artifact attribute store, | ||
# and we raise an exception since such an attribute was not explicitly | ||
# defined in the Artifact PROPERTIES dictionary. | ||
raise AttributeError('Cannot set unknown property %r on artifact %r.' % | ||
(name, self)) | ||
property_type = self.PROPERTIES[name].type | ||
if property_type == PropertyType.STRING: | ||
if not isinstance(value, str): | ||
raise Exception( | ||
'Expected string value for property %r; got %r instead.' % | ||
(name, value)) | ||
self._artifact.properties[name].string_value = value | ||
elif property_type == PropertyType.INT: | ||
if not isinstance(value, int): | ||
raise Exception( | ||
'Expected integer value for property %r; got %r instead.' % | ||
(name, value)) | ||
self._artifact.properties[name].int_value = value | ||
elif property_type == PropertyType.DOUBLE: | ||
if not isinstance(value, float): | ||
raise Exception( | ||
'Expected integer value for property %r; got %r instead.' % | ||
(name, value)) | ||
self._artifact.properties[name].double_value = value | ||
else: | ||
raise Exception('Unknown property type %r for property %r.' % | ||
(property_type, name)) | ||
|
||
@property | ||
def type(self): | ||
return self.__class__ | ||
|
||
@property | ||
def type_name(self): | ||
return self.TYPE_NAME | ||
|
||
@property | ||
def runtime_artifact(self) -> pipeline_spec_pb2.RuntimeArtifact: | ||
return self._artifact | ||
|
||
@runtime_artifact.setter | ||
def runtime_artifact(self, artifact: pipeline_spec_pb2.RuntimeArtifact): | ||
self._artifact = artifact | ||
|
||
@property | ||
def uri(self) -> str: | ||
return self._artifact.uri | ||
|
||
@uri.setter | ||
def uri(self, uri: str) -> None: | ||
self._artifact.uri = uri | ||
|
||
@property | ||
def name(self) -> str: | ||
return self._artifact.name | ||
|
||
@name.setter | ||
def name(self, name: str) -> None: | ||
self._artifact.name = name | ||
|
||
# Custom property accessors. | ||
def set_string_custom_property(self, key: str, value: str): | ||
"""Sets a custom property of string type.""" | ||
self._artifact.custom_properties[key].string_value = value | ||
|
||
def set_int_custom_property(self, key: str, value: int): | ||
"""Sets a custom property of int type.""" | ||
self._artifact.custom_properties[key].int_value = value | ||
|
||
def set_float_custom_property(self, key: str, value: float): | ||
"""Sets a custom property of float type.""" | ||
self._artifact.custom_properties[key].double_value = value | ||
|
||
def has_custom_property(self, key: str) -> bool: | ||
return key in self._artifact.custom_properties | ||
|
||
def get_string_custom_property(self, key: str) -> str: | ||
"""Gets a custom property of string type.""" | ||
if key not in self._artifact.custom_properties: | ||
return '' | ||
return self._artifact.custom_properties[key].string_value | ||
|
||
def get_int_custom_property(self, key: str) -> int: | ||
"""Gets a custom property of int type.""" | ||
if key not in self._artifact.custom_properties: | ||
return 0 | ||
return self._artifact.custom_properties[key].int_value | ||
|
||
def get_float_custom_property(self, key: str) -> float: | ||
"""Gets a custom property of float type.""" | ||
if key not in self._artifact.custom_properties: | ||
return 0.0 | ||
return self._artifact.custom_properties[key].double_value | ||
|
||
@classmethod | ||
def deserialize(cls, data: str) -> Any: | ||
"""Deserializes an Artifact object from JSON dict.""" | ||
artifact = pipeline_spec_pb2.RuntimeArtifact() | ||
json_format.Parse(data, artifact, ignore_unknown_fields=True) | ||
instance_schema = yaml.safe_load(artifact.type.instance_schema) | ||
type_name = instance_schema['title'][len('kfp.')] | ||
result = None | ||
try: | ||
artifact_cls = getattr( | ||
importlib.import_module(_KFP_ARTIFACT_ONTOLOGY_MODULE), type_name) | ||
# TODO(numerology): Add deserialization tests for first party classes. | ||
result = artifact_cls() | ||
except (AttributeError, ImportError, ValueError): | ||
logging.warning(( | ||
'Could not load artifact class %s.%s; using fallback deserialization ' | ||
'for the relevant artifact. Please make sure that any artifact ' | ||
'classes can be imported within your container or environment.'), | ||
_KFP_ARTIFACT_ONTOLOGY_MODULE, type_name) | ||
if not result: | ||
# Otherwise generate a generic Artifact object. | ||
result = Artifact(instance_schema=artifact.type.instance_schema) | ||
result.runtime_artifact = artifact | ||
return result | ||
|
||
def serialize(self) -> str: | ||
"""Serializes an Artifact to JSON dict format.""" | ||
return json_format.MessageToJson(self._artifact, sort_keys=True) |
Oops, something went wrong.