diff --git a/paddlers/utils/stats.py b/paddlers/utils/stats.py index 33519c75..214b5f19 100644 --- a/paddlers/utils/stats.py +++ b/paddlers/utils/stats.py @@ -49,7 +49,12 @@ def update(self, stats): for k in stats.keys() } for k, v in self.meters.items(): - v.update(stats[k].numpy()) + stat = stats[k] + if stat.ndim == 0: + stat = float(stat) + else: + stat = stat.numpy() + v.update(stat) def get(self, extras=None): stats = collections.OrderedDict()