Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 3, 2024
2 parents 5535a05 + 1cffffe commit 710290c
Show file tree
Hide file tree
Showing 26 changed files with 558 additions and 204 deletions.
4 changes: 2 additions & 2 deletions .github/unittest/linux_sota/scripts/test_sota.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@
logger.backend=
""",
"dreamer": """python sota-implementations/dreamer/dreamer.py \
collector.total_frames=200 \
collector.total_frames=600 \
collector.init_random_frames=10 \
collector.frames_per_batch=200 \
env.n_parallel_envs=1 \
optimization.optim_steps_per_batch=1 \
logger.video=True \
logger.video=False \
logger.backend=csv \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
Expand Down
8 changes: 4 additions & 4 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a
constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such.

- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
some information retrieved from that data. Examples include the ``Recorder`` hook, the reward
logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the
some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward
logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
should be displayed on the progression bar printed on the training log.

Expand Down Expand Up @@ -174,9 +174,9 @@ Trainer and hooks
BatchSubSampler
ClearCudaCache
CountFramesLog
LogReward
LogScaler
OptimizerHook
Recorder
LogValidationReward
ReplayBufferTrainer
RewardNormalizer
SelectKeys
Expand Down
6 changes: 4 additions & 2 deletions sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,10 @@ def compile_rssms(module):

t_collect_init = time.time()

test_env.close()
train_env.close()
if not test_env.is_closed:
test_env.close()
if not train_env.is_closed:
train_env.close()
collector.shutdown()

del test_env
Expand Down
10 changes: 5 additions & 5 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
BatchSubSampler,
ClearCudaCache,
CountFramesLog,
LogReward,
Recorder,
LogScalar,
LogValidationReward,
ReplayBufferTrainer,
RewardNormalizer,
Trainer,
Expand Down Expand Up @@ -331,7 +331,7 @@ def make_trainer(

if recorder is not None:
# create recorder object
recorder_obj = Recorder(
recorder_obj = LogValidationReward(
record_frames=cfg.logger.record_frames,
frame_skip=cfg.env.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -347,7 +347,7 @@ def make_trainer(
# call recorder - could be removed
recorder_obj(None)
# create explorative recorder - could be optional
recorder_obj_explore = Recorder(
recorder_obj_explore = LogValidationReward(
record_frames=cfg.logger.record_frames,
frame_skip=cfg.env.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -369,7 +369,7 @@ def make_trainer(
"post_steps", UpdateWeights(collector, update_weights_interval=1)
)

trainer.register_op("pre_steps_log", LogReward())
trainer.register_op("pre_steps_log", LogScalar())
trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.env.frame_skip))

return trainer
Expand Down
83 changes: 83 additions & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,3 +1927,86 @@ def _step(
def _set_seed(self, seed: Optional[int]):
self.manual_seed = seed
return seed


class EnvWithScalarAction(EnvBase):
def __init__(self, singleton: bool = False, **kwargs):
super().__init__(**kwargs)
self.singleton = singleton
self.action_spec = Bounded(
-1,
1,
shape=(
*self.batch_size,
1,
)
if self.singleton
else self.batch_size,
)
self.observation_spec = Composite(
observation=Unbounded(
shape=(
*self.batch_size,
3,
)
),
shape=self.batch_size,
)
self.done_spec = Composite(
done=Unbounded(self.batch_size + (1,), dtype=torch.bool),
terminated=Unbounded(self.batch_size + (1,), dtype=torch.bool),
truncated=Unbounded(self.batch_size + (1,), dtype=torch.bool),
shape=self.batch_size,
)
self.reward_spec = Unbounded(
shape=(
*self.batch_size,
1,
)
)

def _reset(self, td: TensorDict):
return TensorDict(
observation=torch.randn(*self.batch_size, 3, device=self.device),
done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device),
truncated=torch.zeros(
*self.batch_size, 1, dtype=torch.bool, device=self.device
),
terminated=torch.zeros(
*self.batch_size, 1, dtype=torch.bool, device=self.device
),
device=self.device,
)

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
return TensorDict(
observation=torch.randn(*self.batch_size, 3, device=self.device),
reward=torch.zeros(1, device=self.device),
done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device),
truncated=torch.zeros(
*self.batch_size, 1, dtype=torch.bool, device=self.device
),
terminated=torch.zeros(
*self.batch_size, 1, dtype=torch.bool, device=self.device
),
)

def _set_seed(self, seed: Optional[int]):
...


class EnvThatDoesNothing(EnvBase):
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
return TensorDict(batch_size=self.batch_size, device=self.device)

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
return TensorDict(batch_size=self.batch_size, device=self.device)

def _set_seed(self, seed):
...
2 changes: 2 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4493,6 +4493,7 @@ def test_sac_terminating(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
skip_done_states=True,
)
loss.set_keys(
action=action_key,
Expand Down Expand Up @@ -5204,6 +5205,7 @@ def test_discrete_sac_terminating(
qvalue_network=qvalue,
num_actions=actor.spec[action_key].space.n,
action_space="one-hot",
skip_done_states=True,
)
loss.set_keys(
action=action_key,
Expand Down
30 changes: 30 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
DiscreteActionConvMockEnvNumpy,
DiscreteActionVecMockEnv,
DummyModelBasedEnvBase,
EnvThatDoesNothing,
EnvWithDynamicSpec,
EnvWithMetadata,
HeterogeneousCountingEnv,
Expand Down Expand Up @@ -81,6 +82,7 @@
DiscreteActionConvMockEnvNumpy,
DiscreteActionVecMockEnv,
DummyModelBasedEnvBase,
EnvThatDoesNothing,
EnvWithDynamicSpec,
EnvWithMetadata,
HeterogeneousCountingEnv,
Expand Down Expand Up @@ -3554,6 +3556,34 @@ def test_auto_spec():
env.check_env_specs(tensordict=td.copy())


def test_env_that_does_nothing():
env = EnvThatDoesNothing()
env.check_env_specs()
r = env.rollout(3)
r.exclude(
"done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True
)
assert r.is_empty()
p_env = SerialEnv(2, EnvThatDoesNothing)
p_env.check_env_specs()
r = p_env.rollout(3)
r.exclude(
"done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True
)
assert r.is_empty()
p_env = ParallelEnv(2, EnvThatDoesNothing)
try:
p_env.check_env_specs()
r = p_env.rollout(3)
r.exclude(
"done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True
)
assert r.is_empty()
finally:
p_env.close()
del p_env


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
27 changes: 13 additions & 14 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
TensorDictReplayBuffer,
)
from torchrl.envs.libs.gym import _has_gym
from torchrl.trainers import Recorder, Trainer
from torchrl.trainers import LogValidationReward, Trainer
from torchrl.trainers.helpers import transformed_env_constructor
from torchrl.trainers.trainers import (
_has_tqdm,
_has_ts,
BatchSubSampler,
CountFramesLog,
LogReward,
LogScalar,
mask_batch,
OptimizerHook,
ReplayBufferTrainer,
Expand Down Expand Up @@ -638,7 +638,7 @@ def test_log_reward(self, logname, pbar):
trainer = mocking_trainer()
trainer.collected_frames = 0

log_reward = LogReward(logname, log_pbar=pbar)
log_reward = LogScalar(logname, log_pbar=pbar)
trainer.register_op("pre_steps_log", log_reward)
td = TensorDict({REWARD_KEY: torch.ones(3)}, [3])
trainer._pre_steps_log_hook(td)
Expand All @@ -654,7 +654,7 @@ def test_log_reward_register(self, logname, pbar):
trainer = mocking_trainer()
trainer.collected_frames = 0

log_reward = LogReward(logname, log_pbar=pbar)
log_reward = LogScalar(logname, log_pbar=pbar)
log_reward.register(trainer)
td = TensorDict({REWARD_KEY: torch.ones(3)}, [3])
trainer._pre_steps_log_hook(td)
Expand Down Expand Up @@ -873,7 +873,7 @@ def test_recorder(self, N=8):
logger=logger,
)()

recorder = Recorder(
recorder = LogValidationReward(
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
Expand Down Expand Up @@ -919,13 +919,12 @@ def test_recorder_load(self, backend, N=8):
os.environ["CKPT_BACKEND"] = backend
state_dict_has_been_called = [False]
load_state_dict_has_been_called = [False]
Recorder.state_dict, Recorder_state_dict = _fun_checker(
Recorder.state_dict, state_dict_has_been_called
LogValidationReward.state_dict, Recorder_state_dict = _fun_checker(
LogValidationReward.state_dict, state_dict_has_been_called
)
(LogValidationReward.load_state_dict, Recorder_load_state_dict,) = _fun_checker(
LogValidationReward.load_state_dict, load_state_dict_has_been_called
)
(
Recorder.load_state_dict,
Recorder_load_state_dict,
) = _fun_checker(Recorder.load_state_dict, load_state_dict_has_been_called)

args = self._get_args()

Expand All @@ -948,7 +947,7 @@ def _make_recorder_and_trainer(tmpdirname):
)()
environment.rollout(2)

recorder = Recorder(
recorder = LogValidationReward(
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
Expand All @@ -969,8 +968,8 @@ def _make_recorder_and_trainer(tmpdirname):
assert recorder2._count == 8
assert state_dict_has_been_called[0]
assert load_state_dict_has_been_called[0]
Recorder.state_dict = Recorder_state_dict
Recorder.load_state_dict = Recorder_load_state_dict
LogValidationReward.state_dict = Recorder_state_dict
LogValidationReward.load_state_dict = Recorder_load_state_dict


def test_updateweights():
Expand Down
Loading

0 comments on commit 710290c

Please sign in to comment.