diff --git a/awq/models/base.py b/awq/models/base.py index c70b0363..3a525f82 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -1,18 +1,14 @@ import os import gc -import json import warnings -import logging import torch import transformers import torch.nn as nn from tqdm import tqdm from typing import List, Union, Dict -from safetensors.torch import save_file from typing_extensions import Doc, Annotated -from huggingface_hub import snapshot_download -from transformers.modeling_utils import shard_checkpoint +from huggingface_hub import snapshot_download, save_torch_state_dict from awq.modules.linear import ( WQLinear_GEMM, @@ -306,29 +302,14 @@ def forward(self, x): if os.path.exists(path): os.remove(path) - # model_name has no extension, add it when saving state_dict - model_name = "model.safetensors" if safetensors else "pytorch_model.bin" - - # shard checkpoint into chunks (10GB default) - shards, index = shard_checkpoint( - self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name + save_torch_state_dict( + state_dict=self.model.state_dict(), + save_directory=save_dir, + max_shard_size=shard_size, + safe_serialization=safetensors, + force_contiguous=True, ) - for shard_file, shard in shards.items(): - if safetensors: - # safetensors must be in the same memory, so we duplicate and use contiguous memory - shard = {k: v.clone().contiguous() for k, v in shard.items()} - save_file( - shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"} - ) - else: - torch.save(shard, os.path.join(save_dir, shard_file)) - - # save shard index - if index is not None: - with open(f"{save_dir}/{model_name}.index.json", "w+") as file: - file.write(json.dumps(index, indent=4)) - @classmethod def from_pretrained( self,