Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Handle empty tokenization in perplexity and allow ad-hoc Advan… #802

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,5 @@ checkpoints/
!tests/sample_outputs/csv_attack_log.csv
tests/test_command_line/attack_log.txt
textattack/=22.3.0

venv/
97 changes: 94 additions & 3 deletions tests/test_metric_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest


def test_perplexity():
from textattack.attack_results import SuccessfulAttackResult
from textattack.attack_results import FailedAttackResult, SuccessfulAttackResult
from textattack.goal_function_results.classification_goal_function_result import (
ClassificationGoalFunctionResult,
)
Expand All @@ -15,14 +18,88 @@ def test_perplexity():
AttackedText(sample_text), None, None, None, None, None, None
),
ClassificationGoalFunctionResult(
AttackedText(sample_atck_text), None, None, None, None, None, None
AttackedText(
sample_atck_text), None, None, None, None, None, None
),
)
]
ppl = Perplexity(model_name="distilbert-base-uncased").calculate(results)

assert int(ppl["avg_original_perplexity"]) == int(81.95)

results = [
FailedAttackResult(
ClassificationGoalFunctionResult(
AttackedText(sample_text), None, None, None, None, None, None
),
)
]

Perplexity(model_name="distilbert-base-uncased").calculate(results)

ppl = Perplexity(model_name="distilbert-base-uncased")
texts = [sample_text]
ppl.ppl_tokenizer.encode(" ".join(texts), add_special_tokens=True)

encoded = ppl.ppl_tokenizer.encode(" ".join([]), add_special_tokens=True)
assert len(encoded) > 0


def test_perplexity_empty_results():
from textattack.metrics.quality_metrics import Perplexity

ppl = Perplexity()
with pytest.raises(ValueError):
ppl.calculate([])

ppl = Perplexity("gpt2")
with pytest.raises(ValueError):
ppl.calculate([])

ppl = Perplexity(model_name="distilbert-base-uncased")
ppl_values = ppl.calculate([])

assert "avg_original_perplexity" in ppl_values
assert "avg_attack_perplexity" in ppl_values


def test_perplexity_no_model():
from textattack.attack_results import FailedAttackResult, SuccessfulAttackResult
from textattack.goal_function_results.classification_goal_function_result import (
ClassificationGoalFunctionResult,
)
from textattack.metrics.quality_metrics import Perplexity
from textattack.shared.attacked_text import AttackedText

sample_text = "hide new secretions from the parental units "
sample_atck_text = "Ehide enw secretions from the parental units "

results = [
SuccessfulAttackResult(
ClassificationGoalFunctionResult(
AttackedText(sample_text), None, None, None, None, None, None
),
ClassificationGoalFunctionResult(
AttackedText(
sample_atck_text), None, None, None, None, None, None
),
)
]

ppl = Perplexity()
ppl_values = ppl.calculate(results)

assert "avg_original_perplexity" in ppl_values
assert "avg_attack_perplexity" in ppl_values


def test_perplexity_calc_ppl():
from textattack.metrics.quality_metrics import Perplexity

ppl = Perplexity("gpt2")
with pytest.raises(ValueError):
ppl.calc_ppl([])


def test_use():
import transformers
Expand Down Expand Up @@ -85,5 +162,19 @@ def test_metric_recipe():
attacker = Attacker(attack, dataset, attack_args)
results = attacker.attack_dataset()

adv_score = AdvancedAttackMetric(["meteor_score", "perplexity"]).calculate(results)
adv_score = AdvancedAttackMetric(
["meteor_score", "perplexity"]).calculate(results)
assert adv_score["avg_attack_meteor_score"] == 0.71


def test_metric_ad_hoc():
from textattack.metrics.quality_metrics import Perplexity
from textattack.metrics.recipe import AdvancedAttackMetric

metrics = AdvancedAttackMetric()
metrics.add_metric("perplexity", Perplexity(
model_name="distilbert-base-uncased"))

metric_results = metrics.calculate([])

assert "perplexity" in metric_results
5 changes: 4 additions & 1 deletion textattack/metrics/quality_metrics/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,11 @@ def calc_ppl(self, texts):
input_ids = torch.tensor(
self.ppl_tokenizer.encode(text, add_special_tokens=True)
).unsqueeze(0)
if not (input_ids_size := input_ids.size(1)):
raise ValueError("No tokens recognized for input text")

# Strided perplexity calculation from huggingface.co/transformers/perplexity.html
for i in range(0, input_ids.size(1), self.stride):
for i in range(0, input_ids_size, self.stride):
begin_loc = max(i + self.stride - self.max_length, 0)
end_loc = min(i + self.stride, input_ids.size(1))
trg_len = end_loc - i
Expand Down
35 changes: 23 additions & 12 deletions textattack/metrics/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,30 @@ class AdvancedAttackMetric(Metric):
"""Calculate a suite of advanced metrics to evaluate attackResults'
quality."""

def __init__(self, choices=["use"]):
def __init__(self, choices: list[str] = ["use"]):
self.achoices = choices
available_metrics = {
"use": USEMetric,
"perplexity": Perplexity,
"bert_score": BERTScoreMetric,
"meteor_score": MeteorMetric,
"sbert_score": SBERTMetric,
}
self.selected_metrics = {}
for choice in self.achoices:
if choice not in available_metrics:
raise KeyError(f"'{choice}' is not a valid metric name")
metric = available_metrics[choice]()
self.selected_metrics.update({choice: metric})

def calculate(self, results):
def add_metric(self, name: str, metric: Metric):
if not isinstance(metric, Metric):
raise ValueError(f"Object {metric} must be a subtype of Metric")
self.selected_metrics.update({name: metric})

def calculate(self, results) -> dict[str, float]:
advanced_metrics = {}
if "use" in self.achoices:
advanced_metrics.update(USEMetric().calculate(results))
if "perplexity" in self.achoices:
advanced_metrics.update(Perplexity().calculate(results))
if "bert_score" in self.achoices:
advanced_metrics.update(BERTScoreMetric().calculate(results))
if "meteor_score" in self.achoices:
advanced_metrics.update(MeteorMetric().calculate(results))
if "sbert_score" in self.achoices:
advanced_metrics.update(SBERTMetric().calculate(results))
# TODO: Would like to guarantee unique keys from calls to calculate()
for metric in self.selected_metrics.values():
advanced_metrics.update(metric.calculate(results))
return advanced_metrics