Skip to content

Commit

Permalink
[Feature] DQN compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 48a014ab50d28ecae2e519ee9c80dd2eaa36c5be
Pull Request resolved: #2571
  • Loading branch information
vmoens committed Dec 13, 2024
1 parent 9f1c3e0 commit 9ed04b0
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 89 deletions.
5 changes: 5 additions & 0 deletions sota-implementations/dqn/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ loss:
gamma: 0.99
hard_update_freq: 10_000
num_updates: 1

compile:
compile: False
compile_mode:
cudagraphs: False
5 changes: 5 additions & 0 deletions sota-implementations/dqn/config_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@ loss:
gamma: 0.99
hard_update_freq: 50
num_updates: 1

compile:
compile: False
compile_mode:
cudagraphs: False
113 changes: 65 additions & 48 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
Deep Q-Learning Algorithm on Atari Environments.
"""
import tempfile
import time
import warnings

import hydra
import torch.nn
import torch.optim
import tqdm
from tensordict.nn import TensorDictSequential
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule, TensorDictSequential
from torchrl._utils import timeit

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
Expand Down Expand Up @@ -58,18 +58,6 @@ def main(cfg: "DictConfig"): # noqa: F821
greedy_module,
).to(device)

# Create the collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, frame_skip, device),
policy=model_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
init_random_frames=init_random_frames,
)

# Create the replay buffer
if cfg.buffer.scratch_dir is None:
tempdir = tempfile.TemporaryDirectory()
Expand Down Expand Up @@ -127,25 +115,68 @@ def main(cfg: "DictConfig"): # noqa: F821
)
test_env.eval()

def update(sampled_tensordict):
loss_td = loss_module(sampled_tensordict)
q_loss = loss_td["loss"]
optimizer.zero_grad()
q_loss.backward()
torch.nn.utils.clip_grad_norm_(
list(loss_module.parameters()), max_norm=max_grad
)
optimizer.step()
target_net_updater.step()
return q_loss.detach()

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Create the collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, frame_skip, device),
policy=model_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
init_random_frames=init_random_frames,
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

# Main loop
collected_frames = 0
start_time = time.time()
sampling_start = time.time()
num_updates = cfg.loss.num_updates
max_grad = cfg.optim.max_grad_norm
num_test_episodes = cfg.logger.num_test_episodes
q_losses = torch.zeros(num_updates, device=device)
pbar = tqdm.tqdm(total=total_frames)
for i, data in enumerate(collector):

c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
data = next(c_iter)
log_info = {}
sampling_time = time.time() - sampling_start
pbar.update(data.numel())
data = data.reshape(-1)
current_frames = data.numel() * frame_skip
collected_frames += current_frames
greedy_module.step(current_frames)
replay_buffer.extend(data)
with timeit("rb - extend"):
replay_buffer.extend(data)

# Get and log training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
Expand All @@ -167,24 +198,13 @@ def main(cfg: "DictConfig"): # noqa: F821
continue

# optimization steps
training_start = time.time()
for j in range(num_updates):

sampled_tensordict = replay_buffer.sample()
sampled_tensordict = sampled_tensordict.to(device)

loss_td = loss_module(sampled_tensordict)
q_loss = loss_td["loss"]
optimizer.zero_grad()
q_loss.backward()
torch.nn.utils.clip_grad_norm_(
list(loss_module.parameters()), max_norm=max_grad
)
optimizer.step()
target_net_updater.step()
q_losses[j].copy_(q_loss.detach())

training_time = time.time() - training_start
with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample()
sampled_tensordict = sampled_tensordict.to(device)
with timeit("update"):
q_loss = update(sampled_tensordict)
q_losses[j].copy_(q_loss)

# Get and log q-values, loss, epsilon, sampling time and training time
log_info.update(
Expand All @@ -193,48 +213,45 @@ def main(cfg: "DictConfig"): # noqa: F821
/ frames_per_batch,
"train/q_loss": q_losses.mean().item(),
"train/epsilon": greedy_module.eps,
"train/sampling_time": sampling_time,
"train/training_time": training_time,
}
)

# Get and log evaluation rewards and eval time
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
with torch.no_grad(), set_exploration_type(
ExplorationType.DETERMINISTIC
), timeit("eval"):
prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
cur_test_frame = (i * frames_per_batch) // test_interval
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
model.eval()
eval_start = time.time()
test_rewards = eval_model(
model, test_env, num_episodes=num_test_episodes
)
eval_time = time.time() - eval_start
log_info.update(
{
"eval/reward": test_rewards,
"eval/eval_time": eval_time,
}
)
model.train()

if i % 200 == 0:
timeit.print()
log_info.update(timeit.todict(prefix="time"))
timeit.erase()

# Log all the information
if logger:
for key, value in log_info.items():
logger.log_scalar(key, value, step=collected_frames)

# update weights of the inference policy
collector.update_policy_weights_()
sampling_start = time.time()

collector.shutdown()
if not test_env.is_closed:
test_env.close()

end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
Loading

0 comments on commit 9ed04b0

Please sign in to comment.