-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Eleuther LM-Eval-Harness in Levanter #675
Conversation
whole_enc = self.tokenizer(context + completion) | ||
context_enc = self.tokenizer(context) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Must we run tokenizer twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's how it's done in lm harness and in side alpaca. it's the easiest thing and not a bottleneck
|
||
task: str | ||
task_alias: str | None = None | ||
num_fewshot: int | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does None represent that's not representable by an integer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
none means the default for the task, whatever it is
return [task.to_dict() if isinstance(task, TaskConfig) else task for task in self.task_spec] | ||
|
||
def to_task_dict(self) -> dict: | ||
import lm_eval.tasks as tasks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docstring on what this function is doing / what it's for?
src/levanter/eval_harness.py
Outdated
|
||
EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) | ||
harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer) | ||
# we always log_samples here and filter out the samples later if we don't want them |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log_samples is a verb?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i mean, it's a verb phrase :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, we don't need this behavior actually because i can't do the metrics I want to do with the samples anyhow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh i lied, we do
|
||
NAT_TO_BIT = 1 / np.log(2) | ||
|
||
# eval_harness isn't consistent enough for this to actually be workable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this mean? lm evaluation harness is just a framework, so whether it's meaningful or not depends on the actual eval and the model size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well like, the samples for multiple choice tasks don't all have the answer in a standard format even though they easily could (and must internall at some point)
return self.trainer.EvalBatch | ||
|
||
@cached_property | ||
def the_tokenizer(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not obvious what this function is doing from the name...maybe tokenizer_object
or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i use this convention throughout levanter...
to_log = {} | ||
for task_name, task_results in report["results"].items(): | ||
for metric_name, metric_value in task_results.items(): | ||
if metric_name.endswith(",none"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this hackery? put assumptions in comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they just add ,none
to all the default metrics for some reason. e.g. it's acc,none
, acc_norm,none
etc etc
) | ||
|
||
if jax.process_index() == 0: | ||
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tmp file that doesn't get deleted doesn't sound good - I'd delete or put the file in the evaluation directory (which I guess would have to be passed in as an explicit location)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's delete=False so that the wandb process still can upload it
@@ -122,6 +126,14 @@ def main(config: TrainLmConfig): | |||
Pos = config.model.Pos | |||
KeyPos = config.model.KeyPos | |||
|
|||
# to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this moved up?
ignore_id: Optional[int] = None, | ||
all_causal: bool = True, | ||
) -> "LmExample": | ||
# mask out the prompt tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring
src/levanter/eval_harness.py
Outdated
task_dict = tasks.get_task_dict([task], manager) | ||
this_task = task_dict.popitem()[1] | ||
# hacky, but this allows us to run multiple instances of the same task with different fewshot settings | ||
this_task.config.task = our_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this works well since the result file is saved as task.jsonl
so using the alias helps us distinguish the number of shots. It might be worth it to move this logic into a helper function create_task_with_unique_name()
for readability...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea!
src/levanter/eval_harness.py
Outdated
return outputs | ||
|
||
|
||
def _actually_run_eval_harness(config: LmEvalHarnessConfig, model, tasks_to_run, tokenizer, EvalBatch, axis_resources): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typing + docstring for outputs would be useful here
return outputs | ||
|
||
|
||
def _compute_averages(outputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated to this PR specifically but we could add more ways to aggregate results- the DCLM paper for example reports centered accuracy. Maybe we can have this aggregation function be something we pass into LmEvalHarnessConfig
. Or something like subtract_random_baseline: true
in the YAML config...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imho we should modify lm harness to do that, but it's a great idea
Adds Eleuther's LM Eval Harness as a callback in Levanter. It's much slower than it needs to be because I'm not doing any sequence packing, but it gets the job done. Scores on Llama 3 seem reasonable, so I think this is right.
Closes #564