Skip to content

Commit

Permalink
FEA Extend OpenBLAS controller to support scipy_openblas (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedbb authored Apr 29, 2024
1 parent 5282c0b commit 8dd980b
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 70 deletions.
6 changes: 6 additions & 0 deletions .azure_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ stages:
name: Linux
vmImage: ubuntu-20.04
matrix:
# Linux environment with development versions of numpy and scipy
pylatest_pip_dev:
PACKAGER: 'pip-dev'
PYTHON_VERSION: '*'
CC_OUTER_LOOP: 'gcc'
CC_INNER_LOOP: 'gcc'
# Linux environment to test that packages that comes with Ubuntu 20.04
# are correctly handled.
py38_ubuntu_atlas_gcc_gcc:
Expand Down
8 changes: 8 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
3.5.0 (TDB)
===========

- Added support for the Scientific Python version of OpenBLAS
(https://github.com/MacPython/openblas-libs), which exposes symbols with different
names than the ones of the original OpenBLAS library.
https://github.com/joblib/threadpoolctl/pull/175

3.4.0 (2024-03-20)
==================

Expand Down
9 changes: 9 additions & 0 deletions continuous_integration/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ elif [[ "$PACKAGER" == "pip" ]]; then
pip install numpy scipy
fi

elif [[ "$PACKAGER" == "pip-dev" ]]; then
# Use conda to build an empty python env and then use pip to install
# numpy and scipy dev versions
TO_INSTALL="python=$PYTHON_VERSION pip"
make_conda $TO_INSTALL

dev_anaconda_url=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
pip install --pre --upgrade --timeout=60 --extra-index $dev_anaconda_url numpy scipy

elif [[ "$PACKAGER" == "ubuntu" ]]; then
# Remove the ubuntu toolchain PPA that seems to be invalid:
# https://github.com/scikit-learn/scikit-learn/pull/13934
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/posix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
steps:
- bash: echo "##vso[task.prependpath]$CONDA/bin"
displayName: Add conda to PATH
condition: or(startsWith(variables['PACKAGER'], 'conda'), eq(variables['PACKAGER'], 'pip'))
condition: or(startsWith(variables['PACKAGER'], 'conda'), startsWith(variables['PACKAGER'], 'pip'))
- bash: sudo chown -R $USER $CONDA
# On Hosted macOS, the agent user doesn't have ownership of Miniconda's installation directory/
# We need to take ownership if we want to update conda or install packages globally
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/test_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set -e
if [[ "$PACKAGER" == conda* ]]; then
source activate $VIRTUALENV
conda list
elif [[ "$PACKAGER" == "pip" ]]; then
elif [[ "$PACKAGER" == pip* ]]; then
# we actually use conda to install the base environment:
source activate $VIRTUALENV
pip list
Expand Down
127 changes: 59 additions & 68 deletions threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import re
import sys
import ctypes
import itertools
import textwrap
from typing import final
import warnings
Expand Down Expand Up @@ -111,20 +112,19 @@ def __init__(self, *, filepath=None, prefix=None, parent=None):
self.prefix = prefix
self.filepath = filepath
self.dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD)
self._symbol_prefix, self._symbol_suffix = self._find_affixes()
self.version = self.get_version()
self.set_additional_attributes()

def info(self):
"""Return relevant info wrapped in a dict"""
exposed_attrs = {
hidden_attrs = ("dynlib", "parent", "_symbol_prefix", "_symbol_suffix")
return {
"user_api": self.user_api,
"internal_api": self.internal_api,
"num_threads": self.num_threads,
**vars(self),
**{k: v for k, v in vars(self).items() if k not in hidden_attrs},
}
exposed_attrs.pop("dynlib")
exposed_attrs.pop("parent")
return exposed_attrs

def set_additional_attributes(self):
"""Set additional attributes meant to be exposed in the info dict"""
Expand All @@ -149,96 +149,87 @@ def set_num_threads(self, num_threads):
def get_version(self):
"""Return the version of the shared library"""

def _find_affixes(self):
"""Return the affixes for the symbols of the shared library"""
return "", ""

def _get_symbol(self, name):
"""Return the symbol of the shared library accounding for the affixes"""
return getattr(
self.dynlib, f"{self._symbol_prefix}{name}{self._symbol_suffix}", None
)


class OpenBLASController(LibController):
"""Controller class for OpenBLAS"""

user_api = "blas"
internal_api = "openblas"
filename_prefixes = ("libopenblas", "libblas")
check_symbols = (
"openblas_get_num_threads",
"openblas_get_num_threads64_",
"openblas_set_num_threads",
"openblas_set_num_threads64_",
"openblas_get_config",
"openblas_get_config64_",
"openblas_get_parallel",
"openblas_get_parallel64_",
"openblas_get_corename",
"openblas_get_corename64_",
filename_prefixes = ("libopenblas", "libblas", "libscipy_openblas")

_symbol_prefixes = ("", "scipy_")
_symbol_suffixes = ("", "64_", "_64")

# All variations of "openblas_get_num_threads", accounting for the affixes
check_symbols = tuple(
f"{prefix}openblas_get_num_threads{suffix}"
for prefix, suffix in itertools.product(_symbol_prefixes, _symbol_suffixes)
)

def _find_affixes(self):
for prefix, suffix in itertools.product(
self._symbol_prefixes, self._symbol_suffixes
):
if hasattr(self.dynlib, f"{prefix}openblas_get_num_threads{suffix}"):
return prefix, suffix

def set_additional_attributes(self):
self.threading_layer = self._get_threading_layer()
self.architecture = self._get_architecture()

def get_num_threads(self):
get_func = getattr(
self.dynlib,
"openblas_get_num_threads",
# Symbols differ when built for 64bit integers in Fortran
getattr(self.dynlib, "openblas_get_num_threads64_", lambda: None),
)

return get_func()
get_num_threads_func = self._get_symbol("openblas_get_num_threads")
if get_num_threads_func is not None:
return get_num_threads_func()
return None

def set_num_threads(self, num_threads):
set_func = getattr(
self.dynlib,
"openblas_set_num_threads",
# Symbols differ when built for 64bit integers in Fortran
getattr(
self.dynlib, "openblas_set_num_threads64_", lambda num_threads: None
),
)
return set_func(num_threads)
set_num_threads_func = self._get_symbol("openblas_set_num_threads")
if set_num_threads_func is not None:
return set_num_threads_func(num_threads)
return None

def get_version(self):
# None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS
# did not expose its version before that.
get_config = getattr(
self.dynlib,
"openblas_get_config",
getattr(self.dynlib, "openblas_get_config64_", None),
)
if get_config is None:
get_version_func = self._get_symbol("openblas_get_config")
if get_version_func is not None:
get_version_func.restype = ctypes.c_char_p
config = get_version_func().split()
if config[0] == b"OpenBLAS":
return config[1].decode("utf-8")
return None

get_config.restype = ctypes.c_char_p
config = get_config().split()
if config[0] == b"OpenBLAS":
return config[1].decode("utf-8")
return None

def _get_threading_layer(self):
"""Return the threading layer of OpenBLAS"""
openblas_get_parallel = getattr(
self.dynlib,
"openblas_get_parallel",
getattr(self.dynlib, "openblas_get_parallel64_", None),
)
if openblas_get_parallel is None:
return "unknown"
threading_layer = openblas_get_parallel()
if threading_layer == 2:
return "openmp"
elif threading_layer == 1:
return "pthreads"
return "disabled"
get_threading_layer_func = self._get_symbol("openblas_get_parallel")
if get_threading_layer_func is not None:
threading_layer = get_threading_layer_func()
if threading_layer == 2:
return "openmp"
elif threading_layer == 1:
return "pthreads"
return "disabled"
return "unknown"

def _get_architecture(self):
"""Return the architecture detected by OpenBLAS"""
get_corename = getattr(
self.dynlib,
"openblas_get_corename",
getattr(self.dynlib, "openblas_get_corename64_", None),
)
if get_corename is None:
return None

get_corename.restype = ctypes.c_char_p
return get_corename().decode("utf-8")
get_architecture_func = self._get_symbol("openblas_get_corename")
if get_architecture_func is not None:
get_architecture_func.restype = ctypes.c_char_p
return get_architecture_func().decode("utf-8")
return None


class BLISController(LibController):
Expand Down

0 comments on commit 8dd980b

Please sign in to comment.