Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Update implementation of rev_block to use new fn_with_custom_grad (wh…
Browse files Browse the repository at this point in the history
…ich limits usage of Defun)

PiperOrigin-RevId: 165525242
  • Loading branch information
Ryan Sepassi committed Aug 17, 2017
1 parent 3e295e7 commit f5d5405
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 144 deletions.
237 changes: 168 additions & 69 deletions tensor2tensor/layers/rev_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from __future__ import division
from __future__ import print_function

import random
import re

# Dependency imports
Expand Down Expand Up @@ -137,12 +138,112 @@ def _rev_block_forward(x1,
if layer_scopes is not None:
layer_scopes.append(layer_vs)
out = _rev_layer_forward(
out, f, g, f_side_input, g_side_input, gate_outputs=gate_outputs)
out,
f[i],
g[i],
f_side_input,
g_side_input,
gate_outputs=gate_outputs)

y1, y2 = out
return y1, y2


def _underlying_variable(t):
"""Find the underlying variable ref, ignoring Identity ops."""
while t.op.type == "Identity":
t = t.op.inputs[0]
if t.dtype == dtypes.float32_ref and "Variable" in t.op.type:
return t
else:
return None


def fn_with_custom_grad(grad_fn):
"""Decorator to create a subgraph with a custom gradient function.
The subgraph created by the decorated function is NOT put in a Defun and so
does not suffer from the limitations of the Defun (all subgraph ops on the
same device, no summaries).
Args:
grad_fn: function with signature
(inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars),
all of which are lists of Tensors.
Returns:
Decorator for function such that the gradient is defined by grad_fn.
"""

def dec(fn):

def wrapped(*args):
return _fn_with_custom_grad(fn, args, grad_fn)

return wrapped

return dec


def _fn_with_custom_grad(fn, inputs, grad_fn):
"""Create a subgraph with a custom gradient.
Args:
fn: function that takes inputs as arguments and produces 1 or more Tensors.
inputs: list<Tensor>, will be passed as fn(*inputs).
grad_fn: function with signature
(inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars),
all of which are lists of Tensors.
Returns:
fn(*inputs)
"""
with tf.variable_scope(None, default_name="fn_with_custom_grad") as vs:
inputs = list(inputs)
outputs = fn(*inputs)
train_vars = list(vs.trainable_variables())

if grad_fn is None:
return outputs
else:
if not (isinstance(outputs, tuple) or isinstance(outputs, list)):
outputs = [outputs]
outputs = list(outputs)

in_types = [t.dtype for t in inputs]
out_types = [t.dtype for t in outputs]
var_types = [t.dtype for t in train_vars]

def custom_grad_fn(op, *dys):
"""Custom grad fn applying grad_fn for identity Defun."""
dys = list(dys)
fn_inputs = op.inputs[:len(inputs)]
fn_vars = op.inputs[len(inputs):len(inputs) + len(train_vars)]
fn_outputs = op.inputs[len(inputs) + len(train_vars):]
assert len(fn_outputs) == len(outputs)
assert len(fn_outputs) == len(dys)

grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys)
grad_outputs = [None] * len(fn_outputs)
return tuple(grad_inputs + grad_vars + grad_outputs)

# The Defun takes as input the original inputs, the trainable variables
# created in fn, and the outputs. In the forward it passes through the
# outputs. In the backwards, it produces gradients for the original inputs
# and the trainable variables.
@function.Defun(
*(in_types + var_types + out_types),
func_name="identity_custom_grad%d" % random.randint(1, 10**9),
python_grad_func=custom_grad_fn,
shape_func=lambda _: [t.get_shape() for t in outputs])
def identity(*args):
outs = args[len(inputs) + len(train_vars):]
return tuple([tf.identity(t) for t in outs])

id_out = identity(*(inputs + train_vars + outputs))
return id_out


def rev_block(x1,
x2,
f,
Expand All @@ -156,19 +257,29 @@ def rev_block(x1,
A reversible residual layer is defined as:
```
y1 = x1 + f(x2)
y2 = x2 + g(y1)
y1 = x1 + f(x2, f_side_input)
y2 = x2 + g(y1, g_side_input)
```
A reversible residual block, defined here, is a series of reversible residual
layers.
Limitations:
* f and g must not close over any Tensors; all side inputs to f and g should
be passed in with f_side_input and g_side_input which will be forwarded to
f and g.
* f and g must not change the dimensionality of their inputs in order for the
addition in the equations above to work.
Args:
x1: a float Tensor.
x2: a float Tensor.
f: a function, (Tensor) -> (Tensor). Should not change the shape of the
Tensor. Expected to create variables. See f_side_input if there are side
inputs.
g: a function, (Tensor) -> (Tensor). Should not change the shape of the
Tensor. Expected to create variables. See g_side_input if there are side
inputs.
f: a function, (Tensor) -> (Tensor) (or list of such of length num_layers).
Should not change the shape of the Tensor. Expected to create variables.
See f_side_input if there are side inputs.
g: a function, (Tensor) -> (Tensor) (or list of such of length num_layers).
Should not change the shape of the Tensor. Expected to create variables.
See g_side_input if there are side inputs.
num_layers: int, number of reversible residual layers. Each layer will
apply f and g according to the equations above, with new variables in each
layer.
Expand All @@ -185,46 +296,43 @@ def rev_block(x1,
f_side_input = []
if g_side_input is None:
g_side_input = []
if isinstance(f, list):
assert len(f) == num_layers
else:
f = [f] * num_layers
if isinstance(g, list):
assert len(g) == num_layers
else:
g = [g] * num_layers

# Filled by the forward function below
layer_scopes = []

def rev_block_grad(op, *grad_ys):
def custom_grad_fn(inputs, variables, ys, grad_ys):
"""Custom gradient fn for a block of reversible residual layers."""
ys = (op.outputs[0], op.outputs[1])

# The Defun will have as inputs the main inputs (x1, x2), the variables
# created inside f and g, and the side inputs to f and g. The order of the
# grads returned from this function must match the order of the inputs.
# The code here partitions the hoisted inputs into f variables, f side
# inputs, g variables, and g side inputs and keeps track of their positions
# in hoisted_inputs.

hoisted_inputs = op.inputs[2:]
f_vars = [[] for _ in range(num_layers)]
g_vars = [[] for _ in range(num_layers)]
f_vars_idxs = [[] for _ in range(num_layers)]
g_vars_idxs = [[] for _ in range(num_layers)]
side_inputs = inputs[2:]
f_side_idxs = [None] * len(f_side_input)
g_side_idxs = [None] * len(g_side_input)
assert len(side_inputs) == len(f_side_input) + len(g_side_input)

for t in f_side_input + g_side_input:
assert t in hoisted_inputs

for i, t in enumerate(hoisted_inputs):
# Side inputs
for i, t in enumerate(side_inputs):
if t in f_side_input:
f_side_idxs[f_side_input.index(t)] = i
continue
if t in g_side_input:
elif t in g_side_input:
g_side_idxs[g_side_input.index(t)] = i
continue
else:
assert False

# Variables
ref = t.op.inputs[0]
assert ref.dtype == dtypes.float32_ref
f_vars = [[] for _ in range(num_layers)]
g_vars = [[] for _ in range(num_layers)]
f_vars_idxs = [[] for _ in range(num_layers)]
g_vars_idxs = [[] for _ in range(num_layers)]

for i, t in enumerate(variables):
ref = _underlying_variable(t)

# Use the name to identify the layer number and function (f or g)
regex = LAYER_RE.match(t.name)
regex = LAYER_RE.match(ref.name)
layer_no = int(regex.group(1))
fn_name = regex.group(2)
if fn_name == "f":
Expand All @@ -244,12 +352,15 @@ def rev_block_grad(op, *grad_ys):
layer_scopes.reverse()
f_vars.reverse()
g_vars.reverse()
f.reverse()
g.reverse()

for i in xrange(num_layers):
with tf.variable_scope(layer_scopes[i], reuse=True):
ys, grad_ys, f_ret, g_ret = (_rev_layer_backward(
ys, grad_ys, f, g, f_vars[i], f_side_input, g_vars[i],
g_side_input))

ys, grad_ys, f_ret, g_ret = _rev_layer_backward(ys, grad_ys, f[i], g[i],
f_vars[i], f_side_input,
g_vars[i], g_side_input)

grad_f_vars, grad_f_side = f_ret
grad_g_vars, grad_g_side = g_ret
Expand All @@ -262,8 +373,9 @@ def rev_block_grad(op, *grad_ys):
acc_f_side_grads = _acc_grads(*f_side_grads)
acc_g_side_grads = _acc_grads(*g_side_grads)

# Use the stored idxs to put gradients in the same order as hoisted_inputs.
hoisted_inputs_grads = [None] * len(hoisted_inputs)
# Use the stored idxs to put gradients in the passed-in order.
side_input_grads = [None] * len(side_inputs)
variable_grads = [None] * len(variables)

# Variable gradients were collected in reverse layer order. Reverse to match
# idxs.
Expand All @@ -272,43 +384,30 @@ def rev_block_grad(op, *grad_ys):
for idxs, grads in zip(f_vars_idxs, f_var_grads) + zip(
g_vars_idxs, g_var_grads):
for i, grad in zip(idxs, grads):
hoisted_inputs_grads[i] = grad
variable_grads[i] = grad

for i, grad in zip(f_side_idxs, acc_f_side_grads):
hoisted_inputs_grads[i] = grad
side_input_grads[i] = grad
for i, grad in zip(g_side_idxs, acc_g_side_grads):
hoisted_inputs_grads[i] = grad
side_input_grads[i] = grad

grad_x1, grad_x2 = grad_ys
return [grad_x1, grad_x2] + hoisted_inputs_grads

@function.Defun(
tf.float32,
tf.float32,
python_grad_func=rev_block_grad,
shape_func=lambda _: [x1.get_shape(), x2.get_shape()])
def rev_block_defun(inp1, inp2):
inp1.set_shape(x1.get_shape())
inp2.set_shape(x2.get_shape())
return _rev_block_forward(
inp1,
inp2,
f,
g,
num_layers=num_layers,
f_side_input=f_side_input,
g_side_input=g_side_input,
layer_scopes=layer_scopes,
gate_outputs=True)
return [grad_x1, grad_x2] + side_input_grads, variable_grads

if is_training:
return rev_block_defun(x1, x2)
else:
# Need a forward function with positional arguments
@fn_with_custom_grad(custom_grad_fn if is_training else None)
def forward(x1, x2, *side_inputs):
f_side = side_inputs[:len(f_side_input)]
g_side = side_inputs[len(f_side_input):]
return _rev_block_forward(
x1,
x2,
f,
g,
num_layers=num_layers,
f_side_input=f_side_input,
g_side_input=g_side_input)
f_side_input=f_side,
g_side_input=g_side,
layer_scopes=layer_scopes,
gate_outputs=is_training)

return forward(x1, x2, *(f_side_input + g_side_input))
Loading

0 comments on commit f5d5405

Please sign in to comment.