Optimized K-means for 1D case (flash1dkmeans integration for faster quantization) #72
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hi again, I previously hinted at faster 1D specific K-means optimizations at #60, and mentioned that my library flash1dkmeans achieved this at #67.
Here I propose a simple integration of my library into
nuq.py
.This yields a modest 5x speedup on top of the previous 22.7x speedup at #60, and with this integration each Llama 2 7B layer can be quantized in 2 or 3 seconds. Excluding file IO time, this would make quantization time close to 1 minute for the whole model, down from 6 minutes (which was, in turn, originally down from 2 hours!)
In our Any-Precision LLM codebase we actually managed to bring down this time close to 30 seconds by using Numba multithreading (possible by using underlying Numba functions of
flash1dkmeans
), and pipelining out the disk IO. However these are additional separate optimizations, and in this PR I focus on providing a drop-in replacement for sklearn's K-means.The main speedup comes from reducing the time complexity of K-means++ initialization and Lloyd's algorithm iterations by exploiting sorted prefix sum arrays - only possible with 1D data.
If interested in further speeding up the quantization, please consider testing this code.
Questions are welcome!
I noticed that scikit-learn was not in the original
pyproject.toml
dependencies, even with its usage innuq.py
. If dependencies exclusive to the quantization pipeline are not meant to be included inpyproject.toml
, you may want to exclude that part of this PR.