Skip to content

Commit

Permalink
temp commit
Browse files Browse the repository at this point in the history
  • Loading branch information
SLR722 committed Nov 26, 2024
1 parent 900b055 commit d7598c6
Show file tree
Hide file tree
Showing 6 changed files with 491 additions and 3 deletions.
11 changes: 8 additions & 3 deletions llama_stack/apis/post_training/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
import torch


class OptimizerType(Enum):
Expand All @@ -30,18 +31,22 @@ class OptimizerConfig(BaseModel):
lr: float
lr_min: float
weight_decay: float
num_warmup_steps: int


@json_schema_type
class TrainingConfig(BaseModel):
dtype: torch.dtype
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
batch_size: int
shuffle: bool
n_iters: int
# n_iters: int

enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: bool
fsdp_cpu_offload: bool
memory_efficient_fsdp_wrap: Optional[bool]
fsdp_cpu_offload: Optional[bool]


@json_schema_type
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Optional
from pydantic import BaseModel, Field,


class MetaReferencePostTrainingConfig(BaseModel):
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
torch_seed: Optional[int] = None
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
# can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Callable, Dict, List, Mapping, Optional

import numpy as np

from datasets import load_dataset
from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import validate_messages
from torchtune.modules.transforms import Transform


class SFTDataset(Dataset):
def __init__(
self,
rows: List[Dict[str, Any]],
message_transform: Transform,
model_transform: Transform,
) -> None:
self._rows = rows
self._message_transform = message_transform
self._model_transform = model_transform

def __len__(self):
return len(self._rows)

def __getitem__(self, index: int) -> Dict[str, Any]:
sample = self._rows[index]
return self._prepare_sample(sample)

def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
transformed_sample = self._message_transform(sample)
if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"])

tokenized_dict = self._model_transform(transformed_sample)

if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys())
error_message = (
"model_transform returned the following keys: "
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
)
raise ValueError(error_message)

# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
tokenized_dict["labels"] = list(
np.where(
tokenized_dict["mask"],
CROSS_ENTROPY_IGNORE_IDX,
tokenized_dict["tokens"],
)
)
assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"])

return tokenized_dict
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
class MetaReferencePostTrainingImpl:
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config

def supervised_fine_tune(
self,
job_uuid: str,
model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
logger_config: Dict[str, Any],
) -> PostTrainingJob:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = PostTrainingSFTRequest(
job_uuid=job_uuid,
model=model,
dataset_id=dataset_id,
validation_dataset_id=validation_dataset_id,
algorithm=algorithm,
algorithm_config=algorithm_config,
optimizer_config=optimizer_config,
training_config=training_config,
logger_config=logger_config,
)
if request.algorithm == FinetuningAlgorithm.lora:
recipe = LoraFinetuningRecipeSingleDevice(self.config, request)
recipe.train()
else:
raise NotImplementedError()
Loading

0 comments on commit d7598c6

Please sign in to comment.