Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized K-means for 1D case (flash1dkmeans integration for faster quantization) #72

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ dependencies = [
"tokenizers>=0.12.1",
"torch",
"transformers==4.29.0",
"datasets"
"datasets",
"flash1dkmeans==0.1.4",
]

[tool.setuptools.packages.find]
32 changes: 19 additions & 13 deletions quantization/nuq.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import os
os.environ["OMP_NUM_THREADS"] = "1" # this is necessary to parallelize the kmeans

import argparse
import json
import pickle

import numpy as np
import torch
from sklearn.cluster import KMeans
from squeezellm.model_parse import get_module_names, parse_model
from squeezellm.outliers import remove_outliers
from tqdm import tqdm
from multiprocessing import Pool
import flash1dkmeans

parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -49,13 +47,22 @@
# Define the helper function for parallel k-means
def kmeans_fit(row_data):
weights_np, sample_weight, n_cluster = row_data
kmeans = KMeans(
n_clusters=n_cluster,
random_state=0,
n_init="auto",
max_iter=50,
).fit(weights_np, sample_weight=sample_weight)
return kmeans.cluster_centers_.reshape(-1), np.cast["byte"](kmeans.labels_)
if n_cluster > 2:
centers, labels = flash1dkmeans.kmeans_1d(
X=weights_np,
n_clusters=n_cluster,
sample_weights=sample_weight,
random_state=0,
max_iter=50,
)
else:
# When n_cluster is 2, random_state and max_iter are not used
centers, labels = flash1dkmeans.kmeans_1d(
X=weights_np,
n_clusters=n_cluster,
sample_weights=sample_weight,
)
return centers, np.cast["byte"](labels)


if __name__ == "__main__":
Expand Down Expand Up @@ -167,9 +174,8 @@ def kmeans_fit(row_data):
n_cluster = 2**args.bit

for i in range(module_weight.shape[0]):
weights_np_temp = _weights_np[i, :]
weights_np = weights_np_temp.reshape(-1, 1)
weight_mask = weights_np_temp != 0
weights_np = _weights_np[i, :]
weight_mask = weights_np != 0
sample_weight = g[i, :] * weight_mask
if np.sum(sample_weight) == 0:
sample_weight = np.ones_like(sample_weight)
Expand Down