Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696882315
  • Loading branch information
tensorflower-gardener committed Nov 15, 2024
1 parent 58dbbe2 commit 694f4e4
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 166 deletions.
218 changes: 92 additions & 126 deletions official/projects/pix2seq/modeling/pix2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""

import math
from typing import Any, List, Mapping, Optional, Sequence, Union
from typing import Any, List, Mapping, Optional, Union

import tensorflow as tf, tf_keras

Expand Down Expand Up @@ -214,7 +214,7 @@ class Pix2Seq(tf_keras.Model):

def __init__(
self,
backbones: Sequence[tf_keras.Model],
backbone,
backbone_endpoint_name,
max_seq_len,
vocab_size,
Expand All @@ -229,10 +229,10 @@ def __init__(
top_k=0,
top_p=0.4,
early_stopping_token: int | None = None,
**kwargs,
**kwargs
):
super().__init__(**kwargs)
self._backbones = backbones
self._backbone = backbone
self._backbone_endpoint_name = backbone_endpoint_name
self._max_seq_len = max_seq_len
self._vocab_size = vocab_size
Expand All @@ -247,22 +247,17 @@ def __init__(
raise ValueError("hidden_size must be a multiple of 2.")

self._dropout = tf_keras.layers.Dropout(self._drop_units)
# Separate projections and learned layer normalization for each image.
num_backbones = len(self._backbones)
self._stem_projections = [
tf_keras.layers.Dense(self._hidden_size, name="stem_projection")
for _ in range(num_backbones)
]
self._stem_lns = [
tf_keras.layers.LayerNormalization(epsilon=1e-6, name="stem_ln")
for _ in range(num_backbones)
]
self._stem_projection = tf_keras.layers.Dense(
self._hidden_size, name="stem_projection"
)
self._stem_ln = tf_keras.layers.LayerNormalization(
epsilon=1e-6, name="stem_ln"
)

self._transformer = Pix2SeqTransformer(
max_seq_len=self._max_seq_len,
vocab_size=self._vocab_size,
hidden_size=self._hidden_size,
num_images=num_backbones,
pos_encoding="learned",
num_encoder_layers=self._num_encoder_layers,
num_decoder_layers=self._num_decoder_layers,
Expand All @@ -277,8 +272,8 @@ def __init__(
self._early_stopping_token = early_stopping_token

@property
def backbones(self) -> Sequence[tf_keras.Model]:
return self._backbones
def backbone(self) -> tf_keras.Model:
return self._backbone

@property
def transformer(self) -> tf_keras.Model:
Expand Down Expand Up @@ -312,12 +307,7 @@ def checkpoint_items(
self,
) -> Mapping[str, Union[tf_keras.Model, tf_keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed."""
# For backward-compatibility with prior checkpoints, the first backbone
# should be named "backbone" and the second one should be named
# "backbone_2", etc.
items = dict(backbone=self.backbones[0], transformer=self.transformer)
for i in range(1, len(self.backbones)):
items[f"backbone_{i+1}"] = self.backbones[i]
items = dict(backbone=self.backbone, transformer=self.transformer)
return items

def _generate_image_mask(
Expand All @@ -343,41 +333,36 @@ def call(
use_teacher_forcing_for_eval: bool = False,
use_input_as_backbone_features=False,
) -> List[Any]:
transformer_inputs = {
"tokens": targets,
"inputs": [], # List of [B, H*W, C] tensors, one per image modality.
"pos_emb": [], # List of positional embeddings for each image modality.
}
# Inputs has shape [B, N, H, W, C] where N is the number of images.
for i in range(len(self.backbones)):
inputs_i = inputs[:, i, :, :, :]
if use_input_as_backbone_features:
features = inputs_i
else:
features = self._backbones[i](inputs_i)[self._backbone_endpoint_name]
mask = tf.ones_like(features)
batch_size, h, w, num_channels = get_shape(features)
features = tf.reshape(features, [batch_size, h * w, num_channels])
features = self._stem_lns[i](
self._stem_projections[i](self._dropout(features, training))
)
if use_input_as_backbone_features:
features = inputs
else:
features = self._backbone(inputs)[self._backbone_endpoint_name]
mask = tf.ones_like(features)
batch_size, h, w, num_channels = get_shape(features)
features = tf.reshape(features, [batch_size, h * w, num_channels])
features = self._stem_ln(
self._stem_projection(self._dropout(features, training))
)

pos_emb = position_embedding_sine(
mask[:, :, :, 0], num_pos_features=self._hidden_size
)
pos_emb = tf.reshape(pos_emb, [batch_size, -1, self._hidden_size])
pos_emb = tf.cast(pos_emb, features.dtype)
transformer_inputs["inputs"].append(features)
transformer_inputs["pos_emb"].append(pos_emb)
pos_emb = position_embedding_sine(
mask[:, :, :, 0], num_pos_features=self._hidden_size
)
pos_emb = tf.reshape(pos_emb, [batch_size, -1, self._hidden_size])
pos_emb = tf.cast(pos_emb, features.dtype)

tokens = None
inputs = {
"inputs": features,
"tokens": targets,
"pos_emb": pos_emb,
}
if training:
logits = self._transformer(transformer_inputs, training=True)
logits = self._transformer(inputs, training=True)
elif use_teacher_forcing_for_eval:
logits = self._transformer(transformer_inputs, training=False)
logits = self._transformer(inputs, training=False)
else:
tokens, logits = self._transformer.infer(
transformer_inputs,
inputs,
temperature=self._temperature,
top_k=self._top_k,
top_p=self._top_p,
Expand Down Expand Up @@ -423,7 +408,6 @@ def __init__(
max_seq_len,
vocab_size,
hidden_size,
num_images,
pos_encoding="learned",
num_encoder_layers=6,
num_decoder_layers=6,
Expand All @@ -432,13 +416,12 @@ def __init__(
drop_att=0.0,
output_bias=True,
num_heads=8,
**kwargs,
**kwargs
):
super().__init__(**kwargs)
self._max_seq_len = max_seq_len
self._vocab_size = vocab_size
self._hidden_size = hidden_size
self._num_images = num_images
self._pos_encoding = pos_encoding
self._num_encoder_layers = num_encoder_layers
self._num_decoder_layers = num_decoder_layers
Expand All @@ -459,45 +442,34 @@ def __init__(
)

if self._num_encoder_layers > 0:
self._encoders = [
transformer.TransformerEncoder(
num_layers=self._num_encoder_layers,
dim=self._hidden_size,
mlp_ratio=4,
num_heads=self._num_heads,
drop_path=self._drop_path,
drop_units=self._drop_units,
drop_att=self._drop_att,
)
for _ in range(self._num_images)
]
self._encoder = transformer.TransformerEncoder(
num_layers=self._num_encoder_layers,
dim=self._hidden_size,
mlp_ratio=4,
num_heads=self._num_heads,
drop_path=self._drop_path,
drop_units=self._drop_units,
drop_att=self._drop_att,
)
else:
self._encoders = None

self._output_ln_encs = [
tf_keras.layers.LayerNormalization(epsilon=1e-6, name="output_ln_enc")
for _ in range(self._num_images)
]

self._projs = [
tf_keras.layers.Dense(self._hidden_size, name="proj/linear")
for _ in range(self._num_images)
]
self._proj_lns = [
tf_keras.layers.LayerNormalization(epsilon=1e-6, name="proj/ln")
for _ in range(self._num_images)
]
self._proj_mlps = [
transformer.MLP(
num_layers=1,
dim=self._hidden_size,
mlp_ratio=4,
drop_path=self._drop_path,
drop_units=self._drop_units,
name="proj/mlp",
)
for _ in range(self._num_images)
]
self._encoder = None

self._output_ln_enc = tf_keras.layers.LayerNormalization(
epsilon=1e-6, name="output_ln_enc"
)

self._proj = tf_keras.layers.Dense(self._hidden_size, name="proj/linear")
self._proj_ln = tf_keras.layers.LayerNormalization(
epsilon=1e-6, name="proj/ln"
)
self._proj_mlp = transformer.MLP(
num_layers=1,
dim=self._hidden_size,
mlp_ratio=4,
drop_path=self._drop_path,
drop_units=self._drop_units,
name="proj/mlp",
)

self._decoder = transformer.TransformerDecoder(
num_layers=self._num_decoder_layers,
Expand Down Expand Up @@ -527,40 +499,22 @@ def get_config(self):
"num_heads": self._num_heads,
}

def encode_sources(
self,
sources: Sequence[tf.Tensor],
mem_pos_embeds: Sequence[tf.Tensor],
training: bool,
):
"""Encodes and concatenates sources for the decoder."""
encoded_sources = []
for i in range(self._num_images):
source = sources[i]
mem_pos_embed = mem_pos_embeds[i]
source = source + mem_pos_embed
if self._encoders is not None:
encoded = self._encoders[i](
source, None, training=training, ret_list=False
)
else:
encoded = source

encoded = self._output_ln_encs[i](encoded)
encoded = self._proj_lns[i](self._projs[i](encoded))
encoded = encoded + mem_pos_embed
encoded = self._proj_mlps[i](encoded, training=training)
encoded_sources.append(encoded)
def call(self, inputs: tf.Tensor, training: bool = None):
sources = inputs["inputs"]
targets = inputs["tokens"]
mem_pos_embed = inputs["pos_emb"]

# encoded_sources is of length N, each item having shape
# [B, H*W, self._hidden_size]. Reshape to [B, N*H*W, self._hidden_size]
# before passing to decoder.
return tf.concat(encoded_sources, axis=1)
sources = sources + mem_pos_embed
if self._encoder is not None:
encoded = self._encoder(sources, None, training=training, ret_list=False)
else:
encoded = sources
encoded = self._output_ln_enc(encoded)

def call(self, inputs: dict[str, tf.Tensor], training: bool = None):
encoded = self.encode_sources(inputs["inputs"], inputs["pos_emb"], training)
encoded = self._proj_ln(self._proj(encoded))
encoded = encoded + mem_pos_embed
encoded = self._proj_mlp(encoded, training=training)

targets = inputs["tokens"]
seq_len = tf.shape(targets)[1]
seq_pos_emb = tf.expand_dims(self.seq_pos_emb[:seq_len], 0)
inp_embedding = outp_embedding = self.token_embedding
Expand Down Expand Up @@ -623,10 +577,22 @@ def infer(
logits (temperature-scaled) associated with sampled token, in shape of
(bsz, max_seq_len-prompt_len, vocab_size).
"""
encoded = self.encode_sources(
inputs["inputs"], inputs["pos_emb"], training=False
)

sources = inputs["inputs"]
prompt = inputs["tokens"]
mem_pos_embed = inputs["pos_emb"]

sources = sources + mem_pos_embed
if self._encoder is not None:
encoded = self._encoder(sources, None, training=False, ret_list=False)
else:
encoded = sources
encoded = self._output_ln_enc(encoded)

encoded = self._proj_ln(self._proj(encoded))
encoded = encoded + mem_pos_embed
encoded = self._proj_mlp(encoded, training=False)

bsz = tf.shape(prompt)[0]
prompt_len = tf.shape(prompt)[1]

Expand Down
Loading

0 comments on commit 694f4e4

Please sign in to comment.