Skip to content

Commit

Permalink
added pml
Browse files Browse the repository at this point in the history
  • Loading branch information
vganapati committed Jul 13, 2023
1 parent 4672048 commit 705253f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ python $SCRATCH/PINN/main.py --2d --dist --epochs 100 --bs 17672 --siren
Debugging for adding the PML:
```
python $SCRATCH/PINN/main.py --2d --epochs 100 --bs 100
python $SCRATCH/PINN/main.py --2d --epochs 100 --bs 4900
```

## How to run the SLURM script on NERSC
Expand Down
26 changes: 19 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_args():
parser.add_argument('--lr', type=float, action='store', dest='learning_rate',
help='learning rate', default = 1e-3)
parser.add_argument('-j', type=float, action='store', dest='jitter',
help='jitter for training data', default = 0.5)
help='jitter for training data', default = 0.4)
parser.add_argument('--show', action='store_true', dest='show_figures',
help='show figures')

Expand All @@ -66,17 +66,21 @@ def get_args():
parser.add_argument('--x_end', action='store', dest='data_x_end',
help='boundary data x end', nargs='+', default = [14.0,14.0])

# set the pml thickness
parser.add_argument('--pml_thickness', action='store', dest='pml_thickness',
help='pml thickness', nargs='+', default = [2.0,2.0])

# set the training spacing
parser.add_argument('--train_x_step', action='store', dest='training_data_x_step',
help='training data x step', nargs='+', default = [0.015,0.015])
help='training data x step', nargs='+', default = [0.4,0.4])

# set the test spacing
parser.add_argument('--test_x_step', action='store', dest='test_data_x_step',
help='test data x step', nargs='+', default = [0.3,0.3])
help='test data x step', nargs='+', default = [0.4,0.4])

# set the evaluation region subset spacing for evaluting w
parser.add_argument('--eval_x_step_subset', action='store', dest='eval_data_x_step_subset',
help='evaluation data x step', nargs='+', default = [0.15,0.15])
help='evaluation data x step', nargs='+', default = [0.2,0.2])

# set the evaluation region spacing for final visualization
parser.add_argument('--eval_x_step', action='store', dest='eval_data_x_step',
Expand Down Expand Up @@ -158,7 +162,7 @@ def partition_dataset(args, world_size):

# Test data for validation of pde loss
# Test data is a list of coordinates
test_data, _ = create_data(args.test_data_x_start, args.test_data_x_end,
test_data, _ = create_data(args.data_x_start, args.data_x_end,
args.test_data_x_step, args.two_d)
test_partition = DataPartitioner(test_data, partition_sizes, shuffle=True)

Expand Down Expand Up @@ -229,6 +233,10 @@ def loss_fn(data, u_scatter, data_2):
device,
args.use_pde_cl,
args.two_d,
args.data_x_end[0]-args.data_x_start[0],
args.data_x_end[1]-args.data_x_start[1],
args.pml_thickness[0],
args.pml_thickness[1],
data_2=data_2,
)

Expand Down Expand Up @@ -291,10 +299,10 @@ def visualize(args,
"""
device = get_device(args)
# Solve the linear system for a subset of the points, use those weights for all points
eval_data_subset, _ = create_data(args.eval_data_x_start_subset, args.eval_data_x_end_subset,
eval_data_subset, _ = create_data(args.data_x_start, args.data_x_end,
args.eval_data_x_step_subset, args.two_d)

eval_data, lengths = create_data(args.eval_data_x_start, args.eval_data_x_end,
eval_data, lengths = create_data(args.data_x_start, args.data_x_end,
args.eval_data_x_step, args.two_d)
eval_dataloader = DataLoader(eval_data, batch_size=args.batch_size, shuffle=False)

Expand Down Expand Up @@ -323,6 +331,10 @@ def loss_fn(data, u_scatter, data_2,w=None):
device,
args.use_pde_cl,
args.two_d,
args.data_x_end[0]-args.data_x_start[0],
args.data_x_end[1]-args.data_x_start[1],
args.pml_thickness[0],
args.pml_thickness[1],
data_2=data_2,
w=w,
)
Expand Down
28 changes: 24 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ def transform_linear_pde(data,
model,
two_d,
device,
domain_size_x=14,
domain_size_z=14,
L_pml_x=2,
L_pml_z=2,
domain_size_x,
domain_size_z,
L_pml_x,
L_pml_z,
a_0=0.25,
use_vmap=True,
):
Expand Down Expand Up @@ -276,6 +276,10 @@ def transform_affine_pde(wavelength,
model,
device,
two_d,
domain_size_x,
domain_size_z,
L_pml_x,
L_pml_z,
):
'''Get the right and left hand side of the PDE (del**2 + n**2*k0**2)*u_scatter = -(n**2-n_background**2)*k0**2*u_in))'''

Expand All @@ -287,6 +291,10 @@ def transform_affine_pde(wavelength,
model,
two_d,
device,
domain_size_x,
domain_size_z,
L_pml_x,
L_pml_z,
)
if two_d:
u_in = create_plane_wave_2d(data,
Expand All @@ -312,6 +320,10 @@ def get_pde_loss(data,
device,
use_pde_cl,
two_d,
domain_size_x,
domain_size_z,
L_pml_x,
L_pml_z,
data_2=None,
w=None,
):
Expand All @@ -327,6 +339,10 @@ def get_pde_loss(data,
model,
device,
two_d,
domain_size_x,
domain_size_z,
L_pml_x,
L_pml_z,
)


Expand Down Expand Up @@ -356,6 +372,10 @@ def get_pde_loss(data,
model,
device,
two_d,
domain_size_x,
domain_size_z,
L_pml_x,
L_pml_z,
)
linear_pde_combine = torch.matmul(linear_pde,w)

Expand Down

0 comments on commit 705253f

Please sign in to comment.