Initial fix for token corruption when batching #665
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There are 2 problems fixed by 2 code changes in this PR.
Cache over-allocation.
This is a small problem that causes us to over-allocate cache pages in the KV cache. This will require further work to get service.py and {Base,Trie}PagedAttentionCache to allocate a precise & consistent amout of cache, but is sufficient to solve the problem at hand.
Zero-padding of seq_len and start_position
For unused requests in a batch, seq_len and start_position are usually filled with 0. This injects NaNs that are written to page 0.
Page index 0 serves a special padding role in our batching system. It's used to fill unused pages for shorter requests and to pad unused requests within a batch.
Under normal circumstances, NaNs in page 0 wouldn't be problematic since our masking system is designed to ignore values beyond the current token. For example, when generating token 17 with a page list of [255, 254, 0], we should never need to read from the padding page.
The issue stems from our current masking implementation. Instead of directly ignoring values, we mask by adding negative infinity to values before applying an exponential function. While this typically works fine and results in zeroes, it breaks down when encountering NaN values. When this happens, NaN values from page 0 can leak into our calculations, resulting in token corruption.