-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalising the training module to accomodate varying input #20
base: main
Are you sure you want to change the base?
Changes from all commits
3da88a0
4cc821f
b29d46a
793f693
992cb06
da1eb23
5cf19cc
0b595cc
88aa237
8b6175c
2639bbe
fec3e25
28dd18f
b0ca08d
ae1ffa7
460db98
e85a661
83e88ab
8f53dba
08b6987
667d5f7
0d71a7d
81f0167
fd780f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,158 +1,141 @@ | ||
import torch | ||
"""Neural Network model for the CAM-EM.""" | ||
|
||
import netCDF4 as nc | ||
import numpy as np | ||
import scipy.stats as st | ||
import torch | ||
import xarray as xr | ||
from torch import nn | ||
from torch.nn.utils import prune | ||
from torch.utils.data import DataLoader, Dataset | ||
from torch.utils.data import Dataset | ||
|
||
|
||
# Required for feeding the data iinto NN. | ||
class myDataset(Dataset): | ||
""" | ||
Dataset class for loading features and labels. | ||
|
||
Args: | ||
X (numpy.ndarray): Input features. | ||
Y (numpy.ndarray): Corresponding labels. | ||
""" | ||
|
||
def __init__(self, X, Y): | ||
"""Create an instance of myDataset class.""" | ||
""" | ||
Parameters: | ||
X (tensor): Input data. | ||
Y (tensor): Output data. | ||
""" | ||
self.features = torch.tensor(X, dtype=torch.float64) | ||
self.labels = torch.tensor(Y, dtype=torch.float64) | ||
|
||
def __len__(self): | ||
"""Return the number of samples in the dataset.""" | ||
"""Function that is called when you call len(dataloader)""" | ||
return len(self.features.T) | ||
|
||
def __getitem__(self, idx): | ||
"""Return a sample from the dataset.""" | ||
"""Function that is called when you call dataloader""" | ||
feature = self.features[:, idx] | ||
label = self.labels[:, idx] | ||
|
||
return feature, label | ||
|
||
|
||
# The NN model. | ||
class NormalizationLayer(nn.Module): | ||
def __init__(self, mean, std): | ||
super(NormalizationLayer, self).__init__() | ||
self.mean = mean | ||
self.std = std | ||
|
||
def forward(self, x): | ||
return (x - self.mean) / self.std | ||
|
||
class FullyConnected(nn.Module): | ||
""" | ||
Fully connected neural network model. | ||
|
||
The model consists of multiple fully connected layers with SiLU activation function. | ||
|
||
Attributes | ||
---------- | ||
linear_stack (torch.nn.Sequential): Sequential container for layers. | ||
linear_stack : nn.Sequential | ||
Sequential container of linear layers and activation functions. | ||
""" | ||
|
||
def __init__(self): | ||
"""Create an instance of FullyConnected NN model.""" | ||
def __init__(self, ilev=93, in_ver=8,in_nover=4,out_ver=2, hidden_layers=8, hidden_size=500): | ||
super(FullyConnected, self).__init__() | ||
ilev = 93 | ||
|
||
self.linear_stack = nn.Sequential( | ||
nn.Linear(8 * ilev + 4, 500, dtype=torch.float64), | ||
nn.SiLU(), | ||
nn.Linear(500, 500, dtype=torch.float64), | ||
nn.SiLU(), | ||
nn.Linear(500, 500, dtype=torch.float64), | ||
nn.SiLU(), | ||
nn.Linear(500, 500, dtype=torch.float64), | ||
nn.SiLU(), | ||
nn.Linear(500, 500, dtype=torch.float64), | ||
nn.SiLU(), | ||
nn.Linear(500, 500, dtype=torch.float64), | ||
nn.SiLU(), | ||
nn.Linear(500, 500, dtype=torch.float64), | ||
nn.SiLU(), | ||
nn.Linear(500, 500, dtype=torch.float64), | ||
nn.SiLU(), | ||
nn.Linear(500, 500, dtype=torch.float64), | ||
nn.SiLU(), | ||
nn.Linear(500, 500, dtype=torch.float64), | ||
nn.SiLU(), | ||
nn.Linear(500, 500, dtype=torch.float64), | ||
nn.SiLU(), | ||
nn.Linear(500, 500, dtype=torch.float64), | ||
nn.SiLU(), | ||
nn.Linear(500, 2 * ilev, dtype=torch.float64), | ||
) | ||
self.ilev = ilev | ||
self.in_ver = in_ver | ||
self.in_nover = in_nover | ||
self.out_ver = out_ver | ||
self.hidden_layers = hidden_layers | ||
self.hidden_size = hidden_size | ||
|
||
layers = [] | ||
input_size = in_ver * ilev + in_nover | ||
for _ in range(hidden_layers): | ||
layers.append(nn.Linear(input_size, hidden_size, dtype=torch.float64)) | ||
layers.append(nn.SiLU()) | ||
input_size = hidden_size | ||
layers.append(nn.Linear(hidden_size, out_ver * ilev, dtype=torch.float64)) | ||
self.linear_stack = nn.Sequential(*layers) | ||
Comment on lines
+61
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A brief comment to summarise what this code is doing for those used to writing it in the previous layout might be useful here. |
||
|
||
def forward(self, X): | ||
""" | ||
Forward pass through the network. | ||
|
||
Args: | ||
X (torch.Tensor): Input tensor. | ||
Parameters | ||
---------- | ||
X : torch.Tensor | ||
Input tensor. | ||
|
||
Returns | ||
------- | ||
torch.Tensor: Output tensor. | ||
torch.Tensor | ||
Output tensor. | ||
""" | ||
return self.linear_stack(X) | ||
|
||
|
||
# training loop | ||
def train_loop(dataloader, model, loss_fn, optimizer): | ||
""" | ||
Training loop. | ||
|
||
Args: | ||
dataloader (DataLoader): DataLoader for training data. | ||
model (nn.Module): Neural network model. | ||
loss_fn (torch.nn.Module): Loss function. | ||
optimizer (torch.optim.Optimizer): Optimizer. | ||
|
||
Returns | ||
------- | ||
float: Average training loss. | ||
class EarlyStopper: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not immediately clear to me why this should be a class rather than a function, and whether it belons here with the Model, or if it would be better being moved to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @surbhigoel77 agrees this should be a singe function with inputs, and moved to |
||
""" | ||
size = len(dataloader.dataset) | ||
avg_loss = 0 | ||
for batch, (X, Y) in enumerate(dataloader): | ||
# Compute prediction and loss | ||
pred = model(X) | ||
loss = loss_fn(pred, Y) | ||
|
||
# Backpropagation | ||
optimizer.zero_grad(set_to_none=True) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
with torch.no_grad(): | ||
avg_loss += loss.item() | ||
|
||
avg_loss /= len(dataloader) | ||
|
||
return avg_loss | ||
Early stopping utility to stop training when validation loss doesn't improve. | ||
|
||
Parameters | ||
---------- | ||
patience : int, optional | ||
Number of epochs to wait before stopping (default is 1). | ||
min_delta : float, optional | ||
Minimum change in the monitored quantity to qualify as an improvement (default is 0). | ||
|
||
# validating loop | ||
def val_loop(dataloader, model, loss_fn): | ||
Attributes | ||
---------- | ||
patience : int | ||
Number of epochs to wait before stopping. | ||
min_delta : float | ||
Minimum change in the monitored quantity to qualify as an improvement. | ||
counter : int | ||
Counter for the number of epochs without improvement. | ||
min_validation_loss : float | ||
Minimum validation loss recorded. | ||
""" | ||
Validation loop. | ||
|
||
Args: | ||
dataloader (DataLoader): DataLoader for validation data. | ||
model (nn.Module): Neural network model. | ||
loss_fn (torch.nn.Module): Loss function. | ||
def __init__(self, patience=1, min_delta=0): | ||
self.patience = patience | ||
self.min_delta = min_delta | ||
self.counter = 0 | ||
self.min_validation_loss = np.inf | ||
|
||
Returns | ||
------- | ||
float: Average validation loss. | ||
""" | ||
avg_loss = 0 | ||
with torch.no_grad(): | ||
for batch, (X, Y) in enumerate(dataloader): | ||
# Compute prediction and loss | ||
pred = model(X) | ||
loss = loss_fn(pred, Y) | ||
avg_loss += loss.item() | ||
def early_stop(self, validation_loss, model=None): | ||
""" | ||
Check if training should be stopped early. | ||
|
||
avg_loss /= len(dataloader) | ||
Parameters | ||
---------- | ||
validation_loss : float | ||
Current validation loss. | ||
model : nn.Module, optional | ||
Model to save if validation loss improves (default is None). | ||
|
||
return avg_loss | ||
Returns | ||
------- | ||
bool | ||
True if training should be stopped, False otherwise. | ||
""" | ||
if validation_loss < self.min_validation_loss: | ||
self.min_validation_loss = validation_loss | ||
self.counter = 0 | ||
# if model is not None: | ||
# # torch.save(model.state_dict(), 'conv_torch.pth') | ||
# torch.save(model.state_dict(), 'trained_models/weights_conv') | ||
Comment on lines
+134
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this dead code, or should it be wrapped in a conditional of some sort, or replaced by a comment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unsure as to it's purpose, ask @yqsun91. |
||
elif validation_loss > (self.min_validation_loss + self.min_delta): | ||
self.counter += 1 | ||
if self.counter >= self.patience: | ||
return True | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is looking much better, the net is much cleaner and the docstrings really help understand what is going on.