Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Range-based 16b quantization #124

Merged
merged 22 commits into from
Jan 15, 2024
Merged

Range-based 16b quantization #124

merged 22 commits into from
Jan 15, 2024

Conversation

SkBlaz
Copy link
Collaborator

@SkBlaz SkBlaz commented Dec 3, 2023

Turns out that by actually estimating the weight ranges + overloading the weight buffer with a mini header that's used for dequantization (we need two f32s passed onwards), we can do much better than just down-casting and re-casting. Appears to offer noise-level approximation quality on reasonable offline data.

This immediately opened an interesting side result that's not part of this PR - it seems we might be able to pull this off with just one byte too (minor decay in performance, nothing too drastic)

Use is via an additional (optional, it's backwards compatible) param --quantize_weights that has to be run during the weight conversion phase (pre-inference).

Added also a test prior to each quantization that panics if weight distribution is too skewed (can happen with corrupted models) - this wasn't really tested before, yet should be as it's a cheap sanity check pre-prod.

Copy link
Collaborator

@yonatankarni yonatankarni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very elegant imho, a great first step!

waynesworld

I think as a next step it might be interesting to try and store the embeddings in memory also with the more compact form (and de-quantize after lookup) this might speed up lookup as there will be less memory pages and more embeddings per page (especially with 8 bit quantization)

}

// Uniform distribution within the relevant interval
let weight_increment = (weight_statistics.max - weight_statistics.min) / NUM_BUCKETS;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great! I guess if we want to reduce to 1 byte per weight in the future we'll need to try quantile based buckets right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct + non-uniform range estimation

@@ -18,6 +18,7 @@ use std::time::Instant;

extern crate blas;
extern crate intel_mkl_src;
extern crate half;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't use half after all right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's used when creating 2byte reps for a given quantized int (that we make sure before first in to 16b range though), no need to roll our own at this point imo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad! I was looking for "half" instead of "f16". do you think it will perform badly if we use u16 instead? - greater loss of accuracy for sure... just the bucket index, but won't we have to do it that way when we switch to 8 bit quantization (with a lookup table)?

Copy link
Collaborator Author

@SkBlaz SkBlaz Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a small difference, but it's basically the same thing, f16 felt somewhat safer here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the upside, with f16, for much of the range (lower values...) you will be able to get a more accurate representation, but for the far end of the range most buckets will be unrepresented, so the precision will suffer.
probably f16 is still better overall (but worth testing maybe)

Copy link
Collaborator

@ggaspersic ggaspersic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍

src/block_ffm.rs Outdated
block_helpers::read_weights_from_buf(&mut self.optimizer, input_bufreader)?;

if use_quantization {
// in-place expand weights via dequantization (for inference)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant comment

block_helpers::write_weights_to_buf(&self.optimizer, output_bufwriter)?;

if use_quantization {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove empty line

src/block_ffm.rs Outdated
@@ -866,7 +883,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
.as_any()
.downcast_mut::<BlockFFM<optimizer::OptimizerSGD>>()
.unwrap();
block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader)?;
block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader, false)?;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No additional tests?

) -> Result<(), Box<dyn Error>> {
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter)
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why, but I thought this was the same as with the forward/forward_backward pass (aka iterating over everything). For future, lets think about moving the writing/reading to a module that is meant for weights/file handling & not the iteration/block_helpers part

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

@SkBlaz SkBlaz merged commit 067554a into main Jan 15, 2024
3 checks passed
@SkBlaz SkBlaz deleted the fwq branch January 15, 2024 06:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants