Skip to content

Commit

Permalink
add contrastive learning CL4SRec
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Sep 21, 2023
1 parent f8ccdd9 commit fb36269
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
5 changes: 4 additions & 1 deletion docs/source/component/backbone.md
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,8 @@ MovieLens-1M数据集效果:
| MultiTower | 1 | 0.8814 |
| ContrastiveLearning | 1 | 0.8728 |

一个更复杂一点的对比学习模型案例:[CL4SRec](../models/cl4srec.md)

## 案例8:多目标模型 MMoE

多目标模型的model_class一般配置为"MultiTaskModel",并且需要在`model_params`里配置多个目标对应的Tower。`model_name`为任意自定义字符串,仅有注释作用。
Expand Down Expand Up @@ -962,6 +964,7 @@ MovieLens-1M数据集效果:

- DIN模型配置文件:[DIN_backbone.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/din_backbone_on_taobao.config)
- BST模型配置文件:[BST_backbone.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/bst_backbone_on_taobao.config)
- CL4SRec模型:[CL4SRec](../models/cl4srec.md)

其他模型:

Expand Down Expand Up @@ -1006,7 +1009,7 @@ MovieLens-1M数据集效果:
| ---------- | ---------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------ |
| DIN | target attention | DIN模型的组件 | [DIN_backbone.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/din_backbone_on_taobao.config) |
| BST | transformer | BST模型的组件 | [BST_backbone.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/bst_backbone_on_taobao.config) |
| SeqAugment | 序列数据增强 | crop, mask, reorder | [CL4SRec](../models/cl4srec.html) |
| SeqAugment | 序列数据增强 | crop, mask, reorder | [CL4SRec](../models/cl4srec.html#id2) |

## 5. 多目标学习组件

Expand Down
2 changes: 2 additions & 0 deletions docs/source/models/cl4srec.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ model_config: {
- backbone: 通过组件化的方式搭建的主干网络,[参考文档](../component/backbone.md)
- blocks: 由多个`组件块`组成的一个有向无环图(DAG),框架负责按照DAG的拓扑排序执行个`组件块`关联的代码逻辑,构建TF Graph的一个子图
- name/inputs: 每个`block`有一个唯一的名字(name),并且有一个或多个输入(inputs)和输出
- package: package可以打包一组block,构成一个可被复用的子网络,即被打包的子网络以共享参数的方式在同一个模型中调用多次
- use_package_input: 当`package`的输入是动态的时,设置该输入占位符,表示当前`block`的输入由调用`package`时指定
- keras_layer: 加载由`class_name`指定的自定义或系统内置的keras layer,执行一段代码逻辑;[参考文档](../component/backbone.md#keraslayer)
- SeqAugment: 序列数据增强的组件,参数详见[参考文档](../component/component.md#id5)
- AuxiliaryLoss: 计算辅助任务损失函数的组件,参数详见[参考文档](../component/component.md#id7)
Expand Down
6 changes: 3 additions & 3 deletions easy_rec/python/layers/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def __init__(self, config, features, input_layer, l2_reg=None):
fn = EnhancedInputLayer(self._input_layer, self._features, iname)
self._name_to_layer[iname] = fn
elif Package.has_backbone_block(iname):
backbone = Package.__packages['backbone']
backbone._dag.add_node_if_not_exists(self._config.name)
backbone._dag.add_edge(iname, self._config.name)
num_pkg_input += 1
else:
raise KeyError(
Expand Down Expand Up @@ -180,9 +183,6 @@ def has_block(self, name):
def block_outputs(self, name):
return self._block_outputs.get(name, None)

# def add_edge(self, src, dest):
# self._dag.add_edge(src, dest)

def block_input(self, config, block_outputs, training=None):
inputs = []
for input_node in config.inputs:
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/layers/keras/data_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def item_crop(aug_data, length, crop_rate):
zeros = tf.zeros_like(aug_data)
x = aug_data[crop_begin:crop_begin + num_left]
y = zeros[:max_length - num_left]
cropped = tf.concat([x, y], axis=0),
cropped = tf.concat([x, y], axis=0)
cropped_item_seq = tf.where(
crop_begin + num_left < max_length, cropped,
tf.concat([aug_data[crop_begin:], zeros[:crop_begin]], axis=0))
Expand Down

0 comments on commit fb36269

Please sign in to comment.