Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Added v0 of prometheus lm-buddy entrypoint #75

Merged
merged 19 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/configs/prometheus/prometheus_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
dataset:
load_from:
name: "wandb_file_artifact_name.json"
version: "latest"
project: "lm-buddy-prometheus"
entity: "mozilla-ai"
# field containing scoring instructions in the json file
text_field: "instruction"

prometheus:
inference:
base_url: "http://your.vllm.server:8000/v1"
sfriedowitz marked this conversation as resolved.
Show resolved Hide resolved
tokenizer:
load_from: "meta-llama/Llama-2-7b-chat-hf"
max_tokens: 256
# number of times the model is called per sample
num_answers: 3

tracking:
name: "lm-buddy-prometheus"
project: "lm-buddy-examples"
entity: "mozilla-ai"
14 changes: 13 additions & 1 deletion src/lm_buddy/cli/run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import click

import lm_buddy
from lm_buddy.jobs.configs import FinetuningJobConfig, LMHarnessJobConfig, SimpleJobConfig
from lm_buddy.jobs.configs import (
FinetuningJobConfig,
LMHarnessJobConfig,
PrometheusJobConfig,
SimpleJobConfig,
)

# TODO(RD2024-125): We should probably collapse all these commands into a single CLI command
# - Need to figure out best way to polymorphically deserialize the job config classes
Expand Down Expand Up @@ -32,3 +37,10 @@ def run_finetuning(config: str) -> None:
def run_lm_harness(config: str) -> None:
config = LMHarnessJobConfig.from_yaml_file(config)
lm_buddy.run_job(config)


@group.command("prometheus", help="Run the prometheus evaluation job.")
sfriedowitz marked this conversation as resolved.
Show resolved Hide resolved
@click.option("--config", type=str)
def run_prometheus(config: str) -> None:
config = PrometheusJobConfig.from_yaml_file(config)
lm_buddy.run_job(config)
13 changes: 12 additions & 1 deletion src/lm_buddy/cli/schema.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import json

import click

from lm_buddy.jobs.configs import FinetuningJobConfig, LMHarnessJobConfig, SimpleJobConfig
from lm_buddy.jobs.configs import (
FinetuningJobConfig,
LMHarnessJobConfig,
PrometheusJobConfig,
SimpleJobConfig,
)


@click.group(name="schema", help="Get a job configuration schema.")

Check failure on line 13 in src/lm_buddy/cli/schema.py

View workflow job for this annotation

GitHub Actions / PR Checks

Ruff (I001)

src/lm_buddy/cli/schema.py:1:1: I001 Import block is un-sorted or un-formatted
def group():
pass

Expand All @@ -26,3 +31,9 @@
def schema_lm_harness() -> None:
schema = LMHarnessJobConfig.model_json_schema()
click.secho(json.dumps(schema, indent=2))


@group.command("prometheus", help="Schema for the prometheus job configuration.")
def schema_prometheus() -> None:
schema = PrometheusJobConfig.model_json_schema()
click.secho(json.dumps(schema, indent=2))
37 changes: 36 additions & 1 deletion src/lm_buddy/integrations/wandb/artifact_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from enum import Enum
from pathlib import Path
from typing import Any
from urllib.parse import ParseResult, urlparse

import wandb

import os

Check failure on line 7 in src/lm_buddy/integrations/wandb/artifact_utils.py

View workflow job for this annotation

GitHub Actions / PR Checks

Ruff (F401)

src/lm_buddy/integrations/wandb/artifact_utils.py:7:8: F401 `os` imported but unused

class ArtifactType(str, Enum):

Check failure on line 9 in src/lm_buddy/integrations/wandb/artifact_utils.py

View workflow job for this annotation

GitHub Actions / PR Checks

Ruff (I001)

src/lm_buddy/integrations/wandb/artifact_utils.py:1:1: I001 Import block is un-sorted or un-formatted
"""Enumeration of artifact types used by the LM Buddy."""
Expand Down Expand Up @@ -110,3 +110,38 @@
table = wandb.Table(data=table_data, columns=columns)
artifact.add(table, name=table_name)
return artifact


def build_file_artifact(
aittalam marked this conversation as resolved.
Show resolved Hide resolved
artifact_name: str,
artifact_type: ArtifactType,
file_path: str | Path,
*,
reference: bool = False,
entry_name: str | None = None,
) -> wandb.Artifact:
"""Build an artifact containing a single file

Args:
artifact_name (str): Name of the artifact
artifact_type (ArtifactType): Type of artifact
file_path (str | Path): The full path (including filename) of the file

Keyword Args:
reference (bool): Only reference the file, do not copy contents. Defaults to False.
entry_name (str | None): Name for the file within the artifact. If None, defaults
to the original filename.

Returns:
wandb.Artifact: The generated artifact.
"""
artifact = wandb.Artifact(name=artifact_name, type=artifact_type)

if reference:
artifact.add_reference(
uri=f"{ArtifactURIScheme.FILE}://{file_path}",
name=entry_name,
)
else:
artifact.add_file(str(file_path), name=entry_name)
return artifact
10 changes: 9 additions & 1 deletion src/lm_buddy/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from lm_buddy.integrations.wandb import ArtifactLoader, WandbArtifactLoader
from lm_buddy.jobs._entrypoints import run_finetuning, run_lm_harness, run_simple
from lm_buddy.jobs._entrypoints import (
run_finetuning,
run_lm_harness,
sfriedowitz marked this conversation as resolved.
Show resolved Hide resolved
run_prometheus,
run_simple,
)
from lm_buddy.jobs.configs import (
FinetuningJobConfig,
LMBuddyJobConfig,
LMHarnessJobConfig,
PrometheusJobConfig,
SimpleJobConfig,
)

Expand All @@ -26,5 +32,7 @@ def run_job(
run_finetuning(finetuning_config, artifact_loader)
case LMHarnessJobConfig() as lm_harness_config:
run_lm_harness(lm_harness_config, artifact_loader)
case PrometheusJobConfig() as prometheus_config:
run_prometheus(prometheus_config, artifact_loader)
case _:
raise ValueError(f"Received invalid job configuration: {config}")
3 changes: 2 additions & 1 deletion src/lm_buddy/jobs/_entrypoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from lm_buddy.jobs._entrypoints.finetuning import run_finetuning
from lm_buddy.jobs._entrypoints.lm_harness import run_lm_harness
from lm_buddy.jobs._entrypoints.prometheus import run_prometheus
from lm_buddy.jobs._entrypoints.simple import run_simple

__all__ = ["run_finetuning", "run_lm_harness", "run_simple"]
__all__ = ["run_finetuning", "run_lm_harness", "run_prometheus", "run_simple"]
128 changes: 128 additions & 0 deletions src/lm_buddy/jobs/_entrypoints/prometheus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from lm_buddy.jobs.configs import PrometheusJobConfig
from lm_buddy.integrations.huggingface import HuggingFaceAssetLoader
from lm_buddy.integrations.wandb import (
ArtifactType,
ArtifactLoader,
build_file_artifact,
wandb_init_from_config,
)
from fastchat.conversation import get_conv_template
from transformers import AutoTokenizer

Check failure on line 10 in src/lm_buddy/jobs/_entrypoints/prometheus.py

View workflow job for this annotation

GitHub Actions / PR Checks

Ruff (F401)

src/lm_buddy/jobs/_entrypoints/prometheus.py:10:26: F401 `transformers.AutoTokenizer` imported but unused
from openai import OpenAIError, OpenAI

from tqdm import tqdm
import os
import json
import copy

class BadResponseException(Exception):

Check failure on line 18 in src/lm_buddy/jobs/_entrypoints/prometheus.py

View workflow job for this annotation

GitHub Actions / PR Checks

Ruff (I001)

src/lm_buddy/jobs/_entrypoints/prometheus.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 18 in src/lm_buddy/jobs/_entrypoints/prometheus.py

View workflow job for this annotation

GitHub Actions / PR Checks

Ruff (N818)

src/lm_buddy/jobs/_entrypoints/prometheus.py:18:7: N818 Exception name `BadResponseException` should be named with an Error suffix
aittalam marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, message, error):
self.message = message
self.error = error


def openai_completion(config, client, prompt):
aittalam marked this conversation as resolved.
Show resolved Hide resolved
return client.completions.create(
model = "kaist-ai/prometheus-13b-v1.0",
aittalam marked this conversation as resolved.
Show resolved Hide resolved
prompt = prompt,
best_of = config.prometheus.best_of,
max_tokens = config.prometheus.max_tokens,
frequency_penalty = config.prometheus.frequency_penalty,
temperature = config.prometheus.temperature,
top_p = config.prometheus.top_p
)


def parse_response(response):
try:
aittalam marked this conversation as resolved.
Show resolved Hide resolved
assert response is not None
response_text = response.choices[0].text
feedback, score = response_text.split('[RESULT]')
feedback = feedback.strip()
score = score.strip()
assert score in ["1","2","3","4","5"]
except (ValueError, AssertionError) as e:
raise BadResponseException("Server returned a bad response", e)

return feedback, score


def instruction_to_prompt(instruction):
conv = get_conv_template("llama-2")
aittalam marked this conversation as resolved.
Show resolved Hide resolved
conv.set_system_message("You are a fair evaluator language model.")
conv.append_message(conv.roles[0], instruction)
conv.append_message(conv.roles[1], None)
return conv.get_prompt()


def run_prometheus(config: PrometheusJobConfig, artifact_loader: ArtifactLoader):

# load dataset from W&B artifact
hf_loader = HuggingFaceAssetLoader(artifact_loader)
artifact_path,_ = hf_loader.resolve_asset_path(config.dataset.load_from)
dataset_fname = os.path.join(artifact_path, config.dataset.load_from.name)

Check failure on line 63 in src/lm_buddy/jobs/_entrypoints/prometheus.py

View workflow job for this annotation

GitHub Actions / PR Checks

Ruff (PTH118)

src/lm_buddy/jobs/_entrypoints/prometheus.py:63:21: PTH118 `os.path.join()` should be replaced by `Path` with `/` operator
aittalam marked this conversation as resolved.
Show resolved Hide resolved

with open(dataset_fname,'r') as f:

Check failure on line 65 in src/lm_buddy/jobs/_entrypoints/prometheus.py

View workflow job for this annotation

GitHub Actions / PR Checks

Ruff (UP015)

src/lm_buddy/jobs/_entrypoints/prometheus.py:65:10: UP015 Unnecessary open mode parameters

Check failure on line 65 in src/lm_buddy/jobs/_entrypoints/prometheus.py

View workflow job for this annotation

GitHub Actions / PR Checks

Ruff (PTH123)

src/lm_buddy/jobs/_entrypoints/prometheus.py:65:10: PTH123 `open()` should be replaced by `Path.open()`
# eval samples are JSON-encoded, each takes one line in the dataset file
data = [json.loads(line) for line in f.readlines()]

# get the tokenizer
tokenizer = hf_loader.load_pretrained_tokenizer(config.prometheus.tokenizer)

# instantiate OpenAI client to speak with the vLLM endpoint
client = OpenAI(
base_url = config.prometheus.inference.base_url
)

# open the output file for writing and iterate on samples
output_fname = os.path.join("/tmp", config.tracking.name)

Check failure on line 78 in src/lm_buddy/jobs/_entrypoints/prometheus.py

View workflow job for this annotation

GitHub Actions / PR Checks

Ruff (PTH118)

src/lm_buddy/jobs/_entrypoints/prometheus.py:78:20: PTH118 `os.path.join()` should be replaced by `Path` with `/` operator
aittalam marked this conversation as resolved.
Show resolved Hide resolved
with open(output_fname,'w') as file:
for sample in tqdm(data):
aittalam marked this conversation as resolved.
Show resolved Hide resolved
# convert instructions from the dataset (`text_field` in a dict) to
# prompts that prometheus accepts
prompt = instruction_to_prompt(sample[config.dataset.text_field])

# skip those examples which are too long
tokenized_prompt = tokenizer(prompt, truncation=False)
if(len(tokenized_prompt['input_ids'])>3072):
continue
aittalam marked this conversation as resolved.
Show resolved Hide resolved

# prepare output
result = copy.deepcopy(sample)
result['prometheus_output'] = []
result['prometheus_score'] = []

for idx in range(config.prometheus.num_answers):

i = 0
aittalam marked this conversation as resolved.
Show resolved Hide resolved
while i < config.prometheus.max_retries:
try:
response = openai_completion(config, client, prompt)
feedback, score = parse_response(response)
print(feedback, score)
aittalam marked this conversation as resolved.
Show resolved Hide resolved
break
except (OpenAIError, BadResponseException) as e:
print(f"[w] {e.message}, retrying ({i+1}/{config.prometheus.max_retries})")
i += 1
if i == config.prometheus.max_retries:
raise e

result['prometheus_output'].append(feedback)
result['prometheus_score'].append(score)

# dump sample results
file.write(json.dumps(result)+"\n")


# Register a dataset file artifact if tracking is enabled
if config.tracking:

with wandb_init_from_config(config.tracking) as run:
aittalam marked this conversation as resolved.
Show resolved Hide resolved
file_artifact = build_file_artifact(
artifact_name = config.tracking.name,
artifact_type = ArtifactType.DATASET,
file_path = output_fname,
reference = False,
)
print("[i] Logging artifact for evaluation results...")
aittalam marked this conversation as resolved.
Show resolved Hide resolved
artifact_loader.log_artifact(file_artifact)
3 changes: 3 additions & 0 deletions src/lm_buddy/jobs/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
LMHarnessJobConfig,
LocalChatCompletionsConfig,
)
from lm_buddy.jobs.configs.prometheus import PrometheusCompletionsConfig, PrometheusJobConfig
from lm_buddy.jobs.configs.simple import SimpleJobConfig

__all__ = [
Expand All @@ -15,5 +16,7 @@
"LMHarnessEvaluatorConfig",
"LMHarnessJobConfig",
"LocalChatCompletionsConfig",
"PrometheusCompletionsConfig",
"PrometheusJobConfig",
"SimpleJobConfig",
]
42 changes: 42 additions & 0 deletions src/lm_buddy/jobs/configs/prometheus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Literal

from pydantic import conlist, model_validator

from lm_buddy.types import BaseLMBuddyConfig
from lm_buddy.jobs.configs import LMBuddyJobConfig
from lm_buddy.integrations.wandb import WandbRunConfig
from lm_buddy.integrations.vllm import InferenceServerConfig
from lm_buddy.integrations.huggingface import TextDatasetConfig, AutoTokenizerConfig

class PrometheusCompletionsConfig(BaseLMBuddyConfig):
"""Configuration for a "local-completions" prometheus model.

The prometheus model is powered by a self-hosted inference server, specified
as an `InferenceServerConfig`. Additional arguments are also provided
to control the tokenizer type and generation parameters.
"""

inference: InferenceServerConfig

# vLLM-served model params
best_of: int = 1
max_tokens: int = 512
frequency_penalty: float = 1.03
temperature: float = 1.0
top_p: float = 0.9

# evaluation script params
aittalam marked this conversation as resolved.
Show resolved Hide resolved
tokenizer: AutoTokenizerConfig | None = None
num_answers: int = 3
max_retries: int = 5


class PrometheusJobConfig(LMBuddyJobConfig):
"""Configuration to run a prometheus evaluation job."""

# dataset (json artifact from which we'll extract `text_field`)
dataset: TextDatasetConfig
aittalam marked this conversation as resolved.
Show resolved Hide resolved
# details for our self-hosted prometheus endpoint
prometheus: PrometheusCompletionsConfig
# wandb experiment tracking details
tracking: WandbRunConfig | None = None