Skip to content

Commit

Permalink
Replace custom sharding with save_torch_state_dict from huggingface_h…
Browse files Browse the repository at this point in the history
…ub (#644)
  • Loading branch information
casper-hansen authored Nov 14, 2024
1 parent 419a242 commit a28c747
Showing 1 changed file with 7 additions and 26 deletions.
33 changes: 7 additions & 26 deletions awq/models/base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import os
import gc
import json
import warnings
import logging
import torch
import transformers
import torch.nn as nn

from tqdm import tqdm
from typing import List, Union, Dict
from safetensors.torch import save_file
from typing_extensions import Doc, Annotated
from huggingface_hub import snapshot_download
from transformers.modeling_utils import shard_checkpoint
from huggingface_hub import snapshot_download, save_torch_state_dict

from awq.modules.linear import (
WQLinear_GEMM,
Expand Down Expand Up @@ -306,29 +302,14 @@ def forward(self, x):
if os.path.exists(path):
os.remove(path)

# model_name has no extension, add it when saving state_dict
model_name = "model.safetensors" if safetensors else "pytorch_model.bin"

# shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint(
self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
save_torch_state_dict(
state_dict=self.model.state_dict(),
save_directory=save_dir,
max_shard_size=shard_size,
safe_serialization=safetensors,
force_contiguous=True,
)

for shard_file, shard in shards.items():
if safetensors:
# safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(
shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_dir, shard_file))

# save shard index
if index is not None:
with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
file.write(json.dumps(index, indent=4))

@classmethod
def from_pretrained(
self,
Expand Down

0 comments on commit a28c747

Please sign in to comment.