Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #95 from mozilla-ai/sfriedowitz/patch-prometheus
Browse files Browse the repository at this point in the history
Move client inside data generator becuase its not serializable
  • Loading branch information
Sean Friedowitz authored Apr 9, 2024
2 parents 706ac29 + 2fb8a05 commit fa3a054
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "lm-buddy"
version = "0.10.0"
version = "0.10.1"
authors = [
{ name = "Sean Friedowitz", email = "[email protected]" },
{ name = "Aaron Gonzales", email = "[email protected]" },
Expand Down
15 changes: 8 additions & 7 deletions src/lm_buddy/jobs/evaluation/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_response_with_retries(
feedback, score = parse_response(config, response)
break
except (OpenAIError, BadResponseError) as e:
logger.warn(
logger.warning(
f"{e.message}: "
f"Retrying ({current_retry_attempt}/{config.evaluation.max_retries})"
)
Expand All @@ -111,12 +111,8 @@ def get_response_with_retries(


def run_eval(config: PrometheusJobConfig) -> Path:
# Instantiate OpenAI client to speak with the vLLM endpoint
client = OpenAI(base_url=config.prometheus.inference.base_url)

hf_loader = HuggingFaceAssetLoader()

# Resolve the engine model
hf_loader = HuggingFaceAssetLoader()
engine_path = hf_loader.resolve_asset_path(config.prometheus.inference.engine)

# Load dataset from W&B artifact
Expand All @@ -135,6 +131,11 @@ def run_eval(config: PrometheusJobConfig) -> Path:

# Generator that iterates over samples and yields new rows with the prometheus outputs
def data_generator():
# Instantiate OpenAI client to speak with the vLLM endpoint
# Client is non-serializable so must be instantiated internal to this method
# Reference: https://huggingface.co/docs/datasets/en/troubleshoot#pickling-issues
client = OpenAI(base_url=config.prometheus.inference.base_url)

for sample in dataset_iterable:
# convert instructions from the dataset (`text_field` in a dict) to
# prompts that prometheus accepts
Expand All @@ -143,7 +144,7 @@ def data_generator():
# skip those examples which are too long
tokenized_prompt = tokenizer(prompt, truncation=False)
if len(tokenized_prompt["input_ids"]) > 3072:
logger.warn(f"Skipping row due to prompt exceeding token limit: {prompt=}")
logger.warning(f"Skipping row due to prompt exceeding token limit: {prompt=}")
continue

# prepare output
Expand Down

0 comments on commit fa3a054

Please sign in to comment.