Skip to content

Commit

Permalink
Fix predict_chat_history.
Browse files Browse the repository at this point in the history
Signed-off-by: ldwang <[email protected]>
  • Loading branch information
ldwang committed Oct 24, 2023
1 parent 1b0a7b3 commit 2c3811b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
9 changes: 5 additions & 4 deletions flagai/model/aquila2/modeling_aquila.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ def predict(self, text, tokenizer=None,
sft=True, convo_template = "",
device = "cuda",
model_name="AquilaChat2-7B",
history=[],
history=None,
**kwargs):

vocab = tokenizer.get_vocab()
Expand Down Expand Up @@ -1033,9 +1033,10 @@ def predict(self, text, tokenizer=None,
convert_tokens = convert_tokens[1:]
probs = probs[1:]

# Update history
history.insert(0, ('ASSISTANT', out))
history.insert(0, ('USER', text))
if isinstance(history, list):
# Update history
history.insert(0, ('ASSISTANT', out))
history.insert(0, ('USER', text))

return out

Expand Down
5 changes: 4 additions & 1 deletion flagai/model/aquila2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token,

example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']

if history is None or not isinstance(history, list):
history = []

while(len(history) > 0 and (len(example) < max_token)):
tmp = history.pop()
if tmp[0] == 'ASSISTANT':
Expand All @@ -35,4 +38,4 @@ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token,
print('model in:', conv.get_prompt())
example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']

return example
return example
12 changes: 8 additions & 4 deletions flagai/model/aquila2_hf/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token,

example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']

if history is None or not isinstance(history, list):
history = []

while(len(history) > 0 and (len(example) < max_token)):
tmp = history.pop()
if tmp[0] == 'ASSISTANT':
Expand All @@ -333,7 +336,7 @@ def predict(model, text, tokenizer=None,
sft=True, convo_template = "",
device = "cuda",
model_name="AquilaChat2-7B",
history=[],
history=None,
**kwargs):

vocab = tokenizer.get_vocab()
Expand Down Expand Up @@ -435,8 +438,9 @@ def predict(model, text, tokenizer=None,
convert_tokens = convert_tokens[1:]
probs = probs[1:]

# Update history
history.insert(0, ('ASSISTANT', out))
history.insert(0, ('USER', text))
if isinstance(history, list):
# Update history
history.insert(0, ('ASSISTANT', out))
history.insert(0, ('USER', text))

return out

0 comments on commit 2c3811b

Please sign in to comment.