diff --git a/chatllms/configs/gen_args.py b/chatllms/configs/gen_args.py index aefe292..ca5df19 100644 --- a/chatllms/configs/gen_args.py +++ b/chatllms/configs/gen_args.py @@ -4,11 +4,16 @@ @dataclass class GenerationArguments: + """ + Arguments pertaining to specify the model generation parameters. + """ # generation parameters + # 是否使用cache + use_cache: Optional[bool] = field(default=True) # Length arguments # 最大的新生成的token数量 max_new_tokens: Optional[int] = field( - default=512, + default=1024, metadata={ 'help': 'Maximum number of new tokens to be generated in evaluation or prediction loops' @@ -16,7 +21,7 @@ class GenerationArguments: }) # 最少的新生成的token数量 min_new_tokens: Optional[int] = field( - default=None, + default=0, metadata={'help': 'Minimum number of new tokens to generate.'}) # 最大的token数量,会被 max_new_tokens 覆盖 max_length: Optional[int] = field( @@ -27,31 +32,61 @@ class GenerationArguments: }) # Generation strategy # 是否采样 - do_sample: Optional[bool] = field(default=True) + do_sample: Optional[bool] = field( + default=True, + metadata={ + 'help': + 'Whether or not to use sampling, use greedy decoding otherwise.' + }) # 集束搜索的数量 - num_beams: Optional[int] = field(default=1) + num_beams: Optional[int] = field( + default=1, + metadata={ + 'help': 'Number of beams for beam search. 1 means no beam search.' + }) # 集束搜索的组数量 num_beam_groups: Optional[int] = field(default=1) # 惩罚因子 penalty_alpha: Optional[float] = field(default=None) - # 是否使用cache - use_cache: Optional[bool] = field(default=True) - # Hyperparameters for logit manipulation # softmax 函数的温度因子,来调节输出token的分布 - temperature: Optional[float] = field(default=1.0) + temperature: Optional[float] = field( + default=1.0, + metadata={ + 'help': 'The value used to modulate the next token probabilities.' + }) # top_k随机搜索中的k个最高概率选择 - top_k: Optional[int] = field(default=50) + top_k: Optional[int] = field( + default=50, + metadata={ + 'help': + 'The number of highest probability vocabulary tokens to keep for top-k filtering.' + }) # 核采样参数,top_p最高的前n个(n是变化)概率和为p,从这些n个候选token中随机采样 - top_p: Optional[float] = field(default=1.0) + top_p: Optional[float] = field( + default=1.0, + metadata={ + 'help': + 'The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept.' + }) # 典型p值 typical_p: Optional[float] = field(default=1.0) # 丰富性惩罚因子 diversity_penalty: Optional[float] = field(default=0.0) # 重复性惩罚因子 - repetition_penalty: Optional[float] = field(default=1.0) + repetition_penalty: Optional[float] = field( + default=1.0, + metadata={ + 'help': + 'The parameter for repetition penalty. 1.0 means no penalty.' + }) # 长度惩罚因子 - length_penalty: Optional[float] = field(default=1.0) + length_penalty: Optional[float] = field( + default=1.0, + metadata={ + 'help': + 'Exponential penalty to the length that is used with beam-based generation.' + }) # 没有ngram重复的尺度大小 # 一般随机采样的丰富性够了,所以一般不会设置,如果重复很多则设置为2是比较好的选择 no_repeat_ngram_size: Optional[int] = field(default=0) diff --git a/chatllms/utils/template.py b/chatllms/utils/template.py index 5c4e6b3..a96b8fa 100644 --- a/chatllms/utils/template.py +++ b/chatllms/utils/template.py @@ -20,6 +20,10 @@ class PromptTemplate(object): """ name: str + prefix: str = '' + prompt: str = None + sep: str = None + use_history: bool = False def get_prompt( self, @@ -88,12 +92,14 @@ def format_example(self, convs.append(bot_resp) return convs[:-1] # drop last - def register_template(self, - name: str, - prefix: str, - prompt: str, - sep: str, - use_history: Optional[bool] = True) -> None: + def register_template( + self, + name: str, + prefix: str, + prompt: str, + sep: str, + use_history: Optional[bool] = True, + ) -> None: """ Registers a new conversation template. diff --git a/cli_demo.py b/cli_demo.py index 3c37014..959be76 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -16,7 +16,7 @@ def generate_response( query: str, history: List[Tuple[str, str]], - source_prefix: str, + prefix: str, prompt_template: PromptTemplate, tokenizer: PreTrainedTokenizer, model: PreTrainedModel, @@ -28,7 +28,7 @@ def generate_response( Args: query (str): The input query for which a response is to be generated. history (List[Tuple[str, str]]): A list of previous queries and their responses. - source_prefix (str): The prefix string added to the beginning of each input sequence. + prefix (str): The prefix string added to the beginning of each input sequence. prompt_template (PromptTemplate): The prompt template used to generate the input sequence to the model. tokenizer (PreTrainedTokenizer): The tokenizer used to convert the raw text into input tokens. model (PreTrainedModel): The GPT-3.5 model used to generate the response. @@ -39,7 +39,7 @@ def generate_response( """ # Convert the query and history into input IDs - input_text = prompt_template.get_prompt(query, history, source_prefix) + input_text = prompt_template.get_prompt(query, history, prefix) inputs = tokenizer(input_text, return_tensors='pt') inputs = {k: v.to(model.device) for k, v in inputs.items()} @@ -100,7 +100,7 @@ def main(): ) prompt_template = PromptTemplate(model_server_args.prompt_template) - source_prefix = model_server_args.source_prefix if model_server_args.source_prefix else '' + prefix = model_server_args.source_prefix if model_server_args.source_prefix else '' history: List[str] = [] print('欢迎使用 CLI 对话系统,输入内容即可对话,clear 清空对话历史,stop 终止程序') while True: @@ -122,9 +122,8 @@ def main(): continue # Perform prediction and printing - history = generate_response(query, history, source_prefix, - prompt_template, tokenizer, model, - generation_args) + history = generate_response(query, history, prefix, prompt_template, + tokenizer, model, generation_args) if __name__ == '__main__': diff --git a/data/README.md b/data/README.md index 00cfcb1..85b1950 100644 --- a/data/README.md +++ b/data/README.md @@ -5,15 +5,10 @@ We provide the following datasets for the experiments in this framework. +### English Instruction Datasets + - [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) -- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [Hello-SimpleAI/HC3](https://huggingface.co/datasets/Hello-SimpleAI/HC3) -- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) -- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) -- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) -- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) -- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) -- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) - [databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) - [mosaicml/dolly_hhrlhf](https://huggingface.co/datasets/mosaicml/dolly_hhrlhf) - [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) @@ -21,10 +16,36 @@ We provide the following datasets for the experiments in this framework. - [UltraChat](https://github.com/thunlp/UltraChat) - [OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1) - [ShareGPT_Vicuna_unfiltered](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered) -- [BIAI/OL-CC](https://data.baai.ac.cn/details/OL-CC) - [timdettmers/openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) - [Evol-Instruct](https://huggingface.co/datasets/victor123/evol_instruct_70k) +### 中文指令数据集 +- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) +- [Alpaca-GPT-4 (zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) +- [ShareChat(倡议大家一起翻译高质量 ShareGPT 数据的项目)](https://paratranz.cn/projects/6725) +- [InstructionWild (zh)](https://github.com/XueFuzhao/InstructionWild) +- [SmileConv(通过ChatGPT改写真实的心理互助 QA为多轮的心理健康支持多轮对话)](https://github.com/qiuhuachuan/smile ) +- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) +- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) +- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) +- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) +- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) +- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) +- [OL-CC(OpenLabel-Chinese Conversations Dataset)以众包方式、人工生成的开源中文对话指令集](https://data.baai.ac.cn/details/OL-CC) +- [CValues-Comparison中文大模型价值观比较数据集](https://modelscope.cn/datasets/damo/CValues-Comparison/summary) +- [100PoisonMpts(给AI的100瓶毒药): 中文大模型治理数据集](https://modelscope.cn/datasets/damo/100PoisonMpts/summary) +- [COIG(Chinese Open Instruction Generalist project)](https://huggingface.co/datasets/BAAI/COIG) +- [COIG-PC(Prompt Collection) COIG 数据集二期](https://huggingface.co/datasets/BAAI/COIG-PC) +- [中文医疗指令数据集-华陀](https://huggingface.co/datasets/FreedomIntelligence/HuatuoGPT-sft-data-v1) + + +### RLHF Datasets + +- [CValues](https://github.com/X-PLUG/CValues) + 数据集说明:开源了数据规模为145k的价值对齐数据集,该数据集对于每个prompt包括了拒绝&正向建议,(safe and reponsibility) > 拒绝为主(safe) > 风险回复(unsafe)三种类型,可用于增强SFT模型的安全性或用于训练reward模型。 + + + ## Dataset formation The `dataset_info.yaml` file contains the information of the datasets, main including the following fields.