diff --git a/src/fibad/config_utils.py b/src/fibad/config_utils.py index 0962871..0d7a19b 100644 --- a/src/fibad/config_utils.py +++ b/src/fibad/config_utils.py @@ -268,7 +268,7 @@ def resolve_runtime_config(runtime_config_filepath: Union[Path, str, None] = Non def create_results_dir(config: ConfigDict, postfix: Union[Path, str]) -> Path: """Creates a results directory for this run. - Postfix is the verb name of the run e.g. (predict, train, etc) + Postfix is the verb name of the run e.g. (infer, train, etc) The directory is created within the results dir (set with config results_dir) and follows the pattern - diff --git a/src/fibad/fibad.py b/src/fibad/fibad.py index 4d8cd13..3683c44 100644 --- a/src/fibad/fibad.py +++ b/src/fibad/fibad.py @@ -14,7 +14,7 @@ class Fibad: CLI functions in fibad_cli are implemented by calling this class """ - verbs = ["train", "predict", "download", "prepare", "rebuild_manifest"] + verbs = ["train", "infer", "download", "prepare", "rebuild_manifest"] def __init__(self, *, config_file: Union[Path, str] = None, setup_logging: bool = True): """Initialize fibad. Always applies the default config, and merges it with any provided config file. @@ -170,11 +170,11 @@ def download(self, **kwargs): downloader = Downloader(config=self.config) return downloader.run(**kwargs) - def predict(self, **kwargs): + def infer(self, **kwargs): """ - See Fibad.predict.run() + See Fibad.infer.run() """ - from .predict import run + from .infer import run return run(config=self.config, **kwargs) diff --git a/src/fibad/fibad_default_config.toml b/src/fibad/fibad_default_config.toml index 258688e..e5415d8 100644 --- a/src/fibad/fibad_default_config.toml +++ b/src/fibad/fibad_default_config.toml @@ -148,7 +148,7 @@ batch_size = 32 shuffle = false num_workers = 2 -[predict] +[infer] model_weights_file = false batch_size = 32 split = "test" diff --git a/src/fibad/predict.py b/src/fibad/infer.py similarity index 86% rename from src/fibad/predict.py rename to src/fibad/infer.py index a417388..5d6161a 100644 --- a/src/fibad/predict.py +++ b/src/fibad/infer.py @@ -16,7 +16,7 @@ def run(config: ConfigDict): - """Run Prediction + """Run inference on a model using a dataset Parameters ---------- @@ -24,13 +24,13 @@ def run(config: ConfigDict): The parsed config file as a nested dict """ - data_set = setup_dataset(config, split=config["predict"]["split"]) + data_set = setup_dataset(config, split=config["infer"]["split"]) model = setup_model(config, data_set) logger.info(f"data set has length {len(data_set)}") data_loader = dist_data_loader(data_set, config) # Create a results directory and dump our config there - results_dir = create_results_dir(config, "predict") + results_dir = create_results_dir(config, "infer") log_runtime_config(config, results_dir) load_model_weights(config, model) @@ -69,13 +69,13 @@ def load_model_weights(config: ConfigDict, model): The model class to load weights into """ - weights_file = config["predict"]["model_weights_file"] + weights_file = config["infer"]["model_weights_file"] if not weights_file: - # TODO: Look at the last predict run from the rundir + # TODO: Look at the last infer (or train) run from the rundir # use config["model"]["weights_filename"] to find the weights # Proceed with those weights - raise RuntimeError("Must define pretrained_model in the predict section of fibad config.") + raise RuntimeError("Must define model_weights_file in the [infer] section of fibad config.") weights_file = Path(weights_file) diff --git a/tests/fibad/test_config_utils.py b/tests/fibad/test_config_utils.py index 0b61c72..7173e3e 100644 --- a/tests/fibad/test_config_utils.py +++ b/tests/fibad/test_config_utils.py @@ -56,7 +56,7 @@ def test_get_runtime_config(): "model_class": "new_thing.cool_model.CoolModel", "model": {"model_weights_filepath": "final_best.pth", "layers": 3}, }, - "predict": {"batch_size": 8}, + "infer": {"batch_size": 8}, } assert runtime_config == expected diff --git a/tests/fibad/test_data/test_default_config.toml b/tests/fibad/test_data/test_default_config.toml index 04b17ad..6d371ca 100644 --- a/tests/fibad/test_data/test_default_config.toml +++ b/tests/fibad/test_data/test_default_config.toml @@ -9,5 +9,5 @@ model_class = "new_thing.cool_model.CoolModel" # Use a custom model model_weights_filepath = "example_model.pth" layers = 3 -[predict] +[infer] batch_size = 32 diff --git a/tests/fibad/test_data/test_user_config.toml b/tests/fibad/test_data/test_user_config.toml index 4128f79..5e792cf 100644 --- a/tests/fibad/test_data/test_user_config.toml +++ b/tests/fibad/test_data/test_user_config.toml @@ -5,5 +5,5 @@ dev_mode = false model_weights_filepath = "final_best.pth" layers = 3 -[predict] +[infer] batch_size = 8