Skip to content

Commit

Permalink
support loading safetensors format. (#123)
Browse files Browse the repository at this point in the history
Co-authored-by: Juan Acevedo <[email protected]>
  • Loading branch information
entrpn and jfacevedo-google authored Oct 17, 2024
1 parent 6271ab7 commit 47c7e92
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,21 +196,20 @@ def load_diffusers_checkpoint(self):
precision=precision,
)

if len(self.config.unet_checkpoint) > 0:
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
self.config.unet_checkpoint,
split_head_dim=self.config.split_head_dim,
norm_num_groups=self.config.norm_num_groups,
attention_kernel=self.config.attention,
flash_block_sizes=flash_block_sizes,
dtype=self.activations_dtype,
weights_dtype=self.weights_dtype,
mesh=self.mesh,
)
params["unet"] = unet_params
pipeline.unet = unet
params = jax.tree_util.tree_map(lambda x: x.astype(self.config.weights_dtype), params)

if len(self.config.unet_checkpoint) > 0:
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
self.config.unet_checkpoint,
split_head_dim=self.config.split_head_dim,
norm_num_groups=self.config.norm_num_groups,
attention_kernel=self.config.attention,
flash_block_sizes=flash_block_sizes,
dtype=self.activations_dtype,
weights_dtype=self.weights_dtype,
mesh=self.mesh,
)
params["unet"] = unet_params
pipeline.unet = unet
params = jax.tree_util.tree_map(lambda x: x.astype(self.config.weights_dtype), params)
return pipeline, params

def save_checkpoint(self, train_step, pipeline, params, train_states):
Expand Down
8 changes: 6 additions & 2 deletions src/maxdiffusion/models/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
PushToHubMixin,
logging,
)
Expand Down Expand Up @@ -331,9 +332,12 @@ def from_pretrained(
)
if os.path.isdir(pretrained_path_with_subfolder):
if from_pt:
if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
if os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, SAFETENSORS_WEIGHTS_NAME)):
model_file = os.path.join(pretrained_path_with_subfolder, SAFETENSORS_WEIGHTS_NAME)
else:
raise EnvironmentError(f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} ")
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
Expand Down

0 comments on commit 47c7e92

Please sign in to comment.