Skip to content

Commit

Permalink
removed use_dist and made only option
Browse files Browse the repository at this point in the history
  • Loading branch information
vganapati committed Jul 13, 2023
1 parent 80cd3d4 commit 4672048
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 93 deletions.
141 changes: 51 additions & 90 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,42 +60,31 @@ def get_args():
parser.add_argument('--show', action='store_true', dest='show_figures',
help='show figures')

# set the training region
parser.add_argument('--train_x_start', action='store', dest='training_data_x_start',
help='training data x start', nargs='+', default = [-21.0,-21.0])
parser.add_argument('--train_x_end', action='store', dest='training_data_x_end',
help='training data x end', nargs='+', default = [21.0,21.0])
# set the region
parser.add_argument('--x_start', action='store', dest='data_x_start',
help='boundary x start', nargs='+', default = [-14.0,-14.0])
parser.add_argument('--x_end', action='store', dest='data_x_end',
help='boundary data x end', nargs='+', default = [14.0,14.0])

# set the training spacing
parser.add_argument('--train_x_step', action='store', dest='training_data_x_step',
help='training data x step', nargs='+', default = [0.015,0.015])

# set the test region
parser.add_argument('--test_x_start', action='store', dest='test_data_x_start',
help='test data x start', nargs='+', default = [-14,-14])
parser.add_argument('--test_x_end', action='store', dest='test_data_x_end',
help='test data x end', nargs='+', default = [14,14])
# set the test spacing
parser.add_argument('--test_x_step', action='store', dest='test_data_x_step',
help='test data x step', nargs='+', default = [0.3,0.3])

# set the evaluation region subset for evaluting w
parser.add_argument('--eval_x_start_subset', action='store', dest='eval_data_x_start_subset',
help='evaluation data x start', nargs='+', default = [-7,-7])
parser.add_argument('--eval_x_end_subset', action='store', dest='eval_data_x_end_subset',
help='eval data x end', nargs='+', default = [7,7])
# set the evaluation region subset spacing for evaluting w
parser.add_argument('--eval_x_step_subset', action='store', dest='eval_data_x_step_subset',
help='evaluation data x step', nargs='+', default = [0.15,0.15])

# set the evaluation region
parser.add_argument('--eval_x_start', action='store', dest='eval_data_x_start',
help='evaluation data x start', nargs='+', default = [-14,-14])
parser.add_argument('--eval_x_end', action='store', dest='eval_data_x_end',
help='eval data x end', nargs='+', default = [14,14])
# set the evaluation region spacing for final visualization
parser.add_argument('--eval_x_step', action='store', dest='eval_data_x_step',
help='evaluation data x step', nargs='+', default = [0.05,0.05])

parser.add_argument('--load', action='store_true', dest='load_model',
help='load model from model.pth')
parser.add_argument('--dist', action='store_true', dest='use_dist',
help='use distributed training')

parser.add_argument('--checkpoint', action='store', dest='checkpoint_path',
help='path to checkpoint', default='model.pth')
args = parser.parse_args()
Expand All @@ -111,7 +100,7 @@ def setup(rank, world_size, fn, args,
# proc_id = int(os.environ['SLURM_PROCID'])

# print("Hello from " + str(proc_id))
# print(get_rank(args.use_dist))
# print(get_rank())

# initialize the process group
dist.init_process_group(backend, rank=rank, world_size=world_size)
Expand All @@ -135,11 +124,9 @@ def get_device(args, force_cpu=False):
print("Using " + str(device) + " device")
return device

def get_rank(use_dist):
if use_dist:
return dist.get_rank()
else:
return 0
def get_rank():
return dist.get_rank()


def partition_dataset(args, world_size):
"""
Expand All @@ -153,7 +140,7 @@ def partition_dataset(args, world_size):

# Training data to compute weights w
# Training data is a list of coordinates
training_data, _ = create_data(args.training_data_x_start, args.training_data_x_end,
training_data, _ = create_data(args.data_x_start, args.data_x_end,
args.training_data_x_step, args.two_d)

training_partition = DataPartitioner(training_data, partition_sizes, shuffle=True)
Expand All @@ -162,9 +149,9 @@ def partition_dataset(args, world_size):
# Training data to compute pde loss
# Training data is a list of coordinates
# This is only used if the linear system is underdetermined
if args.batch_size < args.num_basis:
training_data_2, _ = create_data(np.array(args.training_data_x_start),
np.array(args.training_data_x_end), args.training_data_x_step, args.two_d)
if args.batch_size < args.num_basis and args.use_pde_cl:
training_data_2, _ = create_data(np.array(args.data_x_start),
np.array(args.data_x_end), args.training_data_x_step, args.two_d)
training_2_partition = DataPartitioner(training_data_2, partition_sizes, shuffle=True)
else:
training_2_partition = None
Expand All @@ -185,32 +172,29 @@ def run(rank, world_size, args,
dtype = torch.float,
):

if args.use_dist:
print("Running on rank " + str(rank) + ". Running on rank " + str(get_rank(args.use_dist)))

if args.use_dist and args.use_pde_cl:
print("Running on rank " + str(rank) + ". Running on rank " + str(get_rank()))

if args.use_pde_cl:
training_partition = training_partition.use_all()
else:
training_partition = training_partition.use(get_rank(args.use_dist))
training_partition = training_partition.use(get_rank())
train_set = torch.utils.data.DataLoader(training_partition,
batch_size=args.batch_size,
shuffle=True)
if args.batch_size < args.num_basis and args.use_pde_cl:
if args.use_dist:
training_2_partition = training_2_partition.use_all()
else:
training_2_partition = training_2_partition.use(get_rank(args.use_dist))
training_2_partition = training_2_partition.use_all()
train_set_2 = torch.utils.data.DataLoader(training_2_partition,
batch_size=args.batch_size,
shuffle=True)
else:
train_set_2 = None


if args.use_dist and args.use_pde_cl:
if args.use_pde_cl:
test_partition = test_partition.use_all()
else:
test_partition = test_partition.use(get_rank(args.use_dist))
test_partition = test_partition.use(get_rank())
test_set = torch.utils.data.DataLoader(test_partition,
batch_size=args.batch_size,
shuffle=True)
Expand All @@ -227,26 +211,10 @@ def run(rank, world_size, args,
model = NeuralNetwork(args.num_basis, args.two_d)
print(model)

if args.use_dist:
# device = rank #{'cuda:%d' % 0: 'cuda:%d' % rank}
# device = torch.device(rank)

device = torch.device(f'cuda:{rank}')
model.to(device)
#ddp_model = DDP(model, device_ids=[rank])
else:
device = get_device(args)
model.to(device)


# if args.load_model:
# if args.use_dist:
# map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
# ddp_model.load_state_dict(
# torch.load(args.checkpoint_path, map_location=device)
# )
# else:
# ddp_model.load_state_dict(torch.load(args.checkpoint_path))
device = torch.device(f'cuda:{rank}')
# device = get_device(args)
model.to(device)

if args.load_model:
model.load_state_dict(torch.load(args.checkpoint_path))
Expand All @@ -272,7 +240,7 @@ def loss_fn(data, u_scatter, data_2):
start = time.time()
for t in range(args.epochs):
print("Epoch " + str(t+1) + "\n-------------------------------")
train(train_set, train_set_2, model, loss_fn, optimizer, dtype, args.jitter, device, args.use_dist)
train(train_set, train_set_2, model, loss_fn, optimizer, dtype, args.jitter, device)
test_loss = test(test_set, model, loss_fn, device)
test_loss_vec.append(test_loss)
# Automatically synced here, don't need barrier
Expand All @@ -282,8 +250,7 @@ def loss_fn(data, u_scatter, data_2):
torch.save(test_loss_vec, "test_loss_vec_" + str(rank) + ".pth") # save test loss
print("Done! Rank: " + str(rank))

if args.use_dist:
cleanup()
cleanup()

def evaluate(eval_data_i,
device,
Expand Down Expand Up @@ -521,38 +488,32 @@ def loss_fn(data, u_scatter, data_2,w=None):
args = get_args()

print(str(torch.cuda.device_count()) + " GPUs detected!")
if args.use_dist:
world_size = torch.cuda.device_count()
# world_size = int(os.environ['SLURM_NTASKS'])
else:
world_size = 1

# world_size = torch.cuda.device_count()
world_size = int(os.environ['SLURM_NTASKS'])

print('world_size is: ' + str(world_size))

training_partition, training_2_partition, test_partition = partition_dataset(args, world_size)
start = time.time()
if args.use_dist:
processes = []
mp.set_start_method("spawn")
for rank in range(world_size):
p = mp.Process(target=setup, args=(rank, world_size, run, args,
training_partition, training_2_partition, test_partition,
))
p.start()
processes.append(p)
for p in processes:
p.join()

processes = []
mp.set_start_method("spawn")
for rank in range(world_size):
p = mp.Process(target=setup, args=(rank, world_size, run, args,
training_partition, training_2_partition, test_partition,
))
p.start()
processes.append(p)

for p in processes:
p.join()


# mp.spawn(run,
# args=(world_size,args,torch.float), # arguments passed to demo_fn after the rank argument
# nprocs=world_size, # number of processes to spawn
# join=True)
else:
rank = 0
run(rank, world_size, args,
training_partition, training_2_partition, test_partition,
)
# rank = 0
# run(rank, world_size, args,
# training_partition, training_2_partition, test_partition,
# )

visualize(args)
end = time.time()
Expand Down
5 changes: 2 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ def train(dataloader,
dtype,
jitter,
device,
use_dist
):

"""Train the model for one epoch"""
Expand Down Expand Up @@ -416,8 +415,8 @@ def train(dataloader,
# Backpropagation
optimizer.zero_grad()
pde_loss.backward()
if use_dist:
average_gradients(model)

average_gradients(model)
torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
optimizer.step()
total_examples_finished += len(data)
Expand Down

0 comments on commit 4672048

Please sign in to comment.