Skip to content

Commit

Permalink
Merge branch 'main' into fix/diambra-env-done
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico authored Sep 18, 2023
2 parents bc16310 + bfe50a9 commit 2f49b07
Show file tree
Hide file tree
Showing 23 changed files with 1,600 additions and 1,110 deletions.
2 changes: 1 addition & 1 deletion examples/architecture_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def main():
if devices is None or devices in ("1", "2"):
raise RuntimeError(
"Please run the script with the number of devices greater than 2: "
"`lightning run model --devices=3 sheeprl.py ...`"
"`lightning run model --devices=3 examples/architecture_template.py ...`"
)

world_collective = TorchCollective()
Expand Down
2 changes: 0 additions & 2 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def main(cfg: DictConfig):
{
"Rewards/rew_avg": MeanMetric(),
"Game/ep_len_avg": MeanMetric(),
"Time/step_per_second": MeanMetric(),
"Loss/value_loss": MeanMetric(),
"Loss/policy_loss": MeanMetric(),
"Loss/entropy_loss": MeanMetric(),
Expand Down Expand Up @@ -222,7 +221,6 @@ def main(cfg: DictConfig):

# Log metrics
metrics_dict = aggregator.compute()
fabric.log("Time/step_per_second", int(global_step / (time.perf_counter() - start_time)), global_step)
fabric.log_dict(metrics_dict, global_step)
aggregator.reset()

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ dependencies = [
"tensorboard>=2.10",
"python-dotenv>=1.0.0",
"lightning==2.0.*",
"lightning-utilities<0.9",
"lightning-utilities<=0.9",
"hydra-core==1.3.0",
"torchmetrics==1.1.*",
"rich==13.5.*",
"opencv-python==4.8.0.*"
]
dynamic = ["version"]
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@
np.int = np.int64
np.bool = bool

__version__ = "0.3.0"
__version__ = "0.3.2"
218 changes: 123 additions & 95 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py

Large diffs are not rendered by default.

245 changes: 136 additions & 109 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py

Large diffs are not rendered by default.

229 changes: 128 additions & 101 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py

Large diffs are not rendered by default.

205 changes: 120 additions & 85 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import time
import warnings
from math import prod

Expand All @@ -15,7 +14,7 @@
from torch.optim import Optimizer
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import BatchSampler
from torchmetrics import MeanMetric
from torchmetrics import MeanMetric, SumMetric

from sheeprl.algos.droq.agent import DROQAgent, DROQCritic
from sheeprl.algos.sac.agent import SACActor
Expand All @@ -27,6 +26,8 @@
from sheeprl.utils.logger import create_tensorboard_logger
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import print_config


def train(
Expand All @@ -44,86 +45,97 @@ def train(
sample = rb.sample(
cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs
)
gathered_data = fabric.all_gather(sample.to_dict())
gathered_data = make_tensordict(gathered_data).view(-1)
critic_data = fabric.all_gather(sample.to_dict())
critic_data = make_tensordict(critic_data).view(-1)
if fabric.world_size > 1:
dist_sampler: DistributedSampler = DistributedSampler(
range(len(gathered_data)),
range(len(critic_data)),
num_replicas=fabric.world_size,
rank=fabric.global_rank,
shuffle=True,
seed=cfg.seed,
drop_last=False,
)
sampler: BatchSampler = BatchSampler(sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False)
critic_sampler: BatchSampler = BatchSampler(
sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False
)
else:
sampler = BatchSampler(sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False)

# Update the soft-critic
for batch_idxes in sampler:
data = gathered_data[batch_idxes]
next_target_qf_value = agent.get_next_target_q_values(
data["next_observations"],
data["rewards"],
data["dones"],
cfg.algo.gamma,
critic_sampler = BatchSampler(
sampler=range(len(critic_data)), batch_size=cfg.per_rank_batch_size, drop_last=False
)
for qf_value_idx in range(agent.num_critics):
# Line 8 - Algorithm 2
qf_loss = F.mse_loss(
agent.get_ith_q_value(data["observations"], data["actions"], qf_value_idx), next_target_qf_value
)
qf_optimizer.zero_grad(set_to_none=True)
fabric.backward(qf_loss)
qf_optimizer.step()
aggregator.update("Loss/value_loss", qf_loss)

# Update the target networks with EMA
agent.qfs_target_ema(critic_idx=qf_value_idx)

# Sample a different minibatch in a distributed way to update actor and alpha parameter
sample = rb.sample(cfg.per_rank_batch_size)
data = fabric.all_gather(sample.to_dict())
data = make_tensordict(data).view(-1)
actor_data = fabric.all_gather(sample.to_dict())
actor_data = make_tensordict(actor_data).view(-1)
if fabric.world_size > 1:
sampler: DistributedSampler = DistributedSampler(
range(len(data)),
actor_sampler: DistributedSampler = DistributedSampler(
range(len(actor_data)),
num_replicas=fabric.world_size,
rank=fabric.global_rank,
shuffle=True,
seed=cfg.seed,
drop_last=False,
)
data = data[next(iter(sampler))]

# Update the actor
actions, logprobs = agent.get_actions_and_log_probs(data["observations"])
qf_values = agent.get_q_values(data["observations"], actions)
min_qf_values = torch.mean(qf_values, dim=-1, keepdim=True)
actor_loss = policy_loss(agent.alpha, logprobs, min_qf_values)
actor_optimizer.zero_grad(set_to_none=True)
fabric.backward(actor_loss)
actor_optimizer.step()
aggregator.update("Loss/policy_loss", actor_loss)

# Update the entropy value
alpha_loss = entropy_loss(agent.log_alpha, logprobs.detach(), agent.target_entropy)
alpha_optimizer.zero_grad(set_to_none=True)
fabric.backward(alpha_loss)
agent.log_alpha.grad = fabric.all_reduce(agent.log_alpha.grad)
alpha_optimizer.step()
aggregator.update("Loss/alpha_loss", alpha_loss)
actor_data = actor_data[next(iter(actor_sampler))]

with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
# Update the soft-critic
for batch_idxes in critic_sampler:
critic_batch_data = critic_data[batch_idxes]
next_target_qf_value = agent.get_next_target_q_values(
critic_batch_data["next_observations"],
critic_batch_data["rewards"],
critic_batch_data["dones"],
cfg.algo.gamma,
)
for qf_value_idx in range(agent.num_critics):
# Line 8 - Algorithm 2
qf_loss = F.mse_loss(
agent.get_ith_q_value(
critic_batch_data["observations"], critic_batch_data["actions"], qf_value_idx
),
next_target_qf_value,
)
qf_optimizer.zero_grad(set_to_none=True)
fabric.backward(qf_loss)
qf_optimizer.step()
aggregator.update("Loss/value_loss", qf_loss)

# Update the target networks with EMA
agent.qfs_target_ema(critic_idx=qf_value_idx)

# Update the actor
actions, logprobs = agent.get_actions_and_log_probs(actor_data["observations"])
qf_values = agent.get_q_values(actor_data["observations"], actions)
min_qf_values = torch.mean(qf_values, dim=-1, keepdim=True)
actor_loss = policy_loss(agent.alpha, logprobs, min_qf_values)
actor_optimizer.zero_grad(set_to_none=True)
fabric.backward(actor_loss)
actor_optimizer.step()
aggregator.update("Loss/policy_loss", actor_loss)

# Update the entropy value
alpha_loss = entropy_loss(agent.log_alpha, logprobs.detach(), agent.target_entropy)
alpha_optimizer.zero_grad(set_to_none=True)
fabric.backward(alpha_loss)
agent.log_alpha.grad = fabric.all_reduce(agent.log_alpha.grad)
alpha_optimizer.step()
aggregator.update("Loss/alpha_loss", alpha_loss)


@register_algorithm()
@hydra.main(version_base=None, config_path="../../configs", config_name="config")
def main(cfg: DictConfig):
print_config(cfg)

# Initialize Fabric
fabric = Fabric(callbacks=[CheckpointCallback()])
if not _is_using_cli():
fabric.launch()
rank = fabric.global_rank
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

Expand Down Expand Up @@ -196,17 +208,15 @@ def main(cfg: DictConfig):
)

# Metrics
with device:
aggregator = MetricAggregator(
{
"Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
}
)
aggregator = MetricAggregator(
{
"Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
"Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute),
}
).to(device)

# Local data
buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 1
Expand All @@ -220,25 +230,26 @@ def main(cfg: DictConfig):
step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device)

# Global variables
policy_step = 0
last_log = 0
last_train = 0
train_step = 0
policy_step = 0
last_checkpoint = 0
start_time = time.perf_counter()
policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size)
num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0

# Warning for log and checkpoint every
if cfg.metric.log_every % policy_steps_per_update != 0:
warnings.warn(
f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
"the metrics will be logged at the nearest greater multiple of the "
"policy_steps_per_update value."
)
if cfg.checkpoint.every % policy_steps_per_update != 0:
warnings.warn(
f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the "
f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
"the checkpoint will be saved at the nearest greater multiple of the "
"policy_steps_per_update value."
Expand All @@ -249,23 +260,25 @@ def main(cfg: DictConfig):
obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32) # [N_envs, N_obs]

for update in range(1, num_updates + 1):
# Sample an action given the observation received by the environment
with torch.no_grad():
actions, _ = actor.module(obs)
actions = actions.cpu().numpy()
next_obs, rewards, dones, truncated, infos = envs.step(actions)
dones = np.logical_or(dones, truncated)

policy_step += cfg.env.num_envs * fabric.world_size

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with torch.no_grad():
actions, _ = actor.module(obs)
actions = actions.cpu().numpy()
next_obs, rewards, dones, truncated, infos = envs.step(actions)
dones = np.logical_or(dones, truncated)

if "final_info" in infos:
for i, agent_final_info in enumerate(infos["final_info"]):
if agent_final_info is not None and "episode" in agent_final_info:
fabric.print(
f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}"
)
aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0])
aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0])
for i, agent_ep_info in enumerate(infos["final_info"]):
if agent_ep_info is not None:
ep_rew = agent_ep_info["episode"]["r"]
ep_len = agent_ep_info["episode"]["l"]
aggregator.update("Rewards/rew_avg", ep_rew)
aggregator.update("Game/ep_len_avg", ep_len)
fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}")

# Save the real next observation
real_next_obs = next_obs.copy()
Expand Down Expand Up @@ -295,12 +308,34 @@ def main(cfg: DictConfig):
# Train the agent
if update > learning_starts:
train(fabric, agent, actor_optimizer, qf_optimizer, alpha_optimizer, rb, aggregator, cfg)
aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time)))
if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run:
last_log = policy_step
fabric.log_dict(aggregator.compute(), policy_step)
train_step += world_size

# Log metrics
if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run:
# Sync distributed metrics
metrics_dict = aggregator.compute()
fabric.log_dict(metrics_dict, policy_step)
aggregator.reset()

# Sync distributed timers
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
last_log = policy_step
last_train = train_step

# Checkpoint model
if (
(cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every)
Expand Down
Loading

0 comments on commit 2f49b07

Please sign in to comment.