diff --git a/flagai/model/aquila2/modeling_aquila.py b/flagai/model/aquila2/modeling_aquila.py index ea1083cc..2ebe79af 100755 --- a/flagai/model/aquila2/modeling_aquila.py +++ b/flagai/model/aquila2/modeling_aquila.py @@ -923,7 +923,7 @@ def predict(self, text, tokenizer=None, sft=True, convo_template = "", device = "cuda", model_name="AquilaChat2-7B", - history=[], + history=None, **kwargs): vocab = tokenizer.get_vocab() @@ -1033,9 +1033,10 @@ def predict(self, text, tokenizer=None, convert_tokens = convert_tokens[1:] probs = probs[1:] - # Update history - history.insert(0, ('ASSISTANT', out)) - history.insert(0, ('USER', text)) + if isinstance(history, list): + # Update history + history.insert(0, ('ASSISTANT', out)) + history.insert(0, ('USER', text)) return out diff --git a/flagai/model/aquila2/utils.py b/flagai/model/aquila2/utils.py index 7e740c45..054db0d8 100755 --- a/flagai/model/aquila2/utils.py +++ b/flagai/model/aquila2/utils.py @@ -21,6 +21,9 @@ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token, example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids'] + if history is None or not isinstance(history, list): + history = [] + while(len(history) > 0 and (len(example) < max_token)): tmp = history.pop() if tmp[0] == 'ASSISTANT': @@ -35,4 +38,4 @@ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token, print('model in:', conv.get_prompt()) example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids'] - return example \ No newline at end of file + return example diff --git a/flagai/model/aquila2_hf/predict.py b/flagai/model/aquila2_hf/predict.py index b2d19c8b..2e5143f4 100644 --- a/flagai/model/aquila2_hf/predict.py +++ b/flagai/model/aquila2_hf/predict.py @@ -310,6 +310,9 @@ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token, example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids'] + if history is None or not isinstance(history, list): + history = [] + while(len(history) > 0 and (len(example) < max_token)): tmp = history.pop() if tmp[0] == 'ASSISTANT': @@ -333,7 +336,7 @@ def predict(model, text, tokenizer=None, sft=True, convo_template = "", device = "cuda", model_name="AquilaChat2-7B", - history=[], + history=None, **kwargs): vocab = tokenizer.get_vocab() @@ -435,8 +438,9 @@ def predict(model, text, tokenizer=None, convert_tokens = convert_tokens[1:] probs = probs[1:] - # Update history - history.insert(0, ('ASSISTANT', out)) - history.insert(0, ('USER', text)) + if isinstance(history, list): + # Update history + history.insert(0, ('ASSISTANT', out)) + history.insert(0, ('USER', text)) return out