-
Notifications
You must be signed in to change notification settings - Fork 9
/
sensebert.py
93 lines (74 loc) · 3.81 KB
/
sensebert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
from collections import namedtuple
import tensorflow as tf
from tokenization import FullTokenizer
_SenseBertGraph = namedtuple(
'SenseBertGraph',
('input_ids', 'input_mask', 'contextualized_embeddings', 'mlm_logits', 'supersense_losits')
)
_MODEL_PATHS = {
'sensebert-base-uncased': 'gs://ai21-public-models/sensebert-base-uncased',
'sensebert-large-uncased': 'gs://ai21-public-models/sensebert-large-uncased'
}
_CONTEXTUALIZED_EMBEDDINGS_TENSOR_NAME = "bert/encoder/Reshape_13:0"
def _get_model_path(name_or_path, is_tokenizer=False):
if name_or_path in _MODEL_PATHS:
print(f"Loading the known {'tokenizer' if is_tokenizer else 'model'} '{name_or_path}'")
model_path = _MODEL_PATHS[name_or_path]
else:
print(f"This is not a known {'tokenizer' if is_tokenizer else 'model'}. "
f"Assuming {name_or_path} is a path or a url...")
model_path = name_or_path
return model_path
def load_tokenizer(name_or_path):
model_path = _get_model_path(name_or_path, is_tokenizer=True)
vocab_file = os.path.join(model_path, "vocab.txt")
supersense_vocab_file = os.path.join(model_path, "supersense_vocab.txt")
return FullTokenizer(vocab_file=vocab_file, senses_file=supersense_vocab_file)
def _load_model(name_or_path, session=None):
if session is None:
session = tf.get_default_session()
model = tf.saved_model.load(export_dir=_get_model_path(name_or_path), sess=session, tags=[tf.saved_model.SERVING])
serve_def = model.signature_def[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
inputs, outputs = ({key: session.graph.get_tensor_by_name(info.name) for key, info in puts.items()}
for puts in (serve_def.inputs, serve_def.outputs))
return _SenseBertGraph(
input_ids=inputs['input_ids'],
input_mask=inputs['input_mask'],
contextualized_embeddings=session.graph.get_tensor_by_name(_CONTEXTUALIZED_EMBEDDINGS_TENSOR_NAME),
supersense_losits=outputs['ss'],
mlm_logits=outputs['masked_lm']
)
class SenseBert:
def __init__(self, name_or_path, max_seq_length=512, session=None):
self.max_seq_length = max_seq_length
self.session = session if session else tf.get_default_session()
self.model = _load_model(name_or_path, session=self.session)
self.tokenizer = load_tokenizer(name_or_path)
def tokenize(self, inputs):
"""
Gets a string or a list of strings, and returns a tuple (input_ids, input_mask) to use as inputs for SenseBERT.
Both share the same shape: [batch_size, sequence_length] where sequence_length is the maximal sequence length.
"""
if isinstance(inputs, str):
inputs = [inputs]
# tokenizing all inputs
all_token_ids = []
for inp in inputs:
tokens = [self.tokenizer.start_sym] + self.tokenizer.tokenize(inp)[0] + [self.tokenizer.end_sym]
assert len(tokens) <= self.max_seq_length
all_token_ids.append(self.tokenizer.convert_tokens_to_ids(tokens))
# decide the maximum sequence length and pad accordingly
max_len = max([len(token_ids) for token_ids in all_token_ids])
input_ids, input_mask = [], []
pad_sym_id = self.tokenizer.convert_tokens_to_ids([self.tokenizer.pad_sym])
for token_ids in all_token_ids:
to_pad = max_len - len(token_ids)
input_ids.append(token_ids + pad_sym_id * to_pad)
input_mask.append([1] * len(token_ids) + [0] * to_pad)
return input_ids, input_mask
def run(self, input_ids, input_mask):
return self.session.run(
[self.model.contextualized_embeddings, self.model.mlm_logits, self.model.supersense_losits],
feed_dict={self.model.input_ids: input_ids, self.model.input_mask: input_mask}
)