Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #75 from mozilla-ai/davide/prometheus
Browse files Browse the repository at this point in the history
Added v0 of prometheus lm-buddy entrypoint
  • Loading branch information
aittalam authored Mar 12, 2024
2 parents 22f3a80 + e36a1fd commit 88e606a
Show file tree
Hide file tree
Showing 11 changed files with 355 additions and 6 deletions.
35 changes: 35 additions & 0 deletions examples/configs/prometheus/prometheus_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
dataset:
load_from:
name: "wandb_file_artifact_name.json"
version: "latest"
project: "lm-buddy-prometheus"
entity: "mozilla-ai"
# field containing scoring instructions in the json file
text_field: "instruction"

prometheus:
inference:
base_url: "http://your.vllm.server:8000/v1"
engine: "kaist-ai/prometheus-13b-v1.0"
best_of: 1
max_tokens: 512
frequency_penalty: 1.03
temperature: 1.0
top_p: 0.9

evaluation:
# number of times a model is evaluated per sample
num_answers: 3
# max number of retries if a communication error
# with the server occurs
max_retries: 5
# min and max scores as defined in the scoring rubric
min_score: 1
max_score: 5
# enable/disable tqdm to track eval progress
enable_tqdm: True

tracking:
name: "lm-buddy-prometheus"
project: "lm-buddy-examples"
entity: "mozilla-ai"
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ build-backend = "setuptools.build_meta"

[project]
name = "lm-buddy"
version = "0.2.8"
version = "0.3.0"
authors = [
{ name = "Sean Friedowitz", email = "[email protected]" },
{ name = "Aaron Gonzales", email = "[email protected]" },
{ name = "Vicki Boykis", email = "[email protected]" },
{ name = "Davide Eynard", email = "[email protected]" },
]
description = "Ray-centric library for finetuning and evaluation of (large) language models."
readme = "README.md"
Expand Down Expand Up @@ -37,6 +38,8 @@ dependencies = [
# Evaluation frameworks
"lm-eval==0.4.1",
"einops==0.7.0",
"fschat==0.2.36",
"openai==1.3.9",
]

[project.optional-dependencies]
Expand Down
14 changes: 13 additions & 1 deletion src/lm_buddy/cli/run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import click

import lm_buddy
from lm_buddy.jobs.configs import FinetuningJobConfig, LMHarnessJobConfig, SimpleJobConfig
from lm_buddy.jobs.configs import (
FinetuningJobConfig,
LMHarnessJobConfig,
PrometheusJobConfig,
SimpleJobConfig,
)

# TODO(RD2024-125): We should probably collapse all these commands into a single CLI command
# - Need to figure out best way to polymorphically deserialize the job config classes
Expand Down Expand Up @@ -32,3 +37,10 @@ def run_finetuning(config: str) -> None:
def run_lm_harness(config: str) -> None:
config = LMHarnessJobConfig.from_yaml_file(config)
lm_buddy.run_job(config)


@group.command("prometheus", help="Run the prometheus evaluation job.")
@click.option("--config", type=str)
def run_prometheus(config: str) -> None:
config = PrometheusJobConfig.from_yaml_file(config)
lm_buddy.run_job(config)
13 changes: 12 additions & 1 deletion src/lm_buddy/cli/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import click

from lm_buddy.jobs.configs import FinetuningJobConfig, LMHarnessJobConfig, SimpleJobConfig
from lm_buddy.jobs.configs import (
FinetuningJobConfig,
LMHarnessJobConfig,
PrometheusJobConfig,
SimpleJobConfig,
)


@click.group(name="schema", help="Get a job configuration schema.")
Expand All @@ -26,3 +31,9 @@ def schema_finetuning() -> None:
def schema_lm_harness() -> None:
schema = LMHarnessJobConfig.model_json_schema()
click.secho(json.dumps(schema, indent=2))


@group.command("prometheus", help="Schema for the prometheus job configuration.")
def schema_prometheus() -> None:
schema = PrometheusJobConfig.model_json_schema()
click.secho(json.dumps(schema, indent=2))
25 changes: 24 additions & 1 deletion src/lm_buddy/integrations/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,31 @@ class InferenceServerConfig(BaseLMBuddyConfig):
Note: This configuration is intended to be generic and not bound to the interface
of any specific training/evaluation framework. See `LocalChatCompletionConfig`
for intended usage alongside a third-party framework.
or `vLLMCompleptionsConfig` for intended usage alongside a third-party framework.
"""

base_url: str
engine: str | HuggingFaceAssetPath | None = None


class VLLMCompletionsConfig(BaseLMBuddyConfig):
"""Configuration for a vLLM-based completions service
The "local-chat-completions" model is powered by a self-hosted inference server,
specified as an `InferenceServerConfig`. Additional arguments are also provided
to control the tokenizer type and generation parameters.
Note that this is just a subset of the parameters allowed by a vLLM server (see
https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py). If we
choose to use this configuration to cover for more use cases, it will make sense
to add the other supported configuration parameters too.
"""

inference: InferenceServerConfig

# vLLM-specific params
best_of: int | None = None
max_tokens: int | None = None
frequency_penalty: float | None = None
temperature: float | None = None
top_p: float | None = None
35 changes: 35 additions & 0 deletions src/lm_buddy/integrations/wandb/artifact_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,38 @@ def build_table_artifact(
table = wandb.Table(data=table_data, columns=columns)
artifact.add(table, name=table_name)
return artifact


def build_file_artifact(
artifact_name: str,
artifact_type: ArtifactType,
file_path: str | Path,
*,
reference: bool = False,
entry_name: str | None = None,
) -> wandb.Artifact:
"""Build an artifact containing a single file
Args:
artifact_name (str): Name of the artifact
artifact_type (ArtifactType): Type of artifact
file_path (str | Path): The full path (including filename) of the file
Keyword Args:
reference (bool): Only reference the file, do not copy contents. Defaults to False.
entry_name (str | None): Name for the file within the artifact. If None, defaults
to the original filename.
Returns:
wandb.Artifact: The generated artifact.
"""
artifact = wandb.Artifact(name=artifact_name, type=artifact_type)

if reference:
artifact.add_reference(
uri=f"{ArtifactURIScheme.FILE}://{file_path}",
name=entry_name,
)
else:
artifact.add_file(str(file_path), name=entry_name)
return artifact
10 changes: 9 additions & 1 deletion src/lm_buddy/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from lm_buddy.integrations.wandb import ArtifactLoader, WandbArtifactLoader
from lm_buddy.jobs._entrypoints import run_finetuning, run_lm_harness, run_simple
from lm_buddy.jobs._entrypoints import (
run_finetuning,
run_lm_harness,
run_prometheus,
run_simple,
)
from lm_buddy.jobs.configs import (
FinetuningJobConfig,
LMBuddyJobConfig,
LMHarnessJobConfig,
PrometheusJobConfig,
SimpleJobConfig,
)

Expand All @@ -26,5 +32,7 @@ def run_job(
run_finetuning(finetuning_config, artifact_loader)
case LMHarnessJobConfig() as lm_harness_config:
run_lm_harness(lm_harness_config, artifact_loader)
case PrometheusJobConfig() as prometheus_config:
run_prometheus(prometheus_config, artifact_loader)
case _:
raise ValueError(f"Received invalid job configuration: {config}")
3 changes: 2 additions & 1 deletion src/lm_buddy/jobs/_entrypoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from lm_buddy.jobs._entrypoints.finetuning import run_finetuning
from lm_buddy.jobs._entrypoints.lm_harness import run_lm_harness
from lm_buddy.jobs._entrypoints.prometheus import run_prometheus
from lm_buddy.jobs._entrypoints.simple import run_simple

__all__ = ["run_finetuning", "run_lm_harness", "run_simple"]
__all__ = ["run_finetuning", "run_lm_harness", "run_prometheus", "run_simple"]
180 changes: 180 additions & 0 deletions src/lm_buddy/jobs/_entrypoints/prometheus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""
lm-buddy entrypoint to run evaluations using a Prometheus inference server
see https://github.com/kaistAI/prometheus/blob/main/evaluation/benchmark/run_absolute_scoring.py
"""

import copy
import json
from dataclasses import dataclass
from pathlib import Path

from datasets import load_dataset
from fastchat.conversation import get_conv_template
from openai import Completion, OpenAI, OpenAIError
from tqdm import tqdm

from lm_buddy.integrations.huggingface import HuggingFaceAssetLoader
from lm_buddy.integrations.huggingface.tokenizer_config import AutoTokenizerConfig
from lm_buddy.integrations.wandb import (
ArtifactLoader,
ArtifactType,
build_directory_artifact,
wandb_init_from_config,
)
from lm_buddy.jobs.common import LMBuddyJobType
from lm_buddy.jobs.configs import PrometheusJobConfig


@dataclass
class BadResponseError(Exception):
def __init__(self, message, error=None):
self.message = message
self.error = error


def openai_completion(config: PrometheusJobConfig, client: OpenAI, prompt: str) -> Completion:
"""Connects to a remote OpenAI-API-compatible Prometheus endpoint
and returns a Completion holding the model's response.
"""

return client.completions.create(
model=config.prometheus.inference.engine,
prompt=prompt,
best_of=config.prometheus.best_of,
max_tokens=config.prometheus.max_tokens,
frequency_penalty=config.prometheus.frequency_penalty,
temperature=config.prometheus.temperature,
top_p=config.prometheus.top_p,
)


def parse_response(config: PrometheusJobConfig, response: Completion) -> tuple[str, str]:
"""Given a Prometheus eval response as returned by the OpenAI API
endpoint (i.e. in Completion format), extract feedback
and score.
"""

if response is None:
raise BadResponseError("Server returned an empty response")

try:
response_text = response.choices[0].text
# note: this can raise a ValueError if the message is malformed
feedback, score = response_text.split("[RESULT]")
feedback = feedback.strip()
score = score.strip()
if score not in [str(s) for s in config.evaluation.scores]:
raise BadResponseError(f"Score {score} is not in range")
except (ValueError, BadResponseError) as e:
raise BadResponseError(f"Server returned a malformed response ({e})", e)

return feedback, score


def instruction_to_prompt(config: PrometheusJobConfig, instruction: str) -> str:
"""Given some text containing Prometheus instructions, a conversation
template (e.g. "llama-2") and a system message (e.g. "You are a
fair evaluator language model"), generate an actual prompt.
"""
conv = get_conv_template(config.evaluation.conversation_template)
conv.set_system_message(config.evaluation.conversation_system_message)
conv.append_message(conv.roles[0], instruction)
conv.append_message(conv.roles[1], None)
return conv.get_prompt()


def get_response_with_retries(
config: PrometheusJobConfig, client: OpenAI, prompt: str, max_retries: int
) -> tuple[str, str]:
current_retry_attempt = 1
while current_retry_attempt <= config.evaluation.max_retries:
try:
response = openai_completion(config, client, prompt)
feedback, score = parse_response(config, response)
break
except (OpenAIError, BadResponseError) as e:
print(
f"[w] {e.message}, "
f"retrying ({current_retry_attempt}/{config.evaluation.max_retries})"
)
current_retry_attempt += 1
if current_retry_attempt > config.evaluation.max_retries:
raise e
return (feedback, score)


def run_eval(
config: PrometheusJobConfig,
artifact_loader: ArtifactLoader,
client: OpenAI,
) -> str:
# load dataset from W&B artifact
hf_loader = HuggingFaceAssetLoader(artifact_loader)
data = hf_loader.load_dataset(config.dataset)

# get the tokenizer
tokenizer_config = AutoTokenizerConfig(load_from=config.prometheus.inference.engine)
tokenizer = hf_loader.load_pretrained_tokenizer(tokenizer_config)

# enable / disable tqdm
dataset_iterable = tqdm(data) if config.evaluation.enable_tqdm else data

# open the output file for writing and iterate on samples
tracking_name = config.tracking.name if config.tracking is not None else "output.json"
output_fname = Path(config.evaluation.output_folder) / tracking_name
with output_fname.open("w") as file:
for sample in dataset_iterable:
# convert instructions from the dataset (`text_field` in a dict) to
# prompts that prometheus accepts
prompt = instruction_to_prompt(config, sample[config.dataset.text_field])

# skip those examples which are too long
tokenized_prompt = tokenizer(prompt, truncation=False)
if len(tokenized_prompt["input_ids"]) > 3072:
continue

# prepare output
result = copy.deepcopy(sample)
result["prometheus_output"] = []
result["prometheus_score"] = []

for idx in range(config.evaluation.num_answers):
(feedback, score) = get_response_with_retries(
config, client, prompt, config.evaluation.max_retries
)
result["prometheus_output"].append(feedback)
result["prometheus_score"].append(score)

# dump sample results incrementally
file.write(json.dumps(result) + "\n")

# convert plain json dataset in HF format
output_hf_name = str(Path(config.evaluation.output_folder) / "hf" / tracking_name)
ds = load_dataset("json", data_files=str(output_fname), split="train")
ds.save_to_disk(output_hf_name)

return str(output_hf_name)


def run_prometheus(config: PrometheusJobConfig, artifact_loader: ArtifactLoader):
# instantiate OpenAI client to speak with the vLLM endpoint
client = OpenAI(base_url=config.prometheus.inference.base_url)

# Register a dataset file artifact if tracking is enabled
if config.tracking:
with wandb_init_from_config(config.tracking, job_type=LMBuddyJobType.EVALUATION):
# run eval and store output in local filename
output_dataset_name = run_eval(config, artifact_loader, client)

# store HF dataset as a directory artifact
artifact = build_directory_artifact(
dir_path=output_dataset_name,
artifact_name=config.tracking.name,
artifact_type=ArtifactType.DATASET,
reference=False,
)
print("Logging artifact for evaluation results...")
artifact_loader.log_artifact(artifact)
else:
output_dataset_name = run_eval(config, artifact_loader, client)
print(f"Evaluation results stored in {output_dataset_name}")
Loading

0 comments on commit 88e606a

Please sign in to comment.