Skip to content

Commit

Permalink
Refactor LoRA argument names
Browse files Browse the repository at this point in the history
  • Loading branch information
armbues committed Dec 2, 2024
1 parent 9987307 commit 42e7e88
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions sillm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
parser.add_argument("-m", "--save_merge", default=None, type=str, help="Save merged model weights to .safetensors file")
parser.add_argument("--max_length", default=1024, type=int, help="Max token length per training dataset entry (default: 1024)")
parser.add_argument("--template", type=str, default=None, help="Chat template (chatml, llama-2, alpaca, etc.)")
parser.add_argument("--layers", default=0, type=int, help="Layers to use for LoRA (default: 0 for all layers)")
parser.add_argument("--target_modules", default="query_value", type=str, help="Target modules to use for LoRA: query_value, all_linear")
parser.add_argument("--rank", default=8, type=int, help="Rank to use for LoRA (default: 8)")
parser.add_argument("--dropout", default=0.0, type=int, help="Dropout to use for LoRA (default: 0.0)")
parser.add_argument("--scale", default=10.0, type=float, help="Scale to use for LoRA (default: 10.0)")
parser.add_argument("--lora_layers", default=0, type=int, help="Layers to use for LoRA (default: 0 for all layers)")
parser.add_argument("--lora_modules", default="query_value", type=str, help="Target modules to use for LoRA: query_value, all_linear")
parser.add_argument("--lora_rank", default=8, type=int, help="Rank to use for LoRA (default: 8)")
parser.add_argument("--lora_dropout", default=0.0, type=int, help="Dropout to use for LoRA (default: 0.0)")
parser.add_argument("--lora_scale", default=10.0, type=float, help="Scale to use for LoRA (default: 10.0)")
parser.add_argument("--optimizer", type=str, default="adam", help="Optimizer type (default: adam)")
parser.add_argument("--grad_checkpoint", default=False, action="store_true", help="Use gradient checkpointing")
parser.add_argument("--grad_accu_steps", type=int, default=1, help="Gradient accumulation steps (default: 1)")
Expand Down Expand Up @@ -81,11 +81,11 @@

# Initialize LoRA layers
lora_config = {
"num_layers": args.layers,
"target_modules": args.target_modules,
"rank": args.rank,
"dropout": args.dropout,
"scale": args.scale
"num_layers": args.lora_layers,
"target_modules": args.lora_modules,
"rank": args.lora_rank,
"dropout": args.lora_dropout,
"scale": args.lora_scale
}
model.init_lora(**lora_config)

Expand Down

0 comments on commit 42e7e88

Please sign in to comment.