-
Notifications
You must be signed in to change notification settings - Fork 0
/
lora.py
108 lines (90 loc) · 2.62 KB
/
lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from argparse import ArgumentParser
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from unsloth import FastLanguageModel, is_bfloat16_supported, PatchDPOTrainer
from torch import nn
from typing import Dict, Union, Any
import gc
PatchDPOTrainer()
from datasets import load_dataset
import torch
from trl import ORPOConfig, ORPOTrainer
parser = ArgumentParser()
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--batch_size", type=int, default=1)
args = parser.parse_args()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
max_seq_length = 8192
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/gemma-2-9b-it-bnb-4bit",
max_seq_length=max_seq_length,
dtype=None,
load_in_4bit=True,
)
tokenizer.padding_side = "right"
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=42,
max_seq_length=max_seq_length,
use_rslora=False,
loftq_config=None,
)
cfg = ORPOConfig(
num_train_epochs=2,
learning_rate=5e-5,
do_train=True,
logging_steps=5,
save_strategy="steps",
save_steps=250,
output_dir="model-result",
save_total_limit=2,
push_to_hub=False,
per_device_train_batch_size=args.batch_size,
optim="adamw_8bit",
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
remove_unused_columns=False,
max_length=max_seq_length,
max_prompt_length=max_seq_length,
)
dataset = load_dataset("neody/null-instruct-ja", split="train")
def func(example):
prompts = []
chosens = []
rejecteds = []
for i, user in enumerate(example["user"]):
prompts.append(user)
chosens.append(example["model"][i]),
rejecteds.append(example["reject"][i]),
return {
"prompt": prompts,
"chosen": chosens,
"rejected": rejecteds,
}
dataset = dataset.map(func, batched=True, remove_columns=list(dataset.features))
class Trainer(ORPOTrainer):
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
) -> torch.Tensor:
loss_step = super().training_step(model, inputs)
torch.cuda.empty_cache()
gc.collect()
return loss_step
trainer = Trainer(model=model, train_dataset=dataset, args=cfg, tokenizer=tokenizer)
trainer.train(args.resume)