From 8dd980b64073349a51e4b681338f9aee721ae79a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Mon, 29 Apr 2024 09:46:19 +0200 Subject: [PATCH] FEA Extend OpenBLAS controller to support scipy_openblas (#175) --- .azure_pipeline.yml | 6 ++ CHANGES.md | 8 ++ continuous_integration/install.sh | 9 ++ continuous_integration/posix.yml | 2 +- continuous_integration/test_script.sh | 2 +- threadpoolctl.py | 127 ++++++++++++-------------- 6 files changed, 84 insertions(+), 70 deletions(-) diff --git a/.azure_pipeline.yml b/.azure_pipeline.yml index 701792a1..29267725 100644 --- a/.azure_pipeline.yml +++ b/.azure_pipeline.yml @@ -58,6 +58,12 @@ stages: name: Linux vmImage: ubuntu-20.04 matrix: + # Linux environment with development versions of numpy and scipy + pylatest_pip_dev: + PACKAGER: 'pip-dev' + PYTHON_VERSION: '*' + CC_OUTER_LOOP: 'gcc' + CC_INNER_LOOP: 'gcc' # Linux environment to test that packages that comes with Ubuntu 20.04 # are correctly handled. py38_ubuntu_atlas_gcc_gcc: diff --git a/CHANGES.md b/CHANGES.md index 1b13d7e5..e19c536d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,11 @@ +3.5.0 (TDB) +=========== + +- Added support for the Scientific Python version of OpenBLAS + (https://github.com/MacPython/openblas-libs), which exposes symbols with different + names than the ones of the original OpenBLAS library. + https://github.com/joblib/threadpoolctl/pull/175 + 3.4.0 (2024-03-20) ================== diff --git a/continuous_integration/install.sh b/continuous_integration/install.sh index 86d77fdb..d0e8402f 100755 --- a/continuous_integration/install.sh +++ b/continuous_integration/install.sh @@ -65,6 +65,15 @@ elif [[ "$PACKAGER" == "pip" ]]; then pip install numpy scipy fi +elif [[ "$PACKAGER" == "pip-dev" ]]; then + # Use conda to build an empty python env and then use pip to install + # numpy and scipy dev versions + TO_INSTALL="python=$PYTHON_VERSION pip" + make_conda $TO_INSTALL + + dev_anaconda_url=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + pip install --pre --upgrade --timeout=60 --extra-index $dev_anaconda_url numpy scipy + elif [[ "$PACKAGER" == "ubuntu" ]]; then # Remove the ubuntu toolchain PPA that seems to be invalid: # https://github.com/scikit-learn/scikit-learn/pull/13934 diff --git a/continuous_integration/posix.yml b/continuous_integration/posix.yml index 5d4fe982..a1598d37 100644 --- a/continuous_integration/posix.yml +++ b/continuous_integration/posix.yml @@ -14,7 +14,7 @@ jobs: steps: - bash: echo "##vso[task.prependpath]$CONDA/bin" displayName: Add conda to PATH - condition: or(startsWith(variables['PACKAGER'], 'conda'), eq(variables['PACKAGER'], 'pip')) + condition: or(startsWith(variables['PACKAGER'], 'conda'), startsWith(variables['PACKAGER'], 'pip')) - bash: sudo chown -R $USER $CONDA # On Hosted macOS, the agent user doesn't have ownership of Miniconda's installation directory/ # We need to take ownership if we want to update conda or install packages globally diff --git a/continuous_integration/test_script.sh b/continuous_integration/test_script.sh index 2864f584..8eeaba19 100755 --- a/continuous_integration/test_script.sh +++ b/continuous_integration/test_script.sh @@ -5,7 +5,7 @@ set -e if [[ "$PACKAGER" == conda* ]]; then source activate $VIRTUALENV conda list -elif [[ "$PACKAGER" == "pip" ]]; then +elif [[ "$PACKAGER" == pip* ]]; then # we actually use conda to install the base environment: source activate $VIRTUALENV pip list diff --git a/threadpoolctl.py b/threadpoolctl.py index 8cbb869e..c53b33a7 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -15,6 +15,7 @@ import re import sys import ctypes +import itertools import textwrap from typing import final import warnings @@ -111,20 +112,19 @@ def __init__(self, *, filepath=None, prefix=None, parent=None): self.prefix = prefix self.filepath = filepath self.dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD) + self._symbol_prefix, self._symbol_suffix = self._find_affixes() self.version = self.get_version() self.set_additional_attributes() def info(self): """Return relevant info wrapped in a dict""" - exposed_attrs = { + hidden_attrs = ("dynlib", "parent", "_symbol_prefix", "_symbol_suffix") + return { "user_api": self.user_api, "internal_api": self.internal_api, "num_threads": self.num_threads, - **vars(self), + **{k: v for k, v in vars(self).items() if k not in hidden_attrs}, } - exposed_attrs.pop("dynlib") - exposed_attrs.pop("parent") - return exposed_attrs def set_additional_attributes(self): """Set additional attributes meant to be exposed in the info dict""" @@ -149,96 +149,87 @@ def set_num_threads(self, num_threads): def get_version(self): """Return the version of the shared library""" + def _find_affixes(self): + """Return the affixes for the symbols of the shared library""" + return "", "" + + def _get_symbol(self, name): + """Return the symbol of the shared library accounding for the affixes""" + return getattr( + self.dynlib, f"{self._symbol_prefix}{name}{self._symbol_suffix}", None + ) + class OpenBLASController(LibController): """Controller class for OpenBLAS""" user_api = "blas" internal_api = "openblas" - filename_prefixes = ("libopenblas", "libblas") - check_symbols = ( - "openblas_get_num_threads", - "openblas_get_num_threads64_", - "openblas_set_num_threads", - "openblas_set_num_threads64_", - "openblas_get_config", - "openblas_get_config64_", - "openblas_get_parallel", - "openblas_get_parallel64_", - "openblas_get_corename", - "openblas_get_corename64_", + filename_prefixes = ("libopenblas", "libblas", "libscipy_openblas") + + _symbol_prefixes = ("", "scipy_") + _symbol_suffixes = ("", "64_", "_64") + + # All variations of "openblas_get_num_threads", accounting for the affixes + check_symbols = tuple( + f"{prefix}openblas_get_num_threads{suffix}" + for prefix, suffix in itertools.product(_symbol_prefixes, _symbol_suffixes) ) + def _find_affixes(self): + for prefix, suffix in itertools.product( + self._symbol_prefixes, self._symbol_suffixes + ): + if hasattr(self.dynlib, f"{prefix}openblas_get_num_threads{suffix}"): + return prefix, suffix + def set_additional_attributes(self): self.threading_layer = self._get_threading_layer() self.architecture = self._get_architecture() def get_num_threads(self): - get_func = getattr( - self.dynlib, - "openblas_get_num_threads", - # Symbols differ when built for 64bit integers in Fortran - getattr(self.dynlib, "openblas_get_num_threads64_", lambda: None), - ) - - return get_func() + get_num_threads_func = self._get_symbol("openblas_get_num_threads") + if get_num_threads_func is not None: + return get_num_threads_func() + return None def set_num_threads(self, num_threads): - set_func = getattr( - self.dynlib, - "openblas_set_num_threads", - # Symbols differ when built for 64bit integers in Fortran - getattr( - self.dynlib, "openblas_set_num_threads64_", lambda num_threads: None - ), - ) - return set_func(num_threads) + set_num_threads_func = self._get_symbol("openblas_set_num_threads") + if set_num_threads_func is not None: + return set_num_threads_func(num_threads) + return None def get_version(self): # None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS # did not expose its version before that. - get_config = getattr( - self.dynlib, - "openblas_get_config", - getattr(self.dynlib, "openblas_get_config64_", None), - ) - if get_config is None: + get_version_func = self._get_symbol("openblas_get_config") + if get_version_func is not None: + get_version_func.restype = ctypes.c_char_p + config = get_version_func().split() + if config[0] == b"OpenBLAS": + return config[1].decode("utf-8") return None - - get_config.restype = ctypes.c_char_p - config = get_config().split() - if config[0] == b"OpenBLAS": - return config[1].decode("utf-8") return None def _get_threading_layer(self): """Return the threading layer of OpenBLAS""" - openblas_get_parallel = getattr( - self.dynlib, - "openblas_get_parallel", - getattr(self.dynlib, "openblas_get_parallel64_", None), - ) - if openblas_get_parallel is None: - return "unknown" - threading_layer = openblas_get_parallel() - if threading_layer == 2: - return "openmp" - elif threading_layer == 1: - return "pthreads" - return "disabled" + get_threading_layer_func = self._get_symbol("openblas_get_parallel") + if get_threading_layer_func is not None: + threading_layer = get_threading_layer_func() + if threading_layer == 2: + return "openmp" + elif threading_layer == 1: + return "pthreads" + return "disabled" + return "unknown" def _get_architecture(self): """Return the architecture detected by OpenBLAS""" - get_corename = getattr( - self.dynlib, - "openblas_get_corename", - getattr(self.dynlib, "openblas_get_corename64_", None), - ) - if get_corename is None: - return None - - get_corename.restype = ctypes.c_char_p - return get_corename().decode("utf-8") + get_architecture_func = self._get_symbol("openblas_get_corename") + if get_architecture_func is not None: + get_architecture_func.restype = ctypes.c_char_p + return get_architecture_func().decode("utf-8") + return None class BLISController(LibController):