diff --git a/main.py b/main.py index bdbbf58..09cb902 100644 --- a/main.py +++ b/main.py @@ -205,40 +205,41 @@ def run(rank, world_size, args, print(model) if args.use_dist: - device = rank #{'cuda:%d' % 0: 'cuda:%d' % rank} - ddp_model = model - ddp_model.to(rank) + # device = rank #{'cuda:%d' % 0: 'cuda:%d' % rank} + device = torch.device(rank) + model.to(rank) #ddp_model = DDP(model, device_ids=[rank]) else: device = get_device(args) - ddp_model = model - ddp_model.to(device) + model.to(device) + # if args.load_model: + # if args.use_dist: + # map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} + # ddp_model.load_state_dict( + # torch.load(args.checkpoint_path, map_location=device) + # ) + # else: + # ddp_model.load_state_dict(torch.load(args.checkpoint_path)) + if args.load_model: - if args.use_dist: - map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} - ddp_model.load_state_dict( - torch.load(args.checkpoint_path, map_location=device) - ) - else: - model.load_state_dict(torch.load(args.checkpoint_path)) + model.load_state_dict(torch.load(args.checkpoint_path)) - # PDE loss function def loss_fn(data, u_scatter, data_2): return get_pde_loss(data, args.wavelength, args.n_background, u_scatter, - ddp_model, + model, device, args.use_pde_cl, args.two_d, data_2=data_2, ) - optimizer = torch.optim.Adam(ddp_model.parameters(), lr=args.learning_rate) + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) optimizer.zero_grad() # Train the PINN @@ -251,7 +252,7 @@ def loss_fn(data, u_scatter, data_2): test_loss_vec.append(test_loss) # Automatically synced here, don't need barrier if rank == 0: - torch.save(ddp_model.state_dict(), args.checkpoint_path) # save model + torch.save(model.state_dict(), args.checkpoint_path) # save model print("Saved PyTorch Model State to: " + args.checkpoint_path) torch.save(test_loss_vec, "test_loss_vec_" + str(rank) + ".pth") # save test loss print("Done! Rank: " + str(rank))