Skip to content

Commit

Permalink
Add normalize_embeddings option to encode_sentences
Browse files Browse the repository at this point in the history
  • Loading branch information
Paulooh007 committed Nov 7, 2023
1 parent 7d4c469 commit 2cc5434
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions laser_encoders/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,21 @@ def batch(tokens, lengths, indices):
if nsentences > 0:
yield batch(batch_tokens, batch_lengths, batch_indices)

def encode_sentences(self, sentences):
def encode_sentences(self, sentences, normalize_embeddings=False):
indices = []
results = []
for batch, batch_indices in self._make_batches(sentences):
indices.extend(batch_indices)
results.append(self._process_batch(batch))
encoded_batch = self._process_batch(batch)
if normalize_embeddings:
# Perform L2 normalization on the embeddings
norms = np.linalg.norm(encoded_batch, axis=1, keepdims=True)
encoded_batch = encoded_batch / norms
results.append(encoded_batch)
return np.vstack(results)[np.argsort(indices, kind=self.sort_kind)]



class LaserTransformerEncoder(TransformerEncoder):
def __init__(self, state_dict, vocab_path):
self.dictionary = Dictionary.load(vocab_path)
Expand Down Expand Up @@ -384,7 +390,7 @@ def __init__(
lang=lang, model_dir=model_dir, spm=spm, laser=laser
)

def encode_sentences(self, sentences: list) -> list:
def encode_sentences(self, sentences: list, normalize_embeddings: bool = False) -> list:
"""
Tokenizes and encodes a list of sentences.
Expand All @@ -397,4 +403,4 @@ def encode_sentences(self, sentences: list) -> list:
tokenized_sentences = [
self.tokenizer.tokenize(sentence) for sentence in sentences
]
return self.encoder.encode_sentences(tokenized_sentences)
return self.encoder.encode_sentences(tokenized_sentences, normalize_embeddings)

0 comments on commit 2cc5434

Please sign in to comment.