Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] TD3+BC #2249

Merged
merged 22 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \
optim.gradient_steps=55 \
logger.backend=

# ==================================================================================== #
# ================================ Gymnasium ========================================= #

python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3_bc/td3_bc.py \
optim.gradient_steps=55 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/impala/impala_single_node.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ A series of [examples](https://github.com/pytorch/rl/blob/main/examples/) are pr
- [IQL](https://github.com/pytorch/rl/blob/main/sota-implementations/iql/iql_offline.py)
- [CQL](https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py)
- [TD3](https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py)
- [TD3+BC](https://github.com/pytorch/rl/blob/main/sota-implementations/td3+bc/td3+bc.py)
- [A2C](https://github.com/pytorch/rl/blob/main/examples/a2c_old/a2c.py)
- [PPO](https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/ppo.py)
- [SAC](https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py)
Expand Down
9 changes: 9 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ TD3

TD3Loss

TD3+BC
----

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

TD3BCLoss

PPO
---

Expand Down
26 changes: 26 additions & 0 deletions sota-check/run_td3bc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash

#SBATCH --job-name=td3bc_offline
#SBATCH --ntasks=32
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --output=slurm_logs/td3bc_offline_%j.txt
#SBATCH --error=slurm_errors/td3bc_offline_%j.txt

current_commit=$(git rev-parse --short HEAD)
project_name="torchrl-example-check-$current_commit"
group_name="td3bc_offline"
export PYTHONPATH=$(dirname $(dirname $PWD))
python $PYTHONPATH/sota-implementations/td3_bc/td3_bc.py \
logger.backend=wandb \
logger.project_name="$project_name" \
logger.group_name="$group_name"

# Capture the exit status of the Python command
exit_status=$?
# Write the exit status to a file
if [ $exit_status -eq 0 ]; then
echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log
else
echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log
fi
1 change: 1 addition & 0 deletions sota-check/submitit-release-check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ scripts=(
run_ppo_mujoco.sh
run_sac.sh
run_td3.sh
run_td3bc.sh
run_dt.sh
run_dt_online.sh
)
Expand Down
45 changes: 45 additions & 0 deletions sota-implementations/td3_bc/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# task and env
env:
name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency
task: ""
library: gymnasium
seed: 42
max_episode_steps: 1000

# replay buffer
replay_buffer:
dataset: halfcheetah-medium-v2
batch_size: 256

# optim
optim:
gradient_steps: 100000
gamma: 0.99
loss_function: l2
lr: 3.0e-4
weight_decay: 0.0
adam_eps: 1e-4
batch_size: 256
target_update_polyak: 0.995
policy_update_delay: 2
policy_noise: 0.2
noise_clip: 0.5
alpha: 2.5

# network
network:
hidden_sizes: [256, 256]
activation: relu
device: null

# logging
logger:
backend: wandb
project_name: td3+bc_${replay_buffer.dataset}
group_name: null
exp_name: TD3+BC_${replay_buffer.dataset}
mode: online
eval_iter: 5000
eval_steps: 1000
eval_envs: 1
video: False
146 changes: 146 additions & 0 deletions sota-implementations/td3_bc/td3_bc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""TD3+BC Example.

This is a self-contained example of an offline RL TD3+BC training script.

The helper functions are coded in the utils.py associated with this script.

"""
import time

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
dump_video,
log_metrics,
make_environment,
make_loss_module,
make_offline_replay_buffer,
make_optimizer,
make_td3_agent,
)


@hydra.main(config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.library).set()

# Create logger
exp_name = generate_exp_name("TD3BC-offline", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="td3bc_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = cfg.network.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)

# Creante env
eval_env = make_environment(
cfg,
logger=logger,
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)

# Create agent
model, _ = make_td3_agent(cfg, eval_env, device)

# Create loss
loss_module, target_net_updater = make_loss_module(cfg.optim, model)

# Create optimizer
optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module)

gradient_steps = cfg.optim.gradient_steps
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps
delayed_updates = cfg.optim.policy_update_delay
update_counter = 0
pbar = tqdm.tqdm(range(gradient_steps))
# Training loop
start_time = time.time()
for i in pbar:
pbar.update(1)
# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0

# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(device)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
q_loss, *_ = loss_module.value_loss(sampled_tensordict)

# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_loss.item()

to_log = {"q_loss": q_loss.item()}

# Update actor
if update_actor:
actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

# Update target params
target_net_updater.step()

to_log["actor_loss"] = actor_loss.item()
to_log.update(actorloss_metadata)

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward
if logger is not None:
log_metrics(logger, to_log, i)

pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
main()
Loading
Loading