Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

early stop when all sequence reach EOS #57

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

je1lee
Copy link

@je1lee je1lee commented Apr 9, 2024

With model.generate() it takes too long even sequence generation have done earlier with EOS token. Because now, it generates til it reached to output_len

fix the generate method to stop when every sequence has generated EOS token

@je1lee
Copy link
Author

je1lee commented Apr 16, 2024

@pengchongjin any idea for this?

@pengchongjin
Copy link
Collaborator

Thanks for the change. Could you please paste a few example outputs before and after this change?

Also please make sure to test both run.py and run_xla.py. Thanks!

@je1lee
Copy link
Author

je1lee commented Jun 3, 2024

@pengchongjin
test done with both scripts

BEFORE
스크린샷 2024-06-03 오후 2 59 25

model generates token regardless of eos token, so time spent in generation increases quadratically as output_len increases

AFTER
스크린샷 2024-06-03 오후 2 58 04

model stop generate when model samples out eos token time spent in generation remain still as output_len increases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants