From c9546209285913b7cd2b8e3b02d4b802b40469e4 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 8 May 2024 10:24:52 -0400 Subject: [PATCH 1/3] update pyproject.toml --- .pre-commit-config.yaml | 24 +++--------------- pyproject.toml | 54 +++++++++++++++-------------------------- 2 files changed, 22 insertions(+), 56 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2f91c585..191a2ec0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,31 +5,13 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.3 + rev: v0.4.3 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - - repo: https://github.com/psf/black - rev: 23.10.1 - hooks: - - id: black-jupyter - args: - - --line-length=88 - - --include='\.pyi?$' - - # these folders wont be formatted by black - - --exclude="""\.git | - \.__pycache__| - \.hg| - \.mypy_cache| - \.tox| - \.venv| - _build| - buck-out| - build| - dist""" + - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.1 # Use the sha / tag you want to point at + rev: v1.10.0 # Use the sha / tag you want to point at hooks: - id: mypy args: [--no-strict-optional, --ignore-missing-imports] diff --git a/pyproject.toml b/pyproject.toml index d53be14b..4ef08e0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,58 +16,42 @@ keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"] [tool.poetry.dependencies] python = ">=3.10,<3.12" -pymc = ">=5.12" +pymc = "^5.14.0" arviz = "^0.18.0" -onnx = "^1.12.0" +onnx = "^1.16.0" jax = "^0.4.25" jaxlib = "^0.4.25" -ssm-simulators = "^0.7.0" -huggingface-hub = "^0.15.1" +ssm-simulators = "^0.7.2" +huggingface-hub = "^0.23.0" bambi = "^0.13.0" numpyro = "^0.14.0" -hddm-wfpt = "^0.1.1" +hddm-wfpt = "^0.1.4" seaborn = "^0.13.2" [tool.poetry.group.dev.dependencies] -pytest = "^7.3.1" -black = { extras = ["jupyter"], version = "^23.10.1" } -mypy = "^1.6.1" +pytest = "^8.2.0" +mypy = "^1.10.0" pre-commit = "^2.20.0" jupyterlab = "^4.0.2" ipykernel = "^6.16.0" ipywidgets = "^8.0.3" -ruff = "^0.1.3" -mkdocs = "^1.4.3" -mkdocs-material = "^9.1.17" -mkdocstrings-python = "^1.1.2" -mkdocs-jupyter = "^0.24.1" +ruff = "^0.4.3" graphviz = "^0.20.1" -pytest-xdist = "^3.5.0" +pytest-xdist = "^3.6.1" onnxruntime = "^1.17.1" - -[tool.black] -line-length = 88 -include = '\.pyi?$' -exclude = ''' -/( - \.git - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist -)/ -''' - -[tool.isort] -profile = "black" +mkdocs = "^1.6.0" +mkdocs-material = "^9.5.21" +mkdocstrings-python = "^1.10.0" +mkdocs-jupyter = "^0.24.7" [tool.ruff] line-length = 88 target-version = "py310" + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] unfixable = ["E711"] select = [ @@ -168,7 +152,7 @@ ignore = [ exclude = [".github", "docs", "notebook", "tests"] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "numpy" [tool.mypy] From 908e6404573016d1d3b16365388933a2b195bb0f Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 8 May 2024 10:25:33 -0400 Subject: [PATCH 2/3] Ruff-black incompatibilities --- src/hssm/distribution_utils/onnx/onnx2pt.py | 4 +--- src/hssm/distribution_utils/onnx/onnx2xla.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/hssm/distribution_utils/onnx/onnx2pt.py b/src/hssm/distribution_utils/onnx/onnx2pt.py index cd562f9c..9c04bb33 100644 --- a/src/hssm/distribution_utils/onnx/onnx2pt.py +++ b/src/hssm/distribution_utils/onnx/onnx2pt.py @@ -17,9 +17,7 @@ def onnx_add(a, b, axis=None, broadcast=True): return [pt.add(a, b)] -def pytensor_gemm( - a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0 -): # pylint: disable=C0103 +def pytensor_gemm(a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0): # pylint: disable=C0103 """Perform the GEMM op. Numpy-backed implementatio, of ONNX General Matrix Multiply (GeMM) op. diff --git a/src/hssm/distribution_utils/onnx/onnx2xla.py b/src/hssm/distribution_utils/onnx/onnx2xla.py index 76c386b7..a2142be1 100644 --- a/src/hssm/distribution_utils/onnx/onnx2xla.py +++ b/src/hssm/distribution_utils/onnx/onnx2xla.py @@ -102,9 +102,7 @@ def onnx_add(a, b, axis=None, broadcast=True): # Added by HSSM Developers -def onnx_gemm( - a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0 -): # pylint: disable=C0103 +def onnx_gemm(a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0): # pylint: disable=C0103 """Numpy-backed implementatio of Onnx Gemm op.""" a = jnp.transpose(a) if transA else a b = jnp.transpose(b) if transB else b From e99817ee29ac29f6f45a380190b13adf0cca7c15 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 8 May 2024 10:32:56 -0400 Subject: [PATCH 3/3] update CI scripts --- .github/workflows/build_and_publish.yml | 8 ++++---- .github/workflows/run_tests.yml | 8 ++++---- pyproject.toml | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/build_and_publish.yml b/.github/workflows/build_and_publish.yml index 5dd19985..5a5f6ed1 100644 --- a/.github/workflows/build_and_publish.yml +++ b/.github/workflows/build_and_publish.yml @@ -41,13 +41,13 @@ jobs: pip install -e . - name: Run mypy - run: mypy src + run: mypy src/hssm - - name: Check styling - run: black . --check + - name: Check formatting + run: ruff format --check . - name: Linting - run: ruff check . + run: ruff check src/hssm - name: Run tests run: pytest -n auto -s diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index a143656e..3708c2c5 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -39,13 +39,13 @@ jobs: pip install -e . - name: Run mypy - run: mypy src + run: mypy src/hssm - - name: Check styling - run: black . --check + - name: Check formatting + run: ruff format --check . - name: Linting - run: ruff check . + run: ruff check src/hssm - name: Run tests run: pytest -n auto -s diff --git a/pyproject.toml b/pyproject.toml index 4ef08e0a..c0dd632d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,7 +150,7 @@ ignore = [ "TID252", ] -exclude = [".github", "docs", "notebook", "tests"] +exclude = [".github", "docs", "notebook", "tests/*"] [tool.ruff.lint.pydocstyle] convention = "numpy"