Skip to content

Commit

Permalink
Add aquila-v2 template.
Browse files Browse the repository at this point in the history
Signed-off-by: ldwang <[email protected]>
  • Loading branch information
ldwang committed Nov 2, 2023
1 parent b0cee8d commit 994ee63
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
27 changes: 27 additions & 0 deletions flagai/model/aquila2/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def get_conversation_template(model_path: str) -> Conversation:
"""Get the default conversation template."""
if "aquila-v1" in model_path:
return get_conv_template("aquila-v1")
elif "aquila-v2" in model_path:
return get_conv_template("aquila-v2")
elif "aquila-chat" in model_path:
return get_conv_template("aquila-chat")
elif "aquila-legacy" in model_path:
Expand Down Expand Up @@ -227,6 +229,21 @@ def get_conversation_template(model_path: str) -> Conversation:
)
)

register_conv_template(
Conversation(
name="aquila-v2",
system_message="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("<|startofpiece|>", "<|endofpiece|>", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_TWO,
sep="",
sep2="</s>",
stop_str=["</s>", "<|endoftext|>"],
)
)


if __name__ == "__main__":
print("aquila template:")
Expand Down Expand Up @@ -269,3 +286,13 @@ def get_conversation_template(model_path: str) -> Conversation:

print("\n")

print("aquila-v2 template:")
conv = get_conv_template("aquila-v2")
conv.append_message(conv.roles[0], "Hello!")
conv.append_message(conv.roles[1], "Hi!")
conv.append_message(conv.roles[0], "How are you?")
conv.append_message(conv.roles[1], None)
print(conv.get_prompt())

print("\n")

28 changes: 28 additions & 0 deletions flagai/model/aquila2_hf/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def get_conversation_template(model_path: str) -> Conversation:
"""Get the default conversation template."""
if "aquila-v1" in model_path:
return get_conv_template("aquila-v1")
elif "aquila-v2" in model_path:
return get_conv_template("aquila-v2")
elif "aquila-chat" in model_path:
return get_conv_template("aquila-chat")
elif "aquila-legacy" in model_path:
Expand Down Expand Up @@ -252,6 +254,21 @@ def get_conversation_template(model_path: str) -> Conversation:
)
)

register_conv_template(
Conversation(
name="aquila-v2",
system_message="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("<|startofpiece|>", "<|endofpiece|>", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_TWO,
sep="",
sep2="</s>",
stop_str=["</s>", "<|endoftext|>"],
)
)


if __name__ == "__main__":
print("aquila template:")
Expand Down Expand Up @@ -294,6 +311,17 @@ def get_conversation_template(model_path: str) -> Conversation:

print("\n")

print("aquila-v2 template:")
conv = get_conv_template("aquila-v2")
conv.append_message(conv.roles[0], "Hello!")
conv.append_message(conv.roles[1], "Hi!")
conv.append_message(conv.roles[0], "How are you?")
conv.append_message(conv.roles[1], None)
print(conv.get_prompt())

print("\n")


def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
Expand Down

0 comments on commit 994ee63

Please sign in to comment.