From 95763062350caeaf116163cc68fb0ecccd7dffc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E7=85=A7=E4=B8=9C?= Date: Thu, 12 Oct 2023 04:52:34 +0000 Subject: [PATCH 1/2] enabled auto convo selection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 严照东 --- flagai/model/aquila2/modeling_aquila.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/flagai/model/aquila2/modeling_aquila.py b/flagai/model/aquila2/modeling_aquila.py index 7963fa2d..aee64266 100755 --- a/flagai/model/aquila2/modeling_aquila.py +++ b/flagai/model/aquila2/modeling_aquila.py @@ -921,19 +921,24 @@ def predict(self, text, tokenizer=None, seed=1234, topk=100, temperature=0.9, sft=True, convo_template = "aquila-chat", - device = "cuda"): + device = "cuda", + model_name="", + **kwargs): vocab = tokenizer.get_vocab() - #device = device - id2word = {v:k for k, v in vocab.items()} + id2word = {v:k for k, v in vocab.items()} + template_map = {"AquilaChat2-7B": "aquila-v1", + "AquilaChat2-34B": "aquila-legacy", + "AquilaChat2-7B-16K": "aquila", + "AquilaChat2-34B-16K": "aquila-v1"} set_random_seed(seed) if temperature == 0: topk = 1 temperature = 1.0 if sft: - tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=2048, convo_template=convo_template) + tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=2048, convo_template=template_map[model_name]) tokens = torch.tensor(tokens)[None,].to(device) else : tokens = tokenizer.encode_plus(text)["input_ids"] From b537a4d285e8044f8facbc164b9a30da96f748ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E7=85=A7=E4=B8=9C?= Date: Thu, 12 Oct 2023 05:12:55 +0000 Subject: [PATCH 2/2] changed prediction of aquila MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 严照东 --- flagai/model/aquila2/conversation.py | 17 +++++++++++++++++ flagai/model/aquila2/modeling_aquila.py | 11 ++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/flagai/model/aquila2/conversation.py b/flagai/model/aquila2/conversation.py index c46d1d9a..74057329 100755 --- a/flagai/model/aquila2/conversation.py +++ b/flagai/model/aquila2/conversation.py @@ -162,6 +162,8 @@ def get_conversation_template(model_path: str) -> Conversation: return get_conv_template("aquila-v1") elif "aquila-chat" in model_path: return get_conv_template("aquila-chat") + elif "aquila-legacy" in model_path: + return get_conv_template("aquila-legacy") else: return get_conv_template("aquila") @@ -182,6 +184,21 @@ def get_conversation_template(model_path: str) -> Conversation: ) ) +register_conv_template( + Conversation( + name="aquila-legacy", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("### Human: ", "### Assistant: ", "System"), + messages=(), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="\n", + sep2="", + stop_str=["", "[UNK]"], + ) +) + register_conv_template( Conversation( name="aquila", diff --git a/flagai/model/aquila2/modeling_aquila.py b/flagai/model/aquila2/modeling_aquila.py index aee64266..89764f71 100755 --- a/flagai/model/aquila2/modeling_aquila.py +++ b/flagai/model/aquila2/modeling_aquila.py @@ -920,25 +920,30 @@ def predict(self, text, tokenizer=None, max_gen_len=200, top_p=0.95, seed=1234, topk=100, temperature=0.9, - sft=True, convo_template = "aquila-chat", + sft=True, convo_template = "", device = "cuda", - model_name="", + model_name="AquilaChat2-7B", **kwargs): vocab = tokenizer.get_vocab() id2word = {v:k for k, v in vocab.items()} + + + template_map = {"AquilaChat2-7B": "aquila-v1", "AquilaChat2-34B": "aquila-legacy", "AquilaChat2-7B-16K": "aquila", "AquilaChat2-34B-16K": "aquila-v1"} + if not convo_template: + convo_template=template_map.get(model_name, "aquila-chat") set_random_seed(seed) if temperature == 0: topk = 1 temperature = 1.0 if sft: - tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=2048, convo_template=template_map[model_name]) + tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=2048, convo_template=convo_template) tokens = torch.tensor(tokens)[None,].to(device) else : tokens = tokenizer.encode_plus(text)["input_ids"]