Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
working on w&b artifact links
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 16, 2024
1 parent 0306f3b commit a03d272
Show file tree
Hide file tree
Showing 13 changed files with 105 additions and 65 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ evaluate = ["lm-eval==0.4.0", "einops"]

test = ["ruff==0.1.4", "pytest==7.4.3", "pytest-cov==4.1.0"]

all = ["flamingo[finetune,ludwig,evaluate,test]"]
all = ["flamingo[finetune,evaluate,test]"]

[project.scripts]
flamingo = "flamingo.cli:main"
Expand Down
10 changes: 5 additions & 5 deletions src/flamingo/integrations/huggingface/model_config.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from peft import LoraConfig
from pydantic import validator

from flamingo.integrations.huggingface import QuantizationConfig
from flamingo.integrations.huggingface.utils import repo_name_validator
from flamingo.integrations.wandb import WandbArtifactLink
from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype


class AutoModelConfig(BaseFlamingoConfig):
"""Settings passed to a HuggingFace AutoModel instantiation."""
"""Settings passed to a HuggingFace AutoModel instantiation.
The model path can either be a string corresponding to a HuggingFace repo ID,
or an artifact link to a reference artifact on W&B.
"""

path: str | WandbArtifactLink
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype = None
quantization: QuantizationConfig | None = None
lora: LoraConfig | None = None # TODO: Create own dataclass here

_path_validator = validator("path", allow_reuse=True, pre=True)(repo_name_validator)
8 changes: 6 additions & 2 deletions src/flamingo/integrations/huggingface/trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@


class TrainerConfig(BaseFlamingoConfig):
"""Configuration for a HuggingFace trainer/training arguments."""
"""Configuration for a HuggingFace trainer/training arguments.
This mainly encompasses arguments passed to the HuggingFace `TrainingArguments` class,
but also contains some additional parameters for the `Trainer` or `SFTTrainer` classes.
"""

max_seq_length: int | None = None
num_train_epochs: int | None = None
Expand Down Expand Up @@ -37,5 +41,5 @@ def get_training_args(self) -> dict[str, Any]:
save_strategy=self.save_strategy,
save_steps=self.save_steps,
)
# Only return non-None values so we get HuggingFace defaults when not specified
# Only return non-None values so we use the HuggingFace defaults when not specified
return {k: v for k, v in args.items() if v is not None}
45 changes: 20 additions & 25 deletions src/flamingo/integrations/wandb/utils.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,42 @@
from pathlib import Path
from typing import Any

import wandb
from wandb.apis.public import Run

from flamingo.integrations.wandb import WandbEnvironment
from flamingo.integrations.wandb.wandb_artifact_link import WandbArtifactLink
from flamingo.integrations.wandb import WandbArtifactLink, WandbEnvironment


def get_wandb_artifact(link: WandbArtifactLink):
def get_wandb_run(env: WandbEnvironment) -> Run:
"""Retrieve a run from the W&B API."""
api = wandb.Api()
return api.artifact(link.artifact_path())
return api.run(env.wandb_path)


def get_wandb_summary(env: WandbEnvironment) -> dict[str, Any]:
"""Get the summary dictionary attached to a W&B run."""
run = _resolve_wandb_run(env)
run = get_wandb_run(env)
return dict(run.summary)


def update_wandb_summary(env: WandbEnvironment, metrics: dict[str, Any]) -> None:
"""Update a run's summary with the provided metrics."""
run = _resolve_wandb_run(env)
run = get_wandb_run(env)
run.summary.update(metrics)
run.update()


def _resolve_wandb_run(env: WandbEnvironment) -> Run:
"""Resolve a WandB run object from the provided environment settings.
An exception is raised if a Run cannot be found,
or if multiple runs exist in scope with the same name.
"""
def get_wandb_artifact(link: WandbArtifactLink) -> wandb.Artifact:
"""Retrieve an artifact from the W&B API."""
api = wandb.Api()
base_path = "/".join(x for x in (env.entity, env.project) if x)
if env.run_id is not None:
full_path = f"{base_path}/{env.run_id}"
return api.run(full_path)
else:
match [run for run in api.runs(base_path) if run.name == env.name]:
case []:
raise RuntimeError(f"No WandB runs found at {base_path}/{env.name}")
case [Run(), _]:
raise RuntimeError(f"Multiple WandB runs found at {base_path}/{env.name}")
case [Run()] as mr:
# we have a single one, hurray
return mr[0]
return api.artifact(link.wandb_path)


def get_artifact_filesystem_path(link: WandbArtifactLink) -> str:
# TODO: What if there are multiple folder paths in the artifact manifest?
artifact = get_wandb_artifact(link)
for entry in artifact.manifest.entries.values():
if entry.ref.startswith("file://"):
entry_path = Path(entry.ref.replace("file://", ""))
return str(entry_path.parent.absolute())
raise ValueError("Artifact does not contain reference to filesystem files.")
4 changes: 3 additions & 1 deletion src/flamingo/integrations/wandb/wandb_artifact_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ class WandbArtifactLink(BaseFlamingoConfig):
project: str | None = None
entity: str | None = None

def artifact_path(self) -> str:
@property
def wandb_path(self) -> str:
"""String identifier for retrieving the asset from the W&B platform."""
path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None)
if self.alias:
path = f"{path}:{self.alias}"
Expand Down
32 changes: 19 additions & 13 deletions src/flamingo/integrations/wandb/wandb_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,13 @@ def warn_missing_api_key(cls, values):
)
return values

@validator("run_id", post=True, always=True)
@validator("run_id", always=True)
def ensure_run_id(cls, run_id):
if run_id is None:
# Generates an 8-digit random hexadecimal string, analogous to W&B platform
run_id = secrets.token_hex(nbytes=4)
return run_id

@property
def env_vars(self) -> dict[str, str]:
env_vars = {
"WANDB_RUN_ID": self.run_id,
"WANDB_NAME": self.name,
"WANDB_PROJECT": self.project,
"WANDB_RUN_GROUP": self.run_group,
"WANDB_ENTITY": self.entity,
"WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None),
}
return {k: v for k, v in env_vars.items() if v is not None}

@classmethod
def from_run(cls, run: Run) -> "WandbEnvironment":
"""Extract environment settings from a W&B Run object.
Expand All @@ -69,3 +57,21 @@ def from_run(cls, run: Run) -> "WandbEnvironment":
entity=run.entity,
run_id=run.id,
)

@property
def env_vars(self) -> dict[str, str]:
env_vars = {
"WANDB_RUN_ID": self.run_id,
"WANDB_NAME": self.name,
"WANDB_PROJECT": self.project,
"WANDB_RUN_GROUP": self.run_group,
"WANDB_ENTITY": self.entity,
"WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None),
}
return {k: v for k, v in env_vars.items() if v is not None}

@property
def wandb_path(self) -> str:
"""String identifier for retrieving the asset from the W&B platform."""
path = "/".join(x for x in [self.entity, self.project, self.run_id] if x is not None)
return path
3 changes: 1 addition & 2 deletions src/flamingo/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from .finetuning_config import FinetuningJobConfig
from .lm_harness_config import LMHarnessJobConfig, ModelNameOrCheckpointPath
from .lm_harness_config import LMHarnessJobConfig
from .simple_config import SimpleJobConfig

__all__ = [
"SimpleJobConfig",
"FinetuningJobConfig",
"LMHarnessJobConfig",
"ModelNameOrCheckpointPath",
]
28 changes: 17 additions & 11 deletions src/flamingo/jobs/drivers/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,24 @@
from accelerate import Accelerator
from datasets import DatasetDict
from ray import train
from ray.train import CheckpointConfig, RunConfig
from ray.train import CheckpointConfig, RunConfig, ScalingConfig
from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer
from ray.train.torch import TorchTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, TrainingArguments
from trl import SFTTrainer

from flamingo.integrations.wandb import WandbArtifactLink
from flamingo.integrations.wandb.utils import get_artifact_filesystem_path
from flamingo.jobs import FinetuningJobConfig


def resolve_file_path(name_or_artifact: str | WandbArtifactLink) -> str:
if isinstance(name_or_artifact, str):
return name_or_artifact
else:
return get_artifact_filesystem_path(name_or_artifact)


def is_wandb_enabled(config: FinetuningJobConfig):
# Only report to WandB on the rank 0 worker
# Reference: https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html
Expand All @@ -22,15 +31,15 @@ def is_wandb_enabled(config: FinetuningJobConfig):

def get_training_arguments(config: FinetuningJobConfig) -> TrainingArguments:
"""Get TrainingArguments appropriate for the worker rank and job config."""
training_args = config.trainer.get_training_args() if config.trainer else {}
provided_args = config.trainer.get_training_args()
return TrainingArguments(
output_dir="out", # Local checkpoint path on a worker
report_to="wandb" if is_wandb_enabled(config) else "none",
no_cuda=not config.scaling.use_gpu,
push_to_hub=False,
disable_tqdm=True,
logging_dir=None,
**training_args,
**provided_args,
)


Expand Down Expand Up @@ -60,7 +69,7 @@ def get_model(config: FinetuningJobConfig) -> PreTrainedModel:


def get_tokenizer(config: FinetuningJobConfig):
tokenizer_name = config.tokenizer.name if config.tokenizer else config.model.name
tokenizer_name = config.tokenizer.path if config.tokenizer else config.model.name
tokenizer_args = config.tokenizer.get_tokenizer_args() if config.tokenizer else {}
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_args)
if not tokenizer.pad_token_id:
Expand Down Expand Up @@ -92,8 +101,8 @@ def train_func(config_data: dict):
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=config.lora,
max_seq_length=config.trainer.max_seq_length if config.trainer else None,
peft_config=config.adapter,
max_seq_length=config.trainer.max_seq_length,
train_dataset=datasets["train"],
eval_dataset=datasets["test"],
dataset_text_field="text",
Expand All @@ -110,10 +119,7 @@ def train_func(config_data: dict):
def run_finetuning(config: FinetuningJobConfig):
print(f"Received job configuration: {config}")

if config.tracking:
# Ensure the run_id is set so that the W&B run can be initialized deterministically
config.tracking.ensure_run_id()

scaling_config = ScalingConfig(**config.ray.get_scaling_args())
run_config = RunConfig(
name=config.tracking.name if config.tracking else None,
storage_path=config.ray.storage_path,
Expand All @@ -122,7 +128,7 @@ def run_finetuning(config: FinetuningJobConfig):
trainer = TorchTrainer(
train_loop_per_worker=train_func,
train_loop_config=json.loads(config.json()),
scaling_config=config.ray.get_scaling_config(),
scaling_config=scaling_config,
run_config=run_config,
)
result = trainer.fit()
Expand Down
35 changes: 31 additions & 4 deletions src/flamingo/jobs/finetuning_config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from pydantic import Field
from ray.train import ScalingConfig
from typing import Any

from peft import LoraConfig
from pydantic import Field, validator

from flamingo.integrations.huggingface import (
AutoModelConfig,
AutoTokenizerConfig,
DatasetConfig,
QuantizationConfig,
TrainerConfig,
)
from flamingo.integrations.wandb import WandbEnvironment
Expand All @@ -21,8 +24,9 @@ class RayTrainConfig(BaseFlamingoConfig):
num_workers: int | None = None
storage_path: str | None = None

def get_scaling_config(self) -> ScalingConfig:
return ScalingConfig(use_gpu=self.use_gpu, num_workers=self.num_workers)
def get_scaling_args(self) -> dict[str, Any]:
args = dict(use_gpu=self.use_gpu, num_workers=self.num_workers)
return {k: v for k, v in args.items() if v is not None}


class FinetuningJobConfig(BaseFlamingoConfig):
Expand All @@ -31,6 +35,29 @@ class FinetuningJobConfig(BaseFlamingoConfig):
model: AutoModelConfig
dataset: DatasetConfig
tokenizer: AutoTokenizerConfig | None = None
quantization: QuantizationConfig | None = None
adapter: LoraConfig | None = None # TODO: Create own dataclass here
tracking: WandbEnvironment | None = None
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
ray: RayTrainConfig = Field(default_factory=RayTrainConfig)

@validator("model", pre=True, always=True)
def validate_model_arg(cls, x):
"""Allow for passing just a path string as the model argument."""
if isinstance(x, str):
return AutoModelConfig(path=x)
return x

@validator("dataset", pre=True, always=True)
def validate_dataset_arg(cls, x):
"""Allow for passing just a path string as the dataset argument."""
if isinstance(x, str):
return DatasetConfig(path=x)
return x

@validator("tokenizer", pre=True, always=True)
def validate_tokenizer_arg(cls, x):
"""Allow for passing just a path string as the tokenizer argument."""
if isinstance(x, str):
return AutoTokenizerConfig(path=x)
return x
3 changes: 2 additions & 1 deletion src/flamingo/jobs/lm_harness_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import Field

from flamingo.integrations.huggingface import AutoModelConfig
from flamingo.integrations.huggingface import AutoModelConfig, QuantizationConfig
from flamingo.integrations.wandb import WandbEnvironment
from flamingo.types import BaseFlamingoConfig

Expand All @@ -29,5 +29,6 @@ class LMHarnessJobConfig(BaseFlamingoConfig):

model: AutoModelConfig
evaluator: LMHarnessEvaluatorSettings
quantization: QuantizationConfig | None = None
tracking: WandbEnvironment | None = None
ray: RayComputeSettings = Field(default_factory=RayComputeSettings)
File renamed without changes.

0 comments on commit a03d272

Please sign in to comment.