diff --git a/.github/workflows/lint_and_tests.yml b/.github/workflows/lint_and_tests.yml new file mode 100644 index 00000000..2f43b652 --- /dev/null +++ b/.github/workflows/lint_and_tests.yml @@ -0,0 +1,32 @@ +name: lint_and_tests + +on: [push, pull_request] + +jobs: + build: + strategy: + max-parallel: 1 + matrix: + platform: [ubuntu-latest] + python-version: [3.8] + + runs-on: ${{ matrix.platform }} + + steps: + - uses: actions/checkout@v2 + + - name: Install dependencies + run: | + python --version + python -m pip install --upgrade 'pip>=23.2.1' + python -m pip show pip + python -m pip install -e '.[dev]' + + - name: isort + run: cd laser_encoders && isort --check --diff . + + - name: black + run: cd laser_encoders && black --check --diff . + + - name: pytest + run: pytest laser_encoders diff --git a/.gitignore b/.gitignore index 16290d9e..95098827 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,6 @@ tasks/xnli/XNLI-1.0* tasks/xnli/multinli_1.0* .??*swp .idea +__pycache__ +nllb +dist diff --git a/README.md b/README.md index 96d96ff0..526f9632 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ LASER is a library to calculate and use multilingual sentence embeddings. **NEWS** +* 2023/11/16 Released [**laser_encoders**](laser_encoders), a pip-installable package supporting LASER-2 and LASER-3 models * 2023/06/26 [**xSIM++**](https://arxiv.org/abs/2306.12907) evaluation pipeline and data [**released**](tasks/xsimplusplus/README.md) * 2022/07/06 Updated LASER models with support for over 200 languages are [**now available**](nllb/README.md) * 2022/07/06 Multilingual similarity search (**xsim**) evaluation pipeline [**released**](tasks/xsim/README.md) @@ -26,7 +27,27 @@ a language family which is covered by other languages. A detailed description of how the multilingual sentence embeddings are trained can be found [here](https://arxiv.org/abs/2205.12654), together with an experimental evaluation. -## Dependencies +## The core sentence embedding package: `laser_encoders` +We provide a package `laser_encoders` with minimal dependencies. +It supports LASER-2 (a single encoder for the languages listed [below](#supported-languages)) +and LASER-3 (147 language-specific encoders described [here](nllb/README.md)). + +The package can be installed simply with `pip install laser_encoders` and used as below: + +```python +from laser_encoders import LaserEncoderPipeline +encoder = LaserEncoderPipeline(lang="eng_Latn") +embeddings = encoder.encode_sentences(["Hi!", "This is a sentence encoder."]) +print(embeddings.shape) # (2, 1024) +``` + +The laser_encoders [readme file](laser_encoders) provides more examples of its installation and usage. + +## The full LASER kit +Apart from the `laser_encoders`, we provide support for LASER-1 (the original multilingual encoder) +and for various LASER applications listed below. + +### Dependencies * Python >= 3.7 * [PyTorch 1.0](http://pytorch.org/) * [NumPy](http://www.numpy.org/), tested with 1.15.4 @@ -42,7 +63,8 @@ be found [here](https://arxiv.org/abs/2205.12654), together with an experimental * [pandas](https://pypi.org/project/pandas), data analysis toolkit (`pip install pandas`) * [Sentencepiece](https://github.com/google/sentencepiece), subword tokenization (installed automatically) -## Installation +### Installation +* install the `laser_encoders` package by e.g. `pip install -e .` for installing it in the editable mode * set the environment variable 'LASER' to the root of the installation, e.g. `export LASER="${HOME}/projects/laser"` * download encoders from Amazon s3 by e.g. `bash ./nllb/download_models.sh` diff --git a/install_external_tools.sh b/install_external_tools.sh index 9fba8417..6aee045f 100755 --- a/install_external_tools.sh +++ b/install_external_tools.sh @@ -181,6 +181,10 @@ InstallMecab () { # ################################################################### +echo "Installing the laser_encoders package in editable mode" + +pip install -e . + echo "Installing external tools" InstallMosesTools diff --git a/laser_encoders/README.md b/laser_encoders/README.md new file mode 100644 index 00000000..a20ed62c --- /dev/null +++ b/laser_encoders/README.md @@ -0,0 +1,149 @@ +# LASER encoders + +LASER Language-Agnostic SEntence Representations Toolkit + +laser_encoders is the official Python package for the Facebook [LASER](https://github.com/facebookresearch/LASER) library. It provides a simple and convenient way to use LASER embeddings in Python. It allows you to calculate multilingual sentence embeddings using the LASER toolkit. These embeddings can be utilized for various natural language processing tasks, including document classification, bitext filtering, and mining. + +## Dependencies + +- Python `>= 3.8` +- [PyTorch `>= 1.10.0`](http://pytorch.org/) +- sacremoses `>=0.1.0` +- sentencepiece `>=0.1.99` +- numpy `>=1.21.3` +- fairseq `>=0.12.2` + +You can find a full list of requirements [here](https://github.com/facebookresearch/LASER/blob/main/pyproject.toml) + +## Installation + +You can install `laser_encoders` package from PyPI: + +```sh +pip install laser_encoders +``` + +Alternatively, you can install it from a local clone of this repository, in editable mode: +```sh +pip install . -e +``` + +## Usage + +Here's a simple example on how to obtain embeddings for sentences using the `LaserEncoderPipeline`: + +>**Note:** By default, the models will be downloaded to the `~/.cache/laser_encoders` directory. To specify a different download location, you can provide the argument `model_dir=path/to/model/directory` + +```py +from laser_encoders import LaserEncoderPipeline + +# Initialize the LASER encoder pipeline +encoder = LaserEncoderPipeline(lang="igbo") + +# Encode sentences into embeddings +embeddings = encoder.encode_sentences(["nnọọ, kedu ka ị mere"]) +# If you want the output embeddings to be L2-normalized, set normalize_embeddings to True +normalized_embeddings = encoder.encode_sentences(["nnọọ, kedu ka ị mere"], normalize_embeddings=True) + +``` + +If you prefer more control over the tokenization and encoding process, you can initialize the tokenizer and encoder separately: +```py +from laser_encoders import initialize_encoder, initialize_tokenizer + +# Initialize the LASER tokenizer +tokenizer = initialize_tokenizer(lang="igbo") +tokenized_sentence = tokenizer.tokenize("nnọọ, kedu ka ị mere") + +# Initialize the LASER sentence encoder +encoder = initialize_encoder(lang="igbo") + +# Encode tokenized sentences into embeddings +embeddings = encoder.encode_sentences([tokenized_sentence]) +``` +>By default, the `spm` flag is set to `True` when initializing the encoder, ensuring the accompanying spm model is downloaded. + +**Supported Languages:** You can specify any language from the [FLORES200](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200) dataset. This includes both languages identified by their full codes (like "ibo_Latn") and simpler alternatives (like "igbo"). + +## Downloading the pre-trained models + +If you prefer to download the models individually, you can use the following command: + +```sh +python -m laser_encoders.download_models --lang=your_prefered_language # e.g., --lang="igbo"" +``` + +By default, the downloaded models will be stored in the `~/.cache/laser_encoders` directory. To specify a different download location, utilize the following command: + +```sh +python -m laser_encoders.download_models --model-dir=path/to/model/directory +``` + +> For a comprehensive list of available arguments, you can use the `--help` command with the download_models script. + +Once you have successfully downloaded the models, you can utilize the `SentenceEncoder` to tokenize and encode your text in your desired language. Here's an example of how you can achieve this: + +```py +from laser_encoders.models import SentenceEncoder +from pathlib import Path + +encoder = SentenceEncoder(model_path=path/to/downloaded/model, spm_model=Path(path/to/spm_model), spm_vocab=path/to/cvocab) +embeddings = encoder("This is a test sentence.") +``` +If you want to perform tokenization seperately, you can do this below: +```py +from laser_encoders.laser_tokenizer import LaserTokenizer + +tokenizer = LaserTokenizer(spm_model=Path(path/to/spm_model)) + +tokenized_sentence = tokenizer.tokenize("This is a test sentence.") + +encoder = SentenceEncoder(model_path=path/to/downloaded/model, spm_vocab=path/to/cvocab) +embeddings = encoder.encode_sentences([tokenized_sentence]) +``` + +For tokenizing a file instead of a string, you can use the following: + +```py +tokenized_sentence = tokenizer.tokenize_file(inp_fname=Path(path/to/input_file.txt), out_fname=Path(path/to/output_file.txt)) +``` + +### Now you can use these embeddings for downstream tasks + +For more advanced usage and options, please refer to the official LASER repository documentation. + +## LASER Versions and Associated Packages + +For users familiar with the earlier version of LASER, you might have encountered the [`laserembeddings`](https://pypi.org/project/laserembeddings/) package. This package primarily dealt with LASER-1 model embeddings. + +For the latest LASER-2,3 models, use the newly introduced `laser_encoders` package, which offers better performance and support for a wider range of languages. + + +## Contributing + +We welcome contributions from the developer community to enhance and improve laser_encoders. If you'd like to contribute, you can: + +1. Submit bug reports or feature requests through GitHub issues. +1. Fork the repository, make changes, and submit pull requests for review. + +Please follow our [Contribution Guidelines](https://github.com/facebookresearch/LASER/blob/main/CONTRIBUTING.md) to ensure a smooth process. + +### Code of Conduct + +We expect all contributors to adhere to our [Code of Conduct](https://github.com/facebookresearch/LASER/blob/main/CODE_OF_CONDUCT.md). + +### Contributors + +The following people have contributed to this project: + +- [Victor Joseph](https://github.com/CaptainVee) +- [Paul Okewunmi](https://github.com/Paulooh007) +- [Siddharth Singh Rana](https://github.com/NIXBLACK11) +- [David Dale](https://github.com/avidale/) +- [Holger Schwenk](https://github.com/hoschwenk) +- [Kevin Heffernan](https://github.com/heffernankevin) + +### License + +This package is released under the [LASER](https://github.com/facebookresearch/LASER/blob/main/LICENSE) BSD License. + diff --git a/laser_encoders/__init__.py b/laser_encoders/__init__.py new file mode 100644 index 00000000..05b46186 --- /dev/null +++ b/laser_encoders/__init__.py @@ -0,0 +1,16 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# LASER Language-Agnostic SEntence Representations +# is a toolkit to calculate multilingual sentence embeddings +# and to use them for document classification, bitext filtering +# and mining +# +# ------------------------------------------------------- + +from laser_encoders.laser_tokenizer import initialize_tokenizer +from laser_encoders.models import LaserEncoderPipeline, initialize_encoder diff --git a/laser_encoders/download_models.py b/laser_encoders/download_models.py new file mode 100644 index 00000000..fbd731db --- /dev/null +++ b/laser_encoders/download_models.py @@ -0,0 +1,151 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# LASER Language-Agnostic SEntence Representations +# is a toolkit to calculate multilingual sentence embeddings +# and to use them for document classification, bitext filtering +# and mining +# +# ------------------------------------------------------- +# +# This python script installs NLLB LASER2 and LASER3 sentence encoders from Amazon s3 + +import argparse +import logging +import os +import shutil +import sys +import tempfile +from pathlib import Path + +import requests +from tqdm import tqdm + +from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE, SPM_LANGUAGE + +logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", +) +logger = logging.getLogger(__name__) + + +class LaserModelDownloader: + def __init__(self, model_dir: str = None): + if model_dir is None: + model_dir = os.path.expanduser("~/.cache/laser_encoders") + os.makedirs(model_dir, exist_ok=True) + + self.model_dir = Path(model_dir) + self.base_url = "https://dl.fbaipublicfiles.com/nllb/laser" + + def download(self, filename: str): + url = os.path.join(self.base_url, filename) + local_file_path = os.path.join(self.model_dir, filename) + + if os.path.exists(local_file_path): + logger.info(f" - {filename} already downloaded") + else: + logger.info(f" - Downloading {filename}") + + tf = tempfile.NamedTemporaryFile(delete=False) + temp_file_path = tf.name + + with tf: + response = requests.get(url, stream=True) + total_size = int(response.headers.get("Content-Length", 0)) + progress_bar = tqdm(total=total_size, unit_scale=True, unit="B") + + for chunk in response.iter_content(chunk_size=1024): + tf.write(chunk) + progress_bar.update(len(chunk)) + progress_bar.close() + + shutil.move(temp_file_path, local_file_path) + + def get_language_code(self, language_list: dict, lang: str) -> str: + try: + lang_3_4 = language_list[lang] + if isinstance(lang_3_4, list): + options = ", ".join(f"'{opt}'" for opt in lang_3_4) + raise ValueError( + f"Language '{lang}' has multiple options: {options}. Please specify using the 'lang' argument." + ) + return lang_3_4 + except KeyError: + raise ValueError( + f"language name: {lang} not found in language list. Specify a supported language name" + ) + + def download_laser2(self): + self.download("laser2.pt") + self.download("laser2.spm") + self.download("laser2.cvocab") + + def download_laser3(self, lang: str, spm: bool = False): + result = self.get_language_code(LASER3_LANGUAGE, lang) + + if isinstance(result, list): + raise ValueError( + f"There are script-specific models available for {lang}. Please choose one from the following: {result}" + ) + + lang = result + self.download(f"laser3-{lang}.v1.pt") + if spm: + if lang in SPM_LANGUAGE: + self.download(f"laser3-{lang}.v1.spm") + self.download(f"laser3-{lang}.v1.cvocab") + else: + self.download(f"laser2.spm") + self.download(f"laser2.cvocab") + + def main(self, args): + if args.laser: + if args.laser == "laser2": + self.download_laser2() + elif args.laser == "laser3": + self.download_laser3(lang=args.lang, spm=args.spm) + else: + raise ValueError( + f"Unsupported laser model: {args.laser}. Choose either laser2 or laser3." + ) + else: + if args.lang in LASER3_LANGUAGE: + self.download_laser3(lang=args.lang, spm=args.spm) + elif args.lang in LASER2_LANGUAGE: + self.download_laser2() + else: + raise ValueError( + f"Unsupported language name: {args.lang}. Please specify a supported language name using --lang." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="LASER: Download Laser models") + parser.add_argument( + "--laser", + type=str, + help="Laser model to download", + ) + parser.add_argument( + "--lang", + type=str, + help="The language name in FLORES200 format", + ) + parser.add_argument( + "--spm", + action="store_false", + help="Do not download the SPM model?", + ) + parser.add_argument( + "--model-dir", type=str, help="The directory to download the models to" + ) + args = parser.parse_args() + downloader = LaserModelDownloader(args.model_dir) + downloader.main(args) diff --git a/laser_encoders/language_list.py b/laser_encoders/language_list.py new file mode 100644 index 00000000..88342601 --- /dev/null +++ b/laser_encoders/language_list.py @@ -0,0 +1,564 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# LASER Language-Agnostic SEntence Representations +# is a toolkit to calculate multilingual sentence embeddings +# and to use them for document classification, bitext filtering +# and mining +# +# ------------------------------------------------------- +# Language mapping to handle different language codes and names + + +def build_language_names_dict(language_list: list, language_names: dict) -> dict: + """ + Build a dictionary mapping language names to their corresponding language codes. + + Parameters: + - language_list (list): A list of language codes. + - language_names (dict): A dictionary mapping language codes to language names. + + Returns: + - dict: A dictionary mapping language names to their corresponding language codes. + """ + result_dict = {} + + for lang_code in language_list: + if lang_code not in language_names: + raise ValueError( + f"Language code '{lang_code}' not found in the provided language_names dictionary." + ) + + names_list = language_names[lang_code] + + # Ensure names_list is always a list + if not isinstance(names_list, list): + names_list = [names_list] + + for name in names_list: + if name not in result_dict: + result_dict[name] = [] + result_dict[name].append(lang_code) + + # Remove single-element lists and convert them to the element itself + for key in result_dict: + if len(result_dict[key]) == 1: + result_dict[key] = result_dict[key][0] + + return result_dict + + +SPM_LANGUAGE = [ + "amh_Ethi", + "ayr_Latn", + "azj_Latn", + "bak_Cyrl", + "bel_Cyrl", + "bod_Tibt", + "ckb_Arab", + "crh_Latn", + "dik_Latn", + "dzo_Tibt", + "fur_Latn", + "fuv_Latn", + "grn_Latn", + "kab_Latn", + "kac_Latn", + "kaz_Cyrl", + "kir_Cyrl", + "kmr_Latn", + "lij_Latn", + "lim_Latn", + "lmo_Latn", + "ltg_Latn", + "mya_Mymr", + "pbt_Arab", + "pes_Arab", + "prs_Arab", + "sat_Beng", + "scn_Latn", + "srd_Latn", + "szl_Latn", + "taq_Latn", + "tgk_Cyrl", + "tir_Ethi", + "tzm_Tfng", + "vec_Latn", +] + + +################################## +###### LANGUAGE NAMES ############ +################################## + +LANGUAGE_NAMES = { + "ace_Arab": ["acehnese", "ace", "ace_Arab"], + "ace_Latn": ["acehnese", "ace", "ace_Latn"], + "acm_Arab": ["mesopotamian arabic", "acm", "acm_Arab"], + "acq_Arab": ["ta’izzi-adeni arabic", "acq", "acq_Arab"], + "aeb_Arab": ["tunisian arabic", "aeb", "aeb_Arab"], + "afr_Latn": ["afrikaans", "afr", "afr_Latn"], + "ajp_Arab": ["south levantine arabic", "ajp", "ajp_Arab"], + "aka_Latn": ["akan", "aka", "aka_Latn"], + "amh_Ethi": ["amharic", "amh", "amh_Ethi"], + "apc_Arab": ["north levantine arabic", "apc", "apc_Arab"], + "arb_Arab": ["modern standard arabic", "arb", "arb_Arab"], + "arb_Latn": ["modern standard arabic", "arb", "arb_Latn"], + "ars_Arab": ["najdi arabic", "ars", "ars_Arab"], + "ary_Arab": ["moroccan arabic", "ary", "ary_Arab"], + "arz_Arab": ["egyptian arabic", "arz", "arz_Arab"], + "asm_Beng": ["assamese", "asm", "asm_Beng"], + "ast_Latn": ["asturian", "ast", "ast_Latn"], + "awa_Deva": ["awadhi", "awa", "awa_Deva"], + "ayr_Latn": ["central aymara", "ayr", "ayr_Latn"], + "azb_Arab": ["south azerbaijani", "azb", "azb_Arab"], + "azj_Latn": ["north azerbaijani", "azj", "azj_Latn"], + "bak_Cyrl": ["bashkir", "bak", "bak_Cyrl"], + "bam_Latn": ["bambara", "bam", "bam_Latn"], + "ban_Latn": ["balinese", "ban", "ban_Latn"], + "bel_Cyrl": ["belarusian", "bel", "bel_Cyrl"], + "bem_Latn": ["bemba", "bem", "bem_Latn"], + "ben_Beng": ["bengali", "ben", "ben_Beng"], + "bho_Deva": ["bhojpuri", "bho", "bho_Deva"], + "bjn_Arab": ["banjar", "bjn", "bjn_Arab"], + "bjn_Latn": ["banjar", "bjn", "bjn_Latn"], + "bod_Tibt": ["standard tibetan", "bod", "bod_Tibt"], + "bos_Latn": ["bosnian", "bos", "bos_Latn"], + "bug_Latn": ["buginese", "bug", "bug_Latn"], + "bul_Cyrl": ["bulgarian", "bul", "bul_Cyrl"], + "cat_Latn": ["catalan", "cat", "cat_Latn"], + "ceb_Latn": ["cebuano", "ceb", "ceb_Latn"], + "ces_Latn": ["czech", "ces", "ces_Latn"], + "cjk_Latn": ["chokwe", "cjk", "cjk_Latn"], + "ckb_Arab": ["central kurdish", "ckb", "ckb_Arab"], + "crh_Latn": ["crimean tatar", "crh", "crh_Latn"], + "cym_Latn": ["welsh", "cym", "cym_Latn"], + "dan_Latn": ["danish", "dan", "dan_Latn"], + "deu_Latn": ["german", "deu", "deu_Latn"], + "dik_Latn": ["southwestern dinka", "dik", "dik_Latn"], + "dyu_Latn": ["dyula", "dyu", "dyu_Latn"], + "dzo_Tibt": ["dzongkha", "dzo", "dzo_Tibt"], + "ell_Grek": ["greek", "ell", "ell_Grek"], + "eng_Latn": ["english", "eng", "eng_Latn"], + "epo_Latn": ["esperanto", "epo", "epo_Latn"], + "est_Latn": ["estonian", "est", "est_Latn"], + "eus_Latn": ["basque", "eus", "eus_Latn"], + "ewe_Latn": ["ewe", "ewe_Latn"], + "fao_Latn": ["faroese", "fao", "fao_Latn"], + "fij_Latn": ["fijian", "fij", "fij_Latn"], + "fin_Latn": ["finnish", "fin", "fin_Latn"], + "fon_Latn": ["fon", "fon_Latn"], + "fra_Latn": ["french", "fra", "fra_Latn"], + "fur_Latn": ["friulian", "fur", "fur_Latn"], + "fuv_Latn": ["nigerian fulfulde", "fuv", "fuv_Latn"], + "gla_Latn": ["scottish gaelic", "gla", "gla_Latn"], + "gle_Latn": ["irish", "gle", "gle_Latn"], + "glg_Latn": ["galician", "glg", "glg_Latn"], + "grn_Latn": ["guarani", "grn", "grn_Latn"], + "guj_Gujr": ["gujarati", "guj", "guj_Gujr"], + "hat_Latn": ["haitian creole", "hat", "hat_Latn"], + "hau_Latn": ["hausa", "hau", "hau_Latn"], + "heb_Hebr": ["hebrew", "heb", "heb_Hebr"], + "hin_Deva": ["hindi", "hin", "hin_Deva"], + "hne_Deva": ["chhattisgarhi", "hne", "hne_Deva"], + "hrv_Latn": ["croatian", "hrv", "hrv_Latn"], + "hun_Latn": ["hungarian", "hun", "hun_Latn"], + "hye_Armn": ["armenian", "hye", "hye_Armn"], + "ibo_Latn": ["igbo", "ibo", "ibo_Latn"], + "ilo_Latn": ["ilocano", "ilo", "ilo_Latn"], + "ind_Latn": ["indonesian", "ind", "ind_Latn"], + "isl_Latn": ["icelandic", "isl", "isl_Latn"], + "ita_Latn": ["italian", "ita", "ita_Latn"], + "jav_Latn": ["javanese", "jav", "jav_Latn"], + "jpn_Jpan": ["japanese", "jpn", "jpn_Jpan"], + "kab_Latn": ["kabyle", "kab", "kab_Latn"], + "kac_Latn": ["jingpho", "kac", "kac_Latn"], + "kam_Latn": ["kamba", "kam", "kam_Latn"], + "kan_Knda": ["kannada", "kan", "kan_Knda"], + "kas_Arab": ["kashmiri", "kas", "kas_Arab"], + "kas_Deva": ["kashmiri", "kas", "kas_Deva"], + "kat_Geor": ["georgian", "kat", "kat_Geor"], + "knc_Arab": ["central kanuri", "knc", "knc_Arab"], + "knc_Latn": ["central kanuri", "knc", "knc_Latn"], + "kaz_Cyrl": ["kazakh", "kaz", "kaz_Cyrl"], + "kbp_Latn": ["kabiyè", "kbp", "kbp_Latn"], + "kea_Latn": ["kabuverdianu", "kea", "kea_Latn"], + "khm_Khmr": ["khmer", "khm", "khm_Khmr"], + "kik_Latn": ["kikuyu", "kik", "kik_Latn"], + "kin_Latn": ["kinyarwanda", "kin", "kin_Latn"], + "kir_Cyrl": ["kyrgyz", "kir", "kir_Cyrl"], + "kmb_Latn": ["kimbundu", "kmb", "kmb_Latn"], + "kmr_Latn": ["northern kurdish", "kmr", "kmr_Latn"], + "kon_Latn": ["kikongo", "kon", "kon_Latn"], + "kor_Hang": ["korean", "kor", "kor_Hang"], + "lao_Laoo": ["lao", "lao_Laoo"], + "lij_Latn": ["ligurian", "lij", "lij_Latn"], + "lim_Latn": ["limburgish", "lim", "lim_Latn"], + "lin_Latn": ["lingala", "lin", "lin_Latn"], + "lit_Latn": ["lithuanian", "lit", "lit_Latn"], + "lmo_Latn": ["lombard", "lmo", "lmo_Latn"], + "ltg_Latn": ["latgalian", "ltg", "ltg_Latn"], + "ltz_Latn": ["luxembourgish", "ltz", "ltz_Latn"], + "lua_Latn": ["luba-kasai", "lua", "lua_Latn"], + "lug_Latn": ["ganda", "lug", "lug_Latn"], + "luo_Latn": ["luo", "luo_Latn"], + "lus_Latn": ["mizo", "lus", "lus_Latn"], + "lvs_Latn": ["standard latvian", "lvs", "lvs_Latn"], + "mag_Deva": ["magahi", "mag", "mag_Deva"], + "mai_Deva": ["maithili", "mai", "mai_Deva"], + "mal_Mlym": ["malayalam", "mal", "mal_Mlym"], + "mar_Deva": ["marathi", "mar", "mar_Deva"], + "min_Arab": ["minangkabau", "min", "min_Arab"], + "min_Latn": ["minangkabau", "min", "min_Latn"], + "mkd_Cyrl": ["macedonian", "mkd", "mkd_Cyrl"], + "plt_Latn": ["plateau malagasy", "plt", "plt_Latn"], + "mlt_Latn": ["maltese", "mlt", "mlt_Latn"], + "mni_Beng": ["meitei", "mni", "mni_Beng"], + "khk_Cyrl": ["halh mongolian", "khk", "khk_Cyrl"], + "mos_Latn": ["mossi", "mos", "mos_Latn"], + "mri_Latn": ["maori", "mri", "mri_Latn"], + "mya_Mymr": ["burmese", "mya", "mya_Mymr"], + "nld_Latn": ["dutch", "nld", "nld_Latn"], + "nno_Latn": ["norwegian nynorsk", "nno", "nno_Latn"], + "nob_Latn": ["norwegian bokmål", "nob", "nob_Latn"], + "npi_Deva": ["nepali", "npi", "npi_Deva"], + "nso_Latn": ["northern sotho", "nso", "nso_Latn"], + "nus_Latn": ["nuer", "nus", "nus_Latn"], + "nya_Latn": ["nyanja", "nya", "nya_Latn"], + "oci_Latn": ["occitan", "oci", "oci_Latn"], + "gaz_Latn": ["west central oromo", "gaz", "gaz_Latn"], + "ory_Orya": ["odia", "ory", "ory_Orya"], + "pag_Latn": ["pangasinan", "pag", "pag_Latn"], + "pan_Guru": ["eastern panjabi", "pan", "pan_Guru"], + "pap_Latn": ["papiamento", "pap", "pap_Latn"], + "pes_Arab": ["western persian", "pes", "pes_Arab"], + "pol_Latn": ["polish", "pol", "pol_Latn"], + "por_Latn": ["portuguese", "por", "por_Latn"], + "prs_Arab": ["dari", "prs", "prs_Arab"], + "pbt_Arab": ["southern pashto", "pbt", "pbt_Arab"], + "quy_Latn": ["ayacucho quechua", "quy", "quy_Latn"], + "ron_Latn": ["romanian", "ron", "ron_Latn"], + "run_Latn": ["rundi", "run", "run_Latn"], + "rus_Cyrl": ["russian", "rus", "rus_Cyrl"], + "sag_Latn": ["sango", "sag", "sag_Latn"], + "san_Deva": ["sanskrit", "san", "san_Deva"], + "sat_Olck": ["santali", "sat", "sat_Olck"], + "scn_Latn": ["sicilian", "scn", "scn_Latn"], + "shn_Mymr": ["shan", "shn", "shn_Mymr"], + "sin_Sinh": ["sinhala", "sin", "sin_Sinh"], + "slk_Latn": ["slovak", "slk", "slk_Latn"], + "slv_Latn": ["slovenian", "slv", "slv_Latn"], + "smo_Latn": ["samoan", "smo", "smo_Latn"], + "sna_Latn": ["shona", "sna", "sna_Latn"], + "snd_Arab": ["sindhi", "snd", "snd_Arab"], + "som_Latn": ["somali", "som", "som_Latn"], + "sot_Latn": ["southern sotho", "sot", "sot_Latn"], + "spa_Latn": ["spanish", "spa", "spa_Latn"], + "als_Latn": ["tosk albanian", "als", "als_Latn"], + "srd_Latn": ["sardinian", "srd", "srd_Latn"], + "srp_Cyrl": ["serbian", "srp", "srp_Cyrl"], + "ssw_Latn": ["swati", "ssw", "ssw_Latn"], + "sun_Latn": ["sundanese", "sun", "sun_Latn"], + "swe_Latn": ["swedish", "swe", "swe_Latn"], + "swh_Latn": ["swahili", "swh", "swh_Latn"], + "szl_Latn": ["silesian", "szl", "szl_Latn"], + "tam_Taml": ["tamil", "tam", "tam_Taml"], + "tat_Cyrl": ["tatar", "tat", "tat_Cyrl"], + "tel_Telu": ["telugu", "tel", "tel_Telu"], + "tgk_Cyrl": ["tajik", "tgk", "tgk_Cyrl"], + "tgl_Latn": ["tagalog", "tgl", "tgl_Latn"], + "tha_Thai": ["thai", "tha", "tha_Thai"], + "tir_Ethi": ["tigrinya", "tir", "tir_Ethi"], + "taq_Latn": ["tamasheq", "taq", "taq_Latn"], + "taq_Tfng": ["tamasheq", "taq", "taq_Tfng"], + "tpi_Latn": ["tok pisin", "tpi", "tpi_Latn"], + "tsn_Latn": ["tswana", "tsn", "tsn_Latn"], + "tso_Latn": ["tsonga", "tso", "tso_Latn"], + "tuk_Latn": ["turkmen", "tuk", "tuk_Latn"], + "tum_Latn": ["tumbuka", "tum", "tum_Latn"], + "tur_Latn": ["turkish", "tur", "tur_Latn"], + "twi_Latn": ["twi", "twi_Latn"], + "tzm_Tfng": ["central atlas tamazight", "tzm", "tzm_Tfng"], + "uig_Arab": ["uyghur", "uig", "uig_Arab"], + "ukr_Cyrl": ["ukrainian", "ukr", "ukr_Cyrl"], + "umb_Latn": ["umbundu", "umb", "umb_Latn"], + "urd_Arab": ["urdu", "urd", "urd_Arab"], + "uzn_Latn": ["northern uzbek", "uzn", "uzn_Latn"], + "vec_Latn": ["venetian", "vec", "vec_Latn"], + "vie_Latn": ["vietnamese", "vie", "vie_Latn"], + "war_Latn": ["waray", "war", "war_Latn"], + "wol_Latn": ["wolof", "wol", "wol_Latn"], + "xho_Latn": ["xhosa", "xho", "xho_Latn"], + "ydd_Hebr": ["eastern yiddish", "ydd", "ydd_Hebr"], + "yor_Latn": ["yoruba", "yor", "yor_Latn"], + "yue_Hant": ["yue chinese", "yue", "yue_Hant"], + "zho_Hans": ["chinese", "zho", "zho_Hans"], + "zho_Hant": ["chinese", "zho", "zho_Hant"], + "zsm_Latn": ["standard malay", "zsm", "zsm_Latn"], + "zul_Latn": ["zulu", "zul", "zul_Latn"], + "diq_Latn": ["southern zaza", "diq", "diq_Latn"], + "sat_Beng": ["santali", "sat", "sat_Beng"], +} + +################################## +###### LASER 3 ################### +################################## + +LASER3_LANGUAGES_LIST = [ + "ace_Latn", + "aka_Latn", + "als_Latn", + "amh_Ethi", + "asm_Beng", + "awa_Deva", + "ayr_Latn", + "azb_Arab", + "azj_Latn", + "bak_Cyrl", + "bam_Latn", + "ban_Latn", + "bel_Cyrl", + "bem_Latn", + "ben_Beng", + "bho_Deva", + "bjn_Latn", + "bod_Tibt", + "bug_Latn", + "ceb_Latn", + "cjk_Latn", + "ckb_Arab", + "crh_Latn", + "cym_Latn", + "dik_Latn", + "diq_Latn", + "dyu_Latn", + "dzo_Tibt", + "ewe_Latn", + "fao_Latn", + "fij_Latn", + "fon_Latn", + "fur_Latn", + "fuv_Latn", + "gaz_Latn", + "gla_Latn", + "gle_Latn", + "grn_Latn", + "guj_Gujr", + "hat_Latn", + "hau_Latn", + "hin_Deva", + "hne_Deva", + "hye_Armn", + "ibo_Latn", + "ilo_Latn", + "ind_Latn", + "jav_Latn", + "kab_Latn", + "kac_Latn", + "kam_Latn", + "kan_Knda", + "kas_Arab", + "kas_Deva", + "kat_Geor", + "kaz_Cyrl", + "kbp_Latn", + "kea_Latn", + "khk_Cyrl", + "khm_Khmr", + "kik_Latn", + "kin_Latn", + "kir_Cyrl", + "kmb_Latn", + "kmr_Latn", + "knc_Arab", + "knc_Latn", + "kon_Latn", + "lao_Laoo", + "lij_Latn", + "lim_Latn", + "lin_Latn", + "lmo_Latn", + "ltg_Latn", + "ltz_Latn", + "lua_Latn", + "lug_Latn", + "luo_Latn", + "lus_Latn", + "mag_Deva", + "mai_Deva", + "mal_Mlym", + "mar_Deva", + "min_Latn", + "mlt_Latn", + "mni_Beng", + "mos_Latn", + "mri_Latn", + "mya_Mymr", + "npi_Deva", + "nso_Latn", + "nus_Latn", + "nya_Latn", + "ory_Orya", + "pag_Latn", + "pan_Guru", + "pap_Latn", + "pbt_Arab", + "pes_Arab", + "plt_Latn", + "prs_Arab", + "quy_Latn", + "run_Latn", + "sag_Latn", + "san_Deva", + "sat_Beng", + "scn_Latn", + "shn_Mymr", + "sin_Sinh", + "smo_Latn", + "sna_Latn", + "snd_Arab", + "som_Latn", + "sot_Latn", + "srd_Latn", + "ssw_Latn", + "sun_Latn", + "swh_Latn", + "szl_Latn", + "tam_Taml", + "taq_Latn", + "tat_Cyrl", + "tel_Telu", + "tgk_Cyrl", + "tgl_Latn", + "tha_Thai", + "tir_Ethi", + "tpi_Latn", + "tsn_Latn", + "tso_Latn", + "tuk_Latn", + "tum_Latn", + "tur_Latn", + "twi_Latn", + "tzm_Tfng", + "uig_Arab", + "umb_Latn", + "urd_Arab", + "uzn_Latn", + "vec_Latn", + "war_Latn", + "wol_Latn", + "xho_Latn", + "ydd_Hebr", + "yor_Latn", + "zsm_Latn", + "zul_Latn", +] + + +LASER3_LANGUAGE = build_language_names_dict(LASER3_LANGUAGES_LIST, LANGUAGE_NAMES) + +################################## +###### LASER 2 ################### +################################## + +LASER2_LANGUAGES_LIST = [ + "acm_Arab", + "acq_Arab", + "aeb_Arab", + "afr_Latn", + "ajp_Arab", + "amh_Ethi", + "apc_Arab", + "arb_Arab", + "arb_Latn", + "ars_Arab", + "ary_Arab", + "arz_Arab", + "ayr_Latn", + "azb_Arab", + "azj_Latn", + "bel_Cyrl", + "ben_Beng", + "bos_Latn", + "bul_Cyrl", + "cat_Latn", + "ces_Latn", + "ckb_Arab", + "crh_Latn", + "dan_Latn", + "deu_Latn", + "ell_Grek", + "eng_Latn", + "epo_Latn", + "est_Latn", + "eus_Latn", + "fin_Latn", + "fra_Latn", + "gle_Latn", + "glg_Latn", + "hau_Latn", + "heb_Hebr", + "hin_Deva", + "hrv_Latn", + "hun_Latn", + "hye_Armn", + "ind_Latn", + "isl_Latn", + "ita_Latn", + "jpn_Jpan", + "kab_Latn", + "kat_Geor", + "kaz_Cyrl", + "khm_Khmr", + "kmr_Latn", + "kor_Hang", + "lit_Latn", + "lvs_Latn", + "mal_Mlym", + "mar_Deva", + "mkd_Cyrl", + "plt_Latn", + "mya_Mymr", + "nld_Latn", + "nob_Latn", + "oci_Latn", + "pes_Arab", + "pol_Latn", + "por_Latn", + "ron_Latn", + "rus_Cyrl", + "sin_Sinh", + "slk_Latn", + "slv_Latn", + "snd_Arab", + "som_Latn", + "spa_Latn", + "als_Latn", + "srp_Cyrl", + "swe_Latn", + "swh_Latn", + "tam_Taml", + "tat_Cyrl", + "tel_Telu", + "tgk_Cyrl", + "tgl_Latn", + "tha_Thai", + "tur_Latn", + "uig_Arab", + "ukr_Cyrl", + "urd_Arab", + "uzn_Latn", + "vie_Latn", + "yue_Hant", + "yue_Hant", + "zho_Hans", + "zho_Hant", + "zsm_Latn", +] + + +LASER2_LANGUAGE = build_language_names_dict(LASER2_LANGUAGES_LIST, LANGUAGE_NAMES) diff --git a/laser_encoders/laser_tokenizer.py b/laser_encoders/laser_tokenizer.py new file mode 100644 index 00000000..5cbd2a4e --- /dev/null +++ b/laser_encoders/laser_tokenizer.py @@ -0,0 +1,179 @@ +#!/usr/bin/python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# LASER Language-Agnostic SEntence Representations +# is a toolkit to calculate multilingual sentence embeddings +# and to use them for document classification, bitext filtering +# and mining +# +# -------------------------------------------------------- +# +# Helper functions for tokenization + +import gzip +import logging +import os +import re +import sys +from pathlib import Path +from typing import IO, List + +import sentencepiece as spm +from sacremoses import MosesDetokenizer, MosesPunctNormalizer +from unicategories import categories + +from laser_encoders.download_models import LaserModelDownloader +from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE, SPM_LANGUAGE + +SPACE_NORMALIZER = re.compile(r"\s+") +NON_PRINT_CHARS = set(c for c in categories["C"].characters()) + +logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", +) +logger = logging.getLogger("preprocess") + + +class LaserTokenizer: + def __init__( + self, + spm_model: Path, + lang: str = "en", + lower_case: bool = True, + descape: bool = False, + verbose: bool = False, + over_write: bool = False, + normalize_punct: bool = True, + ): + self.spm_model = spm_model + self.lang = lang + self.lower_case = lower_case + self.descape = descape + self.verbose = verbose + self.over_write = over_write + self.normalize_punct = normalize_punct + + assert spm_model.exists(), f"spm model file: {spm_model} does not exist" + self.moses_punct_normalizer = MosesPunctNormalizer(self.lang, perl_parity=True) + # add parity with MOSES release-4.0 + self.moses_punct_normalizer.substitutions[21] = ("‘", r'"') + self.moses_punct_normalizer.substitutions[22] = ("‚", r'"') + self.moses_detokenizer = MosesDetokenizer() + self.spm_encoder = spm.SentencePieceProcessor(model_file=str(self.spm_model)) + + def open(self, file: Path, mode: str, encoding="utf-8") -> IO: + return ( + gzip.open(file, mode, encoding=encoding) + if file.name.endswith(".gz") + else open(file, mode, encoding=encoding) + ) + + def log(self, message: str) -> None: + if self.verbose: + logger.info(message) + + def tokenize(self, text: str) -> str: + # Preprocessing + sentence_text = "".join([c if c not in NON_PRINT_CHARS else " " for c in text]) + if self.normalize_punct: + sentence_text = self.moses_punct_normalizer.normalize(sentence_text) + if self.descape: + sentence_text = self.moses_detokenizer.unescape_xml(text=sentence_text) + if self.lower_case: + sentence_text = sentence_text.lower() + + # SentencePiece encoding + encoded_text = " ".join(self.spm_encoder.encode(sentence_text, out_type=str)) + return encoded_text + + def tokenize_file(self, inp_fname: Path, out_fname: Path) -> None: + if not self.over_write and out_fname.exists(): + self.log(f"tokenized file {out_fname.name} already exists") + return + else: + self.log( + f"tokenizing {inp_fname.name}" + + f"{' (de-escaped)' if self.descape else ''}" + + f"{' (lower-cased)' if self.lower_case else ' (cased)'} " + + f"(punctuation-normalization lang: {self.lang})" + ) + + with self.open(inp_fname, "rt") as file_in, open( + out_fname, "w" + ) as file_out: + for line in file_in: + tokens = self.tokenize(line.strip()) + file_out.write(tokens + "\n") + + def __call__(self, text_or_batch): + if isinstance(text_or_batch, str): + return self.tokenize(text_or_batch) + else: + return self.tokenize_batch(text_or_batch) + + def tokenize_batch(self, batch: List[str]) -> List[List[str]]: + return [self.tokenize(text) for text in batch] + + def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: + return [self.spm_encoder.DecodeIds(ids) for ids in ids] + + def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: + ids = [] + + for token in tokens: + # Apply the same tokenization logic as in _tokenize method + tokens = SPACE_NORMALIZER.sub(" ", token).strip().split() + + # Initialize an empty tensor for this token's IDs + token_ids = [] + + for i, token in enumerate(tokens): + token_id = self.spm_encoder.PieceToId(token) + if token_id == 0: # Handle out-of-vocabulary tokens + token_id = self.spm_encoder.PieceToId("") + token_ids.append(token_id) + + # Append token IDs to the final IDs tensor + ids.extend(token_ids) + + return ids + + +def initialize_tokenizer(lang: str = None, model_dir: str = None, laser: str = None): + downloader = LaserModelDownloader(model_dir) + if laser is not None: + if laser == "laser3": + lang = downloader.get_language_code(LASER3_LANGUAGE, lang) + if lang in SPM_LANGUAGE: + filename = f"laser3-{lang}.v1.spm" + else: + filename = "laser2.spm" + elif laser == "laser2": + filename = "laser2.spm" + else: + raise ValueError( + f"Unsupported laser model: {laser}. Choose either laser2 or laser3." + ) + else: + if lang in LASER3_LANGUAGE: + lang = downloader.get_language_code(LASER3_LANGUAGE, lang) + if lang in SPM_LANGUAGE: + filename = f"laser3-{lang}.v1.spm" + else: + filename = "laser2.spm" + elif lang in LASER2_LANGUAGE: + filename = "laser2.spm" + else: + raise ValueError( + f"Unsupported language name: {lang}. Please specify a supported language name." + ) + + downloader.download(filename) + model_path = os.path.join(downloader.model_dir, filename) + return LaserTokenizer(spm_model=Path(model_path)) diff --git a/laser_encoders/models.py b/laser_encoders/models.py new file mode 100644 index 00000000..beaa6cbc --- /dev/null +++ b/laser_encoders/models.py @@ -0,0 +1,426 @@ +#!/usr/bin/python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# LASER Language-Agnostic SEntence Representations +# is a toolkit to calculate multilingual sentence embeddings +# and to use them for document classification, bitext filtering +# and mining +# +# -------------------------------------------------------- + + +import logging +import os +import re +import sys +import warnings +from collections import namedtuple +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +from fairseq.data.dictionary import Dictionary +from fairseq.models.transformer import Embedding, TransformerEncoder +from fairseq.modules import LayerNorm + +from laser_encoders.download_models import LaserModelDownloader +from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE +from laser_encoders.laser_tokenizer import LaserTokenizer, initialize_tokenizer + +SPACE_NORMALIZER = re.compile(r"\s+") +Batch = namedtuple("Batch", "srcs tokens lengths") + +logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", +) +logger = logging.getLogger("embed") + + +class SentenceEncoder: + def __init__( + self, + model_path, + max_sentences=None, + max_tokens=None, + spm_vocab=None, + spm_model=None, + cpu=False, + fp16=False, + verbose=False, + sort_kind="quicksort", + ): + if verbose: + logger.info(f"loading encoder: {model_path}") + self.spm_model = spm_model + if self.spm_model: + self.tokenizer = LaserTokenizer(spm_model=Path(self.spm_model)) + + self.use_cuda = torch.cuda.is_available() and not cpu + self.max_sentences = max_sentences + self.max_tokens = max_tokens + if self.max_tokens is None and self.max_sentences is None: + self.max_sentences = 1 + + state_dict = torch.load(model_path) + if "params" in state_dict: + self.encoder = LaserLstmEncoder(**state_dict["params"]) + self.encoder.load_state_dict(state_dict["model"]) + self.dictionary = state_dict["dictionary"] + self.prepend_bos = False + self.left_padding = False + else: + self.encoder = LaserTransformerEncoder(state_dict, spm_vocab) + self.dictionary = self.encoder.dictionary.indices + self.prepend_bos = state_dict["cfg"]["model"].prepend_bos + self.left_padding = state_dict["cfg"]["model"].left_pad_source + del state_dict + self.bos_index = self.dictionary[""] = 0 + self.pad_index = self.dictionary[""] = 1 + self.eos_index = self.dictionary[""] = 2 + self.unk_index = self.dictionary[""] = 3 + + if fp16: + self.encoder.half() + if self.use_cuda: + if verbose: + logger.info("transfer encoder to GPU") + self.encoder.cuda() + self.encoder.eval() + self.sort_kind = sort_kind + + def __call__(self, text_or_batch): + if self.spm_model: + text_or_batch = self.tokenizer(text_or_batch) + if isinstance(text_or_batch, str): + text_or_batch = [text_or_batch] + return self.encode_sentences(text_or_batch) + else: + raise ValueError( + "Either initialize the encoder with an spm_model or pre-tokenize and use the encode_sentences method." + ) + + def _process_batch(self, batch): + tokens = batch.tokens + lengths = batch.lengths + if self.use_cuda: + tokens = tokens.cuda() + lengths = lengths.cuda() + + with torch.no_grad(): + sentemb = self.encoder(tokens, lengths)["sentemb"] + embeddings = sentemb.detach().cpu().numpy() + return embeddings + + def _tokenize(self, line): + tokens = SPACE_NORMALIZER.sub(" ", line).strip().split() + ntokens = len(tokens) + if self.prepend_bos: + ids = torch.LongTensor(ntokens + 2) + ids[0] = self.bos_index + for i, token in enumerate(tokens): + ids[i + 1] = self.dictionary.get(token, self.unk_index) + ids[ntokens + 1] = self.eos_index + else: + ids = torch.LongTensor(ntokens + 1) + for i, token in enumerate(tokens): + ids[i] = self.dictionary.get(token, self.unk_index) + ids[ntokens] = self.eos_index + return ids + + def _make_batches(self, lines): + tokens = [self._tokenize(line) for line in lines] + lengths = np.array([t.numel() for t in tokens]) + indices = np.argsort(-lengths, kind=self.sort_kind) + + def batch(tokens, lengths, indices): + toks = tokens[0].new_full((len(tokens), tokens[0].shape[0]), self.pad_index) + if not self.left_padding: + for i in range(len(tokens)): + toks[i, : tokens[i].shape[0]] = tokens[i] + else: + for i in range(len(tokens)): + toks[i, -tokens[i].shape[0] :] = tokens[i] + return ( + Batch(srcs=None, tokens=toks, lengths=torch.LongTensor(lengths)), + indices, + ) + + batch_tokens, batch_lengths, batch_indices = [], [], [] + ntokens = nsentences = 0 + for i in indices: + if nsentences > 0 and ( + (self.max_tokens is not None and ntokens + lengths[i] > self.max_tokens) + or (self.max_sentences is not None and nsentences == self.max_sentences) + ): + yield batch(batch_tokens, batch_lengths, batch_indices) + ntokens = nsentences = 0 + batch_tokens, batch_lengths, batch_indices = [], [], [] + batch_tokens.append(tokens[i]) + batch_lengths.append(lengths[i]) + batch_indices.append(i) + ntokens += tokens[i].shape[0] + nsentences += 1 + if nsentences > 0: + yield batch(batch_tokens, batch_lengths, batch_indices) + + def encode_sentences(self, sentences, normalize_embeddings=False): + indices = [] + results = [] + for batch, batch_indices in self._make_batches(sentences): + indices.extend(batch_indices) + encoded_batch = self._process_batch(batch) + if normalize_embeddings: + # Perform L2 normalization on the embeddings + norms = np.linalg.norm(encoded_batch, axis=1, keepdims=True) + encoded_batch = encoded_batch / norms + results.append(encoded_batch) + return np.vstack(results)[np.argsort(indices, kind=self.sort_kind)] + + +class LaserTransformerEncoder(TransformerEncoder): + def __init__(self, state_dict, vocab_path): + self.dictionary = Dictionary.load(vocab_path) + if any( + k in state_dict["model"] + for k in ["encoder.layer_norm.weight", "layer_norm.weight"] + ): + self.dictionary.add_symbol("") + cfg = state_dict["cfg"]["model"] + self.sentemb_criterion = cfg.sentemb_criterion + self.pad_idx = self.dictionary.pad_index + self.bos_idx = self.dictionary.bos_index + embed_tokens = Embedding( + len(self.dictionary), + cfg.encoder_embed_dim, + self.pad_idx, + ) + super().__init__(cfg, self.dictionary, embed_tokens) + if "decoder.version" in state_dict["model"]: + self._remove_decoder_layers(state_dict) + if "layer_norm.weight" in state_dict["model"]: + self.layer_norm = LayerNorm(cfg.encoder_embed_dim) + self.load_state_dict(state_dict["model"]) + + def _remove_decoder_layers(self, state_dict): + for key in list(state_dict["model"].keys()): + if not key.startswith( + ( + "encoder.layer_norm", + "encoder.layers", + "encoder.embed", + "encoder.version", + ) + ): + del state_dict["model"][key] + else: + renamed_key = key.replace("encoder.", "") + state_dict["model"][renamed_key] = state_dict["model"].pop(key) + + def forward(self, src_tokens, src_lengths): + encoder_out = super().forward(src_tokens, src_lengths) + if isinstance(encoder_out, dict): + x = encoder_out["encoder_out"][0] # T x B x C + else: + x = encoder_out[0] + if self.sentemb_criterion == "cls": + cls_indices = src_tokens.eq(self.bos_idx).t() + sentemb = x[cls_indices, :] + else: + padding_mask = src_tokens.eq(self.pad_idx).t().unsqueeze(-1) + if padding_mask.any(): + x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) + sentemb = x.max(dim=0)[0] + return {"sentemb": sentemb} + + +class LaserLstmEncoder(nn.Module): + def __init__( + self, + num_embeddings, + padding_idx, + embed_dim=320, + hidden_size=512, + num_layers=1, + bidirectional=False, + left_pad=True, + padding_value=0.0, + ): + super().__init__() + + self.num_layers = num_layers + self.bidirectional = bidirectional + self.hidden_size = hidden_size + + self.padding_idx = padding_idx + self.embed_tokens = nn.Embedding( + num_embeddings, embed_dim, padding_idx=self.padding_idx + ) + + self.lstm = nn.LSTM( + input_size=embed_dim, + hidden_size=hidden_size, + num_layers=num_layers, + bidirectional=bidirectional, + ) + self.left_pad = left_pad + self.padding_value = padding_value + + self.output_units = hidden_size + if bidirectional: + self.output_units *= 2 + + def forward(self, src_tokens, src_lengths): + bsz, seqlen = src_tokens.size() + + # embed tokens + x = self.embed_tokens(src_tokens) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # pack embedded source tokens into a PackedSequence + packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) + + # apply LSTM + if self.bidirectional: + state_size = 2 * self.num_layers, bsz, self.hidden_size + else: + state_size = self.num_layers, bsz, self.hidden_size + h0 = x.data.new(*state_size).zero_() + c0 = x.data.new(*state_size).zero_() + packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) + + # unpack outputs and apply dropout + x, _ = nn.utils.rnn.pad_packed_sequence( + packed_outs, padding_value=self.padding_value + ) + assert list(x.size()) == [seqlen, bsz, self.output_units] + + if self.bidirectional: + + def combine_bidir(outs): + return torch.cat( + [ + torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view( + 1, bsz, self.output_units + ) + for i in range(self.num_layers) + ], + dim=0, + ) + + final_hiddens = combine_bidir(final_hiddens) + final_cells = combine_bidir(final_cells) + + encoder_padding_mask = src_tokens.eq(self.padding_idx).t() + + # Set padded outputs to -inf so they are not selected by max-pooling + padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) + if padding_mask.any(): + x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) + + # Build the sentence embedding by max-pooling over the encoder outputs + sentemb = x.max(dim=0)[0] + + return { + "sentemb": sentemb, + "encoder_out": (x, final_hiddens, final_cells), + "encoder_padding_mask": encoder_padding_mask + if encoder_padding_mask.any() + else None, + } + + +def initialize_encoder( + lang: str = None, + model_dir: str = None, + spm: bool = True, + laser: str = None, +): + downloader = LaserModelDownloader(model_dir) + if laser is not None: + if laser == "laser3": + lang = downloader.get_language_code(LASER3_LANGUAGE, lang) + downloader.download_laser3(lang=lang, spm=spm) + file_path = f"laser3-{lang}.v1" + elif laser == "laser2": + downloader.download_laser2() + file_path = "laser2" + else: + raise ValueError( + f"Unsupported laser model: {laser}. Choose either laser2 or laser3." + ) + else: + if lang in LASER3_LANGUAGE: + lang = downloader.get_language_code(LASER3_LANGUAGE, lang) + downloader.download_laser3(lang=lang, spm=spm) + file_path = f"laser3-{lang}.v1" + elif lang in LASER2_LANGUAGE: + downloader.download_laser2() + file_path = "laser2" + else: + raise ValueError( + f"Unsupported language name: {lang}. Please specify a supported language name." + ) + + model_dir = downloader.model_dir + model_path = os.path.join(model_dir, f"{file_path}.pt") + spm_vocab = os.path.join(model_dir, f"{file_path}.cvocab") + + if not os.path.exists(spm_vocab): + # if there is no cvocab for the laser3 lang use laser2 cvocab + spm_vocab = os.path.join(model_dir, "laser2.cvocab") + + return SentenceEncoder(model_path=model_path, spm_vocab=spm_vocab, spm_model=None) + + +class LaserEncoderPipeline: + def __init__( + self, + lang: str = None, + model_dir: str = None, + spm: bool = True, + laser: str = None, + ): + + if laser == "laser2" and lang is not None: + warnings.warn( + "Warning: The 'lang' parameter is optional when using 'laser2'. It will be ignored." + ) + + if laser == "laser3" and lang is None: + raise ValueError("For 'laser3', the 'lang' parameter is required.") + + if laser is None and lang is None: + raise ValueError("Either 'laser' or 'lang' should be provided.") + + self.tokenizer = initialize_tokenizer( + lang=lang, model_dir=model_dir, laser=laser + ) + self.encoder = initialize_encoder( + lang=lang, model_dir=model_dir, spm=spm, laser=laser + ) + + def encode_sentences( + self, sentences: list, normalize_embeddings: bool = False + ) -> list: + """ + Tokenizes and encodes a list of sentences. + + Args: + - sentences (list of str): List of sentences to tokenize and encode. + + Returns: + - List of embeddings for each sentence. + """ + tokenized_sentences = [ + self.tokenizer.tokenize(sentence) for sentence in sentences + ] + return self.encoder.encode_sentences(tokenized_sentences, normalize_embeddings) diff --git a/laser_encoders/test_laser_tokenizer.py b/laser_encoders/test_laser_tokenizer.py new file mode 100644 index 00000000..78a3aadd --- /dev/null +++ b/laser_encoders/test_laser_tokenizer.py @@ -0,0 +1,310 @@ +#!/usr/bin/python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# LASER Language-Agnostic SEntence Representations +# is a toolkit to calculate multilingual sentence embeddings +# and to use them for document classification, bitext filtering +# and mining +# +# -------------------------------------------------------- +# Tests for LaserTokenizer + +import os +import warnings +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import List + +import numpy as np +import pytest + +from laser_encoders import ( + LaserEncoderPipeline, + initialize_encoder, + initialize_tokenizer, +) + + +@pytest.fixture +def tokenizer(tmp_path: Path): + tokenizer_instance = initialize_tokenizer(model_dir=tmp_path, laser="laser2") + return tokenizer_instance + + +@pytest.fixture +def input_text() -> str: + return "This is a test sentence." + + +@pytest.fixture +def test_readme_params() -> dict: + return { + "lang": "igbo", + "input_sentences": ["nnọọ, kedu ka ị mere"], + "expected_embedding_shape": (1, 1024), + "expected_array": [ + 0.3807628, + -0.27941525, + -0.17819545, + 0.44144684, + -0.38985375, + 0.04719935, + 0.20238206, + -0.03934783, + 0.0118901, + 0.28986093, + ], + } + + +def test_tokenize(tokenizer, input_text: str): + expected_output = "▁this ▁is ▁a ▁test ▁sent ence ." + assert tokenizer.tokenize(input_text) == expected_output + + +def test_tokenizer_call_method(tokenizer, input_text: str): + single_string = "This is a test sentence." + expected_output = "▁this ▁is ▁a ▁test ▁sent ence ." + assert tokenizer(single_string) == expected_output + + list_of_strings = ["This is a test sentence.", "This is another test sentence."] + expected_output = [ + "▁this ▁is ▁a ▁test ▁sent ence .", + "▁this ▁is ▁another ▁test ▁sent ence .", + ] + assert tokenizer(list_of_strings) == expected_output + + +def test_normalization(tokenizer): + test_data = "Hello!!! How are you??? I'm doing great." + expected_output = "▁hel lo !!! ▁how ▁are ▁you ??? ▁i ' m ▁do ing ▁great ." + assert tokenizer.tokenize(test_data) == expected_output + + +def test_descape(tokenizer): + test_data = "I <3 Apple & Carrots!" + expected_output = "▁i ▁<3 ▁app le ▁& ▁car ro ts !" + tokenizer.descape = True + assert tokenizer.tokenize(test_data) == expected_output + + +def test_lowercase(tokenizer): + test_data = "THIS OUTPUT MUST BE UPPERCASE" + expected_output = "▁TH IS ▁ OU TP UT ▁ MU ST ▁BE ▁ UP PER CA SE" + tokenizer.lower_case = False + assert tokenizer.tokenize(test_data) == expected_output + + +def test_is_printable(tokenizer): + test_data = "Hello, \tWorld! ABC\x1f123" + expected_output = "▁hel lo , ▁world ! ▁ab c ▁12 3" + assert tokenizer.tokenize(test_data) == expected_output + + +def test_tokenize_file(tokenizer, input_text: str): + with TemporaryDirectory() as temp_dir: + input_file = os.path.join(temp_dir, "input.txt") + output_file = os.path.join(temp_dir, "output.txt") + + with open(input_file, "w") as file: + file.write(input_text) + + tokenizer.tokenize_file( + inp_fname=Path(input_file), + out_fname=Path(output_file), + ) + + with open(output_file, "r") as file: + output = file.read().strip() + + expected_output = "▁this ▁is ▁a ▁test ▁sent ence ." + assert output == expected_output + + +def test_tokenize_file_overwrite(tokenizer, input_text: str): + with TemporaryDirectory() as temp_dir: + input_file = os.path.join(temp_dir, "input.txt") + output_file = os.path.join(temp_dir, "output.txt") + + with open(input_file, "w") as file: + file.write(input_text) + + with open(output_file, "w") as file: + file.write("Existing output") + + # Test when over_write is False + tokenizer.over_write = False + tokenizer.tokenize_file( + inp_fname=Path(input_file), + out_fname=Path(output_file), + ) + + with open(output_file, "r") as file: + output = file.read().strip() + + assert output == "Existing output" + + # Test when over_write is True + tokenizer.over_write = True + tokenizer.tokenize_file( + inp_fname=Path(input_file), + out_fname=Path(output_file), + ) + + with open(output_file, "r") as file: + output = file.read().strip() + + expected_output = "▁this ▁is ▁a ▁test ▁sent ence ." + assert output == expected_output + + +@pytest.mark.parametrize( + "laser, expected_array, lang", + [ + ( + "laser2", + [ + 1.042462512850761414e-02, + 6.325428839772939682e-03, + -3.032622225873637944e-05, + 9.033476933836936951e-03, + 2.937933895736932755e-04, + 4.489220678806304932e-03, + 2.334521152079105377e-03, + -9.427300537936389446e-04, + -1.571535394759848714e-04, + 2.095808042213320732e-03, + ], + None, + ), + ( + "laser3", + [ + 3.038274645805358887e-01, + 4.151830971240997314e-01, + -2.458990514278411865e-01, + 3.153458833694458008e-01, + -5.153598189353942871e-01, + -6.035178527235984802e-02, + 2.210616767406463623e-01, + -2.701394855976104736e-01, + -4.902199506759643555e-01, + -3.126966953277587891e-02, + ], + "zul_Latn", + ), + ], +) +def test_sentence_encoder( + tmp_path: Path, + tokenizer, + laser: str, + expected_array: List, + lang: str, + input_text: str, +): + sentence_encoder = initialize_encoder(model_dir=tmp_path, laser=laser, lang=lang) + tokenized_text = tokenizer.tokenize(input_text) + sentence_embedding = sentence_encoder.encode_sentences([tokenized_text]) + + assert isinstance(sentence_embedding, np.ndarray) + assert sentence_embedding.shape == (1, 1024) + assert np.allclose(expected_array, sentence_embedding[:, :10], atol=1e-3) + + +def test_laser_encoder_pipeline(tmp_path: Path, test_readme_params: dict): + lang = test_readme_params["lang"] + input_sentences = test_readme_params["input_sentences"] + expected_embedding_shape = test_readme_params["expected_embedding_shape"] + expected_array = test_readme_params["expected_array"] + + encoder = LaserEncoderPipeline(model_dir=tmp_path, lang=lang) + embeddings = encoder.encode_sentences(input_sentences) + + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == expected_embedding_shape + assert np.allclose(expected_array, embeddings[:, :10], atol=1e-3) + + +def test_separate_initialization_and_encoding( + tmp_path, tokenizer, test_readme_params: dict +): + lang = test_readme_params["lang"] + input_sentences = test_readme_params["input_sentences"] + expected_embedding_shape = test_readme_params["expected_embedding_shape"] + expected_array = test_readme_params["expected_array"] + + tokenized_sentence = tokenizer.tokenize(input_sentences[0]) + sentence_encoder = initialize_encoder(model_dir=tmp_path, lang=lang) + + # Encode tokenized sentences into embeddings + embeddings = sentence_encoder.encode_sentences([tokenized_sentence]) + + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == expected_embedding_shape + assert np.allclose(expected_array, embeddings[:, :10], atol=1e-3) + + +def test_encoder_normalization(tmp_path: Path, test_readme_params: dict): + lang = test_readme_params["lang"] + input_sentences = test_readme_params["input_sentences"] + + encoder = LaserEncoderPipeline(model_dir=tmp_path, lang=lang) + normalized_embeddings = encoder.encode_sentences( + input_sentences, normalize_embeddings=True + ) + norm = np.linalg.norm(normalized_embeddings[0]) + + assert np.allclose(norm, 1.0, atol=1e-3) + + +def test_encoder_default_behaviour(tmp_path: Path, test_readme_params: dict): + lang = test_readme_params["lang"] + input_sentences = test_readme_params["input_sentences"] + + encoder = LaserEncoderPipeline(model_dir=tmp_path, lang=lang) + default_embeddings = encoder.encode_sentences(input_sentences) + non_normalized_embeddings = encoder.encode_sentences( + input_sentences, normalize_embeddings=False + ) + + assert np.allclose(default_embeddings, non_normalized_embeddings) + + +def test_encoder_non_normalization(tmp_path: Path, test_readme_params: dict): + lang = test_readme_params["lang"] + input_sentences = test_readme_params["input_sentences"] + + encoder = LaserEncoderPipeline(model_dir=tmp_path, lang=lang) + non_normalized_embeddings = encoder.encode_sentences( + input_sentences, normalize_embeddings=False + ) + norm = np.linalg.norm(non_normalized_embeddings[0]) + + assert not np.isclose(norm, 1) + + +def test_optional_lang_with_laser2(tmp_path: Path): + with pytest.warns( + UserWarning, + match="The 'lang' parameter is optional when using 'laser2'. It will be ignored.", + ): + encoder = LaserEncoderPipeline(lang="en", laser="laser2", model_dir=tmp_path) + + +def test_required_lang_with_laser3(tmp_path: Path): + with pytest.raises( + ValueError, match="For 'laser3', the 'lang' parameter is required." + ): + encoder = LaserEncoderPipeline(laser="laser3", model_dir=tmp_path) + + +def test_missing_lang_and_laser(tmp_path: Path): + with pytest.raises( + ValueError, match="Either 'laser' or 'lang' should be provided." + ): + encoder = LaserEncoderPipeline(model_dir=tmp_path) diff --git a/laser_encoders/test_models_initialization.py b/laser_encoders/test_models_initialization.py new file mode 100644 index 00000000..88e898fa --- /dev/null +++ b/laser_encoders/test_models_initialization.py @@ -0,0 +1,57 @@ +import os +import tempfile + +import pytest + +from laser_encoders.download_models import LaserModelDownloader +from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE +from laser_encoders.laser_tokenizer import initialize_tokenizer +from laser_encoders.models import initialize_encoder + + +def test_validate_achnese_models_and_tokenize_laser3(lang="acehnese"): + with tempfile.TemporaryDirectory() as tmp_dir: + print(f"Created temporary directory for {lang}", tmp_dir) + + downloader = LaserModelDownloader(model_dir=tmp_dir) + downloader.download_laser3(lang) + encoder = initialize_encoder(lang, model_dir=tmp_dir) + tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) + + # Test tokenization with a sample sentence + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") + + +def test_validate_english_models_and_tokenize_laser2(lang="english"): + with tempfile.TemporaryDirectory() as tmp_dir: + print(f"Created temporary directory for {lang}", tmp_dir) + + downloader = LaserModelDownloader(model_dir=tmp_dir) + downloader.download_laser2() + + encoder = initialize_encoder(lang, model_dir=tmp_dir) + tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) + + # Test tokenization with a sample sentence + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") + + +def test_validate_kashmiri_models_and_tokenize_laser3(lang="kas"): + with tempfile.TemporaryDirectory() as tmp_dir: + print(f"Created temporary directory for {lang}", tmp_dir) + + downloader = LaserModelDownloader(model_dir=tmp_dir) + with pytest.raises(ValueError): + downloader.download_laser3(lang) + + encoder = initialize_encoder(lang, model_dir=tmp_dir) + tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) + + # Test tokenization with a sample sentence + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") diff --git a/laser_encoders/validate_models.py b/laser_encoders/validate_models.py new file mode 100644 index 00000000..0748dfee --- /dev/null +++ b/laser_encoders/validate_models.py @@ -0,0 +1,108 @@ +import os +import tempfile + +import pytest + +from laser_encoders.download_models import LaserModelDownloader +from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE +from laser_encoders.laser_tokenizer import initialize_tokenizer +from laser_encoders.models import initialize_encoder + + +@pytest.mark.slow +@pytest.mark.parametrize("lang", LASER3_LANGUAGE) +def test_validate_language_models_and_tokenize_laser3(lang): + with tempfile.TemporaryDirectory() as tmp_dir: + print(f"Created temporary directory for {lang}", tmp_dir) + + downloader = LaserModelDownloader(model_dir=tmp_dir) + if lang in ["kashmiri", "kas", "central kanuri", "knc"]: + with pytest.raises(ValueError) as excinfo: + downloader.download_laser3(lang) + assert "ValueError" in str(excinfo.value) + print(f"{lang} language model raised a ValueError as expected.") + else: + downloader.download_laser3(lang) + encoder = initialize_encoder(lang, model_dir=tmp_dir) + tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) + + # Test tokenization with a sample sentence + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") + + +@pytest.mark.slow +@pytest.mark.parametrize("lang", LASER2_LANGUAGE) +def test_validate_language_models_and_tokenize_laser2(lang): + with tempfile.TemporaryDirectory() as tmp_dir: + print(f"Created temporary directory for {lang}", tmp_dir) + + downloader = LaserModelDownloader(model_dir=tmp_dir) + downloader.download_laser2() + + encoder = initialize_encoder(lang, model_dir=tmp_dir) + tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) + + # Test tokenization with a sample sentence + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") + + +class MockLaserModelDownloader(LaserModelDownloader): + def __init__(self, model_dir): + self.model_dir = model_dir + + def download_laser3(self, lang): + lang = self.get_language_code(LASER3_LANGUAGE, lang) + file_path = os.path.join(self.model_dir, f"laser3-{lang}.v1.pt") + if not os.path.exists(file_path): + raise FileNotFoundError(f"Could not find {file_path}.") + + def download_laser2(self): + files = ["laser2.pt", "laser2.spm", "laser2.cvocab"] + for file_name in files: + file_path = os.path.join(self.model_dir, file_name) + if not os.path.exists(file_path): + raise FileNotFoundError(f"Could not find {file_path}.") + + +CACHE_DIR = "/home/user/.cache/models" # Change this to the desired cache directory + +# This uses the mock downloader +@pytest.mark.slow +@pytest.mark.parametrize("lang", LASER3_LANGUAGE) +def test_validate_language_models_and_tokenize_mock_laser3(lang): + downloader = MockLaserModelDownloader(model_dir=CACHE_DIR) + + try: + downloader.download_laser3(lang) + except FileNotFoundError as e: + raise pytest.error(str(e)) + + encoder = initialize_encoder(lang, model_dir=CACHE_DIR) + tokenizer = initialize_tokenizer(lang, model_dir=CACHE_DIR) + + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") + + +# This uses the mock downloader +@pytest.mark.slow +@pytest.mark.parametrize("lang", LASER2_LANGUAGE) +def test_validate_language_models_and_tokenize_mock_laser2(lang): + downloader = MockLaserModelDownloader(model_dir=CACHE_DIR) + + try: + downloader.download_laser2() + except FileNotFoundError as e: + raise pytest.error(str(e)) + + encoder = initialize_encoder(lang, model_dir=CACHE_DIR) + tokenizer = initialize_tokenizer(lang, model_dir=CACHE_DIR) + + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..6e82f8bd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,69 @@ +[build-system] +requires = ["flit_core >=3.2,<4", "setuptools"] +build-backend = "flit_core.buildapi" + +[project] +name = "laser_encoders" +version = "0.0.1" +authors = [{name = "Facebook AI Research"}] +description = "LASER Language-Agnostic SEntence Representations is a toolkit to calculate multilingual sentence embeddings and to use them for document classification, bitext filtering and mining" +readme = "laser_encoders/README.md" +requires-python = ">=3.8" + +dependencies = [ + 'sacremoses==0.1.0', + 'unicategories>=0.1.2', + 'sentencepiece>=0.1.99', + 'numpy>=1.21.3', + 'torch>=1.10.0', + 'fairseq>=0.12.2', +] + +classifiers=[ + "License :: OSI Approved :: BSD License", + "Topic :: Scientific/Engineering", + "Development Status :: 4 - Beta", +] + +[project.urls] +"Homepage" = "https://github.com/facebookresearch/LASER" +"Bug Tracker" = "https://github.com/facebookresearch/LASER/issues" + +[project.optional-dependencies] + dev = [ + # Test + "pytest>=4.3.0", + # Format + "black==22.3.0", + "isort>=5.10.1", + # Linters + "mypy>=0.782", + "pylint>=2.8.0", + # Release + "flit>=3.5.1" + ] + +[tool.black] +# Black defaults are great ! + +[tool.isort] +profile = "black" +skip_gitignore = true +skip_glob = ["website/*", "*.pyx"] + +[tool.mypy] +python_version = "3.8" +show_error_codes = true +check_untyped_defs = true + +ignore_missing_imports = true + +files = [ + "laser_encoders/" +] + +[tool.pytest.ini_options] +testpaths = ["laser_encoders"] +python_files = [ + "test_*.py", +] \ No newline at end of file diff --git a/source/embed.py b/source/embed.py index 3737d7f9..9260a27c 100644 --- a/source/embed.py +++ b/source/embed.py @@ -16,31 +16,25 @@ # The functions can be also imported into another Python code -import re +import argparse +import logging import os -import tempfile +import re import sys +import tempfile import time -import argparse -import numpy as np -import logging from collections import namedtuple -from subprocess import run from pathlib import Path -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - +from subprocess import run +from typing import Optional, Union -from lib.text_processing import Token, BPEfastApply, SPMApply +assert os.environ.get("LASER"), "Please set the environment variable LASER" +LASER = os.environ["LASER"] +sys.path.append(LASER) -from fairseq.models.transformer import ( - Embedding, - TransformerEncoder, -) -from fairseq.data.dictionary import Dictionary -from fairseq.modules import LayerNorm +import numpy as np +from lib.text_processing import BPEfastApply, SPMApply, Token +from laser_encoders.models import SentenceEncoder SPACE_NORMALIZER = re.compile(r"\s+") Batch = namedtuple("Batch", "srcs tokens lengths") @@ -48,8 +42,10 @@ logging.basicConfig( stream=sys.stdout, level=logging.INFO, - format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") -logger = logging.getLogger('embed') + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", +) +logger = logging.getLogger("embed") + def buffered_read(fp, buffer_size): buffer = [] @@ -63,129 +59,10 @@ def buffered_read(fp, buffer_size): yield buffer -class SentenceEncoder: - def __init__( - self, - model_path, - max_sentences=None, - max_tokens=None, - spm_vocab=None, - cpu=False, - fp16=False, - verbose=False, - sort_kind="quicksort", - ): - if verbose: - logger.info(f"loading encoder: {model_path}") - self.use_cuda = torch.cuda.is_available() and not cpu - self.max_sentences = max_sentences - self.max_tokens = max_tokens - if self.max_tokens is None and self.max_sentences is None: - self.max_sentences = 1 - - state_dict = torch.load(model_path) - if "params" in state_dict: - self.encoder = LaserLstmEncoder(**state_dict["params"]) - self.encoder.load_state_dict(state_dict["model"]) - self.dictionary = state_dict["dictionary"] - self.prepend_bos = False - self.left_padding = False - else: - self.encoder = LaserTransformerEncoder(state_dict, spm_vocab) - self.dictionary = self.encoder.dictionary.indices - self.prepend_bos = state_dict["cfg"]["model"].prepend_bos - self.left_padding = state_dict["cfg"]["model"].left_pad_source - del state_dict - self.bos_index = self.dictionary[""] = 0 - self.pad_index = self.dictionary[""] = 1 - self.eos_index = self.dictionary[""] = 2 - self.unk_index = self.dictionary[""] = 3 - - if fp16: - self.encoder.half() - if self.use_cuda: - if verbose: - logger.info("transfer encoder to GPU") - self.encoder.cuda() - self.encoder.eval() - self.sort_kind = sort_kind - - def _process_batch(self, batch): - tokens = batch.tokens - lengths = batch.lengths - if self.use_cuda: - tokens = tokens.cuda() - lengths = lengths.cuda() - - with torch.no_grad(): - sentemb = self.encoder(tokens, lengths)["sentemb"] - embeddings = sentemb.detach().cpu().numpy() - return embeddings - - def _tokenize(self, line): - tokens = SPACE_NORMALIZER.sub(" ", line).strip().split() - ntokens = len(tokens) - if self.prepend_bos: - ids = torch.LongTensor(ntokens + 2) - ids[0] = self.bos_index - for i, token in enumerate(tokens): - ids[i + 1] = self.dictionary.get(token, self.unk_index) - ids[ntokens + 1] = self.eos_index - else: - ids = torch.LongTensor(ntokens + 1) - for i, token in enumerate(tokens): - ids[i] = self.dictionary.get(token, self.unk_index) - ids[ntokens] = self.eos_index - return ids - - def _make_batches(self, lines): - tokens = [self._tokenize(line) for line in lines] - lengths = np.array([t.numel() for t in tokens]) - indices = np.argsort(-lengths, kind=self.sort_kind) - - def batch(tokens, lengths, indices): - toks = tokens[0].new_full((len(tokens), tokens[0].shape[0]), self.pad_index) - if not self.left_padding: - for i in range(len(tokens)): - toks[i, : tokens[i].shape[0]] = tokens[i] - else: - for i in range(len(tokens)): - toks[i, -tokens[i].shape[0] :] = tokens[i] - return ( - Batch(srcs=None, tokens=toks, lengths=torch.LongTensor(lengths)), - indices, - ) - - batch_tokens, batch_lengths, batch_indices = [], [], [] - ntokens = nsentences = 0 - for i in indices: - if nsentences > 0 and ( - (self.max_tokens is not None and ntokens + lengths[i] > self.max_tokens) - or (self.max_sentences is not None and nsentences == self.max_sentences) - ): - yield batch(batch_tokens, batch_lengths, batch_indices) - ntokens = nsentences = 0 - batch_tokens, batch_lengths, batch_indices = [], [], [] - batch_tokens.append(tokens[i]) - batch_lengths.append(lengths[i]) - batch_indices.append(i) - ntokens += tokens[i].shape[0] - nsentences += 1 - if nsentences > 0: - yield batch(batch_tokens, batch_lengths, batch_indices) - - def encode_sentences(self, sentences): - indices = [] - results = [] - for batch, batch_indices in self._make_batches(sentences): - indices.extend(batch_indices) - results.append(self._process_batch(batch)) - return np.vstack(results)[np.argsort(indices, kind=self.sort_kind)] - - -class HuggingFaceEncoder(): +class HuggingFaceEncoder: def __init__(self, encoder_name: str, verbose=False): from sentence_transformers import SentenceTransformer + encoder = f"sentence-transformers/{encoder_name}" if verbose: logger.info(f"loading HuggingFace encoder: {encoder}") @@ -195,165 +72,13 @@ def encode_sentences(self, sentences): return self.encoder.encode(sentences) -class LaserTransformerEncoder(TransformerEncoder): - def __init__(self, state_dict, vocab_path): - self.dictionary = Dictionary.load(vocab_path) - if any( - k in state_dict["model"] - for k in ["encoder.layer_norm.weight", "layer_norm.weight"] - ): - self.dictionary.add_symbol("") - cfg = state_dict["cfg"]["model"] - self.sentemb_criterion = cfg.sentemb_criterion - self.pad_idx = self.dictionary.pad_index - self.bos_idx = self.dictionary.bos_index - embed_tokens = Embedding( - len(self.dictionary), cfg.encoder_embed_dim, self.pad_idx, - ) - super().__init__(cfg, self.dictionary, embed_tokens) - if "decoder.version" in state_dict["model"]: - self._remove_decoder_layers(state_dict) - if "layer_norm.weight" in state_dict["model"]: - self.layer_norm = LayerNorm(cfg.encoder_embed_dim) - self.load_state_dict(state_dict["model"]) - - def _remove_decoder_layers(self, state_dict): - for key in list(state_dict["model"].keys()): - if not key.startswith( - ( - "encoder.layer_norm", - "encoder.layers", - "encoder.embed", - "encoder.version", - ) - ): - del state_dict["model"][key] - else: - renamed_key = key.replace("encoder.", "") - state_dict["model"][renamed_key] = state_dict["model"].pop(key) - - def forward(self, src_tokens, src_lengths): - encoder_out = super().forward(src_tokens, src_lengths) - if isinstance(encoder_out, dict): - x = encoder_out["encoder_out"][0] # T x B x C - else: - x = encoder_out[0] - if self.sentemb_criterion == "cls": - cls_indices = src_tokens.eq(self.bos_idx).t() - sentemb = x[cls_indices, :] - else: - padding_mask = src_tokens.eq(self.pad_idx).t().unsqueeze(-1) - if padding_mask.any(): - x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) - sentemb = x.max(dim=0)[0] - return {"sentemb": sentemb} - - -class LaserLstmEncoder(nn.Module): - def __init__( - self, - num_embeddings, - padding_idx, - embed_dim=320, - hidden_size=512, - num_layers=1, - bidirectional=False, - left_pad=True, - padding_value=0.0, - ): - super().__init__() - - self.num_layers = num_layers - self.bidirectional = bidirectional - self.hidden_size = hidden_size - - self.padding_idx = padding_idx - self.embed_tokens = nn.Embedding( - num_embeddings, embed_dim, padding_idx=self.padding_idx - ) - - self.lstm = nn.LSTM( - input_size=embed_dim, - hidden_size=hidden_size, - num_layers=num_layers, - bidirectional=bidirectional, - ) - self.left_pad = left_pad - self.padding_value = padding_value - - self.output_units = hidden_size - if bidirectional: - self.output_units *= 2 - - def forward(self, src_tokens, src_lengths): - bsz, seqlen = src_tokens.size() - - # embed tokens - x = self.embed_tokens(src_tokens) - - # B x T x C -> T x B x C - x = x.transpose(0, 1) - - # pack embedded source tokens into a PackedSequence - packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) - - # apply LSTM - if self.bidirectional: - state_size = 2 * self.num_layers, bsz, self.hidden_size - else: - state_size = self.num_layers, bsz, self.hidden_size - h0 = x.data.new(*state_size).zero_() - c0 = x.data.new(*state_size).zero_() - packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) - - # unpack outputs and apply dropout - x, _ = nn.utils.rnn.pad_packed_sequence( - packed_outs, padding_value=self.padding_value - ) - assert list(x.size()) == [seqlen, bsz, self.output_units] - - if self.bidirectional: - - def combine_bidir(outs): - return torch.cat( - [ - torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view( - 1, bsz, self.output_units - ) - for i in range(self.num_layers) - ], - dim=0, - ) - - final_hiddens = combine_bidir(final_hiddens) - final_cells = combine_bidir(final_cells) - - encoder_padding_mask = src_tokens.eq(self.padding_idx).t() - - # Set padded outputs to -inf so they are not selected by max-pooling - padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) - if padding_mask.any(): - x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) - - # Build the sentence embedding by max-pooling over the encoder outputs - sentemb = x.max(dim=0)[0] - - return { - "sentemb": sentemb, - "encoder_out": (x, final_hiddens, final_cells), - "encoder_padding_mask": encoder_padding_mask - if encoder_padding_mask.any() - else None, - } - - def load_model( encoder: str, spm_model: str, bpe_codes: str, hugging_face=False, verbose=False, - **encoder_kwargs + **encoder_kwargs, ) -> Union[SentenceEncoder, HuggingFaceEncoder]: if hugging_face: return HuggingFaceEncoder(encoder, verbose=verbose) @@ -411,7 +136,6 @@ def EncodeFilep( logger.info(f"encoded {n} sentences in {EncodeTime(t)}") - # Encode sentences (file names) def EncodeFile( encoder, @@ -428,7 +152,8 @@ def EncodeFile( if verbose: logger.info( "encoding {} to {}".format( - inp_fname if len(inp_fname) > 0 else "stdin", out_fname, + inp_fname if len(inp_fname) > 0 else "stdin", + out_fname, ) ) fin = ( @@ -469,7 +194,7 @@ def embed_sentences( output: str, encoder: Union[SentenceEncoder, HuggingFaceEncoder] = None, encoder_path: str = None, - hugging_face = False, + hugging_face=False, token_lang: Optional[str] = "--", bpe_codes: Optional[str] = None, spm_lang: Optional[str] = "en", @@ -522,7 +247,7 @@ def embed_sentences( if bpe_codes: if ifname == "": # stdin ifname = os.path.join(tmpdir, "no_tok") - run(f'cat > {ifname}', shell=True) + run(f"cat > {ifname}", shell=True) bpe_fname = os.path.join(tmpdir, "bpe") BPEfastApply( ifname, bpe_fname, bpe_codes, verbose=verbose, over_write=False @@ -556,7 +281,11 @@ def embed_sentences( if __name__ == "__main__": parser = argparse.ArgumentParser(description="LASER: Embed sentences") parser.add_argument( - "-i", "--input", type=str, default=None, help="Input text file", + "-i", + "--input", + type=str, + default=None, + help="Input text file", ) parser.add_argument("--encoder", type=str, required=True, help="encoder to be used") parser.add_argument( @@ -608,7 +337,9 @@ def embed_sentences( help="Algorithm used to sort batch by length", ) parser.add_argument( - "--use-hugging-face", action="store_true", help="Use a HuggingFace sentence transformer" + "--use-hugging-face", + action="store_true", + help="Use a HuggingFace sentence transformer", ) args = parser.parse_args() diff --git a/utils/requirements.txt b/utils/requirements.txt index 00816939..a47ad724 100644 --- a/utils/requirements.txt +++ b/utils/requirements.txt @@ -3,6 +3,6 @@ sentence-splitter==1.4 botok==0.8.8 khmer-nltk==1.5 LaoNLP==0.6 -sacremoses==0.0.43 +sacremoses==0.1.0 xxhash==3.0.0 emoji==1.7.0 \ No newline at end of file diff --git a/utils/setup.py b/utils/setup.py index d1710211..cb657a0b 100644 --- a/utils/setup.py +++ b/utils/setup.py @@ -15,7 +15,7 @@ "botok==0.8.8", "khmer-nltk==1.5", "LaoNLP==0.6", - "sacremoses==0.0.43", + "sacremoses==0.1.0", "xxhash==3.0.0", "emoji==1.7.0", ], diff --git a/utils/src/cleaner_splitter.py b/utils/src/cleaner_splitter.py index 9fd08814..8d7bdb47 100644 --- a/utils/src/cleaner_splitter.py +++ b/utils/src/cleaner_splitter.py @@ -20,7 +20,7 @@ def __init__(self, splitter_lang: str, split_algo: str): self.splitter = get_split_algo(splitter_lang, split_algo=split_algo) # setup "moses" normalization - self.mpn = MosesPunctNormalizer(lang="en") # TODO + self.mpn = MosesPunctNormalizer(lang="en", perl_parity=True) # TODO self.replace_nonprint = non_printing_char_replacer(" ") def __call__(self, line):