From 994ee6387642abcfe6f62977097cc8ad91a6e77e Mon Sep 17 00:00:00 2001 From: ldwang Date: Thu, 2 Nov 2023 09:34:00 +0800 Subject: [PATCH] Add aquila-v2 template. Signed-off-by: ldwang --- flagai/model/aquila2/conversation.py | 27 +++++++++++++++++++++++++++ flagai/model/aquila2_hf/predict.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/flagai/model/aquila2/conversation.py b/flagai/model/aquila2/conversation.py index 496c1f5a..aba87932 100755 --- a/flagai/model/aquila2/conversation.py +++ b/flagai/model/aquila2/conversation.py @@ -160,6 +160,8 @@ def get_conversation_template(model_path: str) -> Conversation: """Get the default conversation template.""" if "aquila-v1" in model_path: return get_conv_template("aquila-v1") + elif "aquila-v2" in model_path: + return get_conv_template("aquila-v2") elif "aquila-chat" in model_path: return get_conv_template("aquila-chat") elif "aquila-legacy" in model_path: @@ -227,6 +229,21 @@ def get_conversation_template(model_path: str) -> Conversation: ) ) +register_conv_template( + Conversation( + name="aquila-v2", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("<|startofpiece|>", "<|endofpiece|>", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="", + sep2="", + stop_str=["", "<|endoftext|>"], + ) +) + if __name__ == "__main__": print("aquila template:") @@ -269,3 +286,13 @@ def get_conversation_template(model_path: str) -> Conversation: print("\n") + print("aquila-v2 template:") + conv = get_conv_template("aquila-v2") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) + + print("\n") + diff --git a/flagai/model/aquila2_hf/predict.py b/flagai/model/aquila2_hf/predict.py index d711e4d9..a54c45dd 100644 --- a/flagai/model/aquila2_hf/predict.py +++ b/flagai/model/aquila2_hf/predict.py @@ -185,6 +185,8 @@ def get_conversation_template(model_path: str) -> Conversation: """Get the default conversation template.""" if "aquila-v1" in model_path: return get_conv_template("aquila-v1") + elif "aquila-v2" in model_path: + return get_conv_template("aquila-v2") elif "aquila-chat" in model_path: return get_conv_template("aquila-chat") elif "aquila-legacy" in model_path: @@ -252,6 +254,21 @@ def get_conversation_template(model_path: str) -> Conversation: ) ) +register_conv_template( + Conversation( + name="aquila-v2", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("<|startofpiece|>", "<|endofpiece|>", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="", + sep2="", + stop_str=["", "<|endoftext|>"], + ) +) + if __name__ == "__main__": print("aquila template:") @@ -294,6 +311,17 @@ def get_conversation_template(model_path: str) -> Conversation: print("\n") + print("aquila-v2 template:") + conv = get_conv_template("aquila-v2") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) + + print("\n") + + def set_random_seed(seed): """Set random seed for reproducability.""" if seed is not None and seed > 0: