Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update chatllms #80

Merged
merged 2 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions chatllms/configs/gen_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@

@dataclass
class GenerationArguments:
"""
Arguments pertaining to specify the model generation parameters.
"""
# generation parameters
# 是否使用cache
use_cache: Optional[bool] = field(default=True)
# Length arguments
# 最大的新生成的token数量
max_new_tokens: Optional[int] = field(
default=512,
default=1024,
metadata={
'help':
'Maximum number of new tokens to be generated in evaluation or prediction loops'
'if predict_with_generate is set.'
})
# 最少的新生成的token数量
min_new_tokens: Optional[int] = field(
default=None,
default=0,
metadata={'help': 'Minimum number of new tokens to generate.'})
# 最大的token数量,会被 max_new_tokens 覆盖
max_length: Optional[int] = field(
Expand All @@ -27,31 +32,61 @@ class GenerationArguments:
})
# Generation strategy
# 是否采样
do_sample: Optional[bool] = field(default=True)
do_sample: Optional[bool] = field(
default=True,
metadata={
'help':
'Whether or not to use sampling, use greedy decoding otherwise.'
})
# 集束搜索的数量
num_beams: Optional[int] = field(default=1)
num_beams: Optional[int] = field(
default=1,
metadata={
'help': 'Number of beams for beam search. 1 means no beam search.'
})
# 集束搜索的组数量
num_beam_groups: Optional[int] = field(default=1)
# 惩罚因子
penalty_alpha: Optional[float] = field(default=None)
# 是否使用cache
use_cache: Optional[bool] = field(default=True)

# Hyperparameters for logit manipulation
# softmax 函数的温度因子,来调节输出token的分布
temperature: Optional[float] = field(default=1.0)
temperature: Optional[float] = field(
default=1.0,
metadata={
'help': 'The value used to modulate the next token probabilities.'
})
# top_k随机搜索中的k个最高概率选择
top_k: Optional[int] = field(default=50)
top_k: Optional[int] = field(
default=50,
metadata={
'help':
'The number of highest probability vocabulary tokens to keep for top-k filtering.'
})
# 核采样参数,top_p最高的前n个(n是变化)概率和为p,从这些n个候选token中随机采样
top_p: Optional[float] = field(default=1.0)
top_p: Optional[float] = field(
default=1.0,
metadata={
'help':
'The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept.'
})
# 典型p值
typical_p: Optional[float] = field(default=1.0)
# 丰富性惩罚因子
diversity_penalty: Optional[float] = field(default=0.0)
# 重复性惩罚因子
repetition_penalty: Optional[float] = field(default=1.0)
repetition_penalty: Optional[float] = field(
default=1.0,
metadata={
'help':
'The parameter for repetition penalty. 1.0 means no penalty.'
})
# 长度惩罚因子
length_penalty: Optional[float] = field(default=1.0)
length_penalty: Optional[float] = field(
default=1.0,
metadata={
'help':
'Exponential penalty to the length that is used with beam-based generation.'
})
# 没有ngram重复的尺度大小
# 一般随机采样的丰富性够了,所以一般不会设置,如果重复很多则设置为2是比较好的选择
no_repeat_ngram_size: Optional[int] = field(default=0)
Expand Down
18 changes: 12 additions & 6 deletions chatllms/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class PromptTemplate(object):
"""

name: str
prefix: str = ''
prompt: str = None
sep: str = None
use_history: bool = False

def get_prompt(
self,
Expand Down Expand Up @@ -88,12 +92,14 @@ def format_example(self,
convs.append(bot_resp)
return convs[:-1] # drop last

def register_template(self,
name: str,
prefix: str,
prompt: str,
sep: str,
use_history: Optional[bool] = True) -> None:
def register_template(
self,
name: str,
prefix: str,
prompt: str,
sep: str,
use_history: Optional[bool] = True,
) -> None:
"""
Registers a new conversation template.

Expand Down
13 changes: 6 additions & 7 deletions cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def generate_response(
query: str,
history: List[Tuple[str, str]],
source_prefix: str,
prefix: str,
prompt_template: PromptTemplate,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
Expand All @@ -28,7 +28,7 @@ def generate_response(
Args:
query (str): The input query for which a response is to be generated.
history (List[Tuple[str, str]]): A list of previous queries and their responses.
source_prefix (str): The prefix string added to the beginning of each input sequence.
prefix (str): The prefix string added to the beginning of each input sequence.
prompt_template (PromptTemplate): The prompt template used to generate the input sequence to the model.
tokenizer (PreTrainedTokenizer): The tokenizer used to convert the raw text into input tokens.
model (PreTrainedModel): The GPT-3.5 model used to generate the response.
Expand All @@ -39,7 +39,7 @@ def generate_response(
"""

# Convert the query and history into input IDs
input_text = prompt_template.get_prompt(query, history, source_prefix)
input_text = prompt_template.get_prompt(query, history, prefix)
inputs = tokenizer(input_text, return_tensors='pt')
inputs = {k: v.to(model.device) for k, v in inputs.items()}

Expand Down Expand Up @@ -100,7 +100,7 @@ def main():
)

prompt_template = PromptTemplate(model_server_args.prompt_template)
source_prefix = model_server_args.source_prefix if model_server_args.source_prefix else ''
prefix = model_server_args.source_prefix if model_server_args.source_prefix else ''
history: List[str] = []
print('欢迎使用 CLI 对话系统,输入内容即可对话,clear 清空对话历史,stop 终止程序')
while True:
Expand All @@ -122,9 +122,8 @@ def main():
continue

# Perform prediction and printing
history = generate_response(query, history, source_prefix,
prompt_template, tokenizer, model,
generation_args)
history = generate_response(query, history, prefix, prompt_template,
tokenizer, model, generation_args)


if __name__ == '__main__':
Expand Down
37 changes: 29 additions & 8 deletions data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,47 @@

We provide the following datasets for the experiments in this framework.

### English Instruction Datasets

- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Hello-SimpleAI/HC3](https://huggingface.co/datasets/Hello-SimpleAI/HC3)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k)
- [mosaicml/dolly_hhrlhf](https://huggingface.co/datasets/mosaicml/dolly_hhrlhf)
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [UltraChat](https://github.com/thunlp/UltraChat)
- [OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT_Vicuna_unfiltered](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered)
- [BIAI/OL-CC](https://data.baai.ac.cn/details/OL-CC)
- [timdettmers/openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco)
- [Evol-Instruct](https://huggingface.co/datasets/victor123/evol_instruct_70k)

### 中文指令数据集
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Alpaca-GPT-4 (zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [ShareChat(倡议大家一起翻译高质量 ShareGPT 数据的项目)](https://paratranz.cn/projects/6725)
- [InstructionWild (zh)](https://github.com/XueFuzhao/InstructionWild)
- [SmileConv(通过ChatGPT改写真实的心理互助 QA为多轮的心理健康支持多轮对话)](https://github.com/qiuhuachuan/smile )
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [OL-CC(OpenLabel-Chinese Conversations Dataset)以众包方式、人工生成的开源中文对话指令集](https://data.baai.ac.cn/details/OL-CC)
- [CValues-Comparison中文大模型价值观比较数据集](https://modelscope.cn/datasets/damo/CValues-Comparison/summary)
- [100PoisonMpts(给AI的100瓶毒药): 中文大模型治理数据集](https://modelscope.cn/datasets/damo/100PoisonMpts/summary)
- [COIG(Chinese Open Instruction Generalist project)](https://huggingface.co/datasets/BAAI/COIG)
- [COIG-PC(Prompt Collection) COIG 数据集二期](https://huggingface.co/datasets/BAAI/COIG-PC)
- [中文医疗指令数据集-华陀](https://huggingface.co/datasets/FreedomIntelligence/HuatuoGPT-sft-data-v1)


### RLHF Datasets

- [CValues](https://github.com/X-PLUG/CValues)
数据集说明:开源了数据规模为145k的价值对齐数据集,该数据集对于每个prompt包括了拒绝&正向建议,(safe and reponsibility) > 拒绝为主(safe) > 风险回复(unsafe)三种类型,可用于增强SFT模型的安全性或用于训练reward模型。



## Dataset formation

The `dataset_info.yaml` file contains the information of the datasets, main including the following fields.
Expand Down