diff --git a/docs/source/models/wide_and_deep.md b/docs/source/models/wide_and_deep.md index 7fc0276de..ac64faa4a 100644 --- a/docs/source/models/wide_and_deep.md +++ b/docs/source/models/wide_and_deep.md @@ -66,9 +66,40 @@ model_config:{ - embedding_regularization: 对embedding部分加regularization,防止overfit -- input_type: 如果在提交到pai-tf集群上面运行,读取max compute 表作为输入数据,data_config:input_type要设置为OdpsInputV2。 +- input_type: 如果在提交到pai-tf集群上面运行,读取MaxCompute 表作为输入数据,data_config:input_type要设置为OdpsInputV2。 + +#### 2.多优化器 + +- WideAndDeep模型可以配置2个或者3个优化器(optimizer) +- 配置2个优化器(optimizer), wide参数使用第一个优化器, 其它参数使用第二个优化器 +- 配置3个优化器(optimizer), wide参数使用第一个优化器, deep embedding使用第二个优化器, 其它参数使用第三个优化器 +- 配置实例(2 optimizers, samples/model_config/wide_and_deep_two_opti.config): + ```protobuf + optimizer_config: { + ftrl_optimizer: { + l1_reg: 10 + learning_rate: { + constant_learning_rate { + learning_rate: 0.0005 + } + } + } + } + + optimizer_config { + adam_optimizer { + learning_rate { + constant_learning_rate { + learning_rate: 0.0001 + } + } + } + } + ``` +- 代码参考: easy_rec/python/model/wide_and_deep.py + - WideAndDeep.get_grouped_vars重载了EasyRecModel.get_grouped_vars -#### 2. 组件化模型 +#### 3. 组件化模型 ```protobuf model_config: { diff --git a/docs/source/train.md b/docs/source/train.md index 42b04373f..9c2b53795 100644 --- a/docs/source/train.md +++ b/docs/source/train.md @@ -20,6 +20,10 @@ } ``` + - 多优化器支持: + - 可以配置两个optimizer, 分别对应embedding权重和dense权重 + - 实现参考EasyRecModel里面的get_grouped_vars + - sync_replicas: true # 是否同步训练,默认是false - 使用SyncReplicasOptimizer进行分布式训练(同步模式) diff --git a/easy_rec/python/builders/optimizer_builder.py b/easy_rec/python/builders/optimizer_builder.py index 7a2331b32..c7e3aca49 100644 --- a/easy_rec/python/builders/optimizer_builder.py +++ b/easy_rec/python/builders/optimizer_builder.py @@ -103,7 +103,9 @@ def build(optimizer_config): config = optimizer_config.adagrad_optimizer learning_rate = _create_learning_rate(config.learning_rate) summary_vars.append(learning_rate) - optimizer = tf.train.AdagradOptimizer(learning_rate) + optimizer = tf.train.AdagradOptimizer( + learning_rate, + initial_accumulator_value=config.initial_accumulator_value) if optimizer_type == 'adam_async_optimizer': config = optimizer_config.adam_async_optimizer diff --git a/easy_rec/python/model/easy_rec_estimator.py b/easy_rec/python/model/easy_rec_estimator.py index f6ba00d7e..69e578d2b 100644 --- a/easy_rec/python/model/easy_rec_estimator.py +++ b/easy_rec/python/model/easy_rec_estimator.py @@ -199,7 +199,7 @@ def _train_model_fn(self, features, labels, run_config): opt, learning_rate = optimizer_builder.build(tmp_config) tf.summary.scalar('learning_rate', learning_rate[0]) all_opts.append(opt) - grouped_vars = model.get_grouped_vars() + grouped_vars = model.get_grouped_vars(len(all_opts)) assert len(grouped_vars) == len(optimizer_config), \ 'the number of var group(%d) != the number of optimizers(%d)' \ % (len(grouped_vars), len(optimizer_config)) diff --git a/easy_rec/python/model/easy_rec_model.py b/easy_rec/python/model/easy_rec_model.py index 6fb8fa60a..37249949b 100644 --- a/easy_rec/python/model/easy_rec_model.py +++ b/easy_rec/python/model/easy_rec_model.py @@ -8,7 +8,7 @@ import six import tensorflow as tf from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops.variables import PartitionedVariable +from tensorflow.python.ops import variables from easy_rec.python.compat import regularizers from easy_rec.python.layers import input_layer @@ -114,7 +114,7 @@ def l2_regularization(self): if hasattr(model_config, 'dense_regularization') and \ model_config.HasField('dense_regularization'): # backward compatibility - tf.logging.warn( + logging.warn( 'dense_regularization is deprecated, please use l2_regularization') l2_regularization = model_config.dense_regularization elif hasattr(model_config, 'l2_regularization'): @@ -229,7 +229,7 @@ def restore(self, for x in shape_arr[1:]: var_shape[0] += x[0] var_shape = tensor_shape.TensorShape(var_shape) - variable = PartitionedVariable( + variable = variables.PartitionedVariable( variable_name, var_shape, variable[0].dtype, @@ -239,7 +239,7 @@ def restore(self, var_shape = variable.shape.as_list() if ckpt_var_shape == var_shape: vars_in_ckpt[variable_name] = list(variable) if isinstance( - variable, PartitionedVariable) else variable + variable, variables.PartitionedVariable) else variable elif len(ckpt_var_shape) == len(var_shape): if force_restore_shape_compatible: # create a variable compatible with checkpoint to restore @@ -394,10 +394,25 @@ def get_restore_filter(self): return restore_filter.CombineFilter(all_filters, restore_filter.Logical.AND), None - def get_grouped_vars(self): - """Get grouped variables, each group will be optimized by a separate optimizer. + def get_grouped_vars(self, opt_num): + """Group the vars into different optimization groups. + + Each group will be optimized by a separate optimizer. + + Args: + opt_num: number of optimizers from easyrec config. Return: - grouped_vars: list of list of variables + list of list of variables. """ - raise NotImplementedError() + assert opt_num == 2, 'could only support 2 optimizers, one for embedding, one for the other layers' + + embedding_vars = [] + deep_vars = [] + for tmp_var in variables.trainable_variables(): + if tmp_var.name.startswith( + 'input_layer') or '/embedding_weights' in tmp_var.name: + embedding_vars.append(tmp_var) + else: + deep_vars.append(tmp_var) + return [embedding_vars, deep_vars] diff --git a/easy_rec/python/model/wide_and_deep.py b/easy_rec/python/model/wide_and_deep.py index f841ed049..48b620bd7 100755 --- a/easy_rec/python/model/wide_and_deep.py +++ b/easy_rec/python/model/wide_and_deep.py @@ -79,23 +79,43 @@ def build_predict_graph(self): return self._prediction_dict - def get_grouped_vars(self): + def get_grouped_vars(self, opt_num): """Group the vars into different optimization groups. Each group will be optimized by a separate optimizer. + Args: + opt_num: number of optimizers from easyrec config. + Return: list of list of variables. """ - assert len(self._model_config.final_dnn.hidden_units) == 0, \ - 'if use different optimizers for wide group and deep group, '\ - + ' final_dnn should not be set.' - wide_vars = [] - deep_vars = [] - for tmp_var in tf.trainable_variables(): - if tmp_var.name.startswith('input_layer') and \ - (not tmp_var.name.startswith('input_layer_1')): - wide_vars.append(tmp_var) - else: - deep_vars.append(tmp_var) - return [wide_vars, deep_vars] + assert opt_num <= 3, 'could only support 2 or 3 optimizers, ' + \ + 'if opt_num = 2, one for the wide , and one for the others, ' + \ + 'if opt_num = 3, one for the wide, second for the deep embeddings, ' + \ + 'and third for the other layers.' + + if opt_num == 2: + wide_vars = [] + deep_vars = [] + for tmp_var in tf.trainable_variables(): + if tmp_var.name.startswith('input_layer') and \ + (not tmp_var.name.startswith('input_layer_1')): + wide_vars.append(tmp_var) + else: + deep_vars.append(tmp_var) + return [wide_vars, deep_vars] + elif opt_num == 3: + wide_vars = [] + embedding_vars = [] + deep_vars = [] + for tmp_var in tf.trainable_variables(): + if tmp_var.name.startswith('input_layer') and \ + (not tmp_var.name.startswith('input_layer_1')): + wide_vars.append(tmp_var) + elif tmp_var.name.startswith( + 'input_layer') or '/embedding_weights' in tmp_var.name: + embedding_vars.append(tmp_var) + else: + deep_vars.append(tmp_var) + return [wide_vars, embedding_vars, deep_vars] diff --git a/easy_rec/python/protos/optimizer.proto b/easy_rec/python/protos/optimizer.proto index 4be0c6cc6..ee10500cd 100644 --- a/easy_rec/python/protos/optimizer.proto +++ b/easy_rec/python/protos/optimizer.proto @@ -68,7 +68,8 @@ message AdamAsyncWOptimizer { // Configuration message for the AdagradOptimizer // See: https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer message AdagradOptimizer { - optional LearningRate learning_rate = 1; + optional LearningRate learning_rate = 1; + optional float initial_accumulator_value = 2 [default = 0.1]; } // Only available on pai-tf, which has better performance than AdamOptimizer diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 7ba75b462..b0b66d30c 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -1005,6 +1005,12 @@ def test_multi_optimizer(self): 'samples/model_config/wide_and_deep_two_opti.config', self._test_dir) self.assertTrue(self._success) + def test_embedding_separate_optimizer(self): + self._success = test_utils.test_distributed_train_eval( + 'samples/model_config/deepfm_combo_on_avazu_embed_adagrad.config', + self._test_dir) + self.assertTrue(self._success) + def test_expr_feature(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/multi_tower_on_taobao_for_expr.config', diff --git a/samples/model_config/deepfm_combo_on_avazu_embed_adagrad.config b/samples/model_config/deepfm_combo_on_avazu_embed_adagrad.config new file mode 100644 index 000000000..a4a920137 --- /dev/null +++ b/samples/model_config/deepfm_combo_on_avazu_embed_adagrad.config @@ -0,0 +1,383 @@ +train_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv" +eval_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv" +model_dir: "experiments/dwd_avazu_out_test_combo_embedding_adagrad" + +train_config { + log_step_count_steps: 200 + # fine_tune_checkpoint: "" + optimizer_config { + adagrad_optimizer { + learning_rate { + constant_learning_rate { + learning_rate: 0.05 + } + } + initial_accumulator_value: 1.0 + } + } + + optimizer_config: { + adam_optimizer: { + learning_rate: { + exponential_decay_learning_rate { + initial_learning_rate: 0.0001 + decay_steps: 10000 + decay_factor: 0.5 + min_learning_rate: 0.0000001 + } + } + } + use_moving_average: false + } + + sync_replicas: true + save_checkpoints_steps: 500 + num_steps: 1000 +} + +eval_config { + metrics_set: { + auc {} + } +} + +data_config { + separator: "," + input_fields: { + input_name: "label" + input_type: INT64 + default_val:"0" + } + input_fields: { + input_name: "hour" + input_type: INT64 + default_val:"0" + } + input_fields: { + input_name: "c1" + input_type: INT64 + default_val:"0" + } + input_fields: { + input_name: "banner_pos" + input_type: INT64 + default_val:"0" + } + input_fields: { + input_name: "site_id" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "site_domain" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "site_category" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "app_id" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "app_domain" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "app_category" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "device_id" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "device_ip" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "device_model" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "device_type" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "device_conn_type" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "c14" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "c15" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "c16" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "c17" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "c18" + input_type: STRING + default_val:"0" + } + input_fields: { + input_name: "c19" + input_type: INT64 + default_val:"0" + } + input_fields: { + input_name: "c20" + input_type: INT64 + default_val:"0" + } + input_fields: { + input_name: "c21" + input_type: INT64 + default_val:"0" + } + label_fields: "label" + + batch_size: 1024 + prefetch_size: 32 + input_type: CSVInput +} + +feature_config: { + features: { + input_names: "hour" + feature_type: IdFeature + num_buckets: 24 + embedding_dim: 16 + } + features: { + input_names: "c1" + feature_type: RawFeature + boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0] + embedding_dim: 16 + } + features: { + input_names: "banner_pos" + feature_type: RawFeature + boundaries: [1,2,3,4,5,6] + embedding_dim: 16 + } + features: { + input_names: "site_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 + } + features: { + input_names: "site_domain" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features: { + input_names: "site_category" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features: { + input_names: "app_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 + } + features: { + input_names: "app_domain" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 1000 + } + features: { + input_names: "app_category" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features: { + input_names: "device_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: "device_ip" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: "device_model" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 + } + features: { + input_names: "device_type" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: "device_conn_type" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: "c14" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 500 + } + features: { + input_names: "c15" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 500 + } + features: { + input_names: "c16" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 500 + } + features: { + input_names: "c17" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 500 + } + features: { + input_names: "c18" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 500 + } + features: { + input_names: "c19" + feature_type: RawFeature + boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190] + embedding_dim: 16 + } + features: { + input_names: "c20" + feature_type: RawFeature + boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0] + embedding_dim: 16 + } + features: { + input_names: "c21" + feature_type: RawFeature + boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25] + embedding_dim: 16 + } + features: { + input_names: ["site_id", "app_id"] + feature_name: "site_id_app_id" + feature_type: ComboFeature + hash_bucket_size: 1000 + embedding_dim: 16 + } + +} +model_config:{ + model_class: "DeepFM" + feature_groups: { + group_name: "deep" + feature_names: "hour" + feature_names: "c1" + feature_names: "banner_pos" + feature_names: "site_id" + feature_names: "site_domain" + feature_names: "site_category" + feature_names: "app_id" + feature_names: "app_domain" + feature_names: "app_category" + feature_names: "device_id" + feature_names: "device_ip" + feature_names: "device_model" + feature_names: "device_type" + feature_names: "device_conn_type" + feature_names: "c14" + feature_names: "c15" + feature_names: "c16" + feature_names: "c17" + feature_names: "c18" + feature_names: "c19" + feature_names: "c20" + feature_names: "c21" + feature_names: "site_id_app_id" + wide_deep:DEEP + } + feature_groups: { + group_name: "wide" + feature_names: "hour" + feature_names: "c1" + feature_names: "banner_pos" + feature_names: "site_id" + feature_names: "site_domain" + feature_names: "site_category" + feature_names: "app_id" + feature_names: "app_domain" + feature_names: "app_category" + feature_names: "device_id" + feature_names: "device_ip" + feature_names: "device_model" + feature_names: "device_type" + feature_names: "device_conn_type" + feature_names: "c14" + feature_names: "c15" + feature_names: "c16" + feature_names: "c17" + feature_names: "c18" + feature_names: "c19" + feature_names: "c20" + feature_names: "c21" + wide_deep:WIDE + } + + deepfm { + wide_output_dim: 16 + + dnn { + hidden_units: [128, 64, 32] + } + + final_dnn { + hidden_units: [128, 64] + } + l2_regularization: 1e-5 + } + # embedding_regularization: 1e-7 +} + +export_config { + multi_placeholder: false +} diff --git a/samples/model_config/wide_and_deep_two_opti.config b/samples/model_config/wide_and_deep_two_opti.config index 283a6d0c8..fa4ac3b01 100755 --- a/samples/model_config/wide_and_deep_two_opti.config +++ b/samples/model_config/wide_and_deep_two_opti.config @@ -10,30 +10,22 @@ train_config { ftrl_optimizer: { l1_reg: 10 learning_rate: { - exponential_decay_learning_rate { - initial_learning_rate: 0.0001 - decay_steps: 10000 - decay_factor: 0.5 - min_learning_rate: 0.0000001 + constant_learning_rate { + learning_rate: 0.0005 } } } - use_moving_average: false } optimizer_config: { adam_optimizer: { learning_rate: { - exponential_decay_learning_rate { - initial_learning_rate: 0.0001 - decay_steps: 10000 - decay_factor: 0.5 - min_learning_rate: 0.0000001 + constant_learning_rate { + learning_rate: 0.0001 } } } - use_moving_average: false }