Skip to content

Commit

Permalink
Train some models (#31)
Browse files Browse the repository at this point in the history
Save images of inference meshes to the right place
---------

Co-authored-by: Richard Lane <[email protected]>
  • Loading branch information
richard-lane and Richard Lane authored Nov 15, 2024
1 parent 494b3fd commit 13b4338
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
13 changes: 10 additions & 3 deletions scripts/inference_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def _mesh_projections(stl_mesh: mesh.Mesh) -> plt.Figure:
return fig


def _save_mesh(segmentation: np.ndarray, subject_name: str, threshold: float) -> None:
def _save_mesh(
segmentation: np.ndarray, subject_name: str, threshold: float, out_dir: pathlib.Path
) -> None:
"""
Turn a segmentation into a mesh and save it
Expand All @@ -149,7 +151,7 @@ def _save_mesh(segmentation: np.ndarray, subject_name: str, threshold: float) ->

# Save projections
fig = _mesh_projections(stl_mesh)
fig.savefig(f"inference/{subject_name}_mesh_{threshold:.3f}_projections.png")
fig.savefig(f"{out_dir}/{subject_name}_mesh_{threshold:.3f}_projections.png")
plt.close(fig)


Expand Down Expand Up @@ -213,8 +215,13 @@ def _make_plots(

# Save the mesh
if args.mesh:

for threshold in np.arange(0.1, 1, 0.1):
_save_mesh(prediction, args.subject, threshold)
_save_mesh(prediction, prefix, threshold, out_dir)

if args.test:
# Mesh the ground truth too
_save_mesh(truth, f"{prefix}_truth", 0.5, out_dir)


def _inference(args: argparse.Namespace, net: torch.nn.Module, config: dict) -> None:
Expand Down
10 changes: 5 additions & 5 deletions userconf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dicom_dirs:

# Must end in .pkl (the model will be pickled)
# The model will be saved in the model/ directory
model_path: "alpha_001_beta_099_model.pkl"
model_path: "with_attention.pkl"

# Which ones to use for testing and validation
# all the others will be used for testing
Expand All @@ -33,8 +33,8 @@ loss: "monai.losses.TverskyLoss"
loss_options: {
"include_background": false,
"to_onehot_y": true,
"alpha": 0.01,
"beta": 0.99,
"alpha": 0.2,
"beta": 0.8,
"sigmoid": true,
}

Expand All @@ -51,8 +51,8 @@ test_train_seed: 1
device: "cuda"
window_size: "192,192,192" # Comma-separated ZYX. Needs to be large enough to hold the whole jaw
patch_size: "160,160,160" # Bigger holds more context, smaller is faster and allows for bigger batches
batch_size: 6
epochs: 400
batch_size: 12
epochs: 500
lr_lambda: 0.99999 # Exponential decay factor (multiplicative with each epoch)

# Options should be passed
Expand Down

0 comments on commit 13b4338

Please sign in to comment.