diff --git a/flagai/auto_model/auto_loader.py b/flagai/auto_model/auto_loader.py index 726dc340..77831231 100755 --- a/flagai/auto_model/auto_loader.py +++ b/flagai/auto_model/auto_loader.py @@ -173,6 +173,7 @@ def __init__(self, qlora_dir=None, inference_mode=True, model_max_length=None, + all_devices=False, **kwargs): """ Args: @@ -277,7 +278,12 @@ def __init__(self, quantization_config=quantization_config) model.eval() if not quantization_config: - model.to(device) + if all_devices is True: + from accelerate import load_checkpoint_and_dispatch + model = load_checkpoint_and_dispatch( + model, download_path, device_map="balanced", no_split_module_classes=["AquilaDecoderLayer"]) + else: + model.to(device) if lora_dir: from flagai.model.tools.peft import PeftModel model = PeftModel.from_pretrained(model, lora_dir)