Skip to content

Commit

Permalink
Cleaned up code
Browse files Browse the repository at this point in the history
  • Loading branch information
vganapati committed Jun 15, 2023
1 parent 5510f0e commit 67090af
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,40 +205,41 @@ def run(rank, world_size, args,
print(model)

if args.use_dist:
device = rank #{'cuda:%d' % 0: 'cuda:%d' % rank}
ddp_model = model
ddp_model.to(rank)
# device = rank #{'cuda:%d' % 0: 'cuda:%d' % rank}
device = torch.device(rank)
model.to(rank)
#ddp_model = DDP(model, device_ids=[rank])
else:
device = get_device(args)
ddp_model = model
ddp_model.to(device)
model.to(device)


# if args.load_model:
# if args.use_dist:
# map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
# ddp_model.load_state_dict(
# torch.load(args.checkpoint_path, map_location=device)
# )
# else:
# ddp_model.load_state_dict(torch.load(args.checkpoint_path))

if args.load_model:
if args.use_dist:
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
ddp_model.load_state_dict(
torch.load(args.checkpoint_path, map_location=device)
)
else:
model.load_state_dict(torch.load(args.checkpoint_path))
model.load_state_dict(torch.load(args.checkpoint_path))


# PDE loss function
def loss_fn(data, u_scatter, data_2):
return get_pde_loss(data,
args.wavelength,
args.n_background,
u_scatter,
ddp_model,
model,
device,
args.use_pde_cl,
args.two_d,
data_2=data_2,
)

optimizer = torch.optim.Adam(ddp_model.parameters(), lr=args.learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
optimizer.zero_grad()

# Train the PINN
Expand All @@ -251,7 +252,7 @@ def loss_fn(data, u_scatter, data_2):
test_loss_vec.append(test_loss)
# Automatically synced here, don't need barrier
if rank == 0:
torch.save(ddp_model.state_dict(), args.checkpoint_path) # save model
torch.save(model.state_dict(), args.checkpoint_path) # save model
print("Saved PyTorch Model State to: " + args.checkpoint_path)
torch.save(test_loss_vec, "test_loss_vec_" + str(rank) + ".pth") # save test loss
print("Done! Rank: " + str(rank))
Expand Down

0 comments on commit 67090af

Please sign in to comment.