Skip to content

Commit

Permalink
Renaming verb from predict to infer.
Browse files Browse the repository at this point in the history
  • Loading branch information
drewoldag committed Dec 13, 2024
1 parent 3ac3130 commit e8e9cf8
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/fibad/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def resolve_runtime_config(runtime_config_filepath: Union[Path, str, None] = Non
def create_results_dir(config: ConfigDict, postfix: Union[Path, str]) -> Path:
"""Creates a results directory for this run.
Postfix is the verb name of the run e.g. (predict, train, etc)
Postfix is the verb name of the run e.g. (infer, train, etc)
The directory is created within the results dir (set with config results_dir)
and follows the pattern <timestamp>-<postfix>
Expand Down
8 changes: 4 additions & 4 deletions src/fibad/fibad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Fibad:
CLI functions in fibad_cli are implemented by calling this class
"""

verbs = ["train", "predict", "download", "prepare", "rebuild_manifest"]
verbs = ["train", "infer", "download", "prepare", "rebuild_manifest"]

def __init__(self, *, config_file: Union[Path, str] = None, setup_logging: bool = True):
"""Initialize fibad. Always applies the default config, and merges it with any provided config file.
Expand Down Expand Up @@ -170,11 +170,11 @@ def download(self, **kwargs):
downloader = Downloader(config=self.config)
return downloader.run(**kwargs)

def predict(self, **kwargs):
def infer(self, **kwargs):
"""
See Fibad.predict.run()
See Fibad.infer.run()
"""
from .predict import run
from .infer import run

Check warning on line 177 in src/fibad/fibad.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/fibad.py#L177

Added line #L177 was not covered by tests

return run(config=self.config, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ batch_size = 32
shuffle = false
num_workers = 2

[predict]
[infer]
model_weights_file = false
batch_size = 32
split = "test"
12 changes: 6 additions & 6 deletions src/fibad/predict.py → src/fibad/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@


def run(config: ConfigDict):
"""Run Prediction
"""Run inference on a model using a dataset
Parameters
----------
config : ConfigDict
The parsed config file as a nested dict
"""

data_set = setup_dataset(config, split=config["predict"]["split"])
data_set = setup_dataset(config, split=config["infer"]["split"])

Check warning on line 27 in src/fibad/infer.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/infer.py#L27

Added line #L27 was not covered by tests
model = setup_model(config, data_set)
logger.info(f"data set has length {len(data_set)}")
data_loader = dist_data_loader(data_set, config)

# Create a results directory and dump our config there
results_dir = create_results_dir(config, "predict")
results_dir = create_results_dir(config, "infer")

Check warning on line 33 in src/fibad/infer.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/infer.py#L33

Added line #L33 was not covered by tests
log_runtime_config(config, results_dir)
load_model_weights(config, model)

Expand Down Expand Up @@ -69,13 +69,13 @@ def load_model_weights(config: ConfigDict, model):
The model class to load weights into
"""
weights_file = config["predict"]["model_weights_file"]
weights_file = config["infer"]["model_weights_file"]

Check warning on line 72 in src/fibad/infer.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/infer.py#L72

Added line #L72 was not covered by tests

if not weights_file:
# TODO: Look at the last predict run from the rundir
# TODO: Look at the last infer (or train) run from the rundir
# use config["model"]["weights_filename"] to find the weights
# Proceed with those weights
raise RuntimeError("Must define pretrained_model in the predict section of fibad config.")
raise RuntimeError("Must define model_weights_file in the [infer] section of fibad config.")

Check warning on line 78 in src/fibad/infer.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/infer.py#L78

Added line #L78 was not covered by tests

weights_file = Path(weights_file)

Expand Down
2 changes: 1 addition & 1 deletion tests/fibad/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_get_runtime_config():
"model_class": "new_thing.cool_model.CoolModel",
"model": {"model_weights_filepath": "final_best.pth", "layers": 3},
},
"predict": {"batch_size": 8},
"infer": {"batch_size": 8},
}

assert runtime_config == expected
Expand Down
2 changes: 1 addition & 1 deletion tests/fibad/test_data/test_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ model_class = "new_thing.cool_model.CoolModel" # Use a custom model
model_weights_filepath = "example_model.pth"
layers = 3

[predict]
[infer]
batch_size = 32
2 changes: 1 addition & 1 deletion tests/fibad/test_data/test_user_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ dev_mode = false
model_weights_filepath = "final_best.pth"
layers = 3

[predict]
[infer]
batch_size = 8

0 comments on commit e8e9cf8

Please sign in to comment.