Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] DQN compatibility with compile #2571

Merged
merged 47 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
cfa38bb
Update
vmoens Nov 15, 2024
9adac70
Update
vmoens Nov 18, 2024
6eb8bbc
Update
vmoens Nov 18, 2024
b5d6033
Update
vmoens Nov 18, 2024
f35c0d3
Update
vmoens Nov 18, 2024
19db1ab
Update
vmoens Nov 18, 2024
d97b7ab
Update
vmoens Nov 18, 2024
f564293
Update
vmoens Nov 18, 2024
340b119
Update
vmoens Nov 18, 2024
e64b4e1
Update
vmoens Nov 18, 2024
143f226
Update
vmoens Nov 18, 2024
e84f1b0
Update
vmoens Nov 18, 2024
79075cb
Update
vmoens Nov 18, 2024
a61fc38
Update
vmoens Nov 18, 2024
4081e21
Update
vmoens Nov 18, 2024
8d37c31
Update
vmoens Nov 18, 2024
7636777
Update
vmoens Nov 18, 2024
8cba43d
Update
vmoens Nov 18, 2024
122c4aa
Update
vmoens Nov 18, 2024
ea690f2
Update
vmoens Nov 18, 2024
ba52d14
Update
vmoens Nov 18, 2024
0f400be
Update
vmoens Nov 18, 2024
5ad014e
Update
vmoens Nov 19, 2024
7e3484e
Update
vmoens Nov 19, 2024
6ed4040
Update
vmoens Nov 19, 2024
8913064
Update
vmoens Nov 20, 2024
f5406d7
Update
vmoens Nov 20, 2024
102003e
Update
vmoens Nov 21, 2024
8ec6725
Update
vmoens Nov 25, 2024
9303e34
Update
vmoens Dec 13, 2024
163ca84
Update
vmoens Dec 13, 2024
2a802dd
Update
vmoens Dec 13, 2024
21798a6
Update
vmoens Dec 13, 2024
d6077e0
Update
vmoens Dec 13, 2024
725f38b
Update
vmoens Dec 13, 2024
e56d5a5
Update
vmoens Dec 13, 2024
1db8142
Update
vmoens Dec 13, 2024
71d1268
Update
vmoens Dec 13, 2024
7683cd3
Update
vmoens Dec 14, 2024
b627fef
Update
vmoens Dec 14, 2024
5245398
Update
vmoens Dec 14, 2024
79bbf44
Update
vmoens Dec 14, 2024
3da531a
Update
vmoens Dec 14, 2024
ed89a10
Update
vmoens Dec 14, 2024
89c8d98
Update
vmoens Dec 14, 2024
0dc4622
Update
vmoens Dec 14, 2024
e5a358b
Update
vmoens Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sota-implementations/dqn/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ loss:
gamma: 0.99
hard_update_freq: 10_000
num_updates: 1

compile:
compile: False
compile_mode:
cudagraphs: False
5 changes: 5 additions & 0 deletions sota-implementations/dqn/config_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@ loss:
gamma: 0.99
hard_update_freq: 50
num_updates: 1

compile:
compile: False
compile_mode:
cudagraphs: False
123 changes: 71 additions & 52 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from __future__ import annotations

import tempfile
import time
import warnings

import hydra
import torch.nn
import torch.optim
import tqdm
from tensordict.nn import TensorDictSequential
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule, TensorDictSequential
from torchrl._utils import timeit

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
Expand Down Expand Up @@ -48,28 +48,17 @@ def main(cfg: "DictConfig"): # noqa: F821
test_interval = cfg.logger.test_interval // frame_skip

# Make the components
model = make_dqn_model(cfg.env.env_name, frame_skip)
model = make_dqn_model(cfg.env.env_name, frame_skip, device=device)
greedy_module = EGreedyModule(
annealing_num_steps=cfg.collector.annealing_frames,
eps_init=cfg.collector.eps_start,
eps_end=cfg.collector.eps_end,
spec=model.spec,
device=device,
)
model_explore = TensorDictSequential(
model,
greedy_module,
).to(device)

# Create the collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, frame_skip, device),
policy=model_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
init_random_frames=init_random_frames,
)

# Create the replay buffer
Expand Down Expand Up @@ -129,25 +118,70 @@ def main(cfg: "DictConfig"): # noqa: F821
)
test_env.eval()

def update(sampled_tensordict):
loss_td = loss_module(sampled_tensordict)
q_loss = loss_td["loss"]
optimizer.zero_grad()
q_loss.backward()
torch.nn.utils.clip_grad_norm_(
list(loss_module.parameters()), max_norm=max_grad
)
optimizer.step()
target_net_updater.step()
return q_loss.detach()

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Create the collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, frame_skip, device),
policy=model_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
init_random_frames=init_random_frames,
compile_policy={"mode": compile_mode, "fullgraph": True}
if compile_mode is not None
else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

# Main loop
collected_frames = 0
start_time = time.time()
sampling_start = time.time()
num_updates = cfg.loss.num_updates
max_grad = cfg.optim.max_grad_norm
num_test_episodes = cfg.logger.num_test_episodes
q_losses = torch.zeros(num_updates, device=device)
pbar = tqdm.tqdm(total=total_frames)
for i, data in enumerate(collector):

c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
data = next(c_iter)
log_info = {}
sampling_time = time.time() - sampling_start
pbar.update(data.numel())
data = data.reshape(-1)
current_frames = data.numel() * frame_skip
collected_frames += current_frames
greedy_module.step(current_frames)
replay_buffer.extend(data)
with timeit("rb - extend"):
replay_buffer.extend(data)

# Get and log training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
Expand All @@ -169,74 +203,59 @@ def main(cfg: "DictConfig"): # noqa: F821
continue

# optimization steps
training_start = time.time()
for j in range(num_updates):

sampled_tensordict = replay_buffer.sample()
sampled_tensordict = sampled_tensordict.to(device)

loss_td = loss_module(sampled_tensordict)
q_loss = loss_td["loss"]
optimizer.zero_grad()
q_loss.backward()
torch.nn.utils.clip_grad_norm_(
list(loss_module.parameters()), max_norm=max_grad
)
optimizer.step()
target_net_updater.step()
q_losses[j].copy_(q_loss.detach())

training_time = time.time() - training_start
with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample()
sampled_tensordict = sampled_tensordict.to(device)
with timeit("update"):
q_loss = update(sampled_tensordict)
q_losses[j].copy_(q_loss)

# Get and log q-values, loss, epsilon, sampling time and training time
log_info.update(
{
"train/q_values": (data["action_value"] * data["action"]).sum().item()
/ frames_per_batch,
"train/q_loss": q_losses.mean().item(),
"train/q_values": data["chosen_action_value"].sum() / frames_per_batch,
"train/q_loss": q_losses.mean(),
"train/epsilon": greedy_module.eps,
"train/sampling_time": sampling_time,
"train/training_time": training_time,
}
)

# Get and log evaluation rewards and eval time
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
with torch.no_grad(), set_exploration_type(
ExplorationType.DETERMINISTIC
), timeit("eval"):
prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
cur_test_frame = (i * frames_per_batch) // test_interval
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
model.eval()
eval_start = time.time()
test_rewards = eval_model(
model, test_env, num_episodes=num_test_episodes
)
eval_time = time.time() - eval_start
log_info.update(
{
"eval/reward": test_rewards,
"eval/eval_time": eval_time,
}
)
model.train()

if i % 200 == 0:
timeit.print()
log_info.update(timeit.todict(prefix="time"))
timeit.erase()

# Log all the information
if logger:
for key, value in log_info.items():
logger.log_scalar(key, value, step=collected_frames)

# update weights of the inference policy
collector.update_policy_weights_()
sampling_start = time.time()

collector.shutdown()
if not test_env.is_closed:
test_env.close()

end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
Loading
Loading