Skip to content

Commit

Permalink
add ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 19, 2023
1 parent fb7a029 commit 49e28cb
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 53 deletions.
7 changes: 2 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
all: test

lint: FORCE
flake8
black --check .
isort --check .
ruff .
python scripts/update_headers.py --check

license: FORCE
python scripts/update_headers.py

format: license FORCE
black .
isort .
ruff . --fix

install: FORCE
pip install -e .[dev,doc,test,examples]
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

if "READTHEDOCS" not in os.environ:
# if developing locally, use numpyro.__version__ as version
from numpyro import __version__ # noqaE402
from numpyro import __version__ # noqa: E402

version = __version__

Expand Down
4 changes: 2 additions & 2 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,11 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields):
)
init_state = new_init_state if init_state is None else init_state
sample_fn, postprocess_fn = self._get_cached_fns()
diagnostics = (
diagnostics = ( # noqa: E731
lambda x: self.sampler.get_diagnostics_str(x[0])
if is_prng_key(rng_key)
else ""
) # noqa: E731
)
init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
lower_idx = self._collection_params["lower"]
upper_idx = self._collection_params["upper"]
Expand Down
67 changes: 67 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
[tool.ruff]
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"docs/src",
"node_modules",
"site-packages",
"venv",
]

# Same as Black.
line-length = 120
indent-width = 4

# Assume Python 3.8
target-version = "py38"

[tool.ruff.lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
select = ["E4", "E7", "E9", "F"]
ignore = ["E203"]

# Allow fix for all enabled rules (when `--fix`) is provided.
fixable = ["ALL"]
unfixable = []

# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

[tool.ruff.format]
# Like Black, use double quotes for strings.
quote-style = "double"

# Like Black, indent with spaces, rather than tabs.
indent-style = "space"

# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false

# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"

[tool.ruff.extend-per-file-ignores]
"numpyro/contrib/tfp/distributions.py" = ["F811"]
"numpyro/distributions/kl.py" = ["F811"]
20 changes: 0 additions & 20 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,23 +1,3 @@
[flake8]
max-line-length = 120
exclude = docs/src, build, dist, .ipynb_checkpoints
ignore = W503,E203
per-file-ignores =
numpyro/contrib/tfp/distributions.py:F811
numpyro/distributions/kl.py:F811

[isort]
profile = black
skip_glob = .ipynb_checkpoints
known_first_party = funsor, numpyro, test
known_third_party = opt_einsum
known_jax = flax, haiku, jax, optax, tensorflow_probability
sections = FUTURE, STDLIB, THIRDPARTY, JAX, FIRSTPARTY, LOCALFOLDER
force_sort_within_sections = true
combine_as_imports = true
multi_line_output = 3
skip=docs

[tool:pytest]
filterwarnings = error
ignore:numpy.ufunc size changed,:RuntimeWarning
Expand Down
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@
],
"test": [
"importlib-metadata<5.0",
"black[jupyter]>=21.8b0",
"flake8",
"importlib-metadata<5.0",
"isort>=5.0",
"ruff>=0.1.8",
"pytest>=4.1",
"pyro-api>=0.1.1",
"scipy>=1.9",
Expand Down
31 changes: 9 additions & 22 deletions test/contrib/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from copy import deepcopy

import numpy as np
from numpy.testing import assert_allclose
import pytest

from jax import random
from jax.tree_util import tree_all, tree_map
from numpy.testing import assert_allclose

import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.contrib.module import (
ParamShape,
Expand All @@ -20,12 +20,9 @@
random_flax_module,
random_haiku_module,
)
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

pytestmark = pytest.mark.filterwarnings(
"ignore:jax.tree_.+ is deprecated:FutureWarning"
)
pytestmark = pytest.mark.filterwarnings("ignore:jax.tree_.+ is deprecated:FutureWarning")


def haiku_model_by_shape(x, y):
Expand Down Expand Up @@ -119,16 +116,12 @@ def test_haiku_module():
100,
100,
)
assert haiku_tr["nn$params"]["value"]["test_haiku_module/w_linear"]["b"].shape == (
100,
)
assert haiku_tr["nn$params"]["value"]["test_haiku_module/w_linear"]["b"].shape == (100,)
assert haiku_tr["nn$params"]["value"]["test_haiku_module/x_linear"]["w"].shape == (
100,
100,
)
assert haiku_tr["nn$params"]["value"]["test_haiku_module/x_linear"]["b"].shape == (
100,
)
assert haiku_tr["nn$params"]["value"]["test_haiku_module/x_linear"]["b"].shape == (100,)


def test_update_params():
Expand All @@ -137,9 +130,7 @@ def test_update_params():
new_params = deepcopy(params)
with handlers.seed(rng_seed=0):
_update_params(params, new_params, prior)
assert params == {
"a": {"b": {"c": {"d": ParamShape(())}, "e": 2}, "f": ParamShape((4,))}
}
assert params == {"a": {"b": {"c": {"d": ParamShape(())}, "e": 2}, "f": ParamShape((4,))}}

tree_all(
tree_map(
Expand Down Expand Up @@ -194,7 +185,7 @@ def test_random_module_mcmc(backend, init, callable_prior):
kwargs = {}

if callable_prior:
prior = (
prior = ( # noqa: E731
lambda name, shape: dist.Cauchy() if name == bias_name else dist.Normal()
)
else:
Expand All @@ -206,9 +197,7 @@ def model(data, labels):
numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels)

kernel = NUTS(model=model)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(2), data, labels)
mcmc.print_summary()
samples = mcmc.get_samples()
Expand All @@ -232,9 +221,7 @@ def fn(x):
if dropout:
x = hk.dropout(hk.next_rng_key(), 0.5, x)
if batchnorm:
x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)(
x, is_training=True
)
x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)(x, is_training=True)
return x

def model():
Expand Down

0 comments on commit 49e28cb

Please sign in to comment.