Skip to content

Commit

Permalink
fcn to load model
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard Lane committed Nov 15, 2024
1 parent 13b4338 commit bd30ccb
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 24 deletions.
22 changes: 21 additions & 1 deletion fishjaw/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.cuda.amp import autocast, GradScaler

from .data import DataConfig
from ..util import util
from ..util import util, files


@dataclass(frozen=True)
Expand Down Expand Up @@ -461,3 +461,23 @@ def predict(
aggregator.add_batch(prediction, locations=locations)

return aggregator.get_output_tensor()[1].numpy()


def load_model(model_name: str) -> ModelState:
"""
Load a pickled model from disk given its name
:param model_name: the name of the model to load, e.g. "model_state.pkl", as specified
in userconf.yml. Must end in ".pkl".
:returns: the model
:raises FileNotFoundError: if the model file does not exist
:raises ValueError: if the model name does not end in ".pkl"
"""
if not model_name.endswith(".pkl"):
raise ValueError(f"Model name should end with .pkl: {model_name}")

with open(files.model_path({"model_path": model_name}), "rb") as f:
return pickle.load(f)
3 changes: 2 additions & 1 deletion fishjaw/util/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def model_path(config: dict) -> pathlib.Path:
configuration used to initialise the model/define the architecture and
training parameters all in one place.
:param config: the configuration, e.g. from userconf.yml
:param config: the configuration, e.g. from userconf.yml.
Just needs to have a "model_path" key
:returns: Path to the model. If the path doesn't end in .pkl, it will be appended
"""
Expand Down
26 changes: 9 additions & 17 deletions scripts/arch_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,16 @@ def replace_layers_with_tracker(net: torch.nn.Module):
replace_layers_with_tracker(layer)


def _load_model(config: dict) -> torch.nn.Module:
"""
Load the model from disk
"""

# Load the state dict
with open(str(files.model_path(config)), "rb") as f:
model_state: model.ModelState = pickle.load(f)
return model_state.load_model(set_eval=True)


def main():
def main(*, model_name: str):
"""
Load the model, read the chosen image and perform inference
Save the output image
"""
# Load the model
net = _load_model(util.userconf())
model_state: model.ModelState = model.load_model(model_name)

net = model_state.load_model(set_eval=True)
net.to("cuda")

# Print the number of trainable parameters
Expand All @@ -98,7 +88,9 @@ def main():
parser = argparse.ArgumentParser(
description="Summarise the architecture of the model"
)
parser.add_argument(
"model_name",
help="Which model to load from the models dir; e.g. 'model_state.pkl'",
)

# Not passing any arguments
parser.parse_args()
main()
main(**vars(parser.parse_args()))
6 changes: 1 addition & 5 deletions scripts/inference_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,11 @@ def main(args):
Save the output image
"""
if not args.model_name.endswith(".pkl"):
raise ValueError("Model name should end with .pkl")

if args.subject == 247:
raise RuntimeError("I think this one was in the training dataset...")

# Load the model and training-time config
with open(str(files.model_path({"model_path": args.model_name})), "rb") as f:
model_state: model.ModelState = pickle.load(f)
model_state = model.load_model(args.model_name)

config = model_state.config
net = model_state.load_model(set_eval=True)
Expand Down

0 comments on commit bd30ccb

Please sign in to comment.