forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
updaters.py
311 lines (267 loc) · 11.3 KB
/
updaters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Data Parallel Updater for Graph2text data."""
import functools
import os
import pickle
from absl import logging
import haiku as hk
import jax
from jax.tree_util import tree_map
import numpy as np
import optax
def call_fn_with_state_keys(jit_fn, state, other_inputs, keys):
"""Executes `jit_fn`, filtering out all keys except some subset."""
state = state.copy()
extra_state = {}
for k in list(state.keys()):
if k not in keys:
extra_state[k] = state.pop(k)
return jit_fn(state, *other_inputs), extra_state
class Updater:
"""Graph2text model updater with multi-GPU support."""
def __init__(self, loss_fn, optimizer, devices=None, has_graph=False):
self._net_init_fn, self._apply_fn = hk.transform_with_state(
functools.partial(loss_fn, is_training=True))
_, self._eval_apply_fn = hk.transform_with_state(
functools.partial(loss_fn, is_training=False))
if optimizer is None:
optimizer = optax.identity()
self._optimizer = optimizer
self._num_devices = jax.local_device_count()
if devices is None:
devices = []
for host_id in range(jax.process_count()):
for device_id in jax.local_devices(host_id):
devices.append(device_id)
else:
self._num_devices = min(self._num_devices, len(devices))
def _pmap(f, static_broadcasted_argnums=()):
return jax.pmap(f, axis_name='i', devices=devices,
static_broadcasted_argnums=static_broadcasted_argnums)
def handle_graph_size(fn):
def _fn(*args):
batch = args[-1].copy()
max_graph_size = batch['max_graph_size']
del batch['max_graph_size']
args = args[:-1] + (batch, max_graph_size)
return fn(*args)
return _fn
# Try to jit.
if has_graph:
# If the model contains full graphs, we need to set the max_graph_size
# as a statically broadcasted argument.
self._init_fn = handle_graph_size(_pmap(self._init, 4))
self._update_fn = handle_graph_size(_pmap(self._update, 2))
self._eval_fn = handle_graph_size(_pmap(self._eval, 2))
else:
self._init_fn = _pmap(self._init)
self._update_fn = _pmap(self._update)
self._eval_fn = _pmap(self._eval)
def _init(self, master_rng, params, network_state, data, max_graph_size=None):
"""Initializes state of the updater."""
out_rng, init_rng = jax.random.split(master_rng)
if max_graph_size is not None:
new_params, new_network_state = self._net_init_fn(
init_rng, data, max_graph_size)
else:
new_params, new_network_state = self._net_init_fn(init_rng, data)
if params is None:
params = new_params
if network_state is None:
network_state = new_network_state
opt_state = self._optimizer.init(params)
return dict(
replicated_step=0,
rng=out_rng,
state=network_state,
opt_state=opt_state,
params=params,
)
def init(self, master_rng, data, params=None, network_state=None,
replicated_params=False):
"""Initializes state of the updater."""
data = self._preprocess(data)
rngs = np.array([master_rng] * self._num_devices)
if not replicated_params and params is not None:
params = jax.tree_map(
lambda x: np.array([x] * self._num_devices), params)
state = self._init_fn(rngs, params, network_state, data)
state['step'] = np.array(0, dtype=np.int64)
# Wait for initialization to finish before starting training to keep
# memory usage low.
flat_params = jax.tree_leaves(state['params'])
if flat_params:
jax.tree_leaves(state['params'])[0].block_until_ready()
return state
def _update(self, state, data, max_graph_size=None):
"""Updates parameters."""
replicated_step = state['replicated_step']
rng = state['rng']
opt_state = state['opt_state']
params = state['params']
net_state = state['state']
rng, new_rng = jax.random.split(rng)
rng = jax.random.fold_in(rng, jax.lax.axis_index('i'))
def _loss(params, state, batch, rng):
if max_graph_size is not None:
(loss, metrics), state = self._apply_fn(params, state, rng, batch,
max_graph_size)
else:
(loss, metrics), state = self._apply_fn(params, state, rng, batch)
return loss, (metrics, state)
(loss, (metrics, new_net_state)), g = jax.value_and_grad(
_loss, has_aux=True)(params, net_state, data, rng)
g = jax.lax.pmean(g, axis_name='i')
loss = jax.lax.pmean(loss, axis_name='i')
metrics = jax.lax.pmean(metrics, axis_name='i')
updates, new_opt_state = self._optimizer.update(g, opt_state, params)
new_params = optax.apply_updates(params, updates)
new_state = dict(
replicated_step=replicated_step + 1,
rng=new_rng,
state=new_net_state,
opt_state=new_opt_state,
params=new_params,
)
metrics['loss'] = loss
metrics['step'] = replicated_step
return new_state, metrics
def update(self, state, data):
"""Updates the state using some data and returns metrics."""
data = self._preprocess(data)
(state, out), extra_state = call_fn_with_state_keys(
self._update_fn, state, [data], keys=set([
'state', 'params', 'rng', 'replicated_step', 'opt_state']))
state.update(extra_state)
state['step'] += 1
return state, tree_map(lambda x: x[0], out)
def _eval(self, state, data, max_graph_size=None):
"""Evaluates the current state on the given data."""
if max_graph_size is not None:
(loss, metrics), new_state = self._eval_apply_fn(
state['params'], state['state'], state['rng'], data, max_graph_size)
else:
(loss, metrics), new_state = self._eval_apply_fn(
state['params'], state['state'], state['rng'], data)
state['state'] = new_state
loss = jax.lax.pmean(loss, axis_name='i')
metrics = jax.lax.pmean(metrics, axis_name='i')
metrics['loss'] = loss
metrics['step'] = state['replicated_step']
return state, metrics
def eval_return_state(self, state, data):
"""Returns metrics without updating the model."""
data = self._preprocess(data)
(state, out), extra_state = call_fn_with_state_keys(
self._eval_fn, state, [data], keys=set([
'state', 'params', 'rng', 'replicated_step']))
state.update(extra_state)
return state, tree_map(lambda x: x[0], out)
def eval(self, state, data):
"""Returns metrics without updating the model."""
_, out = self.eval_return_state(state, data)
return out
def _preprocess(self, data):
"""Reshapes input so that it can be distributed across multiple cores."""
multi_inputs = data.copy()
def add_core_dimension(x):
if np.isscalar(x):
return x
if x.shape[0] % self._num_devices != 0:
raise ValueError(f'The batch size must be a multiple of the number of'
f' devices. Got batch size = {x.shape[0]} and number'
f' of devices = {self._num_devices}.')
prefix = (self._num_devices, x.shape[0] // self._num_devices)
return np.reshape(x, prefix + x.shape[1:])
multi_inputs = tree_map(add_core_dimension, multi_inputs)
return multi_inputs
def params(self, state):
"""Returns model parameters."""
return tree_map(lambda x: x[0], state['params'])
def opt_state(self, state):
"""Returns the state of the optimiser."""
return tree_map(lambda x: x[0], state['opt_state'])
def network_state(self, state):
"""Returns the model's state."""
return tree_map(lambda x: x[0], state['state'])
def to_checkpoint_state(self, state):
"""Transforms the updater state into a checkpointable state."""
checkpoint_state = state.copy()
# Wrapper around checkpoint_state['step'] so we can get [0].
checkpoint_state['step'] = checkpoint_state['step'][np.newaxis]
# Unstack the replicated contents.
checkpoint_state = tree_map(lambda x: x[0], checkpoint_state)
return checkpoint_state
def from_checkpoint_state(self, checkpoint_state):
"""Initializes the updater state from the checkpointed state."""
# Expand the checkpoint so we have a copy for each device.
state = tree_map(lambda x: np.stack(jax.local_device_count() * [x]),
checkpoint_state)
state['step'] = state['step'][0] # Undo stacking for step.
return state
class CheckpointingUpdater:
"""A checkpointing wrapper around an Updater."""
def __init__(self,
inner: Updater,
checkpoint_dir: str):
self._inner = inner
self._checkpoint_dir = checkpoint_dir
def _checkpoint_paths(self):
return [p for p in os.listdir(self._checkpoint_dir) if 'checkpoint' in p]
def init(self, rng, data, params=None, network_state=None):
"""Initialize experiment state."""
if not os.path.exists(self._checkpoint_dir) or not self._checkpoint_paths():
os.makedirs(self._checkpoint_dir, exist_ok=True)
return self._inner.init(rng, data, params, network_state)
return self.load_checkpoint()
def init_from_checkpoint(self, rng, data, checkpoint_state):
params = self._inner.params(checkpoint_state)
network_state = None
return self._inner.init(rng, data, params, network_state)
def eval_return_state(self, state, data):
return self._inner.eval_return_state(state, data)
def save_checkpoint(self, state):
path = os.path.join(self._checkpoint_dir, 'checkpoint.pkl')
logging.info('Serializing experiment state to %s', path)
checkpoint_state = self._inner.to_checkpoint_state(jax.device_get(state))
with open(path, 'wb') as f:
pickle.dump(checkpoint_state, f)
def load_checkpoint(self):
checkpoint = os.path.join(self._checkpoint_dir,
self._checkpoint_paths()[-1])
logging.info('Loading checkpoint from %s', checkpoint)
with open(checkpoint, 'rb') as f:
state = pickle.load(f)
return self._inner.from_checkpoint_state(state)
def update(self, state, data):
"""Update experiment state."""
state, out = self._inner.update(state, data)
return state, out