Skip to content

Commit

Permalink
Predict saves data to numpy files
Browse files Browse the repository at this point in the history
- Minor rework to how data sets are handled by pytorch_ignite common code
- New .ids() public member for fibad datasets so results can be tracked by ID
- Output format is object_id.npy in numpy binary format.
  • Loading branch information
mtauraso committed Sep 24, 2024
1 parent 1a5a36d commit d7f868a
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/fibad/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def create_results_dir(config: ConfigDict, prefix: Union[Path, str]) -> Path:
The path created by this function
"""
results_root = Path(config["general"]["results_dir"]).resolve()
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%m%S")
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
directory = results_root / f"{prefix}-{timestamp}"
directory.mkdir(parents=True, exist_ok=False)
return directory

Check warning on line 263 in src/fibad/config_utils.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/config_utils.py#L259-L263

Added lines #L259 - L263 were not covered by tests
Expand Down
6 changes: 3 additions & 3 deletions src/fibad/data_loaders/hsc_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _scan_file_dimensions(self) -> dict[str, tuple[int, int]]:
# Scan the filesystem to get the widths and heights of all images into a dict
return {
object_id: [self._fits_file_dims(filepath) for filepath in self._object_files(object_id)]
for object_id in self._all_object_ids()
for object_id in self.ids()
}

def _prune_objects(self, filters_ref: list[str]) -> list[str]:
Expand Down Expand Up @@ -356,7 +356,7 @@ def _get_file(self, index: int) -> Path:
filter = filter_names[index % self.num_filters]
return self._file_to_path(filters[filter])

def _all_object_ids(self):
def ids(self):
"""Private read-only iterator over all object_ids that enforces a strict total order across
objects. Will not work prior to self.files initialization in __init__
Expand All @@ -378,7 +378,7 @@ def _all_files(self):
Path
The path to the file.
"""
for object_id in self._all_object_ids():
for object_id in self.ids():
for filename in self._object_files(object_id):
yield filename

Expand Down
11 changes: 6 additions & 5 deletions src/fibad/models/example_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def _eval_decoder(self, x):
x = CenterCrop(size=(self.image_width, self.image_height))(x)
return x

def forward(self, x):
z = self._eval_encoder(x)
x_hat = self._eval_decoder(z)
return x_hat
def forward(self, batch):
# When we run on a supervised dataset like CIFAR10, drop the labels given by the data loader
x = batch[0] if isinstance(batch, tuple) else batch
return self._eval_encoder(x)

Check warning on line 106 in src/fibad/models/example_autoencoder.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/models/example_autoencoder.py#L105-L106

Added lines #L105 - L106 were not covered by tests

def train_step(self, batch):
"""This function contains the logic for a single training step. i.e. the
Expand All @@ -122,7 +122,8 @@ def train_step(self, batch):
# When we run on a supervised dataset like CIFAR10, drop the labels given by the data loader
x = batch[0] if isinstance(batch, tuple) else batch

x_hat = self.forward(x)
z = self._eval_encoder(x)
x_hat = self._eval_decoder(z)

Check warning on line 126 in src/fibad/models/example_autoencoder.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/models/example_autoencoder.py#L125-L126

Added lines #L125 - L126 were not covered by tests
loss = F.mse_loss(x, x_hat, reduction="none")
loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0])

Expand Down
33 changes: 25 additions & 8 deletions src/fibad/predict.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import logging
from pathlib import Path

Check warning on line 2 in src/fibad/predict.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/predict.py#L2

Added line #L2 was not covered by tests

import numpy as np
from torch import Tensor

Check warning on line 5 in src/fibad/predict.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/predict.py#L4-L5

Added lines #L4 - L5 were not covered by tests

from fibad.config_utils import ConfigDict, create_results_dir, log_runtime_config
from fibad.pytorch_ignite import create_evaluator, setup_model_and_dataloader
from fibad.pytorch_ignite import create_evaluator, dist_data_loader, setup_model_and_dataset

Check warning on line 8 in src/fibad/predict.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/predict.py#L7-L8

Added lines #L7 - L8 were not covered by tests

logger = logging.getLogger(__name__)

Expand All @@ -16,22 +19,36 @@ def run(config: ConfigDict):
The parsed config file as a nested dict
"""

model, data_loader = setup_model_and_dataloader(config)
model, data_set = setup_model_and_dataset(config)
data_loader = dist_data_loader(data_set, config)

Check warning on line 23 in src/fibad/predict.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/predict.py#L22-L23

Added lines #L22 - L23 were not covered by tests

# Create a results directory and dump our config there
results_dir = create_results_dir(config, "predict")
log_runtime_config(config, results_dir)

load_model_weights(config, model)

Check warning on line 28 in src/fibad/predict.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/predict.py#L26-L28

Added lines #L26 - L28 were not covered by tests

evaluator = create_evaluator(model)

write_index = 0
object_ids = list(data_set.ids() if hasattr(data_set, "ids") else range(len(data_set)))

Check warning on line 31 in src/fibad/predict.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/predict.py#L30-L31

Added lines #L30 - L31 were not covered by tests

def _save_batch(batch_results: Tensor):

Check warning on line 33 in src/fibad/predict.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/predict.py#L33

Added line #L33 was not covered by tests
"""Receive and write results tensors to results_dir immediately
This function writes a single numpy binary file for each object.
"""
nonlocal write_index
batch_results = batch_results.detach().to("cpu")
for tensor in batch_results:
object_id = object_ids[write_index]
filename = f"{object_id}.npy"
savepath = results_dir / filename
if savepath.exists():
RuntimeError("The path to save results for object {object_id} already exists.")
np.save(savepath, tensor.numpy(), allow_pickle=False)
write_index += 1

Check warning on line 46 in src/fibad/predict.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/predict.py#L38-L46

Added lines #L38 - L46 were not covered by tests

evaluator = create_evaluator(model, _save_batch)
evaluator.run(data_loader)

logger.info("finished evaluating...")

Check warning on line 50 in src/fibad/predict.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/predict.py#L48-L50

Added lines #L48 - L50 were not covered by tests

# Run inference across the data set...


def load_model_weights(config: ConfigDict, model):

Check warning on line 53 in src/fibad/predict.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/predict.py#L53

Added line #L53 was not covered by tests
"""Loads the model weights from a file. Raises RuntimeError if this is not possible due to
Expand Down
111 changes: 71 additions & 40 deletions src/fibad/pytorch_ignite.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import functools
import logging
from typing import Any, Callable

Check warning on line 3 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L1-L3

Added lines #L1 - L3 were not covered by tests

import ignite.distributed as idist
import torch
from ignite.engine import Engine, Events
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import Dataset

Check warning on line 9 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L5-L9

Added lines #L5 - L9 were not covered by tests

from fibad.config_utils import ConfigDict
from fibad.data_loaders.data_loader_registry import fetch_data_loader_class
Expand All @@ -12,7 +15,7 @@
logger = logging.getLogger(__name__)

Check warning on line 15 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L15

Added line #L15 was not covered by tests


def setup_model_and_dataloader(config: ConfigDict) -> tuple:
def setup_model_and_dataset(config: ConfigDict) -> tuple:

Check warning on line 18 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L18

Added line #L18 was not covered by tests
"""
Construct the data loader and the model according to configuration.
Expand All @@ -39,87 +42,118 @@ def setup_model_and_dataloader(config: ConfigDict) -> tuple:
# Get the pytorch.dataset from dataloader, and use it to create a distributed dataloader
data_set = data_loader.data_set()

Check warning on line 43 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L43

Added line #L43 was not covered by tests

return model, _dist_data_loader(data_set, config)
return model, data_set

Check warning on line 45 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L45

Added line #L45 was not covered by tests


def _dist_data_loader(data_set, config):
def dist_data_loader(data_set: Dataset, config: ConfigDict):

Check warning on line 48 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L48

Added line #L48 was not covered by tests
"""Create a Pytorch Ignite distributed data loader
Parameters
----------
data_set : Dataset
A Pytorch Dataset object
config : ConfigDict
Fibad runtime configuration
Returns
-------
Dataloader (or an ignite-wrapped equivalent)
This is the distributed dataloader, formed by calling ignite.distributed.auto_dataloader
"""
# ~ idist.auto_dataloader will accept a **kwargs parameter, and pass values
# ~ through to the underlying pytorch DataLoader.
# ~ Currently, our config includes unexpected keys like `name`, that cause
# ~ an exception. It would be nice to reduce this to:
# ~ `data_loader = idist.auto_dataloader(data_set, **config)`
data_loader = idist.auto_dataloader(
return idist.auto_dataloader(

Check warning on line 68 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L68

Added line #L68 was not covered by tests
data_set,
batch_size=config["data_loader"]["batch_size"],
shuffle=config["data_loader"]["shuffle"],
num_workers=config["data_loader"]["num_workers"],
)

return data_loader


def _extract_inner_function(model, funcname):
# Extract `train_step` or `forward` from model, which can be wrapped after idist.auto_model(...)
if (
type(model) == torch.nn.parallel.DistributedDataParallel
or type(model) == torch.nn.parallel.DataParallel
):
inner_step = getattr(model.module, funcname)
else:
inner_step = getattr(model, funcname)

return inner_step
def create_engine(funcname: str, device: torch.device, model: torch.nn.Module):

Check warning on line 76 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L76

Added line #L76 was not covered by tests
"""Unified creation of the pytorch engine object for either an evaluator or trainer.
This function will automatically unwrap a distributed model to find the necessary function, and construct
the necessary functions to transfer data to the device on every batch, so model code can be the same no
matter where the model is being run.
# This wraps a model-specific function (func) to move data to the appropriate device.
def _inner_loop(func, device, engine, batch):
#! This feels brittle, it would be worth revisiting this.
# We assume that the batch data will generally have two forms.
# 1) A torch.Tensor that represents N samples.
# 2) A tuple (or list) of torch.Tensors, where the first tensor is the
# data, and the second is labels.
batch = batch.to(device) if isinstance(batch, torch.Tensor) else tuple(i.to(device) for i in batch)

return func(batch)
Parameters
----------
funcname : str
The function name on the model that we will call in the core of the engine loop, and be called once
per batch
device : torch.device
The device the engine will run the model on
model : torch.nn.Module
The Model the engine will be using
"""

def _extract_model_method(model, method_name):

Check warning on line 94 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L94

Added line #L94 was not covered by tests
# Extract `train_step` or `forward` from model, which can be wrapped after idist.auto_model(...)
wrapped = type(model) == DistributedDataParallel or type(model) == DataParallel
return getattr(model.module if wrapped else model, method_name)

Check warning on line 97 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L96-L97

Added lines #L96 - L97 were not covered by tests

# This wraps a model-specific function (func) to move data to the appropriate device.
def _inner_loop(func, device, engine, batch):

Check warning on line 100 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L100

Added line #L100 was not covered by tests
#! This feels brittle, it would be worth revisiting this.
# We assume that the batch data will generally have two forms.
# 1) A torch.Tensor that represents N samples.
# 2) A tuple (or list) of torch.Tensors, where the first tensor is the
# data, and the second is labels.
batch = batch.to(device) if isinstance(batch, torch.Tensor) else tuple(i.to(device) for i in batch)
return func(batch)

Check warning on line 107 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L106-L107

Added lines #L106 - L107 were not covered by tests

def _create_process_func(funcname, device, model):
inner_step = _extract_model_method(model, funcname)
inner_loop = functools.partial(_inner_loop, inner_step, device)
return inner_loop

Check warning on line 112 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L109-L112

Added lines #L109 - L112 were not covered by tests

def _create_engine_loop(funcname, device, model):
inner_step = _extract_inner_function(model, funcname)
inner_loop = functools.partial(_inner_loop, inner_step, device)
return inner_loop
model = idist.auto_model(model)
return Engine(_create_process_func(funcname, device, model))

Check warning on line 115 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L114-L115

Added lines #L114 - L115 were not covered by tests


def create_evaluator(model):
"""Based on create_trainer. This creates a pytorch ignite evaluator object with appropriate event
handlers
def create_evaluator(model: torch.nn.Module, save_function: Callable[[torch.Tensor], Any]) -> Engine:

Check warning on line 118 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L118

Added line #L118 was not covered by tests
"""Creates an evaluator engine
Primary purpose of this function is to attach the appropriate handlers to an evaluator engine
Parameters
----------
model : torch.nn.Module
The model to evaluate
save_function : Callable[[torch.Tensor], Any]
A function which will recieve Engine.state.output at the end of each iteration. The intent
is for the results of evaluation to be saved.
Returns
-------
pytorch-ignite.Engine
Engine object which when run will evaluate the model.
"""
device = idist.device()
model = idist.auto_model(model)
evaluator = Engine(_create_engine_loop("forward", device, model))
evaluator = create_engine("forward", device, model)

Check warning on line 138 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L136-L138

Added lines #L136 - L138 were not covered by tests

@evaluator.on(Events.STARTED)
def log_eval_start(evaluator):
logger.info(f"Evaluating model on device: {device}")
logger.info(f"Total epochs: {evaluator.state.max_epochs}")

Check warning on line 143 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L140-L143

Added lines #L140 - L143 were not covered by tests

@evaluator.on(Events.ITERATION_COMPLETED)
def log_iteration_complete(evaluator):
save_function(evaluator.state.output)

Check warning on line 147 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L145-L147

Added lines #L145 - L147 were not covered by tests

@evaluator.on(Events.COMPLETED)
def log_total_time(evaluator):
logger.info(f"Total evaluation time: {evaluator.state.times['COMPLETED']:.2f}[s]")

Check warning on line 151 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L149-L151

Added lines #L149 - L151 were not covered by tests

return evaluator

Check warning on line 153 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L153

Added line #L153 was not covered by tests


def create_trainer(model):
def create_trainer(model) -> Engine:

Check warning on line 156 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L156

Added line #L156 was not covered by tests
"""This function is originally copied from here:
https://github.com/pytorch-ignite/examples/blob/main/tutorials/intermediate/cifar10-distributed.py#L164
Expand All @@ -135,12 +169,9 @@ def create_trainer(model):
pytorch-ignite.Engine
Engine object that will be used to train the model.
"""
# Get currently available device for training, and set the model to use it
device = idist.device()
# logger.info(f"Training on device: {device}")
model = idist.auto_model(model)

trainer = Engine(_create_engine_loop("train_step", device, model))
trainer = create_engine("train_step", device, model)

Check warning on line 174 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L172-L174

Added lines #L172 - L174 were not covered by tests

@trainer.on(Events.STARTED)
def log_training_start(trainer):
Expand Down
5 changes: 3 additions & 2 deletions src/fibad/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from fibad.config_utils import create_results_dir, log_runtime_config
from fibad.pytorch_ignite import create_trainer, setup_model_and_dataloader
from fibad.pytorch_ignite import create_trainer, dist_data_loader, setup_model_and_dataset

Check warning on line 4 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L3-L4

Added lines #L3 - L4 were not covered by tests

logger = logging.getLogger(__name__)

Expand All @@ -19,7 +19,8 @@ def run(config):
results_dir = create_results_dir(config, "train")
log_runtime_config(config, results_dir)

Check warning on line 20 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L19-L20

Added lines #L19 - L20 were not covered by tests

model, data_loader = setup_model_and_dataloader(config)
model, data_set = setup_model_and_dataset(config)
data_loader = dist_data_loader(data_set, config)

Check warning on line 23 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L22-L23

Added lines #L22 - L23 were not covered by tests

# Create trainer, a pytorch-ignite `Engine` object
trainer = create_trainer(model)

Check warning on line 26 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L26

Added line #L26 was not covered by tests
Expand Down

0 comments on commit d7f868a

Please sign in to comment.