-
Notifications
You must be signed in to change notification settings - Fork 0
/
KPPerturbation.py
520 lines (451 loc) · 21.7 KB
/
KPPerturbation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
# coding:utf-8
from abc import ABC, abstractmethod
from functools import partial
from sentence_transformers import SentenceTransformer, util
import copy
import json
import numpy as np
import re
import spacy
from typing import List
import logging
import sys
from GeneralLLM import LargeLanguageModel, Qwen, ChatGPT
def _add_left_parenthesis(s:str) -> str:
return f"({s}"
def _add_left_bracket(s:str) -> str:
return f"[{s}"
def _add_left_brace(s:str) -> str:
return f"{{{s}"
def _add_left_wave(s:str) ->str:
return f"~{s}"
def _add_right_parenthesis(s:str) -> str:
return f"{s})"
def _add_right_bracket(s:str) -> str:
return f"{s}]"
def _add_right_brace(s:str) -> str:
return f"{s}}}"
def _add_right_wave(s:str) ->str:
return f"{s}~"
def _add_right_eq(s:str) -> str:
return f"{s}="
def _add_parentheses(s:str) -> str:
return f"({s})"
def _add_brackets(s:str) -> str:
return f"[{s}]"
def _add_braces(s:str) -> str:
return f"{{{s}}}"
def _add_waves(s:str) ->str:
return f"~{s}~"
def _caesar(s:str, delta:int = 10) -> str:
return chr(ord(s)+delta)
class MultipleChoiceQuestion:
'''
The class of multiple choice questions.
'''
def __init__(self, question:str = '', option_ids:List[str] = [], options:List[str] = [],
correct:List[bool] = None, question_first:bool = True,
text_type = 'choice'):
'''
Args:
question: the question text
option_ids: the list of option indeces (with formats), e.g., ['(A)', '(B)', '(C)', '(D)']
options: the option contents, each of which appears after the corresponding option index.
correct: the list indicating the correctness of each option. E.g., [True, False, False, True]
indicates that only the first and the last options are correct answers.
question_first: the switch controlling whether the question appears before the option.
text_type: 'choice' or 'judgement', indicating whether the question text appears as a multiple
choice question or a multiple judgement question
'''
self.question = question
self.option_ids = option_ids
self.options = options
self.question_first = question_first
self.correct = correct
self.text_type = text_type
assert(len(option_ids) == len(options))
if self.correct is not None:
assert(len(option_ids) == len(self.correct))
assert(text_type == 'choice' or text_type == 'judgement')
def to_dict(self)->dict:
'''Return the dictionary form of the current question.'''
result = {
"question": self.question,
"option_ids": self.option_ids,
"options": self.options,
"question_first": self.question_first,
"correct": self.correct,
"text_type": self.text_type
}
return result
def load_dict(self, data: dict):
'''Load information of the question from a dictionary.'''
self.__init__(
question = data["question"],
option_ids = data["option_ids"],
options = data["options"],
question_first = data["question_first"],
correct = data["correct"],
text_type = data["text_type"])
def __str__(self):
assert(len(self.option_ids) == len(self.options))
prompt = "Options:\n"
for key, value in zip(self.option_ids, self.options):
prompt += f"{key} {value}\n"
prompt = f"Question: {self.question}\n" + prompt
prompt += f"Answer:{self.correct}\n"
prompt += f"question_first:{self.question_first}\n"
prompt += f"text_type:{self.text_type}"
return prompt
def get_prompt(self):
'''Get the prompt of the current question for LLMs.'''
assert(len(self.option_ids) == len(self.options))
assert(self.text_type == 'choice' or self.text_type == 'judgement')
if self.text_type == 'choice':
prefix = "Please select the correct option(s) from the following options given the question:\n"
elif self.text_type == 'judgement':
prefix = "Please judge whether each of the options is correct or incorrect given the question:\n"
prompt = 'Options:\n'
for key, value in zip(self.option_ids, self.options):
prompt += f"{key} {value}\n"
if self.question_first:
prompt = f"Question: {self.question}\n" + prompt
else:
prompt = prompt + f"Question: {self.question}\n"
if self.text_type == 'choice':
option_or = ", ".join([f'"{option}"' for option in self.option_ids])
prompt += 'Your output must strictly follow this format:\n{"answer": <the list of selected options, e.g., [%s]>}\n'%option_or
elif self.text_type == 'judgement':
output_fmt = ', '.join([f'"{option}": <"True" or "False">' for option in self.option_ids])
output_fmt = "{" + output_fmt + "}"
prompt += f'Your output must strictly follow this format:\n{output_fmt}\n'
prompt = prefix + prompt
prompt += "Your output:"
return prompt
def get_formatted_answer(self):
result = None
if self.text_type == 'choice':
answers = []
for option, correct in zip(self.option_ids, self.correct):
if correct:
answers.append(option)
content = ', '.join(['"'+option+'"' for option in answers])
result = f"\"answer\": [{content}]"
result = '{' + result + '}'
elif self.text_type == 'judgement':
result = ', '.join([f"\"{option}\":\"{correct}\"" for option, correct in zip(self.option_ids, self.correct)])
result = '{' + result + '}'
return result
def get_mcq_llm_answer(mcq: MultipleChoiceQuestion, llm: LargeLanguageModel) -> tuple:
''' Get the answer of an LLM to a multiple-choice question.
Args:
mcq:MultipleChoiceQuestion, the question to answer
llm:LargeLauguageModel, the model that answer the question
Return:
List[bool]: the list that indicates whether each of the
options is selected by the model.
str: the original response of the model.
'''
prompt = mcq.get_prompt()
response_ok = False
max_retry = 3
n_retry = 1
result = [False] * len(mcq.correct)
while response_ok is False and n_retry <= max_retry:
try:
llm.refresh()
original_response = ''
original_response = llm.listen_and_response(prompt)
response = re.sub(r'\n', ' ', original_response)
response = re.findall(r'[{]\s*"[^{]*[}]', response)[0]
# print(response)
if mcq.text_type == 'choice':
response = json.loads(response)
for i in range(len(mcq.option_ids)):
if mcq.option_ids[i] in response['answer']:
result[i] = True
elif mcq.text_type == 'judgement':
response = re.sub(r'\s+True',' "True"', response)
response = re.sub(r'\s+False',' "False"', response)
response = json.loads(response)
oid2pos = {}
for i in range(len(mcq.option_ids)):
oid2pos[mcq.option_ids[i]] = i
for key in response.keys():
result[oid2pos[key]] = eval(response[key])
else:
logging.error(f"get_mcq_llm_answer: Invalid text type '{mcq.text_type}'")
response_ok = True
except:
logging.error(f"original_llm_answer = {original_response}")
logging.error(f"get_mcq_llm_answer: Format error, try again. n_retry = {n_retry}")
n_retry += 1
return result, original_response
class KPPerturbation(ABC):
def __init__(self):
self.method = "default"
pass
@abstractmethod
def perturb(self, mcq:MultipleChoiceQuestion) -> MultipleChoiceQuestion:
pass
class OptionFormatPerturbation(KPPerturbation):
def __init__(self, method:str = "add_parentheses",):
'''
Args:
method:str, the perturbation method
'''
super().__init__()
self.method = method
self.formatter = None
if method == "add_left_parenthesis": self.formatter = _add_left_parenthesis
elif method == "add_left_bracket": self.formatter = _add_left_bracket
elif method == "add_left_brace": self.formatter = _add_left_brace
elif method == "add_left_wave": self.formatter = _add_left_wave
elif method == "add_right_parenthesis": self.formatter = _add_right_parenthesis
elif method == "add_right_bracket": self.formatter = _add_right_bracket
elif method == "add_right_brace": self.formatter = _add_right_brace
elif method == "add_right_wave": self.formatter = _add_right_wave
elif method == "add_right_eq": self.formatter = _add_right_eq
elif method == "add_parentheses": self.formatter = _add_parentheses
elif method == "add_brackets": self.formatter = _add_brackets
elif method == "add_braces": self.formatter = _add_braces
elif method == "add_waves": self.formatter = _add_waves
else: raise Exception("Invalid option format perturbation method.")
def __str__(self):
return f"OptionFormatPerturbation.{self.method}"
def perturb(self, mcq:MultipleChoiceQuestion) -> MultipleChoiceQuestion:
assert(len(mcq.option_ids) == len(mcq.options))
try:
new_option_ids = [self.formatter(elem) for elem in mcq.option_ids]
result = copy.deepcopy(mcq)
result.option_ids = new_option_ids
except:
logging.error('OptionFormatPerturbation error. Keep the original result.')
result = copy.deepcopy(mcq)
return result
class CaesarPerturbation(KPPerturbation):
def __init__(self, delta:int = 20):
'''
Args:
delta:int, the offset value in ASCII of option ids.
'''
super().__init__()
self.formatter = partial(_caesar, delta = delta)
def __str__(self):
return f"CaesarPerturbation.{self.method}"
def perturb(self, mcq:MultipleChoiceQuestion) -> MultipleChoiceQuestion:
assert(len(mcq.option_ids) == len(mcq.options))
try:
new_option_ids = [self.formatter(elem) for elem in mcq.option_ids]
result = copy.deepcopy(mcq)
result.option_ids = new_option_ids
except:
logging.error('CaesarPerturbation error. Keep the original result.')
result = copy.deepcopy(mcq)
return result
class OptionPermutationPerturbation(KPPerturbation):
def __init__(self, permutation_map = {0:3,1:2,2:1,3:0}):
'''
Args:
permutation_map:dict. The key denotes the original position of option contents,
while the value denotes the target position to place option contents.
If it is None, then the permutation map is randomly generated.
'''
super().__init__()
self.map = permutation_map
self.rmap = {}
for k in self.map.keys():
self.rmap[self.map[k]] = k
def __str__(self):
return f"OptionPermutationPerturbation.{self.method}"
def perturb(self, mcq:MultipleChoiceQuestion) -> MultipleChoiceQuestion:
assert(len(mcq.option_ids) == len(mcq.options))
result = copy.deepcopy(mcq)
new_options = [mcq.options[self.rmap.get(i,i)] for i in range(len(mcq.options))]
new_correct = [mcq.correct[self.rmap.get(i,i)] for i in range(len(mcq.correct))]
result.options = new_options
result.correct = new_correct
# else: raise Exception("Invalid permutation perturbation method.")
return result
class ChangeQuestionPosPerturbation(KPPerturbation):
def __init__(self):
super().__init__()
def __str__(self):
return f"ChangeQuestionPosPerturbation.{self.method}"
def perturb(self, mcq:MultipleChoiceQuestion) -> MultipleChoiceQuestion:
result = copy.deepcopy(mcq)
result.question_first = not result.question_first
return result
class ChangeTypePerturbation(KPPerturbation):
def __init__(self):
super().__init__()
self.type_dict = {0:'choice', 1:'judgement'}
self.rev_type_dict = {}
for key in self.type_dict:
self.rev_type_dict[self.type_dict[key]] = key
def __str__(self):
return f"ChangeTypePerturbation.{self.method}"
def perturb(self, mcq:MultipleChoiceQuestion) -> MultipleChoiceQuestion:
type_id = self.rev_type_dict[mcq.text_type]
new_type_id = (type_id + 1) % len(self.type_dict)
result = copy.deepcopy(mcq)
result.text_type = self.type_dict[new_type_id]
return result
class QuestionRewriter:
def __init__(self):
self.nlp = spacy.load('en_core_web_sm')
self.n_requests = 0
def rewrite(self, mcq: MultipleChoiceQuestion, rewriter: ChatGPT, n_candidates = 2, similarity_score = 0.6)->List[List[str]]:
assert(n_candidates > 1)
self.n_requests = 0
doc = self.nlp(mcq.question)
sentences = [item.text for item in doc.sents]
result = []
for i in range(0, len(sentences)):
context = ' '.join(sentences[:i])
sentence = sentences[i]
if len(sentences) > 1:
prompt = f'''Here is a sentence in a multiple choice question. Please rewrite the sentence given its context and the expected similarity score. Here are necessary requirements:
[Requirements Start]
1. Be consistent with its context.
2. The rewrited sentence should keep the semantic of the original sentence.
3. If the sentence contains blanks/underlines to be filled, these blanks/underlines should be kept after paraphrasing.
4. You can utilize various rewriting skills (e.g., add/replace/delete words, paraphrase) to make it looks different from the original.
[Requirements End]
[Meaning of Expected Similarity Score Start]
For the expected similarity score (0.0~1.0),1.0 denotes that the rewrited is exactly the same as the original; 0.8 denotes that the the there exist word-level differences between the rewrited and the original; 0.6 denotes that there exist not only word-level, but lots of sentence structure-level differences between the rewrited and the original; 0.4 denotes that you are allowed to entirely paraphrase the sentence by your own; 0.2 denotes that you are allowed to add misleading statements to the current sentence.
[Meaning of Expected Similarity Score End]
You should only output the rewrited sentence without any extra content.
Expected similarity score: {similarity_score}
Context: {context}
Sentence: {sentence}
Your output:'''
else:
prompt = f'''Here is a sentence in a multiple choice question. Please rewrite the sentence given its context and the expected similarity score. Here are necessary requirements:
[Requirements Start]
1. Be consistent with its context.
2. The rewrited sentence should keep the semantic of the original sentence.
3. If the sentence contains blanks/underlines to be filled, these blanks/underlines should be kept after paraphrasing.
4. You can utilize various rewriting skills (e.g., add/replace/delete words, paraphrase) to make it looks different from the original.
[Requirements End]
[Meaning of Expected Similarity Score Start]
For the expected similarity score (0.0~1.0),1.0 denotes that the rewrited is exactly the same as the original; 0.8 denotes that the the there exist word-level differences between the rewrited and the original; 0.6 denotes that there exist not only word-level, but lots of sentence structure-level differences between the rewrited and the original; 0.4 denotes that you are allowed to entirely paraphrase the sentence by your own; 0.2 denotes that you are allowed to add misleading statements to the current sentence.
[Meaning of Expected Similarity Score End]
You should only output the rewrited sentence without any extra content.
Expected similarity score: {similarity_score}
Sentence: {sentence}
Your output:'''
rewriter.refresh()
response = rewriter.listen_and_response(prompt, n_outputs = n_candidates)
self.n_requests += 1
response = [sentence] + response
result.append(response)
return result
class QuestionGenerator:
def __init__(self):
self.transformer = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
def generate(self, candidates:List[List[str]]) -> tuple:
embeddings = []
similarities = []
priorities = []
question_texts = []
question_similarities = []
# for j in range(len(candidates[0])):
# candidates[0][j] = candidates[0][0]
for i in range(len(candidates)):
for j in range(len(candidates[i])):
if "Sentence:" in candidates[i][j]:
candidates[i][j] = re.findall('.*Sentence:\s*(.*)',candidates[i][j])[0]
for elem in candidates:
emb = self.transformer.encode(elem, convert_to_tensor=True)
embeddings.append(emb)
similarity = []
for j in range(emb.shape[0]):
j_sim = util.pytorch_cos_sim(emb[0], emb[j]).detach().numpy()[0,0]
similarity.append(j_sim)
priority = np.argsort(similarity)[::-1]
similarities.append(similarity)
priorities.append(priority.tolist())
priorities = np.array(priorities)
original_text = ' '.join(candidates[i][0] for i in range(priorities.shape[0]))
embedding_1 = self.transformer.encode(original_text, convert_to_tensor=True)
for top_k in range(priorities.shape[1]):
question_text = ' '.join(
candidates[i][priorities[i][top_k]] for i in range(priorities.shape[0]))
embedding_2 = self.transformer.encode(question_text, convert_to_tensor=True)
q_sim = util.pytorch_cos_sim(embedding_1, embedding_2).detach().numpy()[0,0]
question_texts.append(question_text)
question_similarities.append(q_sim)
return question_texts, question_similarities
class ParaphrasingPerturbation(KPPerturbation):
def __init__(self, paraphrase_config:dict, rewriter: ChatGPT):
'''
Args:
paraphrase_config:dict, the configuration of paraphrasing.
<key:value> = {"n_candidates":int,
"similarity_score":float}
"n_candidates" denotes the number of generated candidates for each
question sentence. "similarity_score" denotes the expected similarity
score of the paraphrasing result, between 0 and 1. The larger, the similar.
rewriter:ChatGPT (gpt-4-turbo is recommended), the rewriter that paraphrases questions.
'''
super().__init__()
self.questionRewriter = QuestionRewriter()
self.questionGenerator = QuestionGenerator()
self.rewriterLM = rewriter
self.n_candidates = paraphrase_config["n_candidates"]
self.similarity_score = paraphrase_config["similarity_score"]
def __str__(self):
return f"ParaphrasingPerturbation_n_candidates_{self.n_candidates}_similarity_score_{self.similarity_score}"
def perturb(self, mcq:MultipleChoiceQuestion) -> MultipleChoiceQuestion:
candidate_texts = self.questionRewriter.rewrite(
mcq = mcq,
rewriter = self.rewriterLM,
n_candidates = self.n_candidates,
similarity_score = self.similarity_score)
candidate_questions, similarities = self.questionGenerator.generate(candidate_texts)
mid = int(len(candidate_questions)/2)
result = copy.deepcopy(mcq)
result.question = candidate_questions[mid]
return result
class MixedPerturbation(KPPerturbation):
''' The MixedPerturbation is the composite of atomic knowledge-invariant perturbations.'''
def __init__(self, perturbations:List[KPPerturbation] = None):
super().__init__()
self.perturbations = perturbations if isinstance(perturbations, list) else []
def __str__(self):
result = 'MixedPerturbation = [\n' + ',\n'.join(
' ' * 4 + elem.__str__() for elem in self.perturbations)+'\n]\n'
return result
def refresh(self):
self.perturbations = []
def push(self, elem:KPPerturbation):
self.perturbations.append(elem)
return
def pop(self):
if len(self.perturbations > 0):
del self.perturbations[-1]
return
def perturb(self, mcq:MultipleChoiceQuestion) -> MultipleChoiceQuestion:
assert(len(mcq.option_ids) == len(mcq.options))
result = copy.deepcopy(mcq)
for pt_elem in self.perturbations:
result = pt_elem.perturb(result)
return result
if __name__ == "__main__":
mcq = MultipleChoiceQuestion(
question="__________ memory is the aspect of memory that is involved in the recall of information acquired within the past few hours to days.",
option_ids=['A','B','C','D'],
options=['Working', 'Sensory', 'Long-term', 'Prospective'],
question_first = True,
correct = [False,False,False,True]
)
llm = ChatGPT(name = "chatgpt", description = "The chatgpt assistant.", model = "gpt-4-turbo", temperature = 1.0)
paraphrase_config = {
"n_candidates":3,
"similarity_score": 0.6
}
ptb = ParaphrasingPerturbation(rewriter = llm, paraphrase_config = paraphrase_config)
result = ptb.perturb(mcq)
print(f"Original question: {mcq.get_prompt()}")
print(f"Paraphrased question: {result.get_prompt()}")