Skip to content

Commit

Permalink
Merge pull request #37 from ruanchaves/minicons
Browse files Browse the repository at this point in the history
Hashformers v2.0
  • Loading branch information
ruanchaves authored Jun 3, 2023
2 parents 8649376 + c63ccc4 commit 144b8f8
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 188 deletions.
53 changes: 37 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@

Hashtag segmentation is the task of automatically adding spaces between the words on a hashtag.

[Hashformers](https://github.com/ruanchaves/hashformers) is the current **state-of-the-art** for hashtag segmentation. On average, hashformers is **10% more accurate** than the second best hashtag segmentation library ( [Learn More](https://github.com/ruanchaves/hashformers/blob/master/tutorials/EVALUATION.md) ).
[Hashformers](https://github.com/ruanchaves/hashformers) is the current **state-of-the-art** for hashtag segmentation, as demonstrated on [this paper accepted at LREC 2022](https://aclanthology.org/2022.lrec-1.782.pdf).

Hashformers is also **language-agnostic**: you can use it to segment hashtags not just in English, but also in any language with a GPT-2 model on the [Hugging Face Model Hub](https://huggingface.co/models).
Hashformers is also **language-agnostic**: you can use it to segment hashtags not just with English models, but also using any language model available on the [Hugging Face Model Hub](https://huggingface.co/models).

<p align="center">

<h3> <a href="https://ruanchaves-hashtag-segmentation.hf.space/"> ✂️ Segment hashtags on Hugging Face Spaces </a> </h3>

<h3> <a href="https://colab.research.google.com/github/ruanchaves/hashformers/blob/master/hashformers.ipynb"> ✂️ Get started - Google Colab tutorial </a> </h3>

</p>
<h3> <a href="https://github.com/ruanchaves/hashformers/wiki"> ✂️ Read the Docs </a> </h3>

</p>


## Basic usage
Expand All @@ -26,7 +27,9 @@ from hashformers import TransformerWordSegmenter as WordSegmenter

ws = WordSegmenter(
segmenter_model_name_or_path="gpt2",
reranker_model_name_or_path="bert-base-uncased"
segmenter_model_type="incremental",
reranker_model_name_or_path="google/flan-t5-base",
reranker_model_type="seq2seq"
)

segmentations = ws.segment([
Expand All @@ -40,30 +43,44 @@ print(segmentations)
# 'ice cold' ]
```

## Installation
It is also possible to use hashformers without a reranker by setting the `reranker_model_name_or_path` and the `reranker_model_type` to `None`.

Hashformers is compatible with Python 3.7.
## Installation

```
pip install hashformers
```

It is possible to use **hashformers** without a reranker:
## What models can I use?

Visit the [HuggingFace Model Hub](https://huggingface.co/models) and choose your models for the `WordSegmenter` class.

You can use any model supported by the [minicons](https://github.com/kanishkamisra/minicons) library. Currently `hashformers` supports the following model types as the `segmenter_model_type` or `reranker_model_type`:

### `incremental`

Auto-regressive models like GPT-2 and XLNet, or any model that can be loaded with `AutoModelForCausalLM`. This includes large language models (LLMs) such as Alpaca-LoRA ( `chainyo/alpaca-lora-7b` ) and GPT-J ( `EleutherAI/gpt-j-6b` ).

```python
from hashformers import TransformerWordSegmenter as WordSegmenter
ws = WordSegmenter(
segmenter_model_name_or_path="gpt2",
reranker_model_name_or_path=None
segmenter_model_name_or_path="EleutherAI/gpt-j-6b",
segmenter_model_type="incremental",
reranker_model_name_or_path=None,
reranker_model_type=None
)
```

If you want to use a BERT model as a reranker, you must install [mxnet](https://pypi.org/project/mxnet/). Here we install **hashformers** with `mxnet-cu110`, which is compatible with Hugging Face Spaces. If installing in another environment, replace it by the [mxnet package](https://pypi.org/project/mxnet/) compatible with your CUDA version.
### `masked`

```
pip install mxnet-cu110
pip install hashformers
```
Masked language models like BERT, or any model that can be loaded with `AutoModelForMaskedLM`.

### `seq2seq`

Seq2Seq models like FLAN-T5 ( `google/flan-t5-base` ), or any model that can be loaded with `AutoModelForSeq2SeqLM`.

Best results are usually achieved by using an `incremental` model as the `segmenter_model_name_or_path` and a `masked` or `seq2seq` model as the `reranker_model_name_or_path`.

A segmenter is always required, however a reranker is optional.

## Contributing

Expand All @@ -81,6 +98,10 @@ pip install -e .

This is a collection of papers that have utilized the *hashformers* library as a tool in their research.

### hashformers v1.3

These papers have utilized `hashformers` version 1.3 or below.

* [Zero-shot hashtag segmentation for multilingual sentiment analysis](https://arxiv.org/abs/2112.03213)

* [HashSet -- A Dataset For Hashtag Segmentation (LREC 2022)](https://aclanthology.org/2022.lrec-1.782/)
Expand All @@ -104,4 +125,4 @@ This is a collection of papers that have utilized the *hashformers* library as a
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
```
48 changes: 39 additions & 9 deletions hashformers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
"id": "geWaMgWXu1f5"
},
"source": [
"Here we install `mxnet-cu110` and `hashformers`. \n",
"\n",
"\n",
"**Deprecation Notice**: Support for `mxnet-cu110` has been discontinued on Google Colab. If you intend to execute cells involving the reranker, please consider using an alternative environment."
"Here we install `hashformers`. "
]
},
{
Expand All @@ -45,7 +42,6 @@
"source": [
"%%capture\n",
"\n",
"!pip install mxnet-cu110 \n",
"!pip install hashformers"
]
},
Expand Down Expand Up @@ -81,6 +77,7 @@
"\n",
"ws = WordSegmenter(\n",
" segmenter_model_name_or_path=\"distilgpt2\",\n",
" segmenter_model_type=\"incremental\",\n",
" reranker_model_name_or_path=None\n",
")"
]
Expand Down Expand Up @@ -137,9 +134,36 @@
"id": "1F0rTjzQWY6q"
},
"source": [
"You can use **hashformers** to segment hashtags in any language, not just English. Visit the [HuggingFace Model Hub](https://huggingface.co/models) and choose any GPT-2 and a BERT models for the WordSegmenter class.\n",
"## What models can I use?\n",
"\n",
"You can use **hashformers** to segment hashtags in any language, not just English. Visit the [HuggingFace Model Hub](https://huggingface.co/models) and choose your models for the `WordSegmenter` class.\n",
"\n",
"You can use any model supported by the [minicons](https://github.com/kanishkamisra/minicons) library. Currently `hashformers` supports the following model types as the `segmenter_model_type` or `reranker_model_type`:\n",
"\n",
"### `incremental`\n",
"\n",
"Auto-regressive models like GPT-2 and XLNet, or any model that can be loaded with `AutoModelForCausalLM`. This includes recent large language models (LLMs) such as Alpaca-LoRA ( `chainyo/alpaca-lora-7b` ) and GPT-J ( `EleutherAI/gpt-j-6b` ).\n",
"\n",
"### `masked`\n",
"\n",
"Masked language models like BERT, or any model that can be loaded with `AutoModelForMaskedLM`.\n",
"\n",
"### `seq2seq`\n",
"\n",
"Seq2Seq models like FLAN-T5 ( `google/flan-t5-base` ), or any model that can be loaded with `AutoModelForSeq2SeqLM`.\n",
"\n",
"The GPT-2 model should be informed as `segmenter_model_name_or_path` and the BERT model as `reranker_model_name_or_path`. A segmenter is required, however a reranker is optional. "
"\n",
"Best results are usually achieved by using an `incremental` model as the `segmenter_model_name_or_path` and a `masked` or `seq2seq` model as the `reranker_model_name_or_path`. \n",
"\n",
"A segmenter is required, however a reranker is optional. "
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we segment hashtags in Portuguese using a GPT-2 model and a BERT model pretrained on Portuguese data."
]
},
{
Expand All @@ -156,7 +180,9 @@
"\n",
"portuguese_ws = WordSegmenter(\n",
" segmenter_model_name_or_path=\"pierreguillou/gpt2-small-portuguese\",\n",
" reranker_model_name_or_path=\"neuralmind/bert-base-portuguese-cased\"\n",
" segmenter_model_type=\"incremental\",\n",
" reranker_model_name_or_path=\"neuralmind/bert-base-portuguese-cased\",\n",
" segmenter_model_type=\"masked\"\n",
")"
]
},
Expand Down Expand Up @@ -235,7 +261,9 @@
"\n",
"ws = WordSegmenter(\n",
" segmenter_model_name_or_path=\"distilgpt2\",\n",
" segmenter_model_type=\"incremental\",\n",
" reranker_model_name_or_path=\"distilbert-base-uncased\",\n",
" reranker_model_type=\"masked\",\n",
" segmenter_gpu_batch_size=1,\n",
" reranker_gpu_batch_size=2000\n",
")"
Expand Down Expand Up @@ -1121,7 +1149,9 @@
"\n",
"ws = TransformerWordSegmenter(\n",
" segmenter_model_name_or_path=\"distilgpt2\",\n",
" reranker_model_name_or_path=None\n",
" segmenter_model_type=\"incremental\",\n",
" reranker_model_name_or_path=None,\n",
" reranker_model_type=None,\n",
")\n",
"\n",
"def generate_experiments(datasets, splits, samples=100):\n",
Expand Down
5 changes: 1 addition & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
sphinx_markdown_tables
recommonmark
mlm-hashformers
lm-scorer-hashformers
minicons
twitter-text-python
ekphrasis
pandas
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

setup(
name='hashformers',
version='1.2.8',
version='2.0.0',
author='Ruan Chaves Rodrigues',
author_email='[email protected]',
description='Word segmentation with transformers',
packages=find_packages('src'),
package_dir={'': 'src'},
install_requires=[
"mlm-hashformers",
"lm-scorer-hashformers",
"minicons",
"twitter-text-python",
"pandas"
]
Expand Down
45 changes: 8 additions & 37 deletions src/hashformers/beamsearch/bert_lm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import mxnet as mx
import numpy as np
import pandas as pd
from mlm.models import get_pretrained
from mlm.scorers import MLMScorerPT
from hashformers.beamsearch.minicons_lm import MiniconsLM

class BertLM(object):
class BertLM(MiniconsLM):
"""
Implements a BERT-based language model scorer, to compute sentence probabilities.
This class uses a transformer-based Masked Language Model (MLM) for scoring.
Expand All @@ -18,34 +14,9 @@ class BertLM(object):
"""
def __init__(self, model_name_or_path, gpu_batch_size=1, gpu_id=0):
mx_device = [mx.gpu(gpu_id)]
self.scorer = MLMScorerPT(*get_pretrained(mx_device, model_name_or_path), mx_device)
self.gpu_batch_size = gpu_batch_size

def get_probs(self, list_of_candidates):
"""
Returns probabilities for a list of candidate sentences.
Args:
list_of_candidates (list): A list of sentences for which the probability is to be
calculated. Each sentence should be a string.
Returns:
list: A list of probabilities corresponding to the input sentences. If an exception is encountered
while computing the probability for a sentence (e.g., if the sentence is not a string or
is NaN), the corresponding score in the output list is NaN.
"""
scores = []
try:
scores = self.scorer.score_sentences(list_of_candidates, split_size=self.gpu_batch_size)
scores = [ x * -1 for x in scores ]
return scores
except:
for candidate in list_of_candidates:
try:
score = self.scorer.score_sentences([candidate])[0] if not pd.isna(candidate) else np.nan
score = score * -1
except IndexError:
score = np.nan
scores.append(score)
return scores
super().__init__(
model_name_or_path=model_name_or_path,
device='cuda',
gpu_batch_size=gpu_batch_size,
model_type='MaskedLMScorer'
)
Loading

0 comments on commit 144b8f8

Please sign in to comment.