diff --git a/README.md b/README.md index f0a35bdfd..70285409a 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ Running Platform: ### A variety of models - [DSSM](docs/source/models/dssm.md) / [MIND](docs/source/models/mind.md) / [DropoutNet](docs/source/models/dropoutnet.md) / [CoMetricLearningI2I](docs/source/models/co_metric_learning_i2i.md) / [PDN](docs/source/models/pdn.md) -- [W&D](docs/source/models/wide_and_deep.md) / [DeepFM](docs/source/models/deepfm.md) / [MultiTower](docs/source/models/multi_tower.md) / [DCN](docs/source/models/dcn.md) / [FiBiNet](docs/source/models/fibinet.md) / [MaskNet](docs/source/models/masknet.md) / [CDN](docs/source/models/cdn.md) +- [W&D](docs/source/models/wide_and_deep.md) / [DeepFM](docs/source/models/deepfm.md) / [MultiTower](docs/source/models/multi_tower.md) / [DCN](docs/source/models/dcn.md) / [FiBiNet](docs/source/models/fibinet.md) / [MaskNet](docs/source/models/masknet.md) / [PPNet](docs/source/models/ppnet.md) / [CDN](docs/source/models/cdn.md) - [DIN](docs/source/models/din.md) / [BST](docs/source/models/bst.md) / [CL4SRec](docs/source/models/cl4srec.md) - [MMoE](docs/source/models/mmoe.md) / [ESMM](docs/source/models/esmm.md) / [DBMTL](docs/source/models/dbmtl.md) / [PLE](docs/source/models/ple.md) - [HighwayNetwork](docs/source/models/highway.md) / [CMBF](docs/source/models/cmbf.md) / [UNITER](docs/source/models/uniter.md) diff --git a/docs/images/models/ppnet.jpg b/docs/images/models/ppnet.jpg new file mode 100644 index 000000000..1c15c472e Binary files /dev/null and b/docs/images/models/ppnet.jpg differ diff --git a/docs/source/component/backbone.md b/docs/source/component/backbone.md index 21c2c467a..32450362d 100644 --- a/docs/source/component/backbone.md +++ b/docs/source/component/backbone.md @@ -946,6 +946,127 @@ DBMTL模型需要在`model_params`里为每个子任务的Tower配置`relation_d 这个案例同样没有为backbone配置`concat_blocks`,框架会自动设置为DAG的叶子节点。 +## 案例10:MaskNet + PPNet + MMoE + +```protobuf +model_config: { + model_name: 'MaskNet + PPNet + MMoE' + model_class: 'RankModel' + feature_groups: { + group_name: 'memorize' + feature_names: 'user_id' + feature_names: 'adgroup_id' + feature_names: 'pid' + wide_deep: DEEP + } + feature_groups: { + group_name: 'general' + feature_names: 'age_level' + feature_names: 'shopping_level' + ... + wide_deep: DEEP + } + backbone { + blocks { + name: "mask_net" + inputs { + feature_group_name: "general" + } + repeat { + num_repeat: 3 + keras_layer { + class_name: "MaskBlock" + mask_block { + output_size: 512 + aggregation_size: 1024 + } + } + } + } + blocks { + name: "ppnet" + inputs { + block_name: "mask_net" + } + inputs { + feature_group_name: "memorize" + } + merge_inputs_into_list: true + repeat { + num_repeat: 3 + input_fn: "lambda x, i: [x[0][i], x[1]]" + keras_layer { + class_name: "PPNet" + ppnet { + mlp { + hidden_units: [256, 128, 64] + } + gate_params { + output_dim: 512 + } + mode: "eager" + full_gate_input: false + } + } + } + } + blocks { + name: "mmoe" + inputs { + block_name: "ppnet" + } + inputs { + feature_group_name: "general" + } + keras_layer { + class_name: "MMoE" + mmoe { + num_task: 2 + num_expert: 3 + } + } + } + } + model_params { + l2_regularization: 0.0 + task_towers { + tower_name: "ctr" + label_name: "is_click" + metrics_set { + auc { + num_thresholds: 20000 + } + } + loss_type: CLASSIFICATION + num_class: 1 + dnn { + hidden_units: 64 + hidden_units: 32 + } + weight: 1.0 + } + task_towers { + tower_name: "cvr" + label_name: "is_train" + metrics_set { + auc { + num_thresholds: 20000 + } + } + loss_type: CLASSIFICATION + num_class: 1 + dnn { + hidden_units: 64 + hidden_units: 32 + } + weight: 1.0 + } + } +} +``` + +该案例体现了如何应用[重复组件块](#id21)。 + ## 更多案例 两个新的模型: @@ -1002,6 +1123,7 @@ MovieLens-1M数据集效果: | SENet | 建模特征重要度 | FiBiNet模型的组件 | [MMoE](../models/mmoe.html#id4) | | MaskBlock | 建模特征重要度 | MaskNet模型的组件 | [Cross Decoupling Network](../models/cdn.html#id2) | | MaskNet | 多个串行或并行的MaskBlock | MaskNet模型 | [DBMTL](../models/dbmtl.html#dbmtl-based-on-backbone) | +| PPNet | 参数个性化网络 | PPNet模型 | [PPNet](../models/ppnet.html#id2) | ## 4. 序列特征编码组件 @@ -1310,6 +1432,10 @@ repeat { - `num_repeat` 配置重复执行的次数 - `output_concat_axis` 配置多次执行结果tensors的拼接维度,若不配置则输出多次执行结果的列表 - `keras_layer` 配置需要执行的组件 +- `input_slice` 配置每个执行组件的输入切片,例如`[i]`获取输入列表的第 i 个元素作为第 i 次重复执行时的输入;不配置时获取所有输入 +- `input_fn` 配置每个执行组件的输入函数,例如`input_fn: "lambda x, i: [x[0][i], x[1]]"` + +`重复组件块` 的使用案例[MaskNet+PPNet+MMoE](#masknet-ppnet-mmoe)。 ## 7. 序列组件块 diff --git a/docs/source/component/component.md b/docs/source/component/component.md index 80e9998e3..89d3819fa 100644 --- a/docs/source/component/component.md +++ b/docs/source/component/component.md @@ -96,6 +96,25 @@ | use_parallel | bool | true | 是否使用并行模式 | | mlp | MLP | 可选 | 顶部mlp | +- PPNet + +| 参数 | 类型 | 默认值 | 说明 | +| --------------- | ------ | ----- | -------------------------------------------------- | +| mlp | MLP | | mlp 配置 | +| gate_params | GateNN | | 参数个性化Gate网络的配置 | +| mode | string | eager | 配置参数个性化是作用在MLP的每个layer的输入上还是输出上,可选:\[eager, lazy\] | +| full_gate_input | bool | true | 是否需要添加stop_gradient之后的mlp的输入作为gate网络的输入 | + +其中,GateNN的参数如下: + +| 参数 | 类型 | 默认值 | 说明 | +| ------------ | ------ | --------------- | ----------------------------------------- | +| output_dim | uint32 | mlp前一层的输出units数 | Gate网络的输出维度,eager模式下必须要配置为mlp第一层的输入units数 | +| hidden_dim | uint32 | output_dim | 隐层单元数 | +| dropout_rate | float | 0.0 | 隐层dropout rate | +| activation | str | relu | 隐层的激活函数 | +| use_bn | bool | true | 隐层是否使用batch normalization | + ## 4. 序列特征编码组件 - SeqAugment diff --git a/docs/source/component/sequence.md b/docs/source/component/sequence.md new file mode 100644 index 000000000..ddcbace7d --- /dev/null +++ b/docs/source/component/sequence.md @@ -0,0 +1,79 @@ +# 序列化组件的配置方式 + +序列模型(DIN、BST)的组件化配置方式需要把输入特征放置在同一个`feature_group`内。 + +序列模型一般包含 `history behavior sequence` 与 `target item` 两部分,且每部分都可能包含多个属性(子特征)。 + +在序列组件输入的`feature_group`内,**按照顺序**定义 `history behavior sequence` 与 `target item`的各个子特征。 + +框架按照特征定义的类型`feature_type`字段来识别某个具体的特征是属于 `history behavior sequence` 还是 `target item`。 +所有 `SequenceFeature` 类型的子特征都被识别为`history behavior sequence`的一部分; 所有非`SequenceFeature` 类型的子特征都被识别为`target item`的一部分。 + +**两部分的子特征的顺序需要保持一致**。在下面的例子中, + +- `concat([cate_id,brand], axis=-1)` 是`target item`最终的embedding(2D); +- `concat([tag_category_list, tag_brand_list], axis=-1)` 是`history behavior sequence`最终的embedding(3D) + +```protobuf +model_config: { + model_name: 'DIN' + model_class: 'RankModel + ... + feature_groups: { + group_name: 'sequence' + feature_names: "cate_id" + feature_names: "brand" + feature_names: "tag_category_list" + feature_names: "tag_brand_list" + wide_deep: DEEP + } + backbone { + blocks { + name: 'seq_input' + inputs { + feature_group_name: 'sequence' + } + input_layer { + output_seq_and_normal_feature: true + } + } + blocks { + name: 'DIN' + inputs { + block_name: 'seq_input' + } + keras_layer { + class_name: 'DIN' + din { + attention_dnn { + hidden_units: 32 + hidden_units: 1 + activation: "dice" + } + need_target_feature: true + } + } + } + ... + } +} +``` + +使用序列组件时,必须配置一个`input_layer`类型的`block`,并且配置`output_seq_and_normal_feature: true`参数,如下。 + +```protobuf +blocks { + name: 'seq_input' + inputs { + feature_group_name: 'sequence' + } + input_layer { + output_seq_and_normal_feature: true + } +} +``` + +## 完整的例子 + +- [DIN](../models/din.md) +- [BST](../models/bst.md) diff --git a/docs/source/index.rst b/docs/source/index.rst index 9cef0a0a5..2d64bf906 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,7 @@ Welcome to easy_rec's documentation! component/backbone component/component + component/sequence .. toctree:: :maxdepth: 3 diff --git a/docs/source/models/ppnet.md b/docs/source/models/ppnet.md new file mode 100644 index 000000000..6fae79b19 --- /dev/null +++ b/docs/source/models/ppnet.md @@ -0,0 +1,95 @@ +# PPNet(Parameter Personalized Net) + +### 简介 + +PPNet的核心思想来源于NLP领域的LHUC,在语音识别领域中,2016 年提出的LHUC 算法(learning hidden unit contributions) +核心思想是做说话人自适应(speaker adaptation),其中一个关键突破是在 DNN 网络中,为每个说话人学习一个特定的隐式单元贡献(hidden unit contributions), +来提升不同说话人的语音识别效果。 + +借鉴 LHUC 的思想,PPNet设计出一种 gating 机制,可以增加 DNN 网络参数个性化并能够让模型快速收敛。 + +![ppnet](../../images/models/ppnet.jpg) + +### 配置说明 + +```protobuf +model_config: { + model_name: 'PPNet' + model_class: 'RankModel' + feature_groups: { + group_name: 'memorize' + feature_names: 'user_id' + feature_names: 'adgroup_id' + feature_names: 'pid' + wide_deep: DEEP + } + feature_groups: { + group_name: 'general' + feature_names: 'cms_segid' + feature_names: 'cms_group_id' + feature_names: 'age_level' + feature_names: 'pvalue_level' + feature_names: 'shopping_level' + feature_names: 'occupation' + feature_names: 'new_user_class_level' + feature_names: 'cate_id' + feature_names: 'campaign_id' + feature_names: 'customer' + feature_names: 'brand' + feature_names: 'price' + feature_names: 'tag_category_list' + feature_names: 'tag_brand_list' + wide_deep: DEEP + } + backbone { + blocks { + name: "ppnet" + inputs { + feature_group_name: "general" + } + inputs { + feature_group_name: "memorize" + } + merge_inputs_into_list: true + keras_layer { + class_name: "PPNet" + ppnet { + mlp { + hidden_units: [512, 256] + } + mode: "lazy" + full_gate_input: true + } + } + } + top_mlp { + hidden_units: [128, 64] + } + } + model_params { + l2_regularization: 1e-6 + } + embedding_regularization: 1e-5 +} +``` + +- model_name: 任意自定义字符串,仅有注释作用 +- model_class: 'RankModel', 不需要修改, 通过组件化方式搭建的单目标排序模型都叫这个名字 +- feature_groups: 配置一组特征。 +- backbone: 通过组件化的方式搭建的主干网络,[参考文档](../component/backbone.md) + - blocks: 由多个`组件块`组成的一个有向无环图(DAG),框架负责按照DAG的拓扑排序执行个`组件块`关联的代码逻辑,构建TF Graph的一个子图 + - name/inputs: 每个`block`有一个唯一的名字(name),并且有一个或多个输入(inputs)和输出 + - keras_layer: 加载由`class_name`指定的自定义或系统内置的keras layer,执行一段代码逻辑;[参考文档](../component/backbone.md#keraslayer) + - ppnet: PPNet的基础组件,参数详见[参考文档](../component/component.md#id4) + - concat_blocks: DAG的输出节点由`concat_blocks`配置项定义,如果不配置`concat_blocks`,框架会自动拼接DAG的所有叶子节点并输出。 +- model_params: + - l2_regularization: (可选) 对DNN参数的regularization, 减少overfit +- embedding_regularization: 对embedding部分加regularization, 减少overfit + +### 示例Config + +[ppnet_on_taobao.config](https://github.com/alibaba/EasyRec/tree/master/samples/model_config/ppnet_on_taobao.config) + +### 参考论文 + +[PEPNet: Parameter and Embedding Personalized Network for Infusing with Personalized Prior Information](https://arxiv.org/pdf/2302.01115.pdf) diff --git a/docs/source/models/rank.rst b/docs/source/models/rank.rst index 91bd29680..888c8493e 100644 --- a/docs/source/models/rank.rst +++ b/docs/source/models/rank.rst @@ -17,6 +17,7 @@ masknet fibinet cdn + ppnet cl4srec regression multi_cls diff --git a/easy_rec/python/layers/backbone.py b/easy_rec/python/layers/backbone.py index 4a510fa64..23c2c1735 100644 --- a/easy_rec/python/layers/backbone.py +++ b/easy_rec/python/layers/backbone.py @@ -43,7 +43,6 @@ def __init__(self, config, features, input_layer, l2_reg=None): self._l2_reg = l2_reg self._dag = DAG() self._name_to_blocks = {} - self.loss_dict = {} self._name_to_layer = {} self.reset_input_config(None) self._block_outputs = {} @@ -159,15 +158,15 @@ def define_layers(self, layer, layer_cnf, name, reuse): layer_obj = self.load_keras_layer(layer_cnf.keras_layer, name, reuse) self._name_to_layer[name] = layer_obj elif layer == 'recurrent': + keras_layer = layer_cnf.recurrent.keras_layer for i in range(layer_cnf.recurrent.num_steps): name_i = '%s_%d' % (name, i) - keras_layer = layer_cnf.recurrent.keras_layer layer_obj = self.load_keras_layer(keras_layer, name_i, reuse) self._name_to_layer[name_i] = layer_obj elif layer == 'repeat': + keras_layer = layer_cnf.repeat.keras_layer for i in range(layer_cnf.repeat.num_repeat): name_i = '%s_%d' % (name, i) - keras_layer = layer_cnf.repeat.keras_layer layer_obj = self.load_keras_layer(keras_layer, name_i, reuse) self._name_to_layer[name_i] = layer_obj @@ -183,7 +182,7 @@ def has_block(self, name): def block_outputs(self, name): return self._block_outputs.get(name, None) - def block_input(self, config, block_outputs, training=None): + def block_input(self, config, block_outputs, training=None, **kwargs): inputs = [] for input_node in config.inputs: input_type = input_node.WhichOneof('name') @@ -211,9 +210,7 @@ def block_input(self, config, block_outputs, training=None): fn = eval(input_node.package_input_fn) pkg_input = fn(pkg_input) package.set_package_input(pkg_input) - input_feature = package(training) - if len(package.loss_dict) > 0: - self.loss_dict.update(package.loss_dict) + input_feature = package(training, **kwargs) elif input_name in block_outputs: input_feature = block_outputs[input_name] else: @@ -258,16 +255,16 @@ def call(self, is_training, **kwargs): config = self._name_to_blocks[block] if config.layers: # sequential layers logging.info('call sequential %d layers' % len(config.layers)) - output = self.block_input(config, block_outputs, is_training) + output = self.block_input(config, block_outputs, is_training, **kwargs) for i, layer in enumerate(config.layers): name_i = '%s_l%d' % (block, i) - output = self.call_layer(output, layer, name_i, is_training) + output = self.call_layer(output, layer, name_i, is_training, **kwargs) block_outputs[block] = output continue # just one of layer layer = config.WhichOneof('layer') if layer is None: # identity layer - output = self.block_input(config, block_outputs, is_training) + output = self.block_input(config, block_outputs, is_training, **kwargs) block_outputs[block] = output elif layer == 'input_layer': input_fn = self._name_to_layer[block] @@ -277,18 +274,14 @@ def call(self, is_training, **kwargs): input_fn.reset(input_config, is_training) block_outputs[block] = input_fn(input_config, is_training) else: - inputs = self.block_input(config, block_outputs, is_training) - output = self.call_layer(inputs, config, block, is_training) + inputs = self.block_input(config, block_outputs, is_training, **kwargs) + output = self.call_layer(inputs, config, block, is_training, **kwargs) block_outputs[block] = output outputs = [] for output in self._config.concat_blocks: if output in block_outputs: temp = block_outputs[output] - # if type(temp) in (tuple, list): - # outputs.extend(temp) - # else: - # outputs.append(temp) outputs.append(temp) else: raise ValueError('No output `%s` of backbone to be concat' % output) @@ -345,11 +338,10 @@ def load_keras_layer(self, layer_conf, name, reuse=None): layer = layer_cls(*args, name=name) return layer, customize - def call_keras_layer(self, inputs, name, training): + def call_keras_layer(self, inputs, name, training, **kwargs): """Call predefined Keras Layer, which can be reused.""" layer, customize = self._name_to_layer[name] cls = layer.__class__.__name__ - kwargs = {'loss_dict': self.loss_dict} if customize: output = layer(inputs, training=training, **kwargs) else: @@ -361,10 +353,10 @@ def call_keras_layer(self, inputs, name, training): output = layer(inputs) return output - def call_layer(self, inputs, config, name, training): + def call_layer(self, inputs, config, name, training, **kwargs): layer_name = config.WhichOneof('layer') if layer_name == 'keras_layer': - return self.call_keras_layer(inputs, name, training) + return self.call_keras_layer(inputs, name, training, **kwargs) if layer_name == 'lambda': conf = getattr(config, 'lambda') fn = eval(conf.expression) @@ -375,7 +367,14 @@ def call_layer(self, inputs, config, name, training): outputs = [] for i in range(n_loop): name_i = '%s_%d' % (name, i) - output = self.call_keras_layer(inputs, name_i, training) + ly_inputs = inputs + if conf.HasField('input_slice'): + fn = eval('lambda x, i: x' + conf.input_slice.strip()) + ly_inputs = fn(ly_inputs, i) + if conf.HasField('input_fn'): + fn = eval(conf.input_fn) + ly_inputs = fn(ly_inputs, i) + output = self.call_keras_layer(ly_inputs, name_i, training, **kwargs) outputs.append(output) if len(outputs) == 1: return outputs[0] @@ -392,7 +391,7 @@ def call_layer(self, inputs, config, name, training): output = inputs for i in range(conf.num_steps): name_i = '%s_%d' % (name, i) - output_i = self.call_keras_layer(output, name_i, training) + output_i = self.call_keras_layer(output, name_i, training, **kwargs) if fixed_input_index >= 0: j = 0 for idx in range(len(output)): @@ -421,7 +420,6 @@ class Backbone(object): def __init__(self, config, features, input_layer, l2_reg=None): self._config = config self._l2_reg = l2_reg - self.loss_dict = {} main_pkg = backbone_pb2.BlockPackage() main_pkg.name = 'backbone' main_pkg.blocks.MergeFrom(config.blocks) @@ -432,8 +430,6 @@ def __init__(self, config, features, input_layer, l2_reg=None): def __call__(self, is_training, **kwargs): output = self._main_pkg(is_training, **kwargs) - if len(self._main_pkg.loss_dict) > 0: - self.loss_dict = self._main_pkg.loss_dict if self._config.HasField('top_mlp'): params = Parameter.make_from_pb(self._config.top_mlp) diff --git a/easy_rec/python/layers/keras/__init__.py b/easy_rec/python/layers/keras/__init__.py index 3f22f511b..cbe36b5ca 100644 --- a/easy_rec/python/layers/keras/__init__.py +++ b/easy_rec/python/layers/keras/__init__.py @@ -16,3 +16,4 @@ from .multi_task import MMoE from .numerical_embedding import AutoDisEmbedding from .numerical_embedding import PeriodicEmbedding +from .ppnet import PPNet diff --git a/easy_rec/python/layers/keras/mask_net.py b/easy_rec/python/layers/keras/mask_net.py index 507a0020d..49318df3b 100644 --- a/easy_rec/python/layers/keras/mask_net.py +++ b/easy_rec/python/layers/keras/mask_net.py @@ -82,6 +82,9 @@ def call(self, inputs, **kwargs): mask, net.shape[-1], name='%s/mask' % self.name, reuse=self.reuse) masked_net = net * mask + if not self.config.HasField('output_size'): + return masked_net + output_size = self.config.output_size hidden = tf.layers.dense( masked_net, diff --git a/easy_rec/python/layers/keras/multi_task.py b/easy_rec/python/layers/keras/multi_task.py index 35607834f..ec9f1e5cf 100644 --- a/easy_rec/python/layers/keras/multi_task.py +++ b/easy_rec/python/layers/keras/multi_task.py @@ -25,15 +25,20 @@ class MMoE(tf.keras.layers.Layer): def __init__(self, params, name='MMoE', reuse=None, **kwargs): super(MMoE, self).__init__(name, **kwargs) - params.check_required(['num_expert', 'num_task', 'expert_mlp']) + params.check_required(['num_expert', 'num_task']) self._reuse = reuse self._num_expert = params.num_expert self._num_task = params.num_task - expert_params = params.expert_mlp - self._experts = [ - MLP(expert_params, 'expert_%d' % i, reuse=reuse) - for i in range(self._num_expert) - ] + if params.has_field('expert_mlp'): + expert_params = params.expert_mlp + self._has_experts = True + self._experts = [ + MLP(expert_params, 'expert_%d' % i, reuse=reuse) + for i in range(self._num_expert) + ] + else: + self._has_experts = False + self._experts = [lambda x: x[i] for i in range(self._num_expert)] self._l2_reg = params.l2_regularizer def __call__(self, inputs, **kwargs): @@ -41,19 +46,21 @@ def __call__(self, inputs, **kwargs): logging.warning('num_expert of MMoE layer `%s` is 0' % self.name) return inputs - expert_fea_list = [expert(inputs) for expert in self._experts] - experts_fea = tf.stack(expert_fea_list, axis=1) - - task_input_list = [] - for task_id in range(self._num_task): - gate = gate_fn( - inputs, - self._num_expert, - name='gate_%d' % task_id, - l2_reg=self._l2_reg, - reuse=self._reuse) - gate = tf.expand_dims(gate, -1) - task_input = tf.multiply(experts_fea, gate) - task_input = tf.reduce_sum(task_input, axis=1) - task_input_list.append(task_input) + with tf.name_scope(self.name): + expert_fea_list = [expert(inputs) for expert in self._experts] + experts_fea = tf.stack(expert_fea_list, axis=1) + + gate_input = inputs if self._has_experts else inputs[self._num_expert] + task_input_list = [] + for task_id in range(self._num_task): + gate = gate_fn( + gate_input, + self._num_expert, + name='gate_%d' % task_id, + l2_reg=self._l2_reg, + reuse=self._reuse) + gate = tf.expand_dims(gate, -1) + task_input = tf.multiply(experts_fea, gate) + task_input = tf.reduce_sum(task_input, axis=1) + task_input_list.append(task_input) return task_input_list diff --git a/easy_rec/python/layers/keras/ppnet.py b/easy_rec/python/layers/keras/ppnet.py new file mode 100644 index 000000000..71f5902d1 --- /dev/null +++ b/easy_rec/python/layers/keras/ppnet.py @@ -0,0 +1,196 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +"""Convenience blocks for building models.""" +import logging + +import tensorflow as tf + +from easy_rec.python.layers.keras.activation import activation_layer +from easy_rec.python.utils.tf_utils import add_elements_to_collection + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +class GateNN(tf.keras.layers.Layer): + + def __init__(self, + params, + output_units=None, + name='gate_nn', + reuse=None, + **kwargs): + super(GateNN, self).__init__(name=name, **kwargs) + output_dim = output_units if output_units is not None else params.output_dim + hidden_dim = params.get_or_default('hidden_dim', output_dim) + initializer = params.get_or_default('initializer', 'he_uniform') + do_batch_norm = params.get_or_default('use_bn', False) + activation = params.get_or_default('activation', 'relu') + dropout_rate = params.get_or_default('dropout_rate', 0.0) + + self._sub_layers = [] + dense = tf.keras.layers.Dense( + units=hidden_dim, + use_bias=not do_batch_norm, + kernel_initializer=initializer, + name=name) + self._sub_layers.append(dense) + + if do_batch_norm: + bn = tf.keras.layers.BatchNormalization( + name='%s/bn' % name, trainable=True) + self._sub_layers.append(bn) + + act_layer = activation_layer(activation) + self._sub_layers.append(act_layer) + + if 0.0 < dropout_rate < 1.0: + dropout = tf.keras.layers.Dropout(dropout_rate, name='%s/dropout' % name) + self._sub_layers.append(dropout) + elif dropout_rate >= 1.0: + raise ValueError('invalid dropout_ratio: %.3f' % dropout_rate) + + dense = tf.keras.layers.Dense( + units=output_dim, + activation='sigmoid', + use_bias=not do_batch_norm, + kernel_initializer=initializer, + name=name) + self._sub_layers.append(dense) + self._sub_layers.append(lambda x: x * 2) + + def call(self, x, training=None, **kwargs): + """Performs the forward computation of the block.""" + for layer in self._sub_layers: + cls = layer.__class__.__name__ + if cls in ('Dropout', 'BatchNormalization', 'Dice'): + x = layer(x, training=training) + if cls in ('BatchNormalization', 'Dice'): + add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS) + else: + x = layer(x) + return x + + +class PPNet(tf.keras.layers.Layer): + """PEPNet: Parameter and Embedding Personalized Network for Infusing with Personalized Prior Information. + + Attributes: + units: Sequential list of layer sizes. + use_bias: Whether to include a bias term. + activation: Type of activation to use on all except the last layer. + final_activation: Type of activation to use on last layer. + **kwargs: Extra args passed to the Keras Layer base class. + """ + + def __init__(self, params, name='ppnet', reuse=None, **kwargs): + super(PPNet, self).__init__(name=name, **kwargs) + params.check_required('mlp') + self.full_gate_input = params.get_or_default('full_gate_input', True) + mode = params.get_or_default('mode', 'lazy') + gate_params = params.gate_params + params = params.mlp + params.check_required('hidden_units') + use_bn = params.get_or_default('use_bn', True) + use_final_bn = params.get_or_default('use_final_bn', True) + use_bias = params.get_or_default('use_bias', False) + use_final_bias = params.get_or_default('use_final_bias', False) + dropout_rate = list(params.get_or_default('dropout_ratio', [])) + activation = params.get_or_default('activation', 'relu') + initializer = params.get_or_default('initializer', 'he_uniform') + final_activation = params.get_or_default('final_activation', None) + use_bn_after_act = params.get_or_default('use_bn_after_activation', False) + units = list(params.hidden_units) + logging.info( + 'MLP(%s) units: %s, dropout: %r, activate=%s, use_bn=%r, final_bn=%r,' + ' final_activate=%s, bias=%r, initializer=%s, bn_after_activation=%r' % + (name, units, dropout_rate, activation, use_bn, use_final_bn, + final_activation, use_bias, initializer, use_bn_after_act)) + assert len(units) > 0, 'MLP(%s) takes at least one hidden units' % name + self.reuse = reuse + + num_dropout = len(dropout_rate) + self._sub_layers = [] + + if mode != 'lazy': + self._sub_layers.append(GateNN(gate_params, None, 'gate_0')) + for i, num_units in enumerate(units[:-1]): + name = 'layer_%d' % i + drop_rate = dropout_rate[i] if i < num_dropout else 0.0 + self.add_rich_layer(num_units, use_bn, drop_rate, activation, initializer, + use_bias, use_bn_after_act, name, + params.l2_regularizer) + self._sub_layers.append( + GateNN(gate_params, num_units, 'gate_%d' % (i + 1))) + + n = len(units) - 1 + drop_rate = dropout_rate[n] if num_dropout > n else 0.0 + name = 'layer_%d' % n + self.add_rich_layer(units[-1], use_final_bn, drop_rate, final_activation, + initializer, use_final_bias, use_bn_after_act, name, + params.l2_regularizer) + if mode == 'lazy': + self._sub_layers.append( + GateNN(gate_params, units[-1], 'gate_%d' % (n + 1))) + + def add_rich_layer(self, + num_units, + use_bn, + dropout_rate, + activation, + initializer, + use_bias, + use_bn_after_activation, + name, + l2_reg=None): + act_layer = activation_layer(activation) + if use_bn and not use_bn_after_activation: + dense = tf.keras.layers.Dense( + units=num_units, + use_bias=use_bias, + kernel_initializer=initializer, + kernel_regularizer=l2_reg, + name=name) + self._sub_layers.append(dense) + bn = tf.keras.layers.BatchNormalization( + name='%s/bn' % name, trainable=True) + self._sub_layers.append(bn) + self._sub_layers.append(act_layer) + else: + dense = tf.keras.layers.Dense( + num_units, + use_bias=use_bias, + kernel_initializer=initializer, + kernel_regularizer=l2_reg, + name=name) + self._sub_layers.append(dense) + self._sub_layers.append(act_layer) + if use_bn and use_bn_after_activation: + bn = tf.keras.layers.BatchNormalization(name='%s/bn' % name) + self._sub_layers.append(bn) + + if 0.0 < dropout_rate < 1.0: + dropout = tf.keras.layers.Dropout(dropout_rate, name='%s/dropout' % name) + self._sub_layers.append(dropout) + elif dropout_rate >= 1.0: + raise ValueError('invalid dropout_ratio: %.3f' % dropout_rate) + + def call(self, inputs, training=None, **kwargs): + """Performs the forward computation of the block.""" + x, gate_input = inputs + if self.full_gate_input: + with tf.name_scope(self.name): + gate_input = tf.concat([tf.stop_gradient(x), gate_input], axis=-1) + + for layer in self._sub_layers: + cls = layer.__class__.__name__ + if cls == 'GateNN': + gate = layer(gate_input) + x *= gate + elif cls in ('Dropout', 'BatchNormalization', 'Dice'): + x = layer(x, training=training) + if cls in ('BatchNormalization', 'Dice'): + add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS) + else: + x = layer(x) + return x diff --git a/easy_rec/python/model/easy_rec_model.py b/easy_rec/python/model/easy_rec_model.py index 37249949b..24f1a475c 100644 --- a/easy_rec/python/model/easy_rec_model.py +++ b/easy_rec/python/model/easy_rec_model.py @@ -61,6 +61,7 @@ def __init__(self, self._labels = labels self._prediction_dict = {} self._loss_dict = {} + self._metric_dict = {} # add sample weight from inputs self._sample_weight = 1.0 @@ -88,10 +89,12 @@ def backbone(self): if self._backbone_output: return self._backbone_output if self._backbone_net: - self._backbone_output = self._backbone_net(self._is_training) - loss_dict = self._backbone_net.loss_dict - self._loss_dict.update(loss_dict) - return self._backbone_output + kwargs = { + 'loss_dict': self._loss_dict, + 'metric_dict': self._metric_dict, + constant.SAMPLE_WEIGHT: self._sample_weight + } + return self._backbone_net(self._is_training, **kwargs) return None @property @@ -142,9 +145,8 @@ def build_predict_graph(self): def build_loss_graph(self): pass - @abstractmethod def build_metric_graph(self, eval_config): - pass + return self._metric_dict @abstractmethod def get_outputs(self): diff --git a/easy_rec/python/model/multi_task_model.py b/easy_rec/python/model/multi_task_model.py index c683702ae..ade76d5ab 100644 --- a/easy_rec/python/model/multi_task_model.py +++ b/easy_rec/python/model/multi_task_model.py @@ -137,21 +137,20 @@ def _add_to_prediction_dict(self, output): def build_metric_graph(self, eval_config): """Build metric graph for multi task model.""" - metric_dict = {} for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name for metric in task_tower_cfg.metrics_set: loss_types = {task_tower_cfg.loss_type} if len(task_tower_cfg.losses) > 0: loss_types = {loss.loss_type for loss in task_tower_cfg.losses} - metric_dict.update( + self._metric_dict.update( self._build_metric_impl( metric, loss_type=loss_types, label_name=self._label_name_dict[tower_name], num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name)) - return metric_dict + return self._metric_dict def build_loss_weight(self): loss_weights = OrderedDict() diff --git a/easy_rec/python/model/rank_model.py b/easy_rec/python/model/rank_model.py index f8c7f10c3..79e271483 100644 --- a/easy_rec/python/model/rank_model.py +++ b/easy_rec/python/model/rank_model.py @@ -390,18 +390,17 @@ def _build_metric_impl(self, return metric_dict def build_metric_graph(self, eval_config): - metric_dict = {} loss_types = {self._loss_type} if len(self._losses) > 0: loss_types = {loss.loss_type for loss in self._losses} for metric in eval_config.metrics_set: - metric_dict.update( + self._metric_dict.update( self._build_metric_impl( metric, loss_type=loss_types, label_name=self._label_name, num_class=self._num_class)) - return metric_dict + return self._metric_dict def _get_outputs_impl(self, loss_type, num_class=1, suffix=''): binary_loss_set = { diff --git a/easy_rec/python/protos/backbone.proto b/easy_rec/python/protos/backbone.proto index 4b96fcd24..86589a297 100644 --- a/easy_rec/python/protos/backbone.proto +++ b/easy_rec/python/protos/backbone.proto @@ -47,6 +47,8 @@ message RepeatLayer { // default output the list of multiple outputs optional int32 output_concat_axis = 2; required KerasLayer keras_layer = 3; + optional string input_slice = 4; + optional string input_fn = 5; } message Layer { diff --git a/easy_rec/python/protos/keras_layer.proto b/easy_rec/python/protos/keras_layer.proto index 5f09f4515..4d2cf9213 100644 --- a/easy_rec/python/protos/keras_layer.proto +++ b/easy_rec/python/protos/keras_layer.proto @@ -24,5 +24,6 @@ message KerasLayer { BSTEncoder bst = 13; MMoELayer mmoe = 14; SequenceAugment seq_aug = 15; + PPNet ppnet = 16; } } diff --git a/easy_rec/python/protos/layer.proto b/easy_rec/python/protos/layer.proto index 52a1cbf30..5c54741f4 100644 --- a/easy_rec/python/protos/layer.proto +++ b/easy_rec/python/protos/layer.proto @@ -49,7 +49,7 @@ message FiBiNet { message MaskBlock { optional float reduction_factor = 1; - required uint32 output_size = 2; + optional uint32 output_size = 2; optional uint32 aggregation_size = 3; optional bool input_layer_norm = 4 [default = true]; optional uint32 projection_dim = 5; @@ -69,3 +69,21 @@ message MMoELayer { // number of mmoe experts optional uint32 num_expert = 3; } + +message GateNN { + optional uint32 output_dim = 1; + optional uint32 hidden_dim = 2; + // activation function + optional string activation = 3 [default = 'relu']; + // use batch normalization + optional bool use_bn = 4 [default = false]; + optional float dropout_rate = 5; +} + +message PPNet { + required MLP mlp = 1; + required GateNN gate_params = 2; + // run mode: eager, lazy + required string mode = 3 [default = 'eager']; + optional bool full_gate_input = 4 [default = true]; +} diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index b0b66d30c..859e2442c 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -407,6 +407,11 @@ def test_cdn(self): 'samples/model_config/cdn_on_taobao.config', self._test_dir) self.assertTrue(self._success) + def test_ppnet(self): + self._success = test_utils.test_single_train_eval( + 'samples/model_config/ppnet_on_taobao.config', self._test_dir) + self.assertTrue(self._success) + def test_uniter_only_text_feature(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/uniter_on_movielens_only_text_feature.config', diff --git a/easy_rec/python/tools/add_feature_info_to_config.py b/easy_rec/python/tools/add_feature_info_to_config.py index a2df7744a..f1b4a4cfd 100644 --- a/easy_rec/python/tools/add_feature_info_to_config.py +++ b/easy_rec/python/tools/add_feature_info_to_config.py @@ -59,7 +59,7 @@ def main(argv): except common_io.exception.OutOfRangeException: reader.close() break - + feature_configs = config_util.get_compatible_feature_configs(pipeline_config) if drop_feature_names: tmp_feature_configs = feature_configs[:] diff --git a/easy_rec/version.py b/easy_rec/version.py index edc79a5a5..235c9c2a6 100644 --- a/easy_rec/version.py +++ b/easy_rec/version.py @@ -1,3 +1,3 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -__version__ = '0.7.5' +__version__ = '0.7.6' diff --git a/samples/model_config/ppnet_on_taobao.config b/samples/model_config/ppnet_on_taobao.config new file mode 100644 index 000000000..6bf4ba212 --- /dev/null +++ b/samples/model_config/ppnet_on_taobao.config @@ -0,0 +1,289 @@ +train_input_path: "data/test/tb_data/taobao_train_data" +eval_input_path: "data/test/tb_data/taobao_test_data" +model_dir: "experiments/ppnet_taobao_ckpt" + +train_config { + log_step_count_steps: 100 + optimizer_config: { + adam_optimizer: { + learning_rate: { + exponential_decay_learning_rate { + initial_learning_rate: 0.001 + decay_steps: 1000 + decay_factor: 0.5 + min_learning_rate: 0.00001 + } + } + } + use_moving_average: false + } + save_checkpoints_steps: 100 + sync_replicas: True + num_steps: 100 +} + +eval_config { + metrics_set: { + auc {} + } +} + +data_config { + input_fields { + input_name:'clk' + input_type: INT32 + } + input_fields { + input_name:'buy' + input_type: INT32 + } + input_fields { + input_name: 'pid' + input_type: STRING + } + input_fields { + input_name: 'adgroup_id' + input_type: STRING + } + input_fields { + input_name: 'cate_id' + input_type: STRING + } + input_fields { + input_name: 'campaign_id' + input_type: STRING + } + input_fields { + input_name: 'customer' + input_type: STRING + } + input_fields { + input_name: 'brand' + input_type: STRING + } + input_fields { + input_name: 'user_id' + input_type: STRING + } + input_fields { + input_name: 'cms_segid' + input_type: STRING + } + input_fields { + input_name: 'cms_group_id' + input_type: STRING + } + input_fields { + input_name: 'final_gender_code' + input_type: STRING + } + input_fields { + input_name: 'age_level' + input_type: STRING + } + input_fields { + input_name: 'pvalue_level' + input_type: STRING + } + input_fields { + input_name: 'shopping_level' + input_type: STRING + } + input_fields { + input_name: 'occupation' + input_type: STRING + } + input_fields { + input_name: 'new_user_class_level' + input_type: STRING + } + input_fields { + input_name: 'tag_category_list' + input_type: STRING + } + input_fields { + input_name: 'tag_brand_list' + input_type: STRING + } + input_fields { + input_name: 'price' + input_type: INT32 + } + + label_fields: 'clk' + batch_size: 4096 + num_epochs: 10000 + prefetch_size: 32 + input_type: CSVInput +} + +feature_config: { + features: { + input_names: 'pid' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'adgroup_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'cate_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 + } + features: { + input_names: 'campaign_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'customer' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'brand' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'user_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'cms_segid' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features: { + input_names: 'cms_group_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features: { + input_names: 'final_gender_code' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'age_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'pvalue_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'shopping_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'occupation' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'new_user_class_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'tag_category_list' + feature_type: TagFeature + separator: '|' + hash_bucket_size: 100000 + embedding_dim: 16 + } + features: { + input_names: 'tag_brand_list' + feature_type: TagFeature + separator: '|' + hash_bucket_size: 100000 + embedding_dim: 16 + } + features: { + input_names: 'price' + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 50 + } +} +model_config: { + model_name: 'PPNet' + model_class: 'RankModel' + feature_groups: { + group_name: 'memorize' + feature_names: 'user_id' + feature_names: 'adgroup_id' + feature_names: 'pid' + wide_deep: DEEP + } + feature_groups: { + group_name: 'general' + feature_names: 'cms_segid' + feature_names: 'cms_group_id' + feature_names: 'age_level' + feature_names: 'pvalue_level' + feature_names: 'shopping_level' + feature_names: 'occupation' + feature_names: 'new_user_class_level' + feature_names: 'cate_id' + feature_names: 'campaign_id' + feature_names: 'customer' + feature_names: 'brand' + feature_names: 'price' + feature_names: 'tag_category_list' + feature_names: 'tag_brand_list' + wide_deep: DEEP + } + backbone { + blocks { + name: "ppnet" + inputs { + feature_group_name: "general" + } + inputs { + feature_group_name: "memorize" + } + merge_inputs_into_list: true + keras_layer { + class_name: "PPNet" + ppnet { + mlp { + hidden_units: [512, 256] + } + mode: "lazy" + full_gate_input: true + } + } + } + top_mlp { + hidden_units: [128, 64] + } + } + model_params { + l2_regularization: 1e-6 + } + embedding_regularization: 1e-5 +}