Skip to content

Commit

Permalink
Fix error when using fibad with multiple gpus (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
drewoldag authored Oct 15, 2024
1 parent 24f2f51 commit de91389
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions src/fibad/pytorch_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,6 @@ def create_engine(funcname: str, device: torch.device, model: torch.nn.Module):
The Model the engine will be using
"""

def _extract_model_method(model, method_name):
# Extract `train_step` or `forward` from model, which can be wrapped after idist.auto_model(...)
wrapped = type(model) == DistributedDataParallel or type(model) == DataParallel
return getattr(model.module if wrapped else model, method_name)

# This wraps a model-specific function (func) to move data to the appropriate device.
def _inner_loop(func, device, engine, batch):
#! This feels brittle, it would be worth revisiting this.
Expand All @@ -106,14 +101,34 @@ def _inner_loop(func, device, engine, batch):
return func(batch)

def _create_process_func(funcname, device, model):
inner_step = _extract_model_method(model, funcname)
inner_step = extract_model_method(model, funcname)
inner_loop = functools.partial(_inner_loop, inner_step, device)
return inner_loop

model = idist.auto_model(model)
return Engine(_create_process_func(funcname, device, model))


def extract_model_method(model, method_name):
"""Extract a method from a model, which may be wrapped in a DistributedDataParallel
or DataParallel object. For instance, method_name could be `train_step` or
`forward`.
Parameters
----------
model : nn.Module, DistributedDataParallel, or DataParallel
The model to extract the method from
method_name : str
Name of the method to extract
Returns
-------
Callable
The method extracted from the model
"""
wrapped = type(model) == DistributedDataParallel or type(model) == DataParallel
return getattr(model.module if wrapped else model, method_name)


def create_evaluator(model: torch.nn.Module, save_function: Callable[[torch.Tensor], Any]) -> Engine:
"""Creates an evaluator engine
Primary purpose of this function is to attach the appropriate handlers to an evaluator engine
Expand Down Expand Up @@ -176,9 +191,11 @@ def create_trainer(model: torch.nn.Module, config: ConfigDict, results_directory
model = idist.auto_model(model)
trainer = create_engine("train_step", device, model)

optimizer = extract_model_method(model, "optimizer")

to_save = {
"model": model,
"optimizer": model.optimizer,
"optimizer": optimizer,
"trainer": trainer,
}

Expand Down

0 comments on commit de91389

Please sign in to comment.