From 6a12f90ee60c68b5988dcf90dc7f5e0bd394f05b Mon Sep 17 00:00:00 2001 From: mattluutrang Date: Thu, 5 Aug 2021 16:04:20 -0400 Subject: [PATCH] Account for differences between type() and .type --- adept/utils/util.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/adept/utils/util.py b/adept/utils/util.py index aab872a..96bfd0b 100644 --- a/adept/utils/util.py +++ b/adept/utils/util.py @@ -81,7 +81,7 @@ def json_to_dict(file_path): _numpy_to_torch_dtype = { np.float16: torch.float16, - np.float32:torch.float32, + np.float32: torch.float32, np.float64: torch.float64, np.uint8: torch.uint8, np.int8: torch.int8, @@ -96,21 +96,17 @@ def numpy_to_torch_dtype(dtype): if inspect.isclass(dtype): name = dtype else: - name = type(dtype) + name = dtype.type if name not in _numpy_to_torch_dtype: - raise ValueError( - f"Could not convert numpy dtype {dtype.name} to a torch dtype." - ) + name = type(dtype) + if name not in _numpy_to_torch_dtype: + raise ValueError(f"Could not convert numpy dtype {dtype.name} to a torch dtype.") return _numpy_to_torch_dtype[name] def torch_to_numpy_dtype(dtype): if dtype not in _torch_to_numpy_dtype: - raise ValueError( - "Could not convert torch dtype {} to a numpy dtype.".format( - dtype - ) - ) + raise ValueError("Could not convert torch dtype {} to a numpy dtype.".format(dtype)) return _torch_to_numpy_dtype[dtype]