Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
vganapati committed Jun 22, 2023
1 parent 21679ac commit 03d56e5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_args():
parser.add_argument('--lr', type=float, action='store', dest='learning_rate',
help='learning rate', default = 1e-3)
parser.add_argument('-j', type=float, action='store', dest='jitter',
help='jitter for training data', default = 0.6)
help='jitter for training data', default = 0.5)
parser.add_argument('--show', action='store_true', dest='show_figures',
help='show figures')

Expand Down Expand Up @@ -311,7 +311,7 @@ def evaluate(eval_data_i,
return pde_loss_i, u_total, u_scatter, refractive_index, w, u_in

def visualize(args,
num_devices):
):
"""
Visualize the PINN with list of evaluation coordinates
Not yet implemented with distributed computing
Expand All @@ -323,7 +323,7 @@ def visualize(args,

eval_data, lengths = create_data(args.eval_data_x_start, args.eval_data_x_end,
args.eval_data_x_step, args.two_d)
eval_dataloader = DataLoader(eval_data, batch_size=args.batch_size//num_devices, shuffle=False)
eval_dataloader = DataLoader(eval_data, batch_size=args.batch_size, shuffle=False)


# Load model
Expand Down Expand Up @@ -543,6 +543,6 @@ def loss_fn(data, u_scatter, data_2,w=None):
training_partition, training_2_partition, test_partition,
)

visualize(args, world_size)
visualize(args)
end = time.time()
print("Time to train (s): " + str(end-start))
10 changes: 6 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,18 @@ def use(self, partition):
def use_all(self):
return Partition(self.data, self.indexes)


class NeuralNetwork(nn.Module):
def __init__(self, num_basis, two_d, num_hidden_layers=4, hidden_layer_width=64):
def __init__(self, num_basis, two_d, num_hidden_layers=3, hidden_layer_width=64):
super().__init__()
input_dim = 2 if two_d else 3
self.num_basis = num_basis
activation = nn.ELU #nn.Tanh
layers = [nn.Linear(input_dim, hidden_layer_width),
nn.Tanh()]
activation()]
for _ in range(num_hidden_layers):
layers.append(nn.Linear(hidden_layer_width, hidden_layer_width))
layers.append(nn.Tanh())
layers.append(activation())

layers.append(nn.Linear(hidden_layer_width, num_basis*2))
self.linear_relu_stack = nn.Sequential(*layers)
Expand Down Expand Up @@ -270,7 +272,7 @@ def get_pde_loss(data,

pde = linear_pde_combine-f
pde = torch.squeeze(pde, dim=1)
pde_loss = torch.sum(torch.abs(pde))
pde_loss = torch.sum(torch.abs(pde)**2)
return pde_loss, u_total, u_scatter_complex_combine, refractive_index, w

def average_gradients(model):
Expand Down

0 comments on commit 03d56e5

Please sign in to comment.