Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 18, 2024
1 parent 78ec233 commit 0fa1c3d
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,12 +1966,13 @@ def test_env_device(self, env_name, frame_skip, transformed_out, device):


@pytest.mark.skipif(not _has_brax, reason="brax not installed")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("envname", ["fast"])
class TestBrax:
@pytest.mark.parametrize("requires_grad", [False, True])
def test_brax_constructor(self, envname, requires_grad):
env0 = BraxEnv(envname, requires_grad=requires_grad)
env1 = BraxWrapper(env0._env, requires_grad=requires_grad)
def test_brax_constructor(self, envname, requires_grad, device):
env0 = BraxEnv(envname, requires_grad=requires_grad, device=device)
env1 = BraxWrapper(env0._env, requires_grad=requires_grad, device=device)

env0.set_seed(0)
torch.manual_seed(0)
Expand All @@ -1994,12 +1995,12 @@ def test_brax_constructor(self, envname, requires_grad):
assert r1.requires_grad == requires_grad
assert_allclose_td(r0.data, r1.data)

def test_brax_seeding(self, envname):
def test_brax_seeding(self, envname, device):
final_seed = []
tdreset = []
tdrollout = []
for _ in range(2):
env = BraxEnv(envname)
env = BraxEnv(envname, device=device)
torch.manual_seed(0)
np.random.seed(0)
final_seed.append(env.set_seed(0))
Expand All @@ -2012,8 +2013,8 @@ def test_brax_seeding(self, envname):
assert_allclose_td(*tdrollout)

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_brax_batch_size(self, envname, batch_size):
env = BraxEnv(envname, batch_size=batch_size)
def test_brax_batch_size(self, envname, batch_size, device):
env = BraxEnv(envname, batch_size=batch_size, device=device)
env.set_seed(0)
tdreset = env.reset()
tdrollout = env.rollout(max_steps=50)
Expand All @@ -2023,8 +2024,8 @@ def test_brax_batch_size(self, envname, batch_size):
assert tdrollout.batch_size[:-1] == batch_size

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_brax_spec_rollout(self, envname, batch_size):
env = BraxEnv(envname, batch_size=batch_size)
def test_brax_spec_rollout(self, envname, batch_size, device):
env = BraxEnv(envname, batch_size=batch_size, device=device)
env.set_seed(0)
check_env_specs(env)

Expand All @@ -2036,7 +2037,7 @@ def test_brax_spec_rollout(self, envname, batch_size):
False,
],
)
def test_brax_consistency(self, envname, batch_size, requires_grad):
def test_brax_consistency(self, envname, batch_size, requires_grad, device):
import jax
import jax.numpy as jnp
from torchrl.envs.libs.jax_utils import (
Expand All @@ -2045,7 +2046,7 @@ def test_brax_consistency(self, envname, batch_size, requires_grad):
_tree_flatten,
)

env = BraxEnv(envname, batch_size=batch_size, requires_grad=requires_grad)
env = BraxEnv(envname, batch_size=batch_size, requires_grad=requires_grad, device=device)
env.set_seed(1)
rollout = env.rollout(10)

Expand All @@ -2064,9 +2065,9 @@ def test_brax_consistency(self, envname, batch_size, requires_grad):
torch.testing.assert_close(t1, t2)

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_brax_grad(self, envname, batch_size):
def test_brax_grad(self, envname, batch_size, device):
batch_size = (1,)
env = BraxEnv(envname, batch_size=batch_size, requires_grad=True)
env = BraxEnv(envname, batch_size=batch_size, requires_grad=True, device=device)
env.set_seed(0)
td1 = env.reset()
action = torch.randn(env.action_spec.shape)
Expand All @@ -2080,10 +2081,10 @@ def test_brax_grad(self, envname, batch_size):
@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
@pytest.mark.parametrize("parallel", [False, True])
def test_brax_parallel(
self, envname, batch_size, parallel, maybe_fork_ParallelEnv, n=1
self, envname, batch_size, parallel, maybe_fork_ParallelEnv, device, n=1
):
def make_brax():
env = BraxEnv(envname, batch_size=batch_size, requires_grad=False)
env = BraxEnv(envname, batch_size=batch_size, requires_grad=False, device=device)
env.set_seed(1)
return env

Expand Down

0 comments on commit 0fa1c3d

Please sign in to comment.