Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
porting more stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 10, 2024
1 parent 661ede5 commit 639c466
Show file tree
Hide file tree
Showing 14 changed files with 321 additions and 64 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ requires-python = ">=3.10,<3.11"

dependencies = [
"click==8.1.7",
"ray[default]==2.7.0",
"torch==2.1.0",
"scipy==1.10.1",
"wandb==0.16.1",
"pydantic-yaml==1.2.0",
"pydantic==1.10.8",
"ray[default]==2.7.0",
]

[project.optional-dependencies]
Expand Down
32 changes: 29 additions & 3 deletions src/flamingo/cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,40 @@
import click

from flamingo.jobs import run_finetuning, run_lm_harness, run_simple
from flamingo.jobs.configs import FinetuningJobConfig, LMHarnessJobConfig, SimpleJobConfig


@click.group()
def main():
def cli():
pass


@click.group("simple")
@click.option("--config", type=str)
def run_simple_cli(config: str) -> None:
config = SimpleJobConfig.from_yaml_file(config)
run_simple.main(config)


@click.group("finetune")
@click.option("--config", type=str)
def run_finetuning_cli(config: str) -> None:
config = FinetuningJobConfig.from_yaml_file(config)
run_finetuning.main(config)


@click.group("finetune")
@click.option("--config", type=str)
def run_finetuning_cli(config: str) -> None:
config = FinetuningJobConfig.from_yaml_file(config)
run_finetuning.main(config)


# need to add the group / command function itself, not the module
main.add_command(simple.driver)
cli.add_command(run_simple.main)
cli.add_command(run_finetuning.main)
cli.add_command(run_lm_harness.main)


if __name__ == "__main__":
main()
cli()
22 changes: 0 additions & 22 deletions src/flamingo/integrations/wandb/wandb_mixin.py

This file was deleted.

14 changes: 3 additions & 11 deletions src/flamingo/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
from .base_config import BaseJobConfig
from .evaluation_config import EvaluationJobConfig
from .finetuning_config import FinetuningJobConfig
from .simple_config import SimpleJobConfig

__all__ = [
"BaseJobConfig",
"SimpleJobConfig",
"FinetuningJobConfig",
"EvaluationJobConfig",
]
# ruff: noqa
from .configs import *
from .entrypoints import *
11 changes: 11 additions & 0 deletions src/flamingo/jobs/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .base_config import BaseJobConfig
from .finetuning_config import FinetuningJobConfig
from .lm_harness_config import LMHarnessJobConfig
from .simple_config import SimpleJobConfig

__all__ = [
"BaseJobConfig",
"SimpleJobConfig",
"FinetuningJobConfig",
"LMHarnessJobConfig",
]
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from flamingo.integrations.huggingface import QuantizationConfig
from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name
from flamingo.integrations.wandb import WandbEnvironmentMixin
from flamingo.integrations.wandb import WandbEnvironment
from flamingo.jobs import BaseJobConfig
from flamingo.types import SerializableTorchDtype


class FinetuningJobConfig(WandbEnvironmentMixin, BaseJobConfig):
class FinetuningJobConfig(BaseJobConfig):
"""Configuration to submit an LLM finetuning job."""

model: str
Expand All @@ -32,6 +32,7 @@ class FinetuningJobConfig(WandbEnvironmentMixin, BaseJobConfig):
logging_steps: float = 100
save_strategy: str = "steps"
save_steps: int = 500
wandb_env: WandbEnvironment | None = None
# Lora/quantization
lora_config: LoraConfig | None = None # TODO: Create our own config type
quantization_config: QuantizationConfig | None = None
Expand All @@ -45,7 +46,3 @@ def _validate_modelname(cls, v): # noqa: N805
return v
else:
raise (ValueError(f"`{v}` is not a valid HuggingFace model name."))

@property
def entrypoint_command(self) -> str:
return f"python run_finetuning.py --config_json '{self.json()}'"
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from flamingo.integrations.huggingface import QuantizationConfig
from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name
from flamingo.integrations.wandb import WandbEnvironment, WandbEnvironmentMixin
from flamingo.integrations.wandb import WandbEnvironment
from flamingo.jobs import BaseJobConfig
from flamingo.types import SerializableTorchDtype

Expand Down Expand Up @@ -42,7 +42,7 @@ def __post_init__(self, checkpoint):
raise (ValueError(f"{self.name} is not a valid checkpoint path or HF model name"))


class EvaluationJobConfig(WandbEnvironmentMixin, BaseJobConfig):
class LMHarnessJobConfig(BaseJobConfig):
"""Configuration to run an lm-evaluation-harness evaluation job.
This job loads an existing checkpoint path
Expand All @@ -66,6 +66,7 @@ class Config:
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype = None
quantization_config: QuantizationConfig | None = None
wandb_env: WandbEnvironment | None = None
num_cpus: int = 1
num_gpus: int = 1
timeout: datetime.timedelta | None = None
Expand Down Expand Up @@ -112,7 +113,3 @@ def _validate_modelname_or_checkpoint(cls, values) -> Any:
raise (ValueError(f"{mnp} is not a valid HuggingFaceModel or checkpoint path."))

return values

@property
def entrypoint_command(self) -> str:
return f"python run_evaluation.py --config_json '{self.json()}'"
7 changes: 7 additions & 0 deletions src/flamingo/jobs/configs/simple_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from flamingo.jobs import BaseJobConfig


class SimpleJobConfig(BaseJobConfig):
"""A simple job to demonstrate the submission interface."""

magic_number: int
5 changes: 5 additions & 0 deletions src/flamingo/jobs/entrypoints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import run_finetuning
import run_lm_harness
import run_simple

__all__ = ["run_finetuning", "run_lm_harness", "run_simple"]
136 changes: 136 additions & 0 deletions src/flamingo/jobs/entrypoints/run_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import json

import torch
import wandb
from accelerate import Accelerator
from datasets import DatasetDict
from ray import train
from ray.train import CheckpointConfig, RunConfig
from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer
from ray.train.torch import TorchTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, TrainingArguments
from trl import SFTTrainer

from flamingo.integrations.wandb import update_wandb_summary
from flamingo.jobs import FinetuningJobConfig


def is_wandb_enabled(config: FinetuningJobConfig):
# Only report to WandB on the rank 0 worker
# Reference: https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html
return config.wandb_env and train.get_context().get_world_rank() == 0


def get_training_args(config: FinetuningJobConfig) -> TrainingArguments:
"""Get TrainingArguments appropriate for the worker rank and job config."""
return TrainingArguments(
output_dir="out", # Local checkpoint path on a worker
num_train_epochs=config.num_train_epochs,
learning_rate=config.learning_rate,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=config.batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
gradient_checkpointing=config.gradient_checkpointing,
weight_decay=config.weight_decay,
evaluation_strategy=config.evaluation_strategy,
eval_steps=config.eval_steps,
logging_strategy=config.logging_strategy,
logging_steps=config.logging_steps,
save_strategy=config.save_strategy,
save_steps=config.save_steps,
run_name=config.wandb_name,
report_to="wandb" if is_wandb_enabled(config) else "none",
no_cuda=not config.scaling_config.use_gpu,
push_to_hub=False,
disable_tqdm=True,
logging_dir=None,
)


def get_datasets(config: FinetuningJobConfig) -> DatasetDict:
# TODO: Refactor me somehow
...


def get_model(config: FinetuningJobConfig) -> PreTrainedModel:
device_map, bnb_config = None, None
if config.quantization_config:
bnb_config = config.quantization_config.as_huggingface()
# When quantization is enabled, model must all be on same GPU to work with DDP
# If a device_map is not specified we will get accelerate errors downstream
# Reference: https://github.com/huggingface/accelerate/issues/1840#issuecomment-1683105994
current_device = Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
device_map = {"": current_device}
print(f"Setting model device_map = {device_map} to enable quantization")

return AutoModelForCausalLM.from_pretrained(
config.model,
trust_remote_code=config.trust_remote_code,
torch_dtype=config.torch_dtype,
quantization_config=bnb_config,
device_map=device_map,
)


def get_tokenizer(config: FinetuningJobConfig):
tokenizer = AutoTokenizer.from_pretrained(
config.tokenizer or config.model,
trust_remote_code=config.trust_remote_code,
use_fast=True,
)
if not tokenizer.pad_token_id:
# Pad token required for generating consistent batch sizes
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer


def train_func(config_data: dict):
config = FinetuningJobConfig(**config_data)
model = get_model(config)
tokenizer = get_tokenizer(config)

datasets = get_datasets(config)
training_args = get_training_args(config)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=config.lora_config,
max_seq_length=config.max_seq_length,
train_dataset=datasets["train"],
eval_dataset=datasets["test"],
dataset_text_field="text",
)
trainer.add_callback(RayTrainReportCallback())
trainer = prepare_trainer(trainer)
trainer.train()

# Force WandB finish on rank 0 worker
if is_wandb_enabled(config):
wandb.finish()


def main(config: FinetuningJobConfig):
print(f"Received job configuration: {config}")

run_config = RunConfig(
name=config.wandb_name,
storage_path=config.storage_path,
checkpoint_config=CheckpointConfig(num_to_keep=1),
)
trainer = TorchTrainer(
train_loop_per_worker=train_func,
train_loop_config=json.loads(config.json()),
scaling_config=config.scaling_config,
run_config=run_config,
)
result = trainer.fit()
print(f"Training result: {result}")

# Log additional training metrics to completed WandB run
if config.wandb_env:
result_paths = {"ray/result_path": result.path}
if result.checkpoint:
result_paths["ray/checkpoint_path"] = f"{result.checkpoint.path}/checkpoint"
update_wandb_summary(config.wandb_env, result_paths)
Loading

0 comments on commit 639c466

Please sign in to comment.