From 52e58c53fe61e629607817e6ade7fbc41b8382e8 Mon Sep 17 00:00:00 2001 From: Vidya Ganapati Date: Mon, 10 Jul 2023 18:54:37 -0700 Subject: [PATCH] adding PML to PINN --- README.md | 7 ++-- finite_difference.py | 12 +++---- utils.py | 81 ++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 90 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index c42ec78..7316e6d 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,10 @@ python $SCRATCH/PINN/main.py --2d --dist --epochs 1000 --bs 8836 --siren --upc - python $SCRATCH/PINN/main.py --2d --dist --epochs 100 --bs 17672 --siren ``` - +Debugging for adding the PML: +``` +python $SCRATCH/PINN/main.py --2d --epochs 100 +``` ## How to run the SLURM script on NERSC @@ -202,7 +205,7 @@ salloc -N 1 --time=60 -C cpu -A m3562 --qos=interactive Run the code: ``` cd $WORKING_DIR -python $SCRATCH/PINN/finite_differences.py +python $SCRATCH/PINN/finite_difference.py ``` The resulting scattered field is saved in `scattered.png`. diff --git a/finite_difference.py b/finite_difference.py index 4521b74..565ce92 100644 --- a/finite_difference.py +++ b/finite_difference.py @@ -4,11 +4,11 @@ import matplotlib.pyplot as plt # Inputs -x_start = [-7.0,-7.0] -x_end = [7.0,7.0] -x_step = [0.1,0.1] +x_start = [-10.0,-10.0] +x_end = [10.0,10.0] +x_step = [0.075,0.075] wavelength = 1.0 # wavelength in free space -pml_grid_points = 20 +pml_grid_points = 30 n_background = 1.33 radius = 3.0 @@ -43,9 +43,9 @@ # Helper function for PML -def e_i(i, domain_size_i, pml_grid_points, a_0=1): +def e_i(i, domain_size_i, pml_grid_points, a_0=0.25): """ - sigma in this code = sigma_paper/omega + sigma in this code = sigma_siren_paper/omega """ if i < pml_grid_points or i > domain_size_i-pml_grid_points-1: dist_to_edge = min(i,domain_size_i-1-i) diff --git a/utils.py b/utils.py index 4a8cf8c..8fa0129 100644 --- a/utils.py +++ b/utils.py @@ -134,6 +134,31 @@ def create_plane_wave_2d(data, k = torch.tensor([[kx, kz]], dtype=torch.cfloat, device=device) return(amplitude*torch.exp(-1j*(torch.sum(k*data, dim=1)))) +# Helper function for PML + +def e_i(i, domain_size_i, L_pml, a_0=0.25): + """ + sigma in this code = sigma_siren_paper/omega + domain_size_i and L_pml are in the same units + """ + if i < L_pml or i > domain_size_i-L_pml: + dist_to_edge = min(i,domain_size_i-i) + else: + dist_to_edge = 0 + sigma = a_0*(dist_to_edge/L_pml)**2 + + e = 1-1j*sigma + + d_dist_squared_d_x = 0 + if i < L_pml: + d_dist_squared_d_x = 2*i + elif i > domain_size_i-L_pml: + d_dist_squared_d_x = -2*(domain_size_i-i) + + coeff = -1j*a_0/L_pml**2 + + return e, coeff, d_dist_squared_d_x + def transform_linear_pde(data, k0, n_background, @@ -154,8 +179,53 @@ def transform_linear_pde(data, hess.append(hess_i) # Concatenate the outputs to form a single tensor hess = torch.stack(hess, dim=0) + breakpoint() + # get the Jacobian for computing the right hand side of the PDE with PML boundary conditions + # Right hand side of PDE with the PML boundary is (left hand side is unchanged): + # (d/dx eps_y/eps_x d/dx)(u_scatter) + (d/dy eps_x/eps_y d/dy)(u_scatter) + eps_x*eps_y*n**2*k0**2*u_scatter + jacobian_fn = torch.func.jacfwd(model, argnums=0) + use_vmap = False + if use_vmap: + jacobian = torch.vmap(jacobian_fn,in_dims=(0))(data) # jacobian + else: + jacobian = [] + for i in range(data.size(0)): + jacobian_i = jacobian_fn(data[i]) + jacobian.append(jacobian_i) + # Concatenate the outputs to form a single tensor + jacobian = torch.stack(jacobian, dim=0) + breakpoint() + jacobian_func = torch.func.jacrev(model, argnums=0) + jacobian = jacobian_func(data) + + jacobian = torch.func.jacrev(model, argnums=0)(data) + hess_2 = torch.func.jacfwd(jacobian_func)(data) + + breakpoint() + du_scatter_x = torch.squeeze(jacobian[:,:,:,:,0], dim=1) + if two_d: + du_scatter_z = torch.squeeze(jacobian[:,:,:,:,1], dim=1) + else: + du_scatter_y = torch.squeeze(jacobian[:,:,:,:,1], dim=1) + du_scatter_z = torch.squeeze(jacobian[:,:,:,:,2], dim=1) + + du_scatter_x_complex = du_scatter_x[:,:,0]+1j*du_scatter_x[:,:,1] + if not two_d: + du_scatter_y_complex = du_scatter_y[:,:,0]+1j*du_scatter_y[:,:,1] + du_scatter_z_complex = du_scatter_z[:,:,0]+1j*du_scatter_z[:,:,1] + + domain_size_x = 14 + L_pml_x = 2 + + domain_size_y = 14 + L_pml_y = 2 + + e_x, coeff_x, d_dist_squared_d_x = e_i(data_x, domain_size_x, L_pml_x, a_0=0.25) + e_y, coeff_y, d_dist_squared_d_y = e_i(data_y, domain_size_y, L_pml_y, a_0=0.25) + breakpoint() + refractive_index = evalulate_refractive_index(data, n_background) du_scatter_xx = torch.squeeze(hess[:,:,:,:,0,0], dim=1) @@ -171,11 +241,18 @@ def transform_linear_pde(data, du_scatter_yy_complex = du_scatter_yy[:,:,0]+1j*du_scatter_yy[:,:,1] du_scatter_zz_complex = du_scatter_zz[:,:,0]+1j*du_scatter_zz[:,:,1] u_scatter_complex = u_scatter[:,:,0]+1j*u_scatter[:,:,1] - + if two_d: - linear_pde = du_scatter_xx_complex+du_scatter_zz_complex+k0**2*torch.unsqueeze(refractive_index,dim=1)**2*u_scatter_complex + # linear_pde = du_scatter_xx_complex+du_scatter_zz_complex+k0**2*torch.unsqueeze(refractive_index,dim=1)**2*u_scatter_complex + linear_pde = (-1)*e_y*du_scatter_x_complex*(e_x)**(-2)*coeff_x*d_dist_squared_d_x + \ + (e_y/e_x) * (du_scatter_xx_complex) + \ + (-1)*e_x*du_scatter_y_complex*(e_y)**(-2)*coeff_y*d_dist_squared_d_y + \ + (e_x/e_y) * (du_scatter_zz_complex) + \ + e_x*e_y*k0**2*torch.unsqueeze(refractive_index,dim=1)**2*u_scatter_complex + else: linear_pde = du_scatter_xx_complex+du_scatter_yy_complex+du_scatter_zz_complex+k0**2*torch.unsqueeze(refractive_index,dim=1)**2*u_scatter_complex + # boundary condition not implemented ERROR return linear_pde, refractive_index, u_scatter_complex def transform_affine_pde(wavelength,