Skip to content

Commit

Permalink
[Feature] DDPG compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: f18928a419f81794d6870fd4e9fe1205c1b137e1
Pull Request resolved: #2555
  • Loading branch information
vmoens committed Dec 14, 2024
1 parent 01a421e commit 7d7cd95
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 69 deletions.
7 changes: 6 additions & 1 deletion sota-implementations/ddpg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ collector:
frames_per_batch: 1000
init_env_steps: 1000
reset_at_each_iter: False
device: cpu
device:
env_per_collector: 1


Expand All @@ -40,6 +40,11 @@ network:
activation: relu
noise_type: "ou" # ou or gaussian

compile:
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
backend: wandb
Expand Down
135 changes: 81 additions & 54 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@
"""
from __future__ import annotations

import time
import warnings

import hydra

import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import timeit

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
dump_video,
Expand All @@ -46,6 +50,14 @@ def main(cfg: "DictConfig"): # noqa: F821
device = "cpu"
device = torch.device(device)

collector_device = cfg.collector.device
if collector_device in ("", None):
if torch.cuda.is_available():
collector_device = "cuda:0"
else:
collector_device = "cpu"
collector_device = torch.device(collector_device)

# Create logger
exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)
logger = None
Expand Down Expand Up @@ -75,8 +87,25 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create DDPG loss
loss_module, target_net_updater = make_loss_module(cfg, model)

compile_mode = None
if cfg.compile.compile:
if cfg.compile.compile_mode not in (None, ""):
compile_mode = cfg.compile.compile_mode
elif cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)
collector = make_collector(
cfg,
train_env,
exploration_policy,
compile=cfg.compile.compile,
compile_mode=compile_mode,
cudagraph=cfg.compile.cudagraphs,
device=collector_device,
)

# Create replay buffer
replay_buffer = make_replay_buffer(
Expand All @@ -89,9 +118,29 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create optimizers
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
optimizer = group_optimizers(optimizer_actor, optimizer_critic)

def update(sampled_tensordict):
optimizer.zero_grad(set_to_none=True)

td_loss: TensorDict = loss_module(sampled_tensordict)
td_loss.sum(reduce=True).backward()
optimizer.step()

# Update qnet_target params
target_net_updater.step()
return td_loss.detach()

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -106,63 +155,43 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_iter = cfg.logger.eval_iter
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
for _, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
tensordict = next(c_iter)
# Update exploration policy
exploration_policy[1].step(tensordict.numel())

# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
pbar.update(current_frames)

# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
with timeit("rb - extend"):
tensordict = tensordict.reshape(-1)
replay_buffer.extend(tensordict)

collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
) = ([], [])
tds = []
for _ in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Update critic
q_loss, *_ = loss_module.loss_value(sampled_tensordict)
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# Update actor
actor_loss, *_ = loss_module.loss_actor(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

q_losses.append(q_loss.item())
actor_losses.append(actor_loss.item())

# Update qnet_target params
target_net_updater.step()
with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample().to(device)
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
td_loss = update(sampled_tensordict)
tds.append(td_loss.clone())

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)
tds = torch.stack(tds)

training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
Expand All @@ -180,38 +209,36 @@ def main(cfg: "DictConfig"): # noqa: F821
)

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time
tds = TensorDict(train=tds).flatten_keys("/").mean()
metrics_to_log.update(tds.to_dict())

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
exploration_policy,
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if i % 20 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
33 changes: 19 additions & 14 deletions sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from tensordict.nn import TensorDictSequential
from tensordict.nn import TensorDictModule, TensorDictSequential

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
Expand All @@ -32,8 +32,6 @@
AdditiveGaussianModule,
MLP,
OrnsteinUhlenbeckProcessModule,
SafeModule,
SafeSequential,
TanhModule,
ValueOperator,
)
Expand Down Expand Up @@ -115,7 +113,15 @@ def make_environment(cfg, logger):
# ---------------------------


def make_collector(cfg, train_env, actor_model_explore):
def make_collector(
cfg,
train_env,
actor_model_explore,
compile=False,
compile_mode=None,
cudagraph=False,
device: torch.device | None = None,
):
"""Make collector."""
collector = SyncDataCollector(
train_env,
Expand All @@ -124,7 +130,9 @@ def make_collector(cfg, train_env, actor_model_explore):
init_random_frames=cfg.collector.init_random_frames,
reset_at_each_iter=cfg.collector.reset_at_each_iter,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
device=device,
compile_policy={"mode": compile_mode, "fullgraph": True} if compile else False,
cudagraph_policy=cudagraph,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down Expand Up @@ -174,9 +182,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
"""Make DDPG agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.action_spec_unbatched
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": action_spec.shape[-1],
Expand All @@ -186,19 +192,16 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
actor_net = MLP(**actor_net_kwargs)

in_keys_actor = in_keys
actor_module = SafeModule(
actor_module = TensorDictModule(
actor_net,
in_keys=in_keys_actor,
out_keys=[
"param",
],
out_keys=["param"],
)
actor = SafeSequential(
actor = TensorDictSequential(
actor_module,
TanhModule(
in_keys=["param"],
out_keys=["action"],
spec=action_spec,
),
)

Expand Down Expand Up @@ -237,6 +240,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
spec=action_spec,
annealing_num_steps=1_000_000,
device=device,
safe=False,
),
)
elif cfg.network.noise_type == "gaussian":
Expand All @@ -249,6 +253,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
mean=0.0,
std=0.1,
device=device,
safe=False,
),
)
else:
Expand Down
1 change: 1 addition & 0 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
set_exploration_type,
)


try:
from torch.compiler import cudagraph_mark_step_begin
except ImportError:
Expand Down
Loading

1 comment on commit 7d7cd95

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 7d7cd95 Previous: 01a421e Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 36.16408216919682 iter/sec (stddev: 0.16446754651912104) 241.73781953462816 iter/sec (stddev: 0.0005608093730264233) 6.68

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.