Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ZILN loss for ltv prediction task #498

Merged
merged 9 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/component/backbone.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ MovieLens-1M数据集效果对比:
- 还有一些特殊的`block`关联了一个特殊的模块,包括`lambda layer`、`sequential layers`、`repeated layer`和`recurrent layer`。这些特殊layer分别实现了自定义表达式、顺序执行多个layer、重复执行某个layer、循环执行某个layer的功能。
- DAG的输出节点名由`concat_blocks`配置项指定,配置了多个输出节点时自动执行tensor的concat操作。
- 如果不配置`concat_blocks`,框架会自动拼接DAG的所有叶子节点并输出。
- 如果多个`block`的输出不需要 concat 在一起,而是作为一个list类型(下游对接多目标学习的tower)可以用`output_blocks`代替`concat_blocks`
- 可以为主干网络配置一个可选的`MLP`模块。

![](../../images/component/wide_deep.png)
Expand Down Expand Up @@ -1275,6 +1276,8 @@ message InputLayer {
optional bool only_output_3d_tensor = 6;
optional bool output_2d_tensor_and_feature_list = 7;
optional bool output_seq_and_normal_feature = 8;
optional uint32 wide_output_dim = 9;
optional bool concat_seq_feature = 10 [default = true];
}
```

Expand All @@ -1288,6 +1291,8 @@ message InputLayer {
- `only_output_3d_tensor` 输出`feature group`对应的一个3d tensor,在`embedding_dim`相同时可配置该项
- `output_2d_tensor_and_feature_list` 是否同时输出2d tensor与特征list
- `output_seq_and_normal_feature` 是否输出(sequence特征, 常规特征)元组
- `wide_output_dim` wide模型每个特征的参数权重维度,一般设定为1
- `concat_seq_feature` 是否需要把序列特征的embedding拼接在一起

## 3. Lambda组件块

Expand Down Expand Up @@ -1437,6 +1442,12 @@ blocks {
}
```

## 8. 输出组件

- 使用`concat_blocks`或者`output_blocks`配置主干网络的输出
- 两种的区别是前者会对多个输出组件块的结果按照最后一个axis拼接在一起;后者不会拼接,而是以list类型输出
- 如果不配置上述两个选项,框架会自动拼接DAG的所有叶子节点并输出。

## 通过`组件包`实现参数共享的子网络

`组件包`封装了由多个`组件块`搭建的一个子网络DAG,作为整体可以被以参数共享的方式多次调用,通常用在 *自监督学习* 模型中。
Expand Down
43 changes: 34 additions & 9 deletions docs/source/component/component.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

| 类名 | 功能 | 说明 | 示例 |
| ----------------- | ------ | ------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
| MLP | 多层感知机 | 可定制激活函数、initializer、Dropout、BN等 | [案例1](backbone.md#wide-deep) |
| MLP | 多层感知机 | 可定制激活函数、initializer、Dropout、BN等 | [案例1](backbone.html#wide-deep) |
| Highway | 类似残差链接 | 可用来对预训练embedding做增量微调 | [highway network](../models/highway.html) |
| Gate | 门控 | 多个输入的加权求和 | [Cross Decoupling Network](../models/cdn.html#id2) |
| PeriodicEmbedding | 周期激活函数 | 数值特征Embedding | [案例5](backbone.md#dlrm-embedding) |
| PeriodicEmbedding | 周期激活函数 | 数值特征Embedding | [案例5](backbone.html#dlrm-embedding) |
| AutoDisEmbedding | 自动离散化 | 数值特征Embedding | [dlrm_on_criteo_with_autodis.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dlrm_on_criteo_with_autodis.config) |
| NaryDisEmbedding | N进制编码 | 数值特征Embedding | [dlrm_on_criteo_with_narydis.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dlrm_on_criteo_with_narydis.config) |
| TextCNN | 文本卷积 | 提取文本序列的特征 | [text_cnn_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/text_cnn_on_movielens.config) |
Expand All @@ -18,9 +18,9 @@

| 类名 | 功能 | 说明 | 示例 |
| -------------- | ---------------- | ------------ | -------------------------------------------------------------------------------------------------------------------------- |
| FM | 二阶交叉 | DeepFM模型的组件 | [案例2](backbone.md#deepfm) |
| DotInteraction | 二阶内积交叉 | DLRM模型的组件 | [案例4](backbone.md#dlrm) |
| Cross | bit-wise交叉 | DCN v2模型的组件 | [案例3](backbone.md#dcn) |
| FM | 二阶交叉 | DeepFM模型的组件 | [案例2](backbone.html#deepfm) |
| DotInteraction | 二阶内积交叉 | DLRM模型的组件 | [案例4](backbone.html#dlrm) |
| Cross | bit-wise交叉 | DCN v2模型的组件 | [案例3](backbone.html#dcn) |
| BiLinear | 双线性 | FiBiNet模型的组件 | [fibinet_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/fibinet_on_movielens.config) |
| FiBiNet | SENet & BiLinear | FiBiNet模型 | [fibinet_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/fibinet_on_movielens.config) |

Expand Down Expand Up @@ -50,14 +50,14 @@

| 类名 | 功能 | 说明 | 示例 |
| --------- | --------------------------- | --------- | ----------------------------- |
| MMoE | Multiple Mixture of Experts | MMoE模型的组件 | [案例8](backbone.md#mmoe) |
| MMoE | Multiple Mixture of Experts | MMoE模型的组件 | [案例8](backbone.html#mmoe) |
| AITMTower | AITM模型的一个tower | AITM模型的组件 | [AITM](../models/aitm.md#id2) |

## 6. 辅助损失函数组件

| 类名 | 功能 | 说明 | 示例 |
| ------------- | ---------- | --------- | ---------------------- |
| AuxiliaryLoss | 用来计算辅助损失函数 | 常用在自监督学习中 | [案例7](backbone.md#id7) |
| 类名 | 功能 | 说明 | 示例 |
| ------------- | ---------- | --------- | ------------------------ |
| AuxiliaryLoss | 用来计算辅助损失函数 | 常用在自监督学习中 | [案例7](backbone.html#id7) |

# 组件详细参数

Expand Down Expand Up @@ -138,6 +138,31 @@

## 2.特征交叉组件

- FM

| 参数 | 类型 | 默认值 | 说明 |
| ----------- | ---- | ----- | -------------------------- |
| use_variant | bool | false | 是否使用FM的变体:所有二阶交叉项直接输出,而不求和 |

- DotInteraction

| 参数 | 类型 | 默认值 | 说明 |
| ---------------- | ---- | ----- | ------------------------------------ |
| self_interaction | bool | false | 是否运行特征自己与自己交叉 |
| skip_gather | bool | false | 一个优化开关,设置为true,可以提高运行速度,但需要占用更多的内存空间 |

- Cross

| 参数 | 类型 | 默认值 | 说明 |
| ------------------ | ------ | ---------------- | ------------------------------------------------------------------------------------------------------------------------- |
| projection_dim | uint32 | None | 使用矩阵分解降低计算开销,把大的权重矩阵分解为两个小的矩阵相乘,projection_dim是第一个小矩阵的列数,也是第二个小矩阵的行数 |
| diag_scale | float | 0 | used to increase the diagonal of the kernel W by `diag_scale`, that is, W + diag_scale * I, where I is an identity matrix |
| use_bias | bool | true | whether to add a bias term for this layer. |
| kernel_initializer | string | truncated_normal | Initializer to use on the kernel matrix |
| bias_initializer | string | zeros | Initializer to use on the bias vector |
| kernel_regularizer | string | None | Regularizer to use on the kernel matrix |
| bias_regularizer | string | None | Regularizer to use on bias vector |

- Bilinear

| 参数 | 类型 | 默认值 | 说明 |
Expand Down
36 changes: 36 additions & 0 deletions docs/source/component/custom_loss.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# 自定义辅助损失函数组件

可以使用如下方法添加多个辅助损失函数。

在`easy_rec/python/layers/keras/auxiliary_loss.py`里添加一个新的loss函数。
如果计算逻辑比较复杂,建议在一个单独的python文件中实现,然后在`auxiliary_loss.py`里import并使用。

注意:用来标记损失函数类型的`loss_type`参数需要全局唯一。

## 配置方法

```protobuf
blocks {
name: 'custom_loss'
inputs {
block_name: 'pred'
}
inputs {
block_name: 'logit'
}
merge_inputs_into_list: true
keras_layer {
class_name: 'AuxiliaryLoss'
st_params {
fields {
key: "loss_type"
value { string_value: "my_custom_loss" }
}
}
}
}
```

st_params 参数列表下可以追加自定义参数。

记得使用`concat_blocks`或者`output_blocks`配置输出的block列表(不包括当前`custom_loss`节点)。
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Welcome to easy_rec's documentation!
component/backbone
component/component
component/sequence
component/custom_loss
component/custom_op

.. toctree::
Expand Down
6 changes: 6 additions & 0 deletions docs/source/models/loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
| ORDER_CALIBRATE_LOSS | 使用目标依赖关系校正预测结果的辅助损失函数,详见[AITM](aitm.md)模型 |
| LISTWISE_RANK_LOSS | listwise的排序损失 |
| LISTWISE_DISTILL_LOSS | 用来蒸馏给定list排序的损失函数,与listwise rank loss 比较类似 |
| ZILN_LOSS | LTV预测任务的损失函数(num_class必须设置为3) |

- ZILN_LOSS:使用时模型有3个可选的输出(在多目标任务重,输出名有一个目标相关的后缀)
- probs: 预估的转化概率
- y: 预估的LTV值
- logits: Shape为`[batch_size, 3]`的tensor,第一列是`probs`,第二列和第三列是学习到的LogNormal分布的均值与方差
- 说明:SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING
- 支持参数配置,升级为 [support vector guided softmax loss](https://128.84.21.199/abs/1812.11317) ,
- 目前只在DropoutNet模型中可用,可参考《 [冷启动推荐模型DropoutNet深度解析与改进](https://zhuanlan.zhihu.com/p/475117993) 》。
Expand Down Expand Up @@ -184,3 +189,4 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
- [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/abs/2111.10603)
- [AITM: Modeling the Sequential Dependence among Audience Multi-step Conversions with Multi-task Learning in Targeted Display Advertising](https://arxiv.org/pdf/2105.08489.pdf)
- [Pairwise Ranking Distillation for Deep Face Recognition](https://ceur-ws.org/Vol-2744/paper30.pdf)
- [A DEEP PROBABILISTIC MODEL FOR CUSTOMER LIFETIME VALUE PREDICTION](https://arxiv.org/pdf/1912.07753)
8 changes: 8 additions & 0 deletions easy_rec/python/builders/loss_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging

import numpy as np
import tensorflow as tf

from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits
Expand All @@ -14,6 +15,8 @@
from easy_rec.python.loss.pairwise_loss import pairwise_loss
from easy_rec.python.protos.loss_pb2 import LossType

from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_loss # NOQA

from easy_rec.python.loss.f1_reweight_loss import f1_reweight_sigmoid_cross_entropy # NOQA

if tf.__version__ >= '2.0':
Expand Down Expand Up @@ -46,6 +49,11 @@ def build(loss_type,
logging.info('%s is used' % LossType.Name(loss_type))
return tf.losses.mean_squared_error(
labels=label, predictions=pred, weights=loss_weight, **kwargs)
elif loss_type == LossType.ZILN_LOSS:
loss = zero_inflated_lognormal_loss(label, pred)
if np.isscalar(loss_weight) and loss_weight != 1.0:
return loss * loss_weight
return loss
elif loss_type == LossType.JRC_LOSS:
session = kwargs.get('session_ids', None)
if loss_param is None:
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/compat/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import os
import threading
import time
from distutils.version import LooseVersion

import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
Expand Down
3 changes: 2 additions & 1 deletion easy_rec/python/loss/jrc_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging

import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -122,6 +123,6 @@ def jrc_loss(labels,
else:
raise ValueError('Unsupported loss weight strategy `%s` for jrc loss' %
loss_weight_strategy)
if np.isscalar(sample_weights):
if np.isscalar(sample_weights) and sample_weights != 1.0:
return loss * sample_weights
return loss
76 changes: 76 additions & 0 deletions easy_rec/python/loss/zero_inflated_lognormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Zero-inflated lognormal loss for lifetime value prediction."""
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions

if tf.__version__ >= '2.0':
tf = tf.compat.v1


def zero_inflated_lognormal_pred(logits):
"""Calculates predicted mean of zero inflated lognormal logits.

Arguments:
logits: [batch_size, 3] tensor of logits.

Returns:
positive_probs: [batch_size, 1] tensor of positive probability.
preds: [batch_size, 1] tensor of predicted mean.
"""
logits = tf.convert_to_tensor(logits, dtype=tf.float32)
positive_probs = tf.keras.backend.sigmoid(logits[..., :1])
loc = logits[..., 1:2]
scale = tf.keras.backend.softplus(logits[..., 2:])
preds = (
positive_probs *
tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale)))
return positive_probs, preds


def zero_inflated_lognormal_loss(labels, logits, name=''):
"""Computes the zero inflated lognormal loss.

Usage with tf.keras API:

```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=zero_inflated_lognormal)
```

Arguments:
labels: True targets, tensor of shape [batch_size, 1].
logits: Logits of output layer, tensor of shape [batch_size, 3].
name: the name of loss

Returns:
Zero inflated lognormal loss value.
"""
loss_name = name if name else 'ziln_loss'
labels = tf.cast(labels, dtype=tf.float32)
if labels.shape.ndims == 1:
labels = tf.expand_dims(labels, 1) # [B, 1]
positive = tf.cast(labels > 0, tf.float32)

logits = tf.convert_to_tensor(logits, dtype=tf.float32)
logits.shape.assert_is_compatible_with(
tf.TensorShape(labels.shape[:-1].as_list() + [3]))

positive_logits = logits[..., :1]
classification_loss = tf.keras.backend.binary_crossentropy(
positive, positive_logits, from_logits=True)
classification_loss = tf.keras.backend.mean(classification_loss)
tf.summary.scalar('loss/%s_classify' % loss_name, classification_loss)

loc = logits[..., 1:2]
scale = tf.math.maximum(
tf.keras.backend.softplus(logits[..., 2:]),
tf.math.sqrt(tf.keras.backend.epsilon()))
safe_labels = positive * labels + (
1 - positive) * tf.keras.backend.ones_like(labels)
regression_loss = -tf.keras.backend.mean(
positive * tfd.LogNormal(loc=loc, scale=scale).log_prob(safe_labels))
tf.summary.scalar('loss/%s_regression' % loss_name, regression_loss)
return classification_loss + regression_loss
Loading
Loading