Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Explicit api for LMBuddy tasks #83

Closed
wants to merge 14 commits into from
2 changes: 0 additions & 2 deletions src/lm_buddy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
from lm_buddy.jobs import run_job

__all__ = ["run_job"]
1 change: 0 additions & 1 deletion src/lm_buddy/integrations/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# ruff: noqa: I001
from lm_buddy.integrations.huggingface.repo_config import *
from lm_buddy.integrations.huggingface.adapter_config import *
from lm_buddy.integrations.huggingface.dataset_config import *
from lm_buddy.integrations.huggingface.model_config import *
Expand Down
37 changes: 16 additions & 21 deletions src/lm_buddy/integrations/huggingface/asset_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from pathlib import Path

import torch
from accelerate import Accelerator
Expand All @@ -17,17 +18,15 @@
AutoModelConfig,
AutoTokenizerConfig,
DatasetConfig,
HuggingFaceRepoConfig,
HuggingFaceRepoID,
QuantizationConfig,
)
from lm_buddy.integrations.wandb import (
ArtifactLoader,
WandbArtifactConfig,
get_artifact_filesystem_path,
)

HuggingFaceAssetPath = HuggingFaceRepoConfig | WandbArtifactConfig
"""Config that can be resolved to a HuggingFace name/path."""
from lm_buddy.paths import LoadableAssetPath


def resolve_peft_and_pretrained(path: str) -> tuple[str, str | None]:
Expand Down Expand Up @@ -66,23 +65,23 @@ class HuggingFaceAssetLoader:
def __init__(self, artifact_loader: ArtifactLoader):
self._artifact_loader = artifact_loader

def resolve_asset_path(self, path: HuggingFaceAssetPath) -> tuple[str, str | None]:
def resolve_asset_path(self, path: LoadableAssetPath) -> str:
"""Resolve the actual HuggingFace name/path from a config.

Currently, two config types contain references to a loadable HuggingFace path:
(1) A `HuggingFaceRepoConfig` that contains the repo path directly
(2) A `WandbArtifactConfig` where the filesystem path is resolved from the artifact
"""
match path:
case HuggingFaceRepoConfig(repo_id, revision):
load_path, revision = repo_id, revision
case Path() as x:
return str(x)
case HuggingFaceRepoID(repo_id):
return repo_id
case WandbArtifactConfig() as artifact_config:
artifact = self._artifact_loader.use_artifact(artifact_config)
load_path = get_artifact_filesystem_path(artifact)
revision = None
return str(get_artifact_filesystem_path(artifact))
case unknown_path:
raise ValueError(f"Unable to resolve asset path from {unknown_path}.")
return str(load_path), revision

def load_pretrained_config(
self,
Expand All @@ -92,10 +91,8 @@ def load_pretrained_config(

An exception is raised if the HuggingFace repo does not contain a `config.json` file.
"""
model_path, revision = self.resolve_asset_path(config.load_from)
return AutoConfig.from_pretrained(
pretrained_model_name_or_path=model_path, revision=revision
)
config_path = self.resolve_asset_path(config.path)
return AutoConfig.from_pretrained(pretrained_model_name_or_path=config_path)

def load_pretrained_model(
self,
Expand All @@ -122,10 +119,9 @@ def load_pretrained_model(

# TODO: HuggingFace has many AutoModel classes with different "language model heads"
# Can we abstract this to load with any type of AutoModel class?
model_path, revision = self.resolve_asset_path(config.load_from)
model_path = self.resolve_asset_path(config.path)
return AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_path,
revision=revision,
trust_remote_code=config.trust_remote_code,
torch_dtype=config.torch_dtype,
quantization_config=bnb_config,
Expand All @@ -137,10 +133,9 @@ def load_pretrained_tokenizer(self, config: AutoTokenizerConfig) -> PreTrainedTo

An exception is raised if the HuggingFace repo does not contain a `tokenizer.json` file.
"""
tokenizer_path, revision = self.resolve_asset_path(config.load_from)
tokenizer_path = self.resolve_asset_path(config.path)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=tokenizer_path,
revision=revision,
trust_remote_code=config.trust_remote_code,
use_fast=config.use_fast,
)
Expand All @@ -156,10 +151,10 @@ def load_dataset(self, config: DatasetConfig) -> Dataset:
When loading from HuggingFace directly, the `Dataset` is for the provided split.
When loading from disk, the saved files must be for a dataset else an exception is raised.
"""
dataset_path, revision = self.resolve_asset_path(config.load_from)
dataset_path = self.resolve_asset_path(config.path)
# Dataset loading requires a different method if from a HF vs. disk
if isinstance(config.load_from, HuggingFaceRepoConfig):
return load_dataset(dataset_path, revision=revision, split=config.split)
if isinstance(config.path, HuggingFaceRepoID):
return load_dataset(dataset_path, split=config.split)
else:
match load_from_disk(dataset_path):
case Dataset() as dataset:
Expand Down
15 changes: 5 additions & 10 deletions src/lm_buddy/integrations/huggingface/dataset_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pydantic import field_validator, model_validator
from pydantic import model_validator

from lm_buddy.integrations.huggingface import HuggingFaceRepoConfig, convert_string_to_repo_config
from lm_buddy.integrations.wandb import WandbArtifactConfig
from lm_buddy.paths import HuggingFaceRepoID, LoadableAssetPath
from lm_buddy.types import BaseLMBuddyConfig

DEFAULT_TEXT_FIELD: str = "text"
Expand All @@ -10,25 +9,21 @@
class DatasetConfig(BaseLMBuddyConfig):
"""Base configuration to load a HuggingFace dataset."""

load_from: HuggingFaceRepoConfig | WandbArtifactConfig
path: LoadableAssetPath
split: str | None = None
test_size: float | None = None
seed: int | None = None

_validate_load_from_string = field_validator("load_from", mode="before")(
convert_string_to_repo_config
)

@model_validator(mode="after")
def validate_split_if_huggingface_repo(cls, config: "DatasetConfig"):
"""
Ensure a `split` is provided when loading a HuggingFace dataset directly from HF Hub.
This makes it such that the `load_dataset` function returns the type `Dataset`
instead of `DatasetDict`, which makes some of the downstream logic easier.
"""
load_from = config.load_from
path = config.path
split = config.split
if split is None and isinstance(load_from, HuggingFaceRepoConfig):
if split is None and isinstance(path, HuggingFaceRepoID):
raise ValueError(
"A `split` must be specified when loading a dataset directly from HuggingFace."
)
Expand Down
11 changes: 2 additions & 9 deletions src/lm_buddy/integrations/huggingface/model_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from pydantic import field_validator

from lm_buddy.integrations.huggingface import HuggingFaceRepoConfig, convert_string_to_repo_config
from lm_buddy.integrations.wandb import WandbArtifactConfig
from lm_buddy.paths import LoadableAssetPath
from lm_buddy.types import BaseLMBuddyConfig, SerializableTorchDtype


Expand All @@ -11,10 +8,6 @@ class AutoModelConfig(BaseLMBuddyConfig):
The model to load can either be a HuggingFace repo or an artifact reference on W&B.
"""

load_from: HuggingFaceRepoConfig | WandbArtifactConfig
path: LoadableAssetPath
trust_remote_code: bool = False
torch_dtype: SerializableTorchDtype | None = None

_validate_load_from_string = field_validator("load_from", mode="before")(
convert_string_to_repo_config
)
11 changes: 2 additions & 9 deletions src/lm_buddy/integrations/huggingface/tokenizer_config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
from pydantic import field_validator

from lm_buddy.integrations.huggingface import HuggingFaceRepoConfig, convert_string_to_repo_config
from lm_buddy.integrations.wandb import WandbArtifactConfig
from lm_buddy.paths import LoadableAssetPath
from lm_buddy.types import BaseLMBuddyConfig


class AutoTokenizerConfig(BaseLMBuddyConfig):
"""Settings passed to a HuggingFace AutoTokenizer instantiation."""

load_from: HuggingFaceRepoConfig | WandbArtifactConfig
path: LoadableAssetPath
trust_remote_code: bool | None = None
use_fast: bool | None = None

_validate_load_from_string = field_validator("load_from", mode="before")(
convert_string_to_repo_config
)
4 changes: 2 additions & 2 deletions src/lm_buddy/integrations/vllm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from lm_buddy.integrations.huggingface import HuggingFaceAssetPath
from lm_buddy.integrations.huggingface import LoadableAssetPath
from lm_buddy.types import BaseLMBuddyConfig


Expand All @@ -16,7 +16,7 @@ class InferenceServerConfig(BaseLMBuddyConfig):
"""

base_url: str
engine: str | HuggingFaceAssetPath | None = None
engine: str | LoadableAssetPath | None = None


class VLLMCompletionsConfig(BaseLMBuddyConfig):
Expand Down
4 changes: 2 additions & 2 deletions src/lm_buddy/integrations/wandb/artifact_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def use_artifact(self, config: WandbArtifactConfig) -> wandb.Artifact:
"""
pass

def log_artifact(self, artifact: wandb.Artifact) -> None:
def log_artifact(self, artifact: wandb.Artifact) -> wandb.Artifact:
"""Log an artifact, declaring it as an output of the currently active W&B run."""
pass

Expand All @@ -40,5 +40,5 @@ def use_artifact(self, config: WandbArtifactConfig) -> wandb.Artifact:
api = wandb.Api()
return api.artifact(config.wandb_path())

def log_artifact(self, artifact: wandb.Artifact) -> None:
def log_artifact(self, artifact: wandb.Artifact) -> wandb.Artifact:
return wandb.log_artifact(artifact)
38 changes: 0 additions & 38 deletions src/lm_buddy/jobs/__init__.py

This file was deleted.

6 changes: 0 additions & 6 deletions src/lm_buddy/jobs/_entrypoints/__init__.py

This file was deleted.

114 changes: 0 additions & 114 deletions src/lm_buddy/jobs/_entrypoints/finetuning.py

This file was deleted.

Loading
Loading