-
Notifications
You must be signed in to change notification settings - Fork 341
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add ZILN loss for ltv prediction task (#498)
* add ZILN loss for ltv prediction task & add documents
- Loading branch information
1 parent
5996b18
commit f00d8a8
Showing
19 changed files
with
543 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# 自定义辅助损失函数组件 | ||
|
||
可以使用如下方法添加多个辅助损失函数。 | ||
|
||
在`easy_rec/python/layers/keras/auxiliary_loss.py`里添加一个新的loss函数。 | ||
如果计算逻辑比较复杂,建议在一个单独的python文件中实现,然后在`auxiliary_loss.py`里import并使用。 | ||
|
||
注意:用来标记损失函数类型的`loss_type`参数需要全局唯一。 | ||
|
||
## 配置方法 | ||
|
||
```protobuf | ||
blocks { | ||
name: 'custom_loss' | ||
inputs { | ||
block_name: 'pred' | ||
} | ||
inputs { | ||
block_name: 'logit' | ||
} | ||
merge_inputs_into_list: true | ||
keras_layer { | ||
class_name: 'AuxiliaryLoss' | ||
st_params { | ||
fields { | ||
key: "loss_type" | ||
value { string_value: "my_custom_loss" } | ||
} | ||
} | ||
} | ||
} | ||
``` | ||
|
||
st_params 参数列表下可以追加自定义参数。 | ||
|
||
记得使用`concat_blocks`或者`output_blocks`配置输出的block列表(不包括当前`custom_loss`节点)。 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# -*- encoding:utf-8 -*- | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
"""Zero-inflated lognormal loss for lifetime value prediction.""" | ||
import tensorflow as tf | ||
import tensorflow_probability as tfp | ||
|
||
tfd = tfp.distributions | ||
|
||
if tf.__version__ >= '2.0': | ||
tf = tf.compat.v1 | ||
|
||
|
||
def zero_inflated_lognormal_pred(logits): | ||
"""Calculates predicted mean of zero inflated lognormal logits. | ||
Arguments: | ||
logits: [batch_size, 3] tensor of logits. | ||
Returns: | ||
positive_probs: [batch_size, 1] tensor of positive probability. | ||
preds: [batch_size, 1] tensor of predicted mean. | ||
""" | ||
logits = tf.convert_to_tensor(logits, dtype=tf.float32) | ||
positive_probs = tf.keras.backend.sigmoid(logits[..., :1]) | ||
loc = logits[..., 1:2] | ||
scale = tf.keras.backend.softplus(logits[..., 2:]) | ||
preds = ( | ||
positive_probs * | ||
tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale))) | ||
return positive_probs, preds | ||
|
||
|
||
def zero_inflated_lognormal_loss(labels, logits, name=''): | ||
"""Computes the zero inflated lognormal loss. | ||
Usage with tf.keras API: | ||
```python | ||
model = tf.keras.Model(inputs, outputs) | ||
model.compile('sgd', loss=zero_inflated_lognormal) | ||
``` | ||
Arguments: | ||
labels: True targets, tensor of shape [batch_size, 1]. | ||
logits: Logits of output layer, tensor of shape [batch_size, 3]. | ||
name: the name of loss | ||
Returns: | ||
Zero inflated lognormal loss value. | ||
""" | ||
loss_name = name if name else 'ziln_loss' | ||
labels = tf.cast(labels, dtype=tf.float32) | ||
if labels.shape.ndims == 1: | ||
labels = tf.expand_dims(labels, 1) # [B, 1] | ||
positive = tf.cast(labels > 0, tf.float32) | ||
|
||
logits = tf.convert_to_tensor(logits, dtype=tf.float32) | ||
logits.shape.assert_is_compatible_with( | ||
tf.TensorShape(labels.shape[:-1].as_list() + [3])) | ||
|
||
positive_logits = logits[..., :1] | ||
classification_loss = tf.keras.backend.binary_crossentropy( | ||
positive, positive_logits, from_logits=True) | ||
classification_loss = tf.keras.backend.mean(classification_loss) | ||
tf.summary.scalar('loss/%s_classify' % loss_name, classification_loss) | ||
|
||
loc = logits[..., 1:2] | ||
scale = tf.math.maximum( | ||
tf.keras.backend.softplus(logits[..., 2:]), | ||
tf.math.sqrt(tf.keras.backend.epsilon())) | ||
safe_labels = positive * labels + ( | ||
1 - positive) * tf.keras.backend.ones_like(labels) | ||
regression_loss = -tf.keras.backend.mean( | ||
positive * tfd.LogNormal(loc=loc, scale=scale).log_prob(safe_labels)) | ||
tf.summary.scalar('loss/%s_regression' % loss_name, regression_loss) | ||
return classification_loss + regression_loss |
Oops, something went wrong.