diff --git a/sillm/chat.py b/sillm/chat.py index 3ae9329..6c22e8e 100644 --- a/sillm/chat.py +++ b/sillm/chat.py @@ -26,7 +26,9 @@ parser.add_argument("--system_prompt", type=str, default=None, help="System prompt for chat template") parser.add_argument("--ascii", default=False, action="store_true", help="Force output tokens to ASCII printable characters") parser.add_argument("-q2", default=False, action="store_true", help="Quantize the model to 2 bits") + parser.add_argument("-q3", default=False, action="store_true", help="Quantize the model to 3 bits") parser.add_argument("-q4", default=False, action="store_true", help="Quantize the model to 4 bits") + parser.add_argument("-q6", default=False, action="store_true", help="Quantize the model to 6 bits") parser.add_argument("-q8", default=False, action="store_true", help="Quantize the model to 8 bits") parser.add_argument("-v", "--verbose", default=1, action="count", help="Increase output verbosity") args = parser.parse_args() @@ -70,8 +72,12 @@ # Quantize model if args.q2 is True: model.quantize(bits=2) + elif args.q3 is True: + model.quantize(bits=3) elif args.q4 is True: model.quantize(bits=4) + elif args.q6 is True: + model.quantize(bits=6) elif args.q8 is True: model.quantize(bits=8) diff --git a/sillm/convert.py b/sillm/convert.py index ca72c68..4cac815 100644 --- a/sillm/convert.py +++ b/sillm/convert.py @@ -10,7 +10,10 @@ parser.add_argument("output", type=str, help="The output model directory or file") parser.add_argument("-a", "--input_adapters", default=None, type=str, help="Load LoRA adapter weights from .safetensors file") parser.add_argument("-r", "--remap", default=False, action="store_true", help="Remap weights keys to native SiLLM format") + parser.add_argument("-q2", default=False, action="store_true", help="Quantize the model to 2 bits") + parser.add_argument("-q3", default=False, action="store_true", help="Quantize the model to 3 bits") parser.add_argument("-q4", default=False, action="store_true", help="Quantize the model to 4 bits") + parser.add_argument("-q6", default=False, action="store_true", help="Quantize the model to 6 bits") parser.add_argument("-q8", default=False, action="store_true", help="Quantize the model to 8 bits") parser.add_argument("-dtype", type=str, default=None, help="Cast model weights to specified data type") parser.add_argument("-m", "--max_shard_size", type=int, default=5368709120, help="Maximum shard size in bytes") @@ -42,8 +45,14 @@ model.merge_and_unload_lora() # Quantize model or cast model weights - if args.q4 is True: + if args.q2 is True: + model.quantize(bits=2) + elif args.q3 is True: + model.quantize(bits=3) + elif args.q4 is True: model.quantize(bits=4) + elif args.q6 is True: + model.quantize(bits=6) elif args.q8 is True: model.quantize(bits=8) elif args.dtype is not None: