diff --git a/co/mytorch.py b/co/mytorch.py index 19986e0..e3dee8a 100644 --- a/co/mytorch.py +++ b/co/mytorch.py @@ -542,7 +542,10 @@ def get_eval_data_loader(self, dset): def format_err_str(self, errs, div=1): err_list = [] for v in errs.values(): - if isinstance(v, (list, np.ndarray)): + if isinstance(v, np.ndarray): + err_list.extend(v.ravel()) + elif isinstance(v, list): + v=np.array(v) err_list.extend(v.ravel()) else: err_list.append(v)