Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
starting eval job updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 18, 2024
1 parent 4da64d1 commit 6b3248e
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 67 deletions.
4 changes: 2 additions & 2 deletions examples/configs/lm_harness_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ tracking:
entity: "another-entity"

ray:
use_gpu: True
num_workers: 4
num_cpus: 1
num_gpus: 4
timeout: 3600
1 change: 1 addition & 0 deletions src/flamingo/integrations/wandb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_artifact_directory(artifact: wandb.Artifact) -> str:
case 1:
return list(dir_paths)[0]
case _:
# TODO: Can this be resolved somehow else???
dir_string = ",".join(dir_paths)
raise ValueError(
f"Artifact {artifact.name} references multiple directories: {dir_string}. "
Expand Down
2 changes: 2 additions & 0 deletions src/flamingo/jobs/drivers/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def run_finetuning(config: FinetuningJobConfig):
# Must resume from the just-completed training run
with wandb.init(**config.tracking.wandb_init_args(), resume="must") as run:
artifact_type = ArtifactType.MODEL.value
print(f"Generating {artifact_type} artifact of training results...")

artifact_name = f"{config.tracking.name or config.tracking.run_id}-{artifact_type}"
artifact = wandb.Artifact(artifact_name, type=artifact_type)
artifact.add_reference(f"file://{result.checkpoint.path}/checkpoint")
Expand Down
89 changes: 29 additions & 60 deletions src/flamingo/jobs/drivers/lm_harness.py
Original file line number Diff line number Diff line change
@@ -1,114 +1,83 @@
from pathlib import Path

import lm_eval
import ray
import wandb
from lm_eval.models.huggingface import HFLM
from peft import PeftConfig

from flamingo.integrations.wandb import get_wandb_summary, update_wandb_summary
from flamingo.jobs import LMHarnessJobConfig, ModelNameOrCheckpointPath

from flamingo.integrations.wandb import WandbArtifactLoader
from flamingo.integrations.wandb.utils import resolve_artifact_path
from flamingo.jobs import LMHarnessJobConfig

def resolve_model_or_path(config: LMHarnessJobConfig) -> str:
mn_or_path = None
match config.model_name_or_path:
case None:
print("Attempting to resolve checkpoint path from existing W&B run...")
run_summary = get_wandb_summary(config.wandb_env)
cp = Path(run_summary["ray/checkpoint_path"])
print(f"Using checkpoint path from wandb run: {cp}")
if not cp.exists():
raise (FileNotFoundError(f"{mn_or_path} cannot be found."))
mn_or_path = str(cp)
case ModelNameOrCheckpointPath(checkpoint=None) as x:
print("No checkpoint; will attempt to load model from HuggingFace")
mn_or_path = x.name
case ModelNameOrCheckpointPath(checkpoint=ckpt):
print(f"Checkpoint found; will attempt to load model from {ckpt}")
mn_or_path = ckpt
case _:
raise (
ValueError(
"Something is wrong with the passed "
f"model_name_or_path: {config.model_name_or_path}"
)
)
return mn_or_path

def load_harness_model(config: LMHarnessJobConfig, loader: WandbArtifactLoader) -> HFLM:
model_path = resolve_artifact_path(config.model.path, loader)

def load_harness_model(config: LMHarnessJobConfig, model_to_load: str) -> HFLM:
# We don't know if the checkpoint is adapter weights or merged model weights
# Try to load as an adapter and fall back to the checkpoint containing the full model
try:
adapter_config = PeftConfig.from_pretrained(model_to_load)
adapter_config = PeftConfig.from_pretrained(model_path)
pretrained = adapter_config.base_model_name_or_path
peft = model_to_load
peft = model_path
except ValueError as e:
print(
f"Unable to load model as adapter: {e}. "
"This is expected if the checkpoint does not contain adapter weights."
)
pretrained = model_to_load
pretrained = model_path
peft = None

# Return lm-harness model wrapper class
quantization_kwargs = config.quantization_config.dict() if config.quantization_config else {}
quantization_kwargs = config.quantization.dict() if config.quantization else {}
return HFLM(
pretrained=pretrained,
tokenizer=pretrained,
peft=peft,
device="cuda" if config.num_gpus > 0 else None,
trust_remote_code=config.trust_remote_code,
dtype=config.torch_dtype if config.torch_dtype else "auto",
device="cuda" if config.ray.use_gpu else None,
trust_remote_code=config.model.trust_remote_code,
dtype=config.model.torch_dtype if config.model.torch_dtype else "auto",
**quantization_kwargs,
)


@ray.remote
@ray.remote(num_cpus=)

Check failure on line 42 in src/flamingo/jobs/drivers/lm_harness.py

View workflow job for this annotation

GitHub Actions / pytest_ruff

Ruff (E999)

src/flamingo/jobs/drivers/lm_harness.py:42:22: E999 SyntaxError: Unexpected token ')'
def evaluation_task(config: LMHarnessJobConfig, model_to_load: str) -> None:
print("Initializing lm-harness tasks...")
lm_eval.tasks.initialize_tasks()

print("Running lm-harness evaluation inside remote function...")
llm = load_harness_model(config, model_to_load)
raw_results = lm_eval.simple_evaluate(
wandb_run = None
if config.tracking is not None:
wandb_run = wandb.init(**config.tracking.wandb_init_args(), resume="never")
artifact_loader = WandbArtifactLoader(wandb_run)

llm = load_harness_model(config, artifact_loader)
eval_results = lm_eval.simple_evaluate(
model=llm,
tasks=config.tasks,
num_fewshot=config.num_fewshot,
batch_size=config.batch_size,
limit=config.limit,
log_samples=False,
)
print("Finished lm-harness evaluation inside remote function")
eval_results = eval_results["results"]
print(f"Obtained evaluation results: {eval_results}")

formatted_results = {}
for task_name, metrics in raw_results["results"].items():
task_metrics = {
f"{task_name}/{metric.replace(',', '_')}": value for metric, value in metrics.items()
}
formatted_results.update(task_metrics)
print(f"Obtained evaluation results: {formatted_results}")

if config.wandb_env:
print("Logging results to W&B...")
update_wandb_summary(config.wandb_env, formatted_results)
if config.tracking is not None:
print("Generating table artifact of evaluation results...")
pass


def run_lm_harness(config: LMHarnessJobConfig):
print(f"Received job configuration: {config}")

# Resolve path and ensure exists
model_to_load = resolve_model_or_path(config)

# Using .options() to dynamically specify resource requirements
eval_func = evaluation_task.options(num_cpus=config.num_cpus, num_gpus=config.num_gpus)
eval_future = eval_func.remote(config, model_to_load)
eval_func = evaluation_task.options(num_cpus=config.ray.num_cpus, num_gpus=config.ray.num_gpus)
eval_future = eval_func.remote(config)

timeout_seconds = config.timeout.seconds if config.timeout else None
try:
print("Waiting on evaluation task...")
ray.get(eval_future, timeout=timeout_seconds)
print("Evaluation successfully completed")
print("Evaluation successfully completed!")
except TimeoutError:
print(
f"Evaluation task timed out after {timeout_seconds} sec. "
Expand Down
4 changes: 2 additions & 2 deletions src/flamingo/jobs/lm_harness_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
class LMHarnessRayConfig(BaseFlamingoConfig):
"""Misc settings for Ray compute in the LM harness job."""

use_gpu: bool = True
num_workers: int = 1
num_cpus: int | float = 1
num_gpus: int | float = 1
timeout: datetime.timedelta | None = None


Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/wandb/test_artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ def test_serde_round_trip(wandb_artifact_config):


def test_wandb_path(wandb_artifact_config):
assert wandb_artifact_config.get_wandb_path() == "twitter/cortex/artifact-name:latest"
assert wandb_artifact_config.wandb_path() == "twitter/cortex/artifact-name:latest"
5 changes: 3 additions & 2 deletions tests/jobs/test_lm_harness_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def lm_harness_evaluator_config():
@pytest.fixture
def lm_harness_ray_config():
return LMHarnessRayConfig(
num_workers=4,
use_gpu=True,
num_cpus=2,
num_gpus=4,
timeout=3600,
)


Expand Down

0 comments on commit 6b3248e

Please sign in to comment.