From b65afe6b6cc57a2726c4cdac4b3d5abbeb8953a7 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sat, 5 Feb 2022 13:10:30 +0100 Subject: [PATCH 01/25] TweetSegmenter for pysentimiento --- src/hashformers/segmenter.py | 155 ++++++++++++++++++++++++++++++----- 1 file changed, 136 insertions(+), 19 deletions(-) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index 2412f1c..8d12f95 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -1,11 +1,50 @@ +import hashformers from hashformers.beamsearch.algorithm import Beamsearch from hashformers.beamsearch.reranker import Reranker from hashformers.beamsearch.data_structures import enforce_prob_dict from hashformers.ensemble.top2_fusion import top2_ensemble from typing import List, Union, Any +from dataclasses import dataclass +import pandas as pd +from ttp import ttp +import re -class WordSegmenter(object): +@dataclass +class WordSegmenterOutput: + output: List[str] + segmenter_rank: Union[pd.DataFrame, None] + reranker_rank: Union[pd.DataFrame, None] + ensemble_rank: Union[pd.DataFrame, None] + +@dataclass +class TweetSegmenterOutput: + output: List[str] + word_segmenter_output: hashformers.WordSegmenterOutput + hashtag_dict: dict + +class RegexWordSegmenter(object): + + def __init__(self,regex_rules=None): + if not regex_rules: + regex_rules = [r'([A-Z]+)'] + self.regex_rules = [ + re.compile(x) for x in regex_rules + ] + + def segment(self, word_list): + for rule in self.regex_rules: + for idx, word in enumerate(word_list): + word_list[idx] = rule.sub(r' \1', word).strip() + return WordSegmenterOutput( + segmenter_rank=None, + reranker_rank=None, + ensemble_rank=None, + output=word_list + ) +class WordSegmenter(object): + """A general-purpose word segmentation API. + """ def __init__( self, segmenter_model_name_or_path = "gpt2", @@ -54,12 +93,11 @@ def segment( alpha: float = 0.222, beta: float = 0.111, use_reranker: bool = True, - return_ranks: bool = False, - trim_hashtags: bool = True) -> Any : - """Segment a list of hashtags. + return_ranks: bool = False) -> Any : + """Segment a list of strings. Args: - word_list (List[str]): A list of hashtag strings. + word_list (List[str]): A list of strings. topk (int, optional): top-k parameter for the Beamsearch algorithm. A lower top-k value will speed up the algorithm. @@ -86,18 +124,11 @@ def segment( return_ranks (bool, optional): Return not just the segmented hashtags but also the a dictionary of the ranks. Defaults to False. - trim_hashtags (bool, optional): - Automatically remove "#" characters from the beginning of the hashtags. - Defaults to True. Returns: - Any: A list of segmented hashtags if return_ranks == False. A dictionary of the ranks and the segmented hashtags if return_ranks == True. + Any: A list of segmented words if return_ranks == False. A dictionary of the ranks and the segmented words if return_ranks == True. """ - if trim_hashtags: - word_list = \ - [ x.lstrip("#") for x in word_list ] - segmenter_run = self.segmenter_model.run( word_list, topk=topk, @@ -140,9 +171,95 @@ def segment( reranker_df = reranker_run.to_dataframe().reset_index(drop=True) else: reranker_df = None - return { - "segmenter": segmenter_df, - "reranker": reranker_df, - "ensemble": ensemble, - "segmentations": segs - } \ No newline at end of file + return WordSegmenterOutput( + segmenter_rank=segmenter_df, + reranker_rank=reranker_df, + ensemble_rank=ensemble, + output=segs + ) + +class TwitterTextMatcher(object): + + def __init__(self): + self.parser = ttp.Parser() + + def __call__(self, tweets): + return [ self.parser.parse(x).tags for x in tweets ] + +class TweetSegmenter(object): + + def __init__(self, matcher=None, word_segmenter=None): + + if matcher: + self.matcher = matcher + else: + self.matcher = TwitterTextMatcher() + + if word_segmenter: + self.word_segmenter = word_segmenter + else: + self.word_segmenter = RegexWordSegmenter() + + def extract_hashtags(self, tweets): + return self.matcher(tweets) + + def create_regex_pattern(self, replacement_dict, flags=0): + return re.compile("|".join(replacement_dict), flags) + + def replace_hashtags(self, tweets, hashtag_dict, hashtag_token=None, separator=" ", hashtag_character="#"): + + if not hashtag_dict: + return tweets + + replacement_dict = {} + + for key, value in hashtag_dict.items(): + if not key.startswith(hashtag_character): + hashtag_key = hashtag_character + key + else: + hashtag_key = key + + if hashtag_token: + hashtag_value = hashtag_token + separator + value + else: + hashtag_value = value + + replacement_dict.update(hashtag_key, hashtag_value) + + replacement_dict = \ + map(re.escape, sorted(replacement_dict, key=len, reverse=True)) + + pattern = self.create_regex_pattern(replacement_dict) + + for idx, tweet in enumerate(tweets): + tweets[idx] = pattern.sub(lambda m: replacement_dict[m.group(0)], tweet) + + return tweets + + def segment_tweets(self, tweets, hashtag_dict=None, **kwargs): + + tweets = self.replace_hashtags(tweets, hashtag_dict) + + hashtags = self.extract_hashtags(tweets) + + word_segmenter_output = self.word_segmenter.segment(hashtags, **kwargs) + + segmentations = word_segmenter_output.output + + hashtag_dict.update({ + k:v for k,v in zip(hashtags, segmentations) + }) + + tweets = self.replace_hashtags(tweets, hashtag_dict) + + return TweetSegmenterOutput( + word_segmenter_output = word_segmenter_output, + hashtag_dict = hashtag_dict, + output = tweets + ) + + def predict(self, inputs, **kwargs): + if isinstance(inputs, str): + return self.segment_tweets([inputs], **kwargs)[0] + elif isinstance(inputs, list): + return self.segment_tweets(inputs, **kwargs) \ No newline at end of file From c3d511a178bc8f8f479d7c1b5199186cd85ecb7b Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sat, 5 Feb 2022 16:55:13 +0100 Subject: [PATCH 02/25] batch size bug fix --- src/hashformers/beamsearch/gpt2_lm.py | 77 ++++++++++++++++++++++++++- src/hashformers/segmenter.py | 13 +++-- 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/src/hashformers/beamsearch/gpt2_lm.py b/src/hashformers/beamsearch/gpt2_lm.py index e517e18..36f2ec7 100644 --- a/src/hashformers/beamsearch/gpt2_lm.py +++ b/src/hashformers/beamsearch/gpt2_lm.py @@ -1,9 +1,82 @@ -from lm_scorer.models.auto import GPT2LMScorer as LMScorer +from lm_scorer.models.auto import GPT2LMScorer +from typing import * # pylint: disable=wildcard-import,unused-wildcard-import +import torch +from transformers import AutoTokenizer, GPT2LMHeadModel +from transformers.tokenization_utils import BatchEncoding + +class PaddedGPT2LMScorer(GPT2LMScorer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _build(self, model_name: str, options: Dict[str, Any]) -> None: + super()._build(model_name, options) + + # pylint: disable=attribute-defined-outside-init + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, use_fast=True, add_special_tokens=False + ) + # Add the pad token to GPT2 dictionary. + # len(tokenizer) = vocab_size + 1 + self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|pad|>"]}) + self.tokenizer.pad_token = "<|pad|>" + + self.model = GPT2LMHeadModel.from_pretrained(model_name) + # We need to resize the embedding layer because we added the pad token. + self.model.resize_token_embeddings(len(self.tokenizer)) + self.model.eval() + if "device" in options: + self.model.to(options["device"]) + + def _tokens_log_prob_for_batch( + self, text: List[str] + ) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]: + outputs: List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]] = [] + if len(text) == 0: + return outputs + + # TODO: Handle overflowing elements for long sentences + text = list(map(self._add_special_tokens, text)) + encoding: BatchEncoding = self.tokenizer.batch_encode_plus( + text, return_tensors="pt", padding=True, truncation=True + ) + with torch.no_grad(): + ids = encoding["input_ids"].to(self.model.device) + attention_mask = encoding["attention_mask"].to(self.model.device) + nopad_mask = ids != self.tokenizer.pad_token_id + logits: torch.Tensor = self.model(ids, attention_mask=attention_mask)[0] + + for sent_index in range(len(text)): + sent_nopad_mask = nopad_mask[sent_index] + # len(tokens) = len(text[sent_index]) + 1 + sent_tokens = [ + tok + for i, tok in enumerate(encoding.tokens(sent_index)) + if sent_nopad_mask[i] and i != 0 + ] + + # sent_ids.shape = [len(text[sent_index]) + 1] + sent_ids = ids[sent_index, sent_nopad_mask][1:] + # logits.shape = [len(text[sent_index]) + 1, vocab_size] + sent_logits = logits[sent_index, sent_nopad_mask][:-1, :] + sent_logits[:, self.tokenizer.pad_token_id] = float("-inf") + # ids_scores.shape = [seq_len + 1] + sent_ids_scores = sent_logits.gather(1, sent_ids.unsqueeze(1)).squeeze(1) + # log_prob.shape = [seq_len + 1] + sent_log_probs = sent_ids_scores - sent_logits.logsumexp(1) + + sent_log_probs = cast(torch.DoubleTensor, sent_log_probs) + sent_ids = cast(torch.LongTensor, sent_ids) + + output = (sent_log_probs, sent_ids, sent_tokens) + outputs.append(output) + + return outputs class GPT2LM(object): def __init__(self, model_name_or_path, device='cuda', gpu_batch_size=20): - self.scorer = LMScorer(model_name_or_path, device=device, batch_size=gpu_batch_size) + self.scorer = PaddedGPT2LMScorer(model_name_or_path, device=device, batch_size=gpu_batch_size) def get_probs(self, list_of_candidates): scores = self.scorer.sentence_score(list_of_candidates, log=True) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index 8d12f95..3685fc6 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -206,7 +206,7 @@ def extract_hashtags(self, tweets): def create_regex_pattern(self, replacement_dict, flags=0): return re.compile("|".join(replacement_dict), flags) - def replace_hashtags(self, tweets, hashtag_dict, hashtag_token=None, separator=" ", hashtag_character="#"): + def replace_hashtags(self, tweets, hashtag_dict, hashtag_token=None, lower=False, separator=" ", hashtag_character="#"): if not hashtag_dict: return tweets @@ -214,6 +214,7 @@ def replace_hashtags(self, tweets, hashtag_dict, hashtag_token=None, separator=" replacement_dict = {} for key, value in hashtag_dict.items(): + if not key.startswith(hashtag_character): hashtag_key = hashtag_character + key else: @@ -224,8 +225,12 @@ def replace_hashtags(self, tweets, hashtag_dict, hashtag_token=None, separator=" else: hashtag_value = value + if lower: + hashtag_value = hashtag_value.lower() + replacement_dict.update(hashtag_key, hashtag_value) + # Treat edge case: overlapping hashtags replacement_dict = \ map(re.escape, sorted(replacement_dict, key=len, reverse=True)) @@ -236,13 +241,13 @@ def replace_hashtags(self, tweets, hashtag_dict, hashtag_token=None, separator=" return tweets - def segment_tweets(self, tweets, hashtag_dict=None, **kwargs): + def segment_tweets(self, tweets, hashtag_dict = None, preprocessing_kwargs = {}, segmenter_kwargs = {} ): - tweets = self.replace_hashtags(tweets, hashtag_dict) + tweets = self.replace_hashtags(tweets, hashtag_dict, **preprocessing_kwargs) hashtags = self.extract_hashtags(tweets) - word_segmenter_output = self.word_segmenter.segment(hashtags, **kwargs) + word_segmenter_output = self.word_segmenter.segment(hashtags, **segmenter_kwargs) segmentations = word_segmenter_output.output From 489fca54743e7f0de884cb3ba62c397519116aa3 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sat, 5 Feb 2022 17:04:25 +0100 Subject: [PATCH 03/25] ttp dependency --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b87ef40..48a38b7 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ package_dir={'': 'src'}, install_requires=[ "mlm-hashformers", - "lm-scorer-hashformers" + "lm-scorer-hashformers", + "twitter-text-python" ] ) From 04a86fef1c8cd8af81330f1b164148b393d9a924 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sat, 5 Feb 2022 17:06:03 +0100 Subject: [PATCH 04/25] type fix --- src/hashformers/segmenter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index 3685fc6..ab0f245 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -19,7 +19,7 @@ class WordSegmenterOutput: @dataclass class TweetSegmenterOutput: output: List[str] - word_segmenter_output: hashformers.WordSegmenterOutput + word_segmenter_output: Any hashtag_dict: dict class RegexWordSegmenter(object): From 5186745da348e0da401df4e7b47ef9b688034cef Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sat, 5 Feb 2022 22:09:13 +0100 Subject: [PATCH 05/25] EVALUATION README --- docs/EVALUATION.md | 30 +++++++++++++++++++++++++++++- src/hashformers/segmenter.py | 2 +- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/docs/EVALUATION.md b/docs/EVALUATION.md index 1126033..5141c4b 100644 --- a/docs/EVALUATION.md +++ b/docs/EVALUATION.md @@ -1,5 +1,7 @@ # Evaluation +## Accuracy +

@@ -34,4 +36,30 @@ A script to reproduce the evaluation of ekphrasis is available on [scripts/evalu | | | | | average (all) | HashtagMaster | 58.35 | | | ekphrasis | 41.65 | -| |**hashformers**| **68.06**| \ No newline at end of file +| |**hashformers**| **68.06**| + +## Speed + +| model | hashtags/second | accuracy | topk | layers| +|:--------------|:----------------|----------:|-----:|------:| +| ekphrasis | 4405.00 | 44.74 | - | - | +| gpt2-large | 12.04 | 63.86 | 2 | first | +| distilgpt2 | 29.32 | 64.56 | 2 | first | +|**distilgpt2** | **15.00** | **80.48** |**2** |**all**| +| gpt2 | 11.36 | - | 2 | all | +| gpt2 | 3.48 | - | 20 | all | +| gpt2 + bert | 1.38 | 83.68 | 20 | all | + +In this table we evaluate hashformers under different settings on the Dev-BOUN dataset and compare it with ekphrasis. As ekphrasis relies on n-grams, it is a few orders of magnitude faster than hashformers. + +All experiments were performed on Google Colab while connected to a Tesla T4 GPU with 15GB of RAM. We highlight `distilgpt2` at `topk = 2`, which provides the best speed-accuracy trade-off. + +* **model**: The name of the model. We evaluate ekphrasis under the default settings, and use the reranker only for the SOTA experiment at the bottom row. + +* **hashtags/second**: How many hashtags the model can segment per second. All experiments on hashformers had the `batch_size` parameter adjusted to take up close to 100% of GPU RAM. A sidenote: even at 100% of GPU memory usage, we get about 60% of GPU utilization. So you may get better results by adding more memory. + +* **accuracy**: Accuracy on the Dev-BOUN dataset. + +* **topk**: the `topk` parameter of the Beamsearch algorithm ( passed as the `topk` argument to the `WordSegmenter.segment` method). The `steps` Beamsearch parameter was fixed at a default value of 13 for all experiments with hashformers, as it doesn't have a significant impact on performance as `topk`. + +* **layers**: How many Transformer layers were utilized for language modeling: either all layers or just the bottom layer. \ No newline at end of file diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index ab0f245..fa6fc60 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -136,7 +136,7 @@ def segment( ) ensemble = None - if use_reranker: + if use_reranker and self.reranker_model: reranker_run = self.reranker_model.rerank(segmenter_run) ensemble = top2_ensemble( From 02b0fd985231cae59e7fb84a3efced92839173b7 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sat, 5 Feb 2022 22:14:36 +0100 Subject: [PATCH 06/25] EVALUATION README --- docs/EVALUATION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/EVALUATION.md b/docs/EVALUATION.md index 5141c4b..c33bd98 100644 --- a/docs/EVALUATION.md +++ b/docs/EVALUATION.md @@ -58,7 +58,7 @@ All experiments were performed on Google Colab while connected to a Tesla T4 GPU * **hashtags/second**: How many hashtags the model can segment per second. All experiments on hashformers had the `batch_size` parameter adjusted to take up close to 100% of GPU RAM. A sidenote: even at 100% of GPU memory usage, we get about 60% of GPU utilization. So you may get better results by adding more memory. -* **accuracy**: Accuracy on the Dev-BOUN dataset. +* **accuracy**: Accuracy on the Dev-BOUN dataset. We don't evaluate the accuracy of `gpt2`, but we know [from the literature](https://arxiv.org/abs/2112.03213) that it is expected to be between `distilgpt2` (at 80%) and `gpt2 + bert` (the SOTA, at 83%). * **topk**: the `topk` parameter of the Beamsearch algorithm ( passed as the `topk` argument to the `WordSegmenter.segment` method). The `steps` Beamsearch parameter was fixed at a default value of 13 for all experiments with hashformers, as it doesn't have a significant impact on performance as `topk`. From 4b7e5bb95dbf13780c571058557958b0d94234cf Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sat, 5 Feb 2022 22:39:52 +0100 Subject: [PATCH 07/25] EVALUATION README --- docs/EVALUATION.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/EVALUATION.md b/docs/EVALUATION.md index c33bd98..9f78cba 100644 --- a/docs/EVALUATION.md +++ b/docs/EVALUATION.md @@ -1,5 +1,11 @@ # Evaluation +We provide a detailed evaluation of the accuracy and speed of the `hashformers` framework in comparison with alternative libraries. + +Although models based on n-grams such as `ekphrasis` are orders of magnitude faster than `hashformers`, they are remarkably unstable across different datasets. + +Research papers on word segmentation usually try to bring the best of both worlds together and combine deep learning with statistical methods. So it is possible that the best speed-accuracy trade-off may lie in building [ranking cascades](https://arxiv.org/abs/2010.06467) ( a.k.a. "telescoping" ) where `hashformers` is used as a fallback for when less time-consuming methods score below a certain confidence threshold. + ## Accuracy

From e9da7ad4285b874fadf8faf6c7f66729ccffd412 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 17:35:22 +0100 Subject: [PATCH 08/25] extra segmenters, unit tests, docs --- docs/EVALUATION.md | 2 +- setup.py | 4 +- src/hashformers/segmenter.py | 117 +++++++++++++++++++--------- tests/fixtures/test_boun_sample.txt | 10 +++ tests/fixtures/word_segmenters.json | 46 +++++++++++ tests/test_segmenter.py | 48 ++++++++++++ 6 files changed, 190 insertions(+), 37 deletions(-) create mode 100644 tests/fixtures/test_boun_sample.txt create mode 100644 tests/fixtures/word_segmenters.json create mode 100644 tests/test_segmenter.py diff --git a/docs/EVALUATION.md b/docs/EVALUATION.md index 9f78cba..6ed41d3 100644 --- a/docs/EVALUATION.md +++ b/docs/EVALUATION.md @@ -62,7 +62,7 @@ All experiments were performed on Google Colab while connected to a Tesla T4 GPU * **model**: The name of the model. We evaluate ekphrasis under the default settings, and use the reranker only for the SOTA experiment at the bottom row. -* **hashtags/second**: How many hashtags the model can segment per second. All experiments on hashformers had the `batch_size` parameter adjusted to take up close to 100% of GPU RAM. A sidenote: even at 100% of GPU memory usage, we get about 60% of GPU utilization. So you may get better results by adding more memory. +* **hashtags/second**: How many hashtags the model can segment per second. All experiments on hashformers had the `batch_size` parameter adjusted to take up close to 100% of GPU RAM. A sidenote: even at 100% of GPU memory usage, we get about 60% of GPU utilization. So you may get better results by using GPUs with more memory than 16GB. * **accuracy**: Accuracy on the Dev-BOUN dataset. We don't evaluate the accuracy of `gpt2`, but we know [from the literature](https://arxiv.org/abs/2112.03213) that it is expected to be between `distilgpt2` (at 80%) and `gpt2 + bert` (the SOTA, at 83%). diff --git a/setup.py b/setup.py index 48a38b7..82540e1 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,8 @@ install_requires=[ "mlm-hashformers", "lm-scorer-hashformers", - "twitter-text-python" + "twitter-text-python", + "ekphrasis", + "methodtools" ] ) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index fa6fc60..15f5628 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -7,7 +7,12 @@ from dataclasses import dataclass import pandas as pd from ttp import ttp +from ekphrasis.classes.segmenter import Segmenter as EkphrasisSegmenter import re +import typing +import inspect +import copy +import torch @dataclass class WordSegmenterOutput: @@ -20,9 +25,52 @@ class WordSegmenterOutput: class TweetSegmenterOutput: output: List[str] word_segmenter_output: Any - hashtag_dict: dict -class RegexWordSegmenter(object): +def prune_segmenter_layers(ws, layer_list=[0]): + ws.segmenter_model.model.scorer.model = \ + deleteEncodingLayers(ws.segmenter_model.model.scorer.model, layer_list=layer_list) + return ws + +def deleteEncodingLayers(model, layer_list=[0]): + oldModuleList = model.transformer.h + newModuleList = torch.nn.ModuleList() + + for index in layer_list: + newModuleList.append(oldModuleList[index]) + + copyOfModel = copy.deepcopy(model) + copyOfModel.transformer.h = newModuleList + + return copyOfModel + +class BaseSegmenter(object): + + def predict(self, input, *args, **kwargs): + first_argument = inspect.getfullargspec(self.segment).args[1] + first_argument_type = typing.get_type_hints(self.segment)[first_argument] + a = type(first_argument_type) == type(str) + b = type(input) == type(str) + if a and b: + return self.segment(input, *args, **kwargs) + elif not a and not b: + return self.segment(input, *args, **kwargs) + elif a and not b: + return [ self.segment(x, *args, **kwargs) for x in input ] + elif not a and b: + return self.segment([input], *args, **kwargs)[0] + +class EkphrasisWordSegmenter(EkphrasisSegmenter, BaseSegmenter): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def find_segment(self, *args, **kwargs): + return super().find_segment.__wrapped__(*args, **kwargs) + + def segment(self, word: str) -> str: + return super().segment.__wrapped__(word) + +class RegexWordSegmenter(BaseSegmenter): def __init__(self,regex_rules=None): if not regex_rules: @@ -31,10 +79,13 @@ def __init__(self,regex_rules=None): re.compile(x) for x in regex_rules ] - def segment(self, word_list): + def segment_word(self, rule, word): + return rule.sub(r' \1', word).strip() + + def segment(self, word_list: List[str]): for rule in self.regex_rules: for idx, word in enumerate(word_list): - word_list[idx] = rule.sub(r' \1', word).strip() + word_list[idx] = self.segment_word(rule, word) return WordSegmenterOutput( segmenter_rank=None, reranker_rank=None, @@ -42,7 +93,7 @@ def segment(self, word_list): output=word_list ) -class WordSegmenter(object): +class WordSegmenter(BaseSegmenter): """A general-purpose word segmentation API. """ def __init__( @@ -185,8 +236,8 @@ def __init__(self): def __call__(self, tweets): return [ self.parser.parse(x).tags for x in tweets ] - -class TweetSegmenter(object): + +class TweetSegmenter(BaseSegmenter): def __init__(self, matcher=None, word_segmenter=None): @@ -206,14 +257,15 @@ def extract_hashtags(self, tweets): def create_regex_pattern(self, replacement_dict, flags=0): return re.compile("|".join(replacement_dict), flags) - def replace_hashtags(self, tweets, hashtag_dict, hashtag_token=None, lower=False, separator=" ", hashtag_character="#"): + def compile_dict(self, hashtags, segmentations, hashtag_token=None, lower=False, separator=" ", hashtag_character="#"): - if not hashtag_dict: - return tweets + hashtag_buffer = { + k:v for k,v in zip(hashtags, segmentations) + } replacement_dict = {} - for key, value in hashtag_dict.items(): + for key, value in hashtag_buffer.items(): if not key.startswith(hashtag_character): hashtag_key = hashtag_character + key @@ -224,7 +276,7 @@ def replace_hashtags(self, tweets, hashtag_dict, hashtag_token=None, lower=False hashtag_value = hashtag_token + separator + value else: hashtag_value = value - + if lower: hashtag_value = hashtag_value.lower() @@ -234,37 +286,32 @@ def replace_hashtags(self, tweets, hashtag_dict, hashtag_token=None, lower=False replacement_dict = \ map(re.escape, sorted(replacement_dict, key=len, reverse=True)) - pattern = self.create_regex_pattern(replacement_dict) + return replacement_dict - for idx, tweet in enumerate(tweets): - tweets[idx] = pattern.sub(lambda m: replacement_dict[m.group(0)], tweet) - - return tweets + def replace_hashtags(self, tweet, regex_pattern, replacement_dict): - def segment_tweets(self, tweets, hashtag_dict = None, preprocessing_kwargs = {}, segmenter_kwargs = {} ): - - tweets = self.replace_hashtags(tweets, hashtag_dict, **preprocessing_kwargs) + if not replacement_dict: + return tweet + + tweet = regex_pattern.sub(lambda m: replacement_dict[m.group(0)], tweet) + return tweet + + def segment_tweets(self, tweets, regex_flag=0, preprocessing_kwargs = {}, segmenter_kwargs = {} ): + hashtags = self.extract_hashtags(tweets) - word_segmenter_output = self.word_segmenter.segment(hashtags, **segmenter_kwargs) - + word_segmenter_output = self.word_segmenter.predict(hashtags, **segmenter_kwargs) + segmentations = word_segmenter_output.output - - hashtag_dict.update({ - k:v for k,v in zip(hashtags, segmentations) - }) - tweets = self.replace_hashtags(tweets, hashtag_dict) + replacement_dict = self.compile_dict(hashtags, segmentations, **preprocessing_kwargs) + + regex_pattern = self.create_regex_pattern(replacement_dict, flag=regex_flag) + + tweets = [ self.replace_hashtags(tweet, regex_pattern, replacement_dict) for tweet in tweets] return TweetSegmenterOutput( word_segmenter_output = word_segmenter_output, - hashtag_dict = hashtag_dict, output = tweets - ) - - def predict(self, inputs, **kwargs): - if isinstance(inputs, str): - return self.segment_tweets([inputs], **kwargs)[0] - elif isinstance(inputs, list): - return self.segment_tweets(inputs, **kwargs) \ No newline at end of file + ) \ No newline at end of file diff --git a/tests/fixtures/test_boun_sample.txt b/tests/fixtures/test_boun_sample.txt new file mode 100644 index 0000000..f310663 --- /dev/null +++ b/tests/fixtures/test_boun_sample.txt @@ -0,0 +1,10 @@ +conceptiphone +haute +sneakerfiles +forevertired +specialfood +idhangonthat +amillionhoodies +minecraf +ourmomentfragrance +waybackwhen \ No newline at end of file diff --git a/tests/fixtures/word_segmenters.json b/tests/fixtures/word_segmenters.json new file mode 100644 index 0000000..52d525a --- /dev/null +++ b/tests/fixtures/word_segmenters.json @@ -0,0 +1,46 @@ +[ + { + "class": "WordSegmenter", + "init_kwargs": { + "segmenter_model_name_or_path": "distilgpt2", + "segmenter_gpu_batch_size": 1000, + "reranker_model_name_or_path": null + }, + "predict_kwargs": { + "topk": 2, + "steps": 2 + }, + "prune": true + }, + { + "class": "WordSegmenter", + "init_kwargs": { + "segmenter_model_name_or_path": "distilgpt2", + "segmenter_gpu_batch_size": 1000, + "reranker_model_name_or_path": "bert-base-uncased" + }, + "predict_kwargs": { + "topk": 2, + "steps": 2 + }, + "prune": true + }, + { + "class": "EkphrasisWordSegmenter", + "init_kwargs": { + "corpus": "twitter" + }, + "predict_kwargs": { + + } + }, + { + "class": "RegexWordSegmenter", + "init_kwargs": { + + }, + "predict_kwargs": { + + } + } +] \ No newline at end of file diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py new file mode 100644 index 0000000..dfe164f --- /dev/null +++ b/tests/test_segmenter.py @@ -0,0 +1,48 @@ +import hashformers +import pytest +import json +from hashformers import prune_segmenter_layers + +import hashformers + +with open("fixtures/test_boun_sample.txt", "r") as f1,\ + open("fixtures/word_Segmenters.json") as f2: + + test_boun_gold = f1.read().strip().split("\n") + test_boun_hashtags = [ x.replace(" ", "") for x in test_boun_gold] + word_segmenter_params = json.load(f2) + + +@pytest.fixture(scope="module", params=word_segmenter_params) +def word_segmenter(request): + + word_segmenter_class = request.param["class"] + word_segmenter_init_kwargs = request.param["init_kwargs"] + word_segmenter_predict_kwargs = request.param["predict_kwargs"] + + WordSegmenterClass = getattr(hashformers, word_segmenter_class) + + class PartialWordSegmenterClass(WordSegmenterClass): + + def __init__(self, **kwargs): + return super().__init__(**kwargs) + + def predict(self, *args): + super().predict(*args, **word_segmenter_predict_kwargs) + + ws = PartialWordSegmenterClass(**word_segmenter_init_kwargs) + + if request.param.get("prune", False): + ws = prune_segmenter_layers(ws, layer_list=[0]) + + return ws + +def test_word_segmenter_output_format(): + + model = word_segmenter() + + predictions = model.predict(test_boun_hashtags) + + predictions_chars = [ x.replace(" ", "") for x in predictions ] + + assert all([x == y for x,y in zip(test_boun_hashtags, predictions_chars)]) \ No newline at end of file From 0bad46730428e662c50458fad302f64890e0d268 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 17:46:16 +0100 Subject: [PATCH 09/25] test bug fix --- tests/test_segmenter.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index dfe164f..f51e94a 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -2,11 +2,13 @@ import pytest import json from hashformers import prune_segmenter_layers - +from pathlib import Path import hashformers +import os +test_data_dir = Path(__file__).parent.absolute() -with open("fixtures/test_boun_sample.txt", "r") as f1,\ - open("fixtures/word_Segmenters.json") as f2: +with open(os.path.join(test_data_dir,"fixtures/test_boun_sample.txt"), "r") as f1,\ + open(os.path.join(test_data_dir,"fixtures/word_segmenters.json"), "r") as f2: test_boun_gold = f1.read().strip().split("\n") test_boun_hashtags = [ x.replace(" ", "") for x in test_boun_gold] From 68c4022d878ce72d1ea4b623aa7bbba51d418b43 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 17:50:12 +0100 Subject: [PATCH 10/25] test bug fix --- tests/test_segmenter.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index f51e94a..2e35fe0 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -39,11 +39,9 @@ def predict(self, *args): return ws -def test_word_segmenter_output_format(): +def test_word_segmenter_output_format(word_segmenter): - model = word_segmenter() - - predictions = model.predict(test_boun_hashtags) + predictions = word_segmenter.predict(test_boun_hashtags) predictions_chars = [ x.replace(" ", "") for x in predictions ] From 9efd0ae5223b1b4be6709bffdeaf103533959e26 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 18:06:49 +0100 Subject: [PATCH 11/25] improve error messages --- tests/test_segmenter.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index 2e35fe0..03a1436 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -14,7 +14,6 @@ test_boun_hashtags = [ x.replace(" ", "") for x in test_boun_gold] word_segmenter_params = json.load(f2) - @pytest.fixture(scope="module", params=word_segmenter_params) def word_segmenter(request): @@ -24,7 +23,7 @@ def word_segmenter(request): WordSegmenterClass = getattr(hashformers, word_segmenter_class) - class PartialWordSegmenterClass(WordSegmenterClass): + class WordSegmenterClassWrapper(WordSegmenterClass): def __init__(self, **kwargs): return super().__init__(**kwargs) @@ -32,7 +31,9 @@ def __init__(self, **kwargs): def predict(self, *args): super().predict(*args, **word_segmenter_predict_kwargs) - ws = PartialWordSegmenterClass(**word_segmenter_init_kwargs) + WordSegmenterClassWrapper.__name__ = request.param["class"] + "ClassWrapper" + + ws = WordSegmenterClassWrapper(**word_segmenter_init_kwargs) if request.param.get("prune", False): ws = prune_segmenter_layers(ws, layer_list=[0]) From 7c12b301bb3c6f3aec91ebf9d4a2522436450cbe Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 18:36:49 +0100 Subject: [PATCH 12/25] shorten and bug fix tests --- tests/fixtures/test_boun_sample.txt | 7 ------- tests/test_segmenter.py | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/fixtures/test_boun_sample.txt b/tests/fixtures/test_boun_sample.txt index f310663..57222e6 100644 --- a/tests/fixtures/test_boun_sample.txt +++ b/tests/fixtures/test_boun_sample.txt @@ -1,10 +1,3 @@ -conceptiphone -haute -sneakerfiles -forevertired -specialfood -idhangonthat -amillionhoodies minecraf ourmomentfragrance waybackwhen \ No newline at end of file diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index 03a1436..5a4fd53 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -29,7 +29,7 @@ def __init__(self, **kwargs): return super().__init__(**kwargs) def predict(self, *args): - super().predict(*args, **word_segmenter_predict_kwargs) + return super().predict(*args, **word_segmenter_predict_kwargs) WordSegmenterClassWrapper.__name__ = request.param["class"] + "ClassWrapper" From 2b3a32c602525bfb76e1653230f3a83d1ef79b09 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 18:53:33 +0100 Subject: [PATCH 13/25] predict in BaseSegmenter types --- src/hashformers/segmenter.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index 15f5628..b97ecea 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -50,14 +50,18 @@ def predict(self, input, *args, **kwargs): first_argument_type = typing.get_type_hints(self.segment)[first_argument] a = type(first_argument_type) == type(str) b = type(input) == type(str) + output = WordSegmenterOutput() if a and b: - return self.segment(input, *args, **kwargs) + output = self.segment(input, *args, **kwargs) elif not a and not b: - return self.segment(input, *args, **kwargs) + output = self.segment(input, *args, **kwargs) elif a and not b: - return [ self.segment(x, *args, **kwargs) for x in input ] + output = [ self.segment(x, *args, **kwargs) for x in input ] elif not a and b: - return self.segment([input], *args, **kwargs)[0] + output = self.segment([input], *args, **kwargs)[0] + if type(output) != type(WordSegmenterOutput): + output = WordSegmenterOutput(output=output) + return output class EkphrasisWordSegmenter(EkphrasisSegmenter, BaseSegmenter): @@ -86,12 +90,7 @@ def segment(self, word_list: List[str]): for rule in self.regex_rules: for idx, word in enumerate(word_list): word_list[idx] = self.segment_word(rule, word) - return WordSegmenterOutput( - segmenter_rank=None, - reranker_rank=None, - ensemble_rank=None, - output=word_list - ) + return word_list class WordSegmenter(BaseSegmenter): """A general-purpose word segmentation API. From f60916cc1712d9e4ee74a21569327421f2d14ab2 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 18:54:09 +0100 Subject: [PATCH 14/25] test WordSegmenterOutput --- tests/test_segmenter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index 5a4fd53..5bf3515 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -42,7 +42,7 @@ def predict(self, *args): def test_word_segmenter_output_format(word_segmenter): - predictions = word_segmenter.predict(test_boun_hashtags) + predictions = word_segmenter.predict(test_boun_hashtags).output predictions_chars = [ x.replace(" ", "") for x in predictions ] From 3670650a74e3deb7603bf985530692f5c85f126d Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 18:59:24 +0100 Subject: [PATCH 15/25] test WordSegmenterOutput --- src/hashformers/segmenter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index b97ecea..b00f059 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -17,9 +17,9 @@ @dataclass class WordSegmenterOutput: output: List[str] - segmenter_rank: Union[pd.DataFrame, None] - reranker_rank: Union[pd.DataFrame, None] - ensemble_rank: Union[pd.DataFrame, None] + segmenter_rank: Union[pd.DataFrame, None] = None + reranker_rank: Union[pd.DataFrame, None] = None + ensemble_rank: Union[pd.DataFrame, None] = None @dataclass class TweetSegmenterOutput: @@ -50,7 +50,7 @@ def predict(self, input, *args, **kwargs): first_argument_type = typing.get_type_hints(self.segment)[first_argument] a = type(first_argument_type) == type(str) b = type(input) == type(str) - output = WordSegmenterOutput() + output = None if a and b: output = self.segment(input, *args, **kwargs) elif not a and not b: From 40f329f1ba89475839c1ac22fcc3f876b81b954c Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 19:09:39 +0100 Subject: [PATCH 16/25] ekphrasis debugging --- src/hashformers/segmenter.py | 14 +++++++++++--- tests/test_segmenter.py | 11 +++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index b00f059..61e6117 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -13,6 +13,7 @@ import inspect import copy import torch +from math import log10 @dataclass class WordSegmenterOutput: @@ -68,11 +69,18 @@ class EkphrasisWordSegmenter(EkphrasisSegmenter, BaseSegmenter): def __init__(self, **kwargs): super().__init__(**kwargs) - def find_segment(self, *args, **kwargs): - return super().find_segment.__wrapped__(*args, **kwargs) + def find_segment(self, text, prev=''): + if not text: + return 0.0, [] + candidates = [self.combine((log10(self.condProbWord(first, prev)), first), self.find_segment(rem, first)) + for first, rem in self.splits(text)] + return max(candidates) def segment(self, word: str) -> str: - return super().segment.__wrapped__(word) + if word.islower(): + return " ".join(self.find_segment(word)[1]) + else: + return self.case_split.sub(r' \1', word).lower() class RegexWordSegmenter(BaseSegmenter): diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index 5bf3515..6d6965f 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -13,8 +13,15 @@ test_boun_gold = f1.read().strip().split("\n") test_boun_hashtags = [ x.replace(" ", "") for x in test_boun_gold] word_segmenter_params = json.load(f2) - -@pytest.fixture(scope="module", params=word_segmenter_params) + word_segmenter_test_ids = [] + for row in word_segmenter_params: + class_name = row["class"] + segmenter = row["init_kwargs"].get("segmenter_model_name_or_path", "O") + reranker = row["init_kwargs"].get("reranker_model_name_or_path", "O") + id_string = "{0}_{1}_{2}".format(class_name, segmenter, reranker) + word_segmenter_test_ids.append(id_string) + +@pytest.fixture(scope="module", params=word_segmenter_params, ids=word_segmenter_test_ids) def word_segmenter(request): word_segmenter_class = request.param["class"] From 1622c5d58cae6680201465d86b428d5cfb981d58 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 19:42:05 +0100 Subject: [PATCH 17/25] tweet segmenter tests --- tests/test_segmenter.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index 6d6965f..3d3ca92 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -1,10 +1,13 @@ import hashformers +from hashformers.segmenter import TweetSegmenter import pytest import json from hashformers import prune_segmenter_layers from pathlib import Path import hashformers import os +import torch + test_data_dir = Path(__file__).parent.absolute() with open(os.path.join(test_data_dir,"fixtures/test_boun_sample.txt"), "r") as f1,\ @@ -21,7 +24,12 @@ id_string = "{0}_{1}_{2}".format(class_name, segmenter, reranker) word_segmenter_test_ids.append(id_string) +@pytest.fixture(scope="module") +def tweet_segmenter(): + return TweetSegmenter() + @pytest.fixture(scope="module", params=word_segmenter_params, ids=word_segmenter_test_ids) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="A GPU is not available.") def word_segmenter(request): word_segmenter_class = request.param["class"] @@ -49,8 +57,32 @@ def predict(self, *args): def test_word_segmenter_output_format(word_segmenter): + test_boun_hashtags = [ + "minecraf", + "ourmomentfragrance", + "waybackwhen" + ] + predictions = word_segmenter.predict(test_boun_hashtags).output predictions_chars = [ x.replace(" ", "") for x in predictions ] - assert all([x == y for x,y in zip(test_boun_hashtags, predictions_chars)]) \ No newline at end of file + assert all([x == y for x,y in zip(test_boun_hashtags, predictions_chars)]) + +def test_tweet_segmenter_output_format(tweet_segmenter): + + original_tweets = [ + "esto es #UnaGenialidad" + ] + + expected_tweets = [ + "esto es Una Genialidad" + ] + + output_tweets = tweet_segmenter.predict(original_tweets).output + + assert len(original_tweets) == len(expected_tweets) == len(output_tweets) + + for idx, tweet in enumerate(original_tweets): + assert expected_tweets[idx] == output_tweets[idx], \ + "{0} != {1}".format(expected_tweets[idx], output_tweets[idx]) \ No newline at end of file From 879d1a00e19178fb613399f0d69f33d8688614b7 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 19:57:00 +0100 Subject: [PATCH 18/25] skipif test --- tests/test_segmenter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index 3d3ca92..5ff2caf 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -9,6 +9,7 @@ import torch test_data_dir = Path(__file__).parent.absolute() +cuda_is_available = torch.cuda.is_available() with open(os.path.join(test_data_dir,"fixtures/test_boun_sample.txt"), "r") as f1,\ open(os.path.join(test_data_dir,"fixtures/word_segmenters.json"), "r") as f2: @@ -29,7 +30,6 @@ def tweet_segmenter(): return TweetSegmenter() @pytest.fixture(scope="module", params=word_segmenter_params, ids=word_segmenter_test_ids) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="A GPU is not available.") def word_segmenter(request): word_segmenter_class = request.param["class"] @@ -55,6 +55,7 @@ def predict(self, *args): return ws +@pytest.mark.skipif(not cuda_is_available, reason="A GPU is not available.") def test_word_segmenter_output_format(word_segmenter): test_boun_hashtags = [ From caf3f3638f3932295ee9584ba3ff015ce3fecaa2 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 20:04:34 +0100 Subject: [PATCH 19/25] tweetsegment tests --- src/hashformers/segmenter.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index 61e6117..509eadd 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -60,8 +60,16 @@ def predict(self, input, *args, **kwargs): output = [ self.segment(x, *args, **kwargs) for x in input ] elif not a and b: output = self.segment([input], *args, **kwargs)[0] + + if type(output) == type(WordSegmenterOutput): + return output + + if type(output) == type(TweetSegmenterOutput): + return output + if type(output) != type(WordSegmenterOutput): output = WordSegmenterOutput(output=output) + return output class EkphrasisWordSegmenter(EkphrasisSegmenter, BaseSegmenter): @@ -304,7 +312,7 @@ def replace_hashtags(self, tweet, regex_pattern, replacement_dict): return tweet - def segment_tweets(self, tweets, regex_flag=0, preprocessing_kwargs = {}, segmenter_kwargs = {} ): + def segment(self, tweets, regex_flag=0, preprocessing_kwargs = {}, segmenter_kwargs = {} ): hashtags = self.extract_hashtags(tweets) From 43c3d98b069de4f386dfe7fd068cfe2ea0defa0e Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 20:10:37 +0100 Subject: [PATCH 20/25] tweetsegment tests --- src/hashformers/segmenter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index 509eadd..2107f4f 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -312,7 +312,7 @@ def replace_hashtags(self, tweet, regex_pattern, replacement_dict): return tweet - def segment(self, tweets, regex_flag=0, preprocessing_kwargs = {}, segmenter_kwargs = {} ): + def segment(self, tweets: str, regex_flag: Any = 0, preprocessing_kwargs: dict = {}, segmenter_kwargs: dict = {} ): hashtags = self.extract_hashtags(tweets) From d32ce5cd9873ac101c163458cca4ce76c3520ec5 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 20:17:51 +0100 Subject: [PATCH 21/25] test matcher --- tests/test_segmenter.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index 5ff2caf..9a18068 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -1,5 +1,5 @@ import hashformers -from hashformers.segmenter import TweetSegmenter +from hashformers.segmenter import TweetSegmenter, TwitterTextMatcher import pytest import json from hashformers import prune_segmenter_layers @@ -70,6 +70,11 @@ def test_word_segmenter_output_format(word_segmenter): assert all([x == y for x,y in zip(test_boun_hashtags, predictions_chars)]) +def test_matcher(): + matcher = TwitterTextMatcher() + result = matcher(["esto es #UnaGenialidad"]) + assert result == "UnaGenialidad" + def test_tweet_segmenter_output_format(tweet_segmenter): original_tweets = [ From 9b8daa082c3ee92c0a8624a03f22a062414ad1b4 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 20:44:13 +0100 Subject: [PATCH 22/25] test tweets segmenter --- src/hashformers/segmenter.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index 2107f4f..4ab1fd4 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -4,7 +4,7 @@ from hashformers.beamsearch.data_structures import enforce_prob_dict from hashformers.ensemble.top2_fusion import top2_ensemble from typing import List, Union, Any -from dataclasses import dataclass +from dataclasses import dataclass, replace import pandas as pd from ttp import ttp from ekphrasis.classes.segmenter import Segmenter as EkphrasisSegmenter @@ -14,6 +14,7 @@ import copy import torch from math import log10 +from functools import reduce @dataclass class WordSegmenterOutput: @@ -312,21 +313,38 @@ def replace_hashtags(self, tweet, regex_pattern, replacement_dict): return tweet + def segmented_tweet_generator(self, tweets, hashtags, hashtag_set, replacement_dict, flag=0): + + hashtag_set_index = { value:idx for idx, value in enumerate(hashtag_set)} + replacement_pairs = [ (key, value) for key, value in replacement_dict.items() ] + + for idx, tweet_hashtags in enumerate(hashtags): + + tweet_dict = [ hashtag_set_index[hashtag] for hashtag in tweet_hashtags] + tweet_dict = [ replacement_pairs[index] for index in tweet_dict ] + tweet_dict = dict(tweet_dict) + + regex_pattern = self.create_regex_pattern(tweet_dict, flag=flag) + tweet = self.replace_hashtags(tweets[idx], regex_pattern, tweet_dict) + yield tweet + + def segment(self, tweets: str, regex_flag: Any = 0, preprocessing_kwargs: dict = {}, segmenter_kwargs: dict = {} ): hashtags = self.extract_hashtags(tweets) + + hashtag_set = list(set(reduce(lambda x, y: x + y, hashtags))) - word_segmenter_output = self.word_segmenter.predict(hashtags, **segmenter_kwargs) + word_segmenter_output = self.word_segmenter.predict(hashtag_set, **segmenter_kwargs) segmentations = word_segmenter_output.output - replacement_dict = self.compile_dict(hashtags, segmentations, **preprocessing_kwargs) + replacement_dict = self.compile_dict(hashtag_set, segmentations, **preprocessing_kwargs) - regex_pattern = self.create_regex_pattern(replacement_dict, flag=regex_flag) + output = [ tweet for tweet in \ + self.segmented_tweet_generator(tweets, hashtags, hashtag_set, replacement_dict, flag=regex_flag)] - tweets = [ self.replace_hashtags(tweet, regex_pattern, replacement_dict) for tweet in tweets] - return TweetSegmenterOutput( word_segmenter_output = word_segmenter_output, - output = tweets + output = output ) \ No newline at end of file From 211192d758f755faeca66e8fc85098e2ddafe400 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Sun, 6 Feb 2022 21:14:45 +0100 Subject: [PATCH 23/25] test tweets segmenter --- src/hashformers/segmenter.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index 4ab1fd4..b3db03f 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -270,9 +270,6 @@ def __init__(self, matcher=None, word_segmenter=None): def extract_hashtags(self, tweets): return self.matcher(tweets) - def create_regex_pattern(self, replacement_dict, flags=0): - return re.compile("|".join(replacement_dict), flags) - def compile_dict(self, hashtags, segmentations, hashtag_token=None, lower=False, separator=" ", hashtag_character="#"): hashtag_buffer = { @@ -298,10 +295,6 @@ def compile_dict(self, hashtags, segmentations, hashtag_token=None, lower=False, replacement_dict.update(hashtag_key, hashtag_value) - # Treat edge case: overlapping hashtags - replacement_dict = \ - map(re.escape, sorted(replacement_dict, key=len, reverse=True)) - return replacement_dict def replace_hashtags(self, tweet, regex_pattern, replacement_dict): @@ -323,8 +316,12 @@ def segmented_tweet_generator(self, tweets, hashtags, hashtag_set, replacement_d tweet_dict = [ hashtag_set_index[hashtag] for hashtag in tweet_hashtags] tweet_dict = [ replacement_pairs[index] for index in tweet_dict ] tweet_dict = dict(tweet_dict) + + # Treats edge case: overlapping hashtags + tweet_map = \ + map(re.escape, sorted(tweet_dict, key=len, reverse=True)) - regex_pattern = self.create_regex_pattern(tweet_dict, flag=flag) + regex_pattern = re.compile("|".join(tweet_map), flag) tweet = self.replace_hashtags(tweets[idx], regex_pattern, tweet_dict) yield tweet From 45a0244f0004ef35ff19ba10ad322d58b621216e Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Mon, 7 Feb 2022 00:01:14 +0100 Subject: [PATCH 24/25] all tests passing --- setup.py | 2 +- src/hashformers/segmenter.py | 83 +++++++++++++++++++----------------- tests/test_segmenter.py | 37 ++++++++++++++-- 3 files changed, 80 insertions(+), 42 deletions(-) diff --git a/setup.py b/setup.py index 82540e1..4cc8f8d 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,6 @@ "lm-scorer-hashformers", "twitter-text-python", "ekphrasis", - "methodtools" + "pandas", ] ) diff --git a/src/hashformers/segmenter.py b/src/hashformers/segmenter.py index b3db03f..6280a1f 100644 --- a/src/hashformers/segmenter.py +++ b/src/hashformers/segmenter.py @@ -4,17 +4,16 @@ from hashformers.beamsearch.data_structures import enforce_prob_dict from hashformers.ensemble.top2_fusion import top2_ensemble from typing import List, Union, Any -from dataclasses import dataclass, replace +from dataclasses import dataclass import pandas as pd from ttp import ttp from ekphrasis.classes.segmenter import Segmenter as EkphrasisSegmenter import re -import typing -import inspect import copy import torch from math import log10 from functools import reduce +import dataclasses @dataclass class WordSegmenterOutput: @@ -23,11 +22,33 @@ class WordSegmenterOutput: reranker_rank: Union[pd.DataFrame, None] = None ensemble_rank: Union[pd.DataFrame, None] = None +@dataclass +class HashtagContainer: + hashtags: List[List[str]] + hashtag_set: List[str] + replacement_dict: dict + @dataclass class TweetSegmenterOutput: output: List[str] word_segmenter_output: Any +def coerce_segmenter_objects(method): + def wrapper(inputs, *args, **kwargs): + if isinstance(inputs, str): + output = method([inputs], *args, **kwargs) + else: + output = method(inputs, *args, **kwargs) + for allowed_type in [ + hashformers.segmenter.WordSegmenterOutput, + hashformers.segmenter.TweetSegmenterOutput + ]: + if isinstance(output, allowed_type): + return output + else: + return WordSegmenterOutput(output=output) + return wrapper + def prune_segmenter_layers(ws, layer_list=[0]): ws.segmenter_model.model.scorer.model = \ deleteEncodingLayers(ws.segmenter_model.model.scorer.model, layer_list=layer_list) @@ -47,31 +68,9 @@ def deleteEncodingLayers(model, layer_list=[0]): class BaseSegmenter(object): - def predict(self, input, *args, **kwargs): - first_argument = inspect.getfullargspec(self.segment).args[1] - first_argument_type = typing.get_type_hints(self.segment)[first_argument] - a = type(first_argument_type) == type(str) - b = type(input) == type(str) - output = None - if a and b: - output = self.segment(input, *args, **kwargs) - elif not a and not b: - output = self.segment(input, *args, **kwargs) - elif a and not b: - output = [ self.segment(x, *args, **kwargs) for x in input ] - elif not a and b: - output = self.segment([input], *args, **kwargs)[0] - - if type(output) == type(WordSegmenterOutput): - return output - - if type(output) == type(TweetSegmenterOutput): - return output - - if type(output) != type(WordSegmenterOutput): - output = WordSegmenterOutput(output=output) - - return output + @coerce_segmenter_objects + def predict(self, *args, **kwargs): + return self.segment(*args, **kwargs) class EkphrasisWordSegmenter(EkphrasisSegmenter, BaseSegmenter): @@ -85,12 +84,15 @@ def find_segment(self, text, prev=''): for first, rem in self.splits(text)] return max(candidates) - def segment(self, word: str) -> str: + def segment_word(self, word) -> str: if word.islower(): return " ".join(self.find_segment(word)[1]) else: return self.case_split.sub(r' \1', word).lower() + def segment(self, inputs) -> List[str]: + return [ self.segment_word(word) for word in inputs ] + class RegexWordSegmenter(BaseSegmenter): def __init__(self,regex_rules=None): @@ -103,11 +105,13 @@ def __init__(self,regex_rules=None): def segment_word(self, rule, word): return rule.sub(r' \1', word).strip() - def segment(self, word_list: List[str]): + def segmentation_generator(self, word_list): for rule in self.regex_rules: for idx, word in enumerate(word_list): - word_list[idx] = self.segment_word(rule, word) - return word_list + yield self.segment_word(rule, word) + + def segment(self, inputs: List[str]): + return list(self.segmentation_generator(inputs)) class WordSegmenter(BaseSegmenter): """A general-purpose word segmentation API. @@ -293,7 +297,7 @@ def compile_dict(self, hashtags, segmentations, hashtag_token=None, lower=False, if lower: hashtag_value = hashtag_value.lower() - replacement_dict.update(hashtag_key, hashtag_value) + replacement_dict.update({hashtag_key : hashtag_value}) return replacement_dict @@ -325,9 +329,8 @@ def segmented_tweet_generator(self, tweets, hashtags, hashtag_set, replacement_d tweet = self.replace_hashtags(tweets[idx], regex_pattern, tweet_dict) yield tweet - - def segment(self, tweets: str, regex_flag: Any = 0, preprocessing_kwargs: dict = {}, segmenter_kwargs: dict = {} ): - + def build_hashtag_container(self, tweets: str, preprocessing_kwargs: dict = {}, segmenter_kwargs: dict = {} ): + hashtags = self.extract_hashtags(tweets) hashtag_set = list(set(reduce(lambda x, y: x + y, hashtags))) @@ -338,8 +341,12 @@ def segment(self, tweets: str, regex_flag: Any = 0, preprocessing_kwargs: dict = replacement_dict = self.compile_dict(hashtag_set, segmentations, **preprocessing_kwargs) - output = [ tweet for tweet in \ - self.segmented_tweet_generator(tweets, hashtags, hashtag_set, replacement_dict, flag=regex_flag)] + return HashtagContainer(hashtags, hashtag_set, replacement_dict), word_segmenter_output + + def segment(self, tweets: List[str], regex_flag: Any = 0, preprocessing_kwargs: dict = {}, segmenter_kwargs: dict = {} ): + + hashtag_container, word_segmenter_output = self.build_hashtag_container(tweets, preprocessing_kwargs, segmenter_kwargs) + output = list(self.segmented_tweet_generator(tweets, *dataclasses.astuple(hashtag_container), flag=regex_flag)) return TweetSegmenterOutput( word_segmenter_output = word_segmenter_output, diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index 9a18068..5984076 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -1,5 +1,6 @@ +import dataclasses import hashformers -from hashformers.segmenter import TweetSegmenter, TwitterTextMatcher +from hashformers.segmenter import RegexWordSegmenter, TweetSegmenter, TwitterTextMatcher, WordSegmenterOutput import pytest import json from hashformers import prune_segmenter_layers @@ -73,7 +74,33 @@ def test_word_segmenter_output_format(word_segmenter): def test_matcher(): matcher = TwitterTextMatcher() result = matcher(["esto es #UnaGenialidad"]) - assert result == "UnaGenialidad" + assert result == [["UnaGenialidad"]] + +def test_regex_word_segmenter(): + ws = RegexWordSegmenter() + test_case = ["UnaGenialidad"] + expected_output=["Una Genialidad"] + prediction = ws.predict(test_case) + error_message = "{0} != {1}".format(prediction, str(test_case)) + assert prediction.output == expected_output, error_message + +def test_hashtag_container(tweet_segmenter): + original_tweets = [ + "esto es #UnaGenialidad" + ] + hashtag_container, word_segmenter_output = tweet_segmenter.build_hashtag_container(original_tweets) + assert hashtag_container.hashtags == [['UnaGenialidad']] + assert hashtag_container.hashtag_set == ['UnaGenialidad'] + assert hashtag_container.replacement_dict == {'#UnaGenialidad': 'Una Genialidad'} + assert isinstance(word_segmenter_output, hashformers.segmenter.WordSegmenterOutput) + +def test_tweet_segmenter_generator(tweet_segmenter): + original_tweets = [ + "esto es #UnaGenialidad" + ] + hashtag_container, word_segmenter_output = tweet_segmenter.build_hashtag_container(original_tweets) + for tweet in tweet_segmenter.segmented_tweet_generator(original_tweets, *dataclasses.astuple(hashtag_container), flag=0): + assert tweet == "esto es Una Genialidad" def test_tweet_segmenter_output_format(tweet_segmenter): @@ -85,7 +112,11 @@ def test_tweet_segmenter_output_format(tweet_segmenter): "esto es Una Genialidad" ] - output_tweets = tweet_segmenter.predict(original_tweets).output + output_tweets = tweet_segmenter.predict(original_tweets) + + output_tweets = output_tweets.output + + assert type(output_tweets) == type([]) assert len(original_tweets) == len(expected_tweets) == len(output_tweets) From f5a622879e4f0ee4211977411567a8299b21e8f2 Mon Sep 17 00:00:00 2001 From: ruanchaves Date: Mon, 7 Feb 2022 00:05:05 +0100 Subject: [PATCH 25/25] all tests passing --- tests/fixtures/test_boun_sample.txt | 3 --- tests/test_segmenter.py | 20 +++++++++----------- 2 files changed, 9 insertions(+), 14 deletions(-) delete mode 100644 tests/fixtures/test_boun_sample.txt diff --git a/tests/fixtures/test_boun_sample.txt b/tests/fixtures/test_boun_sample.txt deleted file mode 100644 index 57222e6..0000000 --- a/tests/fixtures/test_boun_sample.txt +++ /dev/null @@ -1,3 +0,0 @@ -minecraf -ourmomentfragrance -waybackwhen \ No newline at end of file diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index 5984076..27ac278 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -1,23 +1,21 @@ -import dataclasses -import hashformers -from hashformers.segmenter import RegexWordSegmenter, TweetSegmenter, TwitterTextMatcher, WordSegmenterOutput import pytest + +import dataclasses import json -from hashformers import prune_segmenter_layers +import os from pathlib import Path + import hashformers -import os import torch +from hashformers import prune_segmenter_layers +from hashformers.segmenter import (RegexWordSegmenter, TweetSegmenter, + TwitterTextMatcher) test_data_dir = Path(__file__).parent.absolute() cuda_is_available = torch.cuda.is_available() -with open(os.path.join(test_data_dir,"fixtures/test_boun_sample.txt"), "r") as f1,\ - open(os.path.join(test_data_dir,"fixtures/word_segmenters.json"), "r") as f2: - - test_boun_gold = f1.read().strip().split("\n") - test_boun_hashtags = [ x.replace(" ", "") for x in test_boun_gold] - word_segmenter_params = json.load(f2) +with open(os.path.join(test_data_dir,"fixtures/word_segmenters.json"), "r") as f: + word_segmenter_params = json.load(f) word_segmenter_test_ids = [] for row in word_segmenter_params: class_name = row["class"]