Skip to content

Commit

Permalink
Add Callback function to BLIP-Diffusion pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Skquark committed Sep 26, 2023
1 parent e032adc commit 88a5f8b
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from typing import Callable

import PIL.Image
import torch
Expand Down Expand Up @@ -202,6 +203,8 @@ def __call__(
prompt_reps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -248,6 +251,12 @@ def __call__(
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
Returns:
Expand Down Expand Up @@ -335,6 +344,8 @@ def __call__(
t,
latents,
)["prev_sample"]
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
Expand Down

0 comments on commit 88a5f8b

Please sign in to comment.