diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index f9d057c5..81395580 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -36,9 +36,6 @@ jobs: runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"] steps: - uses: actions/checkout@v3 - - name: Install Ubuntu dependencies - run: | - sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y - name: Install dependencies run: | pip install -e . @@ -54,7 +51,7 @@ jobs: ruff check . - name: PyTest run: | - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ python3 -m pytest + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/README.md b/README.md index ea88ca8a..b7dde6aa 100644 --- a/README.md +++ b/README.md @@ -88,14 +88,14 @@ After installation completes, run the training script. ```bash export LIBTPU_INIT_ARGS="" - python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash + python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash ``` -- **Stable Diffusion 1.5** +- **Stable Diffusion 1.4** ```bash export LIBTPU_INIT_ARGS="" - python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base15.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash + python -m src.maxdiffusion.train src/maxdiffusion/configs/base14.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash ``` To generate images with a trained checkpoint, run: @@ -109,7 +109,7 @@ After installation completes, run the training script. **Stable Diffusion 1.x,2.x** ```bash - python src/maxdiffusion/dreambooth/train_dreambooth.py src/maxdiffusion/configs/base15.yml class_data_dir= instance_data_dir= instance_prompt="a photo of ohwx dog" class_prompt="photo of a dog" max_train_steps=150 jax_cache_dir= class_prompt="a photo of a dog" activations_dtype=bfloat16 weights_dtype=float32 per_device_batch_size=1 enable_profiler=False precision=DEFAULT cache_dreambooth_dataset=False learning_rate=4e-6 num_class_images=100 run_name= output_dir=gs:// + python src/maxdiffusion/dreambooth/train_dreambooth.py src/maxdiffusion/configs/base14.yml class_data_dir= instance_data_dir= instance_prompt="a photo of ohwx dog" class_prompt="photo of a dog" max_train_steps=150 jax_cache_dir= class_prompt="a photo of a dog" activations_dtype=bfloat16 weights_dtype=float32 per_device_batch_size=1 enable_profiler=False precision=DEFAULT cache_dreambooth_dataset=False learning_rate=4e-6 num_class_images=100 run_name= output_dir=gs:// ``` ## Inference @@ -143,7 +143,7 @@ To generate images, run the following command: Single and Multi host inference is supported with sharding annotations: ```bash - python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" lightning_repo="ByteDance/SDXL-Lightning" lightning_ckpt="sdxl_lightning_4step_unet.safetensors" + python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl_lightning.yml run_name="my_run" lightning_repo="ByteDance/SDXL-Lightning" lightning_ckpt="sdxl_lightning_4step_unet.safetensors" ``` ## ControlNet @@ -176,7 +176,7 @@ cd maxdiffusion pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip3 install -r requirements.txt pip3 install . -python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run output_dir=gs://your-bucket/" +python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run output_dir=gs://your-bucket/" ``` # Comparison to Alternatives diff --git a/docs/train_README.md b/docs/train_README.md index 5b6adcd8..74ff8ce4 100644 --- a/docs/train_README.md +++ b/docs/train_README.md @@ -28,7 +28,7 @@ In this session, we'll explain some of the core config parameters and how they a | config | model | supports | | ------ | ----- | -------- | -| [base15.yml](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base15.yml) | [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) | training / inference +| [base14.yml](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base14.yml) | [stable-diffusion-v1-4](CompVis/stable-diffusion-v1-4) | training / inference | [base_2_base.yml](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base_2_base.yml) | [stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base) | training / inference | [base21.yml](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base21.yml) | [stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) | training / inference | [base_xl.yml](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base_xl.yml) | [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | training / inference @@ -40,10 +40,10 @@ Let's start with a simple example. After setting up your environment, create a t ```bash export LIBTPU_INIT_ARGS="" - python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base15.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash + python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base14.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash ``` -The job will use the predefined parameters in base15.yml and will overwrite any parameters that as passed into the cli. +The job will use the predefined parameters in base14.yml and will overwrite any parameters that as passed into the cli. ### Changing The Base Model @@ -53,7 +53,7 @@ To load Pytorch weights, set `from_pt=True` and set `revision=main`. Let's look ```bash export LIBTPU_INIT_ARGS="" - python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base15.yml run_name="my_run" output_dir="gs://your-bucket/" pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 from_pt=True revision=main + python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base14.yml run_name="my_run" output_dir="gs://your-bucket/" pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 from_pt=True revision=main ``` After training, a new folder structure with weights and metrics has been created under the `output_dir` folder: @@ -71,7 +71,7 @@ It is recommended to use a Google Cloud Storage bucket as the `output_dir`. This To use the trained checkpoint, then run: ```bash - python src/maxdiffusion/generate.py src/maxdiffusion/configs/base15.yml output_dir="gs://your-bucket/" run_name="my_run" + python src/maxdiffusion/generate.py src/maxdiffusion/configs/base14.yml output_dir="gs://your-bucket/" run_name="my_run" ``` @@ -79,7 +79,7 @@ To use the trained checkpoint, then run: MaxDiffusion models use logical axis annotations, which allows users to explore different sharding layouts without making changes to the model code. To learn more about distributed arrays and Flax partitioning, checkout JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) and then FLAX's [Scale up Flax Modules on multiple devices](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html#flax-and-jax-jit-scaled-up) -The main [config values](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base15.yml#L74) for these are: +The main [config values](https://github.com/google/maxdiffusion/blob/main/src/maxdiffusion/configs/base14.yml#L74) for these are: - mesh_axes - logical_axis_rules diff --git a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py index 4c72efda..d2d0ab61 100644 --- a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py +++ b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py @@ -226,7 +226,7 @@ def load_diffusers_checkpoint(self): return pipeline, params - def save_checkpoint(self, train_step, pipeline, params, train_states, save_inference_states=False): + def save_checkpoint(self, train_step, pipeline, params, train_states): def config_to_json(model_or_config): return json.loads(model_or_config.to_json_string()) @@ -248,35 +248,6 @@ def config_to_json(model_or_config): tokenizer_config = {"path" : self.config.tokenizer_model_name_or_path} items["tokenizer_config"] = ocp.args.JsonSave(tokenizer_config) - if save_inference_states: - inference_unet_state, _, _ = self.create_unet_state( - pipeline, - {"unet" : train_states["unet_state"].params}, - checkpoint_item_name="inference_unet_state", - is_training=False, - ) - items["inference_unet_state"] = ocp.args.StandardSave(inference_unet_state) - - if self.config.train_text_encoder: - inference_text_encoder_state, _ = self.create_text_encoder_state( - pipeline, - {"text_encoder" : train_states["text_encoder_state"].params}, - checkpoint_item_name="inference_text_encoder_state", - is_training=False) - - items["inference_text_encoder_state"] = ocp.args.StandardSave(inference_text_encoder_state) - - # TODO - this is broken since create_text_encoder_state will create a text_encoder init weights - # and not a text_encoder_2 init weights fn. - if hasattr(pipeline, "text_encoder_2"): - inference_text_encoder_2_state, _ = self.create_text_encoder_2_state( - pipeline, - {"text_encoder" : train_states["text_encoder_2_state"].params}, - checkpoint_item_name="inference_text_encoder_2_state", - is_training=False - ) - items["inference_text_encoder_2_state"] = ocp.args.StandardSave(inference_text_encoder_2_state) - self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) def load_params(self, step = None): diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index 2783be20..9d3de8af 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -18,6 +18,7 @@ """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" from typing import Optional, Any +import os from maxdiffusion import max_logging from etils import epath from flax.training import train_state @@ -55,17 +56,14 @@ def create_orbax_checkpoint_manager( "text_encoder_config", "scheduler_config", "unet_state", - "inference_unet_state", "vae_state", "text_encoder_state", - "inference_text_encoder_state", "tokenizer_config" ) if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT: item_names+= ( "text_encoder_2_state", "text_encoder_2_config", - "inference_text_encoder_2_state" ) print("item_names: ", item_names) @@ -120,6 +118,34 @@ def load_stable_diffusion_configs( args=orbax.checkpoint.args.Composite(**restore_args) ),None) +def load_params_from_path( + config, + checkpoint_manager: CheckpointManager, + unboxed_abstract_params, + checkpoint_item : str, + step: Optional[int] = None, +): + ckptr = ocp.PyTreeCheckpointer() + + if step is None: + step = checkpoint_manager.latest_step() + if step is None: + return None + + ckpt_path = os.path.join(config.checkpoint_dir, str(step),checkpoint_item) + ckpt_path = epath.Path(ckpt_path) + + restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params) + restored = ckptr.restore( + ckpt_path, + item={"params": unboxed_abstract_params}, + transforms={}, + restore_args={"params" : restore_args} + ) + return restored["params"] + + + def load_state_if_possible( checkpoint_manager: CheckpointManager, abstract_unboxed_pre_state: train_state.TrainState, diff --git a/src/maxdiffusion/configs/base15.yml b/src/maxdiffusion/configs/base14.yml similarity index 98% rename from src/maxdiffusion/configs/base15.yml rename to src/maxdiffusion/configs/base14.yml index a950f39a..b5f938c1 100644 --- a/src/maxdiffusion/configs/base15.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -23,7 +23,7 @@ gcs_metrics: True save_config_to_gcs: False log_period: 10000000000 # Flushes Tensorboard -pretrained_model_name_or_path: 'runwayml/stable-diffusion-v1-5' +pretrained_model_name_or_path: 'CompVis/stable-diffusion-v1-4' unet_checkpoint: '' revision: 'flax' @@ -105,6 +105,7 @@ logical_axis_rules: [ ['activation_kv', 'tensor'], ['embed','fsdp'], ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] @@ -143,7 +144,6 @@ enable_data_shuffling: True # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 -upload_ckpts_to_gcs: False # Prepare image latents and text encoder outputs # during dataset creation to reduce memory consumption. diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index acd7c9d7..c8a154a3 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -102,11 +102,14 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_batch', 'data'], - ['activation_length', 'fsdp'], - ['out_channels', 'fsdp'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], ['conv_out', 'fsdp'], - ['length', 'fsdp'] ] data_sharding: [['data', 'fsdp', 'tensor']] @@ -143,7 +146,6 @@ enable_data_shuffling: True # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 -upload_ckpts_to_gcs: False # Prepare image latents and text encoder outputs # during dataset creation to reduce memory consumption. diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index d96c6bee..2c860109 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -115,11 +115,14 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_batch', 'data'], - ['activation_length', 'fsdp'], - ['out_channels', 'fsdp'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], ['conv_out', 'fsdp'], - ['length', 'fsdp'] ] data_sharding: [['data', 'fsdp', 'tensor']] @@ -156,7 +159,6 @@ enable_data_shuffling: True # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 -upload_ckpts_to_gcs: False # Prepare image latents and text encoder outputs # during dataset creation to reduce memory consumption. diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 3bce6d12..18c4e94b 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -108,7 +108,8 @@ logical_axis_rules: [ ['activation_kv', 'tensor'], ['embed','fsdp'], ['heads', 'tensor'], - ['out_channels', 'fsdp'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] data_sharding: [['data', 'fsdp', 'tensor']] @@ -146,7 +147,6 @@ enable_data_shuffling: True # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 -upload_ckpts_to_gcs: False # Prepare image latents and text encoder outputs # during dataset creation to reduce memory consumption. @@ -185,6 +185,7 @@ profiler_steps: 10 # Generation parameters prompt: "A magical castle in the middle of a forest, artistic drawing" negative_prompt: "purple, red" +do_classifier_free_guidance: True guidance_scale: 9 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index d0ef99d1..8973abf9 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -82,11 +82,14 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_batch', 'data'], - ['activation_length', 'fsdp'], - ['out_channels', 'fsdp'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], ['conv_out', 'fsdp'], - ['length', 'fsdp'] ] data_sharding: [['data', 'fsdp', 'tensor']] @@ -146,6 +149,7 @@ profiler_steps: 5 # Generation parameters prompt: "portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal, elegant, sharp focus, soft lighting, vibrant colors" negative_prompt: "purple, red" +do_classifier_free_guidance: False guidance_scale: 2 guidance_rescale: -1 num_inference_steps: 4 diff --git a/src/maxdiffusion/generate.py b/src/maxdiffusion/generate.py index 52024833..4c339773 100644 --- a/src/maxdiffusion/generate.py +++ b/src/maxdiffusion/generate.py @@ -23,10 +23,11 @@ import jax from jax.sharding import PartitionSpec as P import jax.numpy as jnp -from maxdiffusion import pyconfig from absl import app from maxdiffusion import ( + pyconfig, FlaxDDIMScheduler, + max_utils ) from maxdiffusion.maxdiffusion_utils import rescale_noise_cfg @@ -36,6 +37,7 @@ StableDiffusionTrainer ) +from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path from maxdiffusion.checkpointing.base_stable_diffusion_checkpointer import ( STABLE_DIFFUSION_CHECKPOINT ) @@ -164,11 +166,27 @@ def run(config): checkpoint_loader = GenerateSD(config, STABLE_DIFFUSION_CHECKPOINT) pipeline, params = checkpoint_loader.load_checkpoint() - unet_state, unet_state_shardings, _ = checkpoint_loader.create_unet_state( - pipeline, - params, - checkpoint_item_name="inference_unet_state", - is_training=False, + weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng) + unboxed_abstract_state, _, _ = max_utils.get_abstract_state(pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False) + unet_params = load_params_from_path( + config, + checkpoint_loader.checkpoint_manager, + unboxed_abstract_state.params, + "unet_state" + ) + if unet_params: + params["unet"] = unet_params + + # Don't restore the train state to save memory, just restore params + # and create an inference state. + unet_state, unet_state_shardings = max_utils.setup_initial_state( + model=pipeline.unet, + tx=None, + config=config, + mesh=checkpoint_loader.mesh, + weights_init_fn=weights_init_fn, + model_params=params.get("unet", None), + training=False ) vae_state, vae_state_shardings = checkpoint_loader.create_vae_state( @@ -178,11 +196,35 @@ def run(config): is_training=False ) - text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state( - pipeline, - params, - checkpoint_item_name="inference_text_encoder_state" if config.train_text_encoder else "text_encoder_state", - is_training=False + weights_init_fn = functools.partial( + pipeline.text_encoder.init_weights, + rng=checkpoint_loader.rng, + input_shape=(checkpoint_loader.total_train_batch_size, pipeline.tokenizer.model_max_length)) + unboxed_abstract_state, _, _ = max_utils.get_abstract_state( + pipeline.text_encoder, + None, + config, + checkpoint_loader.mesh, + weights_init_fn, + False + ) + text_encoder_params = load_params_from_path( + config, + checkpoint_loader.checkpoint_manager, + unboxed_abstract_state.params, + "text_encoder_state" + ) + if text_encoder_params: + params["text_encoder"] = text_encoder_params + + text_encoder_state, text_encoder_state_shardings = max_utils.setup_initial_state( + model=pipeline.text_encoder, + tx=None, + config=config, + mesh=checkpoint_loader.mesh, + weights_init_fn=weights_init_fn, + model_params=params.get("text_encoder", None), + training=False ) states = {} diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/generate_sdxl.py index 5ca8a79b..42544ae1 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/generate_sdxl.py @@ -30,7 +30,7 @@ ) -from maxdiffusion import pyconfig +from maxdiffusion import pyconfig, max_utils from maxdiffusion.image_processor import VaeImageProcessor from maxdiffusion.maxdiffusion_utils import ( get_add_time_ids, @@ -42,13 +42,19 @@ StableDiffusionXLTrainer ) +from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path + class GenerateSDXL(StableDiffusionXLTrainer): def __init__(self, config): super().__init__(config) -def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale, guidance_rescale): +def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale, guidance_rescale, config): latents, scheduler_state, state = args - latents_input = jnp.concatenate([latents] * 2) + + if config.do_classifier_free_guidance: + latents_input = jnp.concatenate([latents] * 2) + else: + latents_input = latents t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) @@ -62,12 +68,15 @@ def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, gui added_cond_kwargs=added_cond_kwargs ).sample - noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) - - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_prediction_text, guidance_rescale=guidance_rescale) + def apply_classifier_free_guidance(noise_pred, guidance_scale, guidance_rescale): + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + if guidance_rescale > 0: + noise_pred = rescale_noise_cfg(noise_pred, noise_prediction_text, guidance_rescale) + return noise_pred + if config.do_classifier_free_guidance: + noise_pred = apply_classifier_free_guidance(noise_pred, guidance_scale, guidance_rescale) latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() @@ -121,15 +130,26 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size): "text_encoder" : states["text_encoder_state"].params, "text_encoder_2" : states["text_encoder_2_state"].params} prompt_embeds, pooled_embeds = get_embeddings(prompt_ids, pipeline, text_encoder_params) + batch_size = prompt_embeds.shape[0] - negative_prompt_embeds, negative_pooled_embeds = get_embeddings(negative_prompt_ids, pipeline, text_encoder_params) add_time_ids = get_add_time_ids( (height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype ) - prompt_embeds = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) - add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) - add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0) + if config.do_classifier_free_guidance: + if negative_prompt_ids is None: + negative_prompt_embeds = jnp.zeros_like(prompt_embeds) + negative_pooled_embeds = jnp.zeros_like(pooled_embeds) + else: + negative_prompt_embeds, negative_pooled_embeds = get_embeddings(negative_prompt_ids, pipeline, text_encoder_params) + + prompt_embeds = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) + add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) + add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0) + + else: + add_text_embeds = pooled_embeds + # Ensure model output will be `float32` before going into the scheduler guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) guidance_rescale = jnp.array([guidance_rescale], dtype=jnp.float32) @@ -186,7 +206,8 @@ def run_inference(states, pipeline, params, config, rng, mesh, batch_size): added_cond_kwargs=added_cond_kwargs, prompt_embeds=prompt_embeds, guidance_scale=guidance_scale, - guidance_rescale=guidance_rescale) + guidance_rescale=guidance_rescale, + config=config) vae_decode_p = functools.partial(vae_decode, pipeline=pipeline) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): @@ -199,14 +220,31 @@ def run(config): checkpoint_loader = GenerateSDXL(config) pipeline, params = checkpoint_loader.load_checkpoint() + weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng) + unboxed_abstract_state, _, _ = max_utils.get_abstract_state(pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False) + + unet_params = load_params_from_path( + config, + checkpoint_loader.checkpoint_manager, + unboxed_abstract_state.params, + "unet_state" + ) + if unet_params: + params["unet"] = unet_params + if config.lightning_repo: pipeline, params = load_sdxllightning_unet(config, pipeline, params) - unet_state, unet_state_shardings, _ = checkpoint_loader.create_unet_state( - pipeline, - params, - checkpoint_item_name="inference_unet_state", - is_training=False + # Don't restore the train state to save memory, just restore params + # and create an inference state. + unet_state, unet_state_shardings = max_utils.setup_initial_state( + model=pipeline.unet, + tx=None, + config=config, + mesh=checkpoint_loader.mesh, + weights_init_fn=weights_init_fn, + model_params=params.get("unet", None), + training=False ) vae_state, vae_state_shardings = checkpoint_loader.create_vae_state( @@ -244,7 +282,7 @@ def run(config): noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( config.pretrained_model_name_or_path, - revision=config.revision, subfolder="scheduler", dtype=jnp.float32 + revision=config.revision, subfolder="scheduler", dtype=jnp.float32, timestep_spacing="trailing" ) pipeline.scheduler = noise_scheduler diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index 14ef0eca..59c67642 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -97,15 +97,15 @@ def make_pokemon_train_iterator( global_batch_size, tokenize_fn, image_transforms_fn): - train_ds = load_dataset(config.dataset_name,split="train") - - captions_column = config.caption_column - image_column = config.image_column - cache_latents_text_encoder_outputs = config.cache_latents_text_encoder_outputs dataset_save_location = config.dataset_save_location if os.path.isdir(dataset_save_location): train_ds = load_from_disk(dataset_save_location) else: + train_ds = load_dataset(config.dataset_name,split="train") + + captions_column = config.caption_column + image_column = config.image_column + cache_latents_text_encoder_outputs = config.cache_latents_text_encoder_outputs train_ds = train_ds.map( function=tokenize_fn, batched=True, diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index f50ace29..0a60d0dc 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -285,7 +285,7 @@ def unbox_logicallypartioned_trainstate( else x, boxed_train_state, \ is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned)) -def init_train_state(model, tx, weights_init_fn, training=True, eval_only=False): +def init_train_state(model, tx, weights_init_fn, params=None, training=True, eval_only=False): """ We pass in "static" objects like model, tx, config, as JAX compares them by object hash, and instantiating them inside causes pjit top-level annotations @@ -293,16 +293,17 @@ def init_train_state(model, tx, weights_init_fn, training=True, eval_only=False) Args: model_params, model, tx, training """ - model_params = weights_init_fn(eval_only=eval_only) + if not params: + params = weights_init_fn(eval_only=eval_only) if training: state = train_state.TrainState.create( apply_fn=model.apply if hasattr(model, 'apply') else model.__call__, - params=model_params, + params=params, tx=tx) else: state = InferenceState( apply_fn=model.apply if hasattr(model, 'apply') else model.__call__, - params=model_params) + params=params) return state def get_abstract_state(model, tx, config, mesh, weights_init_fn, training=True): @@ -358,29 +359,25 @@ def setup_initial_state( state: the initialized train state state_mesh_annotations: the mesh annotations for the train state """ - max_logging.log(f"setup_initial_state for {checkpoint_item}") # Initialization state = None unboxed_abstract_state, _, state_mesh_shardings = get_abstract_state( model, tx, config, mesh, weights_init_fn, training) with nn_partitioning.axis_rules(config.logical_axis_rules): - if checkpoint_manager: + if checkpoint_manager and checkpoint_item: + max_logging.log(f"setup_initial_state for {checkpoint_item}") state = checkpointing_utils.load_state_if_possible(checkpoint_manager, unboxed_abstract_state, checkpoint_item) if state: state = state[checkpoint_item] if not state: - max_logging.log(f"Could not find {checkpoint_item} in orbax, creating state...") - # for DDP needs replication before jit state. - if config.ici_data_parallelism == -1 or config.dcn_data_parallelism == -1: - sharding = PositionalSharding(mesh.devices).replicate() - partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding) - model_params = jax.tree_util.tree_map(partial_device_put_replicated, model_params) + max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial(init_train_state, model=model, tx=tx, weights_init_fn=weights_init_fn, + params=model_params, training=training, eval_only=False ) @@ -390,12 +387,6 @@ def setup_initial_state( in_shardings=None, out_shardings=state_mesh_shardings )() - if model_params: - state = state.replace(params=model_params) - else: - # this should only be the case when training a new model from scratch. - # else, its possible a model is being loaded from a wrong dir path. - max_logging.log(f"model_params is None, random init weights have been loaded...") state = unbox_logicallypartioned_trainstate(state) diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index 10d44503..3b10a233 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -74,9 +74,10 @@ def transform_images( ds_length = tensor_list.shape[0] iters = ds_length // global_batch_size latents_list = [] - for i in range(0, iters * global_batch_size, global_batch_size): + local_batch_size = global_batch_size // jax.device_count() + for i in range(0, iters * global_batch_size, local_batch_size): sample_rng, rng = jax.random.split(rng) - latents = p_vae_apply(tensor_list[i:i+global_batch_size], sample_rng) + latents = p_vae_apply(tensor_list[i:i+local_batch_size], sample_rng) latents_list.append(latents) latents_list = np.stack(latents_list) @@ -85,15 +86,9 @@ def transform_images( # TODO (Juan Acevedo): do last iteration, its required for the Pyarrow dataset # to not break due to items being fewer than expected. Is there a better way? - if tensor_list[i+global_batch_size:].shape[0] != 0: + if tensor_list[i+local_batch_size:].shape[0] != 0: sample_rng, rng = jax.random.split(rng) - latents = p_vae_apply(tensor_list[i+global_batch_size:], sample_rng) - examples[pixel_ids_key] = np.append(latents_list, latents, axis=0) - else: - examples[pixel_ids_key] = latents_list - if tensor_list[i+global_batch_size:].shape[0] != 0: - sample_rng, rng = jax.random.split(rng) - latents = p_vae_apply(tensor_list[i+global_batch_size:], sample_rng) + latents = p_vae_apply(tensor_list[i+local_batch_size:], sample_rng) examples[pixel_ids_key] = np.append(latents_list, latents, axis=0) else: examples[pixel_ids_key] = latents_list diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 965d75cd..017d4a58 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -656,6 +656,10 @@ def setup(self): if self.use_linear_projection: self.proj_in = nn.Dense( inner_dim, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ('embed','hidden') + ), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision @@ -693,7 +697,15 @@ def setup(self): ] if self.use_linear_projection: - self.proj_out = nn.Dense(inner_dim, dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision) + self.proj_out = nn.Dense( + inner_dim, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ('hidden','embed') + ), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision) else: self.proj_out = nn.Conv( inner_dim, diff --git a/src/maxdiffusion/models/resnet_flax.py b/src/maxdiffusion/models/resnet_flax.py index 7daaec91..48b093b0 100644 --- a/src/maxdiffusion/models/resnet_flax.py +++ b/src/maxdiffusion/models/resnet_flax.py @@ -66,13 +66,13 @@ def __call__(self, hidden_states): hidden_states = nn.with_logical_constraint( hidden_states, - ('batch', 'keep_1', 'keep_2', 'out_channels') + ('conv_batch', 'height', 'keep_2', 'out_channels') ) hidden_states = self.conv(hidden_states) hidden_states = nn.with_logical_constraint( hidden_states, - ('batch', 'keep_1', 'keep_2', 'out_channels') + ('conv_batch', 'height', 'keep_2', 'out_channels') ) return hidden_states @@ -102,7 +102,7 @@ def __call__(self, hidden_states): hidden_states = self.conv(hidden_states) hidden_states = nn.with_logical_constraint( hidden_states, - ('batch', 'keep_1', 'keep_2', 'out_channels') + ('conv_batch', 'height', 'keep_2', 'out_channels') ) return hidden_states @@ -184,7 +184,7 @@ def __call__(self, hidden_states, temb, deterministic=True): hidden_states = self.conv1(hidden_states) hidden_states = nn.with_logical_constraint( hidden_states, - ('batch', None, None, 'out_channels') + ('conv_batch', 'height', 'keep_2', 'out_channels') ) temb = self.time_emb_proj(nn.swish(temb)) @@ -197,7 +197,7 @@ def __call__(self, hidden_states, temb, deterministic=True): hidden_states = self.conv2(hidden_states) hidden_states = nn.with_logical_constraint( hidden_states, - ('batch', 'keep_1', 'keep_2', 'out_channels') + ('conv_batch', 'height', 'keep_2', 'out_channels') ) if self.conv_shortcut is not None: diff --git a/src/maxdiffusion/models/unet_2d_condition_flax.py b/src/maxdiffusion/models/unet_2d_condition_flax.py index 4c450d1c..ffb032d0 100644 --- a/src/maxdiffusion/models/unet_2d_condition_flax.py +++ b/src/maxdiffusion/models/unet_2d_condition_flax.py @@ -217,7 +217,11 @@ def setup(self): padding=((1, 1), (1, 1)), dtype=self.dtype, param_dtype=self.weights_dtype, - precision=self.precision + precision=self.precision, + kernel_init = nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ('keep_1', 'keep_2', 'conv_in', 'conv_out') + ) ) # time @@ -452,6 +456,11 @@ def __call__( sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) + sample = nn.with_logical_constraint( + sample, + ('conv_batch', 'height', 'keep_2', 'out_channels') + ) + # 3. down down_block_res_samples = (sample,) for down_block in self.down_blocks: diff --git a/src/maxdiffusion/models/vae_flax.py b/src/maxdiffusion/models/vae_flax.py index 5884fc27..bdab5c89 100644 --- a/src/maxdiffusion/models/vae_flax.py +++ b/src/maxdiffusion/models/vae_flax.py @@ -28,15 +28,6 @@ from ..utils import BaseOutput from .modeling_flax_utils import FlaxModelMixin -from maxdiffusion import common_types - -AxisNames = common_types.AxisNames -BATCH = common_types.BATCH -LENGTH = common_types.LENGTH -HEAD = common_types.HEAD -KEEP_1 = common_types.KEEP_1 -KEEP_2 = common_types.KEEP_2 -CONV_OUT = common_types.CONV_OUT @flax.struct.dataclass class FlaxDecoderOutput(BaseOutput): @@ -91,10 +82,6 @@ def setup(self): padding=((1, 1), (1, 1)), dtype=self.dtype, param_dtype=self.weights_dtype, - kernel_init = nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ('keep_1', 'keep_2', 'conv_in', 'conv_out') - ) ) def __call__(self, hidden_states): @@ -105,10 +92,6 @@ def __call__(self, hidden_states): method="nearest", ) hidden_states = self.conv(hidden_states) - hidden_states = nn.with_logical_constraint( - hidden_states, - (BATCH, KEEP_1, KEEP_2, 'out_channels') - ) return hidden_states @@ -135,20 +118,12 @@ def setup(self): padding="VALID", dtype=self.dtype, param_dtype=self.weights_dtype, - kernel_init = nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ('keep_1', 'keep_2', 'conv_in', 'conv_out') - ) ) def __call__(self, hidden_states): pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim hidden_states = jnp.pad(hidden_states, pad_width=pad) hidden_states = self.conv(hidden_states) - hidden_states = nn.with_logical_constraint( - hidden_states, - (BATCH, 'keep_1', 'keep_2', 'out_channels') - ) return hidden_states @@ -190,10 +165,6 @@ def setup(self): padding=((1, 1), (1, 1)), dtype=self.dtype, param_dtype=self.weights_dtype, - kernel_init = nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ('keep_1', 'keep_2', 'conv_in', 'conv_out') - ) ) self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6, dtype=self.dtype, param_dtype=self.weights_dtype) @@ -205,10 +176,6 @@ def setup(self): padding=((1, 1), (1, 1)), dtype=self.dtype, param_dtype=self.weights_dtype, - kernel_init = nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ('keep_1', 'keep_2', 'conv_in', 'conv_out') - ) ) use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut @@ -229,19 +196,11 @@ def __call__(self, hidden_states, deterministic=True): hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.conv1(hidden_states) - hidden_states = nn.with_logical_constraint( - hidden_states, - (BATCH, KEEP_1, KEEP_2, CONV_OUT) - ) hidden_states = self.norm2(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic) hidden_states = self.conv2(hidden_states) - hidden_states = nn.with_logical_constraint( - hidden_states, - (BATCH, KEEP_1, KEEP_2, CONV_OUT) - ) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) @@ -275,23 +234,14 @@ def setup(self): dense = partial(nn.Dense, self.channels, dtype=self.dtype, param_dtype=self.weights_dtype) - qkv_init_kernel = nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ('embed','heads') - ) - self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6, dtype=self.dtype, param_dtype=self.weights_dtype) self.query, self.key, self.value = ( - dense(kernel_init=qkv_init_kernel), - dense(kernel_init=qkv_init_kernel), - dense(kernel_init=qkv_init_kernel) + dense(), + dense(), + dense() ) - proj_attn_init_kernel = nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ('heads','embed') - ) - self.proj_attn = dense(kernel_init=proj_attn_init_kernel) + self.proj_attn = dense() def transpose_for_scores(self, projection): new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) @@ -313,19 +263,6 @@ def __call__(self, hidden_states): key = self.key(hidden_states) value = self.value(hidden_states) - query = nn.with_logical_constraint( - query, - (BATCH, LENGTH, HEAD) - ) - key = nn.with_logical_constraint( - key, - (BATCH, LENGTH, HEAD) - ) - value = nn.with_logical_constraint( - value, - (BATCH, LENGTH, HEAD) - ) - # transpose query = self.transpose_for_scores(query) key = self.transpose_for_scores(key) @@ -347,7 +284,6 @@ def __call__(self, hidden_states): hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.reshape((batch, height, width, channels)) hidden_states = hidden_states + residual - hidden_states = nn.with_logical_constraint(hidden_states,(BATCH, LENGTH, HEAD)) return hidden_states @@ -597,10 +533,6 @@ def setup(self): padding=((1, 1), (1, 1)), dtype=self.dtype, param_dtype=self.weights_dtype, - kernel_init = nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ('keep_1', 'keep_2', 'conv_in', 'conv_out') - ) ) # downsampling @@ -958,10 +890,6 @@ def decode(self, latents, deterministic: bool = True, return_dict: bool = True): latents = jnp.transpose(latents, (0, 2, 3, 1)) hidden_states = self.post_quant_conv(latents) - hidden_states = nn.with_logical_constraint( - hidden_states, - (BATCH, 'keep_1', 'keep_2', 'out_channels') - ) hidden_states = self.decoder(hidden_states, deterministic=deterministic) hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) diff --git a/src/maxdiffusion/pedagogical_examples/checkpoint_params_restore.py b/src/maxdiffusion/pedagogical_examples/checkpoint_params_restore.py new file mode 100644 index 00000000..978a97c3 --- /dev/null +++ b/src/maxdiffusion/pedagogical_examples/checkpoint_params_restore.py @@ -0,0 +1,77 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +'''This script is used an example of how to restore params from a orbax train_state ckpt.''' + +import os +import functools +from absl import app +from typing import Sequence +from etils import epath + +import jax +from jax.sharding import Mesh +import orbax.checkpoint as ocp +from maxdiffusion import pyconfig, max_utils +from maxdiffusion.models import FlaxUNet2DConditionModel + +def run(config): + rng = jax.random.PRNGKey(config.seed) + + # Creates mesh using number of devices available + # and ici/dcn parallelism rules + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + # Load the UNET from the checkpoint + unet, params = FlaxUNet2DConditionModel.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + subfolder="unet", + split_head_dim=True + ) + + weights_init_fn = functools.partial(unet.init_weights, rng=rng) + #max_utils.get_abstract_state(unet, None, config, mesh, weights_init_fn, training=False) + + unboxed_abstract_state, _, _ = max_utils.get_abstract_state(unet, None, config, mesh, weights_init_fn,False) + ckptr = ocp.PyTreeCheckpointer() + + #ckpt_path = config.checkpoint_dir + ckpt_path = os.path.join(config.checkpoint_dir,"11","unet_state") + ckpt_path = epath.Path(ckpt_path) + + + print(f"loading paramteres from : {ckpt_path}") + + restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_state.params) + restored = ckptr.restore( + ckpt_path, + item={"params": unboxed_abstract_state.params}, + transforms={}, + restore_args={"params" : restore_args} + ) + return restored["params"] + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + +# Run via: +# python src/maxdiffusion/pedagogical_examples/checkpoint_params_restore.py src/diffusers/configs/base_xl.yml +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/pedagogical_examples/model_flop_calculation.py b/src/maxdiffusion/pedagogical_examples/model_flop_calculation.py index 62c0420b..28bbf9d2 100644 --- a/src/maxdiffusion/pedagogical_examples/model_flop_calculation.py +++ b/src/maxdiffusion/pedagogical_examples/model_flop_calculation.py @@ -16,7 +16,7 @@ """ Run ex: -python src/maxdiffusion/pedagogical_examples/model_flop_calculation.py src/maxdiffusion/configs/base15.yml +python src/maxdiffusion/pedagogical_examples/model_flop_calculation.py src/maxdiffusion/configs/base14.yml """ from absl import app diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 357f9c15..38a722e2 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -132,6 +132,9 @@ def user_init(raw_keys): raw_keys["unet_checkpoint"] = max_utils.download_blobs(raw_keys["unet_checkpoint"], "/tmp") if "gs://" in raw_keys["tokenizer_model_name_or_path"]: raw_keys["tokenizer_model_name_or_path"] = max_utils.download_blobs(raw_keys["tokenizer_model_name_or_path"],"/tmp") + if "gs://" in raw_keys["dataset_name"]: + raw_keys["dataset_name"] = max_utils.download_blobs(raw_keys["dataset_name"], raw_keys["dataset_save_location"]) + raw_keys["dataset_save_location"] = raw_keys["dataset_name"] def get_num_target_devices(raw_keys): return len(jax.devices()) diff --git a/src/maxdiffusion/tests/controlnet_tests.py b/src/maxdiffusion/tests/controlnet_tests.py index 964f288f..13ba19be 100644 --- a/src/maxdiffusion/tests/controlnet_tests.py +++ b/src/maxdiffusion/tests/controlnet_tests.py @@ -18,7 +18,7 @@ class ControlNet(unittest.TestCase): def test_controlnet(self): img_url = os.path.join(THIS_DIR,'images','cnet_test.png') base_image = np.array(Image.open(img_url)).astype(np.uint8) - pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base15.yml'), + pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base14.yml'), "prompt=best quality, extremely detailed","activations_dtype=bfloat16","weights_dtype=bfloat16", "negative_prompt=monochrome, lowres, bad anatomy, worst quality, low quality", "num_inference_steps=50","seed=0","split_head_dim=False"]) diff --git a/src/maxdiffusion/tests/dreambooth_tests.py b/src/maxdiffusion/tests/dreambooth_tests.py index 36255a3f..51d7d4c6 100644 --- a/src/maxdiffusion/tests/dreambooth_tests.py +++ b/src/maxdiffusion/tests/dreambooth_tests.py @@ -45,7 +45,7 @@ def test_prior_preservation(self): """Test prior preservation function generates images.""" num_class_images=16 - pyconfig.initialize([None, os.path.join(THIS_DIR,'..','configs','base15.yml'), + pyconfig.initialize([None, os.path.join(THIS_DIR,'..','configs','base14.yml'), "class_data_dir=/tmp/class_data_dir", "class_prompt=a photo of a dog", f"num_class_images={num_class_images}"]) @@ -75,7 +75,7 @@ def test_dreambooth_training(self): # so setting it here. jax.config.update("jax_compilation_cache_dir",cache_dir) - pyconfig.initialize([None, os.path.join(THIS_DIR,'..','configs','base15.yml'), + pyconfig.initialize([None, os.path.join(THIS_DIR,'..','configs','base14.yml'), "class_data_dir=test_dreambooth", f"instance_data_dir={instance_class_local_dir}", f"class_data_dir={class_class_local_dir}","instance_prompt=a photo of ohwx dog", "class_prompt=photo of a dog","max_train_steps=150",f"cache_dir={cache_dir}", @@ -90,7 +90,7 @@ def test_dreambooth_training(self): img_url = os.path.join(THIS_DIR,'images','dreambooth_test.png') base_image = np.array(Image.open(img_url)).astype(np.uint8) - pyconfig.initialize([None, os.path.join(THIS_DIR,'..','configs','base15.yml'), + pyconfig.initialize([None, os.path.join(THIS_DIR,'..','configs','base14.yml'), "prompt=a photo of a ohwx dog", "revision=main", f"pretrained_model_name_or_path={output_dir}/{run_name}/checkpoints/final"]) diff --git a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py index 49ec7772..9a23817b 100644 --- a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py +++ b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py @@ -1,20 +1,22 @@ import os import unittest +import pytest import numpy as np from ..import pyconfig from absl.testing import absltest from maxdiffusion.generate_sdxl import run as generate_run_xl -from maxdiffusion.controlnet.generate_controlnet_sdxl_replicated import run as generate_run_sdxl_controlnet from PIL import Image from skimage.metrics import structural_similarity as ssim +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" THIS_DIR = os.path.dirname(os.path.abspath(__file__)) class Generate(unittest.TestCase): """Smoke test.""" + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_sdxl_config(self): img_url = os.path.join(THIS_DIR,'images','test_sdxl.png') base_image = np.array(Image.open(img_url)).astype(np.uint8) @@ -33,6 +35,7 @@ def test_sdxl_config(self): assert base_image.shape == test_image.shape assert ssim_compare >=0.80 + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_sdxl_from_gcs(self): """Verify load weights from gcs.""" img_url = os.path.join(THIS_DIR,'images','test_sdxl.png') @@ -52,7 +55,10 @@ def test_sdxl_from_gcs(self): assert base_image.shape == test_image.shape assert ssim_compare >=0.80 + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_controlnet_sdxl(self): + from maxdiffusion.controlnet.generate_controlnet_sdxl_replicated import run as generate_run_sdxl_controlnet + img_url = os.path.join(THIS_DIR,'images','cnet_test_sdxl.png') base_image = np.array(Image.open(img_url)).astype(np.uint8) pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base_xl.yml'), @@ -66,6 +72,7 @@ def test_controlnet_sdxl(self): assert base_image.shape == test_image.shape assert ssim_compare >=0.70 + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_sdxl_lightning(self): img_url = os.path.join(THIS_DIR,'images','test_lightning.png') base_image = np.array(Image.open(img_url)).astype(np.uint8) diff --git a/src/maxdiffusion/tests/generate_smoke_test.py b/src/maxdiffusion/tests/generate_smoke_test.py index 6a0db755..b8e0e7e7 100644 --- a/src/maxdiffusion/tests/generate_smoke_test.py +++ b/src/maxdiffusion/tests/generate_smoke_test.py @@ -1,5 +1,6 @@ import os import unittest +import pytest import numpy as np from ..import pyconfig @@ -9,6 +10,7 @@ from PIL import Image from skimage.metrics import structural_similarity as ssim +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -19,10 +21,11 @@ def setUp(self): super().setUp() Generate.dummy_data = {} - def test_sd15_config(self): - img_url = os.path.join(THIS_DIR,'images','test_gen_sd15.png') + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") + def test_sd14_config(self): + img_url = os.path.join(THIS_DIR,'images','test_gen_sd14.png') base_image = np.array(Image.open(img_url)).astype(np.uint8) - pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base15.yml'), + pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base14.yml'), "seed=47","output_dir=gs://maxdiffusion-github-runner-test-assets", "run_name=gen-test-15-config"],unittest=True) images = generate_run(pyconfig.config) @@ -33,6 +36,7 @@ def test_sd15_config(self): assert base_image.shape == test_image.shape assert ssim_compare >=0.70 + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_sd_2_base_from_gcs(self): img_url = os.path.join(THIS_DIR,'images','test_2_base.png') base_image = np.array(Image.open(img_url)).astype(np.uint8) @@ -48,10 +52,11 @@ def test_sd_2_base_from_gcs(self): assert base_image.shape == test_image.shape assert ssim_compare >=0.70 + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_controlnet(self): img_url = os.path.join(THIS_DIR,'images','cnet_test.png') base_image = np.array(Image.open(img_url)).astype(np.uint8) - pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base15.yml'), + pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base14.yml'), "prompt=best quality, extremely detailed","activations_dtype=bfloat16","weights_dtype=bfloat16", "negative_prompt=monochrome, lowres, bad anatomy, worst quality, low quality", "num_inference_steps=50","seed=0","split_head_dim=False"],unittest=True) diff --git a/src/maxdiffusion/tests/images/test_gen_sd14.png b/src/maxdiffusion/tests/images/test_gen_sd14.png new file mode 100644 index 00000000..5a2af6f5 Binary files /dev/null and b/src/maxdiffusion/tests/images/test_gen_sd14.png differ diff --git a/src/maxdiffusion/tests/images/test_gen_sd15.png b/src/maxdiffusion/tests/images/test_gen_sd15.png deleted file mode 100644 index fc87237d..00000000 Binary files a/src/maxdiffusion/tests/images/test_gen_sd15.png and /dev/null differ diff --git a/src/maxdiffusion/tests/input_pipeline_interface_test.py b/src/maxdiffusion/tests/input_pipeline_interface_test.py index d1a01a87..6270dfb6 100644 --- a/src/maxdiffusion/tests/input_pipeline_interface_test.py +++ b/src/maxdiffusion/tests/input_pipeline_interface_test.py @@ -84,7 +84,7 @@ def test_make_dreambooth_train_iterator(self): class_class_local_dir = max_utils.download_blobs(class_class_gcs_dir, local_dir) - pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base15.yml'), + pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base14.yml'), "cache_latents_text_encoder_outputs=True", "dataset_name=my_dreambooth_dataset", f"instance_data_dir={instance_class_local_dir}", diff --git a/src/maxdiffusion/tests/train_smoke_test.py b/src/maxdiffusion/tests/train_smoke_test.py index bd59c53d..5acd6e5f 100644 --- a/src/maxdiffusion/tests/train_smoke_test.py +++ b/src/maxdiffusion/tests/train_smoke_test.py @@ -16,6 +16,7 @@ """ Smoke test """ import os +import pytest import pathlib import shutil import unittest @@ -35,6 +36,8 @@ validate_train_config, ) +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + HOME_DIR = pathlib.Path.home() THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -55,6 +58,7 @@ class Train(unittest.TestCase): def setUp(self): Train.dummy_data = {} + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_sdxl_config(self): output_dir="gs://maxdiffusion-github-runner-test-assets" run_name="sdxl_train_smoke_test" @@ -70,12 +74,13 @@ def test_sdxl_config(self): "per_device_batch_size=1", 'timestep_bias={"strategy" : "later", "multiplier" : 2.0, "portion" : 0.25}', f"output_dir={output_dir}", - f"jax_cache_dir={cache_dir}"],unittest=True) + f"jax_cache_dir={cache_dir}"], unittest=True) train_sdxl(pyconfig.config) delete_blobs(os.path.join(output_dir,run_name)) + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_dreambooth_orbax(self): num_class_images=100 output_dir="gs://maxdiffusion-github-runner-test-assets" @@ -90,7 +95,7 @@ def test_dreambooth_orbax(self): instance_class_local_dir = max_utils.download_blobs(instance_class_gcs_dir, local_dir) class_class_local_dir = max_utils.download_blobs(class_class_gcs_dir, local_dir) - pyconfig.initialize([None, os.path.join(THIS_DIR,'..','configs','base15.yml'), + pyconfig.initialize([None, os.path.join(THIS_DIR,'..','configs','base14.yml'), f"instance_data_dir={instance_class_local_dir}", f"class_data_dir={class_class_local_dir}","instance_prompt=a photo of ohwx dog", "class_prompt=photo of a dog","max_train_steps=150",f"jax_cache_dir={cache_dir}", @@ -100,7 +105,7 @@ def test_dreambooth_orbax(self): "weights_dtype=float32","per_device_batch_size=1","enable_profiler=False","precision=DEFAULT", "cache_dreambooth_dataset=False","learning_rate=4e-6",f"output_dir={output_dir}", f"num_class_images={num_class_images}",f"run_name={run_name}", - "prompt=a photo of ohwx dog", "seed=47"],unittest=True) + "prompt=a photo of ohwx dog", "seed=47"], unittest=True) config = pyconfig.config validate_train_config(config) @@ -110,6 +115,7 @@ def test_dreambooth_orbax(self): cleanup(class_class_local_dir) delete_blobs(os.path.join(output_dir,run_name)) + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_sd15_orbax(self): output_dir="gs://maxdiffusion-github-runner-test-assets" run_name="sd15_orbax_smoke_test" @@ -117,8 +123,8 @@ def test_sd15_orbax(self): delete_blobs(os.path.join(output_dir,run_name)) - pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base15.yml'), - f"run_name={run_name}", "checkpoint_every=256","upload_ckpts_to_gcs=True", + pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base14.yml'), + f"run_name={run_name}", "checkpoint_every=256", "max_train_steps=21","per_device_batch_size=8", f"output_dir={output_dir}", "prompt=A magical castle in the middle of a forest, artistic drawing", "negative_prompt=purple, red","guidance_scale=7.5", diff --git a/src/maxdiffusion/tests/unet_test.py b/src/maxdiffusion/tests/unet_test.py index 6af6bc53..51ff814d 100644 --- a/src/maxdiffusion/tests/unet_test.py +++ b/src/maxdiffusion/tests/unet_test.py @@ -46,7 +46,7 @@ def setUp(self): UnetTest.dummy_data = {} def test_unet15_sharding_test(self): - pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base15.yml'), + pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base14.yml'), "activations_dtype=bfloat16","resolution=512"], unittest=True) config = pyconfig.config unet, params = FlaxUNet2DConditionModel.from_pretrained( diff --git a/src/maxdiffusion/tests/vae_test.py b/src/maxdiffusion/tests/vae_test.py deleted file mode 100644 index 868e75da..00000000 --- a/src/maxdiffusion/tests/vae_test.py +++ /dev/null @@ -1,99 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -""" Smoke test """ -import os -import unittest -import functools -from absl.testing import absltest - -import jax -import jax.numpy as jnp -from ..import max_utils -from ..import pyconfig -from maxdiffusion import FlaxAutoencoderKL -from flax.training import train_state -import optax -from jax.sharding import Mesh, PartitionSpec, NamedSharding - - -THIS_DIR = os.path.dirname(os.path.abspath(__file__)) - -def init_fn(params, model, optimizer): - state = train_state.TrainState.create( - apply_fn=model.apply, - params=params, - tx=optimizer) - return state - -class VaeTest(unittest.TestCase): - """Test Unet sharding""" - def setUp(self): - VaeTest.dummy_data = {} - - def test_vae21_sharding_test(self): - pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base21.yml'), - "pretrained_model_name_or_path=stabilityai/stable-diffusion-2-1", - "revision=bf16","activations_dtype=bfloat16","resolution=768"],unittest=True) - config = pyconfig.config - vae, vae_params = FlaxAutoencoderKL.from_pretrained( - config.pretrained_model_name_or_path, revision=config.revision, subfolder="vae", dtype=jnp.bfloat16, from_pt=config.from_pt - ) - devices_array = max_utils.create_device_mesh(config) - - rng = jax.random.PRNGKey(config.seed) - mesh = Mesh(devices_array, config.mesh_axes) - k = jax.random.key(0) - tx = optax.adam(learning_rate=0.001) - latents = jnp.ones((4,4,96,96), dtype=jnp.float32) - - variables = jax.jit(vae.init)(k, latents) - weights_init_fn = functools.partial(vae.init_weights, rng=rng) - _, state_mesh_annotations, _ = max_utils.get_abstract_state(vae, tx, config, mesh, weights_init_fn, False) - del variables - qkv_sharding = PartitionSpec(None, None) - conv_sharding = PartitionSpec(None, None, None, 'fsdp') - assert state_mesh_annotations.params['decoder']['mid_block']['resnets_0']['conv1']['kernel'] == conv_sharding - assert state_mesh_annotations.params['decoder']['mid_block']['resnets_0']['conv2']['kernel'] == conv_sharding - assert state_mesh_annotations.params['decoder']['mid_block']['attentions_0']['key']['kernel'] == qkv_sharding - assert state_mesh_annotations.params['decoder']['mid_block']['attentions_0']['query']['kernel'] == qkv_sharding - assert state_mesh_annotations.params['decoder']['mid_block']['attentions_0']['value']['kernel'] == qkv_sharding - - _, vae_state_state_mesh_shardings = max_utils.setup_initial_state( - vae, - tx, - config, - mesh, - weights_init_fn, - None, - None, - None, - False - ) - - conv_sharding = PartitionSpec(None, None, None, 'fsdp') - qkv_sharding = PartitionSpec(None, None) - qkv_named_sharding = NamedSharding(mesh, qkv_sharding) - conv_named_sharding = NamedSharding(mesh, conv_sharding) - - assert vae_state_state_mesh_shardings.params['decoder']['mid_block']['resnets_0']['conv1']['kernel'] == conv_named_sharding - assert vae_state_state_mesh_shardings.params['decoder']['mid_block']['resnets_0']['conv2']['kernel'] == conv_named_sharding - assert vae_state_state_mesh_shardings.params['decoder']['mid_block']['attentions_0']['key']['kernel'] == qkv_named_sharding - assert vae_state_state_mesh_shardings.params['decoder']['mid_block']['attentions_0']['query']['kernel'] == qkv_named_sharding - assert vae_state_state_mesh_shardings.params['decoder']['mid_block']['attentions_0']['value']['kernel'] == qkv_named_sharding - -if __name__ == '__main__': - absltest.main() diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index a5ebc90f..270fceeb 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -14,8 +14,6 @@ limitations under the License. """ -import os -import shutil import numpy as np import jax import jax.numpy as jnp @@ -165,33 +163,3 @@ def generate_timestep_weights(config, num_timesteps): weights /= weights.sum() return jnp.array(weights) -# def save_orbax_checkpoint(unet_state, pipeline, text_encoder_state = None): - - -def save_checkpoint(save_fn, params, config, output_dir): - """ - Save checkpoint. - - Args: - save_fn: Must be a save_pretrained fn, for example, pipeline.save_pretrained. - params: params to save. - config: pyconfig - output_dir: local directory path. - """ - user_dir = os.path.expanduser("~") - local_output_dir = output_dir.replace(os.path.join( - config.base_output_directory, - config.run_name - ), user_dir) - - save_fn( - local_output_dir, - params=params - ) - - if jax.process_index() == 0 and config.upload_ckpts_to_gcs: - max_utils.walk_and_upload_blobs(config, local_output_dir) - # delete files in output_dir to save space - max_logging.log(f"Deleting {local_output_dir} to save space.") - shutil.rmtree(local_output_dir) - diff --git a/src/maxdiffusion/trainers/dreambooth_trainer.py b/src/maxdiffusion/trainers/dreambooth_trainer.py index 0d79640f..6350d3ee 100644 --- a/src/maxdiffusion/trainers/dreambooth_trainer.py +++ b/src/maxdiffusion/trainers/dreambooth_trainer.py @@ -247,14 +247,14 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0: train_states["unet_state"] = unet_state train_states["text_encoder_state"] = text_encoder_state - self.save_checkpoint(step, pipeline, params, train_states, save_inference_states=False) + self.save_checkpoint(step, pipeline, params, train_states) if self.config.write_metrics: train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) train_states["unet_state"] = unet_state train_states["text_encoder_state"] = text_encoder_state - self.save_checkpoint(step, pipeline, params, train_states, save_inference_states=True) + self.save_checkpoint(step, pipeline, params, train_states) self.checkpoint_manager.wait_until_finished() return train_states diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index 87c144fc..e746b1b9 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -116,7 +116,8 @@ def load_dataset(self, pipeline, params, train_states): total_train_batch_size = self.total_train_batch_size mesh = self.mesh - if self.config.dataset_name == "diffusers/pokemon-gpt4-captions": + # ideally : diffusers/pokemon-gpt4-captions, but if loading from gcs, make sure the folder has pokemon in the name. + if "pokemon" in self.config.dataset_name: p_encode = None p_vae_apply = None if config.cache_latents_text_encoder_outputs: @@ -230,7 +231,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera train_states["unet_state"] = unet_state train_states["text_encoder_state"] = text_encoder_state train_states["text_encoder_2_state"] = text_encoder_2_state - self.save_checkpoint(step, pipeline, params, train_states, save_inference_states=False) + self.save_checkpoint(step, pipeline, params, train_states) if self.config.write_metrics: write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) @@ -238,7 +239,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera train_states["unet_state"] = unet_state train_states["text_encoder_state"] = text_encoder_state train_states["text_encoder_2_state"] = text_encoder_2_state - self.save_checkpoint(step, pipeline, params, train_states, save_inference_states=True) + self.save_checkpoint(step, pipeline, params, train_states) self.checkpoint_manager.wait_until_finished() def _train_step(unet_state, batch, train_rng, pipeline, params, config): diff --git a/src/maxdiffusion/trainers/stable_diffusion_trainer.py b/src/maxdiffusion/trainers/stable_diffusion_trainer.py index 8f83fbdb..62408729 100644 --- a/src/maxdiffusion/trainers/stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/stable_diffusion_trainer.py @@ -100,7 +100,8 @@ def get_data_shardings(self): return data_sharding def load_dataset(self, pipeline, params, train_states): - if self.config.dataset_name == "diffusers/pokemon-gpt4-captions": + # ideally : diffusers/pokemon-gpt4-captions, but if loading from gcs, make sure the folder has pokemon in the name. + if "pokemon" in self.config.dataset_name: p_encode = None p_vae_apply = None if self.config.cache_latents_text_encoder_outputs: @@ -234,7 +235,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera train_states["unet_state"] = unet_state train_states["vae_state"] = vae_state train_states["text_encoder"] = text_encoder_state - self.save_checkpoint(step, pipeline, params, train_states, save_inference_states=False) + self.save_checkpoint(step, pipeline, params, train_states) if self.config.write_metrics: train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) @@ -243,7 +244,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera train_states["vae_state"] = vae_state train_states["text_encoder"] = text_encoder_state # save the inference states of the last checkpoint so they can be easily loaded during gen. - self.save_checkpoint(step, pipeline, params, train_states, save_inference_states=True) + self.save_checkpoint(step, pipeline, params, train_states) self.checkpoint_manager.wait_until_finished() def _train_step(unet_state, vae_state, text_encoder_state, batch, train_rng, pipeline, params, config): diff --git a/src/maxdiffusion/transformers/models/clip/modeling_flax_clip.py b/src/maxdiffusion/transformers/models/clip/modeling_flax_clip.py index 38d2aa71..d6e9e1bd 100644 --- a/src/maxdiffusion/transformers/models/clip/modeling_flax_clip.py +++ b/src/maxdiffusion/transformers/models/clip/modeling_flax_clip.py @@ -606,6 +606,7 @@ def __call__( class FlaxCLIPVisionTransformer(nn.Module): config: CLIPVisionConfig dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 def setup(self): self.embeddings = FlaxCLIPVisionEmbeddings(self.config, dtype=self.dtype)