Skip to content

Commit

Permalink
Callback for Text2Image Pipeline, Flux pipeline fix (#1223)
Browse files Browse the repository at this point in the history
Callback sample with decoding:

```cpp
auto callback = [&](size_t step, ov::Tensor& intermediate_res) -> bool {
      std::cout << "Image generation step: " << step << std::endl;
      if (step == 9) {
          ov::Tensor img = pipe.decode(intermediate_res);
          imwrite("callback_image_%d.bmp", img, true);
          return true;
      }
      return false;
  };
```
  • Loading branch information
likholat authored Nov 20, 2024
1 parent a2e1ae9 commit 40c249d
Show file tree
Hide file tree
Showing 13 changed files with 135 additions and 5 deletions.
20 changes: 20 additions & 0 deletions samples/cpp/text2image/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,26 @@ Prompt: `cyberpunk cityscape like Tokyo New York with tall buildings at dusk gol

![](./512x512.bmp)

## Run with callback

You can also add a callback to the `main.cpp` file to interrupt the image generation process earlier if you are satisfied with the intermediate result of the image generation or to add logs.

Please find the template of the callback usage below.

```cpp
auto callback = [](size_t step, ov::Tensor& intermediate_res) -> bool {
std::cout << "Image generation step: " << step << std::endl;
if (your_condition) // return true if you want to interrupt image generation
return true;
return false;
};

ov::genai::Text2ImagePipeline pipe(models_path, device);
ov::Tensor image = pipe.generate(prompt,
...
ov::genai::callback(callback)
);
```

## Run with optional LoRA adapters

Expand Down
7 changes: 6 additions & 1 deletion samples/cpp/text2image/imwrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,12 @@ void imwrite_single_image(const std::string& name, ov::Tensor image, bool conver


void imwrite(const std::string& name, ov::Tensor images, bool convert_bgr2rgb) {
const ov::Shape shape = images.get_shape(), img_shape = {1, shape[1], shape[2], shape[3]};
const ov::Shape shape = images.get_shape();
OPENVINO_ASSERT(images.get_element_type() == ov::element::u8 && shape.size() == 4,
"Image of u8 type and [1, H, W, 3] shape is expected.",
"Given image has shape ", shape, " and element type ", images.get_element_type());

const ov::Shape img_shape = {1, shape[1], shape[2], shape[3]};
uint8_t* img_data = images.data<uint8_t>();

for (int img_num = 0, num_images = shape[0], img_size = ov::shape_size(img_shape); img_num < num_images; ++img_num, img_data += img_size) {
Expand Down
19 changes: 19 additions & 0 deletions samples/python/text2image/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,25 @@ Prompt: `cyberpunk cityscape like Tokyo New York with tall buildings at dusk gol

![](./image.bmp)

## Run with callback

You can also add a callback to the `main.py` file to interrupt the image generation process earlier if you are satisfied with the intermediate result of the image generation or to add logs.

Please find the template of the callback usage below.

```python
def callback(step, intermediate_res):
print("Image generation step: ", step)
if your_condition: # return True if you want to interrupt image generation
return True
return False

pipe = openvino_genai.Text2ImagePipeline(model_dir, device)
image = pipe.generate(
...
callback = callback
)
```

## Run with optional LoRA adapters

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ static constexpr ov::Property<float> strength{"strength"};

static constexpr ov::Property<std::shared_ptr<Generator>> generator{"generator"};

static constexpr ov::Property<size_t> max_sequence_length{"max_sequence_length"};
static constexpr ov::Property<int> max_sequence_length{"max_sequence_length"};

static constexpr ov::Property<std::function<bool(size_t, ov::Tensor&)>> callback{"callback"};

OPENVINO_GENAI_EXPORTS
std::pair<std::string, ov::Any> generation_config(const ImageGenerationConfig& generation_config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class OPENVINO_GENAI_EXPORTS Text2ImagePipeline {
return generate(positive_prompt, ov::AnyMap{std::forward<Properties>(properties)...});
}

ov::Tensor decode(const ov::Tensor latent);

private:
std::shared_ptr<DiffusionPipeline> m_impl;

Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/image_generation/diffusion_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class DiffusionPipeline {

virtual ov::Tensor generate(const std::string& positive_prompt, ov::Tensor initial_image, const ov::AnyMap& properties) = 0;

virtual ov::Tensor decode(const ov::Tensor latent) = 0;

virtual ~DiffusionPipeline() = default;

protected:
Expand Down
20 changes: 19 additions & 1 deletion src/cpp/src/image_generation/flux_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,14 @@ class FluxPipeline : public DiffusionPipeline {
std::vector<float> timesteps = m_scheduler->get_float_timesteps();
size_t num_inference_steps = timesteps.size();

// Use callback if defined
std::function<bool(size_t, ov::Tensor&)> callback;
auto callback_iter = properties.find(ov::genai::callback.name());
bool do_callback = callback_iter != properties.end();
if (do_callback) {
callback = callback_iter->second.as<std::function<bool(size_t, ov::Tensor&)>>();
}

// 6. Denoising loop
ov::Tensor timestep(ov::element::f32, {1});
float* timestep_data = timestep.data<float>();
Expand All @@ -330,12 +338,22 @@ class FluxPipeline : public DiffusionPipeline {

auto scheduler_step_result = m_scheduler->step(noise_pred_tensor, latents, inference_step, generation_config.generator);
latents = scheduler_step_result["latent"];

if (do_callback) {
if (callback(inference_step, latents)) {
return ov::Tensor(ov::element::u8, {});
}
}
}

latents = unpack_latents(latents, generation_config.height, generation_config.width, vae_scale_factor);
return m_vae->decode(latents);
}

ov::Tensor decode(const ov::Tensor latent) override {
return m_vae->decode(latent);
}

private:
void initialize_generation_config(const std::string& class_name) override {
assert(m_transformer != nullptr);
Expand Down Expand Up @@ -368,7 +386,7 @@ class FluxPipeline : public DiffusionPipeline {
void check_inputs(const ImageGenerationConfig& generation_config, ov::Tensor initial_image) const override {
check_image_size(generation_config.width, generation_config.height);

OPENVINO_ASSERT(generation_config.max_sequence_length < 512, "T5's 'max_sequence_length' must be less than 512");
OPENVINO_ASSERT(generation_config.max_sequence_length <= 512, "T5's 'max_sequence_length' must be less than 512");

OPENVINO_ASSERT(generation_config.negative_prompt == std::nullopt, "Negative prompt is not used by FluxPipeline");
OPENVINO_ASSERT(generation_config.negative_prompt_2 == std::nullopt, "Negative prompt 2 is not used by FluxPipeline");
Expand Down
18 changes: 18 additions & 0 deletions src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,14 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {
// 6. Denoising loop
ov::Tensor noisy_residual_tensor(ov::element::f32, {});

// Use callback if defined
std::function<bool(size_t, ov::Tensor&)> callback;
auto callback_iter = properties.find(ov::genai::callback.name());
bool do_callback = callback_iter != properties.end();
if (do_callback) {
callback = callback_iter->second.as<std::function<bool(size_t, ov::Tensor&)>>();
}

for (size_t inference_step = 0; inference_step < timesteps.size(); ++inference_step) {
// concat the same latent twice along a batch dimension in case of CFG
if (batch_size_multiplier > 1) {
Expand Down Expand Up @@ -571,11 +579,21 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {

auto scheduler_step_result = m_scheduler->step(noisy_residual_tensor, latent, inference_step, generation_config.generator);
latent = scheduler_step_result["latent"];

if (do_callback) {
if (callback(inference_step, latent)) {
return ov::Tensor(ov::element::u8, {});
}
}
}

return m_vae->decode(latent);
}

ov::Tensor decode(const ov::Tensor latent) override {
return m_vae->decode(latent);
}

private:
bool do_classifier_free_guidance(float guidance_scale) const {
return guidance_scale > 1.0;
Expand Down
18 changes: 18 additions & 0 deletions src/cpp/src/image_generation/stable_diffusion_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,14 @@ class StableDiffusionPipeline : public DiffusionPipeline {
latent_shape_cfg[0] *= batch_size_multiplier;
ov::Tensor latent_cfg(ov::element::f32, latent_shape_cfg);

// use callback if defined
std::function<bool(size_t, ov::Tensor&)> callback;
auto callback_iter = properties.find(ov::genai::callback.name());
bool do_callback = callback_iter != properties.end();
if (do_callback) {
callback = callback_iter->second.as<std::function<bool(size_t, ov::Tensor&)>>();
}

ov::Tensor denoised, noisy_residual_tensor(ov::element::f32, {});
for (size_t inference_step = 0; inference_step < timesteps.size(); inference_step++) {
batch_copy(latent, latent_cfg, 0, 0, generation_config.num_images_per_prompt);
Expand Down Expand Up @@ -280,11 +288,21 @@ class StableDiffusionPipeline : public DiffusionPipeline {
// check whether scheduler returns "denoised" image, which should be passed to VAE decoder
const auto it = scheduler_step_result.find("denoised");
denoised = it != scheduler_step_result.end() ? it->second : latent;

if (do_callback) {
if (callback(inference_step, denoised)) {
return ov::Tensor(ov::element::u8, {});
}
}
}

return m_vae->decode(denoised);
}

ov::Tensor decode(const ov::Tensor latent) override {
return m_vae->decode(latent);
}

private:
void initialize_generation_config(const std::string& class_name) override {
assert(m_unet != nullptr);
Expand Down
18 changes: 18 additions & 0 deletions src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,14 @@ class StableDiffusionXLPipeline : public DiffusionPipeline {
latent_shape_cfg[0] *= batch_size_multiplier;
ov::Tensor latent_cfg(ov::element::f32, latent_shape_cfg);

// use callback if defined
std::function<bool(size_t, ov::Tensor&)> callback;
auto callback_iter = properties.find(ov::genai::callback.name());
bool do_callback = callback_iter != properties.end();
if (do_callback) {
callback = callback_iter->second.as<std::function<bool(size_t, ov::Tensor&)>>();
}

ov::Tensor denoised, noisy_residual_tensor(ov::element::f32, {});
for (size_t inference_step = 0; inference_step < timesteps.size(); inference_step++) {
batch_copy(latent, latent_cfg, 0, 0, generation_config.num_images_per_prompt);
Expand Down Expand Up @@ -443,11 +451,21 @@ class StableDiffusionXLPipeline : public DiffusionPipeline {
// check whether scheduler returns "denoised" image, which should be passed to VAE decoder
const auto it = scheduler_step_result.find("denoised");
denoised = it != scheduler_step_result.end() ? it->second : latent;

if (do_callback) {
if (callback(inference_step, denoised)) {
return ov::Tensor(ov::element::u8, {});
}
}
}

return m_vae->decode(denoised);
}

ov::Tensor decode(const ov::Tensor latent) override {
return m_vae->decode(latent);
}

private:
void initialize_generation_config(const std::string& class_name) override {
assert(m_unet != nullptr);
Expand Down
4 changes: 4 additions & 0 deletions src/cpp/src/image_generation/text2image_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,5 +146,9 @@ ov::Tensor Text2ImagePipeline::generate(const std::string& positive_prompt, cons
return m_impl->generate(positive_prompt, {}, properties);
}

ov::Tensor Text2ImagePipeline::decode(const ov::Tensor latent) {
return m_impl->decode(latent);
}

} // namespace genai
} // namespace ov
2 changes: 2 additions & 0 deletions src/python/openvino_genai/py_openvino_genai.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,8 @@ class Text2ImagePipeline:
device (str): Device to run the model on (e.g., CPU, GPU).
kwargs: Device properties.
"""
def decode(self, latent: openvino._pyopenvino.Tensor) -> openvino._pyopenvino.Tensor:
...
def generate(self, prompt: str, **kwargs) -> openvino._pyopenvino.Tensor:
"""
Generates images for text-to-image models.
Expand Down
6 changes: 4 additions & 2 deletions src/python/py_image_generation_pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ ov::AnyMap text2image_kwargs_to_any_map(const py::kwargs& kwargs, bool allow_com
params.insert({ov::genai::strength(std::move(py::cast<float>(value)))});
} else if (key == "max_sequence_length") {
params.insert({ov::genai::max_sequence_length(std::move(py::cast<size_t>(value)))});
} else if (key == "callback") {
params.insert({ov::genai::callback(std::move(py::cast<std::function<bool(size_t, ov::Tensor&)>>(value)))});
}
else {
if (allow_compile_properties) {
Expand Down Expand Up @@ -291,6 +293,6 @@ void init_image_generation_pipelines(py::module_& m) {
return py::cast(pipe.generate(prompt, params));
},
py::arg("prompt"), "Input string",
(text2image_generate_docstring + std::string(" \n ")).c_str()
);
(text2image_generate_docstring + std::string(" \n ")).c_str())
.def("decode", &ov::genai::Text2ImagePipeline::decode, py::arg("latent"));;
}

0 comments on commit 40c249d

Please sign in to comment.