-
Notifications
You must be signed in to change notification settings - Fork 5
/
language_models.py
103 lines (88 loc) · 3.87 KB
/
language_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright 2021, Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sequence model functions for research baselines."""
import functools
import sys
from typing import Optional
import tensorflow as tf
class TransposableEmbedding(tf.keras.layers.Layer):
"""A Keras layer implements a transposed projection for output."""
def __init__(self, embedding_layer: tf.keras.layers.Embedding):
super().__init__()
self.embeddings = embedding_layer.embeddings
# Placing `tf.matmul` under the `call` method is important for backpropagating
# the gradients of `self.embeddings` in graph mode.
def call(self, inputs):
return tf.matmul(inputs, self.embeddings, transpose_b=True)
def create_recurrent_model(vocab_size: int = 10000,
num_oov_buckets: int = 1,
embedding_size: int = 96,
latent_size: int = 670,
num_layers: int = 1,
name: str = 'rnn',
shared_embedding: bool = False,
seed: Optional[int] = 0,
cell_type: str = 'LSTM'):
"""Constructs zero-padded keras model with the given parameters and cell.
Args:
vocab_size: Size of vocabulary to use.
num_oov_buckets: Number of out of vocabulary buckets.
embedding_size: The size of the embedding.
latent_size: The size of the recurrent state.
num_layers: The number of layers.
name: (Optional) string to name the returned `tf.keras.Model`.
shared_embedding: (Optional) Whether to tie the input and output
embeddings.
seed: A random seed governing the model initialization and layer randomness.
If set to `None`, No random seed is used.
cell_type: The cell type to be used in the recurrent LM.
Returns:
`tf.keras.Model`.
"""
extended_vocab_size = vocab_size + 3 + num_oov_buckets # For pad/bos/eos/oov.
inputs = tf.keras.layers.Input(shape=(None,))
input_embedding = tf.keras.layers.Embedding(
input_dim=extended_vocab_size,
output_dim=embedding_size,
mask_zero=True,
embeddings_initializer=tf.keras.initializers.RandomUniform(seed=seed),
)
embedded = input_embedding(inputs)
projected = embedded
if cell_type == 'LSTM':
cell = tf.keras.layers.LSTMCell
else:
raise ValueError("Unsupported cell type %s".format(cell_type))
lstm_layer_builder = functools.partial(
cell,
units=latent_size,
recurrent_initializer=tf.keras.initializers.Orthogonal(seed=seed),
kernel_initializer=tf.keras.initializers.HeNormal(seed=seed))
dense_layer_builder = functools.partial(
tf.keras.layers.Dense,
kernel_initializer=tf.keras.initializers.GlorotNormal(seed=seed))
for _ in range(num_layers):
layer = tf.keras.layers.RNN(lstm_layer_builder(), return_sequences=True)
processed = layer(projected)
# A projection changes dimension from rnn_layer_size to input_embedding_size
dense_layer = dense_layer_builder(units=embedding_size)
projected = dense_layer(processed)
if shared_embedding:
transposed_embedding = TransposableEmbedding(input_embedding)
logits = transposed_embedding(projected)
else:
final_dense_layer = dense_layer_builder(
units=extended_vocab_size, activation=None)
logits = final_dense_layer(projected)
return tf.keras.Model(inputs=inputs, outputs=logits, name=name)