Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP torch MPS backend #43

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions himalaya/backend/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"cupy",
"torch",
"torch_cuda",
"torch_mps",
]

CURRENT_BACKEND = "numpy"
Expand All @@ -17,6 +18,7 @@
"cupy": "numpy",
"torch": "torch",
"torch_cuda": "torch",
"torch_mps": "torch",
}


Expand Down
61 changes: 52 additions & 9 deletions himalaya/backend/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from himalaya.utils import assert_array_almost_equal


BACKENDS_NO_MPS = ALL_BACKENDS.copy()
BACKENDS_NO_MPS.remove("torch_mps")

@pytest.mark.parametrize('backend', ALL_BACKENDS)
def test_apply_argmax(backend):
backend = set_backend(backend)
Expand All @@ -25,6 +28,29 @@ def test_apply_argmax(backend):
)


@pytest.mark.parametrize('dtype_str', ["float32", "float64"])
@pytest.mark.parametrize('backend', ALL_BACKENDS)
def test_mean_float64(backend, dtype_str):
backend = set_backend(backend)
for array in [
backend.randn(1),
backend.randn(10),
backend.randn(10, 1),
backend.randn(10, 4),
backend.randn(10, 1, 8),
backend.randn(10, 4, 8),
]:
array = backend.asarray(array, dtype=dtype_str)
array_64 = backend.asarray(array, dtype="float64")
for axis in range(array.ndim):
result = backend.mean_float64(array, axis=axis)
reference = backend.to_numpy(array_64).mean(
axis=axis, dtype="float64"
)
reference = backend.asarray(reference, dtype=dtype_str)
assert_array_almost_equal(result, reference)


@pytest.mark.parametrize('dtype_str', ["float32", "float64"])
@pytest.mark.parametrize('backend', ALL_BACKENDS)
def test_std_float64(backend, dtype_str):
Expand Down Expand Up @@ -97,15 +123,20 @@ def test_eigh(backend):
values, vectors = backend.eigh(kernel)
values_ref, vectors_ref = scipy.linalg.eigh(backend.to_numpy(kernel))

assert_array_almost_equal(values, values_ref)
decimal = 4 if backend.name == "torch_mps" else 6
assert_array_almost_equal(values, values_ref, decimal=decimal)

# vectors can be flipped in sign
assert vectors.shape == vectors_ref.shape
for ii in range(vectors.shape[1]):
try:
assert_array_almost_equal(vectors[:, ii], vectors_ref[:, ii])
assert_array_almost_equal(
vectors[:, ii], vectors_ref[:, ii], decimal=decimal
)
except AssertionError:
assert_array_almost_equal(vectors[:, ii], -vectors_ref[:, ii])
assert_array_almost_equal(
vectors[:, ii], -vectors_ref[:, ii], decimal=decimal
)


@pytest.mark.parametrize('backend', ALL_BACKENDS)
Expand All @@ -126,7 +157,8 @@ def test_svd(backend, full_matrices, three_dim):
U_ref, s_ref, V_ref = numpy.linalg.svd(backend.to_numpy(array),
full_matrices=full_matrices)

assert_array_almost_equal(s, s_ref)
decimal = 4 if backend.name == "torch_mps" else 6
assert_array_almost_equal(s, s_ref, decimal=decimal)

if not three_dim:
U_ref = U_ref[None]
Expand All @@ -140,18 +172,26 @@ def test_svd(backend, full_matrices, three_dim):
for kk in range(U.shape[0]):
for ii in range(U.shape[2]):
try:
assert_array_almost_equal(U[kk, :, ii], U_ref[kk, :, ii])
assert_array_almost_equal(V[kk, ii, :], V_ref[kk, ii, :])
assert_array_almost_equal(
U[kk, :, ii], U_ref[kk, :, ii], decimal=decimal
)
assert_array_almost_equal(
V[kk, ii, :], V_ref[kk, ii, :], decimal=decimal
)
except AssertionError:
assert_array_almost_equal(U[kk, :, ii], -U_ref[kk, :, ii])
assert_array_almost_equal(V[kk, ii, :], -V_ref[kk, ii, :])
assert_array_almost_equal(
U[kk, :, ii], -U_ref[kk, :, ii], decimal=decimal
)
assert_array_almost_equal(
V[kk, ii, :], -V_ref[kk, ii, :], decimal=decimal
)


@pytest.mark.parametrize('backend_out', ALL_BACKENDS)
@pytest.mark.parametrize('backend_in', ALL_BACKENDS)
def test_changed_backend_asarray(backend_in, backend_out):
backend = set_backend(backend_in)
array_in = backend.asarray([1.2, 2.4, 4.8])
array_in = backend.asarray([1.2, 2.4, 4.8], dtype="float32")
assert array_in is not None

# change the backend, and cast to the correct class
Expand Down Expand Up @@ -183,6 +223,9 @@ def test_changed_backend_asarray(backend_in, backend_out):
@pytest.mark.parametrize('backend_out', ALL_BACKENDS)
@pytest.mark.parametrize('backend_in', ALL_BACKENDS)
def test_asarray_dtype(backend_in, backend_out, dtype_in, dtype_out):
if (backend_in == "torch_mps" and dtype_in == "float64") or \
(backend_out == "torch_mps" and dtype_out == "float64"):
pytest.skip("torch_mps does not support float64 dtype")
backend = set_backend(backend_in)
array_in = backend.asarray([1.2, 2.4, 4.8], dtype=dtype_in)
assert _dtype_to_str(array_in.dtype) == dtype_in
Expand Down
2 changes: 2 additions & 0 deletions himalaya/backend/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def asarray(x, dtype=None, device="cpu"):
tensor = torch.as_tensor(x, dtype=dtype, device=device)
except Exception:
import numpy as np
if torch.is_tensor(x) and x.device != "cpu":
x = x.cpu()
array = np.asarray(x, dtype=_dtype_to_str(dtype))
tensor = torch.as_tensor(array, dtype=dtype, device=device)
return tensor
Expand Down
136 changes: 136 additions & 0 deletions himalaya/backend/torch_mps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""The "torch_mps" GPU backend, based on PyTorch.

To use this backend, call ``himalaya.backend.set_backend("torch_mps")``.
"""
from .torch import * # noqa
import torch
import warnings

if not torch.backends.mps.is_available():
import sys
if "pytest" in sys.modules: # if run through pytest
import pytest
pytest.skip("PyTorch with MPS is not available.")
raise RuntimeError("PyTorch with MPS is not available.")

from ._utils import _dtype_to_str
from ._utils import warn_if_not_float32

###############################################################################

name = "torch_mps"


def randn(*args, **kwargs):
return torch.randn(*args, **kwargs).to("mps")


def rand(*args, **kwargs):
return torch.rand(*args, **kwargs).to("mps")


def asarray(x, dtype=None, device="mps"):
if dtype is None:
if isinstance(x, torch.Tensor):
dtype = x.dtype
if hasattr(x, "dtype") and hasattr(x.dtype, "name"):
dtype = x.dtype.name
if dtype is not None:
dtype = _dtype_to_str(dtype)
dtype = _check_dtype_torch_mps(dtype)
dtype = getattr(torch, dtype)
if device is None:
if isinstance(x, torch.Tensor):
device = x.device
else:
device = "mps"
try:
tensor = torch.as_tensor(x, dtype=dtype, device=device)
except Exception:
import numpy as np
array = np.asarray(x, dtype=_dtype_to_str(dtype))
tensor = torch.as_tensor(array, dtype=dtype, device=device)
return tensor


_already_warned = [False]


def _check_dtype_torch_mps(dtype):
"""Warn that X will be cast from float64 to float32 and return the correct dtype"""
if _dtype_to_str(dtype) == "float64":
if not _already_warned[0]: # avoid warning multiple times
warnings.warn(
f"GPU backend torch_mps requires single "
f"precision floats (float32), got input in {dtype}. "
"Data will be automatically cast to float32", UserWarning)
_already_warned[0] = True
return "float32"
return dtype


def check_arrays(*all_inputs):
"""Change all inputs into Tensors (or list of Tensors) using the same
precision and device as the first one. Some tensors can be None. float64 tensors
are automatically cast to float32 due to the requirement of torch MPS backend.
"""
all_tensors = []
all_tensors.append(asarray(all_inputs[0]))
dtype = all_tensors[0].dtype
dtype = _check_dtype_torch_mps(dtype)
device = all_tensors[0].device
for tensor in all_inputs[1:]:
if tensor is None:
pass
elif isinstance(tensor, list):
tensor = [asarray(tt, dtype=dtype, device=device) for tt in tensor]
else:
tensor = asarray(tensor, dtype=dtype, device=device)
all_tensors.append(tensor)
return all_tensors


def zeros(shape, dtype="float32", device="mps"):
if isinstance(shape, int):
shape = (shape, )
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
return torch.zeros(shape, dtype=dtype, device=device)


def to_cpu(array):
return array.cpu()


def to_gpu(array, device="mps"):
return asarray(array, device=device)


def std_float64(X, axis=None, demean=True, keepdims=False):
"""Compute the standard deviation of X with double precision on CPU,
then cast back the result to original dtype on the original device.
"""
X_64 = torch.as_tensor(X.to("cpu"), dtype=torch.float64, device="cpu")
X_std = (X_64 ** 2).sum(dim=axis, dtype=torch.float64)
if demean:
X_std -= X_64.sum(axis, dtype=torch.float64) ** 2 / X.shape[axis]
X_std = X_std ** .5
X_std /= (X.shape[axis] ** .5)

X_std = torch.as_tensor(X_std, dtype=torch.float32, device=X.device)
if keepdims:
X_std = X_std.unsqueeze(dim=axis)

return X_std


def mean_float64(X, axis=None, keepdims=False):
"""Compute the mean of X with double precision on CPU,
then cast back the result to original dtype on the original device.
"""
X_mean = X.to("cpu").sum(axis, dtype=torch.float64) / X.shape[axis]

X_mean = torch.as_tensor(X_mean, dtype=X.dtype, device=X.device)
if keepdims:
X_mean = X_mean.unsqueeze(dim=axis)
return X_mean