Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
greatly simplify artifact loading by using builtin download function
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 19, 2024
1 parent 82c6b16 commit 3f8d359
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 106 deletions.
3 changes: 1 addition & 2 deletions src/flamingo/integrations/wandb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from .artifact_config import WandbArtifactConfig, WandbArtifactLoader
from .artifact_config import WandbArtifactConfig
from .artifact_type import ArtifactType
from .run_config import WandbRunConfig

__all__ = [
"ArtifactType",
"WandbArtifactConfig",
"WandbArtifactLoader",
"WandbRunConfig",
]
69 changes: 0 additions & 69 deletions src/flamingo/integrations/wandb/artifact_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from pathlib import Path

import wandb

from flamingo.types import BaseFlamingoConfig


Expand All @@ -18,68 +14,3 @@ def wandb_path(self) -> str:
path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None)
path = f"{path}:{self.version}"
return path


class WandbArtifactLoader:
"""Helper class for loading W&B artifacts and linking them to runs."""

def __init__(self, run: wandb.run):
self._run = run

def load_artifact(self, config: WandbArtifactConfig) -> wandb.Artifact:
"""Load an artifact from the provided config.
If a a W&B run is available, the artifact is loaded via the run as an input.
If not, the artifact is pulled from the W&B API outside of the run.
"""
if self._run is not None:
# Retrieves the artifact and links it as an input to the run
return self._run.use_artifact(config.wandb_path())
else:
# Retrieves the artifact outside of the run
api = wandb.Api()
return api.artifact(config.wandb_path())

def resolve_artifact_path(self, path: str | WandbArtifactConfig) -> str:
"""Resolve the actual filesystem path from an artifact/path reference.
If the provided path is just a string, return the value directly.
If an artifact, load it from W&B (and link it to an in-progress run)
and resolve the filesystem path from the artifact manifest.
"""
match path:
case str():
return path
case WandbArtifactConfig() as artifact_config:
artifact = self.load_artifact(artifact_config)
artifact_path = self._extract_base_path(artifact)
return str(artifact_path)
case _:
raise ValueError(f"Invalid artifact path: {path}")

def _extract_base_path(self, artifact: wandb.Artifact) -> Path:
"""Extract the base filesystem path from entries in an artifact.
An error is raised if the artifact contains ether zero or more than one references
to distinct filesystem directories.
"""
entry_paths = [
e.ref.replace("file://", "")
for e in artifact.manifest.entries.values()
if e.ref.startswith("file://")
]
dir_paths = {Path(e).parent.absolute() for e in entry_paths}
match len(dir_paths):
case 0:
raise ValueError(
f"Artifact {artifact.name} does not contain any filesystem references."
)
case 1:
return list(dir_paths)[0]
case _:
# TODO: Can this be resolved somehow else???
dir_string = ",".join(dir_paths)
raise ValueError(
f"Artifact {artifact.name} references multiple directories: {dir_string}. "
"Unable to determine which directory to load."
)
44 changes: 38 additions & 6 deletions src/flamingo/integrations/wandb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import wandb
from wandb.apis.public import Run as ApiRun

from flamingo.integrations.wandb import ArtifactType, WandbRunConfig
from flamingo.integrations.wandb import ArtifactType, WandbArtifactConfig, WandbRunConfig


@contextlib.contextmanager
Expand All @@ -21,11 +21,6 @@ def wandb_init_from_config(config: WandbRunConfig, *, resume: str | None = None)
yield run


def default_artifact_name(name: str, artifact_type: ArtifactType) -> str:
"""A default name for an artifact based on the run name and type."""
return f"{name}-{artifact_type}"


def get_wandb_api_run(config: WandbRunConfig) -> ApiRun:
"""Retrieve a run from the W&B API."""
api = wandb.Api()
Expand All @@ -43,3 +38,40 @@ def update_wandb_summary(config: WandbRunConfig, metrics: dict[str, Any]) -> Non
run = get_wandb_api_run(config)
run.summary.update(metrics)
run.update()


def get_wandb_artifact(config: WandbArtifactConfig) -> wandb.Artifact:
"""Load an artifact from the provided config.
If a W&B run is active, the artifact is loaded via the run as an input.
If not, the artifact is pulled from the W&B API outside of the run.
"""
if wandb.run is not None:
# Retrieves the artifact and links it as an input to the run
return wandb.run.use_artifact(config.wandb_path())
else:
# Retrieves the artifact outside of the run
api = wandb.Api()
return api.artifact(config.wandb_path())


def resolve_artifact_path(path: str | WandbArtifactConfig) -> str:
"""Resolve the actual filesystem path from an artifact/path reference.
If the provided path is just a string, return the value directly.
If an artifact, download it from W&B (and link it to an in-progress run)
to retrieve the actual data directory.
"""
match path:
case str():
return path
case WandbArtifactConfig() as config:
artifact = get_wandb_artifact(config)
return artifact.download()
case _:
raise ValueError(f"Invalid artifact path: {path}")


def default_artifact_name(name: str, artifact_type: ArtifactType) -> str:
"""A default name for an artifact based on the run name and type."""
return f"{name}-{artifact_type}"
39 changes: 20 additions & 19 deletions src/flamingo/jobs/finetuning/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
from trl import SFTTrainer

from flamingo.integrations.huggingface.utils import load_and_split_dataset
from flamingo.integrations.wandb import ArtifactType, WandbArtifactLoader
from flamingo.integrations.wandb.utils import default_artifact_name, wandb_init_from_config
from flamingo.integrations.wandb import ArtifactType
from flamingo.integrations.wandb.utils import (
default_artifact_name,
resolve_artifact_path,
wandb_init_from_config,
)
from flamingo.jobs.finetuning import FinetuningJobConfig


Expand Down Expand Up @@ -44,8 +48,8 @@ def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments:
)


def load_datasets(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> DatasetDict:
dataset_path = loader.resolve_artifact_path(config.dataset.path)
def load_datasets(config: FinetuningJobConfig) -> DatasetDict:
dataset_path = resolve_artifact_path(config.dataset.path)
# We need to specify a fixed seed to load the datasets on each worker
# Under the hood, HuggingFace uses `accelerate` to create a data loader shard for each worker
# If the datasets are not seeded here, the ordering will be inconsistent between workers
Expand All @@ -59,7 +63,7 @@ def load_datasets(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> D
)


def load_model(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> PreTrainedModel:
def load_model(config: FinetuningJobConfig) -> PreTrainedModel:
device_map, bnb_config = None, None
if config.quantization is not None:
bnb_config = config.quantization.as_huggingface()
Expand All @@ -70,7 +74,7 @@ def load_model(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> PreT
device_map = {"": current_device}
print(f"Setting model device_map = {device_map} to enable quantization")

model_path = loader.resolve_artifact_path(config.model.path)
model_path = resolve_artifact_path(config.model.path)
return AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_path,
trust_remote_code=config.model.trust_remote_code,
Expand All @@ -80,8 +84,8 @@ def load_model(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> PreT
)


def load_tokenizer(config: FinetuningJobConfig, loader: WandbArtifactLoader):
tokenizer_path = loader.resolve_artifact_path(config.tokenizer.path)
def load_tokenizer(config: FinetuningJobConfig):
tokenizer_path = resolve_artifact_path(config.tokenizer.path)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=tokenizer_path,
trust_remote_code=config.tokenizer.trust_remote_code,
Expand All @@ -93,14 +97,13 @@ def load_tokenizer(config: FinetuningJobConfig, loader: WandbArtifactLoader):
return tokenizer


def train_func_with_loader(config: FinetuningJobConfig, loader: WandbArtifactLoader):
training_args = get_training_arguments(config)

def load_and_train(config: FinetuningJobConfig):
# Load the input artifacts, potentially linking them to the active W&B run
datasets = load_datasets(config, loader)
model = load_model(config, loader)
tokenizer = load_tokenizer(config, loader)
datasets = load_datasets(config)
model = load_model(config)
tokenizer = load_tokenizer(config)

training_args = get_training_arguments(config)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
Expand All @@ -119,12 +122,10 @@ def train_func_with_loader(config: FinetuningJobConfig, loader: WandbArtifactLoa
def train_func(config_data: dict):
config = FinetuningJobConfig(**config_data)
if is_tracking_enabled(config):
with wandb_init_from_config(config, resume="never") as run:
loader = WandbArtifactLoader(run=run)
train_func_with_loader(config, loader)
with wandb_init_from_config(config, resume="never"):
load_and_train(config)
else:
loader = WandbArtifactLoader(run=None)
train_func_with_loader(config, loader)
load_and_train(config)


def run_finetuning(config: FinetuningJobConfig):
Expand Down
22 changes: 12 additions & 10 deletions src/flamingo/jobs/lm_harness/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
from lm_eval.models.huggingface import HFLM
from peft import PeftConfig

from flamingo.integrations.wandb import ArtifactType, WandbArtifactLoader
from flamingo.integrations.wandb.utils import default_artifact_name, wandb_init_from_config
from flamingo.integrations.wandb import ArtifactType
from flamingo.integrations.wandb.utils import (
default_artifact_name,
resolve_artifact_path,
wandb_init_from_config,
)
from flamingo.jobs.lm_harness import LMHarnessJobConfig


Expand All @@ -23,8 +27,8 @@ def build_evaluation_artifact(run_name: str, results: dict[str, dict[str, Any]])
return artifact


def load_harness_model(config: LMHarnessJobConfig, loader: WandbArtifactLoader) -> HFLM:
model_path = loader.resolve_artifact_path(config.model.path)
def load_harness_model(config: LMHarnessJobConfig) -> HFLM:
model_path = resolve_artifact_path(config.model.path)

# We don't know if the checkpoint is adapter weights or merged model weights
# Try to load as an adapter and fall back to the checkpoint containing the full model
Expand Down Expand Up @@ -53,11 +57,11 @@ def load_harness_model(config: LMHarnessJobConfig, loader: WandbArtifactLoader)
)


def evaluate_with_loader(config: LMHarnessJobConfig, loader: WandbArtifactLoader) -> dict[str, Any]:
def load_and_evaluate(config: LMHarnessJobConfig) -> dict[str, Any]:
print("Initializing lm-harness tasks...")
lm_eval.tasks.initialize_tasks()

llm = load_harness_model(config, loader)
llm = load_harness_model(config)
eval_results = lm_eval.simple_evaluate(
model=llm,
tasks=config.evaluator.tasks,
Expand All @@ -75,13 +79,11 @@ def evaluate_with_loader(config: LMHarnessJobConfig, loader: WandbArtifactLoader
def evaluation_task(config: LMHarnessJobConfig) -> None:
if config.tracking is not None:
with wandb_init_from_config(config.tracking, resume="never") as run:
artifact_loader = WandbArtifactLoader(run=run)
eval_results = evaluate_with_loader(config, artifact_loader)
eval_results = load_and_evaluate(config)
artifact = build_evaluation_artifact(run.name, eval_results)
run.log_artifact(artifact)
else:
artifact_loader = WandbArtifactLoader(run=None)
evaluate_with_loader(config, artifact_loader)
load_and_evaluate(config)


def run_lm_harness(config: LMHarnessJobConfig):
Expand Down

0 comments on commit 3f8d359

Please sign in to comment.