diff --git a/modules/trainer/GenericTrainer.py b/modules/trainer/GenericTrainer.py index 745b2dc4..957b1cd8 100644 --- a/modules/trainer/GenericTrainer.py +++ b/modules/trainer/GenericTrainer.py @@ -6,6 +6,7 @@ import traceback from pathlib import Path from typing import Callable +import ctypes import torch from PIL.Image import Image @@ -36,6 +37,7 @@ from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.time_util import get_string_timestamp from modules.util.torch_util import torch_gc +from scripts.calculate_fid_scores import calculate_fid_scores class GenericTrainer(BaseTrainer): @@ -83,7 +85,19 @@ def __init__(self, config: TrainConfig, callbacks: TrainCallbacks, commands: Tra self.grad_hook_handles = [] + def pre_training_check(self): + epochs_dir = "workspace/run/epochs" # Use relative path directly + if os.path.exists(epochs_dir): + try: + print(f"Found existing 'epochs' folder. Deleting...") + shutil.rmtree(epochs_dir) + print("Deletion successful.") + except Exception as e: + print(f"Error deleting 'epochs' folder: {e}") + def start(self): + self.pre_training_check() + if self.config.clear_cache_before_training and self.config.latent_caching: self.__clear_cache() @@ -230,7 +244,7 @@ def __sample_loop( def on_sample_default(image: Image): if self.config.samples_to_tensorboard: self.tensorboard.add_image(f"sample{str(i)} - {safe_prompt}", pil_to_tensor(image), - train_progress.global_step) + train_progress.global_step) self.callbacks.on_sample_default(image) def on_sample_custom(image: Image): @@ -257,18 +271,80 @@ def on_sample_custom(image: Image): torch_gc() + current_epoch = train_progress.epoch + + for i, sample_params in enumerate(sample_params_list): + if sample_params.enabled: + safe_prompt = path_util.safe_filename(sample_params.prompt) + sample_dir = os.path.join( + self.config.workspace_dir, + "samples", + f"{str(i)} - {safe_prompt}", + ) + + # Find the most recent sample generated during the current epoch + latest_sample = None + latest_timestamp = None + for filename in os.listdir(sample_dir): + if filename.lower().endswith((".png", ".jpg", ".jpeg")): + parts = filename.split("-") + epoch = int(parts[-2]) # Extract the epoch number from the second-to-last part + if epoch == current_epoch and (latest_timestamp is None or filename > latest_timestamp): + latest_timestamp = filename + latest_sample = filename + + # Copy the most recent sample from the current epoch to the epoch-specific folder + if latest_sample is not None: + src_path = os.path.join(sample_dir, latest_sample) + dst_path = os.path.join(os.path.join(self.config.workspace_dir, "epochs", f"class_{current_epoch}"), f"{safe_prompt}_{latest_sample}") + os.makedirs(os.path.dirname(dst_path), exist_ok=True) + shutil.copy2(src_path, dst_path) + + def get_validation_and_epochs_paths(self): + # Get the path to the "training_concepts" directory + training_concepts_dir = os.path.join(os.path.dirname(__file__), "..", "..", "training_concepts") + + # Construct the path to the "concepts.json" file + concepts_file = os.path.join(training_concepts_dir, "concepts.json") + + # Read the concepts.json file + with open(concepts_file, "r") as f: + concepts = json.load(f) + + # Find the concept named "validation_images" + validation_images_path = None + for concept in concepts: + if concept["name"] == "validation_images": + validation_images_path = concept["path"] + break + + epochs_path = None + if validation_images_path is not None: + # Get the path to the "scripts" directory + scripts_dir = os.path.join(os.path.dirname(__file__), "..", "..", "scripts") + # Add the "scripts" directory to the Python module search path + sys.path.append(scripts_dir) + # Define the epochs_path variable pointing to the hidden "epochs" directory + epochs_path = os.path.join(self.config.workspace_dir, "epochs") + else: + print("No 'validation_images' concept found in concepts.json. Skipping FID score calculation.") + + return validation_images_path, epochs_path + def __sample_during_training( self, train_progress: TrainProgress, train_device: torch.device, sample_params_list: list[SampleConfig] = None, ): + validation_images_path, epochs_path = self.get_validation_and_epochs_paths() + # Special case for schedule-free optimizers. if self.config.optimizer.optimizer.is_schedule_free: torch.clear_autocast_cache() self.model.optimizer.eval() - torch_gc() + torch_gc() self.callbacks.on_update_status("sampling") is_custom_sample = False @@ -287,6 +363,32 @@ def __sample_during_training( if self.model.ema: self.model.ema.copy_ema_to(self.parameters, store_temp=True) + # Create a hidden directory to save the samples for the current epoch + epoch_sample_dir = os.path.join(self.config.workspace_dir, "epochs", f"class_{train_progress.epoch}") + os.makedirs(epoch_sample_dir, exist_ok=True) + + # Set the "Hidden" attribute for the "epochs" folder + ctypes.windll.kernel32.SetFileAttributesW(os.path.join(self.config.workspace_dir, "epochs"), 0x02) + + for i, sample_params in enumerate(sample_params_list): + if sample_params.enabled: + try: + safe_prompt = path_util.safe_filename(sample_params.prompt) + sample_dir = os.path.join( + self.config.workspace_dir, + "samples", + f"{str(i)} - {safe_prompt}", + ) + + # Create the prompt-specific folder if it doesn't exist + os.makedirs(sample_dir, exist_ok=True) + + except: + traceback.print_exc() + print("Error during sampling, proceeding without sampling") + + torch_gc() + self.__sample_loop( train_progress=train_progress, train_device=train_device, @@ -308,7 +410,26 @@ def __sample_during_training( folder_postfix=" - no-ema", ) + # Define the relative path to the fid_scores.json file + fid_scores_file = os.path.join("workspace", "run", "epochs", "fid_scores.json") + + # Calculate and log FID score + if epochs_path is not None: + fid_scores = calculate_fid_scores(validation_images_path, epochs_path) + + # Read FID scores from the JSON file + if os.path.exists(fid_scores_file): + with open(fid_scores_file, "r") as f: + fid_scores = json.load(f) + for epoch, fid_score in fid_scores.items(): + self.tensorboard.add_scalar("loss/validation loss", fid_score, int(epoch)) # Log to TensorBoard + else: + print("FID scores JSON file not found. No scores to log to TensorBoard.") + else: + print("No 'validation_images' concept found in concepts.json. Skipping FID score calculation.") + self.model_setup.setup_train_device(self.model, self.config) + # Special case for schedule-free optimizers. if self.config.optimizer.optimizer.is_schedule_free: torch.clear_autocast_cache() @@ -422,10 +543,13 @@ def save(self, train_progress: TrainProgress): torch_gc() - def __needs_sample(self, train_progress: TrainProgress): - return self.repeating_action_needed( - "sample", self.config.sample_after, self.config.sample_after_unit, train_progress - ) + def __needs_sample(self, train_progress: TrainProgress, is_last_epoch: bool): + if is_last_epoch: + return True + else: + return self.repeating_action_needed( + "sample", self.config.sample_after, self.config.sample_after_unit, train_progress + ) def __needs_backup(self, train_progress: TrainProgress): return self.repeating_action_needed( @@ -540,11 +664,6 @@ def train(self): step_tqdm = tqdm(self.data_loader.get_data_loader(), desc="step", total=current_epoch_length, initial=train_progress.epoch_step) for epoch_step, batch in enumerate(step_tqdm): - if self.__needs_sample(train_progress) or self.commands.get_and_reset_sample_default_command(): - self.__enqueue_sample_during_training( - lambda: self.__sample_during_training(train_progress, train_device) - ) - sample_commands = self.commands.get_and_reset_sample_custom_commands() if sample_commands: def create_sample_commands_fun(sample_commands): @@ -635,10 +754,21 @@ def sample_commands_fun(): return train_progress.next_epoch() + + # Check if the current epoch is the last one + is_last_epoch = train_progress.epoch == self.config.epochs - 1 + + # Check if sampling is needed after the epoch is completed + if self.__needs_sample(train_progress, is_last_epoch) or self.commands.get_and_reset_sample_default_command(): + self.__sample_during_training(train_progress, train_device) # Directly sample + self.callbacks.on_update_train_progress(train_progress, current_epoch_length, self.config.epochs) if self.commands.get_stop_command(): return + + # Ensure sampling after the training loop + self.__execute_sample_during_training() def end(self): if self.one_step_trained: @@ -671,3 +801,7 @@ def end(self): for handle in self.grad_hook_handles: handle.remove() + + epochs_dir = os.path.join(self.config.workspace_dir, "epochs") + if os.path.exists(epochs_dir): + shutil.rmtree(epochs_dir) diff --git a/requirements-global.txt b/requirements-global.txt index 1ee92a64..a72b63dd 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -12,6 +12,7 @@ accelerate==0.30.1 safetensors==0.4.3 tensorboard==2.16.2 pytorch-lightning==2.2.5 +torch-fidelity==0.3.0 # stable diffusion -e git+https://github.com/huggingface/diffusers.git@0ab63ff#egg=diffusers diff --git a/scripts/calculate_fid_scores.py b/scripts/calculate_fid_scores.py new file mode 100644 index 00000000..0c033735 --- /dev/null +++ b/scripts/calculate_fid_scores.py @@ -0,0 +1,93 @@ +import os +import torch +import torchvision.transforms as transforms +from torchvision.datasets import ImageFolder +from torchmetrics.image.fid import FrechetInceptionDistance +from torch.utils.data import Dataset +from PIL import Image +import json + +# Set up the device (GPU if available, else CPU) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Define the transformations for the images +transform = transforms.Compose([ + transforms.Resize((299, 299)), # Resize images to the required input size of Inception v3 + transforms.ToTensor(), # Convert images to tensors + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the images +]) + +# Custom loader function +def loader(path): + return Image.open(path).convert('RGB') + +class EpochDataset(Dataset): + def __init__(self, epoch_path, transform=None, loader=None): + self.epoch_path = epoch_path + self.transform = transform + self.loader = loader + self.image_files = [f for f in os.listdir(epoch_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] + + def __getitem__(self, index): + image_file = self.image_files[index] + image_path = os.path.join(self.epoch_path, image_file) + image = self.loader(image_path) + if self.transform is not None: + image = self.transform(image) + return image + + def __len__(self): + return len(self.image_files) + +def calculate_fid_scores(validation_images_path, epochs_path): + # Load the validation images using ImageFolder with the custom loader + validation_dataset = ImageFolder(validation_images_path, transform=transform, loader=loader) + + # Create an instance of the FrechetInceptionDistance metric + fid = FrechetInceptionDistance(normalize=True).to(device) + + # Load existing FID scores if the file exists + fid_scores_file = os.path.join(epochs_path, "fid_scores.json") + if os.path.exists(fid_scores_file): + with open(fid_scores_file, "r") as f: + epoch_fid_scores = json.load(f) + else: + epoch_fid_scores = {} + + # Get the list of epoch folders sorted in ascending order + epoch_folders = sorted([folder for folder in os.listdir(epochs_path) if folder.startswith("class_")]) + + # Get the latest epoch folder + latest_epoch_folder = epoch_folders[-1] + + # Extract the epoch number from the latest epoch folder name + latest_epoch_number = int(latest_epoch_folder.split("_")[-1]) + + # Calculate FID score only for the latest epoch + epoch_path = os.path.join(epochs_path, latest_epoch_folder) + # Load the generated images for the latest epoch using the custom dataset + generated_dataset = EpochDataset(epoch_path, transform=transform, loader=loader) + + # Check if both validation and generated datasets have at least two samples + if len(validation_dataset) >= 2 and len(generated_dataset) >= 2: + # Calculate the FID score for the latest epoch + fid.reset() + fid.update(torch.stack([img.to(device) for img, _ in validation_dataset]), real=True) + fid.update(torch.stack([img.to(device) for img in generated_dataset]), real=False) + fid_score = fid.compute() + + # Store the FID score for the latest epoch using the epoch number as the key + epoch_fid_scores[str(latest_epoch_number)] = fid_score.item() + else: + print(f"Skipping FID calculation for epoch {latest_epoch_folder} due to insufficient samples.") + + # Print the FID scores for each epoch + for epoch, fid_score in epoch_fid_scores.items(): + print(f"Epoch {epoch}: FID score = {fid_score}") + + # Store updated FID scores in the JSON file + with open(fid_scores_file, "w") as f: + json.dump(epoch_fid_scores, f) + + # Return the epoch_fid_scores dictionary + return epoch_fid_scores diff --git a/wiki_additions/validation_loss.md b/wiki_additions/validation_loss.md new file mode 100644 index 00000000..7cf12d79 --- /dev/null +++ b/wiki_additions/validation_loss.md @@ -0,0 +1,39 @@ +# Validation Loss: Monitoring Training Progress with FID Scores + +Validation loss, implemented as FID (Fréchet Inception Distance) scores, is a valuable metric for monitoring the training progress of your model. It complements training loss and smooth loss by providing insights into the model's performance on unseen data. + +## How it Works + +- FID scores are calculated at regular intervals (currently only supported after each epoch) using a validation image set. +- The validation set is a portion[^1] of the training image set that is not used for training the LoRA model. +- FID scores measure the similarity between the validation images and the images generated by the model after each epoch. + +## Interpreting Validation Loss + +- Lower FID scores signify greater similarity between the generated images and the validation images, indicating that the model's capability of generalizing to unseen data is improving. +- Conversely, if the validation loss (FID score) continues to rise consistently, it suggests that the model may be overfitting[^2] to the training images and not adapting well to unseen data. + +## Benefits of Monitoring Validation Loss + +Validation loss provides insights into: + +- When to stop training to prevent overfitting +- Which hyperparameters to tune for optimal performance +- How well the model generalizes to unseen data + +By monitoring validation loss, you can make informed decisions to improve your model's performance and ensure it learns meaningful representations of the data. + +## Implementation Considerations + +To utilize validation loss effectively: + +1. **Validation Image Set:** Create a separate set of validation images representative of the data distribution but not used during training. +2. **Concept Configuration:** In your `concepts.json`[^3] file, define a concept named "validation_images" with the path to your validation image set. Ensure that this concept is disabled by setting `"disabled": true`, otherwise it will be included as part of the training image set. +3. **Epochs Folder:** FID scores and generated images for each epoch are stored in a hidden "epochs" folder within your workspace directory. +4. **FID Score Calculation:** FID scores are calculated after each epoch and logged to TensorBoard for visualization. +5. **Customization:** You can modify the code to calculate FID scores at different intervals if needed. + +[^1]: Allocating 15% of your dataset to the validation set can be a good middle ground. +[^2]: Overfitting means the model has learned the representations of the training images too well, which is detrimental to its ability to generalize to new, unseen data (e.g., the validation images). +[^3]: Modify the `concepts.json` file directly or through the concepts tab in the UI, which automatically updates the file. +