Skip to content

Commit

Permalink
Merge pull request #189 from huggingface/main
Browse files Browse the repository at this point in the history
Merge changes
  • Loading branch information
Skquark authored Nov 10, 2024
2 parents 743ca08 + dac623b commit c6441f0
Show file tree
Hide file tree
Showing 68 changed files with 5,181 additions and 2,877 deletions.
4 changes: 2 additions & 2 deletions docs/source/en/api/models/controlnet.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro

## ControlNetOutput

[[autodoc]] models.controlnet.ControlNetOutput
[[autodoc]] models.controlnets.controlnet.ControlNetOutput

## FlaxControlNetModel

[[autodoc]] FlaxControlNetModel

## FlaxControlNetOutput

[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
[[autodoc]] models.controlnets.controlnet_flax.FlaxControlNetOutput
2 changes: 1 addition & 1 deletion docs/source/en/api/models/controlnet_sd3.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-di

## SD3ControlNetOutput

[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput
[[autodoc]] models.controlnets.controlnet_sd3.SD3ControlNetOutput

Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
Expand All @@ -59,12 +59,13 @@
)
from diffusers.loaders import StableDiffusionLoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
from diffusers.utils import (
check_min_version,
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
Expand Down Expand Up @@ -1319,6 +1320,37 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)

unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)

if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)

_set_state_dict_into_text_encoder(
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
)

# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)

Expand Down
156 changes: 156 additions & 0 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif

| Example | Description | Code Example | Colab | Author |
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|NA|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
| HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) |
Expand Down Expand Up @@ -85,6 +86,161 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion

## Example usages

### Adaptive Mask Inpainting

**Hyeonwoo Kim\*, Sookwan Han\*, Patrick Kwon, Hanbyul Joo**

**Seoul National University, Naver Webtoon**

Adaptive Mask Inpainting, presented in the ECCV'24 oral paper [*Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models*](https://snuvclab.github.io/coma), is an algorithm designed to insert humans into scene images without altering the background. Traditional inpainting methods often fail to preserve object geometry and details within the masked region, leading to false affordances. Adaptive Mask Inpainting addresses this issue by progressively specifying the inpainting region over diffusion timesteps, ensuring that the inserted human integrates seamlessly with the existing scene.

Here is the demonstration of Adaptive Mask Inpainting:

<video controls>
<source src="https://snuvclab.github.io/coma/static/videos/adaptive_mask_inpainting_vis.mp4" type="video/mp4">
Your browser does not support the video tag.
</video>

![teaser-img](https://snuvclab.github.io/coma/static/images/example_result_adaptive_mask_inpainting.png)


You can find additional information about Adaptive Mask Inpainting in the [paper](https://arxiv.org/pdf/2401.12978) or in the [project website](https://snuvclab.github.io/coma).

#### Usage example
First, clone the diffusers github repository, and run the following command to set environment.
```Shell
git clone https://github.com/huggingface/diffusers.git
cd diffusers

conda create --name ami python=3.9 -y
conda activate ami

conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge -y
python -m pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
pip install easydict
pip install diffusers==0.20.2 accelerate safetensors transformers
pip install setuptools==59.5.0
pip install opencv-python
pip install numpy==1.24.1
```
Then, run the below code under 'diffusers' directory.
```python
import numpy as np
import torch
from PIL import Image

from diffusers import DDIMScheduler
from diffusers import DiffusionPipeline
from diffusers.utils import load_image

from examples.community.adaptive_mask_inpainting import download_file, AdaptiveMaskInpaintPipeline, AMI_INSTALL_MESSAGE

print(AMI_INSTALL_MESSAGE)

from easydict import EasyDict



if __name__ == "__main__":
"""
Download Necessary Files
"""
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/model_final_edd263.pkl?download=true",
output_file = "model_final_edd263.pkl",
exist_ok=True,
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/pointrend_rcnn_R_50_FPN_3x_coco.yaml?download=true",
output_file = "pointrend_rcnn_R_50_FPN_3x_coco.yaml",
exist_ok=True,
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_img.png?download=true",
output_file = "input_img.png",
exist_ok=True,
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_mask.png?download=true",
output_file = "input_mask.png",
exist_ok=True,
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-PointRend-RCNN-FPN.yaml?download=true",
output_file = "Base-PointRend-RCNN-FPN.yaml",
exist_ok=True,
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-RCNN-FPN.yaml?download=true",
output_file = "Base-RCNN-FPN.yaml",
exist_ok=True,
)

"""
Prepare Adaptive Mask Inpainting Pipeline
"""
# device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
num_steps = 50

# Scheduler
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False
)
scheduler.set_timesteps(num_inference_steps=num_steps)

## load models as pipelines
pipeline = AdaptiveMaskInpaintPipeline.from_pretrained(
"Uminosachi/realisticVisionV51_v51VAE-inpainting",
scheduler=scheduler,
torch_dtype=torch.float16,
requires_safety_checker=False
).to(device)

## disable safety checker
enable_safety_checker = False
if not enable_safety_checker:
pipeline.safety_checker = None

"""
Run Adaptive Mask Inpainting
"""
default_mask_image = Image.open("./input_mask.png").convert("L")
init_image = Image.open("./input_img.png").convert("RGB")


seed = 59
generator = torch.Generator(device=device)
generator.manual_seed(seed)

image = pipeline(
prompt="a man sitting on a couch",
negative_prompt="worst quality, normal quality, low quality, bad anatomy, artifacts, blurry, cropped, watermark, greyscale, nsfw",
image=init_image,
default_mask_image=default_mask_image,
guidance_scale=11.0,
strength=0.98,
use_adaptive_mask=True,
generator=generator,
enforce_full_mask_ratio=0.0,
visualization_save_dir="./ECCV2024_adaptive_mask_inpainting_demo", # DON'T CHANGE THIS!!!
human_detection_thres=0.015,
).images[0]


image.save(f'final_img.png')
```
#### [Troubleshooting]

If you run into an error `cannot import name 'cached_download' from 'huggingface_hub'` (issue [1851](https://github.com/easydiffusion/easydiffusion/issues/1851)), remove `cached_download` from the import line in the file `diffusers/utils/dynamic_modules_utils.py`.

For example, change the import line from `.../env/lib/python3.8/site-packages/diffusers/utils/dynamic_modules_utils.py`.


### Flux with CFG

Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).
Expand Down
Loading

0 comments on commit c6441f0

Please sign in to comment.