diff --git a/laser_encoders/__init__.py b/laser_encoders/__init__.py index 75264c55..bd01969b 100644 --- a/laser_encoders/__init__.py +++ b/laser_encoders/__init__.py @@ -12,4 +12,8 @@ # # ------------------------------------------------------- -from laser_encoders.download_models import initialize_encoder, initialize_tokenizer +from laser_encoders.download_models import ( + LaserEncoderPipeline, + initialize_encoder, + initialize_tokenizer, +) diff --git a/laser_encoders/download_models.py b/laser_encoders/download_models.py index 0a19bca2..0f585f2f 100644 --- a/laser_encoders/download_models.py +++ b/laser_encoders/download_models.py @@ -160,10 +160,8 @@ def initialize_encoder( if not os.path.exists(spm_vocab): # if there is no cvocab for the laser3 lang use laser2 cvocab spm_vocab = os.path.join(model_dir, "laser2.cvocab") - - return SentenceEncoder( - model_path=model_path, spm_vocab=spm_vocab, spm_model=None - ) + + return SentenceEncoder(model_path=model_path, spm_vocab=spm_vocab, spm_model=None) def initialize_tokenizer(lang: str = None, model_dir: str = None, laser: str = None): @@ -199,24 +197,32 @@ def initialize_tokenizer(lang: str = None, model_dir: str = None, laser: str = N class LaserEncoderPipeline: - def __init__(self, lang: str, model_dir: str = None, spm: bool = True, laser: str = None): + def __init__( + self, lang: str, model_dir: str = None, spm: bool = True, laser: str = None + ): + self.tokenizer = initialize_tokenizer( + lang=lang, model_dir=model_dir, laser=laser + ) + self.encoder = initialize_encoder( + lang=lang, model_dir=model_dir, spm=spm, laser=laser + ) - self.tokenizer = initialize_tokenizer(lang=lang, model_dir=model_dir, laser=laser) - self.encoder = initialize_encoder(lang=lang, model_dir=model_dir, spm=spm,laser=laser) - def encode_sentences(self, sentences: list) -> list: """ Tokenizes and encodes a list of sentences. - + Args: - sentences (list of str): List of sentences to tokenize and encode. Returns: - List of embeddings for each sentence. """ - tokenized_sentences = [self.tokenizer.tokenize(sentence) for sentence in sentences] + tokenized_sentences = [ + self.tokenizer.tokenize(sentence) for sentence in sentences + ] return self.encoder.encode_sentences(tokenized_sentences) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="LASER: Download Laser models") parser.add_argument(