-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Range-based 16b quantization #124
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very elegant imho, a great first step!
I think as a next step it might be interesting to try and store the embeddings in memory also with the more compact form (and de-quantize after lookup) this might speed up lookup as there will be less memory pages and more embeddings per page (especially with 8 bit quantization)
src/quantization.rs
Outdated
} | ||
|
||
// Uniform distribution within the relevant interval | ||
let weight_increment = (weight_statistics.max - weight_statistics.min) / NUM_BUCKETS; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great! I guess if we want to reduce to 1 byte per weight in the future we'll need to try quantile based buckets right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct + non-uniform range estimation
@@ -18,6 +18,7 @@ use std::time::Instant; | |||
|
|||
extern crate blas; | |||
extern crate intel_mkl_src; | |||
extern crate half; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't use half after all right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's used when creating 2byte reps for a given quantized int (that we make sure before first in to 16b range though), no need to roll our own at this point imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my bad! I was looking for "half" instead of "f16". do you think it will perform badly if we use u16 instead? - greater loss of accuracy for sure... just the bucket index, but won't we have to do it that way when we switch to 8 bit quantization (with a lookup table)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be a small difference, but it's basically the same thing, f16 felt somewhat safer here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on the upside, with f16, for much of the range (lower values...) you will be able to get a more accurate representation, but for the far end of the range most buckets will be unrepresented, so the precision will suffer.
probably f16 is still better overall (but worth testing maybe)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 👍
src/block_ffm.rs
Outdated
block_helpers::read_weights_from_buf(&mut self.optimizer, input_bufreader)?; | ||
|
||
if use_quantization { | ||
// in-place expand weights via dequantization (for inference) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant comment
block_helpers::write_weights_to_buf(&self.optimizer, output_bufwriter)?; | ||
|
||
if use_quantization { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove empty line
src/block_ffm.rs
Outdated
@@ -866,7 +883,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> { | |||
.as_any() | |||
.downcast_mut::<BlockFFM<optimizer::OptimizerSGD>>() | |||
.unwrap(); | |||
block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader)?; | |||
block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader, false)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No additional tests?
) -> Result<(), Box<dyn Error>> { | ||
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter) | ||
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know why, but I thought this was the same as with the forward/forward_backward pass (aka iterating over everything). For future, lets think about moving the writing/reading to a module that is meant for weights/file handling & not the iteration/block_helpers part
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed
Turns out that by actually estimating the weight ranges + overloading the weight buffer with a mini header that's used for dequantization (we need two f32s passed onwards), we can do much better than just down-casting and re-casting. Appears to offer noise-level approximation quality on reasonable offline data.
This immediately opened an interesting side result that's not part of this PR - it seems we might be able to pull this off with just one byte too (minor decay in performance, nothing too drastic)
Use is via an additional (optional, it's backwards compatible) param
--quantize_weights
that has to be run during the weight conversion phase (pre-inference).Added also a test prior to each quantization that panics if weight distribution is too skewed (can happen with corrupted models) - this wasn't really tested before, yet should be as it's a cheap sanity check pre-prod.