Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial fix for token corruption when batching #665

Merged
merged 2 commits into from
Dec 10, 2024
Merged

Conversation

renxida
Copy link
Contributor

@renxida renxida commented Dec 9, 2024

There are 2 problems fixed by 2 code changes in this PR.

Cache over-allocation.

This is a small problem that causes us to over-allocate cache pages in the KV cache. This will require further work to get service.py and {Base,Trie}PagedAttentionCache to allocate a precise & consistent amout of cache, but is sufficient to solve the problem at hand.

Zero-padding of seq_len and start_position

For unused requests in a batch, seq_len and start_position are usually filled with 0. This injects NaNs that are written to page 0.

Page index 0 serves a special padding role in our batching system. It's used to fill unused pages for shorter requests and to pad unused requests within a batch.

Under normal circumstances, NaNs in page 0 wouldn't be problematic since our masking system is designed to ignore values beyond the current token. For example, when generating token 17 with a page list of [255, 254, 0], we should never need to read from the padding page.

The issue stems from our current masking implementation. Instead of directly ignoring values, we mask by adding negative infinity to values before applying an exponential function. While this typically works fine and results in zeroes, it breaks down when encountering NaN values. When this happens, NaN values from page 0 can leak into our calculations, resulting in token corruption.

@renxida renxida marked this pull request as ready for review December 9, 2024 21:20
@renxida renxida requested review from stbaione and rsuderman December 9, 2024 23:15
Copy link
Contributor

@stbaione stbaione left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, once pre-commit passes and leak is figured out in ASan

@renxida renxida force-pushed the bslfix branch 2 times, most recently from 6767c34 to dc9544a Compare December 10, 2024 17:14
@renxida renxida enabled auto-merge (squash) December 10, 2024 17:50
@renxida renxida merged commit 4c015d4 into nod-ai:main Dec 10, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants