Skip to content

Commit

Permalink
fix loading simulator and resuming from middle of epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
Sikan Li committed Jul 11, 2024
1 parent 86b4dbb commit f72b172
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def train(rank, cfg, world_size, device, verbose, use_dist):
cfg.model.path + cfg.model.train_state_file
):
# load model
if device == torch.device("cuda"):
if use_dist:
simulator.module.load(cfg.model.path + cfg.model.file)
else:
simulator.load(cfg.model.path + cfg.model.file)
Expand All @@ -481,9 +481,7 @@ def train(rank, cfg, world_size, device, verbose, use_dist):

# set optimizer state
optimizer = torch.optim.Adam(
simulator.module.parameters()
if device == torch.device("cuda")
else simulator.parameters()
simulator.module.parameters() if use_dist else simulator.parameters()
)
optimizer.load_state_dict(train_state["optimizer_state"])
optimizer_to(optimizer, device_id)
Expand Down Expand Up @@ -523,7 +521,8 @@ def train(rank, cfg, world_size, device, verbose, use_dist):

# Create a tqdm progress bar for each epoch
with tqdm(
total=len(train_dl),
# resume from one step after the checkpoint
range(step % len(train_dl) + 1, len(train_dl)),
desc=f"Epoch {epoch}",
unit="batch",
disable=not verbose,
Expand Down Expand Up @@ -847,8 +846,6 @@ def main(cfg: Config):
world_size = torch.cuda.device_count()
if cfg.hardware.cuda_device_number is not None and torch.cuda.is_available():
device = torch.device(f"cuda:{int(cfg.hardware.cuda_device_number)}")
# test code
print(f"device is {device} world size is {world_size}")
predict(device, cfg)


Expand Down

0 comments on commit f72b172

Please sign in to comment.