From 8175af90e2e9c0b5956ff8dceb9fc41f70d118a1 Mon Sep 17 00:00:00 2001 From: paul Date: Thu, 19 Oct 2023 07:03:20 +0100 Subject: [PATCH] refactor: Change initialize_encoder to LaserEncoderPipeline and set tokenize default to true --- laser_encoders/__init__.py | 2 +- laser_encoders/download_models.py | 4 ++-- laser_encoders/test_laser_tokenizer.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/laser_encoders/__init__.py b/laser_encoders/__init__.py index 75264c55..e355084e 100644 --- a/laser_encoders/__init__.py +++ b/laser_encoders/__init__.py @@ -12,4 +12,4 @@ # # ------------------------------------------------------- -from laser_encoders.download_models import initialize_encoder, initialize_tokenizer +from laser_encoders.download_models import LaserEncoderPipeline, initialize_tokenizer diff --git a/laser_encoders/download_models.py b/laser_encoders/download_models.py index 452501d3..6813b30c 100644 --- a/laser_encoders/download_models.py +++ b/laser_encoders/download_models.py @@ -121,12 +121,12 @@ def main(self, args): ) -def initialize_encoder( +def LaserEncoderPipeline( lang: str = None, model_dir: str = None, spm: bool = True, laser: str = None, - tokenize: bool = False, + tokenize: bool = True, ): downloader = LaserModelDownloader(model_dir) if laser is not None: diff --git a/laser_encoders/test_laser_tokenizer.py b/laser_encoders/test_laser_tokenizer.py index 867111cf..6f0787c5 100644 --- a/laser_encoders/test_laser_tokenizer.py +++ b/laser_encoders/test_laser_tokenizer.py @@ -21,7 +21,7 @@ import numpy as np import pytest -from laser_encoders import initialize_encoder, initialize_tokenizer +from laser_encoders import LaserEncoderPipeline, initialize_tokenizer @pytest.fixture @@ -168,7 +168,7 @@ def test_sentence_encoder( lang: str, input_text: str, ): - sentence_encoder = initialize_encoder(model_dir=tmp_path, laser=laser, lang=lang) + sentence_encoder = LaserEncoderPipeline(model_dir=tmp_path, laser=laser, lang=lang, tokenize=False) tokenized_text = tokenizer.tokenize(input_text) sentence_embedding = sentence_encoder.encode_sentences([tokenized_text])