Skip to content

Commit

Permalink
updated visualizer for pde-cl
Browse files Browse the repository at this point in the history
  • Loading branch information
vganapati committed Jun 17, 2023
1 parent f78643e commit 8bfeac2
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 13 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,19 @@ salloc -N 1 --time=120 -C gpu -A m3562_g --qos=interactive --ntasks-per-gpu=1 --
-per-task=32 -n 4
```

Check SLURM_NTASKS:
```
echo $SLURM_NTASKS
```
Should be 4, if not, run:
```
export SLURM_NTASKS=4
```

Navigate to the working directory and run the code:
```
cd $SCRATCH/output_PINN
export MASTER_ADDR=$(hostname)
python $SCRATCH/PINN/main.py --upc --2d --dist --epochs 2 --bs 8192
```

Expand Down
17 changes: 12 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def get_args():
parser = argparse.ArgumentParser(description='Get command line args')

parser.add_argument('--bs', type=int, action='store', dest='batch_size',
help='batch size', default = 8192)
help='batch size', default = 8192)
parser.add_argument('--ebs', type=int, action='store', dest='eval_batch_size',
help='eval batch size', default = 78400)
parser.add_argument('--nb', type=int, action='store', dest='num_basis',
help='number of basis functions, N in pde-cl paper', default = 200)
parser.add_argument('--upc', action='store_true', dest='use_pde_cl',
Expand Down Expand Up @@ -80,7 +82,7 @@ def get_args():
parser.add_argument('--eval_x_end', action='store', dest='eval_data_x_end',
help='eval data x end', nargs='+', default = [14,14])
parser.add_argument('--eval_x_step', action='store', dest='eval_data_x_step',
help='evaluation data x step', nargs='+', default = [0.03,0.03])
help='evaluation data x step', nargs='+', default = [0.1,0.1])

parser.add_argument('--load', action='store_true', dest='load_model',
help='load model from model.pth')
Expand Down Expand Up @@ -260,15 +262,20 @@ def loss_fn(data, u_scatter, data_2):
if args.use_dist:
cleanup()

def visualize(args, num_devices):
def visualize(args,
):
"""
Visualize the PINN with list of evaluation coordinates
Not yet implemented with distributed computing
"""
device = get_device(args)
eval_data, lengths = create_data(args.eval_data_x_start, args.eval_data_x_end,
args.eval_data_x_step, args.two_d)
eval_dataloader = DataLoader(eval_data, batch_size=args.batch_size//num_devices, shuffle=False)
eval_dataloader = DataLoader(eval_data, batch_size=args.eval_batch_size, shuffle=False)

# XXX Options for solving the linear system:
# Solve the full linear system for the weights
# Solve the linear system for a subset of the points, use those weights for all points

# Load model

Expand Down Expand Up @@ -489,6 +496,6 @@ def loss_fn(data, u_scatter, data_2):
training_partition, training_2_partition, test_partition, bsz,
)

visualize(args, world_size)
visualize(args)
end = time.time()
print("Time to train (s): " + str(end-start))
186 changes: 185 additions & 1 deletion notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,188 @@ All the above could only sense 1 GPU for some reason.

Modified code to distribute with NERSC:
> sbatch /pscratch/sd/v/vidyagan/PINN/slurm_train.sh 8192
Submitted batch job 10217467
Submitted batch job 10217467

Without pde cl, final eval loss:
loss: 0.025223 [10000/872356]
loss: 0.026122 [20000/872356]
loss: 0.026625 [30000/872356]
loss: 0.026221 [40000/872356]
loss: 0.026760 [50000/872356]
loss: 0.027671 [60000/872356]
loss: 0.027561 [70000/872356]
loss: 0.027397 [80000/872356]
loss: 0.028575 [90000/872356]
loss: 0.028962 [100000/872356]
loss: 0.028374 [110000/872356]
loss: 0.029337 [120000/872356]
loss: 0.030252 [130000/872356]
loss: 0.030004 [140000/872356]
loss: 0.030252 [150000/872356]
loss: 0.031832 [160000/872356]
loss: 0.032304 [170000/872356]
loss: 0.031949 [180000/872356]
loss: 0.033780 [190000/872356]
loss: 0.034950 [200000/872356]
loss: 0.034724 [210000/872356]
loss: 0.035757 [220000/872356]
loss: 0.037564 [230000/872356]
loss: 0.037605 [240000/872356]
loss: 0.037405 [250000/872356]
loss: 0.039424 [260000/872356]
loss: 0.040056 [270000/872356]
loss: 0.039693 [280000/872356]
loss: 0.041099 [290000/872356]
loss: 0.042823 [300000/872356]
loss: 0.042949 [310000/872356]
loss: 0.043323 [320000/872356]
loss: 0.045520 [330000/872356]
loss: 0.045651 [340000/872356]
loss: 2.894951 [350000/872356]
loss: 6.818404 [360000/872356]
loss: 9.862532 [370000/872356]
loss: 11.595950 [380000/872356]
loss: 12.644929 [390000/872356]
loss: 12.769813 [400000/872356]
loss: 14.462465 [410000/872356]
loss: 14.955965 [420000/872356]
loss: 13.862768 [430000/872356]
loss: 15.368446 [440000/872356]
loss: 15.297609 [450000/872356]
loss: 14.397958 [460000/872356]
loss: 13.916547 [470000/872356]
loss: 13.976737 [480000/872356]
loss: 13.075364 [490000/872356]
loss: 10.867071 [500000/872356]
loss: 10.427895 [510000/872356]
loss: 8.225926 [520000/872356]
loss: 4.228253 [530000/872356]
loss: 0.037942 [540000/872356]
loss: 0.037947 [550000/872356]
loss: 0.037693 [560000/872356]
loss: 0.038724 [570000/872356]
loss: 0.040180 [580000/872356]
loss: 0.039634 [590000/872356]
loss: 0.040058 [600000/872356]
loss: 0.042105 [610000/872356]
loss: 0.042263 [620000/872356]
loss: 0.042339 [630000/872356]
loss: 0.044006 [640000/872356]
loss: 0.045372 [650000/872356]
loss: 0.044975 [660000/872356]
loss: 0.045743 [670000/872356]
loss: 0.047650 [680000/872356]
loss: 0.047620 [690000/872356]
loss: 0.047710 [700000/872356]
loss: 0.049751 [710000/872356]
loss: 0.051020 [720000/872356]
loss: 0.050766 [730000/872356]
loss: 0.052310 [740000/872356]
loss: 0.054302 [750000/872356]
loss: 0.054105 [760000/872356]
loss: 0.054899 [770000/872356]
loss: 0.057041 [780000/872356]
loss: 0.057848 [790000/872356]
loss: 0.057734 [800000/872356]
loss: 0.059664 [810000/872356]
loss: 0.061118 [820000/872356]
loss: 0.060753 [830000/872356]
loss: 0.061968 [840000/872356]
loss: 0.063881 [850000/872356]
loss: 0.063879 [860000/872356]
loss: 0.064224 [870000/872356]
loss: 0.069460 [872356/872356]
Final eval pde loss is 2.5506897413441303
Time to train (s): 39.422139167785645

With pde-cl, final eval loss:
loss: 0.000005 [10000/872356]
loss: 0.000005 [20000/872356]
loss: 0.000005 [30000/872356]
loss: 0.000005 [40000/872356]
loss: 0.000005 [50000/872356]
loss: 0.000005 [60000/872356]
loss: 0.000005 [70000/872356]
loss: 0.000005 [80000/872356]
loss: 0.000005 [90000/872356]
loss: 0.000005 [100000/872356]
loss: 0.000005 [110000/872356]
loss: 0.000005 [120000/872356]
loss: 0.000005 [130000/872356]
loss: 0.000005 [140000/872356]
loss: 0.000005 [150000/872356]
loss: 0.000005 [160000/872356]
loss: 0.000005 [170000/872356]
loss: 0.000005 [180000/872356]
loss: 0.000005 [190000/872356]
loss: 0.000005 [200000/872356]
loss: 0.000005 [210000/872356]
loss: 0.000005 [220000/872356]
loss: 0.000005 [230000/872356]
loss: 0.000005 [240000/872356]
loss: 0.000005 [250000/872356]
loss: 0.000005 [260000/872356]
loss: 0.000005 [270000/872356]
loss: 0.000005 [280000/872356]
loss: 0.000005 [290000/872356]
loss: 0.000005 [300000/872356]
loss: 0.000005 [310000/872356]
loss: 0.000005 [320000/872356]
loss: 0.000005 [330000/872356]
loss: 0.000005 [340000/872356]
loss: 3.005005 [350000/872356]
loss: 6.998339 [360000/872356]
loss: 10.109420 [370000/872356]
loss: 11.887402 [380000/872356]
loss: 12.992826 [390000/872356]
loss: 13.143788 [400000/872356]
loss: 14.718863 [410000/872356]
loss: 15.257994 [420000/872356]
loss: 14.131932 [430000/872356]
loss: 15.635180 [440000/872356]
loss: 15.537781 [450000/872356]
loss: 14.585422 [460000/872356]
loss: 13.993762 [470000/872356]
loss: 14.124367 [480000/872356]
loss: 13.233357 [490000/872356]
loss: 10.997641 [500000/872356]
loss: 10.532175 [510000/872356]
loss: 8.265468 [520000/872356]
loss: 4.216262 [530000/872356]
loss: 0.000005 [540000/872356]
loss: 0.000005 [550000/872356]
loss: 0.000005 [560000/872356]
loss: 0.000005 [570000/872356]
loss: 0.000005 [580000/872356]
loss: 0.000005 [590000/872356]
loss: 0.000005 [600000/872356]
loss: 0.000005 [610000/872356]
loss: 0.000005 [620000/872356]
loss: 0.000005 [630000/872356]
loss: 0.000005 [640000/872356]
loss: 0.000005 [650000/872356]
loss: 0.000005 [660000/872356]
loss: 0.000005 [670000/872356]
loss: 0.000005 [680000/872356]
loss: 0.000005 [690000/872356]
loss: 0.000005 [700000/872356]
loss: 0.000005 [710000/872356]
loss: 0.000005 [720000/872356]
loss: 0.000005 [730000/872356]
loss: 0.000005 [740000/872356]
loss: 0.000005 [750000/872356]
loss: 0.000005 [760000/872356]
loss: 0.000005 [770000/872356]
loss: 0.000005 [780000/872356]
loss: 0.000005 [790000/872356]
loss: 0.000005 [800000/872356]
loss: 0.000005 [810000/872356]
loss: 0.000005 [820000/872356]
loss: 0.000005 [830000/872356]
loss: 0.000005 [840000/872356]
loss: 0.000005 [850000/872356]
loss: 0.000005 [860000/872356]
loss: 0.000005 [870000/872356]
loss: 0.000005 [872356/872356]
Final eval pde loss is 2.5605068343657864
Time to train (s): 44.82499980926514
7 changes: 3 additions & 4 deletions slurm_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,20 @@
#SBATCH -A m2859_g # allocation
#SBATCH -C gpu
#SBATCH -q regular
#SBATCH -t 02:00:00
#SBATCH -t 07:00:00
#SBATCH --gpus-per-node=4
#SBATCH --ntasks-per-gpu=1
#SBATCH --gpus 4
#SBATCH -o %j.out
#SBATCH -e %j.err

export MASTER_ADDR=$(hostname)
# export BATCH_SIZE=8192
export SCRATCH_FOLDER=$SCRATCH/output_PINN/$SLURM_JOB_ID
mkdir -p $SCRATCH_FOLDER; cd $SCRATCH_FOLDER

echo "jobstart $(date)";pwd

python $SCRATCH/PINN/main.py --2d --dist --epochs 3000 --bs 872356
# python $SCRATCH/PINN/main.py --upc --2d --dist --bs $BATCH_SIZE --epochs 500
python $SCRATCH/PINN/main.py --2d --dist --epochs 30000 --bs 872356
# python $SCRATCH/PINN/main.py --upc --2d --dist --bs 8192 --epochs 500

echo "jobend $(date)";pwd
4 changes: 1 addition & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ def transform_linear_pde(data,
):
'''Get the right hand side of the PDE (del**2 + n**2*k0**2)*u_scatter = -(n**2-n_background**2)*k0**2*u_in))'''
hess_fn = torch.func.hessian(model, argnums=0)


if use_vmap:
hess = torch.vmap(hess_fn,in_dims=(0))(data) # hessian
else:
Expand All @@ -153,7 +151,6 @@ def transform_linear_pde(data,
hess = torch.stack(hess, dim=0)


#hess = torch.zeros([data.size(0), 1, 200, 2, 2, 2],device=device)
refractive_index = evalulate_refractive_index(data, n_background)

du_scatter_xx = torch.squeeze(hess[:,:,:,:,0,0], dim=1)
Expand Down Expand Up @@ -242,6 +239,7 @@ def get_pde_loss(data,
linear_pde_combine = torch.matmul(linear_pde,w)
u_scatter_complex_combine = torch.matmul(u_scatter_complex,w)
u_scatter_complex_combine = torch.squeeze(u_scatter_complex_combine, dim=1)
# breakpoint()
else:
linear_pde_combine = linear_pde[:,0]
linear_pde_combine = torch.unsqueeze(linear_pde_combine,dim=1)
Expand Down

0 comments on commit 8bfeac2

Please sign in to comment.