From 4e9da731a65643bf4c8fe9e06a8635c8882a2476 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Thu, 13 Jul 2023 00:57:31 -0700 Subject: [PATCH] Clu metrics + `__future__.annotations` PiperOrigin-RevId: 547715759 --- clu/metrics.py | 15 +++++++++++---- clu/metrics_test.py | 2 ++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/clu/metrics.py b/clu/metrics.py index a2faa1c..27fcb70 100644 --- a/clu/metrics.py +++ b/clu/metrics.py @@ -57,6 +57,7 @@ def evaluate(model, p_variables, test_ds): """ from __future__ import annotations from collections.abc import Mapping, Sequence +import inspect from typing import Any, TypeVar, Protocol from absl import logging @@ -534,8 +535,11 @@ def empty(cls: type[C]) -> C: _reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)), **{ metric_name: metric.empty() - for metric_name, metric in cls.__annotations__.items() - }) + for metric_name, metric in inspect.get_annotations( + cls, eval_str=True + ).items() + }, + ) @classmethod def _from_model_output(cls: type[C], **kwargs) -> C: @@ -544,8 +548,11 @@ def _from_model_output(cls: type[C], **kwargs) -> C: _reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)), **{ metric_name: metric.from_model_output(**kwargs) - for metric_name, metric in cls.__annotations__.items() - }) + for metric_name, metric in inspect.get_annotations( + cls, eval_str=True + ).items() + }, + ) @classmethod def single_from_model_output(cls: type[C], **kwargs) -> C: diff --git a/clu/metrics_test.py b/clu/metrics_test.py index 0c5a583..543f45c 100644 --- a/clu/metrics_test.py +++ b/clu/metrics_test.py @@ -14,6 +14,8 @@ """Tests for clu.metrics.""" +from __future__ import annotations + import functools from unittest import mock