Skip to content

Commit

Permalink
Fix warnings by downgrading pytorch hehe (#8)
Browse files Browse the repository at this point in the history
* use newer grad fcn

* change sampling to be uniform, change DataLoader to SubjectsLoader

* downgrade pytorch so that a warning message goes away

* format

---------

Co-authored-by: Richard Lane <[email protected]>
  • Loading branch information
richard-lane and Richard Lane authored Oct 15, 2024
1 parent 411e182 commit 9a25b3a
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 18 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- cudnn=8.1.0
- pandas
- pip
- pytorch
- pytorch=2.2
- pytorch-cuda=11.8
- torchvision
- monai
Expand Down
18 changes: 6 additions & 12 deletions fishjaw/model/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _train_val_loader(
train: bool,
patch_size: tuple[int, int, int] = None,
batch_size: int,
) -> torch.utils.data.DataLoader:
) -> tio.SubjectsLoader:
"""
Create a dataloader from a SubjectsDataset
Expand All @@ -94,11 +94,7 @@ def _train_val_loader(
shuffle = train is True
drop_last = train is True

# Even probability of the patches being centred on each value
label_probs = {0: 1, 1: 1}
patch_sampler = tio.LabelSampler(
patch_size=patch_size, label_probabilities=label_probs
)
patch_sampler = tio.UniformSampler(patch_size=patch_size)

patches = tio.Queue(
subjects,
Expand All @@ -110,23 +106,21 @@ def _train_val_loader(
shuffle_subjects=True,
)

return torch.utils.data.DataLoader(
return tio.SubjectsLoader(
patches,
batch_size=batch_size,
shuffle=shuffle,
num_workers=0, # Load the data in the main process
# No idea why I have to set this to False, otherwise we get obscure errors
pin_memory=False,
num_workers=6, # TODO make this a config option
drop_last=drop_last,
)

@property
def train_data(self) -> torch.utils.data.DataLoader:
def train_data(self) -> tio.SubjectsLoader:
"""Get the training data"""
return self._train_data

@property
def val_data(self) -> torch.utils.data.DataLoader:
def val_data(self) -> tio.SubjectsLoader:
"""Get the validation data"""
return self._val_data

Expand Down
8 changes: 5 additions & 3 deletions fishjaw/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def train_step(
net: torch.nn.Module,
optim: torch.optim.Optimizer,
loss_fn: torch.nn.Module,
train_data: torch.utils.data.DataLoader,
train_data: tio.SubjectsLoader,
scaler: GradScaler,
*,
device: torch.device,
Expand All @@ -228,7 +228,9 @@ def train_step(
for data in train_data:
x, y = _get_data(data)

input_, target = x.to(device), y.to(device)
input_, target = x.to(device, non_blocking=True), y.to(
device, non_blocking=True
)

optim.zero_grad()
with autocast():
Expand All @@ -247,7 +249,7 @@ def train_step(
def validation_step(
net: torch.nn.Module,
loss_fn: torch.nn.Module,
validation_data: torch.utils.data.DataLoader,
validation_data: tio.SubjectsLoader,
*,
device: torch.device,
) -> tuple[torch.nn.Module, list[float]]:
Expand Down
4 changes: 2 additions & 2 deletions userconf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ test_train_seed: 1
device: "cuda"
window_size: "192,192,192" # Comma-separated ZYX. Needs to be large enough to hold the whole jaw
patch_size: "160,160,160" # Bigger holds more context, smaller is faster and allows for bigger batches
batch_size: 2
epochs: 25
batch_size: 27
epochs: 10
lr_lambda: 0.9999 # Exponential decay factor (multiplicative with each epoch)

model_params:
Expand Down

0 comments on commit 9a25b3a

Please sign in to comment.