Skip to content

Commit

Permalink
soft boundary on refractive index
Browse files Browse the repository at this point in the history
  • Loading branch information
vganapati committed Jul 18, 2023
1 parent fecf462 commit 6706f0d
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 13 deletions.
54 changes: 54 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,60 @@ export SLURM_NTASKS=8
srun --ntasks-per-node 4 -N 2 --gpus-per-task=1 python $SCRATCH/PINN/main.py --2d --epochs 100 --bs 160000 -j 0.025 --train_x_step 0.05 0.05
```

```
export MASTER_ADDR=$(hostname)
export SLURM_NTASKS=4
python $SCRATCH/PINN/main.py --2d --epochs 100 --bs 160000 -j 0.025 --train_x_step 0.05 0.05
```
Results:
```
Final eval pde loss is 4640.781312
Time to train (s): 79.0293037891388
```


```
export MASTER_ADDR=$(hostname)
export SLURM_NTASKS=1
python $SCRATCH/PINN/main.py --2d --epochs 100 --bs 160000 -j 0.025 --train_x_step 0.05 0.05
```
Results:
```
Final eval pde loss is 4640.165376
Time to train (s): 76.92376923561096
```

Same as above with 1000 epochs:
```
export MASTER_ADDR=$(hostname)
export SLURM_NTASKS=4
python $SCRATCH/PINN/main.py --2d --epochs 1000 --bs 160000 -j 0.025 --train_x_step 0.05 0.05
```
Results:
```
Final eval pde loss is 4640.256512
Time to train (s): 370.7292535305023
```

```
export MASTER_ADDR=$(hostname)
export SLURM_NTASKS=1
python $SCRATCH/PINN/main.py --2d --epochs 1000 --bs 160000 -j 0.025 --train_x_step 0.05 0.05
```
Results:
```
Final eval pde loss is 244.715568
Time to train (s): 625.481682062149
```


PDE-CL
```
export MASTER_ADDR=$(hostname)
export SLURM_NTASKS=4
python $SCRATCH/PINN/main.py --2d --upc --epochs 100 --bs 10000 --nb 10 -j 0.025 --train_x_step 0.05 0.05
```

## How to run the SLURM script on NERSC

Open a Perlmutter terminal.
Expand Down
9 changes: 5 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ def run(rank, world_size, args,
dtype = torch.float,
):

local_rank = get_rank()[1]
print("Running on rank " + str(rank) + ". Running on local rank " + str(local_rank))
# local_rank = get_rank()[1]
# print("Running on rank " + str(rank) + ". Running on local rank " + str(local_rank))
print("Running on rank " + str(rank) + ". Running on rank " + str(get_rank()))

train_set, train_set_2, test_set = get_train_test_sets(args, training_partition, training_2_partition, test_partition)

Expand All @@ -108,8 +109,8 @@ def run(rank, world_size, args,
model = NeuralNetwork(args.num_basis, args.two_d)
print(model)


device = torch.device(f'cuda:{local_rank}')
# device = torch.device(f'cuda:{local_rank}')
device = torch.device(f'cuda:{rank}')
model.to(device)

if args.load_model:
Expand Down
15 changes: 10 additions & 5 deletions utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,17 @@ def partition_dataset(args, world_size):

def get_train_test_sets(args, training_partition, training_2_partition, test_partition):
if args.use_pde_cl:
training_partition = training_partition.use_all()
# training_partition = training_partition.use_all()
training_partition = training_partition.use(get_rank())
else:
training_partition = training_partition.use(get_rank()[0])
training_partition = training_partition.use(get_rank())
# training_partition = training_partition.use(get_rank()[0])
train_set = torch.utils.data.DataLoader(training_partition,
batch_size=args.batch_size,
shuffle=True)
if args.batch_size < args.num_basis and args.use_pde_cl:
training_2_partition = training_2_partition.use_all()
training_2_partition = training_2_partition.use(get_rank())
# training_2_partition = training_2_partition.use_all()
train_set_2 = torch.utils.data.DataLoader(training_2_partition,
batch_size=args.batch_size,
shuffle=True)
Expand All @@ -115,9 +118,11 @@ def get_train_test_sets(args, training_partition, training_2_partition, test_par


if args.use_pde_cl:
test_partition = test_partition.use_all()
# test_partition = test_partition.use_all()
test_partition = test_partition.use(get_rank())
else:
test_partition = test_partition.use(get_rank()[0])
test_partition = test_partition.use(get_rank())
# test_partition = test_partition.use(get_rank()[0])
test_set = torch.utils.data.DataLoader(test_partition,
batch_size=args.batch_size,
shuffle=True)
Expand Down
4 changes: 3 additions & 1 deletion utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def setup(rank, world_size, fn, args,
training_partition, training_2_partition, test_partition) # this will be the run function

def get_rank():

'''
rank = int(os.environ["SLURM_PROCID"])
gpus_per_node = int(os.environ["SLURM_GPUS_ON_NODE"])
local_rank = rank - gpus_per_node * (rank // gpus_per_node)
return rank, local_rank
'''
return dist.get_rank()

def cleanup():
dist.destroy_process_group()
Expand Down
8 changes: 5 additions & 3 deletions utils/physics.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import numpy as np
import torch

def evalulate_refractive_index(data,
def evaluate_refractive_index(data,
n_background,
n_inclusion=1.88, # refractive index of the inclusion
radius=3, # um
sharpness=9, # higher value means sharper boundary
):
"""evalulate the refractive index at the data points for spherical/cylinderical dielectric"""
# refractive_index = torch.where(torch.sum(data**2,dim=1)<radius**2, 1.88, n_background)

dist = torch.sqrt(torch.sum(data**2,dim=1))
refractive_index = torch.sigmoid(2*(-dist+radius))
refractive_index = (n_inclusion-n_background)*torch.sigmoid(sharpness*(-dist+radius))+n_background
return refractive_index


Expand Down Expand Up @@ -161,7 +163,7 @@ def transform_linear_pde(data,
d_dist_squared_d_x = torch.unsqueeze(d_dist_squared_d_x,dim=1)
d_dist_squared_d_z = torch.unsqueeze(d_dist_squared_d_z,dim=1)

refractive_index = evalulate_refractive_index(data, n_background)
refractive_index = evaluate_refractive_index(data, n_background)

du_scatter_xx = torch.squeeze(hess[:,:,:,:,0,0], dim=1)

Expand Down
23 changes: 23 additions & 0 deletions utils/test_refractive_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
import torch
import matplotlib.pyplot as plt
from physics import evaluate_refractive_index

sharpness_vec = np.arange(1,10,1)
n_inclusion=1.88
n_background=1.33
radius=3
sharpness=2
x = np.arange(-10,10,0.1)
y = np.zeros_like(x)
data = np.stack((x,y),axis=1)
data = torch.tensor(data).float()

plt.figure()

for sharpness in sharpness_vec:
refractive_index = evaluate_refractive_index(data, n_background, n_inclusion, radius, sharpness)
plt.plot(x, refractive_index, label="sharpness = " + str(sharpness))
breakpoint()
plt.legend()
plt.savefig("refractive_index_line_plot.png")

0 comments on commit 6706f0d

Please sign in to comment.