-
Notifications
You must be signed in to change notification settings - Fork 553
/
prune_ckpt.py
26 lines (22 loc) · 891 Bytes
/
prune_ckpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
from ldm.pruner import prune_checkpoint
import torch
import argparse
parser = argparse.ArgumentParser(description='Pruning')
parser.add_argument('--ckpt', type=str, default=None, help='path to model ckpt')
args = parser.parse_args()
ckpt = args.ckpt
def prune_it(checkpoint_path):
print(f"Pruning checkpoint from path: {checkpoint_path}")
size_initial = os.path.getsize(checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
pruned = prune_checkpoint(checkpoint)
fn = f"{os.path.splitext(checkpoint_path)[0]}-pruned.ckpt"
print(f"Saving pruned checkpoint at: {fn}")
torch.save(pruned, fn)
newsize = os.path.getsize(fn)
MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \
f"Saved {(size_initial - newsize)*1e-9:.2f} GB by removing optimizer states"
print(MSG)
if __name__ == "__main__":
prune_it(ckpt)