Skip to content

Commit

Permalink
add ZILN loss for ltv prediction task (#498)
Browse files Browse the repository at this point in the history
* add ZILN loss for ltv prediction task & add documents
  • Loading branch information
yangxudong authored Nov 12, 2024
1 parent 5996b18 commit f00d8a8
Show file tree
Hide file tree
Showing 19 changed files with 543 additions and 27 deletions.
11 changes: 11 additions & 0 deletions docs/source/component/backbone.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ MovieLens-1M数据集效果对比:
- 还有一些特殊的`block`关联了一个特殊的模块,包括`lambda layer``sequential layers``repeated layer``recurrent layer`。这些特殊layer分别实现了自定义表达式、顺序执行多个layer、重复执行某个layer、循环执行某个layer的功能。
- DAG的输出节点名由`concat_blocks`配置项指定,配置了多个输出节点时自动执行tensor的concat操作。
- 如果不配置`concat_blocks`,框架会自动拼接DAG的所有叶子节点并输出。
- 如果多个`block`的输出不需要 concat 在一起,而是作为一个list类型(下游对接多目标学习的tower)可以用`output_blocks`代替`concat_blocks`
- 可以为主干网络配置一个可选的`MLP`模块。

![](../../images/component/wide_deep.png)
Expand Down Expand Up @@ -1275,6 +1276,8 @@ message InputLayer {
optional bool only_output_3d_tensor = 6;
optional bool output_2d_tensor_and_feature_list = 7;
optional bool output_seq_and_normal_feature = 8;
optional uint32 wide_output_dim = 9;
optional bool concat_seq_feature = 10 [default = true];
}
```

Expand All @@ -1288,6 +1291,8 @@ message InputLayer {
- `only_output_3d_tensor` 输出`feature group`对应的一个3d tensor,在`embedding_dim`相同时可配置该项
- `output_2d_tensor_and_feature_list` 是否同时输出2d tensor与特征list
- `output_seq_and_normal_feature` 是否输出(sequence特征, 常规特征)元组
- `wide_output_dim` wide模型每个特征的参数权重维度,一般设定为1
- `concat_seq_feature` 是否需要把序列特征的embedding拼接在一起

## 3. Lambda组件块

Expand Down Expand Up @@ -1437,6 +1442,12 @@ blocks {
}
```

## 8. 输出组件

- 使用`concat_blocks`或者`output_blocks`配置主干网络的输出
- 两种的区别是前者会对多个输出组件块的结果按照最后一个axis拼接在一起;后者不会拼接,而是以list类型输出
- 如果不配置上述两个选项,框架会自动拼接DAG的所有叶子节点并输出。

## 通过`组件包`实现参数共享的子网络

`组件包`封装了由多个`组件块`搭建的一个子网络DAG,作为整体可以被以参数共享的方式多次调用,通常用在 *自监督学习* 模型中。
Expand Down
43 changes: 34 additions & 9 deletions docs/source/component/component.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

| 类名 | 功能 | 说明 | 示例 |
| ----------------- | ------ | ------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
| MLP | 多层感知机 | 可定制激活函数、initializer、Dropout、BN等 | [案例1](backbone.md#wide-deep) |
| MLP | 多层感知机 | 可定制激活函数、initializer、Dropout、BN等 | [案例1](backbone.html#wide-deep) |
| Highway | 类似残差链接 | 可用来对预训练embedding做增量微调 | [highway network](../models/highway.html) |
| Gate | 门控 | 多个输入的加权求和 | [Cross Decoupling Network](../models/cdn.html#id2) |
| PeriodicEmbedding | 周期激活函数 | 数值特征Embedding | [案例5](backbone.md#dlrm-embedding) |
| PeriodicEmbedding | 周期激活函数 | 数值特征Embedding | [案例5](backbone.html#dlrm-embedding) |
| AutoDisEmbedding | 自动离散化 | 数值特征Embedding | [dlrm_on_criteo_with_autodis.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dlrm_on_criteo_with_autodis.config) |
| NaryDisEmbedding | N进制编码 | 数值特征Embedding | [dlrm_on_criteo_with_narydis.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dlrm_on_criteo_with_narydis.config) |
| TextCNN | 文本卷积 | 提取文本序列的特征 | [text_cnn_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/text_cnn_on_movielens.config) |
Expand All @@ -18,9 +18,9 @@

| 类名 | 功能 | 说明 | 示例 |
| -------------- | ---------------- | ------------ | -------------------------------------------------------------------------------------------------------------------------- |
| FM | 二阶交叉 | DeepFM模型的组件 | [案例2](backbone.md#deepfm) |
| DotInteraction | 二阶内积交叉 | DLRM模型的组件 | [案例4](backbone.md#dlrm) |
| Cross | bit-wise交叉 | DCN v2模型的组件 | [案例3](backbone.md#dcn) |
| FM | 二阶交叉 | DeepFM模型的组件 | [案例2](backbone.html#deepfm) |
| DotInteraction | 二阶内积交叉 | DLRM模型的组件 | [案例4](backbone.html#dlrm) |
| Cross | bit-wise交叉 | DCN v2模型的组件 | [案例3](backbone.html#dcn) |
| BiLinear | 双线性 | FiBiNet模型的组件 | [fibinet_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/fibinet_on_movielens.config) |
| FiBiNet | SENet & BiLinear | FiBiNet模型 | [fibinet_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/fibinet_on_movielens.config) |

Expand Down Expand Up @@ -50,14 +50,14 @@

| 类名 | 功能 | 说明 | 示例 |
| --------- | --------------------------- | --------- | ----------------------------- |
| MMoE | Multiple Mixture of Experts | MMoE模型的组件 | [案例8](backbone.md#mmoe) |
| MMoE | Multiple Mixture of Experts | MMoE模型的组件 | [案例8](backbone.html#mmoe) |
| AITMTower | AITM模型的一个tower | AITM模型的组件 | [AITM](../models/aitm.md#id2) |

## 6. 辅助损失函数组件

| 类名 | 功能 | 说明 | 示例 |
| ------------- | ---------- | --------- | ---------------------- |
| AuxiliaryLoss | 用来计算辅助损失函数 | 常用在自监督学习中 | [案例7](backbone.md#id7) |
| 类名 | 功能 | 说明 | 示例 |
| ------------- | ---------- | --------- | ------------------------ |
| AuxiliaryLoss | 用来计算辅助损失函数 | 常用在自监督学习中 | [案例7](backbone.html#id7) |

# 组件详细参数

Expand Down Expand Up @@ -138,6 +138,31 @@

## 2.特征交叉组件

- FM

| 参数 | 类型 | 默认值 | 说明 |
| ----------- | ---- | ----- | -------------------------- |
| use_variant | bool | false | 是否使用FM的变体:所有二阶交叉项直接输出,而不求和 |

- DotInteraction

| 参数 | 类型 | 默认值 | 说明 |
| ---------------- | ---- | ----- | ------------------------------------ |
| self_interaction | bool | false | 是否运行特征自己与自己交叉 |
| skip_gather | bool | false | 一个优化开关,设置为true,可以提高运行速度,但需要占用更多的内存空间 |

- Cross

| 参数 | 类型 | 默认值 | 说明 |
| ------------------ | ------ | ---------------- | ------------------------------------------------------------------------------------------------------------------------- |
| projection_dim | uint32 | None | 使用矩阵分解降低计算开销,把大的权重矩阵分解为两个小的矩阵相乘,projection_dim是第一个小矩阵的列数,也是第二个小矩阵的行数 |
| diag_scale | float | 0 | used to increase the diagonal of the kernel W by `diag_scale`, that is, W + diag_scale * I, where I is an identity matrix |
| use_bias | bool | true | whether to add a bias term for this layer. |
| kernel_initializer | string | truncated_normal | Initializer to use on the kernel matrix |
| bias_initializer | string | zeros | Initializer to use on the bias vector |
| kernel_regularizer | string | None | Regularizer to use on the kernel matrix |
| bias_regularizer | string | None | Regularizer to use on bias vector |

- Bilinear

| 参数 | 类型 | 默认值 | 说明 |
Expand Down
36 changes: 36 additions & 0 deletions docs/source/component/custom_loss.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# 自定义辅助损失函数组件

可以使用如下方法添加多个辅助损失函数。

`easy_rec/python/layers/keras/auxiliary_loss.py`里添加一个新的loss函数。
如果计算逻辑比较复杂,建议在一个单独的python文件中实现,然后在`auxiliary_loss.py`里import并使用。

注意:用来标记损失函数类型的`loss_type`参数需要全局唯一。

## 配置方法

```protobuf
blocks {
name: 'custom_loss'
inputs {
block_name: 'pred'
}
inputs {
block_name: 'logit'
}
merge_inputs_into_list: true
keras_layer {
class_name: 'AuxiliaryLoss'
st_params {
fields {
key: "loss_type"
value { string_value: "my_custom_loss" }
}
}
}
}
```

st_params 参数列表下可以追加自定义参数。

记得使用`concat_blocks`或者`output_blocks`配置输出的block列表(不包括当前`custom_loss`节点)。
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Welcome to easy_rec's documentation!
component/backbone
component/component
component/sequence
component/custom_loss
component/custom_op

.. toctree::
Expand Down
6 changes: 6 additions & 0 deletions docs/source/models/loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
| ORDER_CALIBRATE_LOSS | 使用目标依赖关系校正预测结果的辅助损失函数,详见[AITM](aitm.md)模型 |
| LISTWISE_RANK_LOSS | listwise的排序损失 |
| LISTWISE_DISTILL_LOSS | 用来蒸馏给定list排序的损失函数,与listwise rank loss 比较类似 |
| ZILN_LOSS | LTV预测任务的损失函数(num_class必须设置为3) |

- ZILN_LOSS:使用时模型有3个可选的输出(在多目标任务重,输出名有一个目标相关的后缀)
- probs: 预估的转化概率
- y: 预估的LTV值
- logits: Shape为`[batch_size, 3]`的tensor,第一列是`probs`,第二列和第三列是学习到的LogNormal分布的均值与方差
- 说明:SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING
- 支持参数配置,升级为 [support vector guided softmax loss](https://128.84.21.199/abs/1812.11317)
- 目前只在DropoutNet模型中可用,可参考《 [冷启动推荐模型DropoutNet深度解析与改进](https://zhuanlan.zhihu.com/p/475117993) 》。
Expand Down Expand Up @@ -184,3 +189,4 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
- [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/abs/2111.10603)
- [AITM: Modeling the Sequential Dependence among Audience Multi-step Conversions with Multi-task Learning in Targeted Display Advertising](https://arxiv.org/pdf/2105.08489.pdf)
- [Pairwise Ranking Distillation for Deep Face Recognition](https://ceur-ws.org/Vol-2744/paper30.pdf)
- [A DEEP PROBABILISTIC MODEL FOR CUSTOMER LIFETIME VALUE PREDICTION](https://arxiv.org/pdf/1912.07753)
8 changes: 8 additions & 0 deletions easy_rec/python/builders/loss_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging

import numpy as np
import tensorflow as tf

from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits
Expand All @@ -14,6 +15,8 @@
from easy_rec.python.loss.pairwise_loss import pairwise_loss
from easy_rec.python.protos.loss_pb2 import LossType

from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_loss # NOQA

from easy_rec.python.loss.f1_reweight_loss import f1_reweight_sigmoid_cross_entropy # NOQA

if tf.__version__ >= '2.0':
Expand Down Expand Up @@ -46,6 +49,11 @@ def build(loss_type,
logging.info('%s is used' % LossType.Name(loss_type))
return tf.losses.mean_squared_error(
labels=label, predictions=pred, weights=loss_weight, **kwargs)
elif loss_type == LossType.ZILN_LOSS:
loss = zero_inflated_lognormal_loss(label, pred)
if np.isscalar(loss_weight) and loss_weight != 1.0:
return loss * loss_weight
return loss
elif loss_type == LossType.JRC_LOSS:
session = kwargs.get('session_ids', None)
if loss_param is None:
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/compat/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import os
import threading
import time
from distutils.version import LooseVersion

import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
Expand Down
3 changes: 2 additions & 1 deletion easy_rec/python/loss/jrc_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging

import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -122,6 +123,6 @@ def jrc_loss(labels,
else:
raise ValueError('Unsupported loss weight strategy `%s` for jrc loss' %
loss_weight_strategy)
if np.isscalar(sample_weights):
if np.isscalar(sample_weights) and sample_weights != 1.0:
return loss * sample_weights
return loss
76 changes: 76 additions & 0 deletions easy_rec/python/loss/zero_inflated_lognormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Zero-inflated lognormal loss for lifetime value prediction."""
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions

if tf.__version__ >= '2.0':
tf = tf.compat.v1


def zero_inflated_lognormal_pred(logits):
"""Calculates predicted mean of zero inflated lognormal logits.
Arguments:
logits: [batch_size, 3] tensor of logits.
Returns:
positive_probs: [batch_size, 1] tensor of positive probability.
preds: [batch_size, 1] tensor of predicted mean.
"""
logits = tf.convert_to_tensor(logits, dtype=tf.float32)
positive_probs = tf.keras.backend.sigmoid(logits[..., :1])
loc = logits[..., 1:2]
scale = tf.keras.backend.softplus(logits[..., 2:])
preds = (
positive_probs *
tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale)))
return positive_probs, preds


def zero_inflated_lognormal_loss(labels, logits, name=''):
"""Computes the zero inflated lognormal loss.
Usage with tf.keras API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=zero_inflated_lognormal)
```
Arguments:
labels: True targets, tensor of shape [batch_size, 1].
logits: Logits of output layer, tensor of shape [batch_size, 3].
name: the name of loss
Returns:
Zero inflated lognormal loss value.
"""
loss_name = name if name else 'ziln_loss'
labels = tf.cast(labels, dtype=tf.float32)
if labels.shape.ndims == 1:
labels = tf.expand_dims(labels, 1) # [B, 1]
positive = tf.cast(labels > 0, tf.float32)

logits = tf.convert_to_tensor(logits, dtype=tf.float32)
logits.shape.assert_is_compatible_with(
tf.TensorShape(labels.shape[:-1].as_list() + [3]))

positive_logits = logits[..., :1]
classification_loss = tf.keras.backend.binary_crossentropy(
positive, positive_logits, from_logits=True)
classification_loss = tf.keras.backend.mean(classification_loss)
tf.summary.scalar('loss/%s_classify' % loss_name, classification_loss)

loc = logits[..., 1:2]
scale = tf.math.maximum(
tf.keras.backend.softplus(logits[..., 2:]),
tf.math.sqrt(tf.keras.backend.epsilon()))
safe_labels = positive * labels + (
1 - positive) * tf.keras.backend.ones_like(labels)
regression_loss = -tf.keras.backend.mean(
positive * tfd.LogNormal(loc=loc, scale=scale).log_prob(safe_labels))
tf.summary.scalar('loss/%s_regression' % loss_name, regression_loss)
return classification_loss + regression_loss
Loading

0 comments on commit f00d8a8

Please sign in to comment.