-
Notifications
You must be signed in to change notification settings - Fork 157
/
transformer_network.py
684 lines (608 loc) · 28.2 KB
/
transformer_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow based methods for sequence agents."""
from typing import Optional, Tuple, Union, Any
from absl import logging
import numpy as np
from robotics_transformer import transformer
from robotics_transformer.film_efficientnet import preprocessors
from robotics_transformer.tokenizers import action_tokenizer
from robotics_transformer.tokenizers import image_tokenizer
from tensor2robot.utils import tensorspec_utils
import tensorflow as tf
from tf_agents.networks import network
from tf_agents.specs import tensor_spec
from tf_agents.utils import nest_utils
class TransformerNetwork(network.Network):
"""A transformer based actor network."""
def __init__(
self,
input_tensor_spec: tensorspec_utils.TensorSpecStruct,
output_tensor_spec: tensorspec_utils.TensorSpecStruct,
train_step_counter: int = 0,
vocab_size: int = 256,
token_embedding_size: int = 512,
num_layers: int = 1,
layer_size: int = 4096,
num_heads: int = 8,
feed_forward_size: int = 512,
dropout_rate: float = 0.1,
time_sequence_length: int = 1,
crop_size: int = 236,
policy_info_spec: Optional[dict[Any,
tensor_spec.BoundedTensorSpec]] = None,
action_order: Optional[list[str]] = None,
use_token_learner: Optional[bool] = True,
return_attention_scores: bool = False,
**kwargs):
"""Creates a transformer network.
Args:
input_tensor_spec: Nested list/tuple/dict of TensorSpecs, describing the
shape of input tensor.
output_tensor_spec: Nested list/tuple/dict of TensorSpecs, describing the
shape of output tensor.
train_step_counter: Counter for number of steps.
vocab_size: Dimensionality of tokens from the output layer.
token_embedding_size: Dimensionality of tokens from the embedding layer.
num_layers: Number of transformer layers.
layer_size: Size of the multiple head attention layer.
num_heads: Number of heads for the multiple head attention layer.
feed_forward_size: Dimensionality of the feed_forward layer.
dropout_rate: Dropout rate.
time_sequence_length: Length of the time sequence.
crop_size: Height and width of the square crop, where original image will
be padded to allow full field of view to be extracted.
policy_info_spec: Spec on return value given return type of the return
tokenizer.
action_order: Order of actions for the action tokenizer.
use_token_learner: Whether to use token learner. See
https://arxiv.org/abs/2106.11297
return_attention_scores: show attention scores in tensorboard.
**kwargs: Keyword parameter arguments.
"""
self._input_tensor_spec = input_tensor_spec
self._output_tensor_spec = output_tensor_spec
self._train_step_counter = train_step_counter
self._actions = None
self._returns = None
self._vocab_size = vocab_size
self._token_embedding_size = token_embedding_size
self._time_sequence_length = time_sequence_length
self._crop_size = crop_size
self._transformer = transformer.Transformer(
num_layers=num_layers,
layer_size=layer_size,
num_heads=num_heads,
feed_forward_size=feed_forward_size,
dropout_rate=dropout_rate,
vocab_size=self._vocab_size,
return_attention_scores=return_attention_scores)
# create tokenizers
self._image_tokenizer = image_tokenizer.RT1ImageTokenizer(
embedding_output_dim=self._token_embedding_size,
use_token_learner=use_token_learner)
self._action_tokenizer = action_tokenizer.RT1ActionTokenizer(
output_tensor_spec,
vocab_size=self._vocab_size,
action_order=action_order)
self._tokens_per_action = self._action_tokenizer.tokens_per_action
self._tokens_per_context_image = self._image_tokenizer.tokens_per_context_image
# generate loss and attention masks
self._generate_masks()
# define mappings to token embedding size
self._action_token_emb = tf.keras.layers.Dense(self._token_embedding_size)
# define loss function
self._loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
self._attention_scores = []
self._use_token_learner = use_token_learner
super(TransformerNetwork, self).__init__(
input_tensor_spec=input_tensor_spec, **kwargs)
self._state_spec = {
# Force this to be 4 dimension due to b/254902773.
# Otherwise can be dimension 3.
'context_image_tokens':
tensor_spec.TensorSpec(
shape=(time_sequence_length, self._tokens_per_context_image, 1,
token_embedding_size),
dtype=tf.float32,
name='context_image_tokens'),
'action_tokens':
tensor_spec.TensorSpec(
shape=(time_sequence_length, self._tokens_per_action, 1, 1),
dtype=tf.int32,
name='action_tokens'),
# Stores where in the window we are.
# This value is within range [0, time_sequence_length + 1].
# When seq_idx == time_sequence_length, context_image_tokens and
# action_tokens need to be shifted to the left.
'seq_idx':
tensor_spec.TensorSpec(
shape=(1, 1, 1, 1), dtype=tf.int32, name='seq_idx')
}
@property
def attention_scores(self) -> list[tf.Tensor]:
"""Return attention score. This is for debugging/visualization purpose."""
return self._attention_scores
def _get_action_index_for_token(self, k):
"""Returns action associated with the token at given position `k`.
If k is not an action token then it returns -1.
If k is part of the first action in the sequence then returns 0 etc.
Args:
k: an int that represents the position in the sequence.
Returns:
The index of the action that this position belongs to, or if this
position is part of an image token then returns -1.
"""
if (k < 0 or k >= self._all_num_tokens):
return -1
n = k
if n % self._single_time_step_num_tokens < self._tokens_per_context_image:
return -1
return int(n / self._single_time_step_num_tokens)
def _generate_masks(self):
"""Generate mask for action prediction loss and attention visualization."""
# each time step = [image, action]
self._single_time_step_num_tokens = (
self._tokens_per_action + self._tokens_per_context_image)
# full sequence = [prefix context + N x timestep + postfix context]
self._all_num_tokens = (
self._time_sequence_length * self._single_time_step_num_tokens)
# create mask for action predition loss
self._action_tokens_mask = []
for n in range(0, self._all_num_tokens, self._single_time_step_num_tokens):
for x in range(0, self._tokens_per_action, 1):
self._action_tokens_mask.append(x + n + self._tokens_per_context_image)
self._action_tokens_mask = tf.constant(
self._action_tokens_mask, dtype=tf.int32)
# The look ahead mask ensures causality.
self._default_attention_mask = tf.linalg.band_part(
tf.ones((self._all_num_tokens, self._all_num_tokens)), -1, 0)
action_mask = np.ndarray(
shape=(self._all_num_tokens, self._all_num_tokens), dtype=int)
for i in range(self._all_num_tokens):
for j in range(self._all_num_tokens):
action_i = self._get_action_index_for_token(i)
action_j = self._get_action_index_for_token(j)
mask = 0
if action_i != -1 and action_j != -1:
# Ignore actions of previous steps.
if action_j < action_i:
mask = 1
# If we're not auto-regression, ignore action dimensions of current
# step.
if (action_j == action_i and j <= i):
mask = 1
action_mask[i, j] = mask
self._default_attention_mask -= action_mask
def _transformer_call(
self,
context_image_tokens: tf.Tensor,
action_tokens: tf.Tensor,
batch_size: int,
training: bool,
attention_mask: tf.Tensor,
) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
"""Calls the transformer.
Args:
context_image_tokens: Tokenized context and image in Tensor of shape `(B,
T, num token, -1)`.
action_tokens: Discrete action token sequence of size [8, 256].
batch_size: Batch size as when reshaping all tokens.
training: Whether to run the transformer in training mode.
attention_mask: Optional bool tensor for masking transformer's attention.
Returns:
Output tokens in Tensor of shape `(B, T, dim)`. If
return_attention_scores, also return the attention scores of
shape `(B, T, dim)`.
"""
input_token_sequence = self._assemble_input_token_sequence(
context_image_tokens, action_tokens, batch_size)
# run transformer
output_tokens, self._attention_scores = self._transformer(
input_token_sequence, training, attention_mask)
return output_tokens
def _get_tokens_and_mask(self,
observations: dict[str, tf.Tensor],
network_state: dict[str, tf.Tensor],
training: bool = False):
# tokenize all inputs
context_image_tokens, network_state = self._tokenize_images(
observations, network_state, training)
action_tokens = self._tokenize_actions(observations, network_state)
# generate transformer attention mask
attention_mask = self._default_attention_mask
return (context_image_tokens, action_tokens, attention_mask)
def _transformer_call_and_slice(self,
*args,
slice_start: int = 0,
slice_length: int = 1,
**kwargs) -> Tuple[tf.Tensor, tf.Tensor]:
output_tokens = self._transformer_call(*args, **kwargs)
slice_end = slice_start + slice_length
token_logits = output_tokens[:, slice_start:slice_end, :]
token = tf.argmax(token_logits, axis=-1, output_type=tf.int32)
return token, token_logits
def call(self,
observations: dict[str, tf.Tensor],
network_state: dict[str, tf.Tensor],
training: bool = False):
"""Calls the transformer network.
Args:
observations: Observation data including image and natural language
embedding in dict of Tensors.
network_state: Network state data including time step, image, action
tokens, step number in dict of Tensors.
training: Whether to call transformer network in training mode.
Returns:
A tuple `(Detokenized output actions, network state)`.
"""
# used to determine training vs inference call
# outer_rank will be 2 -> [b, t] during training and
# outer_rank will be 1 -> [b] during inference
outer_rank = self._get_outer_rank(observations)
assert outer_rank in (1, 2)
b, t = self._get_batch_size_and_seq_len(network_state)
context_image_tokens, action_tokens, attention_mask = self._get_tokens_and_mask(
observations, network_state, training)
self._aux_info = {'action_labels': action_tokens}
if outer_rank == 1: # This is an inference call
# run transformer in loop to produce action tokens one-by-one
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
action_t = tf.minimum(seq_idx, self._time_sequence_length - 1)
# Transformer shifts all to the left by one step by default (it's usually
# predicting the next token as default training task...).
transformer_shift = -1
# We only want to get the action predicted at time_step.
start_index = (
transformer_shift + self._tokens_per_context_image + action_t *
(self._single_time_step_num_tokens))
current_action_tokens = []
action_predictions_logits = []
for k in range(self._tokens_per_action):
action_index = start_index + k
token, token_logits = self._transformer_call_and_slice(
context_image_tokens,
action_tokens,
attention_mask=attention_mask,
batch_size=b,
training=training,
slice_start=action_index # slicing single action dimension
)
action_predictions_logits.append(token_logits)
current_action_tokens.append(token)
# action_tokens is [b, t * self._tokens_per_action]
action_tokens = tf.reshape(action_tokens, [b, -1])
action_start_index = (action_t * self._tokens_per_action) + k
action_tokens = tf.concat([
action_tokens[:, :action_start_index], token,
action_tokens[:, action_start_index + 1:]
],
axis=1)
# action_tokens is [b, t, self._tokens_per_action]
action_tokens = tf.reshape(action_tokens,
[b, t, self._tokens_per_action])
self._aux_info.update({
# action_predictions_logits is
# [b, self._tokens_per_action, self._vocab_size]
'action_predictions_logits': tf.concat(action_predictions_logits, 1)
})
# predicted_tokens_for_output is [b, self._tokens_per_action]
predicted_tokens_for_output = tf.concat(current_action_tokens, 1)
# state_action_tokens is [b, 1, self._tokens_per_action, 1, 1]
one_state_action_tokens = predicted_tokens_for_output[:, tf.newaxis, :,
tf.newaxis,
tf.newaxis]
state_action_tokens = network_state['action_tokens']
network_state['action_tokens'] = tf.concat([
state_action_tokens[:, :action_t, ...], one_state_action_tokens,
state_action_tokens[:, action_t + 1:, ...]
],
axis=1)
# Increment the time_step for the next inference call.
network_state['seq_idx'] = tf.reshape(
tf.minimum(seq_idx + 1, self._time_sequence_length), [-1, 1, 1, 1, 1])
self._loss = tf.constant(0.0)
else:
# training call --> simply run one transformer forward pass
output_tokens = self._transformer_call(
context_image_tokens,
action_tokens,
attention_mask=attention_mask,
batch_size=b,
training=training)
# Gather all predicted actions for the action loss.
action_logits = tf.gather(
output_tokens, self._action_tokens_mask - 1, axis=1)
action_logits_for_training = tf.reshape(
action_logits, [b, t, self._tokens_per_action, -1])
# Only take the last action as the action.
# action_logits_for_output is [b, self._tokens_per_action, emb]
action_logits_for_output = action_logits_for_training[:, -1]
# predicted_tokens_for_output is [b, self._tokens_per_action]
predicted_tokens_for_output = tf.argmax(
action_logits_for_output, axis=-1, output_type=tf.int32)
num_items = (
tf.cast(b * t, tf.float32) * self._single_time_step_num_tokens)
action_loss = tf.reduce_mean(
self._loss_object(action_tokens, action_logits_for_training) /
num_items,
axis=-1)
self._loss = action_loss
# store action labels and predictions for visualization
self._aux_info.update({
'action_predictions':
tf.argmax(
action_logits_for_training, axis=-1, output_type=tf.int32),
'action_loss':
action_loss,
'actor_loss_mask':
tf.ones([b], dtype=tf.float32)
})
output_actions = self._action_tokenizer.detokenize(
predicted_tokens_for_output)
return output_actions, network_state
def add_summaries(self, observations: dict[str, tf.Tensor],
logging_info: dict[str, tf.Tensor], debug_summaries: bool,
training: bool) -> None:
"""Adds summaries.
Args:
observations: Observation data including image and natural language
instruction in dict of Tensors.
logging_info: Dict with all data stored for logging during training pass.
debug_summaries: Whether to include debug summaries.
training: Whether this function is called during training or inference.
"""
num_params = 0
for weight in self.trainable_weights:
weight_params = 1
for dim in weight.shape:
weight_params *= dim
num_params += weight_params
tf.compat.v2.summary.scalar(name='num_params', data=num_params)
# debug_summaries are for the non-tpu worker, train_summary.
if debug_summaries:
image = observations['image'] # [b, t, h, w, c]
image_h = image.shape[2]
image_w = image.shape[3]
batch_size = image.shape[0]
num_ts = image.shape[1]
logging.info('image shape %s', image.shape)
# Concat images for different timesteps across width.
image = tf.concat(tf.unstack(image, axis=1), 2)
# Concat images for different batches (up to 8) across height.
image = tf.expand_dims(tf.concat(tf.unstack(image, axis=0)[0:8], 0), 0)
tf.summary.image(
'observations/image',
image,
step=self._train_step_counter,
# Single output since we have concatenated images along batch.
max_outputs=1)
# [b, t], strings
if 'natural_language_instruction' in observations:
task = observations['natural_language_instruction'][:, 0]
tf.summary.text(
'natural_language_instruction', task, step=self._train_step_counter)
if self.attention_scores and not self._use_token_learner:
for l_idx, layer_attention_score in enumerate(self.attention_scores):
logging.info('Attention score shape: %s, %s', l_idx,
layer_attention_score.shape)
for head_idx in range(layer_attention_score.shape[1]):
pairwise_attention = tf.expand_dims(
layer_attention_score[:, head_idx], -1)
# pairwise attention shape (16, 552, 552, 1)
# make attention from different time steps comparable
pairwise_attention = pairwise_attention * np.arange(
1, pairwise_attention.shape[1] + 1)[None, :, None, None]
# visualize spatial attention, note this only supports
# mk1_500tasks_transformer pipeline with no token learner
img_tf_ts = tf.reshape(
tf.transpose(
tf.reshape(
tf.reduce_sum(pairwise_attention, axis=1) / np.arange(
pairwise_attention.shape[1], 0, -1)[None, :, None],
[batch_size, num_ts, -1]),
[0, 2, 1])[:, :-self._tokens_per_action, :],
[-1, 9, 9, num_ts])
img_tf_ts = tf.image.resize(
img_tf_ts, [image_h, image_w],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
img_tf_ts_concat = tf.concat(tf.unstack(img_tf_ts, axis=3), 2)
img_tf_ts_concat_min = tf.reduce_min(
img_tf_ts_concat, axis=[1, 2], keepdims=True)
img_tf_ts_concat = (img_tf_ts_concat - img_tf_ts_concat_min) / (
tf.reduce_max(img_tf_ts_concat, axis=[1, 2], keepdims=True) -
img_tf_ts_concat_min)
img_tf_ts_concat = tf.concat(
tf.unstack(img_tf_ts_concat, axis=0)[:8], 0)
img_tf_ts_concat = tf.expand_dims(
tf.expand_dims(img_tf_ts_concat, 0), -1)
tf.summary.image(
'attention/layer_{}/head_{}'.format(l_idx, head_idx),
img_tf_ts_concat,
step=self._train_step_counter,
# Single output since we have concatenated images along batch.
max_outputs=1)
if img_tf_ts_concat.shape[1] == image.shape[
1] and img_tf_ts_concat.shape[2] == image.shape[2]:
# can overlay
overlay_viz = tf.cast(
(tf.cast(image, tf.float32) * (0.2 + img_tf_ts_concat) / 1.2),
tf.uint8)
tf.summary.image(
'overlay_attention/layer_{}/head_{}'.format(l_idx, head_idx),
overlay_viz,
step=self._train_step_counter,
# Single output since we have concatenated images along batch.
max_outputs=1)
# log action info
action_labels = tf.boolean_mask(logging_info['action_labels'],
logging_info['actor_loss_mask'])
action_predictions = tf.boolean_mask(logging_info['action_predictions'],
logging_info['actor_loss_mask'])
with tf.name_scope('ActionTokens'):
token_accuracy = (
tf.cast(tf.equal(action_labels, action_predictions), tf.float32))
accuracy = tf.reduce_mean(token_accuracy)
tf.compat.v2.summary.scalar(
name='accuracy', data=accuracy, step=self._train_step_counter)
# Accuracy across timesteps
for t in range(self._time_sequence_length):
tf.compat.v2.summary.scalar(
name='accuracy/time_step/{}'.format(t),
data=tf.reduce_mean(token_accuracy[:, t, :]),
step=self._train_step_counter)
token_index = 0
for k in self._action_tokenizer.action_order:
spec = self._action_tokenizer.action_spec[k]
if spec.dtype == tf.int32:
n_tokens = 1
else:
n_tokens = spec.shape[0]
action_token_accuracy = tf.reduce_mean(
token_accuracy[:, :, token_index:token_index + n_tokens])
tf.compat.v2.summary.scalar(
name='accuracy/action_type/{}'.format(k),
data=action_token_accuracy,
step=self._train_step_counter)
for n in range(n_tokens):
tf.summary.histogram(
'tokens/{}_{}/labels'.format(k, n + 1),
action_labels[:, :, token_index],
step=self._train_step_counter)
tf.summary.histogram(
'tokens/{}_{}/predictions'.format(k, n + 1),
action_predictions[:, :, token_index],
step=self._train_step_counter)
token_index += 1
# log loss components
with tf.name_scope('TokenLosses'):
tf.compat.v2.summary.scalar(
name='action_loss',
data=tf.reduce_mean(logging_info['action_loss']),
step=self._train_step_counter)
def _tokenize_images(self, observations, network_state, training):
image = observations['image'] # [b, t, h, w, c]
outer_rank = self._get_outer_rank(observations)
if outer_rank == 1: # This is an inference call
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
time_step = tf.minimum(seq_idx, self._time_sequence_length - 1)
image = tf.expand_dims(image, 1)
image_shape = tf.shape(image)
b = image_shape[0]
input_t = image_shape[1]
h = image_shape[2]
w = image_shape[3]
c = image_shape[4]
context = self._extract_context_from_observation(observations, input_t)
image = tf.reshape(image, [b * input_t, h, w, c])
seed = tf.random.uniform(shape=(2,), maxval=2**30, dtype=tf.int32)
image = preprocessors.convert_dtype_and_crop_images(
image,
crop_size=self._crop_size,
training=training,
pad_then_crop=True,
convert_dtype=True,
seed=seed)
image = tf.reshape(image, [b, input_t, h, w, c])
context_image_tokens = self._image_tokenizer(
image, context=context, training=training)
num_tokens = tf.shape(context_image_tokens)[2]
context_image_tokens = tf.reshape(context_image_tokens,
[b, input_t, num_tokens, 1, -1])
if outer_rank == 1: # This is an inference call
network_state['context_image_tokens'] = tf.reshape(
network_state['context_image_tokens'], [
b, self._time_sequence_length, self._tokens_per_context_image, 1,
-1
])
state_image_tokens = network_state['context_image_tokens']
# network_state as input for this call is the output from the last call.
# Therefore, we need to shift all images to the left by 1 in the time axis
# to align w/ the time dim in this call.
state_image_tokens = tf.cond(
seq_idx == self._time_sequence_length,
lambda: tf.roll(state_image_tokens, -1, axis=1),
lambda: state_image_tokens)
context_image_tokens = tf.concat([
state_image_tokens[:, :time_step, ...], context_image_tokens,
state_image_tokens[:, time_step + 1:, ...]
],
axis=1)
network_state['context_image_tokens'] = context_image_tokens
return context_image_tokens, network_state
def _tokenize_actions(self, observations, network_state):
outer_rank = self._get_outer_rank(observations)
if outer_rank == 1: # This is an inference call
action_tokens = tf.squeeze(network_state['action_tokens'], [3, 4])
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
# network_state as input for this call is the output from the last call.
# Therefore, we need to shift all actions by 1 to the left.
action_tokens = tf.cond(seq_idx == self._time_sequence_length,
lambda: tf.roll(action_tokens, -1, axis=1),
lambda: action_tokens)
else:
assert outer_rank == 2
if self._actions is None:
b, t = self._get_batch_size_and_seq_len(network_state)
action_tokens = tf.zeros(
shape=[b, t, self._tokens_per_action], dtype=tf.int32)
else:
action_tokens = self._action_tokenizer.tokenize(self._actions)
return action_tokens
def _assemble_input_token_sequence(self, context_image_tokens, action_tokens,
batch_size):
# embed action tokens
action_tokens = tf.one_hot(action_tokens, self._vocab_size)
action_tokens = self._action_token_emb(action_tokens)
action_tokens = tf.zeros_like(action_tokens) # b/260260205
# Because of b/254902773, we need to add 1 extra dimension.
action_tokens = tf.expand_dims(action_tokens, axis=-2)
# assemble token sequence
input_token_sequence = tf.concat([context_image_tokens, action_tokens],
axis=2)
input_token_sequence = tf.reshape(
input_token_sequence, [batch_size, -1, self._token_embedding_size])
return input_token_sequence
def _extract_context_from_observation(self, observations, seq_len):
"""Extract context from observation."""
context = None
if 'natural_language_embedding' in observations:
outer_rank = self._get_outer_rank(observations)
context = observations['natural_language_embedding'] # [b, t, emb-size]
if outer_rank == 1:
context = tf.tile(context[:, None], [1, seq_len, 1])
return context
def set_actions(self, actions: tensorspec_utils.TensorSpecStruct):
"""Sets actions that will be tokenized and used in transformer network.
Args:
actions: actions to be tokenized and used in transformer network. example
actions are terminate = [0, 1] world_vector = [0.9, 0.8, -0.3]
rotation_delta = [-0.1, 0.2, .6] gripper_closedness = 0.9
"""
self._actions = actions
def _get_outer_rank(self, observations):
# used to determine training vs inference call
# outer_rank will be 2 -> [b, t] during training and
# outer_rank will be 1 -> [b] during inference
return nest_utils.get_outer_rank(observations, self._input_tensor_spec)
def _get_batch_size_and_seq_len(self, network_state):
image_shape = tf.shape(network_state['context_image_tokens'])
b = image_shape[0]
t = image_shape[1]
return b, t
def get_actor_loss(self) -> tf.Tensor:
return self._loss
def get_aux_info(self) -> dict[str, Any]:
return self._aux_info