diff --git a/src/flamingo/integrations/huggingface/trainer_config.py b/src/flamingo/integrations/huggingface/trainer_config.py index 85a932e6..339be12f 100644 --- a/src/flamingo/integrations/huggingface/trainer_config.py +++ b/src/flamingo/integrations/huggingface/trainer_config.py @@ -25,7 +25,7 @@ class TrainerConfig(BaseFlamingoConfig): save_strategy: str | None = None save_steps: int | None = None - def get_training_args(self) -> dict[str, Any]: + def training_args(self) -> dict[str, Any]: """Return the arguments to the HuggingFace `TrainingArguments` class.""" excluded_keys = ["max_seq_length"] return self.dict(exclude=excluded_keys) diff --git a/src/flamingo/integrations/wandb/artifact_config.py b/src/flamingo/integrations/wandb/artifact_config.py index 4b542f3d..c4ebaf10 100644 --- a/src/flamingo/integrations/wandb/artifact_config.py +++ b/src/flamingo/integrations/wandb/artifact_config.py @@ -11,7 +11,7 @@ class WandbArtifactConfig(BaseFlamingoConfig): project: str | None = None entity: str | None = None - def get_wandb_path(self) -> str: + def wandb_path(self) -> str: """String identifier for the asset on the W&B platform.""" path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None) path = f"{path}:{self.version}" diff --git a/src/flamingo/integrations/wandb/run_config.py b/src/flamingo/integrations/wandb/run_config.py index d2c97639..dbf55b39 100644 --- a/src/flamingo/integrations/wandb/run_config.py +++ b/src/flamingo/integrations/wandb/run_config.py @@ -59,12 +59,12 @@ def from_run(cls, run: Run) -> "WandbRunConfig": run_id=run.id, ) - def get_wandb_path(self) -> str: + def wandb_path(self) -> str: """String identifier for the asset on the W&B platform.""" path = "/".join(x for x in [self.entity, self.project, self.run_id] if x is not None) return path - def get_wandb_init_args(self) -> dict[str, str]: + def wandb_init_args(self) -> dict[str, str]: """Return the kwargs passed to `wandb.init` with proper naming.""" return dict( id=self.run_id, @@ -74,7 +74,7 @@ def get_wandb_init_args(self) -> dict[str, str]: group=self.run_group, ) - def get_env_vars(self) -> dict[str, str]: + def env_vars(self) -> dict[str, str]: env_vars = { "WANDB_RUN_ID": self.run_id, "WANDB_NAME": self.name, diff --git a/src/flamingo/integrations/wandb/utils.py b/src/flamingo/integrations/wandb/utils.py index 5e82ab1b..eb12d27e 100644 --- a/src/flamingo/integrations/wandb/utils.py +++ b/src/flamingo/integrations/wandb/utils.py @@ -4,13 +4,13 @@ import wandb from wandb.apis.public import Run as ApiRun -from flamingo.integrations.wandb import WandbRunConfig +from flamingo.integrations.wandb import WandbArtifactConfig, WandbArtifactLoader, WandbRunConfig def get_wandb_api_run(config: WandbRunConfig) -> ApiRun: """Retrieve a run from the W&B API.""" api = wandb.Api() - return api.run(config.get_wandb_path()) + return api.run(config.wandb_path()) def get_wandb_summary(config: WandbRunConfig) -> dict[str, Any]: @@ -26,10 +26,37 @@ def update_wandb_summary(config: WandbRunConfig, metrics: dict[str, Any]) -> Non run.update() -def get_artifact_filesystem_path(artifact: wandb.Artifact) -> str: +def get_artifact_directory(artifact: wandb.Artifact) -> str: + dir_paths = set() for entry in artifact.manifest.entries.values(): if entry.ref.startswith("file://"): - # TODO: What if there are entries with different base paths in the artifact manifest? entry_path = Path(entry.ref.replace("file://", "")) - return str(entry_path.parent.absolute()) - raise ValueError("Artifact does not contain a filesystem reference.") + dir_paths.add(str(entry_path.parent.absolute())) + match len(dir_paths): + case 0: + raise ValueError( + f"Artifact {artifact.name} does not contain any filesystem references." + ) + case 1: + return list(dir_paths)[0] + case _: + dir_string = ",".join(dir_paths) + raise ValueError( + f"Artifact {artifact.name} references multiple directories: {dir_string}. " + "Unable to determine which directory to load." + ) + + +def resolve_artifact_path(path: str | WandbArtifactConfig, loader: WandbArtifactLoader) -> str: + """Resolve the actual filesystem path for a path/artifact asset. + + The artifact loader internally handles linking the artifact-to-load to an in-progress run. + """ + match path: + case str(): + return path + case WandbArtifactConfig() as artifact_config: + artifact = loader.load_artifact(artifact_config) + return get_artifact_directory(artifact) + case _: + raise ValueError(f"Invalid artifact path: {path}") diff --git a/src/flamingo/jobs/drivers/finetuning.py b/src/flamingo/jobs/drivers/finetuning.py index 3d63af63..a7f9cdfd 100644 --- a/src/flamingo/jobs/drivers/finetuning.py +++ b/src/flamingo/jobs/drivers/finetuning.py @@ -12,8 +12,8 @@ from trl import SFTTrainer from flamingo.integrations.huggingface.utils import load_and_split_dataset -from flamingo.integrations.wandb import ArtifactType, WandbArtifactConfig, WandbArtifactLoader -from flamingo.integrations.wandb.utils import get_artifact_filesystem_path +from flamingo.integrations.wandb import ArtifactType, WandbArtifactLoader +from flamingo.integrations.wandb.utils import resolve_artifact_path from flamingo.jobs import FinetuningJobConfig @@ -23,21 +23,6 @@ def is_tracking_enabled(config: FinetuningJobConfig): return config.tracking is not None and train.get_context().get_world_rank() == 0 -def resolve_artifact_path(path: str | WandbArtifactConfig, loader: WandbArtifactLoader) -> str: - """Resolve the actual filesystem path for a path/artifact asset. - - The artifact loader internally handles linking the artifact-to-load to an in-progress run. - """ - match path: - case str(): - return path - case WandbArtifactConfig() as artifact_config: - artifact = loader.load_artifact(artifact_config) - return get_artifact_filesystem_path(artifact) - case _: - raise ValueError(f"Invalid artifact path: {path}") - - def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments: """Get TrainingArguments appropriate for the worker rank and job config.""" return TrainingArguments( @@ -47,7 +32,7 @@ def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments: push_to_hub=False, disable_tqdm=True, logging_dir=None, - **config.trainer.get_training_args(), + **config.trainer.training_args(), ) @@ -56,7 +41,7 @@ def load_datasets(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> D # We need to specify a fixed seed to load the datasets on each worker # Under the hood, HuggingFace uses `accelerate` to create a data loader shard for each worker # If the datasets are not seeded here, the ordering will be inconsistent between workers - # TODO: Get rid of this logic once data loading occurs once outside of the workers + # TODO: Get rid of this logic once data loading is done one time outside of the workers split_seed = config.dataset.seed or 0 return load_and_split_dataset( path=dataset_path, @@ -107,7 +92,7 @@ def train_func(config_data: dict): # Manually initialize run in order to set the run ID and link artifacts wandb_run = None if is_tracking_enabled(config): - wandb_run = wandb.init(**config.tracking.get_wandb_init_args(), resume="never") + wandb_run = wandb.init(**config.tracking.wandb_init_args(), resume="never") # Load the input artifacts, potentially linking them to the active W&B run artifact_loader = WandbArtifactLoader(wandb_run) @@ -157,7 +142,7 @@ def run_finetuning(config: FinetuningJobConfig): if config.tracking and result.checkpoint: # Must resume from the just-completed training run - with wandb.init(config.tracking.get_wandb_init_args(), resume="must") as run: + with wandb.init(**config.tracking.wandb_init_args(), resume="must") as run: artifact_type = ArtifactType.MODEL.value artifact_name = f"{config.tracking.name or config.tracking.run_id}-{artifact_type}" artifact = wandb.Artifact(artifact_name, type=artifact_type) diff --git a/src/flamingo/jobs/drivers/lm_harness.py b/src/flamingo/jobs/drivers/lm_harness.py index 03b3066e..4531dda1 100644 --- a/src/flamingo/jobs/drivers/lm_harness.py +++ b/src/flamingo/jobs/drivers/lm_harness.py @@ -5,8 +5,8 @@ from lm_eval.models.huggingface import HFLM from peft import PeftConfig -from flamingo.configs import LMHarnessJobConfig, ModelNameOrCheckpointPath from flamingo.integrations.wandb import get_wandb_summary, update_wandb_summary +from flamingo.jobs import LMHarnessJobConfig, ModelNameOrCheckpointPath def resolve_model_or_path(config: LMHarnessJobConfig) -> str: diff --git a/tests/integrations/wandb/test_run_config.py b/tests/integrations/wandb/test_run_config.py index 7e0bec62..6bf668b3 100644 --- a/tests/integrations/wandb/test_run_config.py +++ b/tests/integrations/wandb/test_run_config.py @@ -19,7 +19,7 @@ def test_serde_round_trip(wandb_run_config): def test_wandb_path(wandb_run_config): - assert wandb_run_config.get_wandb_path() == "twitter/cortex/run-id" + assert wandb_run_config.wandb_path() == "twitter/cortex/run-id" def test_ensure_run_id(): @@ -28,7 +28,7 @@ def test_ensure_run_id(): def test_env_vars(wandb_run_config): - env_vars = wandb_run_config.get_env_vars() + env_vars = wandb_run_config.env_vars() expected = ["WANDB_NAME", "WANDB_PROJECT", "WANDB_ENTITY", "WANDB_RUN_ID"] for key in expected: assert key in env_vars @@ -43,4 +43,4 @@ def test_disallowed_kwargs(): def test_missing_key_warning(mock_environment_without_keys): with pytest.warns(UserWarning): config = WandbRunConfig(name="I am missing an API key", project="I should warn the user") - assert "WANDB_API_KEY" not in config.get_env_vars() + assert "WANDB_API_KEY" not in config.env_vars()