Skip to content

Commit

Permalink
Feature/fine tune (#25)
Browse files Browse the repository at this point in the history
Fix hyperparam tuning, change params to make the model a bit better and put the inference output in the right place
---------

Co-authored-by: Richard Lane <[email protected]>
  • Loading branch information
richard-lane and Richard Lane authored Nov 8, 2024
1 parent b6d167c commit 5a38404
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
13 changes: 8 additions & 5 deletions scripts/explore_hyperparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ def _lr(rng: np.random.Generator, mode: str) -> float:
lr_range = (-6, 1)
elif mode == "fine":
# From centering around a value that seems to broadly work
lr_range = (-5, 0)
lr_range = (-3, 0)
else:
raise ValueError(f"Unknown mode {mode}")

return 10 ** rng.uniform(*lr_range)


def _batch_size(rng: np.random.Generator) -> int:
def _batch_size(rng: np.random.Generator, mode: str) -> int:
# Maximum here sort of depends on what you can fit on the GPU
if mode == "fine":
return int(rng.integers(1, 21))
return int(rng.integers(1, 33))


Expand All @@ -79,7 +81,8 @@ def _epochs(rng: np.random.Generator, mode: str) -> int:
return 5
if mode == "med":
return 15
return int(rng.integers(25, 500))
# return int(rng.integers(100, 500))
return 400


def _alpha(rng: np.random.Generator) -> float:
Expand All @@ -93,7 +96,7 @@ def _n_filters(rng: np.random.Generator) -> int:
def _lambda(rng: np.random.Generator, mode: str) -> float:
if mode != "fine":
return 1
return 1 - (10 ** rng.uniform(-17, -1))
return 1 - (10 ** rng.uniform(-12, -1))


def _config(rng: np.random.Generator, mode: str) -> dict:
Expand All @@ -118,7 +121,7 @@ def _config(rng: np.random.Generator, mode: str) -> dict:
"device": "cuda",
"learning_rate": _lr(rng, mode),
"optimiser": "Adam",
"batch_size": _batch_size(rng),
"batch_size": _batch_size(rng, mode),
"epochs": _epochs(rng, mode),
"lr_lambda": _lambda(rng, mode),
"loss": "monai.losses.TverskyLoss",
Expand Down
8 changes: 5 additions & 3 deletions scripts/inference_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def _subject(config: dict, args: argparse.Namespace) -> tio.Subject:
"""
# Load the testing subject
if args.test:
with open("train_output/test_subject.pkl", "rb") as f:
with open(
str(files.script_out_dir() / "train_output" / "test_subject.pkl"), "rb"
) as f:
return pickle.load(f)
else:
window_size = transform.window_size(config)
Expand All @@ -60,7 +62,7 @@ def _subject(config: dict, args: argparse.Namespace) -> tio.Subject:
crop_lookup = {
218: (1700, 396, 296), # 24month wt wt dvl:gfp contrast enhance
219: (1411, 344, 420), # 24month wt wt dvl:gfp contrast enhance
247: (1710, 431, 290), # 14month het sp7 sp7+/-
# 247: (1710, 431, 290), # 14month het sp7 sp7+/-
273: (1685, 221, 286), # 9month het sp7 sp7 het
274: (1413, 174, 240), # 9month hom sp7 sp7 mut
120: (1595, 251, 398), # 10month wt giantin giantin sib
Expand Down Expand Up @@ -167,7 +169,7 @@ def _make_plots(
# Convert the image to a 3d numpy array - for plotting
image = subject[tio.IMAGE][tio.DATA].squeeze().numpy()

out_dir = pathlib.Path("inference/")
out_dir = files.script_out_dir() / "inference"
if not out_dir.exists():
out_dir.mkdir()

Expand Down
10 changes: 5 additions & 5 deletions userconf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ loss: "monai.losses.TverskyLoss"
loss_options: {
"include_background": false,
"to_onehot_y": true,
"alpha": 0.5,
"beta": 0.5,
"alpha": 0.3,
"beta": 0.7,
"sigmoid": true,
}

Expand All @@ -44,9 +44,9 @@ test_train_seed: 1
device: "cuda"
window_size: "192,192,192" # Comma-separated ZYX. Needs to be large enough to hold the whole jaw
patch_size: "160,160,160" # Bigger holds more context, smaller is faster and allows for bigger batches
batch_size: 14
epochs: 100
lr_lambda: 0.9999 # Exponential decay factor (multiplicative with each epoch)
batch_size: 6
epochs: 400
lr_lambda: 0.99999 # Exponential decay factor (multiplicative with each epoch)

# Options should be passed
transforms:
Expand Down

0 comments on commit 5a38404

Please sign in to comment.