diff --git a/README.md b/README.md index b3cfaf0..e3f888c 100644 --- a/README.md +++ b/README.md @@ -111,9 +111,19 @@ salloc -N 1 --time=120 -C gpu -A m3562_g --qos=interactive --ntasks-per-gpu=1 -- -per-task=32 -n 4 ``` +Check SLURM_NTASKS: +``` +echo $SLURM_NTASKS +``` +Should be 4, if not, run: +``` +export SLURM_NTASKS=4 +``` + Navigate to the working directory and run the code: ``` cd $SCRATCH/output_PINN +export MASTER_ADDR=$(hostname) python $SCRATCH/PINN/main.py --upc --2d --dist --epochs 2 --bs 8192 ``` diff --git a/main.py b/main.py index 37686b9..06a58fe 100644 --- a/main.py +++ b/main.py @@ -38,7 +38,9 @@ def get_args(): parser = argparse.ArgumentParser(description='Get command line args') parser.add_argument('--bs', type=int, action='store', dest='batch_size', - help='batch size', default = 8192) + help='batch size', default = 8192) + parser.add_argument('--ebs', type=int, action='store', dest='eval_batch_size', + help='eval batch size', default = 78400) parser.add_argument('--nb', type=int, action='store', dest='num_basis', help='number of basis functions, N in pde-cl paper', default = 200) parser.add_argument('--upc', action='store_true', dest='use_pde_cl', @@ -80,7 +82,7 @@ def get_args(): parser.add_argument('--eval_x_end', action='store', dest='eval_data_x_end', help='eval data x end', nargs='+', default = [14,14]) parser.add_argument('--eval_x_step', action='store', dest='eval_data_x_step', - help='evaluation data x step', nargs='+', default = [0.03,0.03]) + help='evaluation data x step', nargs='+', default = [0.1,0.1]) parser.add_argument('--load', action='store_true', dest='load_model', help='load model from model.pth') @@ -260,7 +262,8 @@ def loss_fn(data, u_scatter, data_2): if args.use_dist: cleanup() -def visualize(args, num_devices): +def visualize(args, + ): """ Visualize the PINN with list of evaluation coordinates Not yet implemented with distributed computing @@ -268,7 +271,11 @@ def visualize(args, num_devices): device = get_device(args) eval_data, lengths = create_data(args.eval_data_x_start, args.eval_data_x_end, args.eval_data_x_step, args.two_d) - eval_dataloader = DataLoader(eval_data, batch_size=args.batch_size//num_devices, shuffle=False) + eval_dataloader = DataLoader(eval_data, batch_size=args.eval_batch_size, shuffle=False) + + # XXX Options for solving the linear system: + # Solve the full linear system for the weights + # Solve the linear system for a subset of the points, use those weights for all points # Load model @@ -489,6 +496,6 @@ def loss_fn(data, u_scatter, data_2): training_partition, training_2_partition, test_partition, bsz, ) - visualize(args, world_size) + visualize(args) end = time.time() print("Time to train (s): " + str(end-start)) \ No newline at end of file diff --git a/notes.md b/notes.md index f9583cd..76a2d42 100644 --- a/notes.md +++ b/notes.md @@ -34,4 +34,188 @@ All the above could only sense 1 GPU for some reason. Modified code to distribute with NERSC: > sbatch /pscratch/sd/v/vidyagan/PINN/slurm_train.sh 8192 -Submitted batch job 10217467 \ No newline at end of file +Submitted batch job 10217467 + +Without pde cl, final eval loss: +loss: 0.025223 [10000/872356] +loss: 0.026122 [20000/872356] +loss: 0.026625 [30000/872356] +loss: 0.026221 [40000/872356] +loss: 0.026760 [50000/872356] +loss: 0.027671 [60000/872356] +loss: 0.027561 [70000/872356] +loss: 0.027397 [80000/872356] +loss: 0.028575 [90000/872356] +loss: 0.028962 [100000/872356] +loss: 0.028374 [110000/872356] +loss: 0.029337 [120000/872356] +loss: 0.030252 [130000/872356] +loss: 0.030004 [140000/872356] +loss: 0.030252 [150000/872356] +loss: 0.031832 [160000/872356] +loss: 0.032304 [170000/872356] +loss: 0.031949 [180000/872356] +loss: 0.033780 [190000/872356] +loss: 0.034950 [200000/872356] +loss: 0.034724 [210000/872356] +loss: 0.035757 [220000/872356] +loss: 0.037564 [230000/872356] +loss: 0.037605 [240000/872356] +loss: 0.037405 [250000/872356] +loss: 0.039424 [260000/872356] +loss: 0.040056 [270000/872356] +loss: 0.039693 [280000/872356] +loss: 0.041099 [290000/872356] +loss: 0.042823 [300000/872356] +loss: 0.042949 [310000/872356] +loss: 0.043323 [320000/872356] +loss: 0.045520 [330000/872356] +loss: 0.045651 [340000/872356] +loss: 2.894951 [350000/872356] +loss: 6.818404 [360000/872356] +loss: 9.862532 [370000/872356] +loss: 11.595950 [380000/872356] +loss: 12.644929 [390000/872356] +loss: 12.769813 [400000/872356] +loss: 14.462465 [410000/872356] +loss: 14.955965 [420000/872356] +loss: 13.862768 [430000/872356] +loss: 15.368446 [440000/872356] +loss: 15.297609 [450000/872356] +loss: 14.397958 [460000/872356] +loss: 13.916547 [470000/872356] +loss: 13.976737 [480000/872356] +loss: 13.075364 [490000/872356] +loss: 10.867071 [500000/872356] +loss: 10.427895 [510000/872356] +loss: 8.225926 [520000/872356] +loss: 4.228253 [530000/872356] +loss: 0.037942 [540000/872356] +loss: 0.037947 [550000/872356] +loss: 0.037693 [560000/872356] +loss: 0.038724 [570000/872356] +loss: 0.040180 [580000/872356] +loss: 0.039634 [590000/872356] +loss: 0.040058 [600000/872356] +loss: 0.042105 [610000/872356] +loss: 0.042263 [620000/872356] +loss: 0.042339 [630000/872356] +loss: 0.044006 [640000/872356] +loss: 0.045372 [650000/872356] +loss: 0.044975 [660000/872356] +loss: 0.045743 [670000/872356] +loss: 0.047650 [680000/872356] +loss: 0.047620 [690000/872356] +loss: 0.047710 [700000/872356] +loss: 0.049751 [710000/872356] +loss: 0.051020 [720000/872356] +loss: 0.050766 [730000/872356] +loss: 0.052310 [740000/872356] +loss: 0.054302 [750000/872356] +loss: 0.054105 [760000/872356] +loss: 0.054899 [770000/872356] +loss: 0.057041 [780000/872356] +loss: 0.057848 [790000/872356] +loss: 0.057734 [800000/872356] +loss: 0.059664 [810000/872356] +loss: 0.061118 [820000/872356] +loss: 0.060753 [830000/872356] +loss: 0.061968 [840000/872356] +loss: 0.063881 [850000/872356] +loss: 0.063879 [860000/872356] +loss: 0.064224 [870000/872356] +loss: 0.069460 [872356/872356] +Final eval pde loss is 2.5506897413441303 +Time to train (s): 39.422139167785645 + +With pde-cl, final eval loss: +loss: 0.000005 [10000/872356] +loss: 0.000005 [20000/872356] +loss: 0.000005 [30000/872356] +loss: 0.000005 [40000/872356] +loss: 0.000005 [50000/872356] +loss: 0.000005 [60000/872356] +loss: 0.000005 [70000/872356] +loss: 0.000005 [80000/872356] +loss: 0.000005 [90000/872356] +loss: 0.000005 [100000/872356] +loss: 0.000005 [110000/872356] +loss: 0.000005 [120000/872356] +loss: 0.000005 [130000/872356] +loss: 0.000005 [140000/872356] +loss: 0.000005 [150000/872356] +loss: 0.000005 [160000/872356] +loss: 0.000005 [170000/872356] +loss: 0.000005 [180000/872356] +loss: 0.000005 [190000/872356] +loss: 0.000005 [200000/872356] +loss: 0.000005 [210000/872356] +loss: 0.000005 [220000/872356] +loss: 0.000005 [230000/872356] +loss: 0.000005 [240000/872356] +loss: 0.000005 [250000/872356] +loss: 0.000005 [260000/872356] +loss: 0.000005 [270000/872356] +loss: 0.000005 [280000/872356] +loss: 0.000005 [290000/872356] +loss: 0.000005 [300000/872356] +loss: 0.000005 [310000/872356] +loss: 0.000005 [320000/872356] +loss: 0.000005 [330000/872356] +loss: 0.000005 [340000/872356] +loss: 3.005005 [350000/872356] +loss: 6.998339 [360000/872356] +loss: 10.109420 [370000/872356] +loss: 11.887402 [380000/872356] +loss: 12.992826 [390000/872356] +loss: 13.143788 [400000/872356] +loss: 14.718863 [410000/872356] +loss: 15.257994 [420000/872356] +loss: 14.131932 [430000/872356] +loss: 15.635180 [440000/872356] +loss: 15.537781 [450000/872356] +loss: 14.585422 [460000/872356] +loss: 13.993762 [470000/872356] +loss: 14.124367 [480000/872356] +loss: 13.233357 [490000/872356] +loss: 10.997641 [500000/872356] +loss: 10.532175 [510000/872356] +loss: 8.265468 [520000/872356] +loss: 4.216262 [530000/872356] +loss: 0.000005 [540000/872356] +loss: 0.000005 [550000/872356] +loss: 0.000005 [560000/872356] +loss: 0.000005 [570000/872356] +loss: 0.000005 [580000/872356] +loss: 0.000005 [590000/872356] +loss: 0.000005 [600000/872356] +loss: 0.000005 [610000/872356] +loss: 0.000005 [620000/872356] +loss: 0.000005 [630000/872356] +loss: 0.000005 [640000/872356] +loss: 0.000005 [650000/872356] +loss: 0.000005 [660000/872356] +loss: 0.000005 [670000/872356] +loss: 0.000005 [680000/872356] +loss: 0.000005 [690000/872356] +loss: 0.000005 [700000/872356] +loss: 0.000005 [710000/872356] +loss: 0.000005 [720000/872356] +loss: 0.000005 [730000/872356] +loss: 0.000005 [740000/872356] +loss: 0.000005 [750000/872356] +loss: 0.000005 [760000/872356] +loss: 0.000005 [770000/872356] +loss: 0.000005 [780000/872356] +loss: 0.000005 [790000/872356] +loss: 0.000005 [800000/872356] +loss: 0.000005 [810000/872356] +loss: 0.000005 [820000/872356] +loss: 0.000005 [830000/872356] +loss: 0.000005 [840000/872356] +loss: 0.000005 [850000/872356] +loss: 0.000005 [860000/872356] +loss: 0.000005 [870000/872356] +loss: 0.000005 [872356/872356] +Final eval pde loss is 2.5605068343657864 +Time to train (s): 44.82499980926514 \ No newline at end of file diff --git a/slurm_train.sh b/slurm_train.sh index c9e8581..8ef25bb 100644 --- a/slurm_train.sh +++ b/slurm_train.sh @@ -6,7 +6,7 @@ #SBATCH -A m2859_g # allocation #SBATCH -C gpu #SBATCH -q regular -#SBATCH -t 02:00:00 +#SBATCH -t 07:00:00 #SBATCH --gpus-per-node=4 #SBATCH --ntasks-per-gpu=1 #SBATCH --gpus 4 @@ -14,13 +14,12 @@ #SBATCH -e %j.err export MASTER_ADDR=$(hostname) -# export BATCH_SIZE=8192 export SCRATCH_FOLDER=$SCRATCH/output_PINN/$SLURM_JOB_ID mkdir -p $SCRATCH_FOLDER; cd $SCRATCH_FOLDER echo "jobstart $(date)";pwd -python $SCRATCH/PINN/main.py --2d --dist --epochs 3000 --bs 872356 -# python $SCRATCH/PINN/main.py --upc --2d --dist --bs $BATCH_SIZE --epochs 500 +python $SCRATCH/PINN/main.py --2d --dist --epochs 30000 --bs 872356 +# python $SCRATCH/PINN/main.py --upc --2d --dist --bs 8192 --epochs 500 echo "jobend $(date)";pwd \ No newline at end of file diff --git a/utils.py b/utils.py index c6cecca..5f153c2 100644 --- a/utils.py +++ b/utils.py @@ -140,8 +140,6 @@ def transform_linear_pde(data, ): '''Get the right hand side of the PDE (del**2 + n**2*k0**2)*u_scatter = -(n**2-n_background**2)*k0**2*u_in))''' hess_fn = torch.func.hessian(model, argnums=0) - - if use_vmap: hess = torch.vmap(hess_fn,in_dims=(0))(data) # hessian else: @@ -153,7 +151,6 @@ def transform_linear_pde(data, hess = torch.stack(hess, dim=0) - #hess = torch.zeros([data.size(0), 1, 200, 2, 2, 2],device=device) refractive_index = evalulate_refractive_index(data, n_background) du_scatter_xx = torch.squeeze(hess[:,:,:,:,0,0], dim=1) @@ -242,6 +239,7 @@ def get_pde_loss(data, linear_pde_combine = torch.matmul(linear_pde,w) u_scatter_complex_combine = torch.matmul(u_scatter_complex,w) u_scatter_complex_combine = torch.squeeze(u_scatter_complex_combine, dim=1) + # breakpoint() else: linear_pde_combine = linear_pde[:,0] linear_pde_combine = torch.unsqueeze(linear_pde_combine,dim=1)