Skip to content

Commit

Permalink
[Feature] DT compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: b22fe60afb8d306a20df2d0c69b11907e020edb8
Pull Request resolved: #2556
  • Loading branch information
vmoens committed Dec 13, 2024
1 parent dc13998 commit 0a96f6e
Show file tree
Hide file tree
Showing 17 changed files with 192 additions and 120 deletions.
9 changes: 4 additions & 5 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch.nn
import torch.optim
from tensordict.nn import TensorDictModule
from torchrl.data import Composite
from torchrl.data.tensor_specs import CategoricalBox
from torchrl.envs import (
CatFrames,
Expand Down Expand Up @@ -93,12 +92,12 @@ def make_ppo_modules_pixels(proof_environment, device):
input_shape = proof_environment.observation_spec["pixels"].shape

# Define distribution class and kwargs
if isinstance(proof_environment.action_spec.space, CategoricalBox):
num_outputs = proof_environment.action_spec.space.n
if isinstance(proof_environment.single_action_spec.space, CategoricalBox):
num_outputs = proof_environment.single_action_spec.space.n
distribution_class = OneHotCategorical
distribution_kwargs = {}
else: # is ContinuousBox
num_outputs = proof_environment.action_spec.shape
num_outputs = proof_environment.single_action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec_unbatched.space.low.to(device),
Expand Down Expand Up @@ -152,7 +151,7 @@ def make_ppo_modules_pixels(proof_environment, device):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
spec=Composite(action=proof_environment.action_spec.to(device)),
spec=proof_environment.single_full_action_spec.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
7 changes: 3 additions & 4 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.optim

from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
from torchrl.data import Composite
from torchrl.envs import (
ClipTransform,
DoubleToFloat,
Expand Down Expand Up @@ -54,7 +53,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
input_shape = proof_environment.observation_spec["observation"].shape

# Define policy output distribution class
num_outputs = proof_environment.action_spec.shape[-1]
num_outputs = proof_environment.single_action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec_unbatched.space.low.to(device),
Expand Down Expand Up @@ -82,7 +81,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
policy_mlp = torch.nn.Sequential(
policy_mlp,
AddStateIndependentNormalScale(
proof_environment.action_spec.shape[-1], device=device
proof_environment.single_action_spec.shape[-1], device=device
),
)

Expand All @@ -94,7 +93,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
spec=Composite(action=proof_environment.action_spec.to(device)),
spec=proof_environment.single_full_action_spec.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"):


def make_cql_modules_state(model_cfg, proof_environment):
action_spec = proof_environment.action_spec
action_spec = proof_environment.single_action_spec

actor_net_kwargs = {
"num_cells": model_cfg.hidden_sizes,
Expand Down
78 changes: 52 additions & 26 deletions sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
This is a self-contained example of an offline Decision Transformer training script.
The helper functions are coded in the utils.py associated with this script.
"""
import time

import warnings

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule
from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.libs.gym import set_gym_backend

from torchrl.envs.utils import ExplorationType, set_exploration_type
Expand Down Expand Up @@ -65,58 +68,77 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create policy model
actor = make_dt_model(cfg)
policy = actor.to(model_device)
actor = make_dt_model(cfg, device=model_device)

# Create loss
loss_module = make_dt_loss(cfg.loss, actor)
loss_module = make_dt_loss(cfg.loss, actor, device=model_device)

# Create optimizer
transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module)

# Create inference policy
inference_policy = DecisionTransformerInferenceWrapper(
policy=policy,
policy=actor,
inference_context=cfg.env.inference_context,
).to(model_device)
device=model_device,
)
inference_policy.set_tensor_keys(
observation="observation_cat",
action="action_cat",
return_to_go="return_to_go_cat",
)

pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)

pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
clip_grad = cfg.optim.clip_grad
eval_steps = cfg.logger.eval_steps
pretrain_log_interval = cfg.logger.pretrain_log_interval
reward_scaling = cfg.env.reward_scaling

torchrl_logger.info(" ***Pretraining*** ")
# Pretraining
start_time = time.time()
for i in range(pretrain_gradient_steps):
pbar.update(1)

# Sample data
data = offline_buffer.sample()
def update(data: TensorDict) -> TensorDict:
transformer_optim.zero_grad(set_to_none=True)
# Compute loss
loss_vals = loss_module(data.to(model_device))
loss_vals = loss_module(data)
transformer_loss = loss_vals["loss"]

transformer_optim.zero_grad()
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)
transformer_loss.backward()
torch.nn.utils.clip_grad_norm_(actor.parameters(), clip_grad)
transformer_optim.step()

scheduler.step()
return loss_vals

if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode, dynamic=True)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

eval_steps = cfg.logger.eval_steps
pretrain_log_interval = cfg.logger.pretrain_log_interval
reward_scaling = cfg.env.reward_scaling

torchrl_logger.info(" ***Pretraining*** ")
# Pretraining
pbar = tqdm.tqdm(range(pretrain_gradient_steps))
for i in pbar:
# Sample data
with timeit("rb - sample"):
data = offline_buffer.sample().to(model_device)
with timeit("update"):
loss_vals = update(data)
scheduler.step()
# Log metrics
to_log = {"train/loss": loss_vals["loss"]}

# Evaluation
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
max_steps=eval_steps,
Expand All @@ -127,13 +149,17 @@ def main(cfg: "DictConfig"): # noqa: F821
to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)
if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()

if logger is not None:
log_metrics(logger, to_log, i)

pbar.close()
if not test_env.is_closed:
test_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
7 changes: 6 additions & 1 deletion sota-implementations/decision_transformer/dt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ optim:
# loss
loss:
loss_function: "l2"


compile:
compile: False
compile_mode:
cudagraphs: False

# transformer model
transformer:
n_embd: 128
Expand Down
5 changes: 5 additions & 0 deletions sota-implementations/decision_transformer/odt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ loss:
alpha_init: 0.1
target_entropy: auto

compile:
compile: False
compile_mode:
cudagraphs: False

# transformer model
transformer:
n_embd: 512
Expand Down
81 changes: 56 additions & 25 deletions sota-implementations/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
The helper functions are coded in the utils.py associated with this script.
"""
import time
import warnings

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule
from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.libs.gym import set_gym_backend

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
from torchrl.record import VideoRecorder
Expand Down Expand Up @@ -63,8 +64,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create policy model
actor = make_odt_model(cfg)
policy = actor.to(model_device)
policy = make_odt_model(cfg, device=model_device)

# Create loss
loss_module = make_odt_loss(cfg.loss, policy)
Expand All @@ -78,13 +78,46 @@ def main(cfg: "DictConfig"): # noqa: F821
inference_policy = DecisionTransformerInferenceWrapper(
policy=policy,
inference_context=cfg.env.inference_context,
).to(model_device)
device=model_device,
)
inference_policy.set_tensor_keys(
observation="observation_cat",
action="action_cat",
return_to_go="return_to_go_cat",
)

def update(data):
transformer_optim.zero_grad(set_to_none=True)
temperature_optim.zero_grad(set_to_none=True)
# Compute loss
loss_vals = loss_module(data.to(model_device))
transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"]
temperature_loss = loss_vals["loss_alpha"]

(temperature_loss + transformer_loss).backward()
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)

transformer_optim.step()
temperature_optim.step()

return loss_vals.detach()

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode, dynamic=True)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)

pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
Expand All @@ -98,35 +131,28 @@ def main(cfg: "DictConfig"): # noqa: F821
start_time = time.time()
for i in range(pretrain_gradient_steps):
pbar.update(1)
# Sample data
data = offline_buffer.sample()
# Compute loss
loss_vals = loss_module(data.to(model_device))
transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"]
temperature_loss = loss_vals["loss_alpha"]
with timeit("sample"):
# Sample data
data = offline_buffer.sample()

transformer_optim.zero_grad()
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)
transformer_loss.backward()
transformer_optim.step()

temperature_optim.zero_grad()
temperature_loss.backward()
temperature_optim.step()
with timeit("update"):
loss_vals = update(data.to(model_device))

scheduler.step()

# Log metrics
to_log = {
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"].item(),
"train/loss_entropy": loss_vals["loss_entropy"].item(),
"train/loss_alpha": loss_vals["loss_alpha"].item(),
"train/alpha": loss_vals["alpha"].item(),
"train/entropy": loss_vals["entropy"].item(),
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"],
"train/loss_entropy": loss_vals["loss_entropy"],
"train/loss_alpha": loss_vals["loss_alpha"],
"train/alpha": loss_vals["alpha"],
"train/entropy": loss_vals["entropy"],
}

# Evaluation
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
with torch.no_grad(), set_exploration_type(
ExplorationType.DETERMINISTIC
), timeit("eval"):
inference_policy.eval()
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
Expand All @@ -141,6 +167,11 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)

if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()

if logger is not None:
log_metrics(logger, to_log, i)

Expand Down
Loading

0 comments on commit 0a96f6e

Please sign in to comment.