diff --git a/laser_encoders/models.py b/laser_encoders/models.py index 037a4f9f..8b25bf50 100644 --- a/laser_encoders/models.py +++ b/laser_encoders/models.py @@ -167,15 +167,21 @@ def batch(tokens, lengths, indices): if nsentences > 0: yield batch(batch_tokens, batch_lengths, batch_indices) - def encode_sentences(self, sentences): + def encode_sentences(self, sentences, normalize_embeddings=False): indices = [] results = [] for batch, batch_indices in self._make_batches(sentences): indices.extend(batch_indices) - results.append(self._process_batch(batch)) + encoded_batch = self._process_batch(batch) + if normalize_embeddings: + # Perform L2 normalization on the embeddings + norms = np.linalg.norm(encoded_batch, axis=1, keepdims=True) + encoded_batch = encoded_batch / norms + results.append(encoded_batch) return np.vstack(results)[np.argsort(indices, kind=self.sort_kind)] + class LaserTransformerEncoder(TransformerEncoder): def __init__(self, state_dict, vocab_path): self.dictionary = Dictionary.load(vocab_path) @@ -384,7 +390,7 @@ def __init__( lang=lang, model_dir=model_dir, spm=spm, laser=laser ) - def encode_sentences(self, sentences: list) -> list: + def encode_sentences(self, sentences: list, normalize_embeddings: bool = False) -> list: """ Tokenizes and encodes a list of sentences. @@ -397,4 +403,4 @@ def encode_sentences(self, sentences: list) -> list: tokenized_sentences = [ self.tokenizer.tokenize(sentence) for sentence in sentences ] - return self.encoder.encode_sentences(tokenized_sentences) + return self.encoder.encode_sentences(tokenized_sentences, normalize_embeddings)