Skip to content

Commit

Permalink
adding PML to PINN
Browse files Browse the repository at this point in the history
  • Loading branch information
vganapati committed Jul 11, 2023
1 parent d0d1483 commit 52e58c5
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 10 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ python $SCRATCH/PINN/main.py --2d --dist --epochs 1000 --bs 8836 --siren --upc -
python $SCRATCH/PINN/main.py --2d --dist --epochs 100 --bs 17672 --siren
```


Debugging for adding the PML:
```
python $SCRATCH/PINN/main.py --2d --epochs 100
```

## How to run the SLURM script on NERSC

Expand Down Expand Up @@ -202,7 +205,7 @@ salloc -N 1 --time=60 -C cpu -A m3562 --qos=interactive
Run the code:
```
cd $WORKING_DIR
python $SCRATCH/PINN/finite_differences.py
python $SCRATCH/PINN/finite_difference.py
```
The resulting scattered field is saved in `scattered.png`.

Expand Down
12 changes: 6 additions & 6 deletions finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import matplotlib.pyplot as plt

# Inputs
x_start = [-7.0,-7.0]
x_end = [7.0,7.0]
x_step = [0.1,0.1]
x_start = [-10.0,-10.0]
x_end = [10.0,10.0]
x_step = [0.075,0.075]
wavelength = 1.0 # wavelength in free space
pml_grid_points = 20
pml_grid_points = 30
n_background = 1.33
radius = 3.0

Expand Down Expand Up @@ -43,9 +43,9 @@

# Helper function for PML

def e_i(i, domain_size_i, pml_grid_points, a_0=1):
def e_i(i, domain_size_i, pml_grid_points, a_0=0.25):
"""
sigma in this code = sigma_paper/omega
sigma in this code = sigma_siren_paper/omega
"""
if i < pml_grid_points or i > domain_size_i-pml_grid_points-1:
dist_to_edge = min(i,domain_size_i-1-i)
Expand Down
81 changes: 79 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,31 @@ def create_plane_wave_2d(data,
k = torch.tensor([[kx, kz]], dtype=torch.cfloat, device=device)
return(amplitude*torch.exp(-1j*(torch.sum(k*data, dim=1))))

# Helper function for PML

def e_i(i, domain_size_i, L_pml, a_0=0.25):
"""
sigma in this code = sigma_siren_paper/omega
domain_size_i and L_pml are in the same units
"""
if i < L_pml or i > domain_size_i-L_pml:
dist_to_edge = min(i,domain_size_i-i)
else:
dist_to_edge = 0
sigma = a_0*(dist_to_edge/L_pml)**2

e = 1-1j*sigma

d_dist_squared_d_x = 0
if i < L_pml:
d_dist_squared_d_x = 2*i
elif i > domain_size_i-L_pml:
d_dist_squared_d_x = -2*(domain_size_i-i)

coeff = -1j*a_0/L_pml**2

return e, coeff, d_dist_squared_d_x

def transform_linear_pde(data,
k0,
n_background,
Expand All @@ -154,8 +179,53 @@ def transform_linear_pde(data,
hess.append(hess_i)
# Concatenate the outputs to form a single tensor
hess = torch.stack(hess, dim=0)
breakpoint()

# get the Jacobian for computing the right hand side of the PDE with PML boundary conditions
# Right hand side of PDE with the PML boundary is (left hand side is unchanged):
# (d/dx eps_y/eps_x d/dx)(u_scatter) + (d/dy eps_x/eps_y d/dy)(u_scatter) + eps_x*eps_y*n**2*k0**2*u_scatter
jacobian_fn = torch.func.jacfwd(model, argnums=0)
use_vmap = False
if use_vmap:
jacobian = torch.vmap(jacobian_fn,in_dims=(0))(data) # jacobian
else:
jacobian = []
for i in range(data.size(0)):
jacobian_i = jacobian_fn(data[i])
jacobian.append(jacobian_i)
# Concatenate the outputs to form a single tensor
jacobian = torch.stack(jacobian, dim=0)

breakpoint()
jacobian_func = torch.func.jacrev(model, argnums=0)
jacobian = jacobian_func(data)

jacobian = torch.func.jacrev(model, argnums=0)(data)
hess_2 = torch.func.jacfwd(jacobian_func)(data)

breakpoint()
du_scatter_x = torch.squeeze(jacobian[:,:,:,:,0], dim=1)
if two_d:
du_scatter_z = torch.squeeze(jacobian[:,:,:,:,1], dim=1)
else:
du_scatter_y = torch.squeeze(jacobian[:,:,:,:,1], dim=1)
du_scatter_z = torch.squeeze(jacobian[:,:,:,:,2], dim=1)

du_scatter_x_complex = du_scatter_x[:,:,0]+1j*du_scatter_x[:,:,1]
if not two_d:
du_scatter_y_complex = du_scatter_y[:,:,0]+1j*du_scatter_y[:,:,1]
du_scatter_z_complex = du_scatter_z[:,:,0]+1j*du_scatter_z[:,:,1]

domain_size_x = 14
L_pml_x = 2

domain_size_y = 14
L_pml_y = 2

e_x, coeff_x, d_dist_squared_d_x = e_i(data_x, domain_size_x, L_pml_x, a_0=0.25)
e_y, coeff_y, d_dist_squared_d_y = e_i(data_y, domain_size_y, L_pml_y, a_0=0.25)
breakpoint()

refractive_index = evalulate_refractive_index(data, n_background)

du_scatter_xx = torch.squeeze(hess[:,:,:,:,0,0], dim=1)
Expand All @@ -171,11 +241,18 @@ def transform_linear_pde(data,
du_scatter_yy_complex = du_scatter_yy[:,:,0]+1j*du_scatter_yy[:,:,1]
du_scatter_zz_complex = du_scatter_zz[:,:,0]+1j*du_scatter_zz[:,:,1]
u_scatter_complex = u_scatter[:,:,0]+1j*u_scatter[:,:,1]

if two_d:
linear_pde = du_scatter_xx_complex+du_scatter_zz_complex+k0**2*torch.unsqueeze(refractive_index,dim=1)**2*u_scatter_complex
# linear_pde = du_scatter_xx_complex+du_scatter_zz_complex+k0**2*torch.unsqueeze(refractive_index,dim=1)**2*u_scatter_complex
linear_pde = (-1)*e_y*du_scatter_x_complex*(e_x)**(-2)*coeff_x*d_dist_squared_d_x + \
(e_y/e_x) * (du_scatter_xx_complex) + \
(-1)*e_x*du_scatter_y_complex*(e_y)**(-2)*coeff_y*d_dist_squared_d_y + \
(e_x/e_y) * (du_scatter_zz_complex) + \
e_x*e_y*k0**2*torch.unsqueeze(refractive_index,dim=1)**2*u_scatter_complex

else:
linear_pde = du_scatter_xx_complex+du_scatter_yy_complex+du_scatter_zz_complex+k0**2*torch.unsqueeze(refractive_index,dim=1)**2*u_scatter_complex
# boundary condition not implemented ERROR
return linear_pde, refractive_index, u_scatter_complex

def transform_affine_pde(wavelength,
Expand Down

0 comments on commit 52e58c5

Please sign in to comment.