Skip to content

Commit

Permalink
add contrastive learning CL4SRec
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Sep 21, 2023
1 parent 8f5c05d commit f8ccdd9
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 64 deletions.
10 changes: 5 additions & 5 deletions docs/source/component/backbone.md
Original file line number Diff line number Diff line change
Expand Up @@ -1002,11 +1002,11 @@ MovieLens-1M数据集效果:

## 4. 序列特征编码组件

| 类名 | 功能 | 说明 | 示例 |
| --- | ---------------- | -------- | ------------------------------------------------------------------------------------------------------------------------ |
| DIN | target attention | DIN模型的组件 | [DIN_backbone.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/din_backbone_on_taobao.config) |
| BST | transformer | BST模型的组件 | [BST_backbone.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/bst_backbone_on_taobao.config) |
| SeqAugment | 序列数据增强 | crop, mask, reorder | [CL4SRec](../models/cl4srec.html) |
| 类名 | 功能 | 说明 | 示例 |
| ---------- | ---------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------ |
| DIN | target attention | DIN模型的组件 | [DIN_backbone.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/din_backbone_on_taobao.config) |
| BST | transformer | BST模型的组件 | [BST_backbone.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/bst_backbone_on_taobao.config) |
| SeqAugment | 序列数据增强 | crop, mask, reorder | [CL4SRec](../models/cl4srec.html) |

## 5. 多目标学习组件

Expand Down
56 changes: 28 additions & 28 deletions docs/source/component/component.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,36 +100,36 @@

- SeqAugment

| 参数 | 类型 | 默认值 | 说明 |
| ------------ | ---- | ---- | ------------- |
| mask_rate | float | 0.6 | 被mask掉的token比率 |
| crop_rate | float | 0.2 | 裁剪保留的token比率 |
| reorder_rate | float | 0.6 | shuffle的子序列长度占比 |
| 参数 | 类型 | 默认值 | 说明 |
| ------------ | ----- | --- | --------------- |
| mask_rate | float | 0.6 | 被mask掉的token比率 |
| crop_rate | float | 0.2 | 裁剪保留的token比率 |
| reorder_rate | float | 0.6 | shuffle的子序列长度占比 |

- DIN

| 参数 | 类型 | 默认值 | 说明 |
| ------------ | ---- | ---- | ------------- |
| attention_dnn | MLP | | attention unit mlp |
| need_target_feature | bool | true | 是否返回target item embedding |
| 参数 | 类型 | 默认值 | 说明 |
| -------------------- | ------ | ------- | ------------------------- |
| attention_dnn | MLP | | attention unit mlp |
| need_target_feature | bool | true | 是否返回target item embedding |
| attention_normalizer | string | softmax | softmax or sigmoid |

- BST

| 参数 | 类型 | 默认值 | 说明 |
| ------------ | ---- | ---- | ------------- |
| hidden_size | int | | transformer 编码层单元数 |
| num_hidden_layers | int | | transformer层数 |
| num_attention_heads | int | | transformer head数 |
| intermediate_size | int | | transformer中间层单元数 |
| hidden_act | string | gelu | 隐藏激活函数 |
| hidden_dropout_prob | float | 0.1 | 隐藏dropout rate |
| attention_probs_dropout_prob | float | 0.1 | attention层dropout rate|
| max_position_embeddings | int | 512 | 序列最大长度 |
| use_position_embeddings | bool | true | 是否使用位置编码 |
| initializer_range | float | 0.2 | 权重参数初始值的区间范围 |
| output_all_token_embeddings | bool | true | 是否输出所有token embedding |
| target_item_position | string | head | target item的插入位置,可选:head, tail, ignore |
| 参数 | 类型 | 默认值 | 说明 |
| ---------------------------- | ------ | ---- | -------------------------------------- |
| hidden_size | int | | transformer 编码层单元数 |
| num_hidden_layers | int | | transformer层数 |
| num_attention_heads | int | | transformer head数 |
| intermediate_size | int | | transformer中间层单元数 |
| hidden_act | string | gelu | 隐藏激活函数 |
| hidden_dropout_prob | float | 0.1 | 隐藏dropout rate |
| attention_probs_dropout_prob | float | 0.1 | attention层dropout rate |
| max_position_embeddings | int | 512 | 序列最大长度 |
| use_position_embeddings | bool | true | 是否使用位置编码 |
| initializer_range | float | 0.2 | 权重参数初始值的区间范围 |
| output_all_token_embeddings | bool | true | 是否输出所有token embedding |
| target_item_position | string | head | target item的插入位置,可选:head, tail, ignore |

## 5. 多任务学习组件

Expand All @@ -145,9 +145,9 @@

- AuxiliaryLoss

| 参数 | 类型 | 默认值 | 说明 |
| ----------- | ------ | --- | --------------------------- |
| 参数 | 类型 | 默认值 | 说明 |
| ----------- | ------ | --- | ------------------------------------- |
| loss_type | string | | 损失函数类型,包括:l2_loss, nce_loss, info_nce |
| loss_weight | float | 1.0 | 损失函数权重 |
| temperature | float | 0.1 | info_nce loss 的参数 |
| 其他 | | | 根据loss_type决定 |
| loss_weight | float | 1.0 | 损失函数权重 |
| temperature | float | 0.1 | info_nce loss 的参数 |
| 其他 | | | 根据loss_type决定 |
6 changes: 4 additions & 2 deletions easy_rec/python/layers/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def __init__(self, config, features, input_layer, l2_reg=None):
assert iname != name, 'input name can not equal to block name:' + iname
self._dag.add_edge(iname, name)
elif iname not in input_feature_groups:
if input_type == 'feature_group_name' and input_layer.has_group(iname):
is_fea_group = input_type == 'feature_group_name'
if is_fea_group and input_layer.has_group(iname):
logging.info('adding an input_layer block: ' + iname)
new_block = backbone_pb2.Block()
new_block.name = iname
Expand Down Expand Up @@ -202,7 +203,8 @@ def block_input(self, config, block_outputs, training=None):
pkg_input = block_outputs[pkg_input_name]
else:
if pkg_input_name not in Package.__packages:
raise KeyError('package name `%s` does not exists' % pkg_input_name)
raise KeyError('package name `%s` does not exists' %
pkg_input_name)
inner_package = Package.__packages[pkg_input_name]
pkg_input = inner_package(training)
if input_node.HasField('package_input_fn'):
Expand Down
3 changes: 2 additions & 1 deletion easy_rec/python/layers/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def build(self, config, training):
target_features = tf.concat(target_features, axis=-1)
else:
target_features = None
assert len(seq_features) > 0, '[%s] sequence feature is empty' % self.name
assert len(
seq_features) > 0, '[%s] sequence feature is empty' % self.name
seq_features = tf.concat(seq_features, axis=-1)
self.inputs = seq_features, seq_len, target_features
self.reset(config, training)
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/layers/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .blocks import Gate
from .blocks import Highway
from .bst import BST
from .data_augment import SeqAugment
from .din import DIN
from .fibinet import BiLinear
from .fibinet import FiBiNet
Expand All @@ -15,4 +16,3 @@
from .multi_task import MMoE
from .numerical_embedding import AutoDisEmbedding
from .numerical_embedding import PeriodicEmbedding
from .data_augment import SeqAugment
6 changes: 4 additions & 2 deletions easy_rec/python/layers/keras/bst.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def call(self, inputs, training=None, **kwargs):
target = inputs[2] if len(inputs) > 2 else None
max_position = self.config.max_position_embeddings
# max_seq_len: the max sequence length in current mini-batch, all sequences are padded to this length
batch_size, cur_batch_max_seq_len, seq_embed_size = get_shape_list(seq_input, 3)
batch_size, cur_batch_max_seq_len, seq_embed_size = get_shape_list(
seq_input, 3)
valid_len = tf.assert_less_equal(
cur_batch_max_seq_len,
max_position,
Expand All @@ -90,7 +91,8 @@ def call(self, inputs, training=None, **kwargs):
name=self.name + '/seq_project',
reuse=self.reuse)

if target is not None and self.config.target_item_position in ('head', 'tail'):
if target is not None and self.config.target_item_position in ('head',
'tail'):
target_size = target.shape.as_list()[-1]
assert seq_embed_size == target_size, 'the embedding size of sequence and target item is not equal' \
' in feature group:' + self.name
Expand Down
52 changes: 29 additions & 23 deletions easy_rec/python/layers/keras/data_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf
from tensorflow.python.keras.layers import Layer

from easy_rec.python.utils.shape_utils import get_shape_list

if tf.__version__ >= '2.0':
Expand All @@ -14,7 +15,7 @@ def item_mask(aug_data, length, mask_emb, mask_rate):
max_len = tf.shape(aug_data)[0]
seq_mask = tf.sequence_mask(num_mask, length)
seq_mask = tf.random.shuffle(seq_mask)
padding = tf.sequence_mask(0, max_len-length)
padding = tf.sequence_mask(0, max_len - length)
seq_mask = tf.concat([seq_mask, padding], axis=0)

mask_emb = tf.tile(mask_emb, [max_len, 1])
Expand All @@ -25,22 +26,21 @@ def item_mask(aug_data, length, mask_emb, mask_rate):

def item_crop(aug_data, length, crop_rate):
length1 = tf.cast(length, dtype=tf.float32)
max_len, embedding_size = get_shape_list(aug_data)
max_len, _ = get_shape_list(aug_data)
max_length = tf.cast(max_len, dtype=tf.int32)

num_left = tf.cast(tf.math.floor(length1 * crop_rate), dtype=tf.int32)
crop_begin = tf.random.uniform([],
minval=0,
maxval=length - num_left,
dtype=tf.int32)
cropped_item_seq = tf.zeros([max_len, embedding_size])
zeros = tf.zeros_like(aug_data)
x = aug_data[crop_begin:crop_begin + num_left]
y = zeros[:max_length - num_left]
cropped = tf.concat([x, y], axis=0),
cropped_item_seq = tf.where(
crop_begin + num_left < max_length,
tf.concat([
aug_data[crop_begin:crop_begin + num_left],
cropped_item_seq[:max_length - num_left]
], axis=0),
tf.concat([aug_data[crop_begin:], cropped_item_seq[:crop_begin]], axis=0))
crop_begin + num_left < max_length, cropped,
tf.concat([aug_data[crop_begin:], zeros[:crop_begin]], axis=0))
return cropped_item_seq, num_left


Expand All @@ -58,41 +58,46 @@ def item_reorder(aug_data, length, reorder_rate):
right = tf.slice(x, [reorder_begin + num_reorder], [-1])
reordered_item_index = tf.concat([left, shuffle_index, right], axis=0)
reordered_item_seq = tf.scatter_nd(
tf.expand_dims(reordered_item_index, axis=1), aug_data,
tf.shape(aug_data))
tf.expand_dims(reordered_item_index, axis=1), aug_data,
tf.shape(aug_data))
return reordered_item_seq, length


def augment_fn(x, aug_param, mask):
seq, length = x

crop_fn = lambda: item_crop(seq, length, aug_param.crop_rate)
mask_fn = lambda: item_mask(seq, length, mask, aug_param.mask_rate)
reorder_fn = lambda: item_reorder(seq, length, aug_param.reorder_rate)
def crop_fn():
return item_crop(seq, length, aug_param.crop_rate)

def mask_fn():
return item_mask(seq, length, mask, aug_param.mask_rate)

def reorder_fn():
return item_reorder(seq, length, aug_param.reorder_rate)

methods = tf.range(3, dtype=tf.int32)
method = tf.random.shuffle(methods)[0]

aug_seq, aug_len = tf.cond(
tf.equal(method, 0), crop_fn,
lambda: tf.cond(tf.equal(method, 1), mask_fn, reorder_fn))
tf.equal(method, 0), crop_fn,
lambda: tf.cond(tf.equal(method, 1), mask_fn, reorder_fn))

return aug_seq, aug_len


def sequence_augment(seq_input, seq_len, mask, aug_param):
lengths = tf.cast(seq_len, dtype=tf.int32)
aug_seq, aug_len = tf.map_fn(
lambda elems: augment_fn(elems, aug_param, mask),
elems=(seq_input, lengths),
dtype=(tf.float32, tf.int32))
lambda elems: augment_fn(elems, aug_param, mask),
elems=(seq_input, lengths),
dtype=(tf.float32, tf.int32))

aug_seq = tf.reshape(aug_seq, tf.shape(seq_input))
return aug_seq, aug_len


class SeqAugment(Layer):
"""Do data augmentation for input sequence embedding"""
"""Do data augmentation for input sequence embedding."""

def __init__(self, params, name='seq_aug', reuse=None, **kwargs):
super(SeqAugment, self).__init__(name, **kwargs)
Expand All @@ -105,8 +110,9 @@ def call(self, inputs, training=None, **kwargs):

embedding_size = int(seq_input.shape[-1])
with tf.variable_scope(self.name, reuse=self.reuse):
mask_emb = tf.get_variable('mask', [1, embedding_size], dtype=tf.float32, trainable=True)
mask_emb = tf.get_variable(
'mask', [1, embedding_size], dtype=tf.float32, trainable=True)

aug_seq, aug_len = sequence_augment(
seq_input, seq_len, mask_emb, self.seq_aug_params)
aug_seq, aug_len = sequence_augment(seq_input, seq_len, mask_emb,
self.seq_aug_params)
return aug_seq, aug_len
2 changes: 0 additions & 2 deletions easy_rec/python/protos/seq_encoder.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,4 @@ message SequenceAugment {
required float crop_rate = 2 [default = 0.2];
// Percentage length of reorder original sequence
required float reorder_rate = 3 [default = 0.6];
// The embedding size of each sequence elements
required uint32 embedding_size = 4;
}

0 comments on commit f8ccdd9

Please sign in to comment.