Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558604894
  • Loading branch information
tensorflower-gardener committed Aug 20, 2023
1 parent 2394a73 commit 812caaf
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
14 changes: 10 additions & 4 deletions official/nlp/modeling/layers/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,16 @@ def call(self, sequence_data, masked_positions):
lm_data = self.layer_norm(lm_data)
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
masked_positions_length = masked_positions.shape.as_list()[1] or tf.shape(
masked_positions)[1]
logits = tf.reshape(logits,
[-1, masked_positions_length, self._vocab_size])
masked_positions_length = (
masked_positions.shape.as_list()[1] or tf.shape(masked_positions)[1]
)
batch_size = (
masked_positions.shape.as_list()[0] or tf.shape(masked_positions)[0]
)
logits = tf.reshape(
logits,
[batch_size, masked_positions_length, self._vocab_size],
)
if self._output_type == 'logits':
return logits
return tf.nn.log_softmax(logits)
Expand Down
21 changes: 16 additions & 5 deletions official/nlp/modeling/layers/masked_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# limitations under the License.

"""Tests for masked language model network."""

from absl.testing import parameterized
import numpy as np
import tensorflow as tf

from official.nlp.modeling.layers import masked_lm
from official.nlp.modeling.networks import bert_encoder


class MaskedLMTest(tf.test.TestCase):
class MaskedLMTest(tf.test.TestCase, parameterized.TestCase):

def create_layer(self,
vocab_size,
Expand Down Expand Up @@ -110,11 +110,20 @@ def test_layer_invocation_with_external_logits(self):
self.assertEqual(expected_output_shape, outputs.shape)
self.assertAllClose(ref_outputs, outputs)

def test_layer_invocation(self):
@parameterized.named_parameters(
dict(
testcase_name='default',
num_predictions=21,
),
dict(
testcase_name='zero_predictions',
num_predictions=0,
),
)
def test_layer_invocation(self, num_predictions):
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
test_layer = self.create_layer(
vocab_size=vocab_size, hidden_size=hidden_size)

Expand All @@ -131,7 +140,9 @@ def test_layer_invocation(self):
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
_ = model.predict([lm_input_data, masked_position_data])
res = model.predict([lm_input_data, masked_position_data])
expected_shape = (batch_size, num_predictions, vocab_size)
self.assertEqual(expected_shape, res.shape)

def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
Expand Down

0 comments on commit 812caaf

Please sign in to comment.