diff --git a/samples/cpp/text2image/README.md b/samples/cpp/text2image/README.md index 590efb4f6d..c5ffd53a84 100644 --- a/samples/cpp/text2image/README.md +++ b/samples/cpp/text2image/README.md @@ -39,6 +39,26 @@ Prompt: `cyberpunk cityscape like Tokyo New York with tall buildings at dusk gol ![](./512x512.bmp) +## Run with callback + +You can also add a callback to the `main.cpp` file to interrupt the image generation process earlier if you are satisfied with the intermediate result of the image generation or to add logs. + +Please find the template of the callback usage below. + +```cpp +auto callback = [](size_t step, ov::Tensor& intermediate_res) -> bool { + std::cout << "Image generation step: " << step << std::endl; + if (your_condition) // return true if you want to interrupt image generation + return true; + return false; +}; + +ov::genai::Text2ImagePipeline pipe(models_path, device); +ov::Tensor image = pipe.generate(prompt, + ... + ov::genai::callback(callback) +); +``` ## Run with optional LoRA adapters diff --git a/samples/cpp/text2image/imwrite.cpp b/samples/cpp/text2image/imwrite.cpp index b25db03051..373da7dfe8 100644 --- a/samples/cpp/text2image/imwrite.cpp +++ b/samples/cpp/text2image/imwrite.cpp @@ -135,7 +135,12 @@ void imwrite_single_image(const std::string& name, ov::Tensor image, bool conver void imwrite(const std::string& name, ov::Tensor images, bool convert_bgr2rgb) { - const ov::Shape shape = images.get_shape(), img_shape = {1, shape[1], shape[2], shape[3]}; + const ov::Shape shape = images.get_shape(); + OPENVINO_ASSERT(images.get_element_type() == ov::element::u8 && shape.size() == 4, + "Image of u8 type and [1, H, W, 3] shape is expected.", + "Given image has shape ", shape, " and element type ", images.get_element_type()); + + const ov::Shape img_shape = {1, shape[1], shape[2], shape[3]}; uint8_t* img_data = images.data(); for (int img_num = 0, num_images = shape[0], img_size = ov::shape_size(img_shape); img_num < num_images; ++img_num, img_data += img_size) { diff --git a/samples/python/text2image/README.md b/samples/python/text2image/README.md index bacb2e2838..9421061885 100644 --- a/samples/python/text2image/README.md +++ b/samples/python/text2image/README.md @@ -39,6 +39,25 @@ Prompt: `cyberpunk cityscape like Tokyo New York with tall buildings at dusk gol ![](./image.bmp) +## Run with callback + +You can also add a callback to the `main.py` file to interrupt the image generation process earlier if you are satisfied with the intermediate result of the image generation or to add logs. + +Please find the template of the callback usage below. + +```python +def callback(step, intermediate_res): + print("Image generation step: ", step) + if your_condition: # return True if you want to interrupt image generation + return True + return False + +pipe = openvino_genai.Text2ImagePipeline(model_dir, device) +image = pipe.generate( + ... + callback = callback +) +``` ## Run with optional LoRA adapters diff --git a/src/cpp/include/openvino/genai/image_generation/generation_config.hpp b/src/cpp/include/openvino/genai/image_generation/generation_config.hpp index 8478c2c12b..e798651580 100644 --- a/src/cpp/include/openvino/genai/image_generation/generation_config.hpp +++ b/src/cpp/include/openvino/genai/image_generation/generation_config.hpp @@ -99,7 +99,9 @@ static constexpr ov::Property strength{"strength"}; static constexpr ov::Property> generator{"generator"}; -static constexpr ov::Property max_sequence_length{"max_sequence_length"}; +static constexpr ov::Property max_sequence_length{"max_sequence_length"}; + +static constexpr ov::Property> callback{"callback"}; OPENVINO_GENAI_EXPORTS std::pair generation_config(const ImageGenerationConfig& generation_config); diff --git a/src/cpp/include/openvino/genai/image_generation/text2image_pipeline.hpp b/src/cpp/include/openvino/genai/image_generation/text2image_pipeline.hpp index f3a00196b9..cdfe372df9 100644 --- a/src/cpp/include/openvino/genai/image_generation/text2image_pipeline.hpp +++ b/src/cpp/include/openvino/genai/image_generation/text2image_pipeline.hpp @@ -111,6 +111,8 @@ class OPENVINO_GENAI_EXPORTS Text2ImagePipeline { return generate(positive_prompt, ov::AnyMap{std::forward(properties)...}); } + ov::Tensor decode(const ov::Tensor latent); + private: std::shared_ptr m_impl; diff --git a/src/cpp/src/image_generation/diffusion_pipeline.hpp b/src/cpp/src/image_generation/diffusion_pipeline.hpp index 7a932cdb9d..786115e2f1 100644 --- a/src/cpp/src/image_generation/diffusion_pipeline.hpp +++ b/src/cpp/src/image_generation/diffusion_pipeline.hpp @@ -82,6 +82,8 @@ class DiffusionPipeline { virtual ov::Tensor generate(const std::string& positive_prompt, ov::Tensor initial_image, const ov::AnyMap& properties) = 0; + virtual ov::Tensor decode(const ov::Tensor latent) = 0; + virtual ~DiffusionPipeline() = default; protected: diff --git a/src/cpp/src/image_generation/flux_pipeline.hpp b/src/cpp/src/image_generation/flux_pipeline.hpp index 94949cb191..20a7afd432 100644 --- a/src/cpp/src/image_generation/flux_pipeline.hpp +++ b/src/cpp/src/image_generation/flux_pipeline.hpp @@ -319,6 +319,14 @@ class FluxPipeline : public DiffusionPipeline { std::vector timesteps = m_scheduler->get_float_timesteps(); size_t num_inference_steps = timesteps.size(); + // Use callback if defined + std::function callback; + auto callback_iter = properties.find(ov::genai::callback.name()); + bool do_callback = callback_iter != properties.end(); + if (do_callback) { + callback = callback_iter->second.as>(); + } + // 6. Denoising loop ov::Tensor timestep(ov::element::f32, {1}); float* timestep_data = timestep.data(); @@ -330,12 +338,22 @@ class FluxPipeline : public DiffusionPipeline { auto scheduler_step_result = m_scheduler->step(noise_pred_tensor, latents, inference_step, generation_config.generator); latents = scheduler_step_result["latent"]; + + if (do_callback) { + if (callback(inference_step, latents)) { + return ov::Tensor(ov::element::u8, {}); + } + } } latents = unpack_latents(latents, generation_config.height, generation_config.width, vae_scale_factor); return m_vae->decode(latents); } + ov::Tensor decode(const ov::Tensor latent) override { + return m_vae->decode(latent); + } + private: void initialize_generation_config(const std::string& class_name) override { assert(m_transformer != nullptr); @@ -368,7 +386,7 @@ class FluxPipeline : public DiffusionPipeline { void check_inputs(const ImageGenerationConfig& generation_config, ov::Tensor initial_image) const override { check_image_size(generation_config.width, generation_config.height); - OPENVINO_ASSERT(generation_config.max_sequence_length < 512, "T5's 'max_sequence_length' must be less than 512"); + OPENVINO_ASSERT(generation_config.max_sequence_length <= 512, "T5's 'max_sequence_length' must be less than 512"); OPENVINO_ASSERT(generation_config.negative_prompt == std::nullopt, "Negative prompt is not used by FluxPipeline"); OPENVINO_ASSERT(generation_config.negative_prompt_2 == std::nullopt, "Negative prompt 2 is not used by FluxPipeline"); diff --git a/src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp b/src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp index f9b08ca456..71d9fdd6ff 100644 --- a/src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp +++ b/src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp @@ -537,6 +537,14 @@ class StableDiffusion3Pipeline : public DiffusionPipeline { // 6. Denoising loop ov::Tensor noisy_residual_tensor(ov::element::f32, {}); + // Use callback if defined + std::function callback; + auto callback_iter = properties.find(ov::genai::callback.name()); + bool do_callback = callback_iter != properties.end(); + if (do_callback) { + callback = callback_iter->second.as>(); + } + for (size_t inference_step = 0; inference_step < timesteps.size(); ++inference_step) { // concat the same latent twice along a batch dimension in case of CFG if (batch_size_multiplier > 1) { @@ -571,11 +579,21 @@ class StableDiffusion3Pipeline : public DiffusionPipeline { auto scheduler_step_result = m_scheduler->step(noisy_residual_tensor, latent, inference_step, generation_config.generator); latent = scheduler_step_result["latent"]; + + if (do_callback) { + if (callback(inference_step, latent)) { + return ov::Tensor(ov::element::u8, {}); + } + } } return m_vae->decode(latent); } + ov::Tensor decode(const ov::Tensor latent) override { + return m_vae->decode(latent); + } + private: bool do_classifier_free_guidance(float guidance_scale) const { return guidance_scale > 1.0; diff --git a/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp b/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp index 5009649825..ecf810827c 100644 --- a/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp +++ b/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp @@ -242,6 +242,14 @@ class StableDiffusionPipeline : public DiffusionPipeline { latent_shape_cfg[0] *= batch_size_multiplier; ov::Tensor latent_cfg(ov::element::f32, latent_shape_cfg); + // use callback if defined + std::function callback; + auto callback_iter = properties.find(ov::genai::callback.name()); + bool do_callback = callback_iter != properties.end(); + if (do_callback) { + callback = callback_iter->second.as>(); + } + ov::Tensor denoised, noisy_residual_tensor(ov::element::f32, {}); for (size_t inference_step = 0; inference_step < timesteps.size(); inference_step++) { batch_copy(latent, latent_cfg, 0, 0, generation_config.num_images_per_prompt); @@ -280,11 +288,21 @@ class StableDiffusionPipeline : public DiffusionPipeline { // check whether scheduler returns "denoised" image, which should be passed to VAE decoder const auto it = scheduler_step_result.find("denoised"); denoised = it != scheduler_step_result.end() ? it->second : latent; + + if (do_callback) { + if (callback(inference_step, denoised)) { + return ov::Tensor(ov::element::u8, {}); + } + } } return m_vae->decode(denoised); } + ov::Tensor decode(const ov::Tensor latent) override { + return m_vae->decode(latent); + } + private: void initialize_generation_config(const std::string& class_name) override { assert(m_unet != nullptr); diff --git a/src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp b/src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp index 07de4e6b2d..70f69d37fa 100644 --- a/src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp +++ b/src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp @@ -405,6 +405,14 @@ class StableDiffusionXLPipeline : public DiffusionPipeline { latent_shape_cfg[0] *= batch_size_multiplier; ov::Tensor latent_cfg(ov::element::f32, latent_shape_cfg); + // use callback if defined + std::function callback; + auto callback_iter = properties.find(ov::genai::callback.name()); + bool do_callback = callback_iter != properties.end(); + if (do_callback) { + callback = callback_iter->second.as>(); + } + ov::Tensor denoised, noisy_residual_tensor(ov::element::f32, {}); for (size_t inference_step = 0; inference_step < timesteps.size(); inference_step++) { batch_copy(latent, latent_cfg, 0, 0, generation_config.num_images_per_prompt); @@ -443,11 +451,21 @@ class StableDiffusionXLPipeline : public DiffusionPipeline { // check whether scheduler returns "denoised" image, which should be passed to VAE decoder const auto it = scheduler_step_result.find("denoised"); denoised = it != scheduler_step_result.end() ? it->second : latent; + + if (do_callback) { + if (callback(inference_step, denoised)) { + return ov::Tensor(ov::element::u8, {}); + } + } } return m_vae->decode(denoised); } + ov::Tensor decode(const ov::Tensor latent) override { + return m_vae->decode(latent); + } + private: void initialize_generation_config(const std::string& class_name) override { assert(m_unet != nullptr); diff --git a/src/cpp/src/image_generation/text2image_pipeline.cpp b/src/cpp/src/image_generation/text2image_pipeline.cpp index 24f908de55..e0ecfeb452 100644 --- a/src/cpp/src/image_generation/text2image_pipeline.cpp +++ b/src/cpp/src/image_generation/text2image_pipeline.cpp @@ -146,5 +146,9 @@ ov::Tensor Text2ImagePipeline::generate(const std::string& positive_prompt, cons return m_impl->generate(positive_prompt, {}, properties); } +ov::Tensor Text2ImagePipeline::decode(const ov::Tensor latent) { + return m_impl->decode(latent); +} + } // namespace genai } // namespace ov diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index 61ab25a954..a16b74b703 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -1247,6 +1247,8 @@ class Text2ImagePipeline: device (str): Device to run the model on (e.g., CPU, GPU). kwargs: Device properties. """ + def decode(self, latent: openvino._pyopenvino.Tensor) -> openvino._pyopenvino.Tensor: + ... def generate(self, prompt: str, **kwargs) -> openvino._pyopenvino.Tensor: """ Generates images for text-to-image models. diff --git a/src/python/py_image_generation_pipelines.cpp b/src/python/py_image_generation_pipelines.cpp index 1413d07026..f70faaca61 100644 --- a/src/python/py_image_generation_pipelines.cpp +++ b/src/python/py_image_generation_pipelines.cpp @@ -146,6 +146,8 @@ ov::AnyMap text2image_kwargs_to_any_map(const py::kwargs& kwargs, bool allow_com params.insert({ov::genai::strength(std::move(py::cast(value)))}); } else if (key == "max_sequence_length") { params.insert({ov::genai::max_sequence_length(std::move(py::cast(value)))}); + } else if (key == "callback") { + params.insert({ov::genai::callback(std::move(py::cast>(value)))}); } else { if (allow_compile_properties) { @@ -291,6 +293,6 @@ void init_image_generation_pipelines(py::module_& m) { return py::cast(pipe.generate(prompt, params)); }, py::arg("prompt"), "Input string", - (text2image_generate_docstring + std::string(" \n ")).c_str() - ); + (text2image_generate_docstring + std::string(" \n ")).c_str()) + .def("decode", &ov::genai::Text2ImagePipeline::decode, py::arg("latent"));; }