diff --git a/laser_encoders/download_models.py b/laser_encoders/download_models.py index ccfc31e9..c4ace448 100644 --- a/laser_encoders/download_models.py +++ b/laser_encoders/download_models.py @@ -117,6 +117,7 @@ def initialize_encoder( model_dir: str = None, spm: bool = True, laser: str = None, + tokenize: bool = None, ): downloader = LaserModelDownloader(model_dir) if laser is not None: @@ -147,11 +148,17 @@ def initialize_encoder( model_dir = downloader.model_dir model_path = os.path.join(model_dir, f"{file_path}.pt") spm_path = os.path.join(model_dir, f"{file_path}.cvocab") + spm_model = None + if tokenize: + spm_model = os.path.join(model_dir, f"{file_path}.spm") if not os.path.exists(spm_path): # if there is no cvocab for the laser3 lang use laser2 cvocab spm_path = os.path.join(model_dir, "laser2.cvocab") - return SentenceEncoder(model_path=model_path, spm_vocab=spm_path) + spm_model = os.path.join(model_dir, "laser2.spm") + return SentenceEncoder( + model_path=model_path, spm_vocab=spm_path, spm_model=spm_model + ) def initialize_tokenizer(lang: str = None, model_dir: str = None, laser: str = None): diff --git a/laser_encoders/models.py b/laser_encoders/models.py index 7ce0e326..0a36d49b 100644 --- a/laser_encoders/models.py +++ b/laser_encoders/models.py @@ -17,6 +17,7 @@ import re import sys from collections import namedtuple +from pathlib import Path import numpy as np import torch @@ -25,6 +26,8 @@ from fairseq.models.transformer import Embedding, TransformerEncoder from fairseq.modules import LayerNorm +from laser_encoders.laser_tokenizer import LaserTokenizer + SPACE_NORMALIZER = re.compile(r"\s+") Batch = namedtuple("Batch", "srcs tokens lengths") @@ -43,6 +46,7 @@ def __init__( max_sentences=None, max_tokens=None, spm_vocab=None, + spm_model=None, cpu=False, fp16=False, verbose=False, @@ -50,6 +54,7 @@ def __init__( ): if verbose: logger.info(f"loading encoder: {model_path}") + self.spm_model = spm_model self.use_cuda = torch.cuda.is_available() and not cpu self.max_sentences = max_sentences self.max_tokens = max_tokens @@ -148,6 +153,10 @@ def batch(tokens, lengths, indices): yield batch(batch_tokens, batch_lengths, batch_indices) def encode_sentences(self, sentences): + if self.spm_model: + tokenizer = LaserTokenizer(spm_model=Path(self.spm_model)) + sentences = tokenizer(sentences) + indices = [] results = [] for batch, batch_indices in self._make_batches(sentences):