diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f3577ef..2f3abbe 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,7 +36,7 @@ jobs: cache: 'pip' - name: Install Dependencies - run: pip install -e '.[dev]' + run: pip install -e '.[tf,jax,test,dev]' - name: Lint run: | diff --git a/README.md b/README.md index 565338e..78f2d4c 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ cd corax python3 -m venv .venv source .venv/bin/activate # Then run -pip install -e '.[dev]' +pip install -e '.[tf,jax,test,dev]' # Install pre-commit hooks if you intend to create PRs. pre-commit install # Install the baselines by running diff --git a/hatch.toml b/hatch.toml index f068108..05de692 100644 --- a/hatch.toml +++ b/hatch.toml @@ -1,6 +1,6 @@ [envs.default] python = "3.9" -features = ["jax", "tf", "dev"] +features = ["jax", "tf", "test", "dev"] pre-install-commands = [ "pip install -r projects/baselines/requirements.txt", ] diff --git a/pyproject.toml b/pyproject.toml index e905788..368a64e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ source = "vcs" [project.optional-dependencies] tf = [ - # TF dependencies "tensorflow-cpu~=2.13.0", "tensorflow-probability~=0.21.0", "tensorflow-datasets~=4.9.3", @@ -30,8 +29,9 @@ tf = [ "rlds", ] jax = [ - "corax[tf]", + "tensorflow-probability", "jax", + "jaxlib", "dm-haiku", "flax", "optax", @@ -39,9 +39,7 @@ jax = [ "chex", "dm_env_wrappers", ] -all = ["corax[tf,jax]"] test = [ - "corax[all]", "pytest", "pytest-xdist", "dill", # required for tfds tests @@ -54,7 +52,7 @@ test = [ "ott-jax", "dm-control", ] -dev = ["corax[test]", "black", "ruff", "pre-commit"] +dev = ["black", "ruff", "pre-commit"] [tool.hatch.build] directory = "dist"