From 93e6d8f6974ec6947243e6b87f9fe5ed7615b530 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Tue, 14 Feb 2023 20:57:39 -0800 Subject: [PATCH 1/3] WIP torch MPS backend --- himalaya/backend/_utils.py | 2 + himalaya/backend/tests/test_backends.py | 8 +- himalaya/backend/torch_mps.py | 128 ++++++++++++++++++++++++ 3 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 himalaya/backend/torch_mps.py diff --git a/himalaya/backend/_utils.py b/himalaya/backend/_utils.py index 97025f1..c599ea3 100644 --- a/himalaya/backend/_utils.py +++ b/himalaya/backend/_utils.py @@ -8,6 +8,7 @@ "cupy", "torch", "torch_cuda", + "torch_mps", ] CURRENT_BACKEND = "numpy" @@ -17,6 +18,7 @@ "cupy": "numpy", "torch": "torch", "torch_cuda": "torch", + "torch_mps": "torch", } diff --git a/himalaya/backend/tests/test_backends.py b/himalaya/backend/tests/test_backends.py index caf858d..f89254e 100644 --- a/himalaya/backend/tests/test_backends.py +++ b/himalaya/backend/tests/test_backends.py @@ -6,6 +6,9 @@ from himalaya.utils import assert_array_almost_equal +BACKENDS_NO_MPS = ALL_BACKENDS.copy() +BACKENDS_NO_MPS.remove("torch_mps") + @pytest.mark.parametrize('backend', ALL_BACKENDS) def test_apply_argmax(backend): backend = set_backend(backend) @@ -151,7 +154,7 @@ def test_svd(backend, full_matrices, three_dim): @pytest.mark.parametrize('backend_in', ALL_BACKENDS) def test_changed_backend_asarray(backend_in, backend_out): backend = set_backend(backend_in) - array_in = backend.asarray([1.2, 2.4, 4.8]) + array_in = backend.asarray([1.2, 2.4, 4.8], dtype="float32") assert array_in is not None # change the backend, and cast to the correct class @@ -183,6 +186,9 @@ def test_changed_backend_asarray(backend_in, backend_out): @pytest.mark.parametrize('backend_out', ALL_BACKENDS) @pytest.mark.parametrize('backend_in', ALL_BACKENDS) def test_asarray_dtype(backend_in, backend_out, dtype_in, dtype_out): + if (backend_in == "torch_mps" and dtype_in == "float64") or \ + (backend_out == "torch_mps" and dtype_out == "float64"): + pytest.skip("torch_mps does not support float64 dtype") backend = set_backend(backend_in) array_in = backend.asarray([1.2, 2.4, 4.8], dtype=dtype_in) assert _dtype_to_str(array_in.dtype) == dtype_in diff --git a/himalaya/backend/torch_mps.py b/himalaya/backend/torch_mps.py new file mode 100644 index 0000000..1e95ef8 --- /dev/null +++ b/himalaya/backend/torch_mps.py @@ -0,0 +1,128 @@ +"""The "torch_mps" GPU backend, based on PyTorch. + +To use this backend, call ``himalaya.backend.set_backend("torch_mps")``. +""" +from .torch import * # noqa +import torch +import warnings + +if not torch.backends.mps.is_available(): + import sys + if "pytest" in sys.modules: # if run through pytest + import pytest + pytest.skip("PyTorch with MPS is not available.") + raise RuntimeError("PyTorch with MPS is not available.") + +from ._utils import _dtype_to_str +from ._utils import warn_if_not_float32 + +############################################################################### + +name = "torch_mps" + + +def randn(*args, **kwargs): + return torch.randn(*args, **kwargs).to("mps") + + +def rand(*args, **kwargs): + return torch.rand(*args, **kwargs).to("mps") + + +def asarray(x, dtype=None, device="mps"): + if dtype is None: + if isinstance(x, torch.Tensor): + dtype = x.dtype + if hasattr(x, "dtype") and hasattr(x.dtype, "name"): + dtype = x.dtype.name + if dtype is not None: + dtype = _dtype_to_str(dtype) + dtype = _check_dtype_torch_mps(dtype) + dtype = getattr(torch, dtype) + if device is None: + if isinstance(x, torch.Tensor): + device = x.device + else: + device = "mps" + try: + tensor = torch.as_tensor(x, dtype=dtype, device=device) + except Exception: + import numpy as np + array = np.asarray(x, dtype=_dtype_to_str(dtype)) + tensor = torch.as_tensor(array, dtype=dtype, device=device) + return tensor + + +_already_warned = [False] + + +def _check_dtype_torch_mps(dtype): + """Warn that X will be cast from float64 to float32 and return the correct dtype""" + if _dtype_to_str(dtype) == "float64": + if not _already_warned[0]: # avoid warning multiple times + warnings.warn( + f"GPU backend torch_mps requires single " + f"precision floats (float32), got input in {dtype}. " + "Data will be automatically cast to float32", UserWarning) + _already_warned[0] = True + return "float32" + return dtype + + +def check_arrays(*all_inputs): + """Change all inputs into Tensors (or list of Tensors) using the same + precision and device as the first one. Some tensors can be None. float64 tensors + are automatically cast to float32 due to the requirement of torch MPS backend. + """ + all_tensors = [] + all_tensors.append(asarray(all_inputs[0])) + dtype = all_tensors[0].dtype + dtype = _check_dtype_torch_mps(dtype) + device = all_tensors[0].device + for tensor in all_inputs[1:]: + if tensor is None: + pass + elif isinstance(tensor, list): + tensor = [asarray(tt, dtype=dtype, device=device) for tt in tensor] + else: + tensor = asarray(tensor, dtype=dtype, device=device) + all_tensors.append(tensor) + return all_tensors + + +def zeros(shape, dtype="float32", device="mps"): + if isinstance(shape, int): + shape = (shape, ) + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + return torch.zeros(shape, dtype=dtype, device=device) + + +def to_cpu(array): + return array.cpu() + + +def to_gpu(array, device="mps"): + return asarray(array, device=device) + + +# Workaround to maintain the same API and allow torch_mps +def std_float64(X, axis=None, demean=True, keepdims=False): + """Compute the standard deviation of X with double precision, + and cast back the result to original dtype. + """ + X_64 = torch.as_tensor(X, dtype=torch.float32) + X_std = (X_64 ** 2).sum(dim=axis, dtype=torch.float32) + if demean: + X_std -= X_64.sum(axis, dtype=torch.float32) ** 2 / X.shape[axis] + X_std = X_std ** .5 + X_std /= (X.shape[axis] ** .5) + + X_std = torch.as_tensor(X_std, dtype=X.dtype, device=X.device) + if keepdims: + X_std = X_std.unsqueeze(dim=axis) + + return X_std + + +eigh = torch.linalg.eigh From 4c792872edd788cc6a956430df54b386d13d2471 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 15 Feb 2023 19:03:11 -0800 Subject: [PATCH 2/3] More fixes and tests for torch MPS backend --- himalaya/backend/tests/test_backends.py | 53 +++++++++++++++++++++---- himalaya/backend/torch.py | 2 + himalaya/backend/torch_mps.py | 24 +++++++---- 3 files changed, 63 insertions(+), 16 deletions(-) diff --git a/himalaya/backend/tests/test_backends.py b/himalaya/backend/tests/test_backends.py index f89254e..9bd0817 100644 --- a/himalaya/backend/tests/test_backends.py +++ b/himalaya/backend/tests/test_backends.py @@ -28,6 +28,29 @@ def test_apply_argmax(backend): ) +@pytest.mark.parametrize('dtype_str', ["float32", "float64"]) +@pytest.mark.parametrize('backend', ALL_BACKENDS) +def test_mean_float64(backend, dtype_str): + backend = set_backend(backend) + for array in [ + backend.randn(1), + backend.randn(10), + backend.randn(10, 1), + backend.randn(10, 4), + backend.randn(10, 1, 8), + backend.randn(10, 4, 8), + ]: + array = backend.asarray(array, dtype=dtype_str) + array_64 = backend.asarray(array, dtype="float64") + for axis in range(array.ndim): + result = backend.mean_float64(array, axis=axis) + reference = backend.to_numpy(array_64).mean( + axis=axis, dtype="float64" + ) + reference = backend.asarray(reference, dtype=dtype_str) + assert_array_almost_equal(result, reference) + + @pytest.mark.parametrize('dtype_str', ["float32", "float64"]) @pytest.mark.parametrize('backend', ALL_BACKENDS) def test_std_float64(backend, dtype_str): @@ -100,15 +123,20 @@ def test_eigh(backend): values, vectors = backend.eigh(kernel) values_ref, vectors_ref = scipy.linalg.eigh(backend.to_numpy(kernel)) - assert_array_almost_equal(values, values_ref) + decimal = 4 if backend.name == "torch_mps" else 6 + assert_array_almost_equal(values, values_ref, decimal=decimal) # vectors can be flipped in sign assert vectors.shape == vectors_ref.shape for ii in range(vectors.shape[1]): try: - assert_array_almost_equal(vectors[:, ii], vectors_ref[:, ii]) + assert_array_almost_equal( + vectors[:, ii], vectors_ref[:, ii], decimal=decimal + ) except AssertionError: - assert_array_almost_equal(vectors[:, ii], -vectors_ref[:, ii]) + assert_array_almost_equal( + vectors[:, ii], -vectors_ref[:, ii], decimal=decimal + ) @pytest.mark.parametrize('backend', ALL_BACKENDS) @@ -129,7 +157,8 @@ def test_svd(backend, full_matrices, three_dim): U_ref, s_ref, V_ref = numpy.linalg.svd(backend.to_numpy(array), full_matrices=full_matrices) - assert_array_almost_equal(s, s_ref) + decimal = 4 if backend.name == "torch_mps" else 6 + assert_array_almost_equal(s, s_ref, decimal=decimal) if not three_dim: U_ref = U_ref[None] @@ -143,11 +172,19 @@ def test_svd(backend, full_matrices, three_dim): for kk in range(U.shape[0]): for ii in range(U.shape[2]): try: - assert_array_almost_equal(U[kk, :, ii], U_ref[kk, :, ii]) - assert_array_almost_equal(V[kk, ii, :], V_ref[kk, ii, :]) + assert_array_almost_equal( + U[kk, :, ii], U_ref[kk, :, ii], decimal=decimal + ) + assert_array_almost_equal( + V[kk, ii, :], V_ref[kk, ii, :], decimal=decimal + ) except AssertionError: - assert_array_almost_equal(U[kk, :, ii], -U_ref[kk, :, ii]) - assert_array_almost_equal(V[kk, ii, :], -V_ref[kk, ii, :]) + assert_array_almost_equal( + U[kk, :, ii], -U_ref[kk, :, ii], decimal=decimal + ) + assert_array_almost_equal( + V[kk, ii, :], -V_ref[kk, ii, :], decimal=decimal + ) @pytest.mark.parametrize('backend_out', ALL_BACKENDS) diff --git a/himalaya/backend/torch.py b/himalaya/backend/torch.py index 9d2d337..d49b276 100644 --- a/himalaya/backend/torch.py +++ b/himalaya/backend/torch.py @@ -176,6 +176,8 @@ def asarray(x, dtype=None, device="cpu"): tensor = torch.as_tensor(x, dtype=dtype, device=device) except Exception: import numpy as np + if x.device != "cpu": + x = x.cpu() array = np.asarray(x, dtype=_dtype_to_str(dtype)) tensor = torch.as_tensor(array, dtype=dtype, device=device) return tensor diff --git a/himalaya/backend/torch_mps.py b/himalaya/backend/torch_mps.py index 1e95ef8..ec0a430 100644 --- a/himalaya/backend/torch_mps.py +++ b/himalaya/backend/torch_mps.py @@ -106,23 +106,31 @@ def to_gpu(array, device="mps"): return asarray(array, device=device) -# Workaround to maintain the same API and allow torch_mps def std_float64(X, axis=None, demean=True, keepdims=False): - """Compute the standard deviation of X with double precision, - and cast back the result to original dtype. + """Compute the standard deviation of X with double precision on CPU, + then cast back the result to original dtype on the original device. """ - X_64 = torch.as_tensor(X, dtype=torch.float32) - X_std = (X_64 ** 2).sum(dim=axis, dtype=torch.float32) + X_64 = torch.as_tensor(X.to("cpu"), dtype=torch.float64, device="cpu") + X_std = (X_64 ** 2).sum(dim=axis, dtype=torch.float64) if demean: - X_std -= X_64.sum(axis, dtype=torch.float32) ** 2 / X.shape[axis] + X_std -= X_64.sum(axis, dtype=torch.float64) ** 2 / X.shape[axis] X_std = X_std ** .5 X_std /= (X.shape[axis] ** .5) - X_std = torch.as_tensor(X_std, dtype=X.dtype, device=X.device) + X_std = torch.as_tensor(X_std, dtype=torch.float32, device=X.device) if keepdims: X_std = X_std.unsqueeze(dim=axis) return X_std -eigh = torch.linalg.eigh +def mean_float64(X, axis=None, keepdims=False): + """Compute the mean of X with double precision on CPU, + then cast back the result to original dtype on the original device. + """ + X_mean = X.to("cpu").sum(axis, dtype=torch.float64) / X.shape[axis] + + X_mean = torch.as_tensor(X_mean, dtype=X.dtype, device=X.device) + if keepdims: + X_mean = X_mean.unsqueeze(dim=axis) + return X_mean From 2e2ba681304cb8588115953ac88a8213e9ad18bd Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 15 Feb 2023 19:14:44 -0800 Subject: [PATCH 3/3] FIX asarray --- himalaya/backend/torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/himalaya/backend/torch.py b/himalaya/backend/torch.py index d49b276..d2a3b55 100644 --- a/himalaya/backend/torch.py +++ b/himalaya/backend/torch.py @@ -176,7 +176,7 @@ def asarray(x, dtype=None, device="cpu"): tensor = torch.as_tensor(x, dtype=dtype, device=device) except Exception: import numpy as np - if x.device != "cpu": + if torch.is_tensor(x) and x.device != "cpu": x = x.cpu() array = np.asarray(x, dtype=_dtype_to_str(dtype)) tensor = torch.as_tensor(array, dtype=dtype, device=device)