Skip to content

Commit

Permalink
Merge pull request #420 from lnccbrown/419-drop-black-in-favor-of-ruf…
Browse files Browse the repository at this point in the history
…f-formatter
  • Loading branch information
digicosmos86 authored May 8, 2024
2 parents 20225e0 + e99817e commit e8f16b1
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 71 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/build_and_publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ jobs:
pip install -e .
- name: Run mypy
run: mypy src
run: mypy src/hssm

- name: Check styling
run: black . --check
- name: Check formatting
run: ruff format --check .

- name: Linting
run: ruff check .
run: ruff check src/hssm

- name: Run tests
run: pytest -n auto -s
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ jobs:
pip install -e .
- name: Run mypy
run: mypy src
run: mypy src/hssm

- name: Check styling
run: black . --check
- name: Check formatting
run: ruff format --check .

- name: Linting
run: ruff check .
run: ruff check src/hssm

- name: Run tests
run: pytest -n auto -s
Expand Down
24 changes: 3 additions & 21 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,13 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.3
rev: v0.4.3
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/psf/black
rev: 23.10.1
hooks:
- id: black-jupyter
args:
- --line-length=88
- --include='\.pyi?$'

# these folders wont be formatted by black
- --exclude="""\.git |
\.__pycache__|
\.hg|
\.mypy_cache|
\.tox|
\.venv|
_build|
buck-out|
build|
dist"""
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.1 # Use the sha / tag you want to point at
rev: v1.10.0 # Use the sha / tag you want to point at
hooks:
- id: mypy
args: [--no-strict-optional, --ignore-missing-imports]
56 changes: 20 additions & 36 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,58 +16,42 @@ keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"]

[tool.poetry.dependencies]
python = ">=3.10,<3.12"
pymc = ">=5.12"
pymc = "^5.14.0"
arviz = "^0.18.0"
onnx = "^1.12.0"
onnx = "^1.16.0"
jax = "^0.4.25"
jaxlib = "^0.4.25"
ssm-simulators = "^0.7.0"
huggingface-hub = "^0.15.1"
ssm-simulators = "^0.7.2"
huggingface-hub = "^0.23.0"
bambi = "^0.13.0"
numpyro = "^0.14.0"
hddm-wfpt = "^0.1.1"
hddm-wfpt = "^0.1.4"
seaborn = "^0.13.2"

[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
black = { extras = ["jupyter"], version = "^23.10.1" }
mypy = "^1.6.1"
pytest = "^8.2.0"
mypy = "^1.10.0"
pre-commit = "^2.20.0"
jupyterlab = "^4.0.2"
ipykernel = "^6.16.0"
ipywidgets = "^8.0.3"
ruff = "^0.1.3"
mkdocs = "^1.4.3"
mkdocs-material = "^9.1.17"
mkdocstrings-python = "^1.1.2"
mkdocs-jupyter = "^0.24.1"
ruff = "^0.4.3"
graphviz = "^0.20.1"
pytest-xdist = "^3.5.0"
pytest-xdist = "^3.6.1"
onnxruntime = "^1.17.1"

[tool.black]
line-length = 88
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''

[tool.isort]
profile = "black"
mkdocs = "^1.6.0"
mkdocs-material = "^9.5.21"
mkdocstrings-python = "^1.10.0"
mkdocs-jupyter = "^0.24.7"

[tool.ruff]
line-length = 88
target-version = "py310"

[tool.ruff.format]
docstring-code-format = true

[tool.ruff.lint]
unfixable = ["E711"]

select = [
Expand Down Expand Up @@ -166,9 +150,9 @@ ignore = [
"TID252",
]

exclude = [".github", "docs", "notebook", "tests"]
exclude = [".github", "docs", "notebook", "tests/*"]

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "numpy"

[tool.mypy]
Expand Down
4 changes: 1 addition & 3 deletions src/hssm/distribution_utils/onnx/onnx2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ def onnx_add(a, b, axis=None, broadcast=True):
return [pt.add(a, b)]


def pytensor_gemm(
a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0
): # pylint: disable=C0103
def pytensor_gemm(a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0): # pylint: disable=C0103
"""Perform the GEMM op.
Numpy-backed implementatio, of ONNX General Matrix Multiply (GeMM) op.
Expand Down
4 changes: 1 addition & 3 deletions src/hssm/distribution_utils/onnx/onnx2xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def onnx_add(a, b, axis=None, broadcast=True):


# Added by HSSM Developers
def onnx_gemm(
a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0
): # pylint: disable=C0103
def onnx_gemm(a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0): # pylint: disable=C0103
"""Numpy-backed implementatio of Onnx Gemm op."""
a = jnp.transpose(a) if transA else a
b = jnp.transpose(b) if transB else b
Expand Down

0 comments on commit e8f16b1

Please sign in to comment.