Skip to content

Commit

Permalink
[advanced_dreambooth_lora_sdxl_tranining_script] save embeddings loca…
Browse files Browse the repository at this point in the history
…lly fix (huggingface#6058)

* Update train_dreambooth_lora_sdxl_advanced.py

* remove global function args from dreamboothdataset class

* style

* style

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
apolinario and sayakpaul authored Dec 5, 2023
1 parent 53bc30d commit 6e22133
Showing 1 changed file with 69 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,26 @@ def save_model_card(
"""

trigger_str = f"You should use {instance_prompt} to trigger the image generation."
diffusers_imports_pivotal = ""
diffusers_example_pivotal = ""
if train_text_encoder_ti:
trigger_str = (
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
"in you prompt with the new inserted tokens:\n"
)
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
"""
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
"""
if token_abstraction_dict:
for key, value in token_abstraction_dict.items():
tokens = "".join(value)
trigger_str += f"""
to trigger concept `{key}->` use `{tokens}` in your prompt \n
to trigger concept `{key}` → use `{tokens}` in your prompt \n
"""

yaml = f"""
Expand Down Expand Up @@ -172,7 +182,21 @@ def save_model_card(
{trigger_str}
## Download model
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import AutoPipelineForText2Image
import torch
{diffusers_imports_pivotal}
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
{diffusers_example_pivotal}
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
```
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke)
Weights for this model are available in Safetensors format.
Expand Down Expand Up @@ -791,6 +815,12 @@ def __init__(
instance_data_root,
instance_prompt,
class_prompt,
dataset_name,
dataset_config_name,
cache_dir,
image_column,
caption_column,
train_text_encoder_ti,
class_data_root=None,
class_num=None,
token_abstraction_dict=None, # token mapping for textual inversion
Expand All @@ -805,10 +835,10 @@ def __init__(
self.custom_instance_prompts = None
self.class_prompt = class_prompt
self.token_abstraction_dict = token_abstraction_dict

self.train_text_encoder_ti = train_text_encoder_ti
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset
if args.dataset_name is not None:
if dataset_name is not None:
try:
from datasets import load_dataset
except ImportError:
Expand All @@ -821,38 +851,37 @@ def __init__(
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
dataset_name,
dataset_config_name,
cache_dir=cache_dir,
)
# Preprocessing the datasets.
column_names = dataset["train"].column_names

# 6. Get the column names for input/target.
if args.image_column is None:
if image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]

if args.caption_column is None:
if caption_column is None:
logger.info(
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
"contains captions/prompts for the images, make sure to specify the "
"column as --caption_column"
)
self.custom_instance_prompts = None
else:
if args.caption_column not in column_names:
if caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
custom_instance_prompts = dataset["train"][args.caption_column]
custom_instance_prompts = dataset["train"][caption_column]
# create final list of captions according to --repeats
self.custom_instance_prompts = []
for caption in custom_instance_prompts:
Expand Down Expand Up @@ -907,7 +936,7 @@ def __getitem__(self, index):
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
if caption:
if args.train_text_encoder_ti:
if self.train_text_encoder_ti:
# replace instances of --token_abstraction in caption with the new tokens: "<si><si+1>" etc.
for token_abs, token_replacement in self.token_abstraction_dict.items():
caption = caption.replace(token_abs, "".join(token_replacement))
Expand Down Expand Up @@ -1093,10 +1122,10 @@ def main(args):
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

model_id = args.hub_model_id or Path(args.output_dir).name
repo_id = None
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id

# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -1464,6 +1493,12 @@ def load_model_hook(models, input_dir):
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_prompt=args.class_prompt,
dataset_name=args.dataset_name,
dataset_config_name=args.dataset_config_name,
cache_dir=args.cache_dir,
image_column=args.image_column,
train_text_encoder_ti=args.train_text_encoder_ti,
caption_column=args.caption_column,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
class_num=args.num_class_images,
Expand Down Expand Up @@ -2004,23 +2039,23 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
}
)

if args.push_to_hub:
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/embeddings.safetensors",
)
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
train_text_encoder_ti=args.train_text_encoder_ti,
token_abstraction_dict=train_dataset.token_abstraction_dict,
instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path,
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/embeddings.safetensors",
)
save_model_card(
model_id if not args.push_to_hub else repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
train_text_encoder_ti=args.train_text_encoder_ti,
token_abstraction_dict=train_dataset.token_abstraction_dict,
instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path,
)
if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
Expand Down

0 comments on commit 6e22133

Please sign in to comment.