Skip to content

Commit

Permalink
Rework optional dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanluoyc committed Oct 26, 2023
1 parent 01689e9 commit 8267905
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
cache: 'pip'

- name: Install Dependencies
run: pip install -e '.[dev]'
run: pip install -e '.[tf,jax,test,dev]'

- name: Lint
run: |
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ cd corax
python3 -m venv .venv
source .venv/bin/activate
# Then run
pip install -e '.[dev]'
pip install -e '.[tf,jax,test,dev]'
# Install pre-commit hooks if you intend to create PRs.
pre-commit install
# Install the baselines by running
Expand Down
2 changes: 1 addition & 1 deletion hatch.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[envs.default]
python = "3.9"
features = ["jax", "tf", "dev"]
features = ["jax", "tf", "test", "dev"]
pre-install-commands = [
"pip install -r projects/baselines/requirements.txt",
]
Expand Down
8 changes: 3 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,24 @@ source = "vcs"

[project.optional-dependencies]
tf = [
# TF dependencies
"tensorflow-cpu~=2.13.0",
"tensorflow-probability~=0.21.0",
"tensorflow-datasets~=4.9.3",
"dm-reverb~=0.12.0",
"rlds",
]
jax = [
"corax[tf]",
"tensorflow-probability",
"jax",
"jaxlib",
"dm-haiku",
"flax",
"optax",
"rlax",
"chex",
"dm_env_wrappers",
]
all = ["corax[tf,jax]"]
test = [
"corax[all]",
"pytest",
"pytest-xdist",
"dill", # required for tfds tests
Expand All @@ -54,7 +52,7 @@ test = [
"ott-jax",
"dm-control",
]
dev = ["corax[test]", "black", "ruff", "pre-commit"]
dev = ["black", "ruff", "pre-commit"]

[tool.hatch.build]
directory = "dist"
Expand Down

0 comments on commit 8267905

Please sign in to comment.