Skip to content

Commit

Permalink
Add 3/6 bit quantization to args
Browse files Browse the repository at this point in the history
  • Loading branch information
armbues committed Nov 28, 2024
1 parent 168fd9d commit 9987307
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
6 changes: 6 additions & 0 deletions sillm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
parser.add_argument("--system_prompt", type=str, default=None, help="System prompt for chat template")
parser.add_argument("--ascii", default=False, action="store_true", help="Force output tokens to ASCII printable characters")
parser.add_argument("-q2", default=False, action="store_true", help="Quantize the model to 2 bits")
parser.add_argument("-q3", default=False, action="store_true", help="Quantize the model to 3 bits")
parser.add_argument("-q4", default=False, action="store_true", help="Quantize the model to 4 bits")
parser.add_argument("-q6", default=False, action="store_true", help="Quantize the model to 6 bits")
parser.add_argument("-q8", default=False, action="store_true", help="Quantize the model to 8 bits")
parser.add_argument("-v", "--verbose", default=1, action="count", help="Increase output verbosity")
args = parser.parse_args()
Expand Down Expand Up @@ -70,8 +72,12 @@
# Quantize model
if args.q2 is True:
model.quantize(bits=2)
elif args.q3 is True:
model.quantize(bits=3)
elif args.q4 is True:
model.quantize(bits=4)
elif args.q6 is True:
model.quantize(bits=6)
elif args.q8 is True:
model.quantize(bits=8)

Expand Down
11 changes: 10 additions & 1 deletion sillm/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
parser.add_argument("output", type=str, help="The output model directory or file")
parser.add_argument("-a", "--input_adapters", default=None, type=str, help="Load LoRA adapter weights from .safetensors file")
parser.add_argument("-r", "--remap", default=False, action="store_true", help="Remap weights keys to native SiLLM format")
parser.add_argument("-q2", default=False, action="store_true", help="Quantize the model to 2 bits")
parser.add_argument("-q3", default=False, action="store_true", help="Quantize the model to 3 bits")
parser.add_argument("-q4", default=False, action="store_true", help="Quantize the model to 4 bits")
parser.add_argument("-q6", default=False, action="store_true", help="Quantize the model to 6 bits")
parser.add_argument("-q8", default=False, action="store_true", help="Quantize the model to 8 bits")
parser.add_argument("-dtype", type=str, default=None, help="Cast model weights to specified data type")
parser.add_argument("-m", "--max_shard_size", type=int, default=5368709120, help="Maximum shard size in bytes")
Expand Down Expand Up @@ -42,8 +45,14 @@
model.merge_and_unload_lora()

# Quantize model or cast model weights
if args.q4 is True:
if args.q2 is True:
model.quantize(bits=2)
elif args.q3 is True:
model.quantize(bits=3)
elif args.q4 is True:
model.quantize(bits=4)
elif args.q6 is True:
model.quantize(bits=6)
elif args.q8 is True:
model.quantize(bits=8)
elif args.dtype is not None:
Expand Down

0 comments on commit 9987307

Please sign in to comment.