Skip to content

Commit

Permalink
dynamically get patch size
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard Lane committed Nov 29, 2024
1 parent 7b440ea commit 61e5476
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions scripts/arch_summary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Summarise the architecture of the model
Summarise the architecture of the model - only works for the attention unet
"""

Expand All @@ -8,7 +8,7 @@
import torch
from prettytable import PrettyTable

from fishjaw.model import model
from fishjaw.model import model, data
from monai.networks.nets import attentionunet


Expand Down Expand Up @@ -68,16 +68,17 @@ def main(*, model_name: str):
model_state: model.ModelState = model.load_model(model_name)

net = model_state.load_model(set_eval=True)
if not isinstance(net, attentionunet.AttentionUnet):
raise ValueError("This script only works for the attention unet sorry")
net.to("cuda")

# Print the number of trainable parameters
count_parameters(net)

# Track the size of the receptive field throughout the model
replace_layers_with_tracker(net)
# This should really use the architecture to find the size of the input
# At the moment it's hard coded
dummy_input = torch.randn(1, 1, 160, 160, 160).to("cuda")

dummy_input = torch.randn(1, 1, *data.get_patch_size(model_state.config)).to("cuda")
net(dummy_input)


Expand Down

0 comments on commit 61e5476

Please sign in to comment.