Skip to content

Commit

Permalink
distributed computing for pde-cl PINN
Browse files Browse the repository at this point in the history
  • Loading branch information
vganapati committed Jun 19, 2023
1 parent cb6af76 commit 3301feb
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,16 @@ Navigate to the working directory and run the code:
```
cd $SCRATCH/output_PINN
export MASTER_ADDR=$(hostname)
python main.py --2d --epochs 500 --nb 50 --dist --upc
python $SCRATCH/PINN/main.py --2d --epochs 500 --nb 50 --dist --upc
python $SCRATCH/PINN/main.py --2d --epochs 2 --bs 8192 --nb 100 --dist --upc
python $SCRATCH/PINN/main.py --2d --epochs 2 --bs 8192 --nb 50 --dist --upc
python $SCRATCH/PINN/main.py --upc --2d --dist --epochs 2 --bs 8192
```



## How to run the SLURM script on NERSC

Open a Perlmutter terminal.
Expand Down
53 changes: 31 additions & 22 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_args():
parser = argparse.ArgumentParser(description='Get command line args')

parser.add_argument('--bs', type=int, action='store', dest='batch_size',
help='batch size', default = 8192)
help='batch size per gpu', default = 8836)
parser.add_argument('--nb', type=int, action='store', dest='num_basis',
help='number of basis functions, N in pde-cl paper', default = 200)
parser.add_argument('--upc', action='store_true', dest='use_pde_cl',
Expand All @@ -54,41 +54,41 @@ def get_args():
parser.add_argument('--lr', type=float, action='store', dest='learning_rate',
help='learning rate', default = 1e-3)
parser.add_argument('-j', type=float, action='store', dest='jitter',
help='jitter for training data', default = 0.015)
help='jitter for training data', default = 0.6)
parser.add_argument('--show', action='store_true', dest='show_figures',
help='show figures')

# set the training region
parser.add_argument('--train_x_start', action='store', dest='training_data_x_start',
help='training data x start', nargs='+', default = [-14,-14])
help='training data x start', nargs='+', default = [-7,-7])
parser.add_argument('--train_x_end', action='store', dest='training_data_x_end',
help='training data x end', nargs='+', default = [14,14])
help='training data x end', nargs='+', default = [7,7])
parser.add_argument('--train_x_step', action='store', dest='training_data_x_step',
help='training data x step', nargs='+', default = [0.03,0.03])
help='training data x step', nargs='+', default = [0.15,0.15])

# set the test region
parser.add_argument('--test_x_start', action='store', dest='test_data_x_start',
help='test data x start', nargs='+', default = [-14,-14])
parser.add_argument('--test_x_end', action='store', dest='test_data_x_end',
help='test data x end', nargs='+', default = [14,14])
parser.add_argument('--test_x_step', action='store', dest='test_data_x_step',
help='test data x step', nargs='+', default = [0.5,0.5])
help='test data x step', nargs='+', default = [0.3,0.3])

# set the evaluation region subset for evaluting w
parser.add_argument('--eval_x_start_subset', action='store', dest='eval_data_x_start_subset',
help='evaluation data x start', nargs='+', default = [-7,-7])
parser.add_argument('--eval_x_end_subset', action='store', dest='eval_data_x_end_subset',
help='eval data x end', nargs='+', default = [7,7])
parser.add_argument('--eval_x_step_subset', action='store', dest='eval_data_x_step_subset',
help='evaluation data x step', nargs='+', default = [0.06,0.06])
help='evaluation data x step', nargs='+', default = [0.15,0.15])

# set the evaluation region
parser.add_argument('--eval_x_start', action='store', dest='eval_data_x_start',
help='evaluation data x start', nargs='+', default = [-14,-14])
parser.add_argument('--eval_x_end', action='store', dest='eval_data_x_end',
help='eval data x end', nargs='+', default = [14,14])
parser.add_argument('--eval_x_step', action='store', dest='eval_data_x_step',
help='evaluation data x step', nargs='+', default = [0.03,0.03])
help='evaluation data x step', nargs='+', default = [0.05,0.05])

parser.add_argument('--load', action='store_true', dest='load_model',
help='load model from model.pth')
Expand All @@ -101,7 +101,7 @@ def get_args():


def setup(rank, world_size, fn, args,
training_partition, training_2_partition, test_partition, bsz,
training_partition, training_2_partition, test_partition,
backend='nccl'):
os.environ['MASTER_PORT'] = '29500'

Expand All @@ -114,7 +114,7 @@ def setup(rank, world_size, fn, args,
# initialize the process group
dist.init_process_group(backend, rank=rank, world_size=world_size)
fn(rank,world_size, args,
training_partition, training_2_partition, test_partition, bsz) # this will be the run function
training_partition, training_2_partition, test_partition) # this will be the run function



Expand Down Expand Up @@ -145,7 +145,6 @@ def partition_dataset(args, world_size):
size is the world size (number of ranks)
"""

bsz = args.batch_size//world_size
partition_sizes = [1.0 / world_size for _ in range(world_size)]

# Create full dataset
Expand Down Expand Up @@ -174,33 +173,43 @@ def partition_dataset(args, world_size):
test_partition = DataPartitioner(test_data, partition_sizes, shuffle=True)


return training_partition, training_2_partition, test_partition, bsz
return training_partition, training_2_partition, test_partition



def run(rank, world_size, args,
training_partition, training_2_partition, test_partition, bsz,
training_partition, training_2_partition, test_partition,
dtype = torch.float,
):

if args.use_dist:
print("Running on rank " + str(rank) + ". Running on rank " + str(get_rank(args.use_dist)))

training_partition = training_partition.use(get_rank(args.use_dist))
if args.use_dist and args.use_pde_cl:
training_partition = training_partition.use_all()
else:
training_partition = training_partition.use(get_rank(args.use_dist))
train_set = torch.utils.data.DataLoader(training_partition,
batch_size=bsz,
batch_size=args.batch_size,
shuffle=True)
if args.batch_size < args.num_basis:
training_2_partition = training_2_partition.use(get_rank(args.use_dist))
if args.use_dist and args.use_pde_cl:
training_2_partition = training_2_partition.use_all()
else:
training_2_partition = training_2_partition.use(get_rank(args.use_dist))
train_set_2 = torch.utils.data.DataLoader(training_2_partition,
batch_size=bsz,
batch_size=args.batch_size,
shuffle=True)
else:
train_set_2 = None

test_partition = test_partition.use(get_rank(args.use_dist))

if args.use_dist and args.use_pde_cl:
test_partition = test_partition.use_all()
else:
test_partition = test_partition.use(get_rank(args.use_dist))
test_set = torch.utils.data.DataLoader(test_partition,
batch_size=bsz,
batch_size=args.batch_size,
shuffle=True)

# Force num_basis = 1 if not using pde-cl
Expand Down Expand Up @@ -508,14 +517,14 @@ def loss_fn(data, u_scatter, data_2,w=None):
world_size = 1
print('world_size is: ' + str(world_size))

training_partition, training_2_partition, test_partition, bsz = partition_dataset(args, world_size)
training_partition, training_2_partition, test_partition = partition_dataset(args, world_size)
start = time.time()
if args.use_dist:
processes = []
mp.set_start_method("spawn")
for rank in range(world_size):
p = mp.Process(target=setup, args=(rank, world_size, run, args,
training_partition, training_2_partition, test_partition, bsz,
training_partition, training_2_partition, test_partition,
))
p.start()
processes.append(p)
Expand All @@ -531,7 +540,7 @@ def loss_fn(data, u_scatter, data_2,w=None):
else:
rank = 0
run(rank, world_size, args,
training_partition, training_2_partition, test_partition, bsz,
training_partition, training_2_partition, test_partition,
)

visualize(args, world_size)
Expand Down
4 changes: 4 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,18 @@ def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234, shuffle=True):
indexes = [x for x in range(0, data_len)]
if shuffle:
rng.shuffle(indexes)
self.indexes = indexes

for ind,frac in enumerate(sizes):
part_len = int(frac*data_len)
self.partitions.append(indexes[0:part_len])
indexes = indexes[part_len:]


def use(self, partition):
return Partition(self.data, self.partitions[partition])
def use_all(self):
return Partition(self.data, self.indexes)

class NeuralNetwork(nn.Module):
def __init__(self, num_basis, two_d, num_hidden_layers=4, hidden_layer_width=64):
Expand Down

0 comments on commit 3301feb

Please sign in to comment.