diff --git a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py index cd5293f183ad..9dca31f4965d 100644 --- a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +++ b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union +from typing import Callable import PIL.Image import torch @@ -202,6 +203,8 @@ def __call__( prompt_reps: int = 20, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, ): """ Function invoked when calling the pipeline for generation. @@ -248,6 +251,12 @@ def __call__( (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Examples: Returns: @@ -335,6 +344,8 @@ def __call__( t, latents, )["prev_sample"] + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type)