-
Notifications
You must be signed in to change notification settings - Fork 1
/
emojifinder.py
44 lines (36 loc) · 1.21 KB
/
emojifinder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from dataclasses import dataclass
from pathlib import Path
from typing import Union
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
@dataclass
class Emoji:
symbol: str
keywords: list[str]
def get_vectors(
model: SentenceTransformer,
emojis: list[Emoji],
embeddings_path: Union[str, Path] = Path("embeddings.npy"),
) -> np.ndarray:
if Path(embeddings_path).exists():
# if npy file exists load vectors from disk
embeddings = np.load(embeddings_path)
else:
# otherwise embed texts and save vectors to disk
embeddings = model.encode(sentences=[" ".join(e.keywords) for e in emojis])
np.save(embeddings_path, embeddings)
return embeddings
def find_emoji(
query: str,
emojis: list[Emoji],
model: SentenceTransformer,
embeddings,
n=1,
) -> list[Emoji]:
"""embed file, calculate similarity to existing embeddings, return top n hits"""
embedded_desc: torch.Tensor = model.encode(query, convert_to_tensor=True) # type: ignore
sims = cos_sim(embedded_desc, embeddings)
top_n = sims.argsort(descending=True)[0][:n]
return [emojis[i] for i in top_n]