-
Notifications
You must be signed in to change notification settings - Fork 1
/
NER.py
152 lines (127 loc) · 6.21 KB
/
NER.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Shamelessly stollen and modified from Chunyu via: https://github.com/RTXteam/NL2TRAPI/blob/main/srcs/TRAPI_NER.py
## Import libraries
import os, sys
from utils import get_logger
from typing import Union, List
import spacy
import scispacy
from scispacy.linking import EntityLinker
from scispacy.abbreviation import AbbreviationDetector
import json
class TRAPI_NER:
def __init__(self, synonymizer_dir: str, synonymizer_dbname: str,
spacy_model: str = 'en_core_sci_lg',
linker_name: Union[str, List] = 'umls',
threshold: float = 0.99,
num_neighbors: int = 1,
max_entities_per_mention: int = 1):
# Setup logger
self.logger = get_logger()
# Import Node Synonymizer
sys.path.append(synonymizer_dir)
from node_synonymizer import NodeSynonymizer
self.synonymizer = NodeSynonymizer(synonymizer_dir, synonymizer_dbname)
# Check if the input linker_name is valid
if type(linker_name) not in [str, list]:
raise ValueError("linker_name must be a string or list of strings")
if type(linker_name) is str:
linker_name = [linker_name]
self.logger.info("Input linker_name: {}".format(linker_name))
intersect_link_names = list(set([x.lower() for x in linker_name]).intersection(set(['umls', 'mesh', 'go',
'hpo', 'rxnorm'])))
if len(intersect_link_names) == 0:
raise ValueError("linker_name must be one of: umls, mesh or both of them")
self.available_linker_names = intersect_link_names
# Load spacy models
for x in self.available_linker_names:
self.logger.info("Setting up NER with linker: {}".format(x))
model_name = f"nlp_{x}"
setattr(self, model_name, spacy.load(spacy_model))
nlp = getattr(self, model_name)
nlp.add_pipe("abbreviation_detector")
nlp.add_pipe("scispacy_linker",
config={"resolve_abbreviations": True, "linker_name": x, "threshold": threshold,
"k": num_neighbors, "max_entities_per_mention": max_entities_per_mention})
def _get_preferred_curies_info(self, query: Union[str, List]) -> dict:
"""
Helper function to fetch information on preferred curies.
Args:
- query (list or str): List of curies or an entity to get information for.
Returns:
- dict: Information on preferred curies.
"""
# query node synonymizer
if isinstance(query, list):
res = self.synonymizer.get_canonical_curies(query)
else:
res = self.synonymizer.get_canonical_curies(names=query)
temp_dict = {}
for curie, content in res.items():
if content:
preferred_curie = content['preferred_curie']
preferred_info = {'preferred_name': content['preferred_name'],
'preferred_category': content['preferred_category'],
'matched_synonyms': temp_dict.get(preferred_curie, {}).get('matched_synonyms', []) + [
curie]}
temp_dict[preferred_curie] = preferred_info
# Sort based on the number of matched synonyms
return sorted(temp_dict.items(), key=lambda x: len(x[1]['matched_synonyms']), reverse=True)
def get_kg2_match(self, sentence: str, remove_mark: bool = True):
"""
Extract entities from a sentence and return the KG2 matched preferred curies
Args:
- sentence (str): The input sentence to extract entities from.
- remove_mark (bool): Flag to decide whether to remove punctuation marks.
Returns:
- dict: Matched entities to their preferred curies.
"""
# Validate the input sentence
if not isinstance(sentence, str) or not sentence:
return {}
# Remove punctuation marks if needed
if remove_mark:
sentence = sentence.translate(str.maketrans("", "", ".,;:?!"))
# Extract entities using available linkers
self._detected_entities = {}
for linker_name in self.available_linker_names:
doc = getattr(self, f"nlp_{linker_name}")(sentence)
for ent in doc.ents:
self._detected_entities.setdefault(ent.text, set()).update(
[f"{linker_name.upper()}:{x[0]}" for x in ent._.kb_ents])
# Obtain preferred curies for detected entities
matched_dict = {}
if len(self._detected_entities) > 0:
for entity, curies in self._detected_entities.items():
curie_to_info = self._get_preferred_curies_info(list(curies) if curies else entity)
if curie_to_info:
matched_dict[entity] = curie_to_info
else:
curie_to_info = self._get_preferred_curies_info(sentence)
if curie_to_info:
matched_dict[sentence] = curie_to_info
return matched_dict
def check_equivalent(self, name: str, curie: str):
"""
Check if the given name of UMLS/MESH curie is equivalent can match to its preferred curie.
Args:
- name (str): Name to be checked.
- curie (str): UMLS/MESH curie.
Returns:
- bool: True if name matches the preferred curie, otherwise False.
"""
# Check if the given curie is UMLS or MESH
if not (type(curie) is str and (curie.startswith('UMLS:') or curie.startswith('MESH:')) and type(name) is str):
return False
# Get preferred curie of the given curie
res = self.synonymizer.get_canonical_curies(curie)
res = [content['preferred_curie'] for _, content in res.items() if content]
if len(res) == 0:
return False
else:
preferred_curie = res[0]
# Check if the given name can match to the preferred curie
res = self.get_kg2_match(name, use_synonymizer=False, remove_mark=True)
if len(res) == 0:
return False
else:
return preferred_curie in [res[key][0][0] for key in res]