Skip to content

Commit

Permalink
add support for separate optimizers for the embedding layer
Browse files Browse the repository at this point in the history
  • Loading branch information
chengmengli06 committed Oct 8, 2023
1 parent 563268c commit d2f81ae
Show file tree
Hide file tree
Showing 10 changed files with 492 additions and 38 deletions.
35 changes: 33 additions & 2 deletions docs/source/models/wide_and_deep.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,40 @@ model_config:{

- embedding_regularization: 对embedding部分加regularization,防止overfit

- input_type: 如果在提交到pai-tf集群上面运行,读取max compute 表作为输入数据,data_config:input_type要设置为OdpsInputV2。
- input_type: 如果在提交到pai-tf集群上面运行,读取MaxCompute 表作为输入数据,data_config:input_type要设置为OdpsInputV2。

#### 2.多优化器

- WideAndDeep模型可以配置2个或者3个优化器(optimizer)
- 配置2个优化器(optimizer), wide参数使用第一个优化器, 其它参数使用第二个优化器
- 配置3个优化器(optimizer), wide参数使用第一个优化器, deep embedding使用第二个优化器, 其它参数使用第三个优化器
- 配置实例(2 optimizers, samples/model_config/wide_and_deep_two_opti.config):
```protobuf
optimizer_config: {
ftrl_optimizer: {
l1_reg: 10
learning_rate: {
constant_learning_rate {
learning_rate: 0.0005
}
}
}
}
optimizer_config {
adam_optimizer {
learning_rate {
constant_learning_rate {
learning_rate: 0.0001
}
}
}
}
```
- 代码参考: easy_rec/python/model/wide_and_deep.py
- WideAndDeep.get_grouped_vars重载了EasyRecModel.get_grouped_vars

#### 2. 组件化模型
#### 3. 组件化模型

```protobuf
model_config: {
Expand Down
4 changes: 4 additions & 0 deletions docs/source/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
}
```

- 多优化器支持:
- 可以配置两个optimizer, 分别对应embedding权重和dense权重
- 实现参考EasyRecModel里面的get_grouped_vars

- sync_replicas: true # 是否同步训练,默认是false

- 使用SyncReplicasOptimizer进行分布式训练(同步模式)
Expand Down
4 changes: 3 additions & 1 deletion easy_rec/python/builders/optimizer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def build(optimizer_config):
config = optimizer_config.adagrad_optimizer
learning_rate = _create_learning_rate(config.learning_rate)
summary_vars.append(learning_rate)
optimizer = tf.train.AdagradOptimizer(learning_rate)
optimizer = tf.train.AdagradOptimizer(
learning_rate,
initial_accumulator_value=config.initial_accumulator_value)

if optimizer_type == 'adam_async_optimizer':
config = optimizer_config.adam_async_optimizer
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/model/easy_rec_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _train_model_fn(self, features, labels, run_config):
opt, learning_rate = optimizer_builder.build(tmp_config)
tf.summary.scalar('learning_rate', learning_rate[0])
all_opts.append(opt)
grouped_vars = model.get_grouped_vars()
grouped_vars = model.get_grouped_vars(len(all_opts))
assert len(grouped_vars) == len(optimizer_config), \
'the number of var group(%d) != the number of optimizers(%d)' \
% (len(grouped_vars), len(optimizer_config))
Expand Down
31 changes: 23 additions & 8 deletions easy_rec/python/model/easy_rec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import six
import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops.variables import PartitionedVariable
from tensorflow.python.ops import variables

from easy_rec.python.compat import regularizers
from easy_rec.python.layers import input_layer
Expand Down Expand Up @@ -114,7 +114,7 @@ def l2_regularization(self):
if hasattr(model_config, 'dense_regularization') and \
model_config.HasField('dense_regularization'):
# backward compatibility
tf.logging.warn(
logging.warn(
'dense_regularization is deprecated, please use l2_regularization')
l2_regularization = model_config.dense_regularization
elif hasattr(model_config, 'l2_regularization'):
Expand Down Expand Up @@ -229,7 +229,7 @@ def restore(self,
for x in shape_arr[1:]:
var_shape[0] += x[0]
var_shape = tensor_shape.TensorShape(var_shape)
variable = PartitionedVariable(
variable = variables.PartitionedVariable(
variable_name,
var_shape,
variable[0].dtype,
Expand All @@ -239,7 +239,7 @@ def restore(self,
var_shape = variable.shape.as_list()
if ckpt_var_shape == var_shape:
vars_in_ckpt[variable_name] = list(variable) if isinstance(
variable, PartitionedVariable) else variable
variable, variables.PartitionedVariable) else variable
elif len(ckpt_var_shape) == len(var_shape):
if force_restore_shape_compatible:
# create a variable compatible with checkpoint to restore
Expand Down Expand Up @@ -394,10 +394,25 @@ def get_restore_filter(self):
return restore_filter.CombineFilter(all_filters,
restore_filter.Logical.AND), None

def get_grouped_vars(self):
"""Get grouped variables, each group will be optimized by a separate optimizer.
def get_grouped_vars(self, opt_num):
"""Group the vars into different optimization groups.
Each group will be optimized by a separate optimizer.
Args:
opt_num: number of optimizers from easyrec config.
Return:
grouped_vars: list of list of variables
list of list of variables.
"""
raise NotImplementedError()
assert opt_num == 2, 'could only support 2 optimizers, one for embedding, one for the other layers'

embedding_vars = []
deep_vars = []
for tmp_var in variables.trainable_variables():
if tmp_var.name.startswith(
'input_layer') or '/embedding_weights' in tmp_var.name:
embedding_vars.append(tmp_var)
else:
deep_vars.append(tmp_var)
return [embedding_vars, deep_vars]
46 changes: 33 additions & 13 deletions easy_rec/python/model/wide_and_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,43 @@ def build_predict_graph(self):

return self._prediction_dict

def get_grouped_vars(self):
def get_grouped_vars(self, opt_num):
"""Group the vars into different optimization groups.
Each group will be optimized by a separate optimizer.
Args:
opt_num: number of optimizers from easyrec config.
Return:
list of list of variables.
"""
assert len(self._model_config.final_dnn.hidden_units) == 0, \
'if use different optimizers for wide group and deep group, '\
+ ' final_dnn should not be set.'
wide_vars = []
deep_vars = []
for tmp_var in tf.trainable_variables():
if tmp_var.name.startswith('input_layer') and \
(not tmp_var.name.startswith('input_layer_1')):
wide_vars.append(tmp_var)
else:
deep_vars.append(tmp_var)
return [wide_vars, deep_vars]
assert opt_num <= 3, 'could only support 2 or 3 optimizers, ' + \
'if opt_num = 2, one for the wide , and one for the others, ' + \
'if opt_num = 3, one for the wide, second for the deep embeddings, ' + \
'and third for the other layers.'

if opt_num == 2:
wide_vars = []
deep_vars = []
for tmp_var in tf.trainable_variables():
if tmp_var.name.startswith('input_layer') and \
(not tmp_var.name.startswith('input_layer_1')):
wide_vars.append(tmp_var)
else:
deep_vars.append(tmp_var)
return [wide_vars, deep_vars]
elif opt_num == 3:
wide_vars = []
embedding_vars = []
deep_vars = []
for tmp_var in tf.trainable_variables():
if tmp_var.name.startswith('input_layer') and \
(not tmp_var.name.startswith('input_layer_1')):
wide_vars.append(tmp_var)
elif tmp_var.name.startswith(
'input_layer') or '/embedding_weights' in tmp_var.name:
embedding_vars.append(tmp_var)
else:
deep_vars.append(tmp_var)
return [wide_vars, embedding_vars, deep_vars]
3 changes: 2 additions & 1 deletion easy_rec/python/protos/optimizer.proto
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ message AdamAsyncWOptimizer {
// Configuration message for the AdagradOptimizer
// See: https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer
message AdagradOptimizer {
optional LearningRate learning_rate = 1;
optional LearningRate learning_rate = 1;
optional float initial_accumulator_value = 2 [default = 0.1];
}

// Only available on pai-tf, which has better performance than AdamOptimizer
Expand Down
6 changes: 6 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,12 @@ def test_multi_optimizer(self):
'samples/model_config/wide_and_deep_two_opti.config', self._test_dir)
self.assertTrue(self._success)

def test_embedding_separate_optimizer(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/deepfm_combo_on_avazu_embed_adagrad.config',
self._test_dir)
self.assertTrue(self._success)

def test_expr_feature(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/multi_tower_on_taobao_for_expr.config',
Expand Down
Loading

0 comments on commit d2f81ae

Please sign in to comment.