Skip to content

Commit

Permalink
style: Reformat code using black
Browse files Browse the repository at this point in the history
  • Loading branch information
Paulooh007 committed Oct 22, 2023
1 parent 049f2e2 commit 67ba8bb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
6 changes: 5 additions & 1 deletion laser_encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@
#
# -------------------------------------------------------

from laser_encoders.download_models import initialize_encoder, initialize_tokenizer
from laser_encoders.download_models import (
LaserEncoderPipeline,
initialize_encoder,
initialize_tokenizer,
)
26 changes: 16 additions & 10 deletions laser_encoders/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,8 @@ def initialize_encoder(
if not os.path.exists(spm_vocab):
# if there is no cvocab for the laser3 lang use laser2 cvocab
spm_vocab = os.path.join(model_dir, "laser2.cvocab")

return SentenceEncoder(
model_path=model_path, spm_vocab=spm_vocab, spm_model=None
)

return SentenceEncoder(model_path=model_path, spm_vocab=spm_vocab, spm_model=None)


def initialize_tokenizer(lang: str = None, model_dir: str = None, laser: str = None):
Expand Down Expand Up @@ -199,24 +197,32 @@ def initialize_tokenizer(lang: str = None, model_dir: str = None, laser: str = N


class LaserEncoderPipeline:
def __init__(self, lang: str, model_dir: str = None, spm: bool = True, laser: str = None):
def __init__(
self, lang: str, model_dir: str = None, spm: bool = True, laser: str = None
):
self.tokenizer = initialize_tokenizer(
lang=lang, model_dir=model_dir, laser=laser
)
self.encoder = initialize_encoder(
lang=lang, model_dir=model_dir, spm=spm, laser=laser
)

self.tokenizer = initialize_tokenizer(lang=lang, model_dir=model_dir, laser=laser)
self.encoder = initialize_encoder(lang=lang, model_dir=model_dir, spm=spm,laser=laser)

def encode_sentences(self, sentences: list) -> list:
"""
Tokenizes and encodes a list of sentences.
Args:
- sentences (list of str): List of sentences to tokenize and encode.
Returns:
- List of embeddings for each sentence.
"""
tokenized_sentences = [self.tokenizer.tokenize(sentence) for sentence in sentences]
tokenized_sentences = [
self.tokenizer.tokenize(sentence) for sentence in sentences
]
return self.encoder.encode_sentences(tokenized_sentences)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LASER: Download Laser models")
parser.add_argument(
Expand Down

0 comments on commit 67ba8bb

Please sign in to comment.