Skip to content

Commit

Permalink
refactor: Change initialize_encoder to LaserEncoderPipeline and set t…
Browse files Browse the repository at this point in the history
…okenize default to true
  • Loading branch information
Paulooh007 committed Oct 19, 2023
1 parent e3257c1 commit 8175af9
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion laser_encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
#
# -------------------------------------------------------

from laser_encoders.download_models import initialize_encoder, initialize_tokenizer
from laser_encoders.download_models import LaserEncoderPipeline, initialize_tokenizer
4 changes: 2 additions & 2 deletions laser_encoders/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ def main(self, args):
)


def initialize_encoder(
def LaserEncoderPipeline(
lang: str = None,
model_dir: str = None,
spm: bool = True,
laser: str = None,
tokenize: bool = False,
tokenize: bool = True,
):
downloader = LaserModelDownloader(model_dir)
if laser is not None:
Expand Down
4 changes: 2 additions & 2 deletions laser_encoders/test_laser_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import pytest

from laser_encoders import initialize_encoder, initialize_tokenizer
from laser_encoders import LaserEncoderPipeline, initialize_tokenizer


@pytest.fixture
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_sentence_encoder(
lang: str,
input_text: str,
):
sentence_encoder = initialize_encoder(model_dir=tmp_path, laser=laser, lang=lang)
sentence_encoder = LaserEncoderPipeline(model_dir=tmp_path, laser=laser, lang=lang, tokenize=False)
tokenized_text = tokenizer.tokenize(input_text)
sentence_embedding = sentence_encoder.encode_sentences([tokenized_text])

Expand Down

0 comments on commit 8175af9

Please sign in to comment.