Skip to content

Commit

Permalink
Script to plot training data (#9)
Browse files Browse the repository at this point in the history
Co-authored-by: Richard Lane <[email protected]>
  • Loading branch information
richard-lane and Richard Lane authored Oct 21, 2024
1 parent 9a25b3a commit 9783223
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 7 deletions.
3 changes: 3 additions & 0 deletions config.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# You shouldn't need to change this
# User-defined configuration is in userconf.yml

# Where random stuff from the scripts goes
script_output: "script_output/"

# This contains directories of 2d tifs
wahabs_old_tifs: "DATABASE/uCT/Wahab_clean_dataset/low_res_clean_v3/"
felix_old_labels_dir: "1Felix and Rich make models/Training dataset Tiffs"
Expand Down
17 changes: 10 additions & 7 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,27 @@ Some scripts

Setup
----
`create_dicoms.py`: create DICOM files for each CT scan - label pair
`plot_dicoms.py`: create plots visualising these DICOM files
- `create_dicoms.py`: create DICOM files for each CT scan - label pair
- `plot_dicoms.py`: create plots visualising these DICOM files

Training Models
----
`train_model.py`: train the model
- `train_model.py`: train the model

Hyperparameter Tuning
----
`explore_hyperparams.py`: train lots of models with different hyperparameters, to see what's best
- `explore_hyperparams.py`: train lots of models with different hyperparameters, to see what's best
This script isn't particularly good or robust - most of the options are defined
by the config `dict` within this script, but there are some other parameters and
extra things (e.g. the transformations that get applied) that are hard-coded in other
places.
`plot_hyperparams.py`: plot the result of the hyperparameter tuning
- `plot_hyperparams.py`: plot the result of the hyperparameter tuning

Other Stuff
----
`mesh.py`: example showing the conversion from tiff (which is what the segmentation model creates) to a mesh
- `mesh.py`: example showing the conversion from tiff (which is what the segmentation model creates) to a mesh
[ ] TODO make this useful
`arch_summary.py`: summarise the architecture of the model (at the moment just prints the feature map sizes)
- `arch_summary.py`: summarise the architecture of the model (at the moment just prints the feature map sizes)
- `plot_train_data.py`: plot the training data. Useful to visualising the extent of the data
augmentation, and for making sure you have the expected number of batches,
images per batch, etc...
82 changes: 82 additions & 0 deletions scripts/plot_train_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Plot the training data - to e.g. visualize the transforms
"""

import pathlib
import argparse

import torch
import numpy as np
import torchio as tio
import matplotlib.pyplot as plt

from fishjaw.util import util
from fishjaw.model import data
from fishjaw.visualisation import images_3d


def _data_config() -> data.DataConfig:
"""
Get the training data configuration
"""
# Create training config
config = util.userconf()
torch.manual_seed(config["torch_seed"])
rng = np.random.default_rng(seed=config["test_train_seed"])

train_subjects, val_subjects, _ = data.read_dicoms_from_disk(config, rng)
return data.DataConfig(config, train_subjects, val_subjects)


def main(*, step: int, epochs: int):
"""
Read the DICOMs from disk () Create the data config
"""
output_dir = (
pathlib.Path(__file__).parents[1]
/ util.config()["script_output"]
/ "train_data"
)
if not output_dir.exists():
output_dir.mkdir(parents=True)

data_config = _data_config()

# Epochs
for epoch in range(0, epochs, step):
# Batches
for i, batch in enumerate(data_config.train_data):
images = batch[tio.IMAGE][tio.DATA]
masks = batch[tio.LABEL][tio.DATA]
# Images per batch
for j, (image, mask) in enumerate(zip(images, masks)):
out_path = str(
output_dir / f"traindata_epoch_{epoch}_batch_{i}_img_{j}.png"
)

fig, _ = images_3d.plot_slices(
image.squeeze().numpy(), mask.squeeze().numpy()
)
fig.savefig(out_path)
plt.close(fig)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Plot the training data")
parser.add_argument(
"--step",
type=int,
help="Interval between plots - step of 1 plots all data",
default=1,
)
parser.add_argument(
"--epochs",
type=int,
help="How many complete passes over the training data to make",
default=1,
)

main(**vars(parser.parse_args()))

0 comments on commit 9783223

Please sign in to comment.