From 3c0cb59942ed8a770e3c1d7b818cdee57db4fc6f Mon Sep 17 00:00:00 2001 From: ftgreat Date: Mon, 9 Oct 2023 11:35:57 +0800 Subject: [PATCH 1/3] enabled auto torchdtype detecting Signed-off-by: ftgreat --- flagai/auto_model/auto_loader.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/flagai/auto_model/auto_loader.py b/flagai/auto_model/auto_loader.py index 68521601..3f2a0945 100755 --- a/flagai/auto_model/auto_loader.py +++ b/flagai/auto_model/auto_loader.py @@ -210,7 +210,13 @@ def __init__(self, if task_name == "aquila2": from flagai.model.aquila2.modeling_aquila import AquilaForCausalLM download_path = os.path.join(model_dir, model_name) - + + if not torch_dtype: + if model_name.lower() == "aquilachat2-34b": + torch_dtype = torch.bfloat16 + else: + torch_dtype = torch.float16 + if not os.path.exists(download_path): # Try to download from ModelHub try: From 336b31f23f952f4b5e3464f4cea35619529d09f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E7=85=A7=E4=B8=9C?= Date: Tue, 10 Oct 2023 08:30:44 +0000 Subject: [PATCH 2/3] enabled aquila2 quantize MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 严照东 --- flagai/auto_model/auto_loader.py | 19 ++++++++++--------- setup.py | 8 ++++---- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/flagai/auto_model/auto_loader.py b/flagai/auto_model/auto_loader.py index 3f2a0945..ea9add65 100755 --- a/flagai/auto_model/auto_loader.py +++ b/flagai/auto_model/auto_loader.py @@ -261,9 +261,10 @@ def __init__(self, for file_to_load in model_files: if "pytorch_model-0" in file_to_load: _get_checkpoint_path(download_path, file_to_load, - model_id) - - if qlora_dir: + model_id) + if 'quantization_config' in kwargs: + quantization_config = kwargs['quantization_config'] + elif qlora_dir: from transformers import BitsAndBytesConfig quantization_config=BitsAndBytesConfig( load_in_4bit=True, @@ -271,14 +272,14 @@ def __init__(self, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype, ) + else: + quantization_config = None if inference_mode: - if qlora_dir: - model = AquilaForCausalLM.from_pretrained(download_path,low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, - quantization_config=quantization_config) - else: - model = AquilaForCausalLM.from_pretrained(download_path,low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype,) + + model = AquilaForCausalLM.from_pretrained(download_path,low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, + quantization_config=quantization_config) model.eval() - if not qlora_dir: + if not quantization_config: model.to(device) if lora_dir: from flagai.model.tools.peft import PeftModel diff --git a/setup.py b/setup.py index 3a929b7b..da686a57 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="flagai", - version="v1.8.0", + version="v1.8.1", description="FlagAI aims to help researchers and developers to freely train and test large-scale models for NLP/CV/VL tasks.", long_description=open("README.md", encoding="utf-8").read(), long_description_content_type="text/markdown", @@ -19,19 +19,19 @@ install_requires=[ 'nltk>=3.6.7', 'sentencepiece>=0.1.96', - 'boto3==1.17.32', + 'boto3>=1.17.32', 'pandas>=1.3.5', 'jieba>=0.42.1', 'scikit-learn>=1.0.2', 'tensorboard>=2.9.0', 'transformers>=4.31.0', 'datasets>=2.0.0', - 'setuptools==66.0.0', + 'setuptools>=66.0.0', 'protobuf==3.19.6', 'ftfy', 'Pillow>=9.3.0', 'einops>=0.3.0', - 'diffusers==0.7.2', + 'diffusers>=0.7.2', 'pytorch-lightning>=1.6.5', 'taming-transformers-rom1504==0.0.6', 'rouge-score', From dade181dd2722967570b2dd344d85963d7f7be5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E7=85=A7=E4=B8=9C?= Date: Tue, 10 Oct 2023 08:43:19 +0000 Subject: [PATCH 3/3] updated loader MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 严照东 --- flagai/auto_model/auto_loader.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/flagai/auto_model/auto_loader.py b/flagai/auto_model/auto_loader.py index ea9add65..726dc340 100755 --- a/flagai/auto_model/auto_loader.py +++ b/flagai/auto_model/auto_loader.py @@ -211,12 +211,9 @@ def __init__(self, from flagai.model.aquila2.modeling_aquila import AquilaForCausalLM download_path = os.path.join(model_dir, model_name) - if not torch_dtype: - if model_name.lower() == "aquilachat2-34b": - torch_dtype = torch.bfloat16 - else: - torch_dtype = torch.float16 - + if not torch_dtype and '34b' in model_name.lower(): + torch_dtype = torch.bfloat16 + if not os.path.exists(download_path): # Try to download from ModelHub try: