Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support finetuning with LoRA #431

Merged
merged 9 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions optimum/graphcore/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import poptorch
import torch
import torch.nn.functional as F
from peft import PeftModel, PeftType, get_peft_model
from torch import nn
from transformers import PreTrainedModel

Expand Down Expand Up @@ -53,9 +54,11 @@ def wrapper(cls):


def to_pipelined(model: nn.Module, ipu_config: IPUConfig, force: bool = False):
model_cls = model.__class__
model_cls = model.get_base_model().__class__ if isinstance(model, PeftModel) else model.__class__
pipelined_cls = _PRETRAINED_TO_PIPELINED_REGISTRY.get(model_cls, None)
if pipelined_cls is not None:
if pipelined_cls is not None and isinstance(model, PeftModel):
return pipelined_cls.from_peft(model, ipu_config)
elif pipelined_cls is not None:
return pipelined_cls.from_transformers(model, ipu_config)
# If the user defined his/her own model and already subclassed from PipelineMixin. I.e., the model is already pipelined.
elif isinstance(model, PipelineMixin):
Expand Down Expand Up @@ -92,9 +95,9 @@ def from_transformers(cls, model: PreTrainedModel, ipu_config: IPUConfig):
config = copy.deepcopy(model.config)
generation_config = copy.deepcopy(model.generation_config)
pipelined_model = cls(config)
pipelined_model.generation_config = generation_config
pipelined_model.load_state_dict(model.state_dict())
pipelined_model.ipu_config = copy.deepcopy(ipu_config)
pipelined_model.generation_config = generation_config
pipelined_model.training = model.training
return pipelined_model

Expand All @@ -120,6 +123,42 @@ def from_pretrained_transformers(cls, model_name_or_path: str, ipu_config: IPUCo
pipelined_model.ipu_config = copy.deepcopy(ipu_config)
return pipelined_model

@classmethod
def from_peft(cls, model: PeftModel, ipu_config: IPUConfig):
"""
Creates a pipelined version of model from a [`~peft.PeftModel`] instance.

Currently, only `peft.PeftType.LORA` is supported.

Args:
model ([`~peft.PeftModel`]):
The model to convert to a pipelined model.
ipu_config ([`IPUConfig`]):
The `IPUConfig` instance of the pipelined model.

Returns:
An instance of `peft.PeftModel` wrapping a pipelined version of the base model.
"""
# Technically speaking, instead of returning an instance of a `PipelineMixin`, such as Pipelined<Model>For<Task>,
# we return an instance of a `peft.PeftModel` which wraps such a pipelined model and defers attribute access.
if len(model.peft_config) > 1 or model.active_adapter != "default":
raise ValueError("Currently only `PeftModel` instances with the `'default'` adapter are supported.")
if model.peft_type != PeftType.LORA:
raise ValueError(f"Currently only LoRA is supported, received {model.peft_type}.")

pretrained = model.get_base_model()
config = copy.deepcopy(pretrained.config)
generation_config = copy.deepcopy(pretrained.generation_config)
peft_config = model.active_peft_config

pipelined_model = cls(config)
pipelined_model.ipu_config = copy.deepcopy(ipu_config)
pipelined_model.generation_config = generation_config
peft_pipelined_model = get_peft_model(pipelined_model, peft_config)
peft_pipelined_model.load_state_dict(model.state_dict())
peft_pipelined_model.training = model.training
return peft_pipelined_model

@classmethod
def from_model(cls, model: nn.Module):
clone = copy.deepcopy(model)
Expand Down
7 changes: 5 additions & 2 deletions optimum/graphcore/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def forward(
if layer_head_mask is not None:
raise ValueError("layer_head_mask is not supported yet with serialized attention.")

if self.dropout or self.training:
if self.dropout and self.training:
raise ValueError("dropout is not supported yet with serialized attention.")

if attention_mask is not None:
Expand Down Expand Up @@ -594,10 +594,13 @@ def parallelize(self, for_generation=False, use_cache=False, use_cross_cache=Fal
)
logger.info(f"Decoder Embedding --> IPU {decoder_embedding_ipu}")

prev_ipu = decoder_layer_ipu[0]
for index, (layer, ipu) in enumerate(zip(self.model.decoder.layers, decoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
if ipu != prev_ipu:
self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
prev_ipu = ipu
logger.info(f"Decoder {index:<2} --> IPU {ipu}")

self.model.decoder.layer_norm = poptorch.BeginBlock(
Expand Down
6 changes: 6 additions & 0 deletions optimum/graphcore/pipelines/__init__.py
katalinic-gc marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import poptorch
import torch
import transformers.pipelines
from peft import PeftModel
from transformers import (
AudioClassificationPipeline,
AutoFeatureExtractor,
Expand Down Expand Up @@ -375,6 +376,11 @@ def pipeline(
break
except ValueError:
continue
elif isinstance(model, PeftModel):
raise TypeError(
"Instead of providing `model` as an instance of `PeftModel`, please call `merge_and_unload()` if LoRA "
"or equivalent to obtain the original `PreTrainedModel` back with adapter weights merged in."
)
elif isinstance(model, PreTrainedModel):
if tokenizer is None and load_tokenizer:
raise ValueError("If you pass a model as a PreTrainedModel, you must pass a tokenizer as well")
Expand Down
83 changes: 66 additions & 17 deletions optimum/graphcore/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import torch
from huggingface_hub import Repository
from packaging import version
from peft import PeftModel
from poptorch import DataLoaderMode, PoplarExecutor
from poptorch.optim import LAMB, AdamW
from torch import nn, optim
Expand Down Expand Up @@ -125,6 +126,9 @@
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

# TODO: Import from transformers.utils when updating transformers version.
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"


@dataclass
class IPUTrainerState(TrainerState):
Expand Down Expand Up @@ -841,20 +845,24 @@ def create_optimizer(self):
bias_parameters = {n for n, _ in self.model.named_parameters() if "bias" in n}
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
"params": [
p for n, p in self.model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
# Disable LAMB updates for bias parameters
"params": [p for n, p in self.model.named_parameters() if n in bias_parameters],
"params": [
p for n, p in self.model.named_parameters() if (n in bias_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"max_weight_norm": 0.0,
},
{
"params": [
p
for n, p in self.model.named_parameters()
if n not in decay_parameters and n not in bias_parameters
if n not in decay_parameters and n not in bias_parameters and p.requires_grad
],
"weight_decay": 0.0,
},
Expand All @@ -868,11 +876,17 @@ def create_optimizer(self):
else:
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
"params": [
p for n, p in self.model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
"params": [
p
for n, p in self.model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
Expand Down Expand Up @@ -1326,15 +1340,25 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
if model is None:
model = self.model

if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)

if not any(
os.path.isfile(f)
for f in [
weights_file,
weights_index_file,
adapter_weights_file,
]
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

logger.info(f"Loading model from {resume_from_checkpoint}.")

if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
if os.path.isfile(config_file):
config = PretrainedConfig.from_json_file(config_file)
checkpoint_version = config.transformers_version
if checkpoint_version is not None and checkpoint_version != __version__:
logger.warning(
Expand All @@ -1343,23 +1367,46 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
"yield to errors or unwanted behavior."
)

if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
if os.path.isfile(weights_file):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
state_dict = torch.load(weights_file, map_location="cpu")
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
# release memory
del state_dict
self._issue_warnings_after_load(load_result)

# Load adapters following PR # 24096 (> 4.29.2)
elif isinstance(model, PeftModel):
# If training a model using PEFT & LoRA, assume that adapter has been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint):
model.load_adapter(resume_from_checkpoint, model.active_adapter)
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
"Check some examples here: https://github.com/huggingface/peft/issues/96"
)
else:
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")

def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = self.model
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
if os.path.exists(best_model_path):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
self._load_state_dict_in_model(state_dict)
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
if os.path.exists(best_model_path) or os.path.exists(best_adapter_model_path):
if isinstance(model, PeftModel):
# If training a model using PEFT & LoRA, assume that adapter has been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(best_adapter_model_path):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
self._load_state_dict_in_model(state_dict)
else:
logger.warning(
f"Could not locate the best model at {best_model_path}. If you are running a distributed training "
Expand Down Expand Up @@ -1677,8 +1724,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):

# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `transformers.PreTrainedModel`, only saving its state dict.")
if not isinstance(self.model, (PreTrainedModel, PeftModel)):
logger.info(
"Trainer.model is not a `transformers.PreTrainedModel` or `peft.PeftModel`, only saving its state dict."
)
if state_dict is None:
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"optimum==1.6.1",
"diffusers[torch]==0.12.1",
"cppimport==22.8.2",
"peft==0.3.0",
"datasets",
"tokenizers",
"typeguard",
Expand Down
Loading