From d8569652c0201af1a5f334fc8ebfbce47e35ab44 Mon Sep 17 00:00:00 2001 From: simone bordoni Date: Thu, 21 Mar 2024 15:39:38 +0400 Subject: [PATCH 1/2] fix torch import --- src/qibo/backends/pytorch.py | 57 ++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/src/qibo/backends/pytorch.py b/src/qibo/backends/pytorch.py index 403cea10da..15c206550a 100644 --- a/src/qibo/backends/pytorch.py +++ b/src/qibo/backends/pytorch.py @@ -1,36 +1,24 @@ """PyTorch backend.""" -from typing import Union - import numpy as np -import torch from qibo import __version__ from qibo.backends.npmatrices import NumpyMatrices from qibo.backends.numpy import NumpyBackend -torch_dtype_dict = { - "int": torch.int32, - "float": torch.float32, - "complex": torch.complex64, - "int32": torch.int32, - "int64": torch.int64, - "float32": torch.float32, - "float64": torch.float64, - "complex64": torch.complex64, - "complex128": torch.complex128, -} - class TorchMatrices(NumpyMatrices): """Matrix representation of every gate as a torch Tensor.""" def __init__(self, dtype): + import torch # pylint: disable=import-outside-toplevel + super().__init__(dtype) - self.dtype = torch_dtype_dict[dtype] + self.torch = torch + self.dtype = dtype def _cast(self, x, dtype): - return torch.as_tensor(x, dtype=dtype) + return self.torch.as_tensor(x, dtype=dtype) def Unitary(self, u): return self._cast(u, dtype=self.dtype) @@ -39,34 +27,45 @@ def Unitary(self, u): class PyTorchBackend(NumpyBackend): def __init__(self): super().__init__() + import torch # pylint: disable=import-outside-toplevel + + self.np = torch self.name = "pytorch" self.versions = { "qibo": __version__, "numpy": np.__version__, - "torch": torch.__version__, + "torch": self.np.__version__, } + self.dtype = self._torch_dtype(self.dtype) self.matrices = TorchMatrices(self.dtype) - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.device = self.np.device("cuda:0" if torch.cuda.is_available() else "cpu") self.nthreads = 0 - self.np = torch - self.dtype = torch_dtype_dict[self.dtype] self.tensor_types = (self.np.Tensor, np.ndarray) # These functions in Torch works in a different way than numpy or have different names - self.np.transpose = torch.permute + self.np.transpose = self.np.permute self.np.expand_dims = self.np.unsqueeze - self.np.mod = torch.remainder - self.np.right_shift = torch.bitwise_right_shift + self.np.mod = self.np.remainder + self.np.right_shift = self.np.bitwise_right_shift + + def _torch_dtype(self, dtype): + if dtype == "int": + dtype += "32" + elif dtype == "float": + dtype += "32" + elif dtype == "complex": + dtype += "64" + return getattr(self.np, dtype) def set_device(self, device): # pragma: no cover self.device = device def cast( self, - x: Union[torch.Tensor, list[torch.Tensor], np.ndarray, list[np.ndarray]], - dtype: Union[str, torch.dtype, np.dtype, type] = None, + x, + dtype=None, copy: bool = False, ): """Casts input as a Torch tensor of the specified dtype. @@ -86,9 +85,9 @@ def cast( if dtype is None: dtype = self.dtype elif isinstance(dtype, type): - dtype = torch_dtype_dict[dtype.__name__] - elif not isinstance(dtype, torch.dtype): - dtype = torch_dtype_dict[str(dtype)] + dtype = self._torch_dtype(dtype.__name__) + elif not isinstance(dtype, self.np.dtype): + dtype = self._torch_dtype(str(dtype)) if isinstance(x, self.np.Tensor): x = x.to(dtype) From 51256eeafc549f3b5cb13061c04dfaf36ebcd46e Mon Sep 17 00:00:00 2001 From: simone bordoni Date: Thu, 21 Mar 2024 16:18:46 +0400 Subject: [PATCH 2/2] remove unused lines --- src/qibo/backends/pytorch.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/qibo/backends/pytorch.py b/src/qibo/backends/pytorch.py index 15c206550a..a8233d4e3b 100644 --- a/src/qibo/backends/pytorch.py +++ b/src/qibo/backends/pytorch.py @@ -51,12 +51,8 @@ def __init__(self): self.np.right_shift = self.np.bitwise_right_shift def _torch_dtype(self, dtype): - if dtype == "int": + if dtype == "float": dtype += "32" - elif dtype == "float": - dtype += "32" - elif dtype == "complex": - dtype += "64" return getattr(self.np, dtype) def set_device(self, device): # pragma: no cover