Skip to content

Commit

Permalink
Hotfix for issue #25; setup.py now restricts installation of incompat…
Browse files Browse the repository at this point in the history
…ible jax versions

Also adds fixes for old jax version tests for which jaxlib cannot be installed from pypi anymore.
  • Loading branch information
lumip committed Mar 1, 2024
1 parent f3aa237 commit 709d3dc
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 25 deletions.
8 changes: 1 addition & 7 deletions .github/workflows/build_wheels_and_publish.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: © 2023 Aalto University
# SPDX-FileCopyrightText: © 2023 Aalto University, © 2024 Lukas Prediger

name: Build_pypi_wheels

Expand Down Expand Up @@ -36,12 +36,6 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: 'recursive'

# - uses: actions/setup-python@v4
# name: Install Python
# with:
# python-version: '3.9'

- name: Set up QEMU
if: runner.os == 'Linux' && matrix.arch == 'aarch64'
uses: docker/setup-qemu-action@v2
Expand Down
39 changes: 34 additions & 5 deletions .github/workflows/jax_compatibility_tests.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: © 2023 Aalto University
# SPDX-FileCopyrightText: © 2023 Aalto University, © 2024 Aalto University

name: Jax Compatibility Tests

Expand Down Expand Up @@ -33,7 +33,7 @@ jobs:
python -m pip install pytest
- name: Install dependencies
run: |
python -m pip install "jax[minimum-jaxlib]==${{ matrix.jax-version }}" "numpy < 1.24"
python -m pip install "jax[minimum-jaxlib]==${{ matrix.jax-version }}" "numpy < 1.24" -f https://storage.googleapis.com/jax-releases/jax_releases.html
python -m pip install .
- name: Test with pytest
run: |
Expand All @@ -46,8 +46,7 @@ jobs:
matrix:
jax-version: [
0.2.18, 0.2.19, 0.2.20, 0.2.21,
0.2.22, 0.2.27, 0.3.1, 0.3.13, 0.3.15,
0.3.17, 0.3.23, 0.3.25
0.2.22, 0.2.27
]

steps:
Expand All @@ -64,7 +63,37 @@ jobs:
python -m pip install pytest
- name: Install dependencies
run: |
python -m pip install "jax[minimum-jaxlib]==${{ matrix.jax-version }}"
python -m pip install "jax[minimum-jaxlib]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/jax_releases.html
python -m pip install .
- name: Test with pytest
run: |
pytest tests/
unittests-with-current-jaxlib:
runs-on: ubuntu-20.04
strategy:
fail-fast: false
matrix:
jax-version: [
0.3.1, 0.3.13, 0.3.15,
0.3.17, 0.3.23, 0.3.25, 0.4.14
]

steps:
- uses: actions/checkout@v2
with:
submodules: true
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install environment
run: |
python -m pip install --upgrade pip
python -m pip install pytest
- name: Install dependencies
run: |
python -m pip install "jax[cpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/jax_releases.html
python -m pip install .
- name: Test with pytest
run: |
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/python_unittests.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: © 2021 Aalto University
# SPDX-FileCopyrightText: © 2023 Aalto University, © 2024 Lukas Prediger

name: Python Unittests

Expand Down Expand Up @@ -37,7 +37,7 @@ jobs:
flake8 --count --exit-zero --max-complexity=10 --max-line-length=120 --statistics --ignore=E266 chacha tests
- name: Install dependencies
run: |
python -m pip install -e .[tests]
python -m pip install -e .[tests] -f https://storage.googleapis.com/jax-releases/jax_releases.html
- name: Test with pytest
run: |
pytest --cov=chacha --cov-report term-missing tests/
Expand All @@ -58,7 +58,7 @@ jobs:
python -m pip install mypy
- name: Install dependencies
run: |
python -m pip install .[tests]
python -m pip install .[tests] -f https://storage.googleapis.com/jax-releases/jax_releases.html
- name: Type checking
continue-on-error: true
run: |
Expand Down
2 changes: 2 additions & 0 deletions ChangeLog.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
- 1.4.2:
- Hotfix: preventing installation from pulling incompatible jax versions >0.4.14
- 1.4.1:
- Fix: native module not built for correct Python verison for MacOS.
- Fix: Wheels for MacOS no longer build with OpenMP to avoid complicatons with conda environments.
Expand Down
2 changes: 1 addition & 1 deletion ChangeLog.txt.license
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: © 2021 Aalto University
# SPDX-FileCopyrightText: © 2023 Aalto University, © 2024 Aalto University
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ JAX version known to be compatible with JAX-ChaCha-PRNG:
pip install .[compatible-jax]
```

JAX-ChaCha-PRNG is currently known to work reliably with JAX versions 0.2.12 - 0.3.25 .
JAX-ChaCha-PRNG is currently known to work reliably with JAX versions 0.2.12 - 0.4.14 .
We regularly check the compatible version range, but do not expect new versions of JAX to be immediately tested.

## Versioning
Expand Down
2 changes: 1 addition & 1 deletion README.md.license
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: © 2021 Aalto University
# SPDX-FileCopyrightText: © 2023 Aalto University, © 2024 Aalto University
4 changes: 2 additions & 2 deletions chacha/version.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: © 2023 Aalto University
# SPDX-FileCopyrightText: © 2023 Aalto University, © 2024 Lukas Prediger

MAJOR_VERSION = 1
MINOR_VERSION = 4
PATCH_VERSION = 1
PATCH_VERSION = 2
EXT_VERSION = ""
POST_VERSION = ""

Expand Down
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: © 2022 Aalto University
# SPDX-FileCopyrightText: © 2023 Aalto University, © 2024 Lukas Prediger

import setuptools
from setuptools import Extension
Expand Down Expand Up @@ -84,8 +84,8 @@ def build_extension(self, ext):
spec.loader.exec_module(version_module)

_jax_version_lower_constraint = ' >= 0.2.12'
_jax_version_optimistic_upper_constraint = ', <= 2.0.0'
_jax_version_upper_constraint = ', <= 0.4.8'
_jax_version_optimistic_upper_constraint = ', <= 0.4.14'
_jax_version_upper_constraint = ', <= 0.4.14'

_version = version_module.VERSION
if 'JAX_CHACHA_PRNG_BUILD' in os.environ:
Expand All @@ -101,15 +101,15 @@ def build_extension(self, ext):
long_description_content_type="text/markdown",
url="https://github.com/DPBayes/jax-chacha-prng",
packages=setuptools.find_packages(include=['chacha', 'chacha.*']),
python_requires='>=3.6',
python_requires='>=3.6, <3.12',
install_requires=[
"numpy >= 1.16, < 2",
"deprecation < 3",
f"jax{_jax_version_lower_constraint}{_jax_version_optimistic_upper_constraint}"
],
extras_require={
"tests": [
f"jax[minimum-jaxlib]",
"jax[cpu]",
"pytest"
],
"compatible-jax": [f"jax{_jax_version_lower_constraint}{_jax_version_upper_constraint}"]
Expand Down

0 comments on commit 709d3dc

Please sign in to comment.