Skip to content

Commit

Permalink
Merge pull request #203 from thunlp/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
a710128 authored May 1, 2021
2 parents 3dc04ab + 3298e91 commit cb7a8e1
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 18 deletions.
2 changes: 1 addition & 1 deletion OpenAttack/attackers/fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __call__(self, clsf, x_orig, target=None):
if iter_cnt > 5 * len(sent): # Failed to find a substitute word
return None
try:
reps = list(map(lambda x:x[0], self.substitute(sent[idx], pos=None, threshold=self.config["threshold"])))
reps = list(map(lambda x:x[0], self.substitute(sent[idx], None, threshold=self.config["threshold"])))
except WordNotInDictionaryException:
continue
reps = list(filter(lambda x: x in self.wordid, reps))
Expand Down
2 changes: 1 addition & 1 deletion OpenAttack/attackers/hotflip.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_neighbours(self, word, POS, num):
sub_words = list(
map(
lambda x: x[0],
self.config["substitute"](word, pos=POS, threshold=threshold)[:num],
self.config["substitute"](word, POS, threshold=threshold)[:num],
)
)
neighbours = []
Expand Down
2 changes: 1 addition & 1 deletion OpenAttack/attackers/textbugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def generateBugs(self, word, glove_vectors, sub_w_enabled=False, typo_enabled=Fa

def bug_sub_W(self, word):
try:
res = self.counterfit.__call__(word, pos=None, threshold=0.5)
res = self.counterfit(word, None, threshold=0.5)
if len(res) == 0:
return word
return res[0][0]
Expand Down
5 changes: 3 additions & 2 deletions OpenAttack/attackers/textfooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"import_score_threshold": -1.,
"sim_score_threshold": 0.5,
"sim_score_window": 15,
"threshold": 0.5,
"synonym_num": 50,
"processor": DefaultTextProcessor(),
"substitute": None,
Expand Down Expand Up @@ -190,12 +191,12 @@ def __call__(self, clsf, x_orig, target=None):


def get_neighbours(self, word, pos):
threshold = 0.5
threshold = self.config["threshold"]
try:
return list(
map(
lambda x: x[0],
self.config["substitute"](word, pos=pos, threshold=threshold)[1 : self.config["synonym_num"] + 1],
self.config["substitute"](word, pos, threshold=threshold)[1 : self.config["synonym_num"] + 1],
)
)
except WordNotInDictionaryException:
Expand Down
12 changes: 7 additions & 5 deletions OpenAttack/substitutes/chinese_hownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ def __init__(self):
self.hownet_dict = DataManager.load("AttackAssist.HowNet")
self.zh_word_list = self.hownet_dict.get_zh_words()

def __call__(self, word, pos_tag, threshold=None):
if pos_tag[:2] == "JJ":
def __call__(self, word, pos, threshold=None):
if pos is None:
pp = None
elif pos[:2] == "JJ":
pp = "adj"
elif pos_tag[:2] == "VB":
elif pos[:2] == "VB":
pp = "verb"
elif pos_tag[:2] == "NN":
elif pos[:2] == "NN":
pp = "noun"
elif pos_tag[:2] == "RB":
elif pos[:2] == "RB":
pp = "adv"
else:
pp = None
Expand Down
4 changes: 3 additions & 1 deletion OpenAttack/substitutes/chinese_wordnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def __init__(self):

def __call__(self, word, pos_tag, threshold=None):
from nltk.corpus import wordnet as wn
if pos_tag[:2] == "JJ":
if pos_tag is None:
pp = None
elif pos_tag[:2] == "JJ":
pp = "adj"
elif pos_tag[:2] == "VB":
pp = "verb"
Expand Down
4 changes: 3 additions & 1 deletion OpenAttack/substitutes/hownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def __init__(self):

def __call__(self, word, pos_tag, threshold=None):
pp = "noun"
if pos_tag[:2] == "JJ":
if pos_tag is None:
pp = None
elif pos_tag[:2] == "JJ":
pp = "adj"
elif pos_tag[:2] == "VB":
pp = "verb"
Expand Down
2 changes: 1 addition & 1 deletion examples/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def main():
# The load operation returns a PytorchClassifier that can be further used for Attacker and AttackEval.

dataset = datasets.load_dataset("sst", split="train[:20]").map(function=dataset_mapping)
# Dataset.SST.sample is a list of 1k sentences sampled from test dataset of Dataset.SST.
# We load sst dataset using `datasets` package, and map the fields.

attacker = OpenAttack.attackers.GeneticAttacker()
# After this step, we’ve initialized a GeneticAttacker and uses the default configuration during attack process.
Expand Down
8 changes: 4 additions & 4 deletions slow_tests/attackers_chinese.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def get_attackers_on_chinese(dataset, clsf):
chinese_substitute = OpenAttack.substitutes.ChineseHowNetSubstitute()
#Attackers that current support Chinese: SememePSO TextFooler PWWS Genetic FD TextBugger
attackers = [
OpenAttack.attackers.FDAttacker(word2id=clsf.word2id, embedding=clsf.embedding, processor=chinese_processor, substitute=chinese_substitute),
OpenAttack.attackers.FDAttacker(word2id=clsf.word2id, embedding=clsf.embedding, token_unk="[UNK]", processor=chinese_processor, substitute=chinese_substitute),
OpenAttack.attackers.TextBuggerAttacker(processor=chinese_processor),
OpenAttack.attackers.TextFoolerAttacker(processor=chinese_processor, substitute=chinese_substitute),
OpenAttack.attackers.GeneticAttacker(processor=chinese_processor, substitute=chinese_substitute, skip_words=["的", "了", "着"]),
OpenAttack.attackers.PWWSAttacker(processor=chinese_processor, substitute=chinese_substitute),
OpenAttack.attackers.TextFoolerAttacker(processor=chinese_processor, substitute=chinese_substitute, threshold=10),
OpenAttack.attackers.GeneticAttacker(processor=chinese_processor, substitute=chinese_substitute, skip_words=["的", "了", "着"], neighbour_threshold=10),
OpenAttack.attackers.PWWSAttacker(processor=chinese_processor, substitute=chinese_substitute, threshold=10),
OpenAttack.attackers.PSOAttacker(processor=chinese_processor, substitute=chinese_substitute)
]
return attackers
2 changes: 1 addition & 1 deletion slow_tests/test_chinese_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main():
try:
st = time.perf_counter()
print(
OpenAttack.attack_evals.ChineseAttackEval(attacker, time_clsf, progress_bar=True).eval(dataset, visualize=True),
OpenAttack.attack_evals.ChineseAttackEval(attacker, time_clsf).eval(dataset, progress_bar=True),
time_clsf.total_time,
time.perf_counter() - st
)
Expand Down

0 comments on commit cb7a8e1

Please sign in to comment.