Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement validation loss using FID scores and add corresponding documentation #326

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
ee15c56
Fix: Ensure sampling only happens *after* each epoch in training loop
IndigoDosSantos Jun 2, 2024
c0d325a
Save samples for each epoch in separate folders
IndigoDosSantos Jun 2, 2024
71ae93a
Add hidden epochs folder for organizing samples
IndigoDosSantos Jun 3, 2024
57c409d
Add `torch-fidelity` for FID score calculation
IndigoDosSantos Jun 3, 2024
7b78225
Calculate FID scores for the latest epoch of generated images against…
IndigoDosSantos Jun 3, 2024
7781ee5
Calculate FID scores during training
IndigoDosSantos Jun 3, 2024
09f2ba8
Add functionality to delete the hidden epochs folder after training
IndigoDosSantos Jun 4, 2024
c7c22b3
Return the epoch_fid_scores dictionary
IndigoDosSantos Jun 5, 2024
535723b
Use scalar epoch numbers as keys in epoch_fid_scores dictionary
IndigoDosSantos Jun 5, 2024
73b3e6b
Extract validation_images_path and epochs_path to a separate method
IndigoDosSantos Jun 5, 2024
d124c3a
Add FID score tensorboard logging
IndigoDosSantos Jun 5, 2024
163ef32
Update GenericTrainer.py
IndigoDosSantos Jun 5, 2024
0b92dea
Delete epochs folder before starting training
IndigoDosSantos Jun 6, 2024
074681c
Correct FID score logging and implement JSON file usage for storing s…
IndigoDosSantos Jun 6, 2024
13a4706
Load FID scores from fid_scores.json
IndigoDosSantos Jun 6, 2024
6ad8513
Create page1.md
IndigoDosSantos Jun 6, 2024
2772188
Add wiki documentation for validation loss using FID scores
IndigoDosSantos Jun 7, 2024
01c4161
Update and rename page1.md to validation_loss.md
IndigoDosSantos Jun 7, 2024
61899ce
Merge branch 'Nerogar:master' into TensorBoard
IndigoDosSantos Jun 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 145 additions & 11 deletions modules/trainer/GenericTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import traceback
from pathlib import Path
from typing import Callable
import ctypes

import torch
from PIL.Image import Image
Expand Down Expand Up @@ -36,6 +37,7 @@
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util.time_util import get_string_timestamp
from modules.util.torch_util import torch_gc
from scripts.calculate_fid_scores import calculate_fid_scores


class GenericTrainer(BaseTrainer):
Expand Down Expand Up @@ -83,7 +85,19 @@ def __init__(self, config: TrainConfig, callbacks: TrainCallbacks, commands: Tra

self.grad_hook_handles = []

def pre_training_check(self):
epochs_dir = "workspace/run/epochs" # Use relative path directly
if os.path.exists(epochs_dir):
try:
print(f"Found existing 'epochs' folder. Deleting...")
shutil.rmtree(epochs_dir)
print("Deletion successful.")
except Exception as e:
print(f"Error deleting 'epochs' folder: {e}")

def start(self):
self.pre_training_check()

if self.config.clear_cache_before_training and self.config.latent_caching:
self.__clear_cache()

Expand Down Expand Up @@ -230,7 +244,7 @@ def __sample_loop(
def on_sample_default(image: Image):
if self.config.samples_to_tensorboard:
self.tensorboard.add_image(f"sample{str(i)} - {safe_prompt}", pil_to_tensor(image),
train_progress.global_step)
train_progress.global_step)
self.callbacks.on_sample_default(image)

def on_sample_custom(image: Image):
Expand All @@ -257,18 +271,80 @@ def on_sample_custom(image: Image):

torch_gc()

current_epoch = train_progress.epoch

for i, sample_params in enumerate(sample_params_list):
if sample_params.enabled:
safe_prompt = path_util.safe_filename(sample_params.prompt)
sample_dir = os.path.join(
self.config.workspace_dir,
"samples",
f"{str(i)} - {safe_prompt}",
)

# Find the most recent sample generated during the current epoch
latest_sample = None
latest_timestamp = None
for filename in os.listdir(sample_dir):
if filename.lower().endswith((".png", ".jpg", ".jpeg")):
parts = filename.split("-")
epoch = int(parts[-2]) # Extract the epoch number from the second-to-last part
if epoch == current_epoch and (latest_timestamp is None or filename > latest_timestamp):
latest_timestamp = filename
latest_sample = filename

# Copy the most recent sample from the current epoch to the epoch-specific folder
if latest_sample is not None:
src_path = os.path.join(sample_dir, latest_sample)
dst_path = os.path.join(os.path.join(self.config.workspace_dir, "epochs", f"class_{current_epoch}"), f"{safe_prompt}_{latest_sample}")
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
shutil.copy2(src_path, dst_path)

def get_validation_and_epochs_paths(self):
# Get the path to the "training_concepts" directory
training_concepts_dir = os.path.join(os.path.dirname(__file__), "..", "..", "training_concepts")

# Construct the path to the "concepts.json" file
concepts_file = os.path.join(training_concepts_dir, "concepts.json")

# Read the concepts.json file
with open(concepts_file, "r") as f:
concepts = json.load(f)

# Find the concept named "validation_images"
validation_images_path = None
for concept in concepts:
if concept["name"] == "validation_images":
validation_images_path = concept["path"]
break

epochs_path = None
if validation_images_path is not None:
# Get the path to the "scripts" directory
scripts_dir = os.path.join(os.path.dirname(__file__), "..", "..", "scripts")
# Add the "scripts" directory to the Python module search path
sys.path.append(scripts_dir)
# Define the epochs_path variable pointing to the hidden "epochs" directory
epochs_path = os.path.join(self.config.workspace_dir, "epochs")
else:
print("No 'validation_images' concept found in concepts.json. Skipping FID score calculation.")

return validation_images_path, epochs_path

def __sample_during_training(
self,
train_progress: TrainProgress,
train_device: torch.device,
sample_params_list: list[SampleConfig] = None,
):
validation_images_path, epochs_path = self.get_validation_and_epochs_paths()

# Special case for schedule-free optimizers.
if self.config.optimizer.optimizer.is_schedule_free:
torch.clear_autocast_cache()
self.model.optimizer.eval()
torch_gc()

torch_gc()
self.callbacks.on_update_status("sampling")

is_custom_sample = False
Expand All @@ -287,6 +363,32 @@ def __sample_during_training(
if self.model.ema:
self.model.ema.copy_ema_to(self.parameters, store_temp=True)

# Create a hidden directory to save the samples for the current epoch
epoch_sample_dir = os.path.join(self.config.workspace_dir, "epochs", f"class_{train_progress.epoch}")
os.makedirs(epoch_sample_dir, exist_ok=True)

# Set the "Hidden" attribute for the "epochs" folder
ctypes.windll.kernel32.SetFileAttributesW(os.path.join(self.config.workspace_dir, "epochs"), 0x02)

for i, sample_params in enumerate(sample_params_list):
if sample_params.enabled:
try:
safe_prompt = path_util.safe_filename(sample_params.prompt)
sample_dir = os.path.join(
self.config.workspace_dir,
"samples",
f"{str(i)} - {safe_prompt}",
)

# Create the prompt-specific folder if it doesn't exist
os.makedirs(sample_dir, exist_ok=True)

except:
traceback.print_exc()
print("Error during sampling, proceeding without sampling")

torch_gc()

self.__sample_loop(
train_progress=train_progress,
train_device=train_device,
Expand All @@ -308,7 +410,26 @@ def __sample_during_training(
folder_postfix=" - no-ema",
)

# Define the relative path to the fid_scores.json file
fid_scores_file = os.path.join("workspace", "run", "epochs", "fid_scores.json")

# Calculate and log FID score
if epochs_path is not None:
fid_scores = calculate_fid_scores(validation_images_path, epochs_path)

# Read FID scores from the JSON file
if os.path.exists(fid_scores_file):
with open(fid_scores_file, "r") as f:
fid_scores = json.load(f)
for epoch, fid_score in fid_scores.items():
self.tensorboard.add_scalar("loss/validation loss", fid_score, int(epoch)) # Log to TensorBoard
else:
print("FID scores JSON file not found. No scores to log to TensorBoard.")
else:
print("No 'validation_images' concept found in concepts.json. Skipping FID score calculation.")

self.model_setup.setup_train_device(self.model, self.config)

# Special case for schedule-free optimizers.
if self.config.optimizer.optimizer.is_schedule_free:
torch.clear_autocast_cache()
Expand Down Expand Up @@ -422,10 +543,13 @@ def save(self, train_progress: TrainProgress):

torch_gc()

def __needs_sample(self, train_progress: TrainProgress):
return self.repeating_action_needed(
"sample", self.config.sample_after, self.config.sample_after_unit, train_progress
)
def __needs_sample(self, train_progress: TrainProgress, is_last_epoch: bool):
if is_last_epoch:
return True
else:
return self.repeating_action_needed(
"sample", self.config.sample_after, self.config.sample_after_unit, train_progress
)

def __needs_backup(self, train_progress: TrainProgress):
return self.repeating_action_needed(
Expand Down Expand Up @@ -540,11 +664,6 @@ def train(self):
step_tqdm = tqdm(self.data_loader.get_data_loader(), desc="step", total=current_epoch_length,
initial=train_progress.epoch_step)
for epoch_step, batch in enumerate(step_tqdm):
if self.__needs_sample(train_progress) or self.commands.get_and_reset_sample_default_command():
self.__enqueue_sample_during_training(
lambda: self.__sample_during_training(train_progress, train_device)
)

sample_commands = self.commands.get_and_reset_sample_custom_commands()
if sample_commands:
def create_sample_commands_fun(sample_commands):
Expand Down Expand Up @@ -635,10 +754,21 @@ def sample_commands_fun():
return

train_progress.next_epoch()

# Check if the current epoch is the last one
is_last_epoch = train_progress.epoch == self.config.epochs - 1

# Check if sampling is needed after the epoch is completed
if self.__needs_sample(train_progress, is_last_epoch) or self.commands.get_and_reset_sample_default_command():
self.__sample_during_training(train_progress, train_device) # Directly sample

self.callbacks.on_update_train_progress(train_progress, current_epoch_length, self.config.epochs)

if self.commands.get_stop_command():
return

# Ensure sampling after the training loop
self.__execute_sample_during_training()

def end(self):
if self.one_step_trained:
Expand Down Expand Up @@ -671,3 +801,7 @@ def end(self):

for handle in self.grad_hook_handles:
handle.remove()

epochs_dir = os.path.join(self.config.workspace_dir, "epochs")
if os.path.exists(epochs_dir):
shutil.rmtree(epochs_dir)
1 change: 1 addition & 0 deletions requirements-global.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ accelerate==0.30.1
safetensors==0.4.3
tensorboard==2.16.2
pytorch-lightning==2.2.5
torch-fidelity==0.3.0

# stable diffusion
-e git+https://github.com/huggingface/diffusers.git@0ab63ff#egg=diffusers
Expand Down
93 changes: 93 additions & 0 deletions scripts/calculate_fid_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchmetrics.image.fid import FrechetInceptionDistance
from torch.utils.data import Dataset
from PIL import Image
import json

# Set up the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the transformations for the images
transform = transforms.Compose([
transforms.Resize((299, 299)), # Resize images to the required input size of Inception v3
transforms.ToTensor(), # Convert images to tensors
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the images
])

# Custom loader function
def loader(path):
return Image.open(path).convert('RGB')

class EpochDataset(Dataset):
def __init__(self, epoch_path, transform=None, loader=None):
self.epoch_path = epoch_path
self.transform = transform
self.loader = loader
self.image_files = [f for f in os.listdir(epoch_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

def __getitem__(self, index):
image_file = self.image_files[index]
image_path = os.path.join(self.epoch_path, image_file)
image = self.loader(image_path)
if self.transform is not None:
image = self.transform(image)
return image

def __len__(self):
return len(self.image_files)

def calculate_fid_scores(validation_images_path, epochs_path):
# Load the validation images using ImageFolder with the custom loader
validation_dataset = ImageFolder(validation_images_path, transform=transform, loader=loader)

# Create an instance of the FrechetInceptionDistance metric
fid = FrechetInceptionDistance(normalize=True).to(device)

# Load existing FID scores if the file exists
fid_scores_file = os.path.join(epochs_path, "fid_scores.json")
if os.path.exists(fid_scores_file):
with open(fid_scores_file, "r") as f:
epoch_fid_scores = json.load(f)
else:
epoch_fid_scores = {}

# Get the list of epoch folders sorted in ascending order
epoch_folders = sorted([folder for folder in os.listdir(epochs_path) if folder.startswith("class_")])

# Get the latest epoch folder
latest_epoch_folder = epoch_folders[-1]

# Extract the epoch number from the latest epoch folder name
latest_epoch_number = int(latest_epoch_folder.split("_")[-1])

# Calculate FID score only for the latest epoch
epoch_path = os.path.join(epochs_path, latest_epoch_folder)
# Load the generated images for the latest epoch using the custom dataset
generated_dataset = EpochDataset(epoch_path, transform=transform, loader=loader)

# Check if both validation and generated datasets have at least two samples
if len(validation_dataset) >= 2 and len(generated_dataset) >= 2:
# Calculate the FID score for the latest epoch
fid.reset()
fid.update(torch.stack([img.to(device) for img, _ in validation_dataset]), real=True)
fid.update(torch.stack([img.to(device) for img in generated_dataset]), real=False)
fid_score = fid.compute()

# Store the FID score for the latest epoch using the epoch number as the key
epoch_fid_scores[str(latest_epoch_number)] = fid_score.item()
else:
print(f"Skipping FID calculation for epoch {latest_epoch_folder} due to insufficient samples.")

# Print the FID scores for each epoch
for epoch, fid_score in epoch_fid_scores.items():
print(f"Epoch {epoch}: FID score = {fid_score}")

# Store updated FID scores in the JSON file
with open(fid_scores_file, "w") as f:
json.dump(epoch_fid_scores, f)

# Return the epoch_fid_scores dictionary
return epoch_fid_scores
39 changes: 39 additions & 0 deletions wiki_additions/validation_loss.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Validation Loss: Monitoring Training Progress with FID Scores

Validation loss, implemented as FID (Fréchet Inception Distance) scores, is a valuable metric for monitoring the training progress of your model. It complements training loss and smooth loss by providing insights into the model's performance on unseen data.

## How it Works

- FID scores are calculated at regular intervals (currently only supported after each epoch) using a validation image set.
- The validation set is a portion[^1] of the training image set that is not used for training the LoRA model.
- FID scores measure the similarity between the validation images and the images generated by the model after each epoch.

## Interpreting Validation Loss

- Lower FID scores signify greater similarity between the generated images and the validation images, indicating that the model's capability of generalizing to unseen data is improving.
- Conversely, if the validation loss (FID score) continues to rise consistently, it suggests that the model may be overfitting[^2] to the training images and not adapting well to unseen data.

## Benefits of Monitoring Validation Loss

Validation loss provides insights into:

- When to stop training to prevent overfitting
- Which hyperparameters to tune for optimal performance
- How well the model generalizes to unseen data

By monitoring validation loss, you can make informed decisions to improve your model's performance and ensure it learns meaningful representations of the data.

## Implementation Considerations

To utilize validation loss effectively:

1. **Validation Image Set:** Create a separate set of validation images representative of the data distribution but not used during training.
2. **Concept Configuration:** In your `concepts.json`[^3] file, define a concept named "validation_images" with the path to your validation image set. Ensure that this concept is disabled by setting `"disabled": true`, otherwise it will be included as part of the training image set.
3. **Epochs Folder:** FID scores and generated images for each epoch are stored in a hidden "epochs" folder within your workspace directory.
4. **FID Score Calculation:** FID scores are calculated after each epoch and logged to TensorBoard for visualization.
5. **Customization:** You can modify the code to calculate FID scores at different intervals if needed.

[^1]: Allocating 15% of your dataset to the validation set can be a good middle ground.
[^2]: Overfitting means the model has learned the representations of the training images too well, which is detrimental to its ability to generalize to new, unseen data (e.g., the validation images).
[^3]: Modify the `concepts.json` file directly or through the concepts tab in the UI, which automatically updates the file.