diff --git a/sota-implementations/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml index 6db31a9aa81..a9fb9bfed0c 100644 --- a/sota-implementations/cql/discrete_cql_config.yaml +++ b/sota-implementations/cql/discrete_cql_config.yaml @@ -14,7 +14,7 @@ collector: multi_step: 0 init_random_frames: 1000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 200 annealing_frames: 10000 eps_start: 1.0 diff --git a/sota-implementations/cql/online_config.yaml b/sota-implementations/cql/online_config.yaml index 5a8be9616a0..5c9e649f17f 100644 --- a/sota-implementations/cql/online_config.yaml +++ b/sota-implementations/cql/online_config.yaml @@ -15,7 +15,7 @@ collector: multi_step: 0 init_random_frames: 5_000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 1000 diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index c262a46e563..5abd65f8c0f 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -122,6 +122,12 @@ def make_collector( cudagraph=False, ): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, @@ -129,7 +135,7 @@ def make_collector( frames_per_batch=cfg.collector.frames_per_batch, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, compile_policy={"mode": compile_mode} if compile else False, cudagraph_policy=cudagraph, ) diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index ae5c22b95a5..79f5454a87d 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -393,7 +393,7 @@ def make_odt_model(cfg, device: torch.device | None = None) -> TensorDictModule: with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): td = proof_environment.rollout(max_steps=100) td["action"] = td["next", "action"] - actor(td) + actor(td.to(device)) return actor diff --git a/sota-implementations/gail/config.yaml b/sota-implementations/gail/config.yaml index cf6c8053037..2e057b08220 100644 --- a/sota-implementations/gail/config.yaml +++ b/sota-implementations/gail/config.yaml @@ -41,6 +41,11 @@ gail: gp_lambda: 10.0 device: null +compile: + compile: False + compile_mode: + cudagraphs: False + replay_buffer: dataset: halfcheetah-expert-v2 batch_size: 256 diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index a3c64693fb3..d8187e68635 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -16,13 +16,14 @@ from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer from ppo_utils import eval_model, make_env, make_ppo_models +from tensordict.nn import CudaGraphModule from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.objectives import ClipPPOLoss, GAILLoss +from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers from torchrl.objectives.value.advantages import GAE from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger @@ -69,20 +70,9 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create models (check utils_mujoco.py) - actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = make_ppo_models(cfg.env.env_name, compile=cfg.compile.compile) actor, critic = actor.to(device), critic.to(device) - # Create collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, device), - policy=actor, - frames_per_batch=cfg.ppo.collector.frames_per_batch, - total_frames=cfg.ppo.collector.total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - ) - # Create data buffer data_buffer = TensorDictReplayBuffer( storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch), @@ -111,6 +101,30 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizers actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) + optim = group_optimizers(actor_optim, critic_optim) + del actor_optim, critic_optim + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.ppo.collector.frames_per_batch, + total_frames=cfg.ppo.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + compile_policy={"mode": compile_mode} if compile_mode is not None else False, + cudagraph_policy=cfg.compile.cudagraphs, + ) # Create replay buffer replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) @@ -138,32 +152,9 @@ def main(cfg: "DictConfig"): # noqa: F821 VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) ) test_env.eval() + num_network_updates = torch.zeros((), dtype=torch.int64, device=device) - # Training loop - collected_frames = 0 - num_network_updates = 0 - pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames) - - # extract cfg variables - cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs - cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr - cfg_optim_lr = cfg.ppo.optim.lr - cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon - cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon - cfg_logger_test_interval = cfg.logger.test_interval - cfg_logger_num_test_episodes = cfg.logger.num_test_episodes - - for i, data in enumerate(collector): - - log_info = {} - frames_in_batch = data.numel() - collected_frames += frames_in_batch - pbar.update(data.numel()) - - # Update discriminator - # Get expert data - expert_data = replay_buffer.sample() - expert_data = expert_data.to(device) + def update(data, expert_data, num_network_updates=num_network_updates): # Add collector data to expert data expert_data.set( discriminator_loss.tensor_keys.collector_action, @@ -176,9 +167,9 @@ def main(cfg: "DictConfig"): # noqa: F821 d_loss = discriminator_loss(expert_data) # Backward pass - discriminator_optim.zero_grad() d_loss.get("loss").backward() discriminator_optim.step() + discriminator_optim.zero_grad(set_to_none=True) # Compute discriminator reward with torch.no_grad(): @@ -188,32 +179,19 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set discriminator rewards to tensordict data.set(("next", "reward"), d_rewards) - # Get training rewards and episode lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] - if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( - { - "train/reward": episode_rewards.mean().item(), - "train/episode_length": episode_length.sum().item() - / len(episode_length), - } - ) # Update PPO for _ in range(cfg_loss_ppo_epochs): - # Compute GAE with torch.no_grad(): data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer + data_buffer.empty() data_buffer.extend(data_reshape) - for _, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) + for batch in data_buffer: + optim.zero_grad(set_to_none=True) # Linearly decrease the learning rate and clip epsilon alpha = 1.0 @@ -233,20 +211,66 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_loss = loss["loss_objective"] + loss["loss_entropy"] # Backward pass - actor_loss.backward() - critic_loss.backward() + (actor_loss + critic_loss).backward() # Update the networks - actor_optim.step() - critic_optim.step() - actor_optim.zero_grad() - critic_optim.zero_grad() + optim.step() + return d_loss.detach() + + if cfg.compile.compile: + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + # Training loop + collected_frames = 0 + pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames) + + # extract cfg variables + cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs + cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr + cfg_optim_lr = cfg.ppo.optim.lr + cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon + cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon + cfg_logger_test_interval = cfg.logger.test_interval + cfg_logger_num_test_episodes = cfg.logger.num_test_episodes + + for i, data in enumerate(collector): + + log_info = {} + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + + # Update discriminator + # Get expert data + expert_data = replay_buffer.sample() + expert_data = expert_data.to(device) + + d_loss = update(data, expert_data) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) log_info.update( { - "train/actor_loss": actor_loss.item(), - "train/critic_loss": critic_loss.item(), - "train/discriminator_loss": d_loss["loss"].item(), + # "train/actor_loss": actor_loss.item(), + # "train/critic_loss": critic_loss.item(), + "train/discriminator_loss": d_loss["loss"], "train/lr": alpha * cfg_optim_lr, "train/clip_epsilon": ( alpha * cfg_loss_clip_epsilon diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index f2966ebcc7f..d90b58d35f9 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -42,7 +42,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False) # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment): +def make_ppo_models_state(proof_environment, compile): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -54,6 +54,7 @@ def make_ppo_models_state(proof_environment): "low": proof_environment.action_spec_unbatched.space.low, "high": proof_environment.action_spec_unbatched.space.high, "tanh_loc": False, + "safe_tanh": not compile, } # Define policy architecture @@ -116,9 +117,9 @@ def make_ppo_models_state(proof_environment): return policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, compile): proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment) + actor, critic = make_ppo_models_state(proof_environment, compile=compile) return actor, critic diff --git a/sota-implementations/iql/discrete_iql.yaml b/sota-implementations/iql/discrete_iql.yaml index 9245d4c4832..d28c02cf499 100644 --- a/sota-implementations/iql/discrete_iql.yaml +++ b/sota-implementations/iql/discrete_iql.yaml @@ -15,7 +15,7 @@ collector: total_frames: 20000 init_random_frames: 1000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 200 # logger diff --git a/sota-implementations/iql/online_config.yaml b/sota-implementations/iql/online_config.yaml index 1f7bb361e6c..64ad7466192 100644 --- a/sota-implementations/iql/online_config.yaml +++ b/sota-implementations/iql/online_config.yaml @@ -15,7 +15,7 @@ collector: multi_step: 0 init_random_frames: 5000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 200 # logger diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index a1b9ed54db6..e342ba9281a 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -118,6 +118,12 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, @@ -125,7 +131,7 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/sota-implementations/sac/config.yaml b/sota-implementations/sac/config.yaml index 29586f2e9a7..5cf531a3be2 100644 --- a/sota-implementations/sac/config.yaml +++ b/sota-implementations/sac/config.yaml @@ -12,7 +12,7 @@ collector: init_random_frames: 25000 frames_per_batch: 1000 init_env_steps: 1000 - device: cpu + device: env_per_collector: 1 reset_at_each_iter: False diff --git a/sota-implementations/sac/utils.py b/sota-implementations/sac/utils.py index d1dbb2db791..e827e4e6416 100644 --- a/sota-implementations/sac/utils.py +++ b/sota-implementations/sac/utils.py @@ -105,13 +105,19 @@ def make_environment(cfg, logger=None): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/sota-implementations/td3/config.yaml b/sota-implementations/td3/config.yaml index 7f7854b68b3..5bdf22ea6fa 100644 --- a/sota-implementations/td3/config.yaml +++ b/sota-implementations/td3/config.yaml @@ -13,7 +13,7 @@ collector: init_env_steps: 1000 frames_per_batch: 1000 reset_at_each_iter: False - device: cpu + device: env_per_collector: 1 num_workers: 1 diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index 665c2e0c674..d32375a7649 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -116,6 +116,12 @@ def make_environment(cfg, logger=None): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, @@ -123,7 +129,7 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector