Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flux img2img controlnet channels error #9979

Open
wen020 opened this issue Nov 21, 2024 · 3 comments
Open

flux img2img controlnet channels error #9979

wen020 opened this issue Nov 21, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@wen020
Copy link

wen020 commented Nov 21, 2024

Describe the bug

When I use flux's img2img controlnet for inference, a channel error occurs.

Reproduction

import numpy as np
import torch
import cv2
from PIL import Image
from diffusers.utils import load_image
from diffusers import FluxControlNetImg2ImgPipeline, FluxControlNetPipeline
from diffusers import FluxControlNetModel
from controlnet_aux import HEDdetector

base_model = "black-forest-labs/FLUX.1-dev"
controlnet_model = "Xlabs-AI/flux-controlnet-hed-diffusers"
controlnet = FluxControlNetModel.from_pretrained(
controlnet_model,
torch_dtype=torch.bfloat16,
use_safetensors=True,
)
pipe = FluxControlNetImg2ImgPipeline.from_pretrained(
base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
)
pipe.load_lora_weights("./toonystarkKoreanWebtoonFlux_fluxLoraAlpha.safetensors")

pipe.enable_sequential_cpu_offload()

hed = HEDdetector.from_pretrained("lllyasviel/Annotators")

image_source = load_image("./03.jpeg")
control_image = hed(image_source)
control_image = control_image.resize(image_source.size)
if control_image.mode != 'RGB':
control_image = control_image.convert('RGB')
control_image.save(f"./hed_03.png")

prompt = "bird, cool, futuristic"
image = pipe(
prompt,
image=image_source,
control_image=control_image,
control_guidance_start=0.2,
control_guidance_end=0.8,
controlnet_conditioning_scale=0.5,
num_inference_steps=50,
guidance_scale=6,
).images[0]
image.save("flux.png")

Logs

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[13], line 2
      1 prompt = "bird, cool, futuristic"
----> 2 image = pipe(
      3     prompt,
      4     image=image_source,
      5     control_image=control_image,
      6     control_guidance_start=0.2,
      7     control_guidance_end=0.8,
      8     controlnet_conditioning_scale=0.5,
      9     num_inference_steps=50,
     10     guidance_scale=6,
     11 ).images[0]
     12 image.save("flux.png")

File /opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /opt/conda/lib/python3.11/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py:924, in FluxControlNetImg2ImgPipeline.__call__(self, prompt, prompt_2, image, control_image, height, width, strength, num_inference_steps, timesteps, guidance_scale, control_guidance_start, control_guidance_end, control_mode, controlnet_conditioning_scale, num_images_per_prompt, generator, latents, prompt_embeds, pooled_prompt_embeds, output_type, return_dict, joint_attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    921         controlnet_cond_scale = controlnet_cond_scale[0]
    922     cond_scale = controlnet_cond_scale * controlnet_keep[i]
--> 924 controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
    925     hidden_states=latents,
    926     controlnet_cond=control_image,
    927     controlnet_mode=control_mode,
    928     conditioning_scale=cond_scale,
    929     timestep=timestep / 1000,
    930     guidance=guidance,
    931     pooled_projections=pooled_prompt_embeds,
    932     encoder_hidden_states=prompt_embeds,
    933     txt_ids=text_ids,
    934     img_ids=latent_image_ids,
    935     joint_attention_kwargs=self.joint_attention_kwargs,
    936     return_dict=False,
    937 )
    939 guidance = (
    940     torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
    941 )
    942 guidance = guidance.expand(latents.shape[0]) if guidance is not None else None

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /opt/conda/lib/python3.11/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.11/site-packages/diffusers/models/controlnets/controlnet_flux.py:281, in FluxControlNetModel.forward(self, hidden_states, controlnet_cond, controlnet_mode, conditioning_scale, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids, guidance, joint_attention_kwargs, return_dict)
    278 hidden_states = self.x_embedder(hidden_states)
    280 if self.input_hint_block is not None:
--> 281     controlnet_cond = self.input_hint_block(controlnet_cond)
    282     batch_size, channels, height_pw, width_pw = controlnet_cond.shape
    283     height = height_pw // self.config.patch_size

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /opt/conda/lib/python3.11/site-packages/diffusers/models/controlnets/controlnet.py:99, in ControlNetConditioningEmbedding.forward(self, conditioning)
     98 def forward(self, conditioning):
---> 99     embedding = self.conv_in(conditioning)
    100     embedding = F.silu(embedding)
    102     for block in self.blocks:

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /opt/conda/lib/python3.11/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/conv.py:460, in Conv2d.forward(self, input)
    459 def forward(self, input: Tensor) -> Tensor:
--> 460     return self._conv_forward(input, self.weight, self.bias)

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/conv.py:456, in Conv2d._conv_forward(self, input, weight, bias)
    452 if self.padding_mode != 'zeros':
    453     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    454                     weight, bias, self.stride,
    455                     _pair(0), self.dilation, self.groups)
--> 456 return F.conv2d(input, weight, bias, self.stride,
    457                 self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [16, 3, 3, 3], expected input[1, 1, 4096, 64] to have 3 channels, but got 1 channels instead

System Info

latest diffusers

Who can help?

@yiyixuxu @sayakpaul

@wen020 wen020 added the bug Something isn't working label Nov 21, 2024
@deepak-lenka
Copy link

The error occurs because the control image has only 1 channel (grayscale) while the Flux ControlNet model expects a 3-channel (RGB) image, causing a channel mismatch during the image-to-image generation process.

@wen020
Copy link
Author

wen020 commented Nov 21, 2024

I found that the image has 3 channels

@wen020
Copy link
Author

wen020 commented Nov 21, 2024

I checked that the control image is 3-channel, but the problem still exists

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants