From 3f8d3596141f14a99333f6d3b0b7fefb805dece2 Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Thu, 18 Jan 2024 16:09:19 -0800 Subject: [PATCH] greatly simplify artifact loading by using builtin download function --- src/flamingo/integrations/wandb/__init__.py | 3 +- .../integrations/wandb/artifact_config.py | 69 ------------------- src/flamingo/integrations/wandb/utils.py | 44 ++++++++++-- src/flamingo/jobs/finetuning/entrypoint.py | 39 ++++++----- src/flamingo/jobs/lm_harness/entrypoint.py | 22 +++--- 5 files changed, 71 insertions(+), 106 deletions(-) diff --git a/src/flamingo/integrations/wandb/__init__.py b/src/flamingo/integrations/wandb/__init__.py index 86144c6a..6b37141b 100644 --- a/src/flamingo/integrations/wandb/__init__.py +++ b/src/flamingo/integrations/wandb/__init__.py @@ -1,10 +1,9 @@ -from .artifact_config import WandbArtifactConfig, WandbArtifactLoader +from .artifact_config import WandbArtifactConfig from .artifact_type import ArtifactType from .run_config import WandbRunConfig __all__ = [ "ArtifactType", "WandbArtifactConfig", - "WandbArtifactLoader", "WandbRunConfig", ] diff --git a/src/flamingo/integrations/wandb/artifact_config.py b/src/flamingo/integrations/wandb/artifact_config.py index bbc835c1..797cacde 100644 --- a/src/flamingo/integrations/wandb/artifact_config.py +++ b/src/flamingo/integrations/wandb/artifact_config.py @@ -1,7 +1,3 @@ -from pathlib import Path - -import wandb - from flamingo.types import BaseFlamingoConfig @@ -18,68 +14,3 @@ def wandb_path(self) -> str: path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None) path = f"{path}:{self.version}" return path - - -class WandbArtifactLoader: - """Helper class for loading W&B artifacts and linking them to runs.""" - - def __init__(self, run: wandb.run): - self._run = run - - def load_artifact(self, config: WandbArtifactConfig) -> wandb.Artifact: - """Load an artifact from the provided config. - - If a a W&B run is available, the artifact is loaded via the run as an input. - If not, the artifact is pulled from the W&B API outside of the run. - """ - if self._run is not None: - # Retrieves the artifact and links it as an input to the run - return self._run.use_artifact(config.wandb_path()) - else: - # Retrieves the artifact outside of the run - api = wandb.Api() - return api.artifact(config.wandb_path()) - - def resolve_artifact_path(self, path: str | WandbArtifactConfig) -> str: - """Resolve the actual filesystem path from an artifact/path reference. - - If the provided path is just a string, return the value directly. - If an artifact, load it from W&B (and link it to an in-progress run) - and resolve the filesystem path from the artifact manifest. - """ - match path: - case str(): - return path - case WandbArtifactConfig() as artifact_config: - artifact = self.load_artifact(artifact_config) - artifact_path = self._extract_base_path(artifact) - return str(artifact_path) - case _: - raise ValueError(f"Invalid artifact path: {path}") - - def _extract_base_path(self, artifact: wandb.Artifact) -> Path: - """Extract the base filesystem path from entries in an artifact. - - An error is raised if the artifact contains ether zero or more than one references - to distinct filesystem directories. - """ - entry_paths = [ - e.ref.replace("file://", "") - for e in artifact.manifest.entries.values() - if e.ref.startswith("file://") - ] - dir_paths = {Path(e).parent.absolute() for e in entry_paths} - match len(dir_paths): - case 0: - raise ValueError( - f"Artifact {artifact.name} does not contain any filesystem references." - ) - case 1: - return list(dir_paths)[0] - case _: - # TODO: Can this be resolved somehow else??? - dir_string = ",".join(dir_paths) - raise ValueError( - f"Artifact {artifact.name} references multiple directories: {dir_string}. " - "Unable to determine which directory to load." - ) diff --git a/src/flamingo/integrations/wandb/utils.py b/src/flamingo/integrations/wandb/utils.py index 94452cd9..cda0c1b1 100644 --- a/src/flamingo/integrations/wandb/utils.py +++ b/src/flamingo/integrations/wandb/utils.py @@ -4,7 +4,7 @@ import wandb from wandb.apis.public import Run as ApiRun -from flamingo.integrations.wandb import ArtifactType, WandbRunConfig +from flamingo.integrations.wandb import ArtifactType, WandbArtifactConfig, WandbRunConfig @contextlib.contextmanager @@ -21,11 +21,6 @@ def wandb_init_from_config(config: WandbRunConfig, *, resume: str | None = None) yield run -def default_artifact_name(name: str, artifact_type: ArtifactType) -> str: - """A default name for an artifact based on the run name and type.""" - return f"{name}-{artifact_type}" - - def get_wandb_api_run(config: WandbRunConfig) -> ApiRun: """Retrieve a run from the W&B API.""" api = wandb.Api() @@ -43,3 +38,40 @@ def update_wandb_summary(config: WandbRunConfig, metrics: dict[str, Any]) -> Non run = get_wandb_api_run(config) run.summary.update(metrics) run.update() + + +def get_wandb_artifact(config: WandbArtifactConfig) -> wandb.Artifact: + """Load an artifact from the provided config. + + If a W&B run is active, the artifact is loaded via the run as an input. + If not, the artifact is pulled from the W&B API outside of the run. + """ + if wandb.run is not None: + # Retrieves the artifact and links it as an input to the run + return wandb.run.use_artifact(config.wandb_path()) + else: + # Retrieves the artifact outside of the run + api = wandb.Api() + return api.artifact(config.wandb_path()) + + +def resolve_artifact_path(path: str | WandbArtifactConfig) -> str: + """Resolve the actual filesystem path from an artifact/path reference. + + If the provided path is just a string, return the value directly. + If an artifact, download it from W&B (and link it to an in-progress run) + to retrieve the actual data directory. + """ + match path: + case str(): + return path + case WandbArtifactConfig() as config: + artifact = get_wandb_artifact(config) + return artifact.download() + case _: + raise ValueError(f"Invalid artifact path: {path}") + + +def default_artifact_name(name: str, artifact_type: ArtifactType) -> str: + """A default name for an artifact based on the run name and type.""" + return f"{name}-{artifact_type}" diff --git a/src/flamingo/jobs/finetuning/entrypoint.py b/src/flamingo/jobs/finetuning/entrypoint.py index e9597c35..e7329cfe 100644 --- a/src/flamingo/jobs/finetuning/entrypoint.py +++ b/src/flamingo/jobs/finetuning/entrypoint.py @@ -12,8 +12,12 @@ from trl import SFTTrainer from flamingo.integrations.huggingface.utils import load_and_split_dataset -from flamingo.integrations.wandb import ArtifactType, WandbArtifactLoader -from flamingo.integrations.wandb.utils import default_artifact_name, wandb_init_from_config +from flamingo.integrations.wandb import ArtifactType +from flamingo.integrations.wandb.utils import ( + default_artifact_name, + resolve_artifact_path, + wandb_init_from_config, +) from flamingo.jobs.finetuning import FinetuningJobConfig @@ -44,8 +48,8 @@ def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments: ) -def load_datasets(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> DatasetDict: - dataset_path = loader.resolve_artifact_path(config.dataset.path) +def load_datasets(config: FinetuningJobConfig) -> DatasetDict: + dataset_path = resolve_artifact_path(config.dataset.path) # We need to specify a fixed seed to load the datasets on each worker # Under the hood, HuggingFace uses `accelerate` to create a data loader shard for each worker # If the datasets are not seeded here, the ordering will be inconsistent between workers @@ -59,7 +63,7 @@ def load_datasets(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> D ) -def load_model(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> PreTrainedModel: +def load_model(config: FinetuningJobConfig) -> PreTrainedModel: device_map, bnb_config = None, None if config.quantization is not None: bnb_config = config.quantization.as_huggingface() @@ -70,7 +74,7 @@ def load_model(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> PreT device_map = {"": current_device} print(f"Setting model device_map = {device_map} to enable quantization") - model_path = loader.resolve_artifact_path(config.model.path) + model_path = resolve_artifact_path(config.model.path) return AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=model_path, trust_remote_code=config.model.trust_remote_code, @@ -80,8 +84,8 @@ def load_model(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> PreT ) -def load_tokenizer(config: FinetuningJobConfig, loader: WandbArtifactLoader): - tokenizer_path = loader.resolve_artifact_path(config.tokenizer.path) +def load_tokenizer(config: FinetuningJobConfig): + tokenizer_path = resolve_artifact_path(config.tokenizer.path) tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=tokenizer_path, trust_remote_code=config.tokenizer.trust_remote_code, @@ -93,14 +97,13 @@ def load_tokenizer(config: FinetuningJobConfig, loader: WandbArtifactLoader): return tokenizer -def train_func_with_loader(config: FinetuningJobConfig, loader: WandbArtifactLoader): - training_args = get_training_arguments(config) - +def load_and_train(config: FinetuningJobConfig): # Load the input artifacts, potentially linking them to the active W&B run - datasets = load_datasets(config, loader) - model = load_model(config, loader) - tokenizer = load_tokenizer(config, loader) + datasets = load_datasets(config) + model = load_model(config) + tokenizer = load_tokenizer(config) + training_args = get_training_arguments(config) trainer = SFTTrainer( model=model, tokenizer=tokenizer, @@ -119,12 +122,10 @@ def train_func_with_loader(config: FinetuningJobConfig, loader: WandbArtifactLoa def train_func(config_data: dict): config = FinetuningJobConfig(**config_data) if is_tracking_enabled(config): - with wandb_init_from_config(config, resume="never") as run: - loader = WandbArtifactLoader(run=run) - train_func_with_loader(config, loader) + with wandb_init_from_config(config, resume="never"): + load_and_train(config) else: - loader = WandbArtifactLoader(run=None) - train_func_with_loader(config, loader) + load_and_train(config) def run_finetuning(config: FinetuningJobConfig): diff --git a/src/flamingo/jobs/lm_harness/entrypoint.py b/src/flamingo/jobs/lm_harness/entrypoint.py index 4269ce88..e6fde546 100644 --- a/src/flamingo/jobs/lm_harness/entrypoint.py +++ b/src/flamingo/jobs/lm_harness/entrypoint.py @@ -6,8 +6,12 @@ from lm_eval.models.huggingface import HFLM from peft import PeftConfig -from flamingo.integrations.wandb import ArtifactType, WandbArtifactLoader -from flamingo.integrations.wandb.utils import default_artifact_name, wandb_init_from_config +from flamingo.integrations.wandb import ArtifactType +from flamingo.integrations.wandb.utils import ( + default_artifact_name, + resolve_artifact_path, + wandb_init_from_config, +) from flamingo.jobs.lm_harness import LMHarnessJobConfig @@ -23,8 +27,8 @@ def build_evaluation_artifact(run_name: str, results: dict[str, dict[str, Any]]) return artifact -def load_harness_model(config: LMHarnessJobConfig, loader: WandbArtifactLoader) -> HFLM: - model_path = loader.resolve_artifact_path(config.model.path) +def load_harness_model(config: LMHarnessJobConfig) -> HFLM: + model_path = resolve_artifact_path(config.model.path) # We don't know if the checkpoint is adapter weights or merged model weights # Try to load as an adapter and fall back to the checkpoint containing the full model @@ -53,11 +57,11 @@ def load_harness_model(config: LMHarnessJobConfig, loader: WandbArtifactLoader) ) -def evaluate_with_loader(config: LMHarnessJobConfig, loader: WandbArtifactLoader) -> dict[str, Any]: +def load_and_evaluate(config: LMHarnessJobConfig) -> dict[str, Any]: print("Initializing lm-harness tasks...") lm_eval.tasks.initialize_tasks() - llm = load_harness_model(config, loader) + llm = load_harness_model(config) eval_results = lm_eval.simple_evaluate( model=llm, tasks=config.evaluator.tasks, @@ -75,13 +79,11 @@ def evaluate_with_loader(config: LMHarnessJobConfig, loader: WandbArtifactLoader def evaluation_task(config: LMHarnessJobConfig) -> None: if config.tracking is not None: with wandb_init_from_config(config.tracking, resume="never") as run: - artifact_loader = WandbArtifactLoader(run=run) - eval_results = evaluate_with_loader(config, artifact_loader) + eval_results = load_and_evaluate(config) artifact = build_evaluation_artifact(run.name, eval_results) run.log_artifact(artifact) else: - artifact_loader = WandbArtifactLoader(run=None) - evaluate_with_loader(config, artifact_loader) + load_and_evaluate(config) def run_lm_harness(config: LMHarnessJobConfig):