diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/generate_sdxl.py index be750fb..f7d1194 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/generate_sdxl.py @@ -225,6 +225,7 @@ def run(config): pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False ) + # load unet params from orbax checkpoint unet_params = load_params_from_path( config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "unet_state" ) @@ -253,14 +254,14 @@ def run(config): vae_state, vae_state_shardings = checkpoint_loader.create_vae_state( pipeline, params, checkpoint_item_name="vae_state", is_training=False ) - text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state( - pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False - ) - - text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state( - pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False - ) + with nn.intercept_methods(lora_interceptor): + text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state( + pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False + ) + text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state( + pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False + ) states = {} state_shardings = {} diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index 4b50743..fe4d6e9 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -121,21 +121,35 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas): def rename_for_interceptor(params_keys, network_alphas): new_params_keys = [] + new_network_alphas = {} for layer_lora in params_keys: if "lora" in layer_lora: new_layer_lora = layer_lora[: layer_lora.index("lora")] if new_layer_lora not in new_params_keys: new_params_keys.append(new_layer_lora) network_alpha = network_alphas[layer_lora] - del network_alphas[layer_lora] - network_alphas[new_layer_lora] = network_alpha - return new_params_keys, network_alphas + new_network_alphas[new_layer_lora] = network_alpha + return new_params_keys, new_network_alphas @classmethod def make_lora_interceptor(cls, params, rank, network_alphas): # Only unet interceptor supported for now. + network_alphas_for_interceptor = {} + unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys() - unet_lora_keys, network_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas) + lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas) + network_alphas_for_interceptor.update(unet_alphas) + + text_encoder_keys = flax.traverse_util.flatten_dict(params["text_encoder"]).keys() + text_encoder_keys, text_encoder_alphas = cls.rename_for_interceptor(text_encoder_keys, network_alphas) + lora_keys.extend(text_encoder_keys) + network_alphas_for_interceptor.update(text_encoder_alphas) + + if "text_encoder_2" in params.keys(): + text_encoder_2_keys = flax.traverse_util.flatten_dict(params["text_encoder_2"]).keys() + text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(text_encoder_2_keys, network_alphas) + lora_keys.extend(text_encoder_2_keys) + network_alphas_for_interceptor.update(text_encoder_2_alphas) def _intercept(next_fn, args, kwargs, context): mod = context.module @@ -146,8 +160,8 @@ def _intercept(next_fn, args, kwargs, context): h = next_fn(*args, **kwargs) if context.method_name == "__call__": module_path = context.module.path - if module_path in unet_lora_keys: - lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas) + if module_path in lora_keys: + lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor) return lora_layer(h, *args, **kwargs) return h diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index fb1a827..029641e 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -49,11 +49,13 @@ def _noop_interceptor(next_fn, args, kwargs, context): if len(lora_config["lora_model_name_or_path"]) > 0: # For now only first lora supported. In the future, they will be merged # before being loaded. + # TODO - merge LoRAs here. params, rank, network_alphas = pipeline.load_lora_weights( lora_config["lora_model_name_or_path"][0], weight_name=lora_config["weight_name"][0], params=params, adapter_name=lora_config["adapter_name"][0], + unet_config=pipeline.unet.config, ) interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas) diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 5f02ec8..3fa5417 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -137,7 +137,15 @@ def create_flax_params_from_pytorch_state( # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): network_alpha_value = get_network_alpha_value(pt_key, network_alphas) - renamed_pt_key = rename_key(pt_key) + + # rename text encoders fc1 lora layers. + pt_key = pt_key.replace("lora_linear_layer", "lora") + + # only rename the unet keys, text encoders are already correct. + if "unet" in pt_key: + renamed_pt_key = rename_key(pt_key) + else: + renamed_pt_key = pt_key pt_tuple_key = tuple(renamed_pt_key.split(".")) # conv if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: @@ -147,13 +155,24 @@ def create_flax_params_from_pytorch_state( flax_tensor = pt_tensor else: flax_key_list = [*pt_tuple_key] - for rename_from, rename_to in ( - ("to_k_lora", ("to_k", "lora")), - ("to_q_lora", ("to_q", "lora")), - ("to_v_lora", ("to_v", "lora")), - ("to_out_lora", ("to_out_0", "lora")), - ("weight", "kernel"), - ): + if "text_encoder" in pt_tuple_key or "text_encoder_2" in pt_tuple_key: + rename_from_to = ( + ("to_k_lora", ("k_proj", "lora")), + ("to_q_lora", ("q_proj", "lora")), + ("to_v_lora", ("v_proj", "lora")), + ("to_out_lora", ("out_proj", "lora")), + ("weight", "kernel"), + ) + # the unet + else: + rename_from_to = ( + ("to_k_lora", ("to_k", "lora")), + ("to_q_lora", ("to_q", "lora")), + ("to_v_lora", ("to_v", "lora")), + ("to_out_lora", ("to_out_0", "lora")), + ("weight", "kernel"), + ) + for rename_from, rename_to in rename_from_to: tmp = [] for s in flax_key_list: if s == rename_from: