Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated train_on_episode_end #320

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions howto/work_with_steps.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The hyper-parameters that refer to the *policy steps* are:
* `exploration_steps`: the number of policy steps in which the agent explores the environment in the P2E algorithms.
* `max_episode_steps`: the maximum number of policy steps an episode can last (`max_steps`); when this number is reached a `truncated=True` is returned by the environment. This means that if you decide to have an action repeat greater than one (`action_repeat > 1`), then the environment performs a maximum number of steps equal to: `env_steps = max_steps * action_repeat`$.
* `learning_starts`: how many policy steps the agent has to perform before starting the training. During the first `learning_starts` steps the buffer is pre-filled with random actions sampled by the environment.
* `train_on_episode_end`: If set to `true` training occurs only at the end of episodes rather than after every policy step. This configuration is particularly beneficial in scenarios where maintaining a high step rate (steps per second) is crucial, such as in real-time or physical simulations. It is important to note that in distributed training this feature is disabled automatically to avoid conflicts between parallel processes.

## Gradient steps
A *gradient step* consists of an update of the parameters of the agent, i.e., a call of the *train* function. The gradient step is proportional to the number of parallel processes, indeed, if there are $n$ parallel processes, `n * per_rank_gradient_steps` calls to the *train* method will be executed.
Expand Down
64 changes: 34 additions & 30 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,36 +645,40 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=device,
from_numpy=cfg.buffer.from_numpy,
) # [N_samples, Seq_len, Batch_size, ...]
for i in range(per_rank_gradient_steps):
batch = {k: v[i].float() for k, v in sample.items()}
train(
fabric,
world_model,
actor,
critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size
if aggregator:
aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step))
is_distributed = fabric.world_size > 1
if (
cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @LucaVendruscolo, here I see a problem: if you set cfg.algo.train_on_episode_end = True and you start a distributed training, then you will hav e the following situation:

  • cfg.algo.train_on_episode_end = True
  • reset_envs > 0 = True (let us suppose that the episode ended)
  • not is_distribured = False
  • not cfg.algo.train_on_episode_end = False

This becomes: (True and True and False) or False = False

In this case, the agent will never enter in the if statement, so the agent will never be trained.
What is missing is the modification of the config cfg.algo.train_on_episode_end when is_distributed is True.
For example, by adding near row 385 something like this:

if fabric.world_size > 1:
    cfg.algo.train_on_episode_end = False

Or you need to modify the condition in order to take into account the situation described above

) or not cfg.algo.train_on_episode_end:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=device,
from_numpy=cfg.buffer.from_numpy,
) # [N_samples, Seq_len, Batch_size, ...]
for i in range(per_rank_gradient_steps):
batch = {k: v[i].float() for k, v in sample.items()}
train(
fabric,
world_model,
actor,
critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size
if aggregator:
aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step))

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters):
Expand Down
76 changes: 40 additions & 36 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,42 +680,46 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=fabric.device,
from_numpy=cfg.buffer.from_numpy,
)
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(per_rank_gradient_steps):
if (
cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq
== 0
):
for cp, tcp in zip(critic.module.parameters(), target_critic.module.parameters()):
tcp.data.copy_(cp.data)
batch = {k: v[i].float() for k, v in local_data.items()}
train(
fabric,
world_model,
actor,
critic,
target_critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
actions_dim,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size
is_distributed = fabric.world_size > 1
if (
cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for all the files you modified

) or not cfg.algo.train_on_episode_end:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=fabric.device,
from_numpy=cfg.buffer.from_numpy,
)
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(per_rank_gradient_steps):
if (
cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq
== 0
):
for cp, tcp in zip(critic.module.parameters(), target_critic.module.parameters()):
tcp.data.copy_(cp.data)
batch = {k: v[i].float() for k, v in local_data.items()}
train(
fabric,
world_model,
actor,
critic,
target_critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
actions_dim,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters):
Expand Down
82 changes: 43 additions & 39 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,45 +658,49 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=fabric.device,
from_numpy=cfg.buffer.from_numpy,
)
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(per_rank_gradient_steps):
if (
cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq
== 0
):
tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau
for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()):
tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data)
batch = {k: v[i].float() for k, v in local_data.items()}
train(
fabric,
world_model,
actor,
critic,
target_critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
is_continuous,
actions_dim,
moments,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size
is_distributed = fabric.world_size > 1
if (
cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed
) or not cfg.algo.train_on_episode_end:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=fabric.device,
from_numpy=cfg.buffer.from_numpy,
)
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(per_rank_gradient_steps):
if (
cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq
== 0
):
tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau
for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()):
tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data)
batch = {k: v[i].float() for k, v in local_data.items()}
train(
fabric,
world_model,
actor,
critic,
target_critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
is_continuous,
actions_dim,
moments,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters):
Expand Down
82 changes: 43 additions & 39 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,46 +669,50 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=device,
from_numpy=cfg.buffer.from_numpy,
) # [N_samples, Seq_len, Batch_size, ...]
for i in range(per_rank_gradient_steps):
batch = {k: v[i].float() for k, v in sample.items()}
train(
fabric,
world_model,
actor_task,
critic_task,
world_optimizer,
actor_task_optimizer,
critic_task_optimizer,
batch,
aggregator,
cfg,
ensembles=ensembles,
ensemble_optimizer=ensemble_optimizer,
actor_exploration=actor_exploration,
critic_exploration=critic_exploration,
actor_exploration_optimizer=actor_exploration_optimizer,
critic_exploration_optimizer=critic_exploration_optimizer,
is_distributed = fabric.world_size > 1
if (
cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed
) or not cfg.algo.train_on_episode_end:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=device,
from_numpy=cfg.buffer.from_numpy,
) # [N_samples, Seq_len, Batch_size, ...]
for i in range(per_rank_gradient_steps):
batch = {k: v[i].float() for k, v in sample.items()}
train(
fabric,
world_model,
actor_task,
critic_task,
world_optimizer,
actor_task_optimizer,
critic_task_optimizer,
batch,
aggregator,
cfg,
ensembles=ensembles,
ensemble_optimizer=ensemble_optimizer,
actor_exploration=actor_exploration,
critic_exploration=critic_exploration,
actor_exploration_optimizer=actor_exploration_optimizer,
critic_exploration_optimizer=critic_exploration_optimizer,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size

if aggregator and not aggregator.disabled:
aggregator.update("Params/exploration_amount_task", actor_task._get_expl_amount(policy_step))
aggregator.update(
"Params/exploration_amount_exploration", actor_exploration._get_expl_amount(policy_step)
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size

if aggregator and not aggregator.disabled:
aggregator.update("Params/exploration_amount_task", actor_task._get_expl_amount(policy_step))
aggregator.update(
"Params/exploration_amount_exploration", actor_exploration._get_expl_amount(policy_step)
)

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters):
Expand Down
Loading