This repository has been archived by the owner on Sep 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #75 from mozilla-ai/davide/prometheus
Added v0 of prometheus lm-buddy entrypoint
- Loading branch information
Showing
11 changed files
with
355 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
dataset: | ||
load_from: | ||
name: "wandb_file_artifact_name.json" | ||
version: "latest" | ||
project: "lm-buddy-prometheus" | ||
entity: "mozilla-ai" | ||
# field containing scoring instructions in the json file | ||
text_field: "instruction" | ||
|
||
prometheus: | ||
inference: | ||
base_url: "http://your.vllm.server:8000/v1" | ||
engine: "kaist-ai/prometheus-13b-v1.0" | ||
best_of: 1 | ||
max_tokens: 512 | ||
frequency_penalty: 1.03 | ||
temperature: 1.0 | ||
top_p: 0.9 | ||
|
||
evaluation: | ||
# number of times a model is evaluated per sample | ||
num_answers: 3 | ||
# max number of retries if a communication error | ||
# with the server occurs | ||
max_retries: 5 | ||
# min and max scores as defined in the scoring rubric | ||
min_score: 1 | ||
max_score: 5 | ||
# enable/disable tqdm to track eval progress | ||
enable_tqdm: True | ||
|
||
tracking: | ||
name: "lm-buddy-prometheus" | ||
project: "lm-buddy-examples" | ||
entity: "mozilla-ai" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,11 +4,12 @@ build-backend = "setuptools.build_meta" | |
|
||
[project] | ||
name = "lm-buddy" | ||
version = "0.2.8" | ||
version = "0.3.0" | ||
authors = [ | ||
{ name = "Sean Friedowitz", email = "[email protected]" }, | ||
{ name = "Aaron Gonzales", email = "[email protected]" }, | ||
{ name = "Vicki Boykis", email = "[email protected]" }, | ||
{ name = "Davide Eynard", email = "[email protected]" }, | ||
] | ||
description = "Ray-centric library for finetuning and evaluation of (large) language models." | ||
readme = "README.md" | ||
|
@@ -37,6 +38,8 @@ dependencies = [ | |
# Evaluation frameworks | ||
"lm-eval==0.4.1", | ||
"einops==0.7.0", | ||
"fschat==0.2.36", | ||
"openai==1.3.9", | ||
] | ||
|
||
[project.optional-dependencies] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from lm_buddy.jobs._entrypoints.finetuning import run_finetuning | ||
from lm_buddy.jobs._entrypoints.lm_harness import run_lm_harness | ||
from lm_buddy.jobs._entrypoints.prometheus import run_prometheus | ||
from lm_buddy.jobs._entrypoints.simple import run_simple | ||
|
||
__all__ = ["run_finetuning", "run_lm_harness", "run_simple"] | ||
__all__ = ["run_finetuning", "run_lm_harness", "run_prometheus", "run_simple"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
""" | ||
lm-buddy entrypoint to run evaluations using a Prometheus inference server | ||
see https://github.com/kaistAI/prometheus/blob/main/evaluation/benchmark/run_absolute_scoring.py | ||
""" | ||
|
||
import copy | ||
import json | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
|
||
from datasets import load_dataset | ||
from fastchat.conversation import get_conv_template | ||
from openai import Completion, OpenAI, OpenAIError | ||
from tqdm import tqdm | ||
|
||
from lm_buddy.integrations.huggingface import HuggingFaceAssetLoader | ||
from lm_buddy.integrations.huggingface.tokenizer_config import AutoTokenizerConfig | ||
from lm_buddy.integrations.wandb import ( | ||
ArtifactLoader, | ||
ArtifactType, | ||
build_directory_artifact, | ||
wandb_init_from_config, | ||
) | ||
from lm_buddy.jobs.common import LMBuddyJobType | ||
from lm_buddy.jobs.configs import PrometheusJobConfig | ||
|
||
|
||
@dataclass | ||
class BadResponseError(Exception): | ||
def __init__(self, message, error=None): | ||
self.message = message | ||
self.error = error | ||
|
||
|
||
def openai_completion(config: PrometheusJobConfig, client: OpenAI, prompt: str) -> Completion: | ||
"""Connects to a remote OpenAI-API-compatible Prometheus endpoint | ||
and returns a Completion holding the model's response. | ||
""" | ||
|
||
return client.completions.create( | ||
model=config.prometheus.inference.engine, | ||
prompt=prompt, | ||
best_of=config.prometheus.best_of, | ||
max_tokens=config.prometheus.max_tokens, | ||
frequency_penalty=config.prometheus.frequency_penalty, | ||
temperature=config.prometheus.temperature, | ||
top_p=config.prometheus.top_p, | ||
) | ||
|
||
|
||
def parse_response(config: PrometheusJobConfig, response: Completion) -> tuple[str, str]: | ||
"""Given a Prometheus eval response as returned by the OpenAI API | ||
endpoint (i.e. in Completion format), extract feedback | ||
and score. | ||
""" | ||
|
||
if response is None: | ||
raise BadResponseError("Server returned an empty response") | ||
|
||
try: | ||
response_text = response.choices[0].text | ||
# note: this can raise a ValueError if the message is malformed | ||
feedback, score = response_text.split("[RESULT]") | ||
feedback = feedback.strip() | ||
score = score.strip() | ||
if score not in [str(s) for s in config.evaluation.scores]: | ||
raise BadResponseError(f"Score {score} is not in range") | ||
except (ValueError, BadResponseError) as e: | ||
raise BadResponseError(f"Server returned a malformed response ({e})", e) | ||
|
||
return feedback, score | ||
|
||
|
||
def instruction_to_prompt(config: PrometheusJobConfig, instruction: str) -> str: | ||
"""Given some text containing Prometheus instructions, a conversation | ||
template (e.g. "llama-2") and a system message (e.g. "You are a | ||
fair evaluator language model"), generate an actual prompt. | ||
""" | ||
conv = get_conv_template(config.evaluation.conversation_template) | ||
conv.set_system_message(config.evaluation.conversation_system_message) | ||
conv.append_message(conv.roles[0], instruction) | ||
conv.append_message(conv.roles[1], None) | ||
return conv.get_prompt() | ||
|
||
|
||
def get_response_with_retries( | ||
config: PrometheusJobConfig, client: OpenAI, prompt: str, max_retries: int | ||
) -> tuple[str, str]: | ||
current_retry_attempt = 1 | ||
while current_retry_attempt <= config.evaluation.max_retries: | ||
try: | ||
response = openai_completion(config, client, prompt) | ||
feedback, score = parse_response(config, response) | ||
break | ||
except (OpenAIError, BadResponseError) as e: | ||
print( | ||
f"[w] {e.message}, " | ||
f"retrying ({current_retry_attempt}/{config.evaluation.max_retries})" | ||
) | ||
current_retry_attempt += 1 | ||
if current_retry_attempt > config.evaluation.max_retries: | ||
raise e | ||
return (feedback, score) | ||
|
||
|
||
def run_eval( | ||
config: PrometheusJobConfig, | ||
artifact_loader: ArtifactLoader, | ||
client: OpenAI, | ||
) -> str: | ||
# load dataset from W&B artifact | ||
hf_loader = HuggingFaceAssetLoader(artifact_loader) | ||
data = hf_loader.load_dataset(config.dataset) | ||
|
||
# get the tokenizer | ||
tokenizer_config = AutoTokenizerConfig(load_from=config.prometheus.inference.engine) | ||
tokenizer = hf_loader.load_pretrained_tokenizer(tokenizer_config) | ||
|
||
# enable / disable tqdm | ||
dataset_iterable = tqdm(data) if config.evaluation.enable_tqdm else data | ||
|
||
# open the output file for writing and iterate on samples | ||
tracking_name = config.tracking.name if config.tracking is not None else "output.json" | ||
output_fname = Path(config.evaluation.output_folder) / tracking_name | ||
with output_fname.open("w") as file: | ||
for sample in dataset_iterable: | ||
# convert instructions from the dataset (`text_field` in a dict) to | ||
# prompts that prometheus accepts | ||
prompt = instruction_to_prompt(config, sample[config.dataset.text_field]) | ||
|
||
# skip those examples which are too long | ||
tokenized_prompt = tokenizer(prompt, truncation=False) | ||
if len(tokenized_prompt["input_ids"]) > 3072: | ||
continue | ||
|
||
# prepare output | ||
result = copy.deepcopy(sample) | ||
result["prometheus_output"] = [] | ||
result["prometheus_score"] = [] | ||
|
||
for idx in range(config.evaluation.num_answers): | ||
(feedback, score) = get_response_with_retries( | ||
config, client, prompt, config.evaluation.max_retries | ||
) | ||
result["prometheus_output"].append(feedback) | ||
result["prometheus_score"].append(score) | ||
|
||
# dump sample results incrementally | ||
file.write(json.dumps(result) + "\n") | ||
|
||
# convert plain json dataset in HF format | ||
output_hf_name = str(Path(config.evaluation.output_folder) / "hf" / tracking_name) | ||
ds = load_dataset("json", data_files=str(output_fname), split="train") | ||
ds.save_to_disk(output_hf_name) | ||
|
||
return str(output_hf_name) | ||
|
||
|
||
def run_prometheus(config: PrometheusJobConfig, artifact_loader: ArtifactLoader): | ||
# instantiate OpenAI client to speak with the vLLM endpoint | ||
client = OpenAI(base_url=config.prometheus.inference.base_url) | ||
|
||
# Register a dataset file artifact if tracking is enabled | ||
if config.tracking: | ||
with wandb_init_from_config(config.tracking, job_type=LMBuddyJobType.EVALUATION): | ||
# run eval and store output in local filename | ||
output_dataset_name = run_eval(config, artifact_loader, client) | ||
|
||
# store HF dataset as a directory artifact | ||
artifact = build_directory_artifact( | ||
dir_path=output_dataset_name, | ||
artifact_name=config.tracking.name, | ||
artifact_type=ArtifactType.DATASET, | ||
reference=False, | ||
) | ||
print("Logging artifact for evaluation results...") | ||
artifact_loader.log_artifact(artifact) | ||
else: | ||
output_dataset_name = run_eval(config, artifact_loader, client) | ||
print(f"Evaluation results stored in {output_dataset_name}") |
Oops, something went wrong.