-
Notifications
You must be signed in to change notification settings - Fork 8
/
main_InContext.py
112 lines (89 loc) · 4.43 KB
/
main_InContext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import json
import argparse
import sys
from ChatDrug.task_and_evaluation.Conversational_LLMs_utils import complete
from utils import load_retrieval_DB, construct_prompt_incontext, retrieve_and_feedback, load_thredhold
from ChatDrug.task_and_evaluation import task_to_drug, get_task_specification_dict, evaluate, parse
def main(args):
f = open(args['log_file'], 'w')
record = {}
# load dataset
drug_type = task_to_drug(args['task'])
task_specification_dict = get_task_specification_dict(args['task'])
input_drug_list, retrieval_DB = load_retrieval_DB(args['task'], args['seed'])
threshold_dict = load_thredhold(drug_type)
num_correct = 0
num_all = 0
for index, input_drug in enumerate(input_drug_list):
print(f">>Sample {index}", file=f)
record[input_drug]={}
record[input_drug]['drug_skip'] = 0
# ChatGPT message
messages = [{"role": "system", "content": "You are an expert in the field of molecular chemistry."}]
print(f'Start Retrieval', file=f)
try:
closest_drug = retrieve_and_feedback(args['task'], retrieval_DB, input_drug, input_drug, args['constraint'], threshold_dict)
except:
error = sys.exc_info()
if error[0] == Exception:
print('Cannot find a retrieval result.', file=f)
record[input_drug]['answer'] = 'False'
num_all += 1
else:
print('Invalid drug. Failed to evaluate. Skipped.', file=f)
record[input_drug]['drug_skip'] = 1
continue
print("Retrieval Result:" + closest_drug, file=f)
record[input_drug]['retrieval_drug'] = closest_drug
prompt = construct_prompt_incontext(task_specification_dict, input_drug, drug_type, closest_drug, args['task'])
messages.append({"role": "user", "content": prompt})
generated_text = complete(messages, args['conversational_LLM'])
messages.append({"role": "assistant", "content": generated_text})
print("----------------", file=f)
print("User:" + prompt, file=f)
print("ChatGPT:" + generated_text, file=f)
record[input_drug]['user'] = prompt
record[input_drug]['chatgpt'] = generated_text
generated_drug_list = parse(args['task'], input_drug, generated_text, closest_drug)
# Check Parsing Results
if generated_drug_list == None:
record[input_drug]['drug_skip'] = 1
continue
elif len(generated_drug_list) == 0:
record[input_drug]['answer'] = 'False'
num_all += 1
continue
else:
generated_drug = generated_drug_list[0]
print("Generated Result:" + str(generated_drug), file=f)
record[input_drug]['generated_drug'] = generated_drug
answer = evaluate(input_drug, generated_drug, args['task'], args['constraint'], threshold_dict)
if answer == -1:
record[input_drug]['drug_skip'] = 1
continue
print('Evaluation result is: ' + str(answer), file=f)
record[input_drug]['answer'] = str(answer)
if answer:
num_correct += 1
num_all += 1
else:
num_all += 1
print(f'Acc = {num_correct}/{num_all}', file=f)
print("----------------", file=f)
print("--------Final Acc--------", file=f)
print(f'Acc = {num_correct}/{num_all}', file=f)
print("----------------", file=f)
with open(args['record_file'], 'w') as rf:
json.dump(record, rf, ensure_ascii=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# required arguments
parser.add_argument('--task', action='store', required=True, type=int, help='task_id')
parser.add_argument('--conversational_LLM', action='store', required=False, type=str, default='chatgpt', help='only support chatgpt now')
parser.add_argument('--log_file', action='store', required=False, type=str, default='results/ChatDrug.log', help='saved log file name')
parser.add_argument('--record_file', action='store', required=False, type=str, default='results/ChatDrug.json', help='saved record file name')
parser.add_argument('--constraint', required=False, type=str, default='loose', help='loose or strict')
parser.add_argument('--seed', required=False, type=int, default=0, help='seed for retrieval data base')
args = parser.parse_args()
args = vars(args)
main(args)