diff --git a/docs/EVALUATION.md b/docs/EVALUATION.md
index 1126033..6ed41d3 100644
--- a/docs/EVALUATION.md
+++ b/docs/EVALUATION.md
@@ -1,5 +1,13 @@
# Evaluation
+We provide a detailed evaluation of the accuracy and speed of the `hashformers` framework in comparison with alternative libraries.
+
+Although models based on n-grams such as `ekphrasis` are orders of magnitude faster than `hashformers`, they are remarkably unstable across different datasets.
+
+Research papers on word segmentation usually try to bring the best of both worlds together and combine deep learning with statistical methods. So it is possible that the best speed-accuracy trade-off may lie in building [ranking cascades](https://arxiv.org/abs/2010.06467) ( a.k.a. "telescoping" ) where `hashformers` is used as a fallback for when less time-consuming methods score below a certain confidence threshold.
+
+## Accuracy
+
@@ -34,4 +42,30 @@ A script to reproduce the evaluation of ekphrasis is available on [scripts/evalu
| | | |
| average (all) | HashtagMaster | 58.35 |
| | ekphrasis | 41.65 |
-| |**hashformers**| **68.06**|
\ No newline at end of file
+| |**hashformers**| **68.06**|
+
+## Speed
+
+| model | hashtags/second | accuracy | topk | layers|
+|:--------------|:----------------|----------:|-----:|------:|
+| ekphrasis | 4405.00 | 44.74 | - | - |
+| gpt2-large | 12.04 | 63.86 | 2 | first |
+| distilgpt2 | 29.32 | 64.56 | 2 | first |
+|**distilgpt2** | **15.00** | **80.48** |**2** |**all**|
+| gpt2 | 11.36 | - | 2 | all |
+| gpt2 | 3.48 | - | 20 | all |
+| gpt2 + bert | 1.38 | 83.68 | 20 | all |
+
+In this table we evaluate hashformers under different settings on the Dev-BOUN dataset and compare it with ekphrasis. As ekphrasis relies on n-grams, it is a few orders of magnitude faster than hashformers.
+
+All experiments were performed on Google Colab while connected to a Tesla T4 GPU with 15GB of RAM. We highlight `distilgpt2` at `topk = 2`, which provides the best speed-accuracy trade-off.
+
+* **model**: The name of the model. We evaluate ekphrasis under the default settings, and use the reranker only for the SOTA experiment at the bottom row.
+
+* **hashtags/second**: How many hashtags the model can segment per second. All experiments on hashformers had the `batch_size` parameter adjusted to take up close to 100% of GPU RAM. A sidenote: even at 100% of GPU memory usage, we get about 60% of GPU utilization. So you may get better results by using GPUs with more memory than 16GB.
+
+* **accuracy**: Accuracy on the Dev-BOUN dataset. We don't evaluate the accuracy of `gpt2`, but we know [from the literature](https://arxiv.org/abs/2112.03213) that it is expected to be between `distilgpt2` (at 80%) and `gpt2 + bert` (the SOTA, at 83%).
+
+* **topk**: the `topk` parameter of the Beamsearch algorithm ( passed as the `topk` argument to the `WordSegmenter.segment` method). The `steps` Beamsearch parameter was fixed at a default value of 13 for all experiments with hashformers, as it doesn't have a significant impact on performance as `topk`.
+
+* **layers**: How many Transformer layers were utilized for language modeling: either all layers or just the bottom layer.
\ No newline at end of file
diff --git a/setup.py b/setup.py
index b87ef40..4cc8f8d 100644
--- a/setup.py
+++ b/setup.py
@@ -10,6 +10,9 @@
package_dir={'': 'src'},
install_requires=[
"mlm-hashformers",
- "lm-scorer-hashformers"
+ "lm-scorer-hashformers",
+ "twitter-text-python",
+ "ekphrasis",
+ "pandas",
]
)
diff --git a/src/hashformers/beamsearch/gpt2_lm.py b/src/hashformers/beamsearch/gpt2_lm.py
index e517e18..36f2ec7 100644
--- a/src/hashformers/beamsearch/gpt2_lm.py
+++ b/src/hashformers/beamsearch/gpt2_lm.py
@@ -1,9 +1,82 @@
-from lm_scorer.models.auto import GPT2LMScorer as LMScorer
+from lm_scorer.models.auto import GPT2LMScorer
+from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
+import torch
+from transformers import AutoTokenizer, GPT2LMHeadModel
+from transformers.tokenization_utils import BatchEncoding
+
+class PaddedGPT2LMScorer(GPT2LMScorer):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _build(self, model_name: str, options: Dict[str, Any]) -> None:
+ super()._build(model_name, options)
+
+ # pylint: disable=attribute-defined-outside-init
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_name, use_fast=True, add_special_tokens=False
+ )
+ # Add the pad token to GPT2 dictionary.
+ # len(tokenizer) = vocab_size + 1
+ self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|pad|>"]})
+ self.tokenizer.pad_token = "<|pad|>"
+
+ self.model = GPT2LMHeadModel.from_pretrained(model_name)
+ # We need to resize the embedding layer because we added the pad token.
+ self.model.resize_token_embeddings(len(self.tokenizer))
+ self.model.eval()
+ if "device" in options:
+ self.model.to(options["device"])
+
+ def _tokens_log_prob_for_batch(
+ self, text: List[str]
+ ) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
+ outputs: List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]] = []
+ if len(text) == 0:
+ return outputs
+
+ # TODO: Handle overflowing elements for long sentences
+ text = list(map(self._add_special_tokens, text))
+ encoding: BatchEncoding = self.tokenizer.batch_encode_plus(
+ text, return_tensors="pt", padding=True, truncation=True
+ )
+ with torch.no_grad():
+ ids = encoding["input_ids"].to(self.model.device)
+ attention_mask = encoding["attention_mask"].to(self.model.device)
+ nopad_mask = ids != self.tokenizer.pad_token_id
+ logits: torch.Tensor = self.model(ids, attention_mask=attention_mask)[0]
+
+ for sent_index in range(len(text)):
+ sent_nopad_mask = nopad_mask[sent_index]
+ # len(tokens) = len(text[sent_index]) + 1
+ sent_tokens = [
+ tok
+ for i, tok in enumerate(encoding.tokens(sent_index))
+ if sent_nopad_mask[i] and i != 0
+ ]
+
+ # sent_ids.shape = [len(text[sent_index]) + 1]
+ sent_ids = ids[sent_index, sent_nopad_mask][1:]
+ # logits.shape = [len(text[sent_index]) + 1, vocab_size]
+ sent_logits = logits[sent_index, sent_nopad_mask][:-1, :]
+ sent_logits[:, self.tokenizer.pad_token_id] = float("-inf")
+ # ids_scores.shape = [seq_len + 1]
+ sent_ids_scores = sent_logits.gather(1, sent_ids.unsqueeze(1)).squeeze(1)
+ # log_prob.shape = [seq_len + 1]
+ sent_log_probs = sent_ids_scores - sent_logits.logsumexp(1)
+
+ sent_log_probs = cast(torch.DoubleTensor, sent_log_probs)
+ sent_ids = cast(torch.LongTensor, sent_ids)
+
+ output = (sent_log_probs, sent_ids, sent_tokens)
+ outputs.append(output)
+
+ return outputs
class GPT2LM(object):
def __init__(self, model_name_or_path, device='cuda', gpu_batch_size=20):
- self.scorer = LMScorer(model_name_or_path, device=device, batch_size=gpu_batch_size)
+ self.scorer = PaddedGPT2LMScorer(model_name_or_path, device=device, batch_size=gpu_batch_size)
def get_probs(self, list_of_candidates):
scores = self.scorer.sentence_score(list_of_candidates, log=True)
diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py
index 2412f1c..6280a1f 100644
--- a/src/hashformers/segmenter.py
+++ b/src/hashformers/segmenter.py
@@ -1,11 +1,121 @@
+import hashformers
from hashformers.beamsearch.algorithm import Beamsearch
from hashformers.beamsearch.reranker import Reranker
from hashformers.beamsearch.data_structures import enforce_prob_dict
from hashformers.ensemble.top2_fusion import top2_ensemble
from typing import List, Union, Any
+from dataclasses import dataclass
+import pandas as pd
+from ttp import ttp
+from ekphrasis.classes.segmenter import Segmenter as EkphrasisSegmenter
+import re
+import copy
+import torch
+from math import log10
+from functools import reduce
+import dataclasses
-class WordSegmenter(object):
+@dataclass
+class WordSegmenterOutput:
+ output: List[str]
+ segmenter_rank: Union[pd.DataFrame, None] = None
+ reranker_rank: Union[pd.DataFrame, None] = None
+ ensemble_rank: Union[pd.DataFrame, None] = None
+@dataclass
+class HashtagContainer:
+ hashtags: List[List[str]]
+ hashtag_set: List[str]
+ replacement_dict: dict
+
+@dataclass
+class TweetSegmenterOutput:
+ output: List[str]
+ word_segmenter_output: Any
+
+def coerce_segmenter_objects(method):
+ def wrapper(inputs, *args, **kwargs):
+ if isinstance(inputs, str):
+ output = method([inputs], *args, **kwargs)
+ else:
+ output = method(inputs, *args, **kwargs)
+ for allowed_type in [
+ hashformers.segmenter.WordSegmenterOutput,
+ hashformers.segmenter.TweetSegmenterOutput
+ ]:
+ if isinstance(output, allowed_type):
+ return output
+ else:
+ return WordSegmenterOutput(output=output)
+ return wrapper
+
+def prune_segmenter_layers(ws, layer_list=[0]):
+ ws.segmenter_model.model.scorer.model = \
+ deleteEncodingLayers(ws.segmenter_model.model.scorer.model, layer_list=layer_list)
+ return ws
+
+def deleteEncodingLayers(model, layer_list=[0]):
+ oldModuleList = model.transformer.h
+ newModuleList = torch.nn.ModuleList()
+
+ for index in layer_list:
+ newModuleList.append(oldModuleList[index])
+
+ copyOfModel = copy.deepcopy(model)
+ copyOfModel.transformer.h = newModuleList
+
+ return copyOfModel
+
+class BaseSegmenter(object):
+
+ @coerce_segmenter_objects
+ def predict(self, *args, **kwargs):
+ return self.segment(*args, **kwargs)
+
+class EkphrasisWordSegmenter(EkphrasisSegmenter, BaseSegmenter):
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def find_segment(self, text, prev=''):
+ if not text:
+ return 0.0, []
+ candidates = [self.combine((log10(self.condProbWord(first, prev)), first), self.find_segment(rem, first))
+ for first, rem in self.splits(text)]
+ return max(candidates)
+
+ def segment_word(self, word) -> str:
+ if word.islower():
+ return " ".join(self.find_segment(word)[1])
+ else:
+ return self.case_split.sub(r' \1', word).lower()
+
+ def segment(self, inputs) -> List[str]:
+ return [ self.segment_word(word) for word in inputs ]
+
+class RegexWordSegmenter(BaseSegmenter):
+
+ def __init__(self,regex_rules=None):
+ if not regex_rules:
+ regex_rules = [r'([A-Z]+)']
+ self.regex_rules = [
+ re.compile(x) for x in regex_rules
+ ]
+
+ def segment_word(self, rule, word):
+ return rule.sub(r' \1', word).strip()
+
+ def segmentation_generator(self, word_list):
+ for rule in self.regex_rules:
+ for idx, word in enumerate(word_list):
+ yield self.segment_word(rule, word)
+
+ def segment(self, inputs: List[str]):
+ return list(self.segmentation_generator(inputs))
+
+class WordSegmenter(BaseSegmenter):
+ """A general-purpose word segmentation API.
+ """
def __init__(
self,
segmenter_model_name_or_path = "gpt2",
@@ -54,12 +164,11 @@ def segment(
alpha: float = 0.222,
beta: float = 0.111,
use_reranker: bool = True,
- return_ranks: bool = False,
- trim_hashtags: bool = True) -> Any :
- """Segment a list of hashtags.
+ return_ranks: bool = False) -> Any :
+ """Segment a list of strings.
Args:
- word_list (List[str]): A list of hashtag strings.
+ word_list (List[str]): A list of strings.
topk (int, optional):
top-k parameter for the Beamsearch algorithm.
A lower top-k value will speed up the algorithm.
@@ -86,18 +195,11 @@ def segment(
return_ranks (bool, optional):
Return not just the segmented hashtags but also the a dictionary of the ranks.
Defaults to False.
- trim_hashtags (bool, optional):
- Automatically remove "#" characters from the beginning of the hashtags.
- Defaults to True.
Returns:
- Any: A list of segmented hashtags if return_ranks == False. A dictionary of the ranks and the segmented hashtags if return_ranks == True.
+ Any: A list of segmented words if return_ranks == False. A dictionary of the ranks and the segmented words if return_ranks == True.
"""
- if trim_hashtags:
- word_list = \
- [ x.lstrip("#") for x in word_list ]
-
segmenter_run = self.segmenter_model.run(
word_list,
topk=topk,
@@ -105,7 +207,7 @@ def segment(
)
ensemble = None
- if use_reranker:
+ if use_reranker and self.reranker_model:
reranker_run = self.reranker_model.rerank(segmenter_run)
ensemble = top2_ensemble(
@@ -140,9 +242,113 @@ def segment(
reranker_df = reranker_run.to_dataframe().reset_index(drop=True)
else:
reranker_df = None
- return {
- "segmenter": segmenter_df,
- "reranker": reranker_df,
- "ensemble": ensemble,
- "segmentations": segs
- }
\ No newline at end of file
+ return WordSegmenterOutput(
+ segmenter_rank=segmenter_df,
+ reranker_rank=reranker_df,
+ ensemble_rank=ensemble,
+ output=segs
+ )
+
+class TwitterTextMatcher(object):
+
+ def __init__(self):
+ self.parser = ttp.Parser()
+
+ def __call__(self, tweets):
+ return [ self.parser.parse(x).tags for x in tweets ]
+
+class TweetSegmenter(BaseSegmenter):
+
+ def __init__(self, matcher=None, word_segmenter=None):
+
+ if matcher:
+ self.matcher = matcher
+ else:
+ self.matcher = TwitterTextMatcher()
+
+ if word_segmenter:
+ self.word_segmenter = word_segmenter
+ else:
+ self.word_segmenter = RegexWordSegmenter()
+
+ def extract_hashtags(self, tweets):
+ return self.matcher(tweets)
+
+ def compile_dict(self, hashtags, segmentations, hashtag_token=None, lower=False, separator=" ", hashtag_character="#"):
+
+ hashtag_buffer = {
+ k:v for k,v in zip(hashtags, segmentations)
+ }
+
+ replacement_dict = {}
+
+ for key, value in hashtag_buffer.items():
+
+ if not key.startswith(hashtag_character):
+ hashtag_key = hashtag_character + key
+ else:
+ hashtag_key = key
+
+ if hashtag_token:
+ hashtag_value = hashtag_token + separator + value
+ else:
+ hashtag_value = value
+
+ if lower:
+ hashtag_value = hashtag_value.lower()
+
+ replacement_dict.update({hashtag_key : hashtag_value})
+
+ return replacement_dict
+
+ def replace_hashtags(self, tweet, regex_pattern, replacement_dict):
+
+ if not replacement_dict:
+ return tweet
+
+ tweet = regex_pattern.sub(lambda m: replacement_dict[m.group(0)], tweet)
+
+ return tweet
+
+ def segmented_tweet_generator(self, tweets, hashtags, hashtag_set, replacement_dict, flag=0):
+
+ hashtag_set_index = { value:idx for idx, value in enumerate(hashtag_set)}
+ replacement_pairs = [ (key, value) for key, value in replacement_dict.items() ]
+
+ for idx, tweet_hashtags in enumerate(hashtags):
+
+ tweet_dict = [ hashtag_set_index[hashtag] for hashtag in tweet_hashtags]
+ tweet_dict = [ replacement_pairs[index] for index in tweet_dict ]
+ tweet_dict = dict(tweet_dict)
+
+ # Treats edge case: overlapping hashtags
+ tweet_map = \
+ map(re.escape, sorted(tweet_dict, key=len, reverse=True))
+
+ regex_pattern = re.compile("|".join(tweet_map), flag)
+ tweet = self.replace_hashtags(tweets[idx], regex_pattern, tweet_dict)
+ yield tweet
+
+ def build_hashtag_container(self, tweets: str, preprocessing_kwargs: dict = {}, segmenter_kwargs: dict = {} ):
+
+ hashtags = self.extract_hashtags(tweets)
+
+ hashtag_set = list(set(reduce(lambda x, y: x + y, hashtags)))
+
+ word_segmenter_output = self.word_segmenter.predict(hashtag_set, **segmenter_kwargs)
+
+ segmentations = word_segmenter_output.output
+
+ replacement_dict = self.compile_dict(hashtag_set, segmentations, **preprocessing_kwargs)
+
+ return HashtagContainer(hashtags, hashtag_set, replacement_dict), word_segmenter_output
+
+ def segment(self, tweets: List[str], regex_flag: Any = 0, preprocessing_kwargs: dict = {}, segmenter_kwargs: dict = {} ):
+
+ hashtag_container, word_segmenter_output = self.build_hashtag_container(tweets, preprocessing_kwargs, segmenter_kwargs)
+ output = list(self.segmented_tweet_generator(tweets, *dataclasses.astuple(hashtag_container), flag=regex_flag))
+
+ return TweetSegmenterOutput(
+ word_segmenter_output = word_segmenter_output,
+ output = output
+ )
\ No newline at end of file
diff --git a/tests/fixtures/word_segmenters.json b/tests/fixtures/word_segmenters.json
new file mode 100644
index 0000000..52d525a
--- /dev/null
+++ b/tests/fixtures/word_segmenters.json
@@ -0,0 +1,46 @@
+[
+ {
+ "class": "WordSegmenter",
+ "init_kwargs": {
+ "segmenter_model_name_or_path": "distilgpt2",
+ "segmenter_gpu_batch_size": 1000,
+ "reranker_model_name_or_path": null
+ },
+ "predict_kwargs": {
+ "topk": 2,
+ "steps": 2
+ },
+ "prune": true
+ },
+ {
+ "class": "WordSegmenter",
+ "init_kwargs": {
+ "segmenter_model_name_or_path": "distilgpt2",
+ "segmenter_gpu_batch_size": 1000,
+ "reranker_model_name_or_path": "bert-base-uncased"
+ },
+ "predict_kwargs": {
+ "topk": 2,
+ "steps": 2
+ },
+ "prune": true
+ },
+ {
+ "class": "EkphrasisWordSegmenter",
+ "init_kwargs": {
+ "corpus": "twitter"
+ },
+ "predict_kwargs": {
+
+ }
+ },
+ {
+ "class": "RegexWordSegmenter",
+ "init_kwargs": {
+
+ },
+ "predict_kwargs": {
+
+ }
+ }
+]
\ No newline at end of file
diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py
new file mode 100644
index 0000000..27ac278
--- /dev/null
+++ b/tests/test_segmenter.py
@@ -0,0 +1,123 @@
+import pytest
+
+import dataclasses
+import json
+import os
+from pathlib import Path
+
+import hashformers
+import torch
+from hashformers import prune_segmenter_layers
+from hashformers.segmenter import (RegexWordSegmenter, TweetSegmenter,
+ TwitterTextMatcher)
+
+test_data_dir = Path(__file__).parent.absolute()
+cuda_is_available = torch.cuda.is_available()
+
+with open(os.path.join(test_data_dir,"fixtures/word_segmenters.json"), "r") as f:
+ word_segmenter_params = json.load(f)
+ word_segmenter_test_ids = []
+ for row in word_segmenter_params:
+ class_name = row["class"]
+ segmenter = row["init_kwargs"].get("segmenter_model_name_or_path", "O")
+ reranker = row["init_kwargs"].get("reranker_model_name_or_path", "O")
+ id_string = "{0}_{1}_{2}".format(class_name, segmenter, reranker)
+ word_segmenter_test_ids.append(id_string)
+
+@pytest.fixture(scope="module")
+def tweet_segmenter():
+ return TweetSegmenter()
+
+@pytest.fixture(scope="module", params=word_segmenter_params, ids=word_segmenter_test_ids)
+def word_segmenter(request):
+
+ word_segmenter_class = request.param["class"]
+ word_segmenter_init_kwargs = request.param["init_kwargs"]
+ word_segmenter_predict_kwargs = request.param["predict_kwargs"]
+
+ WordSegmenterClass = getattr(hashformers, word_segmenter_class)
+
+ class WordSegmenterClassWrapper(WordSegmenterClass):
+
+ def __init__(self, **kwargs):
+ return super().__init__(**kwargs)
+
+ def predict(self, *args):
+ return super().predict(*args, **word_segmenter_predict_kwargs)
+
+ WordSegmenterClassWrapper.__name__ = request.param["class"] + "ClassWrapper"
+
+ ws = WordSegmenterClassWrapper(**word_segmenter_init_kwargs)
+
+ if request.param.get("prune", False):
+ ws = prune_segmenter_layers(ws, layer_list=[0])
+
+ return ws
+
+@pytest.mark.skipif(not cuda_is_available, reason="A GPU is not available.")
+def test_word_segmenter_output_format(word_segmenter):
+
+ test_boun_hashtags = [
+ "minecraf",
+ "ourmomentfragrance",
+ "waybackwhen"
+ ]
+
+ predictions = word_segmenter.predict(test_boun_hashtags).output
+
+ predictions_chars = [ x.replace(" ", "") for x in predictions ]
+
+ assert all([x == y for x,y in zip(test_boun_hashtags, predictions_chars)])
+
+def test_matcher():
+ matcher = TwitterTextMatcher()
+ result = matcher(["esto es #UnaGenialidad"])
+ assert result == [["UnaGenialidad"]]
+
+def test_regex_word_segmenter():
+ ws = RegexWordSegmenter()
+ test_case = ["UnaGenialidad"]
+ expected_output=["Una Genialidad"]
+ prediction = ws.predict(test_case)
+ error_message = "{0} != {1}".format(prediction, str(test_case))
+ assert prediction.output == expected_output, error_message
+
+def test_hashtag_container(tweet_segmenter):
+ original_tweets = [
+ "esto es #UnaGenialidad"
+ ]
+ hashtag_container, word_segmenter_output = tweet_segmenter.build_hashtag_container(original_tweets)
+ assert hashtag_container.hashtags == [['UnaGenialidad']]
+ assert hashtag_container.hashtag_set == ['UnaGenialidad']
+ assert hashtag_container.replacement_dict == {'#UnaGenialidad': 'Una Genialidad'}
+ assert isinstance(word_segmenter_output, hashformers.segmenter.WordSegmenterOutput)
+
+def test_tweet_segmenter_generator(tweet_segmenter):
+ original_tweets = [
+ "esto es #UnaGenialidad"
+ ]
+ hashtag_container, word_segmenter_output = tweet_segmenter.build_hashtag_container(original_tweets)
+ for tweet in tweet_segmenter.segmented_tweet_generator(original_tweets, *dataclasses.astuple(hashtag_container), flag=0):
+ assert tweet == "esto es Una Genialidad"
+
+def test_tweet_segmenter_output_format(tweet_segmenter):
+
+ original_tweets = [
+ "esto es #UnaGenialidad"
+ ]
+
+ expected_tweets = [
+ "esto es Una Genialidad"
+ ]
+
+ output_tweets = tweet_segmenter.predict(original_tweets)
+
+ output_tweets = output_tweets.output
+
+ assert type(output_tweets) == type([])
+
+ assert len(original_tweets) == len(expected_tweets) == len(output_tweets)
+
+ for idx, tweet in enumerate(original_tweets):
+ assert expected_tweets[idx] == output_tweets[idx], \
+ "{0} != {1}".format(expected_tweets[idx], output_tweets[idx])
\ No newline at end of file