Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
rename some args
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 18, 2024
1 parent fcc33ea commit 4da64d1
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/flamingo/integrations/huggingface/trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TrainerConfig(BaseFlamingoConfig):
save_strategy: str | None = None
save_steps: int | None = None

def get_training_args(self) -> dict[str, Any]:
def training_args(self) -> dict[str, Any]:
"""Return the arguments to the HuggingFace `TrainingArguments` class."""
excluded_keys = ["max_seq_length"]
return self.dict(exclude=excluded_keys)
2 changes: 1 addition & 1 deletion src/flamingo/integrations/wandb/artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class WandbArtifactConfig(BaseFlamingoConfig):
project: str | None = None
entity: str | None = None

def get_wandb_path(self) -> str:
def wandb_path(self) -> str:
"""String identifier for the asset on the W&B platform."""
path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None)
path = f"{path}:{self.version}"
Expand Down
6 changes: 3 additions & 3 deletions src/flamingo/integrations/wandb/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def from_run(cls, run: Run) -> "WandbRunConfig":
run_id=run.id,
)

def get_wandb_path(self) -> str:
def wandb_path(self) -> str:
"""String identifier for the asset on the W&B platform."""
path = "/".join(x for x in [self.entity, self.project, self.run_id] if x is not None)
return path

def get_wandb_init_args(self) -> dict[str, str]:
def wandb_init_args(self) -> dict[str, str]:
"""Return the kwargs passed to `wandb.init` with proper naming."""
return dict(
id=self.run_id,
Expand All @@ -74,7 +74,7 @@ def get_wandb_init_args(self) -> dict[str, str]:
group=self.run_group,
)

def get_env_vars(self) -> dict[str, str]:
def env_vars(self) -> dict[str, str]:
env_vars = {
"WANDB_RUN_ID": self.run_id,
"WANDB_NAME": self.name,
Expand Down
39 changes: 33 additions & 6 deletions src/flamingo/integrations/wandb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import wandb
from wandb.apis.public import Run as ApiRun

from flamingo.integrations.wandb import WandbRunConfig
from flamingo.integrations.wandb import WandbArtifactConfig, WandbArtifactLoader, WandbRunConfig


def get_wandb_api_run(config: WandbRunConfig) -> ApiRun:
"""Retrieve a run from the W&B API."""
api = wandb.Api()
return api.run(config.get_wandb_path())
return api.run(config.wandb_path())


def get_wandb_summary(config: WandbRunConfig) -> dict[str, Any]:
Expand All @@ -26,10 +26,37 @@ def update_wandb_summary(config: WandbRunConfig, metrics: dict[str, Any]) -> Non
run.update()


def get_artifact_filesystem_path(artifact: wandb.Artifact) -> str:
def get_artifact_directory(artifact: wandb.Artifact) -> str:
dir_paths = set()
for entry in artifact.manifest.entries.values():
if entry.ref.startswith("file://"):
# TODO: What if there are entries with different base paths in the artifact manifest?
entry_path = Path(entry.ref.replace("file://", ""))
return str(entry_path.parent.absolute())
raise ValueError("Artifact does not contain a filesystem reference.")
dir_paths.add(str(entry_path.parent.absolute()))
match len(dir_paths):
case 0:
raise ValueError(
f"Artifact {artifact.name} does not contain any filesystem references."
)
case 1:
return list(dir_paths)[0]
case _:
dir_string = ",".join(dir_paths)
raise ValueError(
f"Artifact {artifact.name} references multiple directories: {dir_string}. "
"Unable to determine which directory to load."
)


def resolve_artifact_path(path: str | WandbArtifactConfig, loader: WandbArtifactLoader) -> str:
"""Resolve the actual filesystem path for a path/artifact asset.
The artifact loader internally handles linking the artifact-to-load to an in-progress run.
"""
match path:
case str():
return path
case WandbArtifactConfig() as artifact_config:
artifact = loader.load_artifact(artifact_config)
return get_artifact_directory(artifact)
case _:
raise ValueError(f"Invalid artifact path: {path}")
27 changes: 6 additions & 21 deletions src/flamingo/jobs/drivers/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from trl import SFTTrainer

from flamingo.integrations.huggingface.utils import load_and_split_dataset
from flamingo.integrations.wandb import ArtifactType, WandbArtifactConfig, WandbArtifactLoader
from flamingo.integrations.wandb.utils import get_artifact_filesystem_path
from flamingo.integrations.wandb import ArtifactType, WandbArtifactLoader
from flamingo.integrations.wandb.utils import resolve_artifact_path
from flamingo.jobs import FinetuningJobConfig


Expand All @@ -23,21 +23,6 @@ def is_tracking_enabled(config: FinetuningJobConfig):
return config.tracking is not None and train.get_context().get_world_rank() == 0


def resolve_artifact_path(path: str | WandbArtifactConfig, loader: WandbArtifactLoader) -> str:
"""Resolve the actual filesystem path for a path/artifact asset.
The artifact loader internally handles linking the artifact-to-load to an in-progress run.
"""
match path:
case str():
return path
case WandbArtifactConfig() as artifact_config:
artifact = loader.load_artifact(artifact_config)
return get_artifact_filesystem_path(artifact)
case _:
raise ValueError(f"Invalid artifact path: {path}")


def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments:
"""Get TrainingArguments appropriate for the worker rank and job config."""
return TrainingArguments(
Expand All @@ -47,7 +32,7 @@ def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments:
push_to_hub=False,
disable_tqdm=True,
logging_dir=None,
**config.trainer.get_training_args(),
**config.trainer.training_args(),
)


Expand All @@ -56,7 +41,7 @@ def load_datasets(config: FinetuningJobConfig, loader: WandbArtifactLoader) -> D
# We need to specify a fixed seed to load the datasets on each worker
# Under the hood, HuggingFace uses `accelerate` to create a data loader shard for each worker
# If the datasets are not seeded here, the ordering will be inconsistent between workers
# TODO: Get rid of this logic once data loading occurs once outside of the workers
# TODO: Get rid of this logic once data loading is done one time outside of the workers
split_seed = config.dataset.seed or 0
return load_and_split_dataset(
path=dataset_path,
Expand Down Expand Up @@ -107,7 +92,7 @@ def train_func(config_data: dict):
# Manually initialize run in order to set the run ID and link artifacts
wandb_run = None
if is_tracking_enabled(config):
wandb_run = wandb.init(**config.tracking.get_wandb_init_args(), resume="never")
wandb_run = wandb.init(**config.tracking.wandb_init_args(), resume="never")

# Load the input artifacts, potentially linking them to the active W&B run
artifact_loader = WandbArtifactLoader(wandb_run)
Expand Down Expand Up @@ -157,7 +142,7 @@ def run_finetuning(config: FinetuningJobConfig):

if config.tracking and result.checkpoint:
# Must resume from the just-completed training run
with wandb.init(config.tracking.get_wandb_init_args(), resume="must") as run:
with wandb.init(**config.tracking.wandb_init_args(), resume="must") as run:
artifact_type = ArtifactType.MODEL.value
artifact_name = f"{config.tracking.name or config.tracking.run_id}-{artifact_type}"
artifact = wandb.Artifact(artifact_name, type=artifact_type)
Expand Down
2 changes: 1 addition & 1 deletion src/flamingo/jobs/drivers/lm_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from lm_eval.models.huggingface import HFLM
from peft import PeftConfig

from flamingo.configs import LMHarnessJobConfig, ModelNameOrCheckpointPath
from flamingo.integrations.wandb import get_wandb_summary, update_wandb_summary
from flamingo.jobs import LMHarnessJobConfig, ModelNameOrCheckpointPath


def resolve_model_or_path(config: LMHarnessJobConfig) -> str:
Expand Down
6 changes: 3 additions & 3 deletions tests/integrations/wandb/test_run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_serde_round_trip(wandb_run_config):


def test_wandb_path(wandb_run_config):
assert wandb_run_config.get_wandb_path() == "twitter/cortex/run-id"
assert wandb_run_config.wandb_path() == "twitter/cortex/run-id"


def test_ensure_run_id():
Expand All @@ -28,7 +28,7 @@ def test_ensure_run_id():


def test_env_vars(wandb_run_config):
env_vars = wandb_run_config.get_env_vars()
env_vars = wandb_run_config.env_vars()
expected = ["WANDB_NAME", "WANDB_PROJECT", "WANDB_ENTITY", "WANDB_RUN_ID"]
for key in expected:
assert key in env_vars
Expand All @@ -43,4 +43,4 @@ def test_disallowed_kwargs():
def test_missing_key_warning(mock_environment_without_keys):
with pytest.warns(UserWarning):
config = WandbRunConfig(name="I am missing an API key", project="I should warn the user")
assert "WANDB_API_KEY" not in config.get_env_vars()
assert "WANDB_API_KEY" not in config.env_vars()

0 comments on commit 4da64d1

Please sign in to comment.