Skip to content

Commit

Permalink
Use uv pip compile and nox (#11)
Browse files Browse the repository at this point in the history
* Fix some typing issues
* Fix more warning messages
* Use uv for compiling dependencies
* Use nox
  • Loading branch information
ethanluoyc authored Feb 21, 2024
1 parent bc78a22 commit f882e9a
Show file tree
Hide file tree
Showing 20 changed files with 813 additions and 2,854 deletions.
54 changes: 9 additions & 45 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,54 +10,18 @@ concurrency:
cancel-in-progress: true

env:
FORCE_COLOR: "1"
PYTHONUNBUFFERED: "1"
FORCE_COLOR: 3

jobs:
test:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout (GitHub)
uses: actions/checkout@v3

- name: pdm cache
uses: actions/cache@v3
with:
path: .cache/pdm
key: ${{ runner.os }}-pdm-${{ hashFiles('pdm.lock') }}
restore-keys: |
${{ runner.os }}-pdm-
- name: TFDS cache
uses: actions/cache@v3
with:
path: .tensorflow_datasets
key: ${{ runner.os }}-tfds-${{ hashFiles('pdm.lock') }}
restore-keys: |
${{ runner.os }}-tfds-
- name: Login to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Build and run Dev Container task
uses: devcontainers/[email protected]
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
# Change this to point to your image name
imageName: ghcr.io/ethanluoyc/corax
cacheFrom: ghcr.io/ethanluoyc/corax
# Change this to be your CI task/script
runCmd: |
# Add multiple commands to run if needed
export TFDS_DATA_DIR=$PWD/.tensorflow_datasets
mkdir -p $TFDS_DATA_DIR
pdm config cache_dir .cache/pdm
pdm sync -G:all
pdm lint
pdm test
pdm run python projects/baselines/baselines/iql/train_test.py
python-version: '3.10'
cache: 'pip' # caching pip dependencies
- name: Run nox
run: |
python -m pip install nox
nox -v
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.10.13
3 changes: 2 additions & 1 deletion corax/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep):
self._actor.observe(action, next_timestep)

def _has_data_for_training(self):
if self._iterator.ready(): # type: ignore
assert self._replay_tables is not None and self._iterator is not None
if self._iterator.ready():
return True
for table, batch_size in zip(
self._replay_tables,
Expand Down
6 changes: 3 additions & 3 deletions corax/agents/jax/decision_transformer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def add_return_to_go(episode):
return episode

def _pad_along_axis(x, padded_size, axis=0, value=0):
pad_width = padded_size - tf.shape(x)[axis]
pad_width = padded_size - tf.shape(x)[axis] # type: ignore
if pad_width <= 0:
return x
padding = [(0, 0)] * len(x.shape.as_list())
Expand All @@ -72,10 +72,10 @@ def pad_steps(steps, max_len):
padded_discounts = _pad_along_axis(steps["discount"], max_len, 0, 2)
padded_timesteps = _pad_along_axis(steps["timestep"], max_len, 0, 0)
mask = _pad_along_axis(
tf.ones(tf.shape(steps["reward"])[0], dtype=bool),
tf.ones(tf.shape(steps["reward"])[0], dtype=bool), # type: ignore
max_len,
0,
False, # type: ignore
False,
)
return {
"observation": padded_obs,
Expand Down
1 change: 1 addition & 0 deletions corax/datasets/reverb.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _make_dataset(unused_idx: tf.Tensor) -> tf.data.Dataset:
datasets, weights=tables.values()
)
else:
assert len(datasets) == 1
dataset = datasets[0]

# Post-process each element if a post-processing function is passed, e.g.
Expand Down
6 changes: 3 additions & 3 deletions corax/jax/running_statistics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from absl.testing import absltest
import jax
from jax.config import config as jax_config # type: ignore
from jax import config as jax_config
import jax.numpy as jnp
import numpy as np
import tree
Expand All @@ -31,7 +31,7 @@
update_and_validate = functools.partial(running_statistics.update, validate_shapes=True)


class TestNestedSpec(NamedTuple):
class _TestNestedSpec(NamedTuple):
# Note: the fields are intentionally in reverse order to test ordering.
a: specs.Array
b: specs.Array
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_pmap_update_nested(self):
tree.map_structure(lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)

def test_different_structure_normalize(self):
spec = TestNestedSpec(
spec = _TestNestedSpec(
a=specs.Array((5,), jnp.float32), b=specs.Array((2,), jnp.float32)
)
state = running_statistics.init_state(spec)
Expand Down
2 changes: 1 addition & 1 deletion corax/utils/counting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def wait(self):
"""Waits on the barrier until all threads have called this method."""
with self._cond:
self._count += 1
self._cond.notifyAll()
self._cond.notify_all()
while self._count < self._num_threads:
self._cond.wait()

Expand Down
2 changes: 1 addition & 1 deletion corax/utils/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def close(self):
def tensor_to_numpy(value: Any):
if hasattr(value, "numpy"):
return value.numpy() # tf.Tensor (TF2).
if hasattr(value, "device_buffer"):
if hasattr(value, "addressable_data"):
return np.asarray(value) # jnp.DeviceArray.
return value

Expand Down
13 changes: 13 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import nox


@nox.session
def test(session):
session.install("-r", "requirements/test.txt", "jax[cpu]", "-e", ".[tf,jax]")
session.run("pytest", "-n", "auto", "corax/")


@nox.session
def lint(session):
session.install("pre-commit")
session.run("pre-commit", "run", "--all-files")
Loading

0 comments on commit f882e9a

Please sign in to comment.