-
Notifications
You must be signed in to change notification settings - Fork 3
/
find_relevant_abstracts.py
133 lines (122 loc) · 3.67 KB
/
find_relevant_abstracts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from pathlib import Path
import pandas as pd
import click
from patents4IPPC.embedders.utils import get_embedder
from patents4IPPC.similarity_search.faiss_ import FaissDocumentRetriever
@click.command()
@click.option(
'-mt', '--model-type',
type=click.Choice(['tfidf', 'glove', 'use', 'huggingface', 'dual']),
required=True,
help='Type of model to use for indexing the corpus.'
)
@click.option(
'-mc', '--model-checkpoint', 'path_to_model_checkpoint',
type=click.Path(exists=True),
required=True,
help='Path to a pre-trained model to use for finding relevant abstracts.'
)
@click.option(
'-p', '--pooling-mode',
type=click.Choice(['cls', 'max', 'mean']),
default=None,
help=('Pooling strategy for aggregating token embeddings into sentence '
'embeddings. Required only when "--model-type" is "huggingface" '
'or "dual".')
)
@click.option(
'-i', '--index', 'path_to_faiss_index',
type=click.Path(exists=True, dir_okay=False),
required=True,
help='Path to a FAISS index containing pre-computed response embeddings.'
)
@click.option(
'-k', '--top-k', 'k',
type=int,
required=True,
help='Number of relevant abstracts to retrieve for each query.'
)
@click.option(
'-d', '--dataset', 'path_to_dataset',
type=click.Path(exists=True, dir_okay=False),
required=True,
help=('Path to a dataset that will be used to map the retrieved IDs to '
'the corresponding abstracts.')
)
@click.option(
'--id-column',
required=True,
help='Name of the dataset column that represents patent IDs.'
)
@click.option(
'--abstract-column',
required=True,
help='Name of the dataset column that represents patent abstracts.'
)
@click.option(
'-g', '--use-gpu',
is_flag=True,
help='Use a GPU for retrieving relevant abstracts with FAISS.'
)
@click.argument(
'input_files',
type=click.Path(exists=True, dir_okay=False),
nargs=-1
)
@click.argument(
'output_path',
type=click.Path(exists=False)
)
def main(
model_type,
path_to_model_checkpoint,
pooling_mode,
path_to_faiss_index,
k,
path_to_dataset,
id_column,
abstract_column,
use_gpu,
input_files,
output_path
):
# Read input files
queries = [Path(f).read_text() for f in input_files]
# Load an embedder
print('Loading the embedder...')
embedder = get_embedder(model_type, path_to_model_checkpoint, pooling_mode)
# Embed the queries
print('Embedding the queries...')
query_embeddings = embedder.embed_documents(queries)
del queries
del embedder
# Find the k closest abstracts for each query
retriever = FaissDocumentRetriever(
path_to_faiss_index, use_gpu=use_gpu, verbose=True
)
scores, ids = retriever.find_closest_matches(query_embeddings, k=k)
del retriever
# Use "ids" to retrieve the actual abstracts from the dataset
print('Retrieving abstracts from IDs...')
ids_flat = ids.reshape((-1,))
dataset = pd.read_csv(
path_to_dataset,
index_col=id_column,
encoding='latin1'
)
closest_abstracts = dataset.loc[ids_flat, abstract_column].values
del dataset
# Save the results to disk
queries_flat = [q for q in queries for _ in range(k)]
query_names_flat = [Path(f).name for f in input_files for _ in range(k)]
scores_flat = scores.reshape((-1,))
results = pd.DataFrame({
'query': queries_flat,
'query_id': query_names_flat,
'response': closest_abstracts,
'response_id': ids_flat,
'score': scores_flat
})
results.to_csv(output_path, index=False)
if __name__ == '__main__':
main() # pylint: disable=no-value-for-parameter