-
Notifications
You must be signed in to change notification settings - Fork 2
/
Step2_T5_judge.py
103 lines (91 loc) · 4.9 KB
/
Step2_T5_judge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from sentence_transformers import SentenceTransformer, util
import numpy as np
import os, json, glob
import copy
import pprint
import argparse
def T5_similarity(output_sequence=None, chocies_list = None):
sentences = [output_sequence]
sentences2 = chocies_list
model = SentenceTransformer('sentence-transformers/sentence-t5-large', cache_folder='/remote-home/share/VideoBenchmark/Video_Benchmark/T5_evaluation')
model = model.cuda()
embeddings = model.encode(sentences)
embeddings2 = model.encode(sentences2)
#Compute cosine-similarities
cosine_scores = util.cos_sim(embeddings, embeddings2)
index = np.argmax(cosine_scores)
return index
import traceback
def json_T5_eval(T5_save_folder=None, jsonfile=None, args=None):
dataset_qajson = {
"Ucfcrime": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/Ucfcrime_QA_new.json"),
"Youcook2": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/Youcook2_QA_new.json"),
"TVQA": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/TVQA_QA_new.json"),
"MSVD": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/MSVD_QA_new.json"),
"MSRVTT": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/MSRVTT_QA_new.json"),
"Driving-decision-making": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/Driving-decision-making_QA_new.json"),
"NBA": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/NBA_QA_new.json"),
"SQA3D": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/SQA3D_QA_new.json"),
"Driving-exam": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/Driving-exam_QA_new.json"),
"MV": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/MV_QA_new.json"),
"MOT": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/MOT_QA_new.json"),
"ActivityNet": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/ActivityNet_QA_new.json"),
"TGIF": os.path.join( f"{args.Eval_QA_root}", "Eval_QA/TGIF_QA_new.json")
}
# dataset 的question-choices-answer jsonfile
dataset_name = os.path.basename(jsonfile).split('_eval.json')[0]
print(f'Dataset name: {dataset_name}')
qa_choice_json = dataset_qajson[dataset_name]
with open(qa_choice_json, 'r', encoding='utf-8') as f:
qa_choice_data = json.load(f)
# model chat jsonfile
with open(jsonfile, 'r', encoding='utf-8') as f:
data = json.load(f)
candidates = ['A', 'B', 'C', 'D', 'E', 'F']
try:
new_data = {}
for qid_vid, item in data.items():
# 单独qa_t5 eval结果保存
os.makedirs(os.path.join(T5_save_folder, os.path.basename(jsonfile).split('.')[0]), exist_ok=True)
T5_qidvid_jsonfile = os.path.join(T5_save_folder, os.path.basename(jsonfile).split('.')[0], qid_vid+'.json')
if not os.path.exists(T5_qidvid_jsonfile):
new_item = copy.deepcopy(item)
output_sequence = item['output_sequence']
video_id = item['video_id']
qid = qid_vid.replace(f'_{video_id}', '')
choices = qa_choice_data[qid]['choices']
choices = [ f'{alpha}. {choice}' for alpha, choice in choices.items()]
answer_index = T5_similarity(str(output_sequence), choices)
T5_answer = candidates[answer_index]
new_item['t5-answer']= T5_answer
new_item['choices'] = choices
pprint.pprint(new_item)
new_data[qid_vid] = new_item
with open(T5_qidvid_jsonfile, 'w', encoding='utf-8') as f:
json.dump({qid_vid:new_item}, f, indent=2)
print(T5_qidvid_jsonfile, 'is saved!')
else:
print(f'{T5_qidvid_jsonfile} is existing!')
#一个model的一个dataset 所有qa保存
T5_dataset_jsonfile = os.path.join(T5_save_folder, os.path.basename(jsonfile))
with open(T5_dataset_jsonfile, 'w', encoding='utf-8') as f:
json.dump(new_data, f, indent=2)
except Exception as e:
print(traceback.print_exc())
import ipdb
ipdb.set_trace()
def main(args):
evaljson_list = glob.glob(f'{args.model_chat_files_folder}/*_eval.json', recursive=True)
print(f'{len(evaljson_list)}') #{evaljson_list},
for evaljson in evaljson_list:
try:
json_T5_eval(args.T5_judge_output_folder, evaljson, args)
except Exception as e:
print(e)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_chat_files_folder", type=str, default="./Chat_results")
parser.add_argument("--T5_judge_output_folder", type=str, default="./T5_Judge")
parser.add_argument("--Eval_QA_root", type=str, default="/remote-home/share/VideoBenchmark/Video_Benchmark")
args = parser.parse_args()
main(args)