-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Permit loading of models at different precision at load time for sentence_transformers #331
Comments
So your idea is to safe memory at loading time before performing |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Feature request
Pass in
torch_dtype
in model_kwargs, as supported by sentence_transformers when specifying dtype in the infinity_emb v2 cli when InferenceEngine type is torch.This would allow the loading of the Transformer model at a lower precision at load time instead of post-loading, which could cause an OOM error.
Post loading quantization, e.g. self.half(), would still be needed since it appears the non-transformer Pytorch modules in some models are still in fp32 and would cause issues downstream where matrix computations later would fail due to mixed type computation.
Motivation
While the current code quantizes the precision post loading a model, the issue is that if we're loading full 32-bit float models against a GPU that has a small amount of memory, it could fail to load since the total size of the model would exceed the GPUs maximum memory limit. A use case of this would be seen in small multi-instance GPUs e.g. deploying a 10Gb GPU instance in an NVIDIA A100 in MIG mode.
By specifying the precision at load time, we would be able to load a model without OOM errors and successfully use the model.
Your contribution
Yes. Happy to submit a PR.
Current tests on my end utilize the following code modifications in the init function of the SentenceTransformerPatched class.
but would need further work for better support/handling of other types.
Some thoughts on what can be done for each dtype per the torch attributes here: https://pytorch.org/docs/stable/tensor_attributes.html
auto -> (leave blank)
float32 -> torch.float
float16 -> torch.half
float8 -> (leave blank) ?
int8 -> (leave blank) ?
The 8-bit options were left blank since I noticed that quantization is performed later for these types - unsure if there is a better approach here.
The text was updated successfully, but these errors were encountered: