Skip to content

Commit

Permalink
refactore: modified the sentence encoder to tokenize a text before en…
Browse files Browse the repository at this point in the history
…codingit
  • Loading branch information
CaptainVee committed Sep 8, 2023
1 parent b97fd24 commit d8e6983
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
9 changes: 8 additions & 1 deletion laser_encoders/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def initialize_encoder(
model_dir: str = None,
spm: bool = True,
laser: str = None,
tokenize: bool = None,
):
downloader = LaserModelDownloader(model_dir)
if laser is not None:
Expand Down Expand Up @@ -147,11 +148,17 @@ def initialize_encoder(
model_dir = downloader.model_dir
model_path = os.path.join(model_dir, f"{file_path}.pt")
spm_path = os.path.join(model_dir, f"{file_path}.cvocab")
spm_model = None
if tokenize:
spm_model = os.path.join(model_dir, f"{file_path}.spm")

if not os.path.exists(spm_path):
# if there is no cvocab for the laser3 lang use laser2 cvocab
spm_path = os.path.join(model_dir, "laser2.cvocab")
return SentenceEncoder(model_path=model_path, spm_vocab=spm_path)
spm_model = os.path.join(model_dir, "laser2.spm")
return SentenceEncoder(
model_path=model_path, spm_vocab=spm_path, spm_model=spm_model
)


def initialize_tokenizer(lang: str = None, model_dir: str = None, laser: str = None):
Expand Down
9 changes: 9 additions & 0 deletions laser_encoders/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import re
import sys
from collections import namedtuple
from pathlib import Path

import numpy as np
import torch
Expand All @@ -25,6 +26,8 @@
from fairseq.models.transformer import Embedding, TransformerEncoder
from fairseq.modules import LayerNorm

from laser_encoders.laser_tokenizer import LaserTokenizer

SPACE_NORMALIZER = re.compile(r"\s+")
Batch = namedtuple("Batch", "srcs tokens lengths")

Expand All @@ -43,13 +46,15 @@ def __init__(
max_sentences=None,
max_tokens=None,
spm_vocab=None,
spm_model=None,
cpu=False,
fp16=False,
verbose=False,
sort_kind="quicksort",
):
if verbose:
logger.info(f"loading encoder: {model_path}")
self.spm_model = spm_model
self.use_cuda = torch.cuda.is_available() and not cpu
self.max_sentences = max_sentences
self.max_tokens = max_tokens
Expand Down Expand Up @@ -148,6 +153,10 @@ def batch(tokens, lengths, indices):
yield batch(batch_tokens, batch_lengths, batch_indices)

def encode_sentences(self, sentences):
if self.spm_model:
tokenizer = LaserTokenizer(spm_model=Path(self.spm_model))
sentences = tokenizer(sentences)

indices = []
results = []
for batch, batch_indices in self._make_batches(sentences):
Expand Down

0 comments on commit d8e6983

Please sign in to comment.