Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/logging #91

Merged
merged 23 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
96aa219
Add IndependentMeanMetric
belerico Sep 8, 2023
26d49db
Merge branch 'fix/logging' of https://github.com/Eclectic-Sheep/sheep…
belerico Sep 8, 2023
590b9f1
Added RankIndependentMetricAggregator and plot per env cum reward
belerico Sep 9, 2023
4c5fd63
Add RunningMean to PPO
belerico Sep 11, 2023
0698a7a
Call redue every time in MetricAggregator
belerico Sep 11, 2023
f067daf
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Sep 11, 2023
533f1c0
Merge branch 'fix/logging' of https://github.com/Eclectic-Sheep/sheep…
belerico Sep 11, 2023
b6f75ca
For PPO, add timer + change timer logging + add print config
belerico Sep 11, 2023
881c877
Fix logging PPO recurrent
belerico Sep 12, 2023
e003389
Chose which metric used to track time
belerico Sep 12, 2023
d4072d2
Fix aggregator.to(device) + check when we measure train time
belerico Sep 12, 2023
fb1a0d9
Fix SAC check if first info are sent to the trainers
belerico Sep 12, 2023
61bcfc0
Update DroQ logging
belerico Sep 12, 2023
38c6862
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Sep 12, 2023
94bc646
Update logging SACAE
belerico Sep 12, 2023
247dcbd
Add sps for Dreamer-V1
belerico Sep 13, 2023
0eb05bb
From `reconstruction_loss` to `world_model_loss` in all Dreamers + ad…
belerico Sep 13, 2023
17ad8d9
Add sps to Draemer-V3 algo
belerico Sep 13, 2023
bb55fa1
Add print_config + Add sps Dreamer-V3
belerico Sep 13, 2023
6eb156a
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Sep 13, 2023
70fd28b
[skip ci] Removed unused metric
belerico Sep 13, 2023
5a50c5b
Update log cumulative rew on rank-0
belerico Sep 14, 2023
7ab2787
Fix sps: sps_train is computed globally while sps_env is computed loc…
belerico Sep 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/architecture_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def main():
if devices is None or devices in ("1", "2"):
raise RuntimeError(
"Please run the script with the number of devices greater than 2: "
"`lightning run model --devices=3 sheeprl.py ...`"
"`lightning run model --devices=3 examples/architecture_template.py ...`"
)

world_collective = TorchCollective()
Expand Down
2 changes: 0 additions & 2 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def main(cfg: DictConfig):
{
"Rewards/rew_avg": MeanMetric(),
"Game/ep_len_avg": MeanMetric(),
"Time/step_per_second": MeanMetric(),
"Loss/value_loss": MeanMetric(),
"Loss/policy_loss": MeanMetric(),
"Loss/entropy_loss": MeanMetric(),
Expand Down Expand Up @@ -222,7 +221,6 @@ def main(cfg: DictConfig):

# Log metrics
metrics_dict = aggregator.compute()
fabric.log("Time/step_per_second", int(global_step / (time.perf_counter() - start_time)), global_step)
fabric.log_dict(metrics_dict, global_step)
aggregator.reset()

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ dependencies = [
"tensorboard>=2.10",
"python-dotenv>=1.0.0",
"lightning==2.0.*",
"lightning-utilities<0.9",
"lightning-utilities<=0.9",
"hydra-core==1.3.0",
"torchmetrics==1.1.*",
"rich==13.5.*",
"opencv-python==4.8.0.*"
]
dynamic = ["version"]
Expand Down
218 changes: 123 additions & 95 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py

Large diffs are not rendered by default.

245 changes: 136 additions & 109 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py

Large diffs are not rendered by default.

229 changes: 128 additions & 101 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py

Large diffs are not rendered by default.

205 changes: 120 additions & 85 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import time
import warnings
from math import prod

Expand All @@ -15,7 +14,7 @@
from torch.optim import Optimizer
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import BatchSampler
from torchmetrics import MeanMetric
from torchmetrics import MeanMetric, SumMetric

from sheeprl.algos.droq.agent import DROQAgent, DROQCritic
from sheeprl.algos.sac.agent import SACActor
Expand All @@ -27,6 +26,8 @@
from sheeprl.utils.logger import create_tensorboard_logger
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import print_config


def train(
Expand All @@ -44,86 +45,97 @@ def train(
sample = rb.sample(
cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs
)
gathered_data = fabric.all_gather(sample.to_dict())
gathered_data = make_tensordict(gathered_data).view(-1)
critic_data = fabric.all_gather(sample.to_dict())
critic_data = make_tensordict(critic_data).view(-1)
if fabric.world_size > 1:
dist_sampler: DistributedSampler = DistributedSampler(
range(len(gathered_data)),
range(len(critic_data)),
num_replicas=fabric.world_size,
rank=fabric.global_rank,
shuffle=True,
seed=cfg.seed,
drop_last=False,
)
sampler: BatchSampler = BatchSampler(sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False)
critic_sampler: BatchSampler = BatchSampler(
sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False
)
else:
sampler = BatchSampler(sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False)

# Update the soft-critic
for batch_idxes in sampler:
data = gathered_data[batch_idxes]
next_target_qf_value = agent.get_next_target_q_values(
data["next_observations"],
data["rewards"],
data["dones"],
cfg.algo.gamma,
critic_sampler = BatchSampler(
sampler=range(len(critic_data)), batch_size=cfg.per_rank_batch_size, drop_last=False
)
for qf_value_idx in range(agent.num_critics):
# Line 8 - Algorithm 2
qf_loss = F.mse_loss(
agent.get_ith_q_value(data["observations"], data["actions"], qf_value_idx), next_target_qf_value
)
qf_optimizer.zero_grad(set_to_none=True)
fabric.backward(qf_loss)
qf_optimizer.step()
aggregator.update("Loss/value_loss", qf_loss)

# Update the target networks with EMA
agent.qfs_target_ema(critic_idx=qf_value_idx)

# Sample a different minibatch in a distributed way to update actor and alpha parameter
sample = rb.sample(cfg.per_rank_batch_size)
data = fabric.all_gather(sample.to_dict())
data = make_tensordict(data).view(-1)
actor_data = fabric.all_gather(sample.to_dict())
actor_data = make_tensordict(actor_data).view(-1)
if fabric.world_size > 1:
sampler: DistributedSampler = DistributedSampler(
range(len(data)),
actor_sampler: DistributedSampler = DistributedSampler(
range(len(actor_data)),
num_replicas=fabric.world_size,
rank=fabric.global_rank,
shuffle=True,
seed=cfg.seed,
drop_last=False,
)
data = data[next(iter(sampler))]

# Update the actor
actions, logprobs = agent.get_actions_and_log_probs(data["observations"])
qf_values = agent.get_q_values(data["observations"], actions)
min_qf_values = torch.mean(qf_values, dim=-1, keepdim=True)
actor_loss = policy_loss(agent.alpha, logprobs, min_qf_values)
actor_optimizer.zero_grad(set_to_none=True)
fabric.backward(actor_loss)
actor_optimizer.step()
aggregator.update("Loss/policy_loss", actor_loss)

# Update the entropy value
alpha_loss = entropy_loss(agent.log_alpha, logprobs.detach(), agent.target_entropy)
alpha_optimizer.zero_grad(set_to_none=True)
fabric.backward(alpha_loss)
agent.log_alpha.grad = fabric.all_reduce(agent.log_alpha.grad)
alpha_optimizer.step()
aggregator.update("Loss/alpha_loss", alpha_loss)
actor_data = actor_data[next(iter(actor_sampler))]

with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
# Update the soft-critic
for batch_idxes in critic_sampler:
critic_batch_data = critic_data[batch_idxes]
next_target_qf_value = agent.get_next_target_q_values(
critic_batch_data["next_observations"],
critic_batch_data["rewards"],
critic_batch_data["dones"],
cfg.algo.gamma,
)
for qf_value_idx in range(agent.num_critics):
# Line 8 - Algorithm 2
qf_loss = F.mse_loss(
agent.get_ith_q_value(
critic_batch_data["observations"], critic_batch_data["actions"], qf_value_idx
),
next_target_qf_value,
)
qf_optimizer.zero_grad(set_to_none=True)
fabric.backward(qf_loss)
qf_optimizer.step()
aggregator.update("Loss/value_loss", qf_loss)

# Update the target networks with EMA
agent.qfs_target_ema(critic_idx=qf_value_idx)

# Update the actor
actions, logprobs = agent.get_actions_and_log_probs(actor_data["observations"])
qf_values = agent.get_q_values(actor_data["observations"], actions)
min_qf_values = torch.mean(qf_values, dim=-1, keepdim=True)
actor_loss = policy_loss(agent.alpha, logprobs, min_qf_values)
actor_optimizer.zero_grad(set_to_none=True)
fabric.backward(actor_loss)
actor_optimizer.step()
aggregator.update("Loss/policy_loss", actor_loss)

# Update the entropy value
alpha_loss = entropy_loss(agent.log_alpha, logprobs.detach(), agent.target_entropy)
alpha_optimizer.zero_grad(set_to_none=True)
fabric.backward(alpha_loss)
agent.log_alpha.grad = fabric.all_reduce(agent.log_alpha.grad)
alpha_optimizer.step()
aggregator.update("Loss/alpha_loss", alpha_loss)


@register_algorithm()
@hydra.main(version_base=None, config_path="../../configs", config_name="config")
def main(cfg: DictConfig):
print_config(cfg)

# Initialize Fabric
fabric = Fabric(callbacks=[CheckpointCallback()])
if not _is_using_cli():
fabric.launch()
rank = fabric.global_rank
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

Expand Down Expand Up @@ -196,17 +208,15 @@ def main(cfg: DictConfig):
)

# Metrics
with device:
aggregator = MetricAggregator(
{
"Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
}
)
aggregator = MetricAggregator(
{
"Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
}
).to(device)

# Local data
buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 1
Expand All @@ -220,25 +230,26 @@ def main(cfg: DictConfig):
step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device)

# Global variables
policy_step = 0
last_log = 0
last_train = 0
train_step = 0
policy_step = 0
last_checkpoint = 0
start_time = time.perf_counter()
policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size)
num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0

# Warning for log and checkpoint every
if cfg.metric.log_every % policy_steps_per_update != 0:
warnings.warn(
f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
"the metrics will be logged at the nearest greater multiple of the "
"policy_steps_per_update value."
)
if cfg.checkpoint.every % policy_steps_per_update != 0:
warnings.warn(
f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the "
f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
"the checkpoint will be saved at the nearest greater multiple of the "
"policy_steps_per_update value."
Expand All @@ -249,23 +260,25 @@ def main(cfg: DictConfig):
obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32) # [N_envs, N_obs]

for update in range(1, num_updates + 1):
# Sample an action given the observation received by the environment
with torch.no_grad():
actions, _ = actor.module(obs)
actions = actions.cpu().numpy()
next_obs, rewards, dones, truncated, infos = envs.step(actions)
dones = np.logical_or(dones, truncated)

policy_step += cfg.env.num_envs * fabric.world_size

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with torch.no_grad():
actions, _ = actor.module(obs)
actions = actions.cpu().numpy()
next_obs, rewards, dones, truncated, infos = envs.step(actions)
dones = np.logical_or(dones, truncated)

if "final_info" in infos:
for i, agent_final_info in enumerate(infos["final_info"]):
if agent_final_info is not None and "episode" in agent_final_info:
fabric.print(
f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}"
)
aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0])
aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0])
for i, agent_ep_info in enumerate(infos["final_info"]):
if agent_ep_info is not None:
ep_rew = agent_ep_info["episode"]["r"]
ep_len = agent_ep_info["episode"]["l"]
aggregator.update("Rewards/rew_avg", ep_rew)
aggregator.update("Game/ep_len_avg", ep_len)
fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}")

# Save the real next observation
real_next_obs = next_obs.copy()
Expand Down Expand Up @@ -295,12 +308,34 @@ def main(cfg: DictConfig):
# Train the agent
if update > learning_starts:
train(fabric, agent, actor_optimizer, qf_optimizer, alpha_optimizer, rb, aggregator, cfg)
aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time)))
if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run:
last_log = policy_step
fabric.log_dict(aggregator.compute(), policy_step)
train_step += world_size

# Log metrics
if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run:
# Sync distributed metrics
metrics_dict = aggregator.compute()
fabric.log_dict(metrics_dict, policy_step)
aggregator.reset()

# Sync distributed timers
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
last_log = policy_step
last_train = train_step

# Checkpoint model
if (
(cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every)
Expand Down
Loading