Skip to content

Commit

Permalink
Merge pull request #95 from Eclectic-Sheep/feature/resume_from_checkp…
Browse files Browse the repository at this point in the history
…oint

Feature/resume from checkpoint
  • Loading branch information
belerico authored Sep 18, 2023
2 parents 733faf1 + 96c09c4 commit d8a06aa
Show file tree
Hide file tree
Showing 17 changed files with 523 additions and 176 deletions.
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def main(fabric: Fabric, cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def main(fabric: Fabric, cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def main(fabric: Fabric, cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
39 changes: 30 additions & 9 deletions sheeprl/algos/droq/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,30 @@ def __init__(

# Actor and critics
self._num_critics = len(critics)
self._actor = actor
self.actor = actor
self.critics = critics

# Automatic entropy tuning
self._target_entropy = torch.tensor(target_entropy, device=device)
self._log_alpha = torch.nn.Parameter(torch.log(torch.tensor([alpha], device=device)), requires_grad=True)

# EMA tau
self._tau = tau

def __setattr__(self, name: str, value: Union[Tensor, nn.Module]) -> None:
# Taken from https://github.com/pytorch/pytorch/pull/92044
# Check if a property setter exists. If it does, use it.
class_attr = getattr(self.__class__, name, None)
if isinstance(class_attr, property) and class_attr.fset is not None:
return class_attr.fset(self, value)
super().__setattr__(name, value)

@property
def critics(self) -> nn.ModuleList:
return self.qfs

@critics.setter
def critics(self, critics: Sequence[Union[DROQCritic, _FabricModule]]) -> None:
self._qfs = nn.ModuleList(critics)

# Create target critic unwrapping the DDP module from the critics to prevent
Expand All @@ -93,7 +116,7 @@ def __init__(
# This happens when we're using the decoupled version of SAC for example
qfs_unwrapped_modules = []
for critic in critics:
if getattr(critic, "module"):
if hasattr(critic, "module"):
critic_module = critic.module
else:
critic_module = critic
Expand All @@ -103,13 +126,6 @@ def __init__(
for p in self._qfs_target.parameters():
p.requires_grad = False

# Automatic entropy tuning
self._target_entropy = torch.tensor(target_entropy, device=device)
self._log_alpha = torch.nn.Parameter(torch.log(torch.tensor([alpha], device=device)), requires_grad=True)

# EMA tau
self._tau = tau

@property
def num_critics(self) -> int:
return self._num_critics
Expand All @@ -126,6 +142,11 @@ def qfs_unwrapped(self) -> nn.ModuleList:
def actor(self) -> Union[SACActor, _FabricModule]:
return self._actor

@actor.setter
def actor(self, actor: Union[SACActor, _FabricModule]) -> None:
self._actor = actor
return

@property
def qfs_target(self) -> nn.ModuleList:
return self._qfs_target
Expand Down
83 changes: 58 additions & 25 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pathlib
import warnings
from math import prod

Expand Down Expand Up @@ -129,6 +130,18 @@ def main(fabric: Fabric, cfg: DictConfig):
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
root_dir = cfg.root_dir
run_name = cfg.run_name
state = fabric.load(cfg.checkpoint.resume_from)
ckpt_path = pathlib.Path(cfg.checkpoint.resume_from)
cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml")
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size
cfg.root_dir = root_dir
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
logger, log_dir = create_tensorboard_logger(fabric, cfg)
Expand Down Expand Up @@ -165,36 +178,41 @@ def main(fabric: Fabric, cfg: DictConfig):
# Define the agent and the optimizer and setup them with Fabric
act_dim = prod(envs.single_action_space.shape)
obs_dim = prod(envs.single_observation_space.shape)
actor = fabric.setup_module(
SACActor(
observation_dim=obs_dim,
action_dim=act_dim,
hidden_size=cfg.algo.actor.hidden_size,
action_low=envs.single_action_space.low,
action_high=envs.single_action_space.high,
)
actor = SACActor(
observation_dim=obs_dim,
action_dim=act_dim,
hidden_size=cfg.algo.actor.hidden_size,
action_low=envs.single_action_space.low,
action_high=envs.single_action_space.high,
)
critics = [
fabric.setup_module(
DROQCritic(
observation_dim=obs_dim + act_dim,
hidden_size=cfg.algo.critic.hidden_size,
num_critics=1,
dropout=cfg.algo.critic.dropout,
)
DROQCritic(
observation_dim=obs_dim + act_dim,
hidden_size=cfg.algo.critic.hidden_size,
num_critics=1,
dropout=cfg.algo.critic.dropout,
)
for _ in range(cfg.algo.critic.n)
]
target_entropy = -act_dim
agent = DROQAgent(
actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device
)
if cfg.checkpoint.resume_from:
agent.load_state_dict(state["agent"])
agent.actor = fabric.setup_module(agent.actor)
agent.critics = [fabric.setup_module(critic) for critic in agent.critics]

# Optimizers
qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters())
actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=agent.actor.parameters())
alpha_optimizer = hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha])
if cfg.checkpoint.resume_from:
qf_optimizer.load_state_dict(state["qf_optimizer"])
actor_optimizer.load_state_dict(state["actor_optimizer"])
alpha_optimizer.load_state_dict(state["alpha_optimizer"])
qf_optimizer, actor_optimizer, alpha_optimizer = fabric.setup_optimizers(
hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()),
hydra.utils.instantiate(cfg.algo.actor.optimizer, params=agent.actor.parameters()),
hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha]),
qf_optimizer, actor_optimizer, alpha_optimizer
)

# Metrics
Expand All @@ -217,17 +235,27 @@ def main(fabric: Fabric, cfg: DictConfig):
memmap=cfg.buffer.memmap,
memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"),
)
if cfg.checkpoint.resume_from and cfg.buffer.checkpoint:
if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]):
rb = state["rb"][fabric.global_rank]
elif isinstance(state["rb"], ReplayBuffer):
rb = state["rb"]
else:
raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated")
step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device)

# Global variables
last_log = 0
last_train = 0
train_step = 0
policy_step = 0
last_checkpoint = 0
start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1
policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0
last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size)
num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0
if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint:
learning_starts += start_step

# Warning for log and checkpoint every
if cfg.metric.log_every % policy_steps_per_update != 0:
Expand All @@ -249,14 +277,15 @@ def main(fabric: Fabric, cfg: DictConfig):
# Get the first environment observation and start the optimization
obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32) # [N_envs, N_obs]

for update in range(1, num_updates + 1):
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * fabric.world_size

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with torch.no_grad():
actions, _ = actor.module(obs)
# Sample an action given the observation received by the environment
actions, _ = agent.actor.module(obs)
actions = actions.cpu().numpy()
next_obs, rewards, dones, truncated, infos = envs.step(actions)
dones = np.logical_or(dones, truncated)
Expand Down Expand Up @@ -332,12 +361,16 @@ def main(fabric: Fabric, cfg: DictConfig):
or cfg.dry_run
or update == num_updates
):
last_checkpoint = policy_step
state = {
"agent": agent.state_dict(),
"qf_optimizer": qf_optimizer.state_dict(),
"actor_optimizer": actor_optimizer.state_dict(),
"alpha_optimizer": alpha_optimizer.state_dict(),
"update": update,
"update": update * fabric.world_size,
"batch_size": cfg.per_rank_batch_size * fabric.world_size,
"last_log": last_log,
"last_checkpoint": last_checkpoint,
}
ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt")
fabric.call(
Expand All @@ -360,4 +393,4 @@ def main(fabric: Fabric, cfg: DictConfig):
mask_velocities=False,
vector_env_idx=0,
)()
test(actor.module, test_env, fabric, cfg)
test(agent.actor.module, test_env, fabric, cfg)
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/p2e_dv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def main(fabric: Fabric, cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv2/p2e_dv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def main(fabric: Fabric, cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
38 changes: 32 additions & 6 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import os
import pathlib
import warnings
from typing import Union

Expand Down Expand Up @@ -123,6 +124,18 @@ def main(fabric: Fabric, cfg: DictConfig):
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
root_dir = cfg.root_dir
run_name = cfg.run_name
state = fabric.load(cfg.checkpoint.resume_from)
ckpt_path = pathlib.Path(cfg.checkpoint.resume_from)
cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml")
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size
cfg.root_dir = root_dir
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
logger, log_dir = create_tensorboard_logger(fabric, cfg)
Expand Down Expand Up @@ -178,8 +191,15 @@ def main(fabric: Fabric, cfg: DictConfig):
is_continuous=is_continuous,
)

# Define the agent and the optimizer and setup them with Fabric
# Define the optimizer
optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters())

# Load the state from the checkpoint
if cfg.checkpoint.resume_from:
agent.load_state_dict(state["agent"])
optimizer.load_state_dict(state["optimizer"])

# Setup agent and optimizer with Fabric
agent = fabric.setup_module(agent)
optimizer = fabric.setup_optimizers(optimizer)

Expand Down Expand Up @@ -211,11 +231,12 @@ def main(fabric: Fabric, cfg: DictConfig):
step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device)

# Global variables
last_log = 0
last_train = 0
train_step = 0
policy_step = 0
last_checkpoint = 0
start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1
policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0
last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1

Expand All @@ -240,6 +261,8 @@ def main(fabric: Fabric, cfg: DictConfig):
from torch.optim.lr_scheduler import PolynomialLR

scheduler = PolynomialLR(optimizer=optimizer, total_iters=num_updates, power=1.0)
if cfg.checkpoint.resume_from:
scheduler.load_state_dict(state["scheduler"])

# Get the first environment observation and start the optimization
o = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs]
Expand All @@ -255,7 +278,7 @@ def main(fabric: Fabric, cfg: DictConfig):
next_obs[k] = torch_obs
next_done = torch.zeros(cfg.env.num_envs, 1, dtype=torch.float32).to(fabric.device) # [N_envs, 1]

for update in range(1, num_updates + 1):
for update in range(start_step, num_updates + 1):
for _ in range(0, cfg.algo.rollout_steps):
policy_step += cfg.env.num_envs * world_size

Expand Down Expand Up @@ -407,8 +430,11 @@ def main(fabric: Fabric, cfg: DictConfig):
state = {
"agent": agent.state_dict(),
"optimizer": optimizer.state_dict(),
"update_step": update,
"scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None,
"update": update * world_size,
"batch_size": cfg.per_rank_batch_size * fabric.world_size,
"last_log": last_log,
"last_checkpoint": last_checkpoint,
}
ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt")
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)
Expand Down
Loading

0 comments on commit d8a06aa

Please sign in to comment.