diff --git a/howto/work_with_steps.md b/howto/work_with_steps.md index cdf87522..21e37df3 100644 --- a/howto/work_with_steps.md +++ b/howto/work_with_steps.md @@ -22,6 +22,7 @@ The hyper-parameters that refer to the *policy steps* are: * `exploration_steps`: the number of policy steps in which the agent explores the environment in the P2E algorithms. * `max_episode_steps`: the maximum number of policy steps an episode can last (`max_steps`); when this number is reached a `truncated=True` is returned by the environment. This means that if you decide to have an action repeat greater than one (`action_repeat > 1`), then the environment performs a maximum number of steps equal to: `env_steps = max_steps * action_repeat`$. * `learning_starts`: how many policy steps the agent has to perform before starting the training. During the first `learning_starts` steps the buffer is pre-filled with random actions sampled by the environment. +* `train_on_episode_end`: If set to `true` training occurs only at the end of episodes rather than after every policy step. This configuration is particularly beneficial in scenarios where maintaining a high step rate (steps per second) is crucial, such as in real-time or physical simulations. It is important to note that in distributed training this feature is disabled automatically to avoid conflicts between parallel processes. ## Gradient steps A *gradient step* consists of an update of the parameters of the agent, i.e., a call of the *train* function. The gradient step is proportional to the number of parallel processes, indeed, if there are $n$ parallel processes, `n * per_rank_gradient_steps` calls to the *train* method will be executed. diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 30eadde7..15f2288e 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -645,36 +645,40 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if iter_num >= learning_starts: - ratio_steps = policy_step - prefill_steps * policy_steps_per_iter - per_rank_gradient_steps = ratio(ratio_steps / world_size) - if per_rank_gradient_steps > 0: - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - sample = rb.sample_tensors( - batch_size=cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=per_rank_gradient_steps, - dtype=None, - device=device, - from_numpy=cfg.buffer.from_numpy, - ) # [N_samples, Seq_len, Batch_size, ...] - for i in range(per_rank_gradient_steps): - batch = {k: v[i].float() for k, v in sample.items()} - train( - fabric, - world_model, - actor, - critic, - world_optimizer, - actor_optimizer, - critic_optimizer, - batch, - aggregator, - cfg, - ) - cumulative_per_rank_gradient_steps += 1 - train_step += world_size - if aggregator: - aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step)) + is_distributed = fabric.world_size > 1 + if ( + cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed + ) or not cfg.algo.train_on_episode_end: + ratio_steps = policy_step - prefill_steps * policy_steps_per_iter + per_rank_gradient_steps = ratio(ratio_steps / world_size) + if per_rank_gradient_steps > 0: + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + sample = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=device, + from_numpy=cfg.buffer.from_numpy, + ) # [N_samples, Seq_len, Batch_size, ...] + for i in range(per_rank_gradient_steps): + batch = {k: v[i].float() for k, v in sample.items()} + train( + fabric, + world_model, + actor, + critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + batch, + aggregator, + cfg, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size + if aggregator: + aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step)) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 49f16751..89445c29 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -680,42 +680,46 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if iter_num >= learning_starts: - ratio_steps = policy_step - prefill_steps * policy_steps_per_iter - per_rank_gradient_steps = ratio(ratio_steps / world_size) - if per_rank_gradient_steps > 0: - local_data = rb.sample_tensors( - batch_size=cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=per_rank_gradient_steps, - dtype=None, - device=fabric.device, - from_numpy=cfg.buffer.from_numpy, - ) - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(per_rank_gradient_steps): - if ( - cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq - == 0 - ): - for cp, tcp in zip(critic.module.parameters(), target_critic.module.parameters()): - tcp.data.copy_(cp.data) - batch = {k: v[i].float() for k, v in local_data.items()} - train( - fabric, - world_model, - actor, - critic, - target_critic, - world_optimizer, - actor_optimizer, - critic_optimizer, - batch, - aggregator, - cfg, - actions_dim, - ) - cumulative_per_rank_gradient_steps += 1 - train_step += world_size + is_distributed = fabric.world_size > 1 + if ( + cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed + ) or not cfg.algo.train_on_episode_end: + ratio_steps = policy_step - prefill_steps * policy_steps_per_iter + per_rank_gradient_steps = ratio(ratio_steps / world_size) + if per_rank_gradient_steps > 0: + local_data = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=fabric.device, + from_numpy=cfg.buffer.from_numpy, + ) + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + for i in range(per_rank_gradient_steps): + if ( + cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq + == 0 + ): + for cp, tcp in zip(critic.module.parameters(), target_critic.module.parameters()): + tcp.data.copy_(cp.data) + batch = {k: v[i].float() for k, v in local_data.items()} + train( + fabric, + world_model, + actor, + critic, + target_critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + batch, + aggregator, + cfg, + actions_dim, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index babcebf8..31ed1909 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -658,45 +658,49 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if iter_num >= learning_starts: - ratio_steps = policy_step - prefill_steps * policy_steps_per_iter - per_rank_gradient_steps = ratio(ratio_steps / world_size) - if per_rank_gradient_steps > 0: - local_data = rb.sample_tensors( - cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=per_rank_gradient_steps, - dtype=None, - device=fabric.device, - from_numpy=cfg.buffer.from_numpy, - ) - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(per_rank_gradient_steps): - if ( - cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq - == 0 - ): - tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau - for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): - tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) - batch = {k: v[i].float() for k, v in local_data.items()} - train( - fabric, - world_model, - actor, - critic, - target_critic, - world_optimizer, - actor_optimizer, - critic_optimizer, - batch, - aggregator, - cfg, - is_continuous, - actions_dim, - moments, - ) - cumulative_per_rank_gradient_steps += 1 - train_step += world_size + is_distributed = fabric.world_size > 1 + if ( + cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed + ) or not cfg.algo.train_on_episode_end: + ratio_steps = policy_step - prefill_steps * policy_steps_per_iter + per_rank_gradient_steps = ratio(ratio_steps / world_size) + if per_rank_gradient_steps > 0: + local_data = rb.sample_tensors( + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=fabric.device, + from_numpy=cfg.buffer.from_numpy, + ) + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + for i in range(per_rank_gradient_steps): + if ( + cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq + == 0 + ): + tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau + for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): + tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) + batch = {k: v[i].float() for k, v in local_data.items()} + train( + fabric, + world_model, + actor, + critic, + target_critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + batch, + aggregator, + cfg, + is_continuous, + actions_dim, + moments, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 6c184b45..a171f5aa 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -669,46 +669,50 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if iter_num >= learning_starts: - ratio_steps = policy_step - prefill_steps * policy_steps_per_iter - per_rank_gradient_steps = ratio(ratio_steps / world_size) - if per_rank_gradient_steps > 0: - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - sample = rb.sample_tensors( - batch_size=cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=per_rank_gradient_steps, - dtype=None, - device=device, - from_numpy=cfg.buffer.from_numpy, - ) # [N_samples, Seq_len, Batch_size, ...] - for i in range(per_rank_gradient_steps): - batch = {k: v[i].float() for k, v in sample.items()} - train( - fabric, - world_model, - actor_task, - critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - batch, - aggregator, - cfg, - ensembles=ensembles, - ensemble_optimizer=ensemble_optimizer, - actor_exploration=actor_exploration, - critic_exploration=critic_exploration, - actor_exploration_optimizer=actor_exploration_optimizer, - critic_exploration_optimizer=critic_exploration_optimizer, + is_distributed = fabric.world_size > 1 + if ( + cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed + ) or not cfg.algo.train_on_episode_end: + ratio_steps = policy_step - prefill_steps * policy_steps_per_iter + per_rank_gradient_steps = ratio(ratio_steps / world_size) + if per_rank_gradient_steps > 0: + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + sample = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=device, + from_numpy=cfg.buffer.from_numpy, + ) # [N_samples, Seq_len, Batch_size, ...] + for i in range(per_rank_gradient_steps): + batch = {k: v[i].float() for k, v in sample.items()} + train( + fabric, + world_model, + actor_task, + critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + batch, + aggregator, + cfg, + ensembles=ensembles, + ensemble_optimizer=ensemble_optimizer, + actor_exploration=actor_exploration, + critic_exploration=critic_exploration, + actor_exploration_optimizer=actor_exploration_optimizer, + critic_exploration_optimizer=critic_exploration_optimizer, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size + + if aggregator and not aggregator.disabled: + aggregator.update("Params/exploration_amount_task", actor_task._get_expl_amount(policy_step)) + aggregator.update( + "Params/exploration_amount_exploration", actor_exploration._get_expl_amount(policy_step) ) - cumulative_per_rank_gradient_steps += 1 - train_step += world_size - - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task._get_expl_amount(policy_step)) - aggregator.update( - "Params/exploration_amount_exploration", actor_exploration._get_expl_amount(policy_step) - ) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index b071c977..5266a64f 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -324,44 +324,48 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Train the agent if iter_num >= learning_starts: - ratio_steps = policy_step - prefill_steps * policy_steps_per_iter - per_rank_gradient_steps = ratio(ratio_steps / world_size) - if per_rank_gradient_steps > 0: - if player.actor_type != "task": - player.actor_type = "task" - player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) - for agent_p, p in zip(actor_task.parameters(), player.actor.parameters()): - p.data = agent_p.data - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - sample = rb.sample_tensors( - batch_size=cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=per_rank_gradient_steps, - dtype=None, - device=device, - from_numpy=cfg.buffer.from_numpy, - ) # [N_samples, Seq_len, Batch_size, ...] - for i in range(per_rank_gradient_steps): - batch = {k: v[i].float() for k, v in sample.items()} - train( - fabric, - world_model, - actor_task, - critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - batch, - aggregator, - cfg, + is_distributed = fabric.world_size > 1 + if ( + cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed + ) or not cfg.algo.train_on_episode_end: + ratio_steps = policy_step - prefill_steps * policy_steps_per_iter + per_rank_gradient_steps = ratio(ratio_steps / world_size) + if per_rank_gradient_steps > 0: + if player.actor_type != "task": + player.actor_type = "task" + player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) + for agent_p, p in zip(actor_task.parameters(), player.actor.parameters()): + p.data = agent_p.data + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + sample = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=device, + from_numpy=cfg.buffer.from_numpy, + ) # [N_samples, Seq_len, Batch_size, ...] + for i in range(per_rank_gradient_steps): + batch = {k: v[i].float() for k, v in sample.items()} + train( + fabric, + world_model, + actor_task, + critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + batch, + aggregator, + cfg, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size + if aggregator and not aggregator.disabled: + aggregator.update("Params/exploration_amount_task", actor_task._get_expl_amount(policy_step)) + aggregator.update( + "Params/exploration_amount_exploration", actor_exploration._get_expl_amount(policy_step) ) - cumulative_per_rank_gradient_steps += 1 - train_step += world_size - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task._get_expl_amount(policy_step)) - aggregator.update( - "Params/exploration_amount_exploration", actor_exploration._get_expl_amount(policy_step) - ) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index ecf285aa..7717a9fe 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -813,55 +813,59 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if iter_num >= learning_starts: - ratio_steps = policy_step - prefill_steps * policy_steps_per_iter - per_rank_gradient_steps = ratio(ratio_steps / world_size) - if per_rank_gradient_steps > 0: - local_data = rb.sample_tensors( - batch_size=cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=per_rank_gradient_steps, - dtype=None, - device=fabric.device, - from_numpy=cfg.buffer.from_numpy, - ) - # Start training - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(per_rank_gradient_steps): - if ( - cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq - == 0 - ): - for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): - tcp.data.copy_(cp.data) - for cp, tcp in zip( - critic_exploration.module.parameters(), target_critic_exploration.parameters() + is_distributed = fabric.world_size > 1 + if ( + cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed + ) or not cfg.algo.train_on_episode_end: + ratio_steps = policy_step - prefill_steps * policy_steps_per_iter + per_rank_gradient_steps = ratio(ratio_steps / world_size) + if per_rank_gradient_steps > 0: + local_data = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=fabric.device, + from_numpy=cfg.buffer.from_numpy, + ) + # Start training + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + for i in range(per_rank_gradient_steps): + if ( + cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq + == 0 ): - tcp.data.copy_(cp.data) - batch = {k: v[i].float() for k, v in local_data.items()} - train( - fabric, - world_model, - actor_task, - critic_task, - target_critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - batch, - aggregator, - cfg, - ensembles=ensembles, - ensemble_optimizer=ensemble_optimizer, - actor_exploration=actor_exploration, - critic_exploration=critic_exploration, - target_critic_exploration=target_critic_exploration, - actor_exploration_optimizer=actor_exploration_optimizer, - critic_exploration_optimizer=critic_exploration_optimizer, - is_continuous=is_continuous, - actions_dim=actions_dim, - ) - cumulative_per_rank_gradient_steps += 1 - train_step += world_size + for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): + tcp.data.copy_(cp.data) + for cp, tcp in zip( + critic_exploration.module.parameters(), target_critic_exploration.parameters() + ): + tcp.data.copy_(cp.data) + batch = {k: v[i].float() for k, v in local_data.items()} + train( + fabric, + world_model, + actor_task, + critic_task, + target_critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + batch, + aggregator, + cfg, + ensembles=ensembles, + ensemble_optimizer=ensemble_optimizer, + actor_exploration=actor_exploration, + critic_exploration=critic_exploration, + target_critic_exploration=target_critic_exploration, + actor_exploration_optimizer=actor_exploration_optimizer, + critic_exploration_optimizer=critic_exploration_optimizer, + is_continuous=is_continuous, + actions_dim=actions_dim, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index fcc59fe4..9d1d3f6a 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -347,48 +347,52 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Train the agent if iter_num >= learning_starts: - ratio_steps = policy_step - prefill_steps * policy_steps_per_iter - per_rank_gradient_steps = ratio(ratio_steps / world_size) - if per_rank_gradient_steps > 0: - if player.actor_type != "task": - player.actor_type = "task" - player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) - for agent_p, p in zip(actor_task.parameters(), player.actor.parameters()): - p.data = agent_p.data - local_data = rb.sample_tensors( - batch_size=cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=per_rank_gradient_steps, - dtype=None, - device=fabric.device, - from_numpy=cfg.buffer.from_numpy, - ) - # Start training - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(per_rank_gradient_steps): - if ( - cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq - == 0 - ): - for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): - tcp.data.copy_(cp.data) - batch = {k: v[i].float() for k, v in local_data.items()} - train( - fabric, - world_model, - actor_task, - critic_task, - target_critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - batch, - aggregator, - cfg, - actions_dim=actions_dim, - ) - cumulative_per_rank_gradient_steps += 1 - train_step += world_size + is_distributed = fabric.world_size > 1 + if ( + cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed + ) or not cfg.algo.train_on_episode_end: + ratio_steps = policy_step - prefill_steps * policy_steps_per_iter + per_rank_gradient_steps = ratio(ratio_steps / world_size) + if per_rank_gradient_steps > 0: + if player.actor_type != "task": + player.actor_type = "task" + player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) + for agent_p, p in zip(actor_task.parameters(), player.actor.parameters()): + p.data = agent_p.data + local_data = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=fabric.device, + from_numpy=cfg.buffer.from_numpy, + ) + # Start training + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + for i in range(per_rank_gradient_steps): + if ( + cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq + == 0 + ): + for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): + tcp.data.copy_(cp.data) + batch = {k: v[i].float() for k, v in local_data.items()} + train( + fabric, + world_model, + actor_task, + critic_task, + target_critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + batch, + aggregator, + cfg, + actions_dim=actions_dim, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): diff --git a/sheeprl/configs/algo/dreamer_v1.yaml b/sheeprl/configs/algo/dreamer_v1.yaml index 2aaa7b40..bb4e0ec5 100644 --- a/sheeprl/configs/algo/dreamer_v1.yaml +++ b/sheeprl/configs/algo/dreamer_v1.yaml @@ -11,6 +11,7 @@ horizon: 15 name: dreamer_v1 # Training recipe +train_on_episode_end: false replay_ratio: 0.1 learning_starts: 5000 per_rank_pretrain_steps: 0 diff --git a/sheeprl/configs/algo/dreamer_v2.yaml b/sheeprl/configs/algo/dreamer_v2.yaml index 2261aaf8..dda24294 100644 --- a/sheeprl/configs/algo/dreamer_v2.yaml +++ b/sheeprl/configs/algo/dreamer_v2.yaml @@ -11,6 +11,7 @@ lmbda: 0.95 horizon: 15 # Training recipe +train_on_episode_end: false replay_ratio: 0.2 learning_starts: 1000 per_rank_pretrain_steps: 100 diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index 15f859b4..3e5749be 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -13,6 +13,7 @@ lmbda: 0.95 horizon: 15 # Training recipe +train_on_episode_end: false replay_ratio: 1 learning_starts: 1024 per_rank_pretrain_steps: 0 diff --git a/sheeprl/configs/algo/p2e_dv1.yaml b/sheeprl/configs/algo/p2e_dv1.yaml index 64484cac..07e847d4 100644 --- a/sheeprl/configs/algo/p2e_dv1.yaml +++ b/sheeprl/configs/algo/p2e_dv1.yaml @@ -5,6 +5,7 @@ defaults: name: p2e_dv1 intrinsic_reward_multiplier: 10000 +train_on_episode_end: false player: actor_type: exploration diff --git a/sheeprl/configs/algo/p2e_dv2.yaml b/sheeprl/configs/algo/p2e_dv2.yaml index 43a6f1aa..95593635 100644 --- a/sheeprl/configs/algo/p2e_dv2.yaml +++ b/sheeprl/configs/algo/p2e_dv2.yaml @@ -5,6 +5,7 @@ defaults: name: p2e_dv2 intrinsic_reward_multiplier: 1 +train_on_episode_end: false player: actor_type: exploration diff --git a/sheeprl/configs/algo/p2e_dv3.yaml b/sheeprl/configs/algo/p2e_dv3.yaml index 49ce538c..641d3808 100644 --- a/sheeprl/configs/algo/p2e_dv3.yaml +++ b/sheeprl/configs/algo/p2e_dv3.yaml @@ -5,6 +5,7 @@ defaults: name: p2e_dv3 intrinsic_reward_multiplier: 1 +train_on_episode_end: false player: actor_type: exploration