Skip to content

Commit

Permalink
Added Callback function to controlnet_blip and shap_e pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
Skquark committed Sep 26, 2023
1 parent 88a5f8b commit 9152deb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from typing import Callable

import PIL.Image
import torch
Expand Down Expand Up @@ -251,6 +252,8 @@ def __call__(
prompt_reps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -296,6 +299,12 @@ def __call__(
to amplify the prompt.
prompt_reps (`int`, *optional*, defaults to 20):
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
Returns:
Expand Down Expand Up @@ -401,6 +410,8 @@ def __call__(
t,
latents,
)["prev_sample"]
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)

Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/pipelines/shap_e/pipeline_shap_e.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import Callable

import numpy as np
import PIL.Image
Expand Down Expand Up @@ -192,6 +193,8 @@ def __call__(
frame_size: int = 64,
output_type: Optional[str] = "pil", # pil, np, latent, mesh
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -222,6 +225,12 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] instead of a plain
tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
Expand Down Expand Up @@ -290,6 +299,8 @@ def __call__(
timestep=t,
sample=latents,
).prev_sample
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# Offload all models
self.maybe_free_model_hooks()
Expand Down

0 comments on commit 9152deb

Please sign in to comment.