Skip to content

Commit

Permalink
refactor & add some negative sampling strategies (#81)
Browse files Browse the repository at this point in the history
Support different negative sampling strategies, including `inbatch`, `uniform`, `frequency`, `adaptive`.
1. add `deepmatch.utils.NegativeSampler`
2. remove `deepmatch.layers.core.NegativeSampler`
3. add `temperature`,`sampler_config`,`loss_type` for models
4. add some google colab scripts
  • Loading branch information
shenweichen authored Jul 3, 2022
1 parent be6c028 commit 5dab795
Show file tree
Hide file tree
Showing 37 changed files with 1,953 additions and 506 deletions.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ Steps to reproduce the behavior:

**Operating environment(运行环境):**
- python version [e.g. 3.6, 3.7, 3.8]
- tensorflow version [e.g. 1.4.0, 1.14.0, 2.5.0]
- deepmatch version [e.g. 0.2.1,]
- tensorflow version [e.g. 1.9.0, 1.14.0, 2.5.0]
- deepmatch version [e.g. 0.3.0,]

**Additional context**
Add any other context about the problem here.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/question.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ Add any other context about the problem here.

**Operating environment(运行环境):**
- python version [e.g. 3.6, 3.7, 3.8]
- tensorflow version [e.g. 1.4.0, 1.14.0, 2.5.0]
- deepmatch version [e.g. 0.2.1,]
- tensorflow version [e.g. 1.9.0, 1.14.0, 2.5.0]
- deepmatch version [e.g. 0.3.0,]
22 changes: 21 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,35 @@ jobs:
strategy:
matrix:
python-version: [3.6,3.7,3.8]
tf-version: [1.4.0,1.14.0,2.5.0]
tf-version: [1.9.0,1.14.0,2.5.0]

exclude:
- python-version: 3.7
tf-version: 1.4.0
- python-version: 3.7
tf-version: 1.9.0
- python-version: 3.7
tf-version: 1.10.0
- python-version: 3.7
tf-version: 1.11.0
- python-version: 3.7
tf-version: 1.12.0
- python-version: 3.7
tf-version: 1.13.0
- python-version: 3.7
tf-version: 1.15.0
- python-version: 3.8
tf-version: 1.4.0
- python-version: 3.8
tf-version: 1.9.0
- python-version: 3.8
tf-version: 1.10.0
- python-version: 3.8
tf-version: 1.11.0
- python-version: 3.8
tf-version: 1.12.0
- python-version: 3.8
tf-version: 1.13.0
- python-version: 3.8
tf-version: 1.14.0
- python-version: 3.8
Expand Down
41 changes: 12 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# DeepMatch

[![Python Versions](https://img.shields.io/pypi/pyversions/deepmatch.svg)](https://pypi.org/project/deepmatch)
[![TensorFlow Versions](https://img.shields.io/badge/TensorFlow-1.4+/2.0+-blue.svg)](https://pypi.org/project/deepmatch)
[![TensorFlow Versions](https://img.shields.io/badge/TensorFlow-1.9+/2.0+-blue.svg)](https://pypi.org/project/deepmatch)
[![Downloads](https://pepy.tech/badge/deepmatch)](https://pepy.tech/project/deepmatch)
[![PyPI Version](https://img.shields.io/pypi/v/deepmatch.svg)](https://pypi.org/project/deepmatch)
[![GitHub Issues](https://img.shields.io/github/issues/shenweichen/deepmatch.svg
)](https://github.com/shenweichen/deepmatch/issues)
Expand All @@ -11,7 +12,8 @@
[![Documentation Status](https://readthedocs.org/projects/deepmatch/badge/?version=latest)](https://deepmatch.readthedocs.io/)
![CI status](https://github.com/shenweichen/deepmatch/workflows/CI/badge.svg)
[![codecov](https://codecov.io/gh/shenweichen/DeepMatch/branch/master/graph/badge.svg)](https://codecov.io/gh/shenweichen/DeepMatch)
[![Disscussion](https://img.shields.io/badge/chat-wechat-brightgreen?style=flat)](./README.md#disscussiongroup)
[![Codacy Badge](https://app.codacy.com/project/badge/Grade/c5a2769ec35444d8958f6b58ff85029b)](https://www.codacy.com/gh/shenweichen/DeepMatch/dashboard?utm_source=github.com&utm_medium=referral&utm_content=shenweichen/DeepMatch&utm_campaign=Badge_Grade)
[![Disscussion](https://img.shields.io/badge/chat-wechat-brightgreen?style=flat)](https://github.com/shenweichen/DeepMatch#disscussiongroup)
[![License](https://img.shields.io/github/license/shenweichen/deepmatch.svg)](https://github.com/shenweichen/deepmatch/blob/master/LICENSE)

DeepMatch is a deep matching model library for recommendations & advertising. It's easy to **train models** and to **export representation vectors** for user and item which can be used for **ANN search**.You can use any complex model with `model.fit()`and `model.predict()` .
Expand Down Expand Up @@ -72,31 +74,12 @@ Let's [**Get Started!**](https://deepmatch.readthedocs.io/en/latest/Quick-Start.
</tbody>
</table>

## DisscussionGroup & Related Projects
## DisscussionGroup

- [Github Discussions](https://github.com/shenweichen/DeepMatch/discussions)
- Wechat Discussions

|公众号:浅梦学习笔记|微信:deepctrbot|学习小组 [加入](https://t.zsxq.com/026UJEuzv) [主题集合](https://mp.weixin.qq.com/mp/appmsgalbum?__biz=MjM5MzY4NzE3MA==&action=getalbum&album_id=1361647041096843265&scene=126#wechat_redirect)|
|:--:|:--:|:--:|
| [![公众号](./docs/pics/code.png)](https://github.com/shenweichen/AlgoNotes)| [![微信](./docs/pics/deepctrbot.png)](https://github.com/shenweichen/AlgoNotes)|[![学习小组](./docs/pics/planet_github.png)](https://t.zsxq.com/026UJEuzv)|

<html>
<table style="margin-left: 20px; margin-right: auto;">
<tr>
<td>
公众号:<b>浅梦的学习笔记</b><br><br>
<a href="https://github.com/shenweichen/deepmatch">
<img align="center" src="./docs/pics/code.png" />
</a>
</td>
<td>
微信:<b>deepctrbot</b><br><br>
<a href="https://github.com/shenweichen/deepmatch">
<img align="center" src="./docs/pics/deepctrbot.png" />
</a>
</td>
<td>
<ul>
<li><a href="https://github.com/shenweichen/AlgoNotes">AlgoNotes</a></li>
<li><a href="https://github.com/shenweichen/DeepCTR">DeepCTR</a></li>
<li><a href="https://github.com/shenweichen/DeepCTR-Torch">DeepCTR-Torch</a></li>
<li><a href="https://github.com/shenweichen/GraphEmbedding">GraphEmbedding</a></li>
</ul>
</td>
</tr>
</table>
</html>
2 changes: 1 addition & 1 deletion deepmatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .utils import check_version

__version__ = '0.2.1'
__version__ = '0.3.0'
check_version(__version__)
6 changes: 3 additions & 3 deletions deepmatch/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from deepctr.layers import custom_objects
from deepctr.layers.utils import reduce_sum

from .core import PoolingLayer, Similarity, LabelAwareAttention, CapsuleLayer, SampledSoftmaxLayer, EmbeddingIndex, \
MaskUserEmbedding
from .core import PoolingLayer, LabelAwareAttention, CapsuleLayer, SampledSoftmaxLayer, EmbeddingIndex, \
MaskUserEmbedding, InBatchSoftmaxLayer
from .interaction import DotAttention, ConcatAttention, SoftmaxWeightedSum, AttentionSequencePoolingLayer, \
SelfAttention, \
SelfMultiHeadAttention, UserAttention
from .sequence import DynamicMultiRNN
from ..utils import sampledsoftmaxloss

_custom_objects = {'PoolingLayer': PoolingLayer,
'Similarity': Similarity,
'LabelAwareAttention': LabelAwareAttention,
'CapsuleLayer': CapsuleLayer,
'reduce_sum': reduce_sum,
'SampledSoftmaxLayer': SampledSoftmaxLayer,
'InBatchSoftmaxLayer': InBatchSoftmaxLayer,
'sampledsoftmaxloss': sampledsoftmaxloss,
'EmbeddingIndex': EmbeddingIndex,
'DotAttention': DotAttention,
Expand Down
159 changes: 104 additions & 55 deletions deepmatch/layers/core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""
Author:
Weichen Shen,[email protected]
"""

import numpy as np
import tensorflow as tf
from deepctr.layers.activation import activation_layer
from deepctr.layers.utils import reduce_max, reduce_mean, reduce_sum, concat_func, div, softmax
from tensorflow.python.keras.initializers import RandomNormal, Zeros, TruncatedNormal
from tensorflow.python.keras.initializers import Zeros
from tensorflow.python.keras.layers import Layer
from tensorflow.python.keras.regularizers import l2


class PoolingLayer(Layer):
Expand Down Expand Up @@ -45,45 +51,103 @@ def get_config(self, ):


class SampledSoftmaxLayer(Layer):
def __init__(self, num_sampled=5, **kwargs):
self.num_sampled = num_sampled
def __init__(self, sampler_config, temperature=1.0, **kwargs):
self.sampler_config = sampler_config
self.temperature = temperature
self.sampler = self.sampler_config['sampler']
self.item_count = self.sampler_config['item_count']

super(SampledSoftmaxLayer, self).__init__(**kwargs)

def build(self, input_shape):
self.size = input_shape[0][0]
self.zero_bias = self.add_weight(shape=[self.size],
self.vocabulary_size = input_shape[0][0]
self.zero_bias = self.add_weight(shape=[self.vocabulary_size],
initializer=Zeros,
dtype=tf.float32,
trainable=False,
name="bias")
super(SampledSoftmaxLayer, self).build(input_shape)

def call(self, inputs_with_label_idx, training=None, **kwargs):
"""
The first input should be the model as it were, and the second the
target (i.e., a repeat of the training data) to compute the labels
argument
"""
embeddings, inputs, label_idx = inputs_with_label_idx

loss = tf.nn.sampled_softmax_loss(weights=embeddings, # self.item_embedding.
biases=self.zero_bias,
labels=label_idx,
inputs=inputs,
num_sampled=self.num_sampled,
num_classes=self.size, # self.target_song_size
)
def call(self, inputs_with_item_idx, training=None, **kwargs):
item_embeddings, user_vec, item_idx = inputs_with_item_idx
if item_idx.dtype != tf.int64:
item_idx = tf.cast(item_idx, tf.int64)
user_vec /= self.temperature
if self.sampler == "inbatch":
item_vec = tf.gather(item_embeddings, tf.squeeze(item_idx, axis=1))
logits = tf.matmul(user_vec, item_vec, transpose_b=True)
loss = inbatch_softmax_cross_entropy_with_logits(logits, self.item_count, item_idx)

else:
num_sampled = self.sampler_config['num_sampled']
if self.sampler == "frequency":
sampled_values = tf.nn.fixed_unigram_candidate_sampler(item_idx, 1, num_sampled, True,
self.vocabulary_size,
distortion=self.sampler_config['distortion'],
unigrams=np.maximum(self.item_count, 1).tolist(),
seed=None,
name=None)
elif self.sampler == "adaptive":
sampled_values = tf.nn.learned_unigram_candidate_sampler(item_idx, 1, num_sampled, True,
self.vocabulary_size, seed=None, name=None)
elif self.sampler == "uniform":
try:
sampled_values = tf.nn.uniform_candidate_sampler(item_idx, 1, num_sampled, True,
self.vocabulary_size, seed=None, name=None)
except AttributeError:
sampled_values = tf.random.uniform_candidate_sampler(item_idx, 1, num_sampled, True,
self.vocabulary_size, seed=None, name=None)
else:
raise ValueError(' `%s` sampler is not supported ' % self.sampler)

loss = tf.nn.sampled_softmax_loss(weights=item_embeddings,
biases=self.zero_bias,
labels=item_idx,
inputs=user_vec,
num_sampled=num_sampled,
num_classes=self.vocabulary_size,
sampled_values=sampled_values
)
return tf.expand_dims(loss, axis=1)

def compute_output_shape(self, input_shape):
return (None, 1)

def get_config(self, ):
config = {'num_sampled': self.num_sampled}
config = {'sampler_config': self.sampler_config, 'temperature': self.temperature}
base_config = super(SampledSoftmaxLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class InBatchSoftmaxLayer(Layer):
def __init__(self, sampler_config, temperature=1.0, **kwargs):
self.sampler_config = sampler_config
self.temperature = temperature
self.item_count = self.sampler_config['item_count']

super(InBatchSoftmaxLayer, self).__init__(**kwargs)

def build(self, input_shape):
super(InBatchSoftmaxLayer, self).build(input_shape)

def call(self, inputs_with_item_idx, training=None, **kwargs):
user_vec, item_vec, item_idx = inputs_with_item_idx
if item_idx.dtype != tf.int64:
item_idx = tf.cast(item_idx, tf.int64)
user_vec /= self.temperature
logits = tf.matmul(user_vec, item_vec, transpose_b=True)
loss = inbatch_softmax_cross_entropy_with_logits(logits, self.item_count, item_idx)
return tf.expand_dims(loss, axis=1)

def compute_output_shape(self, input_shape):
return (None, 1)

def get_config(self, ):
config = {'sampler_config': self.sampler_config, 'temperature': self.temperature}
base_config = super(InBatchSoftmaxLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class LabelAwareAttention(Layer):
def __init__(self, k_max, pow_p=1, **kwargs):
self.k_max = k_max
Expand Down Expand Up @@ -128,38 +192,6 @@ def get_config(self, ):
return dict(list(base_config.items()) + list(config.items()))


class Similarity(Layer):

def __init__(self, gamma=1, axis=-1, type='cos', **kwargs):
self.gamma = gamma
self.axis = axis
self.type = type
super(Similarity, self).__init__(**kwargs)

def build(self, input_shape):
# Be sure to call this somewhere!
super(Similarity, self).build(input_shape)

def call(self, inputs, **kwargs):
query, candidate = inputs
if self.type == "cos":
query_norm = tf.norm(query, axis=self.axis)
candidate_norm = tf.norm(candidate, axis=self.axis)
cosine_score = reduce_sum(tf.multiply(query, candidate), -1)
if self.type == "cos":
cosine_score = div(cosine_score, query_norm * candidate_norm + 1e-8)
cosine_score = tf.clip_by_value(cosine_score, -1, 1.0) * self.gamma
return cosine_score

def compute_output_shape(self, input_shape):
return (None, 1)

def get_config(self, ):
config = {'gamma': self.gamma, 'axis': self.axis, 'type': self.type}
base_config = super(Similarity, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class CapsuleLayer(Layer):
def __init__(self, input_units, out_units, max_len, k_max, iteration_times=3,
init_std=1.0, **kwargs):
Expand Down Expand Up @@ -245,6 +277,23 @@ def squash(inputs):
return vec_squashed


def inbatch_softmax_cross_entropy_with_logits(logits, item_count, item_idx):
Q = tf.gather(tf.constant(item_count / np.sum(item_count), 'float32'),
tf.squeeze(item_idx, axis=1))
try:
logQ = tf.reshape(tf.math.log(Q), (1, -1))
logits -= logQ # subtract_log_q
labels = tf.linalg.diag(tf.ones_like(logits[0]))
except AttributeError:
logQ = tf.reshape(tf.log(Q), (1, -1))
logits -= logQ # subtract_log_q
labels = tf.diag(tf.ones_like(logits[0]))

loss = tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
return loss


class EmbeddingIndex(Layer):

def __init__(self, index, **kwargs):
Expand Down
13 changes: 10 additions & 3 deletions deepmatch/layers/interaction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
Author:
Weichen Shen,[email protected]
"""

import tensorflow as tf
from deepctr.layers.normalization import LayerNormalization
from deepctr.layers.utils import softmax, reduce_mean
Expand Down Expand Up @@ -109,7 +116,7 @@ def call(self, inputs, mask=None, training=None, **kwargs):
lower_tri = tf.ones([length, length])
try:
lower_tri = tf.contrib.linalg.LinearOperatorTriL(lower_tri).to_dense()
except:
except AttributeError:
lower_tri = tf.linalg.LinearOperatorLowerTriangular(lower_tri).to_dense()
masks = tf.tile(tf.expand_dims(lower_tri, 0), [tf.shape(align)[0], 1, 1])
align = tf.where(tf.equal(masks, 0), paddings, align)
Expand Down Expand Up @@ -199,8 +206,8 @@ def build(self, input_shape):
super(SelfAttention, self).build(input_shape)

def call(self, inputs, mask=None, **kwargs):
input, key_masks = inputs
querys, keys, values = input, input, input
_input, key_masks = inputs
querys, keys, values = _input, _input, _input
align = self.attention([querys, keys])
output = self.softmax_weight_sum([align, values, key_masks])
if self.use_layer_norm:
Expand Down
7 changes: 7 additions & 0 deletions deepmatch/layers/sequence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
Author:
Weichen Shen,[email protected]
"""

import tensorflow as tf
from tensorflow.python.keras.layers import Layer

Expand Down
Loading

0 comments on commit 5dab795

Please sign in to comment.