Skip to content

Commit

Permalink
SplitBatchRunner
Browse files Browse the repository at this point in the history
  • Loading branch information
emailweixu committed Dec 4, 2024
1 parent d30cdb4 commit f37a993
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions alf/utils/lean_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,49 @@ def _infer_device_type(*args):
return "cuda"
else:
return device_types[0]


class SplitBatchRunner(torch.nn.Module):
"""Split the input into smaller batches and run the model on each batch.
Note that models using random number generators (e.g. DropOut) are not supported
for training.
Args:
model (nn.Module): the model to run
max_batch_size (int): the maximum batch size to run the model.
"""

def __init__(self, model, max_batch_size):
super().__init__()
self._model = lean_function(model)
self._max_batch_size = max_batch_size

def forward(self, *args, non_batched_inputs={}, **kwargs):
"""Run the model on the input.
Args:
args (tuple): positional arguments for the model
non_batched_inputs (dict): non-batched inputs for the model
kwargs (dict): keyword arguments for the model
"""
batch_size = alf.nest.get_nest_batch_size((args, kwargs))
if self._max_batch_size <= 0 or batch_size <= self._max_batch_size:
return self._model._original_forward_for_lean_function(
*args, **kwargs, **non_batched_inputs)

outputs = []
for i in range(0, batch_size, self._max_batch_size):
batch_args, batch_kwargs = alf.nest.map_structure(
lambda x: x[i:i + self._max_batch_size], (args, kwargs))
outputs.append(
self._model(*batch_args, **batch_kwargs, **non_batched_inputs))

return alf.nest.map_structure(lambda *x: torch.cat(x, dim=0), *outputs)

def original_forward(self, *args, **kwargs):
"""Run the model on the input without splitting the batch."""
return self._model._original_forward_for_lean_function(*args, **kwargs)

def __getattr__(self, name):
return getattr(self._model, name)

0 comments on commit f37a993

Please sign in to comment.