-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Prediction to bigquery component - initial code (#210)
* initial predictions-to-bigquery component * dyn schema - wip * added comment * added code owner --------- Co-authored-by: Hannes Hapke <[email protected]>
- Loading branch information
1 parent
5b87707
commit 4862345
Showing
6 changed files
with
539 additions
and
0 deletions.
There are no files selected for viewing
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
""" | ||
Digits Prediction-to-BigQuery: Functionality to write prediction results usually from a BulkInferrer to BigQuery. | ||
""" | ||
|
||
from typing import Optional | ||
|
||
from tfx import types | ||
from tfx.dsl.components.base import base_component, executor_spec | ||
from tfx.types import standard_artifacts | ||
from tfx.types.component_spec import ChannelParameter, ExecutionParameter | ||
|
||
from .executor import Executor as AnnotateUnlabeledCategoryDataExecutor | ||
|
||
_MIN_THRESHOLD = 0.8 | ||
_VOCAB_FILE = "vocab_label_txt" | ||
|
||
|
||
class AnnotateUnlabeledCategoryDataComponentSpec(types.ComponentSpec): | ||
|
||
PARAMETERS = { | ||
# These are parameters that will be passed in the call to | ||
# create an instance of this component. | ||
"vocab_label_file": ExecutionParameter(type=str), | ||
"bq_table_name": ExecutionParameter(type=str), | ||
"filter_threshold": ExecutionParameter(type=float), | ||
"table_suffix": ExecutionParameter(type=str), | ||
"table_partitioning": ExecutionParameter(type=bool), | ||
"expiration_time_delta": ExecutionParameter(type=int), | ||
} | ||
INPUTS = { | ||
# This will be a dictionary with input artifacts, including URIs | ||
"transform_graph": ChannelParameter(type=standard_artifacts.TransformGraph), | ||
"inference_results": ChannelParameter(type=standard_artifacts.InferenceResult), | ||
"schema": ChannelParameter(type=standard_artifacts.Schema), | ||
} | ||
OUTPUTS = { | ||
"bigquery_export": ChannelParameter(type=standard_artifacts.String), | ||
} | ||
|
||
|
||
class AnnotateUnlabeledCategoryDataComponent(base_component.BaseComponent): | ||
""" | ||
AnnotateUnlabeledCategoryData Component. | ||
The component takes the following input artifacts: | ||
* Inference results: InferenceResult | ||
* Transform graph: TransformGraph | ||
* Schema: Schema (optional) if not present, the component will determine the schema | ||
(only predtion supported at the moment) | ||
The component takes the following parameters: | ||
* vocab_label_file: str - The file name of the file containing the vocabulary labels | ||
(produced by TFT). | ||
* bq_table_name: str - The name of the BigQuery table to write the results to. | ||
* filter_threshold: float - The minimum probability threshold for a prediction to | ||
be considered a positive, thrustworthy prediction. Default is 0.8. | ||
* table_suffix: str (optional) - If provided, the generated datetime string will | ||
be added the BigQuery table name as suffix. The default is %Y%m%d. | ||
* table_partitioning: bool - Whether to partition the table by DAY. If True, | ||
the generated BigQuery table will be partition by date. If False, no partitioning will | ||
be applied. Default is True. | ||
* expiration_time_delta: int (optional) - The number of seconds after which the table will expire. | ||
The component produces the following output artifacts: | ||
* bigquery_export: String - The URI of the BigQuery table containing the results. | ||
""" | ||
|
||
SPEC_CLASS = AnnotateUnlabeledCategoryDataComponentSpec | ||
EXECUTOR_SPEC = executor_spec.BeamExecutorSpec(AnnotateUnlabeledCategoryDataExecutor) | ||
|
||
def __init__( | ||
self, | ||
inference_results: types.Channel = None, | ||
transform_graph: types.Channel = None, | ||
bq_table_name: str = None, | ||
vocab_label_file: str = _VOCAB_FILE, | ||
filter_threshold: float = _MIN_THRESHOLD, | ||
table_suffix: str = "%Y%m%d", | ||
table_partitioning: bool = True, | ||
schema: Optional[types.Channel] = None, | ||
expiration_time_delta: Optional[int] = 0, | ||
bigquery_export: Optional[types.Channel] = None, | ||
): | ||
|
||
bigquery_export = bigquery_export or types.Channel(type=standard_artifacts.String) | ||
schema = schema or types.Channel(type=standard_artifacts.Schema()) | ||
|
||
spec = AnnotateUnlabeledCategoryDataComponentSpec( | ||
inference_results=inference_results, | ||
transform_graph=transform_graph, | ||
schema=schema, | ||
bq_table_name=bq_table_name, | ||
vocab_label_file=vocab_label_file, | ||
filter_threshold=filter_threshold, | ||
table_suffix=table_suffix, | ||
table_partitioning=table_partitioning, | ||
expiration_time_delta=expiration_time_delta, | ||
bigquery_export=bigquery_export, | ||
) | ||
super().__init__(spec=spec) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
""" | ||
Executor functionality to write prediction results usually from a BulkInferrer to BigQuery. | ||
""" | ||
|
||
import datetime | ||
import os | ||
from typing import Any, Dict, List, Tuple | ||
|
||
import apache_beam as beam | ||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow_transform as tft | ||
from absl import logging | ||
from tensorflow.python.eager.context import eager_mode | ||
from tensorflow_serving.apis import prediction_log_pb2 | ||
from tfx import types | ||
from tfx.dsl.components.base import base_beam_executor | ||
from tfx.types import artifact_utils | ||
|
||
from .utils import (convert_single_value_to_native_py_value, | ||
create_annotation_fields, feature_to_bq_schema, | ||
load_schema, parse_schema) | ||
|
||
_SCORE_MULTIPLIER = 1e6 | ||
_SCHEMA_FILE = "schema.pbtxt" | ||
_ADDITIONAL_BQ_PARAMETERS = {} | ||
|
||
|
||
@beam.typehints.with_input_types(str) | ||
@beam.typehints.with_output_types(beam.typehints.Iterable[Tuple[str, str, Any]]) | ||
class FilterPredictionToDictFn(beam.DoFn): | ||
""" | ||
Convert a prediction to a dictionary. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
labels: List, | ||
features: Any, | ||
ts: datetime.datetime, | ||
filter_threshold: float, | ||
score_multiplier: int = _SCORE_MULTIPLIER, | ||
): | ||
self.labels = labels | ||
self.features = features | ||
self.filter_threshold = filter_threshold | ||
self.score_multiplier = score_multiplier | ||
self.ts = ts | ||
|
||
def _fix_types(self, example): | ||
with eager_mode(): | ||
return [convert_single_value_to_native_py_value(v) for v in example.values()] | ||
|
||
def _parse_prediction(self, predictions): | ||
prediction_id = np.argmax(predictions) | ||
logging.debug("Prediction id: %s", prediction_id) | ||
logging.debug("Predictions: %s", predictions) | ||
label = self.labels[prediction_id] | ||
score = predictions[0][prediction_id] | ||
return label, score | ||
|
||
def process(self, element): | ||
parsed_examples = tf.make_ndarray(element.predict_log.request.inputs["examples"]) | ||
parsed_predictions = tf.make_ndarray(element.predict_log.response.outputs["output_0"]) | ||
|
||
example_values = self._fix_types(tf.io.parse_single_example(parsed_examples[0], self.features)) | ||
label, score = self._parse_prediction(parsed_predictions) | ||
|
||
if score > self.filter_threshold: | ||
yield { | ||
# TODO: features should be read dynamically | ||
"feature0": example_values[0], | ||
"feature1": example_values[1], | ||
"feature2": example_values[2], | ||
"category_label": label, | ||
"score": int(score * self.score_multiplier), | ||
"datetime": self.ts, | ||
} | ||
|
||
|
||
class Executor(base_beam_executor.BaseBeamExecutor): | ||
""" | ||
Beam Executor for predictions_to_bq. | ||
""" | ||
|
||
def Do( | ||
self, | ||
input_dict: Dict[str, List[types.Artifact]], | ||
output_dict: Dict[str, List[types.Artifact]], | ||
exec_properties: Dict[str, Any], | ||
) -> None: | ||
"""Do function for predictions_to_bq executor.""" | ||
|
||
ts = datetime.datetime.now().replace(second=0, microsecond=0) | ||
|
||
# check required executive properties | ||
if exec_properties["bq_table_name"] is None: | ||
raise ValueError("bq_table_name must be set in exec_properties") | ||
if exec_properties["filter_threshold"] is None: | ||
raise ValueError("filter_threshold must be set in exec_properties") | ||
if exec_properties["vocab_label_file"] is None: | ||
raise ValueError("vocab_label_file must be set in exec_properties") | ||
|
||
# get labels from tf transform generated vocab file | ||
transform_output = artifact_utils.get_single_uri(input_dict["transform_graph"]) | ||
tf_transform_output = tft.TFTransformOutput(transform_output) | ||
tft_vocab = tf_transform_output.vocabulary_by_name(vocab_filename=exec_properties["vocab_label_file"]) | ||
labels = [label.decode() for label in tft_vocab] | ||
logging.info(f"found the following labels from TFT vocab: {labels}") | ||
|
||
# get predictions from predict log | ||
inference_results_uri = artifact_utils.get_single_uri(input_dict["inference_results"]) | ||
|
||
# set table prefix and partitioning parameters | ||
bq_table_name = exec_properties["bq_table_name"] | ||
if exec_properties["table_suffix"]: | ||
bq_table_name += "_" + ts.strftime(exec_properties["table_suffix"]) | ||
|
||
if exec_properties["expiration_time_delta"]: | ||
expiration_time = int(ts.timestamp()) + exec_properties["expiration_time_delta"] | ||
_ADDITIONAL_BQ_PARAMETERS.update({"expirationTime": str(expiration_time)}) | ||
logging.info(f"expiration time on {bq_table_name} set to {expiration_time}") | ||
|
||
if exec_properties["table_partitioning"]: | ||
_ADDITIONAL_BQ_PARAMETERS.update({"timePartitioning": {"type": "DAY"}}) | ||
logging.info(f"time partitioning on {bq_table_name} set to DAY") | ||
|
||
# set prediction result file path and decoder | ||
prediction_log_path = f"{inference_results_uri}/*.gz" | ||
prediction_log_decoder = beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog) | ||
|
||
# get features from tfx schema if present | ||
if input_dict["schema"]: | ||
schema_uri = os.path.join(artifact_utils.get_single_uri(input_dict["schema"]), _SCHEMA_FILE) | ||
features = load_schema(schema_uri) | ||
|
||
# generate features from predictions | ||
else: | ||
features = parse_schema(prediction_log_path) | ||
|
||
# generate bigquery schema from tfx schema (features) | ||
bq_schema_fields = feature_to_bq_schema(features, required=True) | ||
bq_schema_fields.extend( | ||
create_annotation_fields( | ||
label_field_name="category_label", score_field_name="score", required=True, add_datetime_field=True | ||
) | ||
) | ||
bq_schema = {"fields": bq_schema_fields} | ||
logging.info(f"generated bq_schema: {bq_schema}") | ||
|
||
with self._make_beam_pipeline() as pipeline: | ||
_ = ( | ||
pipeline | ||
| "Read Prediction Log" >> beam.io.ReadFromTFRecord(prediction_log_path, coder=prediction_log_decoder) | ||
| "Filter and Convert to Dict" | ||
>> beam.ParDo( | ||
FilterPredictionToDictFn( | ||
labels=labels, | ||
features=features, | ||
ts=ts, | ||
filter_threshold=exec_properties["filter_threshold"], | ||
) | ||
) | ||
| "Write Dict to BQ" | ||
>> beam.io.gcp.bigquery.WriteToBigQuery( | ||
table=bq_table_name, | ||
schema=bq_schema, | ||
additional_bq_parameters=_ADDITIONAL_BQ_PARAMETERS, | ||
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, | ||
write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE, | ||
) | ||
) | ||
|
||
bigquery_export = artifact_utils.get_single_instance(output_dict["bigquery_export"]) | ||
|
||
bigquery_export.set_string_custom_property("generated_bq_table_name", bq_table_name) | ||
|
||
logging.info(f"Annotated data exported to {bq_table_name}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
Tests around Digits Prediction-to-BigQuery component. | ||
""" | ||
|
||
import tensorflow as tf | ||
from tfx.types import channel_utils, standard_artifacts | ||
|
||
from . import component | ||
|
||
|
||
class ComponentTest(tf.test.TestCase): | ||
def setUp(self): | ||
super(ComponentTest, self).setUp() | ||
self._transform_graph = channel_utils.as_channel([standard_artifacts.TransformGraph()]) | ||
self._inference_results = channel_utils.as_channel([standard_artifacts.InferenceResult()]) | ||
self._schema = channel_utils.as_channel([standard_artifacts.Schema()]) | ||
|
||
def testConstruct(self): | ||
# not a real test, just checking if if the component can be | ||
# instantiated | ||
_ = component.AnnotateUnlabeledCategoryDataComponent( | ||
transform_graph=self._transform_graph, | ||
inference_results=self._inference_results, | ||
schema=self._schema, | ||
bq_table_name="gcp_project:bq_database.table", | ||
vocab_label_file="vocab_txt", | ||
filter_threshold=0.1, | ||
table_suffix="%Y", | ||
table_partitioning=False, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
Oops, something went wrong.