Skip to content

Commit

Permalink
[Feature] GAIL compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 5fadce953ca929deb9c7aef25b497a227a0bb0a0
Pull Request resolved: #2573
  • Loading branch information
vmoens committed Dec 14, 2024
1 parent 54b6b08 commit 1206f65
Show file tree
Hide file tree
Showing 14 changed files with 132 additions and 78 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/cql/discrete_cql_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ collector:
multi_step: 0
init_random_frames: 1000
env_per_collector: 1
device: cpu
device:
max_frames_per_traj: 200
annealing_frames: 10000
eps_start: 1.0
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ collector:
multi_step: 0
init_random_frames: 5_000
env_per_collector: 1
device: cpu
device:
max_frames_per_traj: 1000


Expand Down
8 changes: 7 additions & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,20 @@ def make_collector(
cudagraph=False,
):
"""Make collector."""
device = cfg.collector.device
if device in ("", None):
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
collector = SyncDataCollector(
train_env,
actor_model_explore,
init_random_frames=cfg.collector.init_random_frames,
frames_per_batch=cfg.collector.frames_per_batch,
max_frames_per_traj=cfg.collector.max_frames_per_traj,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
device=device,
compile_policy={"mode": compile_mode} if compile else False,
cudagraph_policy=cudagraph,
)
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def make_odt_model(cfg, device: torch.device | None = None) -> TensorDictModule:
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
td = proof_environment.rollout(max_steps=100)
td["action"] = td["next", "action"]
actor(td)
actor(td.to(device))

return actor

Expand Down
5 changes: 5 additions & 0 deletions sota-implementations/gail/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ gail:
gp_lambda: 10.0
device: null

compile:
compile: False
compile_mode:
cudagraphs: False

replay_buffer:
dataset: halfcheetah-expert-v2
batch_size: 256
152 changes: 88 additions & 64 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@

from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer
from ppo_utils import eval_model, make_env, make_ppo_models
from tensordict.nn import CudaGraphModule
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss, GAILLoss
from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers
from torchrl.objectives.value.advantages import GAE
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -71,20 +72,9 @@ def main(cfg: "DictConfig"): # noqa: F821
np.random.seed(cfg.env.seed)

# Create models (check utils_mujoco.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = make_ppo_models(cfg.env.env_name, compile=cfg.compile.compile)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
)

# Create data buffer
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch),
Expand Down Expand Up @@ -113,6 +103,30 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
optim = group_optimizers(actor_optim, critic_optim)
del actor_optim, critic_optim

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
Expand Down Expand Up @@ -140,32 +154,9 @@ def main(cfg: "DictConfig"): # noqa: F821
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
)
test_env.eval()
num_network_updates = torch.zeros((), dtype=torch.int64, device=device)

# Training loop
collected_frames = 0
num_network_updates = 0
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
cfg_optim_lr = cfg.ppo.optim.lr
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)
def update(data, expert_data, num_network_updates=num_network_updates):
# Add collector data to expert data
expert_data.set(
discriminator_loss.tensor_keys.collector_action,
Expand All @@ -178,9 +169,9 @@ def main(cfg: "DictConfig"): # noqa: F821
d_loss = discriminator_loss(expert_data)

# Backward pass
discriminator_optim.zero_grad()
d_loss.get("loss").backward()
discriminator_optim.step()
discriminator_optim.zero_grad(set_to_none=True)

# Compute discriminator reward
with torch.no_grad():
Expand All @@ -190,32 +181,19 @@ def main(cfg: "DictConfig"): # noqa: F821
# Set discriminator rewards to tensordict
data.set(("next", "reward"), d_rewards)

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)
# Update PPO
for _ in range(cfg_loss_ppo_epochs):

# Compute GAE
with torch.no_grad():
data = adv_module(data)
data_reshape = data.reshape(-1)

# Update the data buffer
data_buffer.empty()
data_buffer.extend(data_reshape)

for _, batch in enumerate(data_buffer):

# Get a data batch
batch = batch.to(device)
for batch in data_buffer:
optim.zero_grad(set_to_none=True)

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
Expand All @@ -235,20 +213,66 @@ def main(cfg: "DictConfig"): # noqa: F821
actor_loss = loss["loss_objective"] + loss["loss_entropy"]

# Backward pass
actor_loss.backward()
critic_loss.backward()
(actor_loss + critic_loss).backward()

# Update the networks
actor_optim.step()
critic_optim.step()
actor_optim.zero_grad()
critic_optim.zero_grad()
optim.step()
return d_loss.detach()

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Training loop
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
cfg_optim_lr = cfg.ppo.optim.lr
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)

d_loss = update(data, expert_data)

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]

log_info.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)

log_info.update(
{
"train/actor_loss": actor_loss.item(),
"train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"].item(),
# "train/actor_loss": actor_loss.item(),
# "train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"],
"train/lr": alpha * cfg_optim_lr,
"train/clip_epsilon": (
alpha * cfg_loss_clip_epsilon
Expand Down
7 changes: 4 additions & 3 deletions sota-implementations/gail/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False)
# --------------------------------------------------------------------


def make_ppo_models_state(proof_environment):
def make_ppo_models_state(proof_environment, compile):

# Define input shape
input_shape = proof_environment.observation_spec["observation"].shape
Expand All @@ -55,6 +55,7 @@ def make_ppo_models_state(proof_environment):
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"tanh_loc": False,
"safe_tanh": not compile,
}

# Define policy architecture
Expand Down Expand Up @@ -117,9 +118,9 @@ def make_ppo_models_state(proof_environment):
return policy_module, value_module


def make_ppo_models(env_name):
def make_ppo_models(env_name, compile):
proof_environment = make_env(env_name, device="cpu")
actor, critic = make_ppo_models_state(proof_environment)
actor, critic = make_ppo_models_state(proof_environment, compile=compile)
return actor, critic


Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/iql/discrete_iql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ collector:
total_frames: 20000
init_random_frames: 1000
env_per_collector: 1
device: cpu
device:
max_frames_per_traj: 200

# logger
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/iql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ collector:
multi_step: 0
init_random_frames: 5000
env_per_collector: 1
device: cpu
device:
max_frames_per_traj: 200

# logger
Expand Down
8 changes: 7 additions & 1 deletion sota-implementations/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,20 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None):

def make_collector(cfg, train_env, actor_model_explore):
"""Make collector."""
device = cfg.collector.device
if device in ("", None):
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
collector = SyncDataCollector(
train_env,
actor_model_explore,
frames_per_batch=cfg.collector.frames_per_batch,
init_random_frames=cfg.collector.init_random_frames,
max_frames_per_traj=cfg.collector.max_frames_per_traj,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
device=device,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/sac/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ collector:
init_random_frames: 25000
frames_per_batch: 1000
init_env_steps: 1000
device: cpu
device:
env_per_collector: 1
reset_at_each_iter: False

Expand Down
8 changes: 7 additions & 1 deletion sota-implementations/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,19 @@ def make_environment(cfg, logger=None):

def make_collector(cfg, train_env, actor_model_explore):
"""Make collector."""
device = cfg.collector.device
if device in ("", None):
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
collector = SyncDataCollector(
train_env,
actor_model_explore,
init_random_frames=cfg.collector.init_random_frames,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
device=device,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/td3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ collector:
init_env_steps: 1000
frames_per_batch: 1000
reset_at_each_iter: False
device: cpu
device:
env_per_collector: 1
num_workers: 1

Expand Down
Loading

0 comments on commit 1206f65

Please sign in to comment.