diff --git a/poetry.lock b/poetry.lock index ae65ddd4..6ae89ac8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -24,13 +24,13 @@ files = [ [[package]] name = "astroid" -version = "3.1.0" +version = "3.2.2" description = "An abstract syntax tree for Python with inference support." optional = false python-versions = ">=3.8.0" files = [ - {file = "astroid-3.1.0-py3-none-any.whl", hash = "sha256:951798f922990137ac090c53af473db7ab4e70c770e6d7fae0cec59f74411819"}, - {file = "astroid-3.1.0.tar.gz", hash = "sha256:ac248253bfa4bd924a0de213707e7ebeeb3138abeb48d798784ead1e56d419d4"}, + {file = "astroid-3.2.2-py3-none-any.whl", hash = "sha256:e8a0083b4bb28fcffb6207a3bfc9e5d0a68be951dd7e336d5dcf639c682388c0"}, + {file = "astroid-3.2.2.tar.gz", hash = "sha256:8ead48e31b92b2e217b6c9733a21afafe479d52d6e164dd25fb1a770c7c3cf94"}, ] [package.dependencies] @@ -207,13 +207,13 @@ numpy = ">=1.21,<2.0" [[package]] name = "custatevec-cu12" -version = "1.6.0" +version = "1.6.0.post1" description = "cuStateVec - a component of NVIDIA cuQuantum SDK" optional = true python-versions = "*" files = [ - {file = "custatevec_cu12-1.6.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:e9296be9eec8b4c32407424a6a7fc6e386c38eefbb649beb73f5a517c6dd3704"}, - {file = "custatevec_cu12-1.6.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:4d3505eff236f822adba4d313358b659fe145824b5be2c7283c1d6243b05be00"}, + {file = "custatevec_cu12-1.6.0.post1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4f38175cb6cb9dfa0008e5109e22bf92eeedd3aad843be3ce27ad41b53318f95"}, + {file = "custatevec_cu12-1.6.0.post1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:0c875981de852091f5f3c040d705b284424171aa5df07f843728ff2fae450016"}, ] [[package]] @@ -230,13 +230,13 @@ files = [ [[package]] name = "cutensornet-cu12" -version = "2.4.0" +version = "2.4.0.post1" description = "cuTensorNet - a component of NVIDIA cuQuantum SDK" optional = true python-versions = "*" files = [ - {file = "cutensornet_cu12-2.4.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1c64b4802c57a727c12129142419cbfcf4db8f2534738f1c99106f3c7a346882"}, - {file = "cutensornet_cu12-2.4.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:af08991e732b6f8672b72fef175e3be8d9609403cbc3871c172c61a3d437a4bb"}, + {file = "cutensornet_cu12-2.4.0.post1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0356564387165914ea3c07047a1ffef2d92dac74a97b544e63664c6cef0af599"}, + {file = "cutensornet_cu12-2.4.0.post1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2301e1f6e24fd002b69a89358592898a949a319ba6497bd3d5ff569ec5841d45"}, ] [package.dependencies] @@ -756,13 +756,13 @@ files = [ [[package]] name = "platformdirs" -version = "4.2.1" +version = "4.2.2" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" files = [ - {file = "platformdirs-4.2.1-py3-none-any.whl", hash = "sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1"}, - {file = "platformdirs-4.2.1.tar.gz", hash = "sha256:031cd18d4ec63ec53e82dceaac0417d218a6863f7745dfcc9efe7793b7039bdf"}, + {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, + {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, ] [package.extras] @@ -865,17 +865,17 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pylint" -version = "3.1.0" +version = "3.2.2" description = "python code static checker" optional = false python-versions = ">=3.8.0" files = [ - {file = "pylint-3.1.0-py3-none-any.whl", hash = "sha256:507a5b60953874766d8a366e8e8c7af63e058b26345cfcb5f91f89d987fd6b74"}, - {file = "pylint-3.1.0.tar.gz", hash = "sha256:6a69beb4a6f63debebaab0a3477ecd0f559aa726af4954fc948c51f7a2549e23"}, + {file = "pylint-3.2.2-py3-none-any.whl", hash = "sha256:3f8788ab20bb8383e06dd2233e50f8e08949cfd9574804564803441a4946eab4"}, + {file = "pylint-3.2.2.tar.gz", hash = "sha256:d068ca1dfd735fb92a07d33cb8f288adc0f6bc1287a139ca2425366f7cbe38f8"}, ] [package.dependencies] -astroid = ">=3.1.0,<=3.2.0-dev0" +astroid = ">=3.2.2,<=3.3.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, @@ -980,7 +980,7 @@ torch = ["torch (>=2.1.1,<3.0.0)"] type = "git" url = "https://github.com/qiboteam/qibo.git" reference = "HEAD" -resolved_reference = "c132711dbf1eb26399b4c4aef26921c683ac00cd" +resolved_reference = "2ace99a9217116bc2db6e50d811c6e89cb47fe19" [[package]] name = "scipy" @@ -1092,13 +1092,13 @@ files = [ [[package]] name = "tomlkit" -version = "0.12.4" +version = "0.12.5" description = "Style preserving TOML library" optional = false python-versions = ">=3.7" files = [ - {file = "tomlkit-0.12.4-py3-none-any.whl", hash = "sha256:5cd82d48a3dd89dee1f9d64420aa20ae65cfbd00668d6f094d7578a78efbb77b"}, - {file = "tomlkit-0.12.4.tar.gz", hash = "sha256:7ca1cfc12232806517a8515047ba66a19369e71edf2439d0f5824f91032b6cc3"}, + {file = "tomlkit-0.12.5-py3-none-any.whl", hash = "sha256:af914f5a9c59ed9d0762c7b64d3b5d5df007448eb9cd2edc8a46b1eafead172f"}, + {file = "tomlkit-0.12.5.tar.gz", hash = "sha256:eef34fba39834d4d6b73c9ba7f3e4d1c417a4e56f89a7e96e090dd0d24b8fb3c"}, ] [[package]] @@ -1160,18 +1160,18 @@ files = [ [[package]] name = "zipp" -version = "3.18.1" +version = "3.18.2" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"}, - {file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"}, + {file = "zipp-3.18.2-py3-none-any.whl", hash = "sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e"}, + {file = "zipp-3.18.2.tar.gz", hash = "sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059"}, ] [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] cupy = ["cupy-cuda12x"] diff --git a/src/qibojit/__init__.py b/src/qibojit/__init__.py index dc57e90b..088fdf9a 100644 --- a/src/qibojit/__init__.py +++ b/src/qibojit/__init__.py @@ -1,3 +1,5 @@ import importlib.metadata as im +from qibojit.backends import MetaBackend + __version__ = im.version(__package__) diff --git a/src/qibojit/backends/__init__.py b/src/qibojit/backends/__init__.py index d73f46d8..4d06546e 100644 --- a/src/qibojit/backends/__init__.py +++ b/src/qibojit/backends/__init__.py @@ -1,2 +1,46 @@ +from typing import Union + from qibojit.backends.cpu import NumbaBackend from qibojit.backends.gpu import CupyBackend, CuQuantumBackend + +QibojitBackend = Union[NumbaBackend, CupyBackend, CuQuantumBackend] + +PLATFORMS = ("numba", "cupy", "cuquantum") + + +class MetaBackend: + """Meta-backend class which takes care of loading the qibojit backends.""" + + @staticmethod + def load(platform: str = None) -> QibojitBackend: + """Loads the backend. + + Args: + platform (str): Name of the backend to load: either `numba`, `cupy` or `cuquantum`. + Returns: + qibo.backends.abstract.Backend: The loaded backend. + """ + + if platform == "numba": + return NumbaBackend() + elif platform == "cupy": + return CupyBackend() + elif platform == "cuquantum": + return CuQuantumBackend() + else: # pragma: no cover + try: + return CupyBackend() + except (ModuleNotFoundError, ImportError): + return NumbaBackend() + + def list_available(self) -> dict: + """Lists all the available qibojit backends.""" + available_backends = {} + for platform in PLATFORMS: + try: + MetaBackend.load(platform=platform) + available = True + except: + available = False + available_backends[platform] = available + return available_backends diff --git a/src/qibojit/tests/test_backends.py b/src/qibojit/tests/test_backends.py index d4dfa381..8d60fe7b 100644 --- a/src/qibojit/tests/test_backends.py +++ b/src/qibojit/tests/test_backends.py @@ -1,6 +1,8 @@ import numpy as np import pytest +from qibojit.backends import MetaBackend + def test_device_setter(backend): if backend.platform == "numba": @@ -114,3 +116,8 @@ def test_backend_eigh_sparse(backend, sparse_type, k): eigvals1 = backend.to_numpy(eigvals1) eigvals2 = backend.to_numpy(eigvals2) backend.assert_allclose(sorted(eigvals1), sorted(eigvals2)) + + +def test_metabackend_list_available(): + available_backends = dict(zip(("numba", "cupy", "cuquantum"), (True, False, False))) + assert MetaBackend().list_available() == available_backends