diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index 640730864..0e305ef54 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -76,7 +76,7 @@ def shakeshake2_py(x, y, equal=False, individual=False): """The shake-shake sum of 2 tensors, python version.""" if equal: alpha = 0.5 - if individual: + elif individual: alpha = tf.random_uniform(tf.get_shape(x)[:1]) else: alpha = tf.random_uniform([])