From e41dee5ad753fc5bcf127e7a3a9ca53d3cfe70d8 Mon Sep 17 00:00:00 2001 From: Hannes Hapke Date: Thu, 11 Aug 2022 18:16:02 -0700 Subject: [PATCH 1/4] initial predictions-to-bigquery component --- tfx_addons/predictions_to_biquery/__init__.py | 0 .../predictions_to_biquery/component.py | 98 ++++++++++ tfx_addons/predictions_to_biquery/executor.py | 172 ++++++++++++++++++ .../predictions_to_biquery/test_component.py | 34 ++++ tfx_addons/predictions_to_biquery/utils.py | 170 +++++++++++++++++ 5 files changed, 474 insertions(+) create mode 100644 tfx_addons/predictions_to_biquery/__init__.py create mode 100644 tfx_addons/predictions_to_biquery/component.py create mode 100644 tfx_addons/predictions_to_biquery/executor.py create mode 100644 tfx_addons/predictions_to_biquery/test_component.py create mode 100644 tfx_addons/predictions_to_biquery/utils.py diff --git a/tfx_addons/predictions_to_biquery/__init__.py b/tfx_addons/predictions_to_biquery/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tfx_addons/predictions_to_biquery/component.py b/tfx_addons/predictions_to_biquery/component.py new file mode 100644 index 00000000..5ad290f4 --- /dev/null +++ b/tfx_addons/predictions_to_biquery/component.py @@ -0,0 +1,98 @@ +""" +Digits Prediction-to-BigQuery: Functionality to write prediction results usually from a BulkInferrer to BigQuery. +""" + +from typing import Optional + +from tfx import types +from tfx.dsl.components.base import base_component, executor_spec +from tfx.types import standard_artifacts +from tfx.types.component_spec import ChannelParameter, ExecutionParameter + +from .executor import Executor as AnnotateUnlabeledCategoryDataExecutor + +_MIN_THRESHOLD = 0.8 +_VOCAB_FILE = "vocab_label_txt" + + +class AnnotateUnlabeledCategoryDataComponentSpec(types.ComponentSpec): + + PARAMETERS = { + # These are parameters that will be passed in the call to + # create an instance of this component. + "vocab_label_file": ExecutionParameter(type=str), + "bq_table_name": ExecutionParameter(type=str), + "filter_threshold": ExecutionParameter(type=float), + "table_suffix": ExecutionParameter(type=str), + "table_partitioning": ExecutionParameter(type=bool), + "expiration_time_delta": ExecutionParameter(type=int), + } + INPUTS = { + # This will be a dictionary with input artifacts, including URIs + "transform_graph": ChannelParameter(type=standard_artifacts.TransformGraph), + "inference_results": ChannelParameter(type=standard_artifacts.InferenceResult), + "schema": ChannelParameter(type=standard_artifacts.Schema), + } + OUTPUTS = { + "bigquery_export": ChannelParameter(type=standard_artifacts.String), + } + + +class AnnotateUnlabeledCategoryDataComponent(base_component.BaseComponent): + """ + AnnotateUnlabeledCategoryData Component. + + The component takes the following input artifacts: + * Inference results: InferenceResult + * Transform graph: TransformGraph + * Schema: Schema + + The component takes the following parameters: + * vocab_label_file: str - The file name of the file containing the vocabulary labels + (produced by TFT). + * bq_table_name: str - The name of the BigQuery table to write the results to. + * filter_threshold: float - The minimum probability threshold for a prediction to + be considered a positive, thrustworthy prediction. Default is 0.8. + * table_suffix: str (optional) - If provided, the generated datetime string will + be added the BigQuery table name as suffix. The default is %Y%m%d. + * table_partitioning: bool - Whether to partition the table by DAY. If True, + the generated BigQuery table will be partition by date. If False, no partitioning will + be applied. Default is True. + * expiration_time_delta: int (optional) - The number of seconds after which the table will expire. + + The component produces the following output artifacts: + * bigquery_export: String - The URI of the BigQuery table containing the results. + """ + + SPEC_CLASS = AnnotateUnlabeledCategoryDataComponentSpec + EXECUTOR_SPEC = executor_spec.BeamExecutorSpec(AnnotateUnlabeledCategoryDataExecutor) + + def __init__( + self, + inference_results: types.Channel = None, + transform_graph: types.Channel = None, + schema: types.Channel = None, + bq_table_name: str = None, + vocab_label_file: str = _VOCAB_FILE, + filter_threshold: float = _MIN_THRESHOLD, + table_suffix: str = "%Y%m%d", + table_partitioning: bool = True, + expiration_time_delta: Optional[int] = 0, + bigquery_export: Optional[types.Channel] = None, + ): + + bigquery_export = bigquery_export or types.Channel(type=standard_artifacts.String) + + spec = AnnotateUnlabeledCategoryDataComponentSpec( + inference_results=inference_results, + transform_graph=transform_graph, + schema=schema, + bq_table_name=bq_table_name, + vocab_label_file=vocab_label_file, + filter_threshold=filter_threshold, + table_suffix=table_suffix, + table_partitioning=table_partitioning, + expiration_time_delta=expiration_time_delta, + bigquery_export=bigquery_export, + ) + super().__init__(spec=spec) diff --git a/tfx_addons/predictions_to_biquery/executor.py b/tfx_addons/predictions_to_biquery/executor.py new file mode 100644 index 00000000..3199f1dd --- /dev/null +++ b/tfx_addons/predictions_to_biquery/executor.py @@ -0,0 +1,172 @@ +""" +Executor functionality to write prediction results usually from a BulkInferrer to BigQuery. +""" + +import datetime +import os +from typing import Any, Dict, List, Tuple + +import apache_beam as beam +import numpy as np +import tensorflow as tf +import tensorflow_transform as tft +from absl import logging +from tensorflow.python.eager.context import eager_mode +from tensorflow_serving.apis import prediction_log_pb2 +from tfx import types +from tfx.dsl.components.base import base_beam_executor +from tfx.types import artifact_utils + +from .utils import convert_single_value_to_native_py_value, create_annotation_fields, feature_to_bq_schema, load_schema + +_SCORE_MULTIPLIER = 1e6 +_SCHEMA_FILE = "schema.pbtxt" +_ADDITIONAL_BQ_PARAMETERS = {} + + +@beam.typehints.with_input_types(str) +@beam.typehints.with_output_types(beam.typehints.Iterable[Tuple[str, str, Any]]) +class FilterPredictionToDictFn(beam.DoFn): + """ + Convert a prediction to a dictionary. + """ + + def __init__( + self, + labels: List, + features: Any, + ts: datetime.datetime, + filter_threshold: float, + score_multiplier: int = _SCORE_MULTIPLIER, + ): + self.labels = labels + self.features = features + self.filter_threshold = filter_threshold + self.score_multiplier = score_multiplier + self.ts = ts + + def _fix_types(self, example): + with eager_mode(): + return [convert_single_value_to_native_py_value(v) for v in example.values()] + + def _parse_prediction(self, predictions): + prediction_id = np.argmax(predictions) + logging.debug("Prediction id: %s", prediction_id) + logging.debug("Predictions: %s", predictions) + label = self.labels[prediction_id] + score = predictions[0][prediction_id] + return label, score + + def process(self, element): + parsed_examples = tf.make_ndarray(element.predict_log.request.inputs["examples"]) + parsed_predictions = tf.make_ndarray(element.predict_log.response.outputs["output_0"]) + + example_values = self._fix_types(tf.io.parse_single_example(parsed_examples[0], self.features)) + label, score = self._parse_prediction(parsed_predictions) + + if score > self.filter_threshold: + # @piero generate dict dynamically + yield { + # @piero set keys dynamically + "feature0": example_values[0], + "feature1": example_values[1], + "feature2": example_values[2], + "category_label": label, + "score": int(score * self.score_multiplier), + "datetime": self.ts, + } + + +class Executor(base_beam_executor.BaseBeamExecutor): + """ + Beam Executor for predictions_to_bq. + """ + + def Do( + self, + input_dict: Dict[str, List[types.Artifact]], + output_dict: Dict[str, List[types.Artifact]], + exec_properties: Dict[str, Any], + ) -> None: + """Do function for predictions_to_bq executor.""" + + ts = datetime.datetime.now().replace(second=0, microsecond=0) + + # check required executive properties + if exec_properties["bq_table_name"] is None: + raise ValueError("bq_table_name must be set in exec_properties") + if exec_properties["filter_threshold"] is None: + raise ValueError("filter_threshold must be set in exec_properties") + if exec_properties["vocab_label_file"] is None: + raise ValueError("vocab_label_file must be set in exec_properties") + + # get features from tfx schema + schema_uri = os.path.join(artifact_utils.get_single_uri(input_dict["schema"]), _SCHEMA_FILE) + features = tft.tf_metadata.schema_utils.schema_as_feature_spec(load_schema(schema_uri)).feature_spec + + # get labels from tf transform generated vocab file + transform_output = artifact_utils.get_single_uri(input_dict["transform_graph"]) + tf_transform_output = tft.TFTransformOutput(transform_output) + tft_vocab = tf_transform_output.vocabulary_by_name(vocab_filename=exec_properties["vocab_label_file"]) + labels = [label.decode() for label in tft_vocab] + logging.info(f"found the following labels from TFT vocab: {labels}") + + # get predictions from predict log + inference_results_uri = artifact_utils.get_single_uri(input_dict["inference_results"]) + + # set table prefix and partitioning parameters + bq_table_name = exec_properties["bq_table_name"] + if exec_properties["table_suffix"]: + bq_table_name += "_" + ts.strftime(exec_properties["table_suffix"]) + + if exec_properties["expiration_time_delta"]: + expiration_time = int(ts.timestamp()) + exec_properties["expiration_time_delta"] + _ADDITIONAL_BQ_PARAMETERS.update({"expirationTime": str(expiration_time)}) + logging.info(f"expiration time on {bq_table_name} set to {expiration_time}") + + if exec_properties["table_partitioning"]: + _ADDITIONAL_BQ_PARAMETERS.update({"timePartitioning": {"type": "DAY"}}) + logging.info(f"time partitioning on {bq_table_name} set to DAY") + + # set prediction result file path and decoder + prediction_log_path = f"{inference_results_uri}/*.gz" + prediction_log_decoder = beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog) + + # generate bigquery schema from tfx schema (features) + bq_schema_fields = feature_to_bq_schema(features, required=True) + bq_schema_fields.extend( + create_annotation_fields( + label_field_name="category_label", score_field_name="score", required=True, add_datetime_field=True + ) + ) + bq_schema = {"fields": bq_schema_fields} + logging.info(f"generated bq_schema: {bq_schema}") + + with self._make_beam_pipeline() as pipeline: + _ = ( + pipeline + | "Read Prediction Log" >> beam.io.ReadFromTFRecord(prediction_log_path, coder=prediction_log_decoder) + | "Filter and Convert to Dict" + >> beam.ParDo( + FilterPredictionToDictFn( + labels=labels, + features=features, + ts=ts, + filter_threshold=exec_properties["filter_threshold"], + ) + ) + | "Write Dict to BQ" + >> beam.io.gcp.bigquery.WriteToBigQuery( + table=bq_table_name, + schema=bq_schema, + additional_bq_parameters=_ADDITIONAL_BQ_PARAMETERS, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, + write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE, + ) + ) + + bigquery_export = artifact_utils.get_single_instance(output_dict["bigquery_export"]) + + bigquery_export.set_string_custom_property("generated_bq_table_name", bq_table_name) + + logging.info(f"Annotated data exported to {bq_table_name}") diff --git a/tfx_addons/predictions_to_biquery/test_component.py b/tfx_addons/predictions_to_biquery/test_component.py new file mode 100644 index 00000000..b1697a87 --- /dev/null +++ b/tfx_addons/predictions_to_biquery/test_component.py @@ -0,0 +1,34 @@ +""" +Tests around Digits Prediction-to-BigQuery component. +""" + +import tensorflow as tf +from tfx.types import channel_utils, standard_artifacts + +from . import component + + +class ComponentTest(tf.test.TestCase): + def setUp(self): + super(ComponentTest, self).setUp() + self._transform_graph = channel_utils.as_channel([standard_artifacts.TransformGraph()]) + self._inference_results = channel_utils.as_channel([standard_artifacts.InferenceResult()]) + self._schema = channel_utils.as_channel([standard_artifacts.Schema()]) + + def testConstruct(self): + # not a real test, just checking if if the component can be + # instantiated + _ = component.AnnotateUnlabeledCategoryDataComponent( + transform_graph=self._transform_graph, + inference_results=self._inference_results, + schema=self._schema, + bq_table_name="gcp_project:bq_database.table", + vocab_label_file="vocab_txt", + filter_threshold=0.1, + table_suffix="%Y", + table_partitioning=False, + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tfx_addons/predictions_to_biquery/utils.py b/tfx_addons/predictions_to_biquery/utils.py new file mode 100644 index 00000000..41c8e53c --- /dev/null +++ b/tfx_addons/predictions_to_biquery/utils.py @@ -0,0 +1,170 @@ +""" +Util functions for the Digits Prediction-to-BigQuery component. +""" + +from typing import Any, Dict, List + +import numpy as np +import tensorflow as tf +from absl import logging +from google.protobuf import text_format +from tensorflow.python.lib.io import file_io +from tensorflow_metadata.proto.v0 import schema_pb2 + + +def load_schema(input_path: str) -> schema_pb2.Schema: + """ + Loads a TFX schema from a file and returns schema object. + + Args: + input_path: Path to the file containing the schema. + + Returns: + A schema object. + """ + + schema = schema_pb2.Schema() + schema_text = file_io.read_file_to_string(input_path) + text_format.Parse(schema_text, schema) + return schema + + +def convert_python_numpy_to_bq_type(python_type: Any) -> str: + """ + Converts a python type to a BigQuery type. + + Args: + python_type: A python type. + + Returns: + A BigQuery type. + """ + if isinstance(python_type, (int, np.int64)): + return "INTEGER" + elif isinstance(python_type, (float, np.float32)): + return "FLOAT" + elif isinstance(python_type, (str, bytes)): + return "STRING" + elif isinstance(python_type, (bool, np.bool)): + return "BOOLEAN" + else: + raise ValueError("Unsupported type: {python_type}") + + +def convert_single_value_to_native_py_value(tensor: Any) -> str: + """ + Converts a Python value to a native Python value. + + Args: + value: A value. + + Returns: + Value casted to native Python type. + """ + + if isinstance(tensor, tf.sparse.SparseTensor): + value = tensor.values.numpy()[0] + logging.debug(f"sparse value: {value}") + else: + value = tensor.numpy()[0] + logging.debug(f"dense value: {value}") + + if isinstance(value, (int, np.int64, np.int32)): + return int(value) + elif isinstance(value, (float, np.float32, np.float64)): + return float(value) + elif isinstance(value, str): + return value + elif isinstance(value, bytes): + return value.decode("utf-8") + elif isinstance(value, (bool, np.bool)): + return bool(value) + else: + raise ValueError(f"Unsupported value type: {value} of type {type(value)}") + + +def convert_tensorflow_dtype_to_bq_type(tf_dtype: tf.dtypes.DType) -> str: + """ + Converts a tensorflow dtype to a BigQuery type string. + + Args: + tf_dtype: A tensorflow dtype. + + Returns: + A BigQuery type string. + """ + if tf_dtype in (tf.int64, tf.int64): + return "INTEGER" + elif tf_dtype in (tf.float32, tf.float64): + return "FLOAT" + elif tf_dtype == tf.string: + return "STRING" + elif tf_dtype == tf.bool: + return "BOOLEAN" + else: + raise ValueError(f"Unsupported type: {tf_dtype}") + + +def feature_to_bq_schema(features: Dict[str, Any], required: bool = True) -> List[Dict]: + """ + Convert a list of features to a list of BigQuery schema fields. + + Args: + features: A list of features. + required: Whether the field is required. + + Returns: + A list of BigQuery schema fields. + """ + return [ + { + "name": feature_name, + "type": convert_tensorflow_dtype_to_bq_type(feature_def.dtype), + "mode": "REQUIRED" if required else "NULLABLE", + } + for feature_name, feature_def in features.items() + ] + + +def create_annotation_fields( + label_field_name: str = "category_label", + score_field_name: str = "score", + required: bool = True, + add_datetime_field: bool = True, +) -> List[Dict]: + """ + Create a list of BigQuery schema fields for the annotation fields. + + Args: + label_field_name: The name of the label field. + score_field_name: The name of the score field. + required: Whether the fields are required. + add_datetime_field: Whether to add a datetime field. + + Returns: + A list of BigQuery schema fields. + """ + + label_field = { + "name": label_field_name, + "type": "STRING", + "mode": "REQUIRED" if required else "NULLABLE", + } + + score_field = { + "name": score_field_name, + "type": "INTEGER", + "mode": "REQUIRED" if required else "NULLABLE", + } + + fields = [label_field, score_field] + + if add_datetime_field: + datetime_field = { + "name": "datetime", + "type": "TIMESTAMP", + "mode": "REQUIRED" if required else "NULLABLE", + } + fields.append(datetime_field) + + return fields From 5b3b2dc2fa79eb16d227f26cc165b9e9fbeaa991 Mon Sep 17 00:00:00 2001 From: Hannes Hapke Date: Sun, 5 Feb 2023 15:27:01 -0800 Subject: [PATCH 2/4] dyn schema - wip --- .../predictions_to_biquery/component.py | 6 +- tfx_addons/predictions_to_biquery/executor.py | 17 ++++-- tfx_addons/predictions_to_biquery/utils.py | 58 ++++++++++++++++++- 3 files changed, 72 insertions(+), 9 deletions(-) diff --git a/tfx_addons/predictions_to_biquery/component.py b/tfx_addons/predictions_to_biquery/component.py index 5ad290f4..20a76d3d 100644 --- a/tfx_addons/predictions_to_biquery/component.py +++ b/tfx_addons/predictions_to_biquery/component.py @@ -45,7 +45,8 @@ class AnnotateUnlabeledCategoryDataComponent(base_component.BaseComponent): The component takes the following input artifacts: * Inference results: InferenceResult * Transform graph: TransformGraph - * Schema: Schema + * Schema: Schema (optional) if not present, the component will determine the schema + (only predtion supported at the moment) The component takes the following parameters: * vocab_label_file: str - The file name of the file containing the vocabulary labels @@ -71,17 +72,18 @@ def __init__( self, inference_results: types.Channel = None, transform_graph: types.Channel = None, - schema: types.Channel = None, bq_table_name: str = None, vocab_label_file: str = _VOCAB_FILE, filter_threshold: float = _MIN_THRESHOLD, table_suffix: str = "%Y%m%d", table_partitioning: bool = True, + schema: Optional[types.Channel] = None, expiration_time_delta: Optional[int] = 0, bigquery_export: Optional[types.Channel] = None, ): bigquery_export = bigquery_export or types.Channel(type=standard_artifacts.String) + schema = schema or types.Channel(type=standard_artifacts.Schema()) spec = AnnotateUnlabeledCategoryDataComponentSpec( inference_results=inference_results, diff --git a/tfx_addons/predictions_to_biquery/executor.py b/tfx_addons/predictions_to_biquery/executor.py index 3199f1dd..a94a5704 100644 --- a/tfx_addons/predictions_to_biquery/executor.py +++ b/tfx_addons/predictions_to_biquery/executor.py @@ -17,7 +17,9 @@ from tfx.dsl.components.base import base_beam_executor from tfx.types import artifact_utils -from .utils import convert_single_value_to_native_py_value, create_annotation_fields, feature_to_bq_schema, load_schema +from .utils import (convert_single_value_to_native_py_value, + create_annotation_fields, feature_to_bq_schema, + load_schema, parse_schema) _SCORE_MULTIPLIER = 1e6 _SCHEMA_FILE = "schema.pbtxt" @@ -100,10 +102,6 @@ def Do( if exec_properties["vocab_label_file"] is None: raise ValueError("vocab_label_file must be set in exec_properties") - # get features from tfx schema - schema_uri = os.path.join(artifact_utils.get_single_uri(input_dict["schema"]), _SCHEMA_FILE) - features = tft.tf_metadata.schema_utils.schema_as_feature_spec(load_schema(schema_uri)).feature_spec - # get labels from tf transform generated vocab file transform_output = artifact_utils.get_single_uri(input_dict["transform_graph"]) tf_transform_output = tft.TFTransformOutput(transform_output) @@ -132,6 +130,15 @@ def Do( prediction_log_path = f"{inference_results_uri}/*.gz" prediction_log_decoder = beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog) + # get features from tfx schema if present + if input_dict["schema"]: + schema_uri = os.path.join(artifact_utils.get_single_uri(input_dict["schema"]), _SCHEMA_FILE) + features = load_schema(schema_uri) + + # generate features from predictions + else: + features = parse_schema(prediction_log_path) + # generate bigquery schema from tfx schema (features) bq_schema_fields = feature_to_bq_schema(features, required=True) bq_schema_fields.extend( diff --git a/tfx_addons/predictions_to_biquery/utils.py b/tfx_addons/predictions_to_biquery/utils.py index 41c8e53c..fa35c973 100644 --- a/tfx_addons/predictions_to_biquery/utils.py +++ b/tfx_addons/predictions_to_biquery/utils.py @@ -2,17 +2,19 @@ Util functions for the Digits Prediction-to-BigQuery component. """ +import glob from typing import Any, Dict, List import numpy as np import tensorflow as tf +import tensorflow_transform as tft from absl import logging from google.protobuf import text_format from tensorflow.python.lib.io import file_io from tensorflow_metadata.proto.v0 import schema_pb2 -def load_schema(input_path: str) -> schema_pb2.Schema: +def load_schema(input_path: str) -> Dict: """ Loads a TFX schema from a file and returns schema object. @@ -26,7 +28,59 @@ def load_schema(input_path: str) -> schema_pb2.Schema: schema = schema_pb2.Schema() schema_text = file_io.read_file_to_string(input_path) text_format.Parse(schema_text, schema) - return schema + return tft.tf_metadata.schema_utils.schema_as_feature_spec(schema).feature_spec + +def _get_compress_type(file_path): + magic_bytes = { + b'x\x01': 'ZLIB', + b'x^': 'ZLIB', + b'x\x9c': 'ZLIB', + b'x\xda': 'ZLIB', + b'\x1f\x8b': 'GZIP'} + + two_bytes = open(file_path, 'rb').read(2) + return magic_bytes.get(two_bytes) + +def _get_feature_type(feature=None, type_=None): + + if type_: + return { + int: tf.int64, + bool: tf.int64, + float: tf.float32, + str: tf.string, + bytes: tf.string, + }[type_] + + if feature: + if feature.HasField('int64_list'): + return tf.int64 + if feature.HasField('float_list'): + return tf.float32 + if feature.HasField('bytes_list'): + return tf.string + +def parse_schema(prediction_log_path: str, compression_type: str = 'auto') -> Dict: + + features = {} + + file_paths = glob.glob(prediction_log_path) + if compression_type == 'auto': + compression_type = _get_compress_type(file_paths[0]) + + dataset = tf.data.TFRecordDataset( + file_paths, compression_type=compression_type) + + serialized = next(iter(dataset.map(lambda serialized: serialized))) + seq_ex = tf.train.SequenceExample.FromString(serialized.numpy()) + + if seq_ex.feature_lists.feature_list: + raise NotImplementedError("FeatureLists aren't supported at the moment.") + + for key, feature in seq_ex.context.feature.items(): + features[key] = tf.io.FixedLenFeature( + (), _get_feature_type(feature=feature)) + return features def convert_python_numpy_to_bq_type(python_type: Any) -> str: From 544a2d5bb729af149f2c5bb30a7e6401b3b26e03 Mon Sep 17 00:00:00 2001 From: Hannes Hapke Date: Sun, 5 Feb 2023 15:30:42 -0800 Subject: [PATCH 3/4] added comment --- tfx_addons/predictions_to_biquery/executor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tfx_addons/predictions_to_biquery/executor.py b/tfx_addons/predictions_to_biquery/executor.py index a94a5704..49b5d07c 100644 --- a/tfx_addons/predictions_to_biquery/executor.py +++ b/tfx_addons/predictions_to_biquery/executor.py @@ -67,9 +67,8 @@ def process(self, element): label, score = self._parse_prediction(parsed_predictions) if score > self.filter_threshold: - # @piero generate dict dynamically yield { - # @piero set keys dynamically + # TODO: features should be read dynamically "feature0": example_values[0], "feature1": example_values[1], "feature2": example_values[2], From f3ae48c2054d9e5d1ab3017d5216c93c5a629dee Mon Sep 17 00:00:00 2001 From: Hannes Hapke Date: Sun, 5 Feb 2023 15:30:53 -0800 Subject: [PATCH 4/4] added code owner --- CODEOWNERS | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CODEOWNERS b/CODEOWNERS index 736fdc81..4400a871 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -34,7 +34,7 @@ /tfx_addons/sampling @kindalime @cent5 @casassg # Feast ExampleGen Component -/tfx_addons/feast_examplegen @BACtaki @casassg @wihanbooyse +/tfx_addons/feast_examplegen @BACtaki @casassg @wihanbooyse /examples/fraud_feast @BACtaki @casassg @wihanbooyse # Feature Selection Component @@ -46,3 +46,5 @@ # Message Exit Handler /tfx_addons/message_exit_handler @hanneshapke +# Predictions to Bigquery Component +/tfx_addons/predictions_to_bigquery @hanneshapke