diff --git a/docs/EVALUATION.md b/docs/EVALUATION.md index 1126033..6ed41d3 100644 --- a/docs/EVALUATION.md +++ b/docs/EVALUATION.md @@ -1,5 +1,13 @@ # Evaluation +We provide a detailed evaluation of the accuracy and speed of the `hashformers` framework in comparison with alternative libraries. + +Although models based on n-grams such as `ekphrasis` are orders of magnitude faster than `hashformers`, they are remarkably unstable across different datasets. + +Research papers on word segmentation usually try to bring the best of both worlds together and combine deep learning with statistical methods. So it is possible that the best speed-accuracy trade-off may lie in building [ranking cascades](https://arxiv.org/abs/2010.06467) ( a.k.a. "telescoping" ) where `hashformers` is used as a fallback for when less time-consuming methods score below a certain confidence threshold. + +## Accuracy +

@@ -34,4 +42,30 @@ A script to reproduce the evaluation of ekphrasis is available on [scripts/evalu | | | | | average (all) | HashtagMaster | 58.35 | | | ekphrasis | 41.65 | -| |**hashformers**| **68.06**| \ No newline at end of file +| |**hashformers**| **68.06**| + +## Speed + +| model | hashtags/second | accuracy | topk | layers| +|:--------------|:----------------|----------:|-----:|------:| +| ekphrasis | 4405.00 | 44.74 | - | - | +| gpt2-large | 12.04 | 63.86 | 2 | first | +| distilgpt2 | 29.32 | 64.56 | 2 | first | +|**distilgpt2** | **15.00** | **80.48** |**2** |**all**| +| gpt2 | 11.36 | - | 2 | all | +| gpt2 | 3.48 | - | 20 | all | +| gpt2 + bert | 1.38 | 83.68 | 20 | all | + +In this table we evaluate hashformers under different settings on the Dev-BOUN dataset and compare it with ekphrasis. As ekphrasis relies on n-grams, it is a few orders of magnitude faster than hashformers. + +All experiments were performed on Google Colab while connected to a Tesla T4 GPU with 15GB of RAM. We highlight `distilgpt2` at `topk = 2`, which provides the best speed-accuracy trade-off. + +* **model**: The name of the model. We evaluate ekphrasis under the default settings, and use the reranker only for the SOTA experiment at the bottom row. + +* **hashtags/second**: How many hashtags the model can segment per second. All experiments on hashformers had the `batch_size` parameter adjusted to take up close to 100% of GPU RAM. A sidenote: even at 100% of GPU memory usage, we get about 60% of GPU utilization. So you may get better results by using GPUs with more memory than 16GB. + +* **accuracy**: Accuracy on the Dev-BOUN dataset. We don't evaluate the accuracy of `gpt2`, but we know [from the literature](https://arxiv.org/abs/2112.03213) that it is expected to be between `distilgpt2` (at 80%) and `gpt2 + bert` (the SOTA, at 83%). + +* **topk**: the `topk` parameter of the Beamsearch algorithm ( passed as the `topk` argument to the `WordSegmenter.segment` method). The `steps` Beamsearch parameter was fixed at a default value of 13 for all experiments with hashformers, as it doesn't have a significant impact on performance as `topk`. + +* **layers**: How many Transformer layers were utilized for language modeling: either all layers or just the bottom layer. \ No newline at end of file diff --git a/setup.py b/setup.py index b87ef40..4cc8f8d 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,9 @@ package_dir={'': 'src'}, install_requires=[ "mlm-hashformers", - "lm-scorer-hashformers" + "lm-scorer-hashformers", + "twitter-text-python", + "ekphrasis", + "pandas", ] ) diff --git a/src/hashformers/beamsearch/gpt2_lm.py b/src/hashformers/beamsearch/gpt2_lm.py index e517e18..36f2ec7 100644 --- a/src/hashformers/beamsearch/gpt2_lm.py +++ b/src/hashformers/beamsearch/gpt2_lm.py @@ -1,9 +1,82 @@ -from lm_scorer.models.auto import GPT2LMScorer as LMScorer +from lm_scorer.models.auto import GPT2LMScorer +from typing import * # pylint: disable=wildcard-import,unused-wildcard-import +import torch +from transformers import AutoTokenizer, GPT2LMHeadModel +from transformers.tokenization_utils import BatchEncoding + +class PaddedGPT2LMScorer(GPT2LMScorer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _build(self, model_name: str, options: Dict[str, Any]) -> None: + super()._build(model_name, options) + + # pylint: disable=attribute-defined-outside-init + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, use_fast=True, add_special_tokens=False + ) + # Add the pad token to GPT2 dictionary. + # len(tokenizer) = vocab_size + 1 + self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|pad|>"]}) + self.tokenizer.pad_token = "<|pad|>" + + self.model = GPT2LMHeadModel.from_pretrained(model_name) + # We need to resize the embedding layer because we added the pad token. + self.model.resize_token_embeddings(len(self.tokenizer)) + self.model.eval() + if "device" in options: + self.model.to(options["device"]) + + def _tokens_log_prob_for_batch( + self, text: List[str] + ) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]: + outputs: List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]] = [] + if len(text) == 0: + return outputs + + # TODO: Handle overflowing elements for long sentences + text = list(map(self._add_special_tokens, text)) + encoding: BatchEncoding = self.tokenizer.batch_encode_plus( + text, return_tensors="pt", padding=True, truncation=True + ) + with torch.no_grad(): + ids = encoding["input_ids"].to(self.model.device) + attention_mask = encoding["attention_mask"].to(self.model.device) + nopad_mask = ids != self.tokenizer.pad_token_id + logits: torch.Tensor = self.model(ids, attention_mask=attention_mask)[0] + + for sent_index in range(len(text)): + sent_nopad_mask = nopad_mask[sent_index] + # len(tokens) = len(text[sent_index]) + 1 + sent_tokens = [ + tok + for i, tok in enumerate(encoding.tokens(sent_index)) + if sent_nopad_mask[i] and i != 0 + ] + + # sent_ids.shape = [len(text[sent_index]) + 1] + sent_ids = ids[sent_index, sent_nopad_mask][1:] + # logits.shape = [len(text[sent_index]) + 1, vocab_size] + sent_logits = logits[sent_index, sent_nopad_mask][:-1, :] + sent_logits[:, self.tokenizer.pad_token_id] = float("-inf") + # ids_scores.shape = [seq_len + 1] + sent_ids_scores = sent_logits.gather(1, sent_ids.unsqueeze(1)).squeeze(1) + # log_prob.shape = [seq_len + 1] + sent_log_probs = sent_ids_scores - sent_logits.logsumexp(1) + + sent_log_probs = cast(torch.DoubleTensor, sent_log_probs) + sent_ids = cast(torch.LongTensor, sent_ids) + + output = (sent_log_probs, sent_ids, sent_tokens) + outputs.append(output) + + return outputs class GPT2LM(object): def __init__(self, model_name_or_path, device='cuda', gpu_batch_size=20): - self.scorer = LMScorer(model_name_or_path, device=device, batch_size=gpu_batch_size) + self.scorer = PaddedGPT2LMScorer(model_name_or_path, device=device, batch_size=gpu_batch_size) def get_probs(self, list_of_candidates): scores = self.scorer.sentence_score(list_of_candidates, log=True) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index 2412f1c..6280a1f 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -1,11 +1,121 @@ +import hashformers from hashformers.beamsearch.algorithm import Beamsearch from hashformers.beamsearch.reranker import Reranker from hashformers.beamsearch.data_structures import enforce_prob_dict from hashformers.ensemble.top2_fusion import top2_ensemble from typing import List, Union, Any +from dataclasses import dataclass +import pandas as pd +from ttp import ttp +from ekphrasis.classes.segmenter import Segmenter as EkphrasisSegmenter +import re +import copy +import torch +from math import log10 +from functools import reduce +import dataclasses -class WordSegmenter(object): +@dataclass +class WordSegmenterOutput: + output: List[str] + segmenter_rank: Union[pd.DataFrame, None] = None + reranker_rank: Union[pd.DataFrame, None] = None + ensemble_rank: Union[pd.DataFrame, None] = None +@dataclass +class HashtagContainer: + hashtags: List[List[str]] + hashtag_set: List[str] + replacement_dict: dict + +@dataclass +class TweetSegmenterOutput: + output: List[str] + word_segmenter_output: Any + +def coerce_segmenter_objects(method): + def wrapper(inputs, *args, **kwargs): + if isinstance(inputs, str): + output = method([inputs], *args, **kwargs) + else: + output = method(inputs, *args, **kwargs) + for allowed_type in [ + hashformers.segmenter.WordSegmenterOutput, + hashformers.segmenter.TweetSegmenterOutput + ]: + if isinstance(output, allowed_type): + return output + else: + return WordSegmenterOutput(output=output) + return wrapper + +def prune_segmenter_layers(ws, layer_list=[0]): + ws.segmenter_model.model.scorer.model = \ + deleteEncodingLayers(ws.segmenter_model.model.scorer.model, layer_list=layer_list) + return ws + +def deleteEncodingLayers(model, layer_list=[0]): + oldModuleList = model.transformer.h + newModuleList = torch.nn.ModuleList() + + for index in layer_list: + newModuleList.append(oldModuleList[index]) + + copyOfModel = copy.deepcopy(model) + copyOfModel.transformer.h = newModuleList + + return copyOfModel + +class BaseSegmenter(object): + + @coerce_segmenter_objects + def predict(self, *args, **kwargs): + return self.segment(*args, **kwargs) + +class EkphrasisWordSegmenter(EkphrasisSegmenter, BaseSegmenter): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def find_segment(self, text, prev=''): + if not text: + return 0.0, [] + candidates = [self.combine((log10(self.condProbWord(first, prev)), first), self.find_segment(rem, first)) + for first, rem in self.splits(text)] + return max(candidates) + + def segment_word(self, word) -> str: + if word.islower(): + return " ".join(self.find_segment(word)[1]) + else: + return self.case_split.sub(r' \1', word).lower() + + def segment(self, inputs) -> List[str]: + return [ self.segment_word(word) for word in inputs ] + +class RegexWordSegmenter(BaseSegmenter): + + def __init__(self,regex_rules=None): + if not regex_rules: + regex_rules = [r'([A-Z]+)'] + self.regex_rules = [ + re.compile(x) for x in regex_rules + ] + + def segment_word(self, rule, word): + return rule.sub(r' \1', word).strip() + + def segmentation_generator(self, word_list): + for rule in self.regex_rules: + for idx, word in enumerate(word_list): + yield self.segment_word(rule, word) + + def segment(self, inputs: List[str]): + return list(self.segmentation_generator(inputs)) + +class WordSegmenter(BaseSegmenter): + """A general-purpose word segmentation API. + """ def __init__( self, segmenter_model_name_or_path = "gpt2", @@ -54,12 +164,11 @@ def segment( alpha: float = 0.222, beta: float = 0.111, use_reranker: bool = True, - return_ranks: bool = False, - trim_hashtags: bool = True) -> Any : - """Segment a list of hashtags. + return_ranks: bool = False) -> Any : + """Segment a list of strings. Args: - word_list (List[str]): A list of hashtag strings. + word_list (List[str]): A list of strings. topk (int, optional): top-k parameter for the Beamsearch algorithm. A lower top-k value will speed up the algorithm. @@ -86,18 +195,11 @@ def segment( return_ranks (bool, optional): Return not just the segmented hashtags but also the a dictionary of the ranks. Defaults to False. - trim_hashtags (bool, optional): - Automatically remove "#" characters from the beginning of the hashtags. - Defaults to True. Returns: - Any: A list of segmented hashtags if return_ranks == False. A dictionary of the ranks and the segmented hashtags if return_ranks == True. + Any: A list of segmented words if return_ranks == False. A dictionary of the ranks and the segmented words if return_ranks == True. """ - if trim_hashtags: - word_list = \ - [ x.lstrip("#") for x in word_list ] - segmenter_run = self.segmenter_model.run( word_list, topk=topk, @@ -105,7 +207,7 @@ def segment( ) ensemble = None - if use_reranker: + if use_reranker and self.reranker_model: reranker_run = self.reranker_model.rerank(segmenter_run) ensemble = top2_ensemble( @@ -140,9 +242,113 @@ def segment( reranker_df = reranker_run.to_dataframe().reset_index(drop=True) else: reranker_df = None - return { - "segmenter": segmenter_df, - "reranker": reranker_df, - "ensemble": ensemble, - "segmentations": segs - } \ No newline at end of file + return WordSegmenterOutput( + segmenter_rank=segmenter_df, + reranker_rank=reranker_df, + ensemble_rank=ensemble, + output=segs + ) + +class TwitterTextMatcher(object): + + def __init__(self): + self.parser = ttp.Parser() + + def __call__(self, tweets): + return [ self.parser.parse(x).tags for x in tweets ] + +class TweetSegmenter(BaseSegmenter): + + def __init__(self, matcher=None, word_segmenter=None): + + if matcher: + self.matcher = matcher + else: + self.matcher = TwitterTextMatcher() + + if word_segmenter: + self.word_segmenter = word_segmenter + else: + self.word_segmenter = RegexWordSegmenter() + + def extract_hashtags(self, tweets): + return self.matcher(tweets) + + def compile_dict(self, hashtags, segmentations, hashtag_token=None, lower=False, separator=" ", hashtag_character="#"): + + hashtag_buffer = { + k:v for k,v in zip(hashtags, segmentations) + } + + replacement_dict = {} + + for key, value in hashtag_buffer.items(): + + if not key.startswith(hashtag_character): + hashtag_key = hashtag_character + key + else: + hashtag_key = key + + if hashtag_token: + hashtag_value = hashtag_token + separator + value + else: + hashtag_value = value + + if lower: + hashtag_value = hashtag_value.lower() + + replacement_dict.update({hashtag_key : hashtag_value}) + + return replacement_dict + + def replace_hashtags(self, tweet, regex_pattern, replacement_dict): + + if not replacement_dict: + return tweet + + tweet = regex_pattern.sub(lambda m: replacement_dict[m.group(0)], tweet) + + return tweet + + def segmented_tweet_generator(self, tweets, hashtags, hashtag_set, replacement_dict, flag=0): + + hashtag_set_index = { value:idx for idx, value in enumerate(hashtag_set)} + replacement_pairs = [ (key, value) for key, value in replacement_dict.items() ] + + for idx, tweet_hashtags in enumerate(hashtags): + + tweet_dict = [ hashtag_set_index[hashtag] for hashtag in tweet_hashtags] + tweet_dict = [ replacement_pairs[index] for index in tweet_dict ] + tweet_dict = dict(tweet_dict) + + # Treats edge case: overlapping hashtags + tweet_map = \ + map(re.escape, sorted(tweet_dict, key=len, reverse=True)) + + regex_pattern = re.compile("|".join(tweet_map), flag) + tweet = self.replace_hashtags(tweets[idx], regex_pattern, tweet_dict) + yield tweet + + def build_hashtag_container(self, tweets: str, preprocessing_kwargs: dict = {}, segmenter_kwargs: dict = {} ): + + hashtags = self.extract_hashtags(tweets) + + hashtag_set = list(set(reduce(lambda x, y: x + y, hashtags))) + + word_segmenter_output = self.word_segmenter.predict(hashtag_set, **segmenter_kwargs) + + segmentations = word_segmenter_output.output + + replacement_dict = self.compile_dict(hashtag_set, segmentations, **preprocessing_kwargs) + + return HashtagContainer(hashtags, hashtag_set, replacement_dict), word_segmenter_output + + def segment(self, tweets: List[str], regex_flag: Any = 0, preprocessing_kwargs: dict = {}, segmenter_kwargs: dict = {} ): + + hashtag_container, word_segmenter_output = self.build_hashtag_container(tweets, preprocessing_kwargs, segmenter_kwargs) + output = list(self.segmented_tweet_generator(tweets, *dataclasses.astuple(hashtag_container), flag=regex_flag)) + + return TweetSegmenterOutput( + word_segmenter_output = word_segmenter_output, + output = output + ) \ No newline at end of file diff --git a/tests/fixtures/word_segmenters.json b/tests/fixtures/word_segmenters.json new file mode 100644 index 0000000..52d525a --- /dev/null +++ b/tests/fixtures/word_segmenters.json @@ -0,0 +1,46 @@ +[ + { + "class": "WordSegmenter", + "init_kwargs": { + "segmenter_model_name_or_path": "distilgpt2", + "segmenter_gpu_batch_size": 1000, + "reranker_model_name_or_path": null + }, + "predict_kwargs": { + "topk": 2, + "steps": 2 + }, + "prune": true + }, + { + "class": "WordSegmenter", + "init_kwargs": { + "segmenter_model_name_or_path": "distilgpt2", + "segmenter_gpu_batch_size": 1000, + "reranker_model_name_or_path": "bert-base-uncased" + }, + "predict_kwargs": { + "topk": 2, + "steps": 2 + }, + "prune": true + }, + { + "class": "EkphrasisWordSegmenter", + "init_kwargs": { + "corpus": "twitter" + }, + "predict_kwargs": { + + } + }, + { + "class": "RegexWordSegmenter", + "init_kwargs": { + + }, + "predict_kwargs": { + + } + } +] \ No newline at end of file diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py new file mode 100644 index 0000000..27ac278 --- /dev/null +++ b/tests/test_segmenter.py @@ -0,0 +1,123 @@ +import pytest + +import dataclasses +import json +import os +from pathlib import Path + +import hashformers +import torch +from hashformers import prune_segmenter_layers +from hashformers.segmenter import (RegexWordSegmenter, TweetSegmenter, + TwitterTextMatcher) + +test_data_dir = Path(__file__).parent.absolute() +cuda_is_available = torch.cuda.is_available() + +with open(os.path.join(test_data_dir,"fixtures/word_segmenters.json"), "r") as f: + word_segmenter_params = json.load(f) + word_segmenter_test_ids = [] + for row in word_segmenter_params: + class_name = row["class"] + segmenter = row["init_kwargs"].get("segmenter_model_name_or_path", "O") + reranker = row["init_kwargs"].get("reranker_model_name_or_path", "O") + id_string = "{0}_{1}_{2}".format(class_name, segmenter, reranker) + word_segmenter_test_ids.append(id_string) + +@pytest.fixture(scope="module") +def tweet_segmenter(): + return TweetSegmenter() + +@pytest.fixture(scope="module", params=word_segmenter_params, ids=word_segmenter_test_ids) +def word_segmenter(request): + + word_segmenter_class = request.param["class"] + word_segmenter_init_kwargs = request.param["init_kwargs"] + word_segmenter_predict_kwargs = request.param["predict_kwargs"] + + WordSegmenterClass = getattr(hashformers, word_segmenter_class) + + class WordSegmenterClassWrapper(WordSegmenterClass): + + def __init__(self, **kwargs): + return super().__init__(**kwargs) + + def predict(self, *args): + return super().predict(*args, **word_segmenter_predict_kwargs) + + WordSegmenterClassWrapper.__name__ = request.param["class"] + "ClassWrapper" + + ws = WordSegmenterClassWrapper(**word_segmenter_init_kwargs) + + if request.param.get("prune", False): + ws = prune_segmenter_layers(ws, layer_list=[0]) + + return ws + +@pytest.mark.skipif(not cuda_is_available, reason="A GPU is not available.") +def test_word_segmenter_output_format(word_segmenter): + + test_boun_hashtags = [ + "minecraf", + "ourmomentfragrance", + "waybackwhen" + ] + + predictions = word_segmenter.predict(test_boun_hashtags).output + + predictions_chars = [ x.replace(" ", "") for x in predictions ] + + assert all([x == y for x,y in zip(test_boun_hashtags, predictions_chars)]) + +def test_matcher(): + matcher = TwitterTextMatcher() + result = matcher(["esto es #UnaGenialidad"]) + assert result == [["UnaGenialidad"]] + +def test_regex_word_segmenter(): + ws = RegexWordSegmenter() + test_case = ["UnaGenialidad"] + expected_output=["Una Genialidad"] + prediction = ws.predict(test_case) + error_message = "{0} != {1}".format(prediction, str(test_case)) + assert prediction.output == expected_output, error_message + +def test_hashtag_container(tweet_segmenter): + original_tweets = [ + "esto es #UnaGenialidad" + ] + hashtag_container, word_segmenter_output = tweet_segmenter.build_hashtag_container(original_tweets) + assert hashtag_container.hashtags == [['UnaGenialidad']] + assert hashtag_container.hashtag_set == ['UnaGenialidad'] + assert hashtag_container.replacement_dict == {'#UnaGenialidad': 'Una Genialidad'} + assert isinstance(word_segmenter_output, hashformers.segmenter.WordSegmenterOutput) + +def test_tweet_segmenter_generator(tweet_segmenter): + original_tweets = [ + "esto es #UnaGenialidad" + ] + hashtag_container, word_segmenter_output = tweet_segmenter.build_hashtag_container(original_tweets) + for tweet in tweet_segmenter.segmented_tweet_generator(original_tweets, *dataclasses.astuple(hashtag_container), flag=0): + assert tweet == "esto es Una Genialidad" + +def test_tweet_segmenter_output_format(tweet_segmenter): + + original_tweets = [ + "esto es #UnaGenialidad" + ] + + expected_tweets = [ + "esto es Una Genialidad" + ] + + output_tweets = tweet_segmenter.predict(original_tweets) + + output_tweets = output_tweets.output + + assert type(output_tweets) == type([]) + + assert len(original_tweets) == len(expected_tweets) == len(output_tweets) + + for idx, tweet in enumerate(original_tweets): + assert expected_tweets[idx] == output_tweets[idx], \ + "{0} != {1}".format(expected_tweets[idx], output_tweets[idx]) \ No newline at end of file