Skip to content

Commit

Permalink
Add a workaround for the removal of top-level TFMA attributes (#7637)
Browse files Browse the repository at this point in the history
* Add a workaround for the removal of top-level TFMA attributes in TFMA 0.47.0
  • Loading branch information
nikelite authored Nov 22, 2024
1 parent a3aa157 commit 58fa4a8
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 10 deletions.
12 changes: 9 additions & 3 deletions tfx/components/model_validator/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
from tfx.utils import io_utils
from tfx.utils import path_utils

try:
# Try to access EvalResult from tfma directly
_EvalResult = tfma.EvalResult
except AttributeError:
# If tfma doesn't have EvalResult, use the one from view_types
from tensorflow_model_analysis.view.view_types import EvalResult as _EvalResult

class Executor(base_beam_executor.BaseBeamExecutor):
"""DEPRECATED: Please use `Evaluator` instead.
Expand All @@ -51,13 +57,13 @@ class Executor(base_beam_executor.BaseBeamExecutor):
"""

# TODO(jyzhao): customized threshold support.
def _pass_threshold(self, eval_result: tfma.EvalResult) -> bool:
def _pass_threshold(self, eval_result: _EvalResult) -> bool:
"""Check threshold."""
return True

# TODO(jyzhao): customized validation support.
def _compare_eval_result(self, current_model_eval_result: tfma.EvalResult,
blessed_model_eval_result: tfma.EvalResult) -> bool:
def _compare_eval_result(self, current_model_eval_result: _EvalResult,
blessed_model_eval_result: _EvalResult) -> bool:
"""Compare accuracy of all metrics and return true if current is better or equal."""
for current_metric, blessed_metric in zip(
current_model_eval_result.slicing_metrics,
Expand Down
19 changes: 17 additions & 2 deletions tfx/components/testdata/module_file/evaluator_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,24 @@
from tfx_bsl.tfxio import tensor_adapter


try:
# Try to access EvalSharedModel from tfma directly
_EvalSharedModel = tfma.EvalSharedModel
except AttributeError:
# If tfma doesn't have EvalSharedModel, use the one from api.types
from tensorflow_model_analysis.api.types import EvalSharedModel as _EvalSharedModel

try:
# Try to access MaybeMultipleEvalSharedModels from tfma directly
_MaybeMultipleEvalSharedModels = tfma.MaybeMultipleEvalSharedModels
except AttributeError:
# If tfma doesn't have MaybeMultipleEvalSharedModels, use the one from api.types
from tensorflow_model_analysis.api.types import MaybeMultipleEvalSharedModels as _MaybeMultipleEvalSharedModels


def custom_eval_shared_model(eval_saved_model_path: str, model_name: str,
eval_config: tfma.EvalConfig,
**kwargs: Dict[str, Any]) -> tfma.EvalSharedModel:
**kwargs: Dict[str, Any]) -> _EvalSharedModel:
return tfma.default_eval_shared_model(
eval_saved_model_path=eval_saved_model_path,
model_name=model_name,
Expand All @@ -30,7 +45,7 @@ def custom_eval_shared_model(eval_saved_model_path: str, model_name: str,


def custom_extractors(
eval_shared_model: tfma.MaybeMultipleEvalSharedModels,
eval_shared_model: _MaybeMultipleEvalSharedModels,
eval_config: tfma.EvalConfig,
tensor_adapter_config: tensor_adapter.TensorAdapterConfig,
) -> List[tfma.extractors.Extractor]:
Expand Down
15 changes: 11 additions & 4 deletions tfx/examples/penguin/experimental/sklearn_predict_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@

_PREDICT_EXTRACTOR_STAGE_NAME = 'SklearnPredict'

try:
# Try to access EvalSharedModel from tfma directly
_EvalSharedModel = tfma.EvalSharedModel
except AttributeError:
# If tfma doesn't have EvalSharedModel, use the one from api.types
from tensorflow_model_analysis.api.types import EvalSharedModel as _EvalSharedModel


def _make_sklearn_predict_extractor(
eval_shared_model: tfma.EvalSharedModel,) -> tfma.extractors.Extractor:
eval_shared_model: _EvalSharedModel,) -> tfma.extractors.Extractor:
"""Creates an extractor for performing predictions using a scikit-learn model.
The extractor's PTransform loads and runs the serving pickle against
Expand All @@ -54,7 +61,7 @@ def _make_sklearn_predict_extractor(
class _TFMAPredictionDoFn(tfma.utils.DoFnWithModels):
"""A DoFn that loads the models and predicts."""

def __init__(self, eval_shared_models: Dict[str, tfma.EvalSharedModel]):
def __init__(self, eval_shared_models: Dict[str, _EvalSharedModel]):
super().__init__({k: v.model_loader for k, v in eval_shared_models.items()})

def setup(self):
Expand Down Expand Up @@ -116,7 +123,7 @@ def process(self, elem: tfma.Extracts) -> Iterable[tfma.Extracts]:
@beam.typehints.with_output_types(tfma.Extracts)
def _ExtractPredictions( # pylint: disable=invalid-name
extracts: beam.pvalue.PCollection,
eval_shared_models: Dict[str, tfma.EvalSharedModel],
eval_shared_models: Dict[str, _EvalSharedModel],
) -> beam.pvalue.PCollection:
"""A PTransform that adds predictions and possibly other tensors to extracts.
Expand All @@ -139,7 +146,7 @@ def _custom_model_loader_fn(model_path: str):
# TFX Evaluator will call the following functions.
def custom_eval_shared_model(
eval_saved_model_path, model_name, eval_config,
**kwargs) -> tfma.EvalSharedModel:
**kwargs) -> _EvalSharedModel:
"""Returns a single custom EvalSharedModel."""
model_path = os.path.join(eval_saved_model_path, 'model.pkl')
return tfma.default_eval_shared_model(
Expand Down
10 changes: 9 additions & 1 deletion tfx/experimental/pipeline_testing/executor_verifier_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
from tensorflow_metadata.proto.v0 import anomalies_pb2


try:
# Try to access EvalResult from tfma directly
_EvalResult = tfma.EvalResult
except AttributeError:
# If tfma doesn't have EvalResult, use the one from view_types
from tensorflow_model_analysis.view.view_types import EvalResult as _EvalResult


def compare_dirs(dir1: str, dir2: str):
"""Recursively compares contents of the two directories.
Expand Down Expand Up @@ -159,7 +167,7 @@ def verify_file_dir(output_uri: str,


def _group_metric_by_slice(
eval_result: tfma.EvalResult) -> Dict[str, Dict[str, float]]:
eval_result: _EvalResult) -> Dict[str, Dict[str, float]]:
"""Returns a dictionary holding metric values for every slice.
Args:
Expand Down

0 comments on commit 58fa4a8

Please sign in to comment.