Skip to content

Commit

Permalink
Merge pull request #1281 from qiboteam/fix_import
Browse files Browse the repository at this point in the history
Fix doc deployment by moving `torch` import
  • Loading branch information
BrunoLiegiBastonLiegi authored Mar 28, 2024
2 parents 8fd5657 + e249e85 commit cb60314
Showing 1 changed file with 24 additions and 29 deletions.
53 changes: 24 additions & 29 deletions src/qibo/backends/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,24 @@
"""PyTorch backend."""

from typing import Union

import numpy as np
import torch

from qibo import __version__
from qibo.backends.npmatrices import NumpyMatrices
from qibo.backends.numpy import NumpyBackend

torch_dtype_dict = {
"int": torch.int32,
"float": torch.float32,
"complex": torch.complex64,
"int32": torch.int32,
"int64": torch.int64,
"float32": torch.float32,
"float64": torch.float64,
"complex64": torch.complex64,
"complex128": torch.complex128,
}


class TorchMatrices(NumpyMatrices):
"""Matrix representation of every gate as a torch Tensor."""

def __init__(self, dtype):
import torch # pylint: disable=import-outside-toplevel

super().__init__(dtype)
self.dtype = torch_dtype_dict[dtype]
self.torch = torch
self.dtype = dtype

def _cast(self, x, dtype):
return torch.as_tensor(x, dtype=dtype)
return self.torch.as_tensor(x, dtype=dtype)

def Unitary(self, u):
return self._cast(u, dtype=self.dtype)
Expand All @@ -39,34 +27,41 @@ def Unitary(self, u):
class PyTorchBackend(NumpyBackend):
def __init__(self):
super().__init__()
import torch # pylint: disable=import-outside-toplevel

self.np = torch

self.name = "pytorch"
self.versions = {
"qibo": __version__,
"numpy": np.__version__,
"torch": torch.__version__,
"torch": self.np.__version__,
}

self.dtype = self._torch_dtype(self.dtype)
self.matrices = TorchMatrices(self.dtype)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.device = self.np.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.nthreads = 0
self.np = torch
self.dtype = torch_dtype_dict[self.dtype]
self.tensor_types = (self.np.Tensor, np.ndarray)

# These functions in Torch works in a different way than numpy or have different names
self.np.transpose = torch.permute
self.np.transpose = self.np.permute
self.np.expand_dims = self.np.unsqueeze
self.np.mod = torch.remainder
self.np.right_shift = torch.bitwise_right_shift
self.np.mod = self.np.remainder
self.np.right_shift = self.np.bitwise_right_shift

def _torch_dtype(self, dtype):
if dtype == "float":
dtype += "32"
return getattr(self.np, dtype)

def set_device(self, device): # pragma: no cover
self.device = device

def cast(
self,
x: Union[torch.Tensor, list[torch.Tensor], np.ndarray, list[np.ndarray]],
dtype: Union[str, torch.dtype, np.dtype, type] = None,
x,
dtype=None,
copy: bool = False,
):
"""Casts input as a Torch tensor of the specified dtype.
Expand All @@ -86,9 +81,9 @@ def cast(
if dtype is None:
dtype = self.dtype
elif isinstance(dtype, type):
dtype = torch_dtype_dict[dtype.__name__]
elif not isinstance(dtype, torch.dtype):
dtype = torch_dtype_dict[str(dtype)]
dtype = self._torch_dtype(dtype.__name__)
elif not isinstance(dtype, self.np.dtype):
dtype = self._torch_dtype(str(dtype))

if isinstance(x, self.np.Tensor):
x = x.to(dtype)
Expand Down

0 comments on commit cb60314

Please sign in to comment.