Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #93 from mozilla-ai/sfriedowitz/refactor-module-la…
Browse files Browse the repository at this point in the history
…yout

Refactor module layout to separate configs and functionality
  • Loading branch information
Sean Friedowitz authored Apr 4, 2024
2 parents e1d7f5b + 244555e commit 404a35c
Show file tree
Hide file tree
Showing 52 changed files with 348 additions and 377 deletions.
16 changes: 6 additions & 10 deletions examples/notebooks/direct_job_execution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,15 @@
"outputs": [],
"source": [
"from lm_buddy import LMBuddy\n",
"from lm_buddy.jobs.configs import (\n",
" FinetuningJobConfig,\n",
" FinetuningRayConfig,\n",
" LMHarnessJobConfig,\n",
" LMHarnessEvaluationConfig,\n",
")\n",
"from lm_buddy.integrations.huggingface import (\n",
"from lm_buddy.configs.jobs.finetuning import FinetuningJobConfig, FinetuningRayConfig\n",
"from lm_buddy.configs.jobs.lm_harness import LMHarnessJobConfig, LMHarnessEvaluationConfig\n",
"from lm_buddy.configs.huggingface import (\n",
" AutoModelConfig,\n",
" TextDatasetConfig,\n",
" DatasetConfig,\n",
" TrainerConfig,\n",
" AdapterConfig,\n",
")\n",
"from lm_buddy.integrations.wandb import WandbRunConfig"
"from lm_buddy.configs.wandb import WandbRunConfig"
]
},
{
Expand All @@ -69,7 +65,7 @@
"model_config = AutoModelConfig(path=\"hf://distilgpt2\")\n",
"\n",
"# Text dataset for finetuning\n",
"dataset_config = TextDatasetConfig(\n",
"dataset_config = DatasetConfig(\n",
" path=\"hf://imdb\",\n",
" split=\"train[:100]\",\n",
" text_field=\"text\",\n",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "lm-buddy"
version = "0.8.0"
version = "0.9.0"
authors = [
{ name = "Sean Friedowitz", email = "[email protected]" },
{ name = "Aaron Gonzales", email = "[email protected]" },
Expand Down
22 changes: 11 additions & 11 deletions src/lm_buddy/buddy.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import wandb

from lm_buddy.integrations.wandb import WandbResumeMode
from lm_buddy.jobs._entrypoints import run_finetuning, run_lm_harness, run_prometheus, run_ragas
from lm_buddy.jobs.common import EvaluationResult, FinetuningResult, LMBuddyJobType
from lm_buddy.jobs.configs import (
from lm_buddy.configs.jobs import (
EvaluationJobConfig,
FinetuningJobConfig,
LMBuddyJobConfig,
JobConfig,
LMHarnessJobConfig,
PrometheusJobConfig,
RagasJobConfig,
)
from lm_buddy.jobs.common import EvaluationResult, FinetuningResult, JobType
from lm_buddy.jobs.evaluation.lm_harness import run_lm_harness
from lm_buddy.jobs.evaluation.prometheus import run_prometheus
from lm_buddy.jobs.evaluation.ragas import run_ragas
from lm_buddy.jobs.finetuning import run_finetuning
from lm_buddy.paths import strip_path_prefix
from lm_buddy.tracking.run_utils import WandbResumeMode


class LMBuddy:
Expand All @@ -25,10 +28,7 @@ def __init__(self):
pass

def _generate_artifact_lineage(
self,
config: LMBuddyJobConfig,
results: list[wandb.Artifact],
job_type: LMBuddyJobType,
self, config: JobConfig, results: list[wandb.Artifact], job_type: JobType
) -> None:
"""Link input artifacts and log output artifacts to a run.
Expand All @@ -51,7 +51,7 @@ def _generate_artifact_lineage(
def finetune(self, config: FinetuningJobConfig) -> FinetuningResult:
"""Run a supervised finetuning job with the provided configuration."""
result = run_finetuning(config)
self._generate_artifact_lineage(config, result.artifacts, LMBuddyJobType.FINETUNING)
self._generate_artifact_lineage(config, result.artifacts, JobType.FINETUNING)
return result

def evaluate(self, config: EvaluationJobConfig) -> EvaluationResult:
Expand All @@ -68,5 +68,5 @@ def evaluate(self, config: EvaluationJobConfig) -> EvaluationResult:
result = run_ragas(ragas_config)
case _:
raise ValueError(f"Invlid configuration for evaluation: {type(config)}")
self._generate_artifact_lineage(config, result.artifacts, LMBuddyJobType.EVALUATION)
self._generate_artifact_lineage(config, result.artifacts, JobType.EVALUATION)
return result
2 changes: 1 addition & 1 deletion src/lm_buddy/cli/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from lm_buddy import LMBuddy
from lm_buddy.cli.utils import parse_config_option
from lm_buddy.jobs.configs import LMHarnessJobConfig, PrometheusJobConfig, RagasJobConfig
from lm_buddy.configs.jobs import LMHarnessJobConfig, PrometheusJobConfig, RagasJobConfig


@click.group(name="evaluate", help="Run an LM Buddy evaluation job.")
Expand Down
2 changes: 1 addition & 1 deletion src/lm_buddy/cli/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from lm_buddy import LMBuddy
from lm_buddy.cli.utils import parse_config_option
from lm_buddy.jobs.configs import FinetuningJobConfig
from lm_buddy.configs.jobs import FinetuningJobConfig


@click.command(name="finetune", help="Run an LM Buddy finetuning job.")
Expand Down
4 changes: 2 additions & 2 deletions src/lm_buddy/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from pathlib import Path
from typing import TypeVar

from lm_buddy.jobs.configs.base import LMBuddyJobConfig
from lm_buddy.configs.jobs.common import JobConfig

ConfigType = TypeVar("ConfigType", bound=LMBuddyJobConfig)
ConfigType = TypeVar("ConfigType", bound=JobConfig)


def parse_config_option(config_cls: type[ConfigType], config: str) -> ConfigType:
Expand Down
File renamed without changes.
29 changes: 28 additions & 1 deletion src/lm_buddy/types.py → src/lm_buddy/configs/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import contextlib
import tempfile
from pathlib import Path
from typing import Annotated, Any

import torch
from pydantic import BaseModel, BeforeValidator, PlainSerializer, WithJsonSchema
from pydantic_yaml import parse_yaml_file_as, to_yaml_file


def validate_torch_dtype(x: Any) -> torch.dtype:
Expand All @@ -28,7 +32,7 @@ def validate_torch_dtype(x: Any) -> torch.dtype:
"""


class BaseLMBuddyConfig(
class LMBuddyConfig(
BaseModel,
extra="forbid",
arbitrary_types_allowed=True,
Expand All @@ -38,3 +42,26 @@ class BaseLMBuddyConfig(
Defines some common settings used by all subclasses.
"""

@classmethod
def from_yaml_file(cls, path: Path | str):
return parse_yaml_file_as(cls, path)

def to_yaml_file(self, path: Path | str):
to_yaml_file(path, self, exclude_none=True)

@contextlib.contextmanager
def to_tempfile(self, *, name: str = "config.yaml", dir: str | Path | None = None):
"""Enter a context manager with the config written to a temporary YAML file.
Keyword Args:
name (str): Name of the config file in the tmp directory. Defaults to "config.yaml".
dir (str | Path | None): Root path of the temporary directory. Defaults to None.
Returns:
Path to the temporary config file.
"""
with tempfile.TemporaryDirectory(dir=dir) as tmpdir:
config_path = Path(tmpdir) / name
self.to_yaml_file(config_path)
yield config_path
158 changes: 158 additions & 0 deletions src/lm_buddy/configs/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import dataclasses
from typing import Any

from peft import PeftConfig, PeftType, TaskType
from pydantic import field_validator, model_validator
from transformers import BitsAndBytesConfig

from lm_buddy.configs.common import LMBuddyConfig, SerializableTorchDtype
from lm_buddy.paths import AssetPath, PathPrefix

DEFAULT_TEXT_FIELD: str = "text"


class AutoModelConfig(LMBuddyConfig):
"""Settings passed to a HuggingFace AutoModel instantiation.
The model to load can either be a HuggingFace repo or an artifact reference on W&B.
"""

path: AssetPath
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype | None = None


class AutoTokenizerConfig(LMBuddyConfig):
"""Settings passed to a HuggingFace AutoTokenizer instantiation."""

path: AssetPath
trust_remote_code: bool | None = None
use_fast: bool | None = None


class DatasetConfig(LMBuddyConfig):
"""Settings passed to load a HuggingFace text dataset.
The dataset can either contain a single text column named by the `text_field` parameter,
or a `prompt_template` can be provided to format columns of the dataset as the `text_field`.
"""

path: AssetPath
text_field: str = DEFAULT_TEXT_FIELD
prompt_template: str | None = None
split: str | None = None
test_size: float | None = None
seed: int | None = None

@model_validator(mode="after")
def validate_split_if_huggingface_path(cls, config: "DatasetConfig"):
"""
Ensure a `split` is provided when loading a HuggingFace dataset directly from HF Hub.
This makes it such that the `load_dataset` function returns the type `Dataset`
instead of `DatasetDict`, which makes some of the downstream logic easier.
"""
if config.split is None and config.path.startswith(PathPrefix.HUGGINGFACE):
raise ValueError(
"A `split` must be specified when loading a dataset directly from HuggingFace."
)
return config


class AdapterConfig(LMBuddyConfig, extra="allow"):
"""Configuration containing PEFT adapter settings.
The type of adapter is controlled by the required `peft_type` field,
which must be one of the allowed values from the PEFT `PeftType` enumeration.
Extra arguments are allowed and are passed down to the HuggingFace `PeftConfig`
class determined by the `peft_type` argument.
The `task_type` for the adapter is also required.
By default, this is set to `TaskType.CAUSAL_LM`
which is appropriate for causal language model finetuning.
See the allowed values in the PEFT `TaskType` enumeration.
"""

peft_type: PeftType
task_type: TaskType = TaskType.CAUSAL_LM

@staticmethod
def _get_peft_config_class(peft_type: PeftType) -> type[PeftConfig]:
# Internal import to avoid bringing the global variable from peft into module scope
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING

return PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]

@field_validator("peft_type", "task_type", mode="before")
def sanitize_enum_args(cls, x):
if isinstance(x, str):
x = x.strip().upper()
return x

@model_validator(mode="after")
def validate_adapter_args(cls, config: "AdapterConfig"):
peft_type = config.peft_type

# PeftConfigs are standard dataclasses so can extract their allowed field names
adapter_cls = cls._get_peft_config_class(peft_type)
allowed_fields = {x.name for x in dataclasses.fields(adapter_cls)}

# Filter fields to those found on the PeftConfig
extra_fields = config.model_fields_set.difference(allowed_fields)
if extra_fields:
raise ValueError(f"Unknowon arguments for {peft_type} adapter: {extra_fields}")

return config

def as_huggingface(self) -> PeftConfig:
adapter_cls = self._get_peft_config_class(self.peft_type)
return adapter_cls(**self.model_dump())


class QuantizationConfig(LMBuddyConfig):
"""Basic quantization settings to pass to training and evaluation jobs.
Note that in order to use BitsAndBytes quantization on Ray,
you must ensure that the runtime environment is installed with GPU support.
This can be configured by setting the `entrypoint_num_gpus > 0` when submitting a job
to the cluster.
"""

load_in_8bit: bool | None = None
load_in_4bit: bool | None = None
bnb_4bit_quant_type: str = "fp4"
bnb_4bit_compute_dtype: SerializableTorchDtype | None = None

def as_huggingface(self) -> BitsAndBytesConfig:
return BitsAndBytesConfig(
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
bnb_4bit_compute_dtype=self.bnb_4bit_compute_dtype,
bnb_4bit_quant_type=self.bnb_4bit_quant_type,
)


class TrainerConfig(LMBuddyConfig):
"""Configuration for a HuggingFace trainer/training arguments.
This mainly encompasses arguments passed to the HuggingFace `TrainingArguments` class,
but also contains some additional parameters for the `Trainer` or `SFTTrainer` classes.
"""

max_seq_length: int | None = None
num_train_epochs: float | None = None
per_device_train_batch_size: int | None = None
per_device_eval_batch_size: int | None = None
learning_rate: float | None = None
weight_decay: float | None = None
gradient_accumulation_steps: int | None = None
gradient_checkpointing: bool | None = None
evaluation_strategy: str | None = None
eval_steps: float | None = None
logging_strategy: str | None = None
logging_steps: float | None = None
save_strategy: str | None = None
save_steps: int | None = None

def training_args(self) -> dict[str, Any]:
"""Return the arguments to the HuggingFace `TrainingArguments` class."""
return self.model_dump(exclude={"max_seq_length"}, exclude_none=True)
16 changes: 16 additions & 0 deletions src/lm_buddy/configs/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from lm_buddy.configs.jobs.common import JobConfig
from lm_buddy.configs.jobs.finetuning import FinetuningJobConfig
from lm_buddy.configs.jobs.lm_harness import LMHarnessJobConfig
from lm_buddy.configs.jobs.prometheus import PrometheusJobConfig
from lm_buddy.configs.jobs.ragas import RagasJobConfig

EvaluationJobConfig = LMHarnessJobConfig | PrometheusJobConfig | RagasJobConfig

__all__ = [
"JobConfig",
"FinetuningJobConfig",
"LMHarnessJobConfig",
"PrometheusJobConfig",
"RagasJobConfig",
"EvaluationJobConfig",
]
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from pydantic import Field
from pydantic_yaml import parse_yaml_file_as, to_yaml_file

from lm_buddy.integrations.wandb import WandbRunConfig
from lm_buddy.configs.common import LMBuddyConfig
from lm_buddy.configs.wandb import WandbRunConfig
from lm_buddy.paths import AssetPath, PathPrefix
from lm_buddy.types import BaseLMBuddyConfig


class LMBuddyJobConfig(BaseLMBuddyConfig):
class JobConfig(LMBuddyConfig):
"""Configuration that comprises the entire input to an LM Buddy job.
This class implements helper methods for de/serializing the configuration from file.
Expand Down
Loading

0 comments on commit 404a35c

Please sign in to comment.