-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_util.py
400 lines (332 loc) · 19.4 KB
/
train_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
import os
import logging
import tensorflow as tf
import numpy as np
import time
import copy
from datetime import timedelta
from data.qm9_loader import QM9Loader
from data.molecules import TFMolBatch
class CurveSmoother:
"""This class imitates the behavior of the TensorBoard smoothing slider by performing a running average.
For instance, to smooth a loss curve, instantiate the class and pass the values to the smooth function one by one.
:param smoothing_factor: controls the amount of smoothing. 0 = no smoothing, 1 = stuck at initial value.
:raises ValueError: If the smoothing factor is outside the valid range [0, 1].
"""
def __init__(self, smoothing_factor):
if not 0 <= smoothing_factor <= 1:
raise ValueError('Smoothing factor must lie between 0 and 1.')
self._smoothing_factor = smoothing_factor
self._last_smoothed_value = None
def smooth(self, new_value):
"""Smooth the value based on all previous values passed to this function.
:param new_value: The latest value of the curve that is to be smoothed.
:return: smoothed value
"""
if self._last_smoothed_value is None:
smoothed_value = new_value
else:
smoothed_value = self._last_smoothed_value * self._smoothing_factor + \
(1 - self._smoothing_factor) * new_value
self._last_smoothed_value = smoothed_value
return smoothed_value
class QM9Trainer:
"""Abstract base class for training and evaluating models on the QM9 data set.
All the logic for loading data and performing training is defined here, while the model and the evaluation
is to be defined in the concrete sub class.
:param data_dir: directory containing the QM9 files *.sdf, *_labels.csv (*=[training|validation|test])
:param train_log_interval: Write training log after this many steps.
:param val_log_interval: Write validation log after this many steps.
:param name: Name of the experiment/training that is performed.
:param implicit_hydrogen: If True, hydrogen atoms will be treated implicitly.
:param patience: Stop training early if the (smoothed) validation loss has not improved for this many steps.
:param loss_smoothing: Early stopping is decided based on a running average of the validation loss.
This parameter controls the amount of smoothing and corresponds to the TensorBoard smoothing slider.
:param property_names: List of QM9 properties that should be used for training.
"""
def __init__(self, data_dir, train_log_interval=250, val_log_interval=10000, name='', implicit_hydrogen=True,
patience=float('inf'), loss_smoothing=0.8, property_names=None):
self.data_dir = data_dir
self.name = name
self.results_dir = os.path.join(os.path.dirname(__file__), name + '_results')
os.makedirs(self.results_dir, exist_ok=True)
self.train_log_interval = train_log_interval
self.val_log_interval = val_log_interval
self.patience = patience
self.loss_smoothing = loss_smoothing
self.property_names = QM9Loader.all_property_names if property_names is None else property_names
self.implicit_hydrogen = implicit_hydrogen
# initialized in subclass constructor
self.featurizer = None
# updated by run_trainings
self._config_name = None # name of currently trained hyperparameter configuration
self._current_config_number = 0
self._num_configs = 0
# initialized by _prepare_data
self._standardization = None
self._train_iterator, self._val_iterator, self._test_iterator = None, None, None
self._train_mols, self._val_mols, self._test_mols = None, None, None
# initialized by _build_model
self._global_step = None # tf.Variable for controlling learning rate decay
self._train_op = None
# initialized by _train
self._sess = None
self._summary_writer = None
# initialized by _init_saver and _restore_saved_model
self._step = 0
self._saver = None
self._checkpoint_dir = None
def run_trainings(self, hparam_configs, num_steps):
"""Run training and evaluation for different hyperparameter configurations.
If a configuration with the same name has been trained before, its weights are restored from the checkpoint
and training is started from the step saved in the checkpoint.
If that step is higher than num_steps, training is skipped. To only run evaluation, e.g. use num_steps=0.
:param hparam_configs: dict of tf.contrib.training.HParams objects
:param num_steps: Number of steps (=batches) to train.
"""
self._num_configs += len(hparam_configs)
self._save_hparam_configs(hparam_configs)
for config_name, hparam_config in hparam_configs.items():
self._current_config_number += 1
self._config_name = config_name
logging.info('==== Configuration %d / %d: ' + config_name + ' ====', self._current_config_number,
self._num_configs)
tf.reset_default_graph()
self._sess = tf.Session()
logging.info('Preparing data.')
self._prepare_data(hparam_config.batch_size)
logging.info('Building model.')
self._build_model(hparam_config)
self._sess.run(tf.global_variables_initializer()) # run before restoring model, otherwise it will overwrite
self._init_saver()
self._restore_saved_model() # restore previously trained model from disk
if num_steps > self._step:
logging.info('Starting training.')
self._train(num_steps)
logging.info('Training complete.')
self._restore_saved_model() # restore model with best validation loss during current training
logging.info('Starting evaluation.')
self._eval_results()
self._sess.close()
logging.info('All trainings complete.')
def _save_hparam_configs(self, hparam_configs):
"""Save the hyperparameter configurations to json files in the results folder.
It is very useful to store the the hyperparameter configurations along with the results they produced.
:param hparam_configs: dict of tf.contrib.training.HParams objects
"""
config_save_dir = os.path.join(self.results_dir, 'configs')
os.makedirs(config_save_dir, exist_ok=True)
for config_name, hparam_config in hparam_configs.items():
with open(os.path.join(config_save_dir, config_name + '.json'), 'w') as f:
f.write(hparam_config.to_json(2))
def _prepare_data(self, batch_size):
"""Create data iterators and TFMolBatches for the QM9 data.
:param batch_size: Number of molecules per batch.
"""
qm9_loader = QM9Loader(self.data_dir, self.featurizer, self.property_names, standardize_labels=True)
self._standardization = qm9_loader.standardization
def create_iterator(data_set, training=True):
"""Create a data iterator from the given tf.data.Dataset."""
data_set = data_set.cache()
if training:
data_set = data_set.shuffle(buffer_size=10000, reshuffle_each_iteration=True)
data_set = data_set.repeat()
data_set = data_set.batch(batch_size)
data_set = data_set.prefetch(buffer_size=1)
if training:
return data_set.make_one_shot_iterator()
return data_set.make_initializable_iterator()
self._train_iterator = create_iterator(qm9_loader.train_data, training=True)
self._val_iterator = create_iterator(qm9_loader.val_data, training=False)
self._test_iterator = create_iterator(qm9_loader.test_data, training=False)
with tf.name_scope('train_data'):
train_data = self._train_iterator.get_next()
self._train_mols = TFMolBatch(train_data['atoms'], labels=train_data['labels'],
distance_matrix=train_data['interactions'][..., 0], # squeeze interaction dim
coordinates=train_data['coordinates'])
with tf.name_scope('val_data'):
val_data = self._val_iterator.get_next()
self._val_mols = TFMolBatch(val_data['atoms'], labels=val_data['labels'],
distance_matrix=val_data['interactions'][..., 0],
coordinates=val_data['coordinates'])
with tf.name_scope('test_data'):
test_data = self._test_iterator.get_next()
self._test_mols = TFMolBatch(test_data['atoms'], labels=test_data['labels'],
distance_matrix=test_data['interactions'][..., 0],
coordinates=test_data['coordinates'])
def _init_saver(self):
"""Initialize the tf.train.Saver to save and restore the model to/from disk."""
self._saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)
self._checkpoint_dir = os.path.join(self.results_dir, 'checkpoints', 'checkpoints_' + self._config_name)
if not os.path.exists(self._checkpoint_dir):
os.makedirs(self._checkpoint_dir)
def _restore_saved_model(self):
"""Restore the model from checkpoint (if available) and set the current training step accordingly. """
latest_checkpoint = tf.train.latest_checkpoint(self._checkpoint_dir)
if latest_checkpoint is None:
self._step = 0
else:
self._saver.restore(self._sess, latest_checkpoint)
self._step = int(latest_checkpoint.split('-')[-1])
def _train(self, num_steps):
"""Perform training until num_steps is reached or the validation loss stops improving (early stopping).
Training starts at self._step with an initial validation via self._write_val_log().
At each step, self._train_op is run.
Every self.train_log_interval steps, self._write_train_log() is called.
Every self.val_log_interval steps, self._write_val_log() is called.
If the validation loss has improved, the model is saved to disk.
Early stopping is done when the smoothed validation loss curve has not improved for self.patience steps.
:param num_steps: Training step at which training stops.
"""
sess = self._sess
self._summary_writer = tf.summary.FileWriter(
os.path.join(self.results_dir, 'logs', 'logs_' + self._config_name), sess.graph)
start_step = self._step
sess.run(self._global_step.assign(start_step))
sess.graph.finalize()
early_stopping_smoother = CurveSmoother(smoothing_factor=self.loss_smoothing)
best_val_loss = self._write_val_log()
best_val_loss_smoothed = early_stopping_smoother.smooth(best_val_loss)
best_step_smoothed = start_step
logging.info('%d / %d: Initial validation yields loss %f', self._step, num_steps, best_val_loss)
start_time = time.time()
for self._step in range(start_step + 1, num_steps + 1):
# check early stopping
if (self._step - best_step_smoothed) > self.patience:
logging.info('Out of patience.')
break
# validate
if self._step % self.val_log_interval == 0:
val_loss = self._write_val_log()
best_log = ''
if val_loss < best_val_loss:
self._saver.save(sess, os.path.join(self._checkpoint_dir, 'checkpoint'), self._step)
best_val_loss = val_loss
best_log = ' (BEST so far)'
logging.info('%d / %d: Validation yields loss %f' + best_log, self._step, num_steps, val_loss)
# smooth validation loss for early stopping
val_loss_smoothed = early_stopping_smoother.smooth(val_loss)
if val_loss_smoothed < best_val_loss_smoothed:
best_val_loss_smoothed = val_loss_smoothed
best_step_smoothed = self._step
# train
if self._step % self.train_log_interval == 0:
self._write_train_log()
logging.info('%d / %d: Training summary written', self._step, num_steps)
else:
sess.run(self._train_op)
# estimate remaining time
num_steps_since_start = self._step - start_step
if self._step % (10 * self.train_log_interval) == 0 and num_steps_since_start >= self.val_log_interval:
seconds_since_start = time.time() - start_time
num_steps_to_train = num_steps - start_step
remaining_seconds = seconds_since_start * (num_steps_to_train / num_steps_since_start - 1)
formatted_remaining_time = str(timedelta(seconds=int(remaining_seconds)))
logging.info(formatted_remaining_time + ' remaining for configuration %d / %d',
self._current_config_number, self._num_configs)
def _build_model(self, hparams):
"""Build the model given the hyperparameter configuration.
Needs to be implemented in subclass and initialize self._train_op and self._global_step.
:param hparams: tf.contrib.training.HParams object
"""
raise NotImplementedError('Model must be defined in child class.')
def _write_train_log(self):
"""Perform training step and and write training log. Needs to be implemented in subclass."""
raise NotImplementedError('Must be implemented in child class.')
def _write_val_log(self):
"""Perform validation, write validation log and return validation loss. Needs to be implemented in subclass."""
raise NotImplementedError('Must be implemented in child class.')
def _eval_results(self):
"""Evaluate results after training is complete. Needs to be implemented in subclass."""
raise NotImplementedError('Must be implemented in child class.')
def _write_eval_results_to_file(self, result_dict):
"""Write results of post-training evaluation into one central results file.
:param result_dict: Dictionary containing evaluation results.
"""
results_filename = os.path.join(self.results_dir, 'results.txt')
if not os.path.isfile(results_filename): # file does not exist yet
with open(results_filename, 'w') as f:
header = 'config' + '\t' + '\t'.join(result_dict.keys()) + '\n'
f.write(header)
with open(results_filename, 'a') as f:
data = self._config_name + '\t' + '\t'.join([str(v) for v in result_dict.values()]) + '\n'
f.write(data)
logging.info('Evaluation results for config ' + self._config_name + ' written to ' + results_filename)
def _average_over_dataset(self, data_iterator, eval_tensors):
"""Calculate the average values of eval_tensors across the specified data set.
:param data_iterator: The initializable_iterator of the relevant data set.
:param eval_tensors: The one-dimensional tensors to be evaluated and averaged.
:return: The average values of eval_tensors.
"""
self._sess.run(data_iterator.initializer)
values = []
while True:
try:
value = self._sess.run(eval_tensors)
values.append(value)
except tf.errors.OutOfRangeError:
break
values_np = np.array(values)
avg_values = np.mean(values_np, axis=0)
return avg_values
class ConfigReader:
"""Read hyperparameter configurations from json files.
Keeps track of whether new files have been added since the last read, such that after training completion,
training can continue directly with new configurations that have been added in the meantime.
:param config_dir: All files in this directory ending with .json will be read.
:param default_hparams: The tf.contrib.training.HParams object with the default hyperparameter values.
"""
def __init__(self, config_dir, default_hparams):
self.config_dir = config_dir
self.previous_config_files = set()
self.default_hparams = default_hparams
def _get_new_config_files(self):
"""List all json files in config_dir that have not been there at the previous call to this method.
:return: List of file paths.
"""
config_files = {os.path.join(self.config_dir, f) for f in os.listdir(self.config_dir) if f.endswith('.json')}
new_config_files = config_files - self.previous_config_files
self.previous_config_files |= new_config_files
return new_config_files
def get_new_hparam_configs(self):
"""Get all hyperparameter configs in config_dir that have not been there at the previous call to this method.
:return: dict of hyperparameter configs; config_name => tf.contrib.training.HParams object
:raises KeyError: If a hyperparameter in the config file does not match any of the default hyperparameters.
"""
new_config_files = self._get_new_config_files()
hparam_configs = {}
for config_file in new_config_files:
with open(config_file, 'r') as f:
hparams = copy.deepcopy(self.default_hparams)
filename = os.path.basename(config_file)
config_name = os.path.splitext(filename)[0]
try:
hparam_configs[config_name] = hparams.parse_json(f.read())
except KeyError as e:
raise KeyError('There is a parameter in the configuration ' + config_name +
' which does not match any of the default parameters: ' + str(e)) from None
return hparam_configs
def read_configs_and_train(trainer, default_hparams, num_steps, config_dir=None):
"""Read hyperparameter configurations from config_dir and run the respective trainings.
If config_dir is specified, the directory is checked for hyperparameter configurations and the respective
trainings are started.
If new configurations have been added in the meantime, training will continue with these.
If config_dir is None, training will be just be run with the default configuration default_hparams.
:param trainer: Subclass of QM9Trainer to perform the training.
:param default_hparams: tf.contrib.training.HParams with the default hyperparameter configuration.
:param num_steps: Number of steps/batches to train
:param config_dir: Directory containing json files with HParams.
"""
if config_dir is not None:
config_reader = ConfigReader(config_dir, default_hparams)
new_hparam_configs = config_reader.get_new_hparam_configs()
while len(new_hparam_configs) > 0:
logging.info('Found %d new hyperparameter configurations.', len(new_hparam_configs))
trainer.run_trainings(new_hparam_configs, num_steps)
logging.info('Checking for new hyperparameter configurations that have been added during training.')
new_hparam_configs = config_reader.get_new_hparam_configs()
logging.info('No new hyperparameter configurations found.')
else:
logging.info('No directory with hyperparameter configurations specified. Using default values.')
trainer.run_trainings({'default': default_hparams}, num_steps)
logging.info('Done.')