Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] DDPG compatibility with compile #2555

Merged
merged 47 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
07204b5
Update
vmoens Nov 12, 2024
bacf268
Update
vmoens Nov 12, 2024
3d68951
Update
vmoens Nov 12, 2024
d5fb253
Update
vmoens Nov 12, 2024
45c8d6c
Update
vmoens Nov 12, 2024
40e588d
Update
vmoens Nov 12, 2024
fc9e147
Update
vmoens Nov 12, 2024
3874f58
Update
vmoens Nov 12, 2024
fef1ed9
Update
vmoens Nov 12, 2024
f144951
Update
vmoens Nov 12, 2024
14e0fcf
Update
vmoens Nov 12, 2024
e2ef604
Update
vmoens Nov 12, 2024
3a01fd3
Update
vmoens Nov 14, 2024
a955c7b
Update
vmoens Nov 15, 2024
38b2a57
Update
vmoens Nov 15, 2024
c0989bb
Update
vmoens Nov 18, 2024
b20c37e
Update
vmoens Nov 18, 2024
04a3174
Update
vmoens Nov 18, 2024
6e0455c
Update
vmoens Nov 18, 2024
c857c2d
Update
vmoens Nov 18, 2024
ed2a16c
Update
vmoens Nov 18, 2024
ea467b6
Update
vmoens Nov 18, 2024
1ddf145
Update
vmoens Nov 18, 2024
13fc52b
Update
vmoens Nov 18, 2024
8a175c1
Update
vmoens Nov 18, 2024
5cad042
Update
vmoens Nov 18, 2024
92660ad
Update
vmoens Nov 18, 2024
78fb498
Update
vmoens Nov 18, 2024
74b5b72
Update
vmoens Nov 18, 2024
c086510
Update
vmoens Nov 18, 2024
c3d7f0e
Update
vmoens Nov 18, 2024
58297c2
Update
vmoens Nov 18, 2024
96bf830
Update
vmoens Nov 18, 2024
0ef9a38
Update
vmoens Nov 18, 2024
8df6918
Update
vmoens Nov 18, 2024
8a9abb7
Update
vmoens Nov 18, 2024
c4ce3a2
Update
vmoens Nov 19, 2024
78cf970
Update
vmoens Nov 19, 2024
6ebc5ab
Update
vmoens Nov 19, 2024
ead71e4
Update
vmoens Nov 20, 2024
f8dc080
Update
vmoens Nov 20, 2024
97c3132
Update
vmoens Nov 21, 2024
09afa20
Update
vmoens Nov 25, 2024
5bebbf0
Update
vmoens Dec 13, 2024
2078c4d
Update
vmoens Dec 13, 2024
a82c241
Update
vmoens Dec 13, 2024
10093ab
Update
vmoens Dec 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sota-implementations/ddpg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ collector:
frames_per_batch: 1000
init_env_steps: 1000
reset_at_each_iter: False
device: cpu
device:
env_per_collector: 1


Expand All @@ -40,6 +40,11 @@ network:
activation: relu
noise_type: "ou" # ou or gaussian

compile:
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
backend: wandb
Expand Down
135 changes: 81 additions & 54 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@
"""
from __future__ import annotations

import time
import warnings

import hydra

import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import timeit

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
dump_video,
Expand All @@ -46,6 +50,14 @@ def main(cfg: "DictConfig"): # noqa: F821
device = "cpu"
device = torch.device(device)

collector_device = cfg.collector.device
if collector_device in ("", None):
if torch.cuda.is_available():
collector_device = "cuda:0"
else:
collector_device = "cpu"
collector_device = torch.device(collector_device)

# Create logger
exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)
logger = None
Expand Down Expand Up @@ -75,8 +87,25 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create DDPG loss
loss_module, target_net_updater = make_loss_module(cfg, model)

compile_mode = None
if cfg.compile.compile:
if cfg.compile.compile_mode not in (None, ""):
compile_mode = cfg.compile.compile_mode
elif cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)
collector = make_collector(
cfg,
train_env,
exploration_policy,
compile=cfg.compile.compile,
compile_mode=compile_mode,
cudagraph=cfg.compile.cudagraphs,
device=collector_device,
)

# Create replay buffer
replay_buffer = make_replay_buffer(
Expand All @@ -89,9 +118,29 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create optimizers
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
optimizer = group_optimizers(optimizer_actor, optimizer_critic)

def update(sampled_tensordict):
optimizer.zero_grad(set_to_none=True)

td_loss: TensorDict = loss_module(sampled_tensordict)
td_loss.sum(reduce=True).backward()
optimizer.step()

# Update qnet_target params
target_net_updater.step()
return td_loss.detach()

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -106,63 +155,43 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_iter = cfg.logger.eval_iter
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
for _, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
tensordict = next(c_iter)
# Update exploration policy
exploration_policy[1].step(tensordict.numel())

# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
pbar.update(current_frames)

# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
with timeit("rb - extend"):
tensordict = tensordict.reshape(-1)
replay_buffer.extend(tensordict)

collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
) = ([], [])
tds = []
for _ in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Update critic
q_loss, *_ = loss_module.loss_value(sampled_tensordict)
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# Update actor
actor_loss, *_ = loss_module.loss_actor(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

q_losses.append(q_loss.item())
actor_losses.append(actor_loss.item())

# Update qnet_target params
target_net_updater.step()
with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample().to(device)
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
td_loss = update(sampled_tensordict)
tds.append(td_loss.clone())

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)
tds = torch.stack(tds)

training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
Expand All @@ -180,38 +209,36 @@ def main(cfg: "DictConfig"): # noqa: F821
)

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time
tds = TensorDict(train=tds).flatten_keys("/").mean()
metrics_to_log.update(tds.to_dict())

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
exploration_policy,
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if i % 20 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
33 changes: 19 additions & 14 deletions sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from tensordict.nn import TensorDictSequential
from tensordict.nn import TensorDictModule, TensorDictSequential

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
Expand All @@ -32,8 +32,6 @@
AdditiveGaussianModule,
MLP,
OrnsteinUhlenbeckProcessModule,
SafeModule,
SafeSequential,
TanhModule,
ValueOperator,
)
Expand Down Expand Up @@ -115,7 +113,15 @@ def make_environment(cfg, logger):
# ---------------------------


def make_collector(cfg, train_env, actor_model_explore):
def make_collector(
cfg,
train_env,
actor_model_explore,
compile=False,
compile_mode=None,
cudagraph=False,
device: torch.device | None = None,
):
"""Make collector."""
collector = SyncDataCollector(
train_env,
Expand All @@ -124,7 +130,9 @@ def make_collector(cfg, train_env, actor_model_explore):
init_random_frames=cfg.collector.init_random_frames,
reset_at_each_iter=cfg.collector.reset_at_each_iter,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
device=device,
compile_policy={"mode": compile_mode, "fullgraph": True} if compile else False,
cudagraph_policy=cudagraph,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down Expand Up @@ -174,9 +182,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
"""Make DDPG agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.action_spec_unbatched
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": action_spec.shape[-1],
Expand All @@ -186,19 +192,16 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
actor_net = MLP(**actor_net_kwargs)

in_keys_actor = in_keys
actor_module = SafeModule(
actor_module = TensorDictModule(
actor_net,
in_keys=in_keys_actor,
out_keys=[
"param",
],
out_keys=["param"],
)
actor = SafeSequential(
actor = TensorDictSequential(
actor_module,
TanhModule(
in_keys=["param"],
out_keys=["action"],
spec=action_spec,
),
)

Expand Down Expand Up @@ -237,6 +240,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
spec=action_spec,
annealing_num_steps=1_000_000,
device=device,
safe=False,
),
)
elif cfg.network.noise_type == "gaussian":
Expand All @@ -249,6 +253,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
mean=0.0,
std=0.1,
device=device,
safe=False,
),
)
else:
Expand Down
1 change: 1 addition & 0 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
set_exploration_type,
)


try:
from torch.compiler import cudagraph_mark_step_begin
except ImportError:
Expand Down
Loading
Loading