Skip to content

Commit

Permalink
Added missing callback_on_step_end to Lumina
Browse files Browse the repository at this point in the history
  • Loading branch information
Skquark committed Jul 8, 2024
1 parent 3e7a4e1 commit b0dcf1a
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/diffusers/pipelines/lumina/pipeline_lumina.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import math
import re
import urllib.parse as ul
from typing import List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from transformers import AutoModel, AutoTokenizer
Expand Down Expand Up @@ -629,6 +629,8 @@ def __call__(
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
return_dict: bool = True,
clean_caption: bool = True,
max_sequence_length: int = 256,
Expand Down Expand Up @@ -879,6 +881,14 @@ def __call__(
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
if callback_outputs is not None:
latents = callback_outputs.pop("latents", latents)

progress_bar.update()

if not output_type == "latent":
Expand Down

0 comments on commit b0dcf1a

Please sign in to comment.