diff --git a/imagen_pytorch/trainer.py b/imagen_pytorch/trainer.py index d3f4f23..8bab483 100644 --- a/imagen_pytorch/trainer.py +++ b/imagen_pytorch/trainer.py @@ -491,12 +491,21 @@ def wrap_unet(self, unet_number): # hacking accelerator due to not having separate gradscaler per optimizer def set_accelerator_scaler(self, unet_number): + def patch_optimizer_step(accelerated_optimizer, method): + def patched_step(*args, **kwargs): + accelerated_optimizer._accelerate_step_called = True + return method(*args, **kwargs) + return patched_step + unet_number = self.validate_unet_number(unet_number) scaler = getattr(self, f'scaler{unet_number - 1}') self.accelerator.scaler = scaler for optimizer in self.accelerator._optimizers: optimizer.scaler = scaler + optimizer._accelerate_step_called = False + optimizer._optimizer_original_step_method = optimizer.optimizer.step + optimizer._optimizer_patched_step_method = patch_optimizer_step(optimizer, optimizer.optimizer.step) # helper print