From 441f63f7779d3e97c267b8b53fa33e025f2f2f37 Mon Sep 17 00:00:00 2001 From: bharatjetti Date: Tue, 8 Oct 2024 05:37:34 +0000 Subject: [PATCH] lint changes in few files --- official/nlp/modeling/layers/transformer_encoder_block.py | 5 +++-- .../uplift/layers/uplift_networks/base_uplift_networks.py | 3 +-- official/recommendation/uplift/metrics/label_mean.py | 3 ++- official/recommendation/uplift/metrics/label_variance.py | 2 +- official/recommendation/uplift/metrics/metric_configs.py | 2 +- official/recommendation/uplift/metrics/sliced_metric.py | 3 +-- official/recommendation/uplift/metrics/uplift_mean.py | 3 ++- official/recommendation/uplift/types.py | 3 +-- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/official/nlp/modeling/layers/transformer_encoder_block.py b/official/nlp/modeling/layers/transformer_encoder_block.py index 86f497b486..a8abc953df 100644 --- a/official/nlp/modeling/layers/transformer_encoder_block.py +++ b/official/nlp/modeling/layers/transformer_encoder_block.py @@ -28,7 +28,7 @@ class RMSNorm(tf_keras.layers.Layer): def __init__( self, - axis: Union[int , Sequence[int]] = -1, + axis: Union[int, Sequence[int]] = -1, epsilon: float = 1e-6, **kwargs ): @@ -43,7 +43,8 @@ def __init__( self.axis = [axis] if isinstance(axis, int) else axis self.epsilon = epsilon - def build(self, input_shape: Union[tf.TensorShape, Sequence[Union[int, None]]]): + def build(self, + input_shape: Union[tf.TensorShape, Sequence[Union[int, None]]]): input_shape = tf.TensorShape(input_shape) scale_shape = [1] * input_shape.rank for dim in self.axis: diff --git a/official/recommendation/uplift/layers/uplift_networks/base_uplift_networks.py b/official/recommendation/uplift/layers/uplift_networks/base_uplift_networks.py index 64552ab982..907ab663f4 100644 --- a/official/recommendation/uplift/layers/uplift_networks/base_uplift_networks.py +++ b/official/recommendation/uplift/layers/uplift_networks/base_uplift_networks.py @@ -15,13 +15,12 @@ """Defines base abstract uplift network layers.""" import abc +from typing import Union import tensorflow as tf, tf_keras from official.recommendation.uplift import types -from typing import Union - class BaseTwoTowerUpliftNetwork(tf_keras.layers.Layer, metaclass=abc.ABCMeta): """Abstract class for uplift layers that compute control and treatment logits. diff --git a/official/recommendation/uplift/metrics/label_mean.py b/official/recommendation/uplift/metrics/label_mean.py index 31b48bde31..aacaa0ef96 100644 --- a/official/recommendation/uplift/metrics/label_mean.py +++ b/official/recommendation/uplift/metrics/label_mean.py @@ -14,12 +14,13 @@ """Keras metric for computing the label mean sliced by treatment group.""" +from typing import Union + import tensorflow as tf, tf_keras from official.recommendation.uplift import types from official.recommendation.uplift.metrics import treatment_sliced_metric -from typing import Union @tf_keras.utils.register_keras_serializable(package="Uplift") class LabelMean(tf_keras.metrics.Metric): diff --git a/official/recommendation/uplift/metrics/label_variance.py b/official/recommendation/uplift/metrics/label_variance.py index fd6e78ca62..e9db2f45c0 100644 --- a/official/recommendation/uplift/metrics/label_variance.py +++ b/official/recommendation/uplift/metrics/label_variance.py @@ -13,6 +13,7 @@ # limitations under the License. """Keras metric for computing the label variance sliced by treatment group.""" +from typing import Union import tensorflow as tf, tf_keras @@ -20,7 +21,6 @@ from official.recommendation.uplift.metrics import treatment_sliced_metric from official.recommendation.uplift.metrics import variance -from typing import Union @tf_keras.utils.register_keras_serializable(package="Uplift") class LabelVariance(tf_keras.metrics.Metric): diff --git a/official/recommendation/uplift/metrics/metric_configs.py b/official/recommendation/uplift/metrics/metric_configs.py index e07cda4018..0487ccc04d 100644 --- a/official/recommendation/uplift/metrics/metric_configs.py +++ b/official/recommendation/uplift/metrics/metric_configs.py @@ -35,7 +35,7 @@ class SlicedMetricConfig(base_config.Config): slicing_feature: Union[str, None] = None slicing_spec: Union[Mapping[str, int], None] = None - slicing_feature_dtype: Union[str, None ]= None + slicing_feature_dtype: Union[str, None] = None def __post_init__( self, default_params: dict[str, Any], restrictions: list[str] diff --git a/official/recommendation/uplift/metrics/sliced_metric.py b/official/recommendation/uplift/metrics/sliced_metric.py index ef5152e2a4..ac0bf7dd5e 100644 --- a/official/recommendation/uplift/metrics/sliced_metric.py +++ b/official/recommendation/uplift/metrics/sliced_metric.py @@ -15,11 +15,10 @@ """Keras metric for reporting metrics sliced by a feature.""" import copy +from typing import Union import tensorflow as tf, tf_keras -from typing import Union - class SlicedMetric(tf_keras.metrics.Metric): """A metric sliced by integer, boolean, or string features. diff --git a/official/recommendation/uplift/metrics/uplift_mean.py b/official/recommendation/uplift/metrics/uplift_mean.py index 31d1e860eb..ee94f91d68 100644 --- a/official/recommendation/uplift/metrics/uplift_mean.py +++ b/official/recommendation/uplift/metrics/uplift_mean.py @@ -14,12 +14,13 @@ """Keras metric for computing the mean uplift sliced by treatment group.""" +from typing import Union + import tensorflow as tf, tf_keras from official.recommendation.uplift import types from official.recommendation.uplift.metrics import treatment_sliced_metric -from typing import Union @tf_keras.utils.register_keras_serializable(package="Uplift") class UpliftMean(tf_keras.metrics.Metric): diff --git a/official/recommendation/uplift/types.py b/official/recommendation/uplift/types.py index 3d21047ee1..95f316fc6e 100644 --- a/official/recommendation/uplift/types.py +++ b/official/recommendation/uplift/types.py @@ -13,12 +13,11 @@ # limitations under the License. """Defines types used by the keras uplift modeling library.""" +from typing import Union import tensorflow as tf, tf_keras -from typing import Union TensorType = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor] - ListOfTensors = list[TensorType] TupleOfTensors = tuple[TensorType, ...] DictOfTensors = dict[str, TensorType]