Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix doc deployment by moving torch import #1281

Merged
merged 3 commits into from
Mar 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 24 additions & 29 deletions src/qibo/backends/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,24 @@
"""PyTorch backend."""

from typing import Union

import numpy as np
import torch

from qibo import __version__
from qibo.backends.npmatrices import NumpyMatrices
from qibo.backends.numpy import NumpyBackend

torch_dtype_dict = {
"int": torch.int32,
"float": torch.float32,
"complex": torch.complex64,
"int32": torch.int32,
"int64": torch.int64,
"float32": torch.float32,
"float64": torch.float64,
"complex64": torch.complex64,
"complex128": torch.complex128,
}


class TorchMatrices(NumpyMatrices):
"""Matrix representation of every gate as a torch Tensor."""

def __init__(self, dtype):
import torch # pylint: disable=import-outside-toplevel

super().__init__(dtype)
self.dtype = torch_dtype_dict[dtype]
self.torch = torch
self.dtype = dtype

def _cast(self, x, dtype):
return torch.as_tensor(x, dtype=dtype)
return self.torch.as_tensor(x, dtype=dtype)

def Unitary(self, u):
return self._cast(u, dtype=self.dtype)
Expand All @@ -39,34 +27,41 @@ def Unitary(self, u):
class PyTorchBackend(NumpyBackend):
def __init__(self):
super().__init__()
import torch # pylint: disable=import-outside-toplevel

self.np = torch

self.name = "pytorch"
self.versions = {
"qibo": __version__,
"numpy": np.__version__,
"torch": torch.__version__,
"torch": self.np.__version__,
}

self.dtype = self._torch_dtype(self.dtype)
self.matrices = TorchMatrices(self.dtype)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.device = self.np.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.nthreads = 0
self.np = torch
self.dtype = torch_dtype_dict[self.dtype]
self.tensor_types = (self.np.Tensor, np.ndarray)

# These functions in Torch works in a different way than numpy or have different names
self.np.transpose = torch.permute
self.np.transpose = self.np.permute
self.np.expand_dims = self.np.unsqueeze
self.np.mod = torch.remainder
self.np.right_shift = torch.bitwise_right_shift
self.np.mod = self.np.remainder
self.np.right_shift = self.np.bitwise_right_shift

def _torch_dtype(self, dtype):
if dtype == "float":
dtype += "32"
return getattr(self.np, dtype)

def set_device(self, device): # pragma: no cover
self.device = device

def cast(
self,
x: Union[torch.Tensor, list[torch.Tensor], np.ndarray, list[np.ndarray]],
dtype: Union[str, torch.dtype, np.dtype, type] = None,
x,
dtype=None,
copy: bool = False,
):
"""Casts input as a Torch tensor of the specified dtype.
Expand All @@ -86,9 +81,9 @@ def cast(
if dtype is None:
dtype = self.dtype
elif isinstance(dtype, type):
dtype = torch_dtype_dict[dtype.__name__]
elif not isinstance(dtype, torch.dtype):
dtype = torch_dtype_dict[str(dtype)]
dtype = self._torch_dtype(dtype.__name__)
elif not isinstance(dtype, self.np.dtype):
dtype = self._torch_dtype(str(dtype))

if isinstance(x, self.np.Tensor):
x = x.to(dtype)
Expand Down
Loading