From 9152debac9e3b09bd8aa512ee6e50f3ee1ab9dff Mon Sep 17 00:00:00 2001 From: Alan Bedian Date: Tue, 26 Sep 2023 01:22:33 -0700 Subject: [PATCH] Added Callback function to controlnet_blip and shap_e pipelines --- .../controlnet/pipeline_controlnet_blip_diffusion.py | 11 +++++++++++ src/diffusers/pipelines/shap_e/pipeline_shap_e.py | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py index e10a8624f068..ca9da9c552aa 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union +from typing import Callable import PIL.Image import torch @@ -251,6 +252,8 @@ def __call__( prompt_reps: int = 20, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, ): """ Function invoked when calling the pipeline for generation. @@ -296,6 +299,12 @@ def __call__( to amplify the prompt. prompt_reps (`int`, *optional*, defaults to 20): The number of times the prompt is repeated along with prompt_strength to amplify the prompt. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Examples: Returns: @@ -401,6 +410,8 @@ def __call__( t, latents, )["prev_sample"] + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py index 87e756b8bd79..874ec7b21858 100644 --- a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py @@ -15,6 +15,7 @@ import math from dataclasses import dataclass from typing import List, Optional, Union +from typing import Callable import numpy as np import PIL.Image @@ -192,6 +193,8 @@ def __call__( frame_size: int = 64, output_type: Optional[str] = "pil", # pil, np, latent, mesh return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, ): """ The call function to the pipeline for generation. @@ -222,6 +225,12 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Examples: @@ -290,6 +299,8 @@ def __call__( timestep=t, sample=latents, ).prev_sample + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # Offload all models self.maybe_free_model_hooks()