diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 41c88cbd2b1e..be6fe203d7e5 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -17,7 +17,7 @@ import math import re import urllib.parse as ul -from typing import List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from transformers import AutoModel, AutoTokenizer @@ -629,6 +629,8 @@ def __call__( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], return_dict: bool = True, clean_caption: bool = True, max_sequence_length: int = 256, @@ -879,6 +881,14 @@ def __call__( # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + if callback_outputs is not None: + latents = callback_outputs.pop("latents", latents) + progress_bar.update() if not output_type == "latent":