-
Notifications
You must be signed in to change notification settings - Fork 534
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor & add some negative sampling strategies (#81)
Support different negative sampling strategies, including `inbatch`, `uniform`, `frequency`, `adaptive`. 1. add `deepmatch.utils.NegativeSampler` 2. remove `deepmatch.layers.core.NegativeSampler` 3. add `temperature`,`sampler_config`,`loss_type` for models 4. add some google colab scripts
- Loading branch information
1 parent
be6c028
commit 5dab795
Showing
37 changed files
with
1,953 additions
and
506 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .utils import check_version | ||
|
||
__version__ = '0.2.1' | ||
__version__ = '0.3.0' | ||
check_version(__version__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,15 @@ | ||
""" | ||
Author: | ||
Weichen Shen,[email protected] | ||
""" | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from deepctr.layers.activation import activation_layer | ||
from deepctr.layers.utils import reduce_max, reduce_mean, reduce_sum, concat_func, div, softmax | ||
from tensorflow.python.keras.initializers import RandomNormal, Zeros, TruncatedNormal | ||
from tensorflow.python.keras.initializers import Zeros | ||
from tensorflow.python.keras.layers import Layer | ||
from tensorflow.python.keras.regularizers import l2 | ||
|
||
|
||
class PoolingLayer(Layer): | ||
|
@@ -45,45 +51,103 @@ def get_config(self, ): | |
|
||
|
||
class SampledSoftmaxLayer(Layer): | ||
def __init__(self, num_sampled=5, **kwargs): | ||
self.num_sampled = num_sampled | ||
def __init__(self, sampler_config, temperature=1.0, **kwargs): | ||
self.sampler_config = sampler_config | ||
self.temperature = temperature | ||
self.sampler = self.sampler_config['sampler'] | ||
self.item_count = self.sampler_config['item_count'] | ||
|
||
super(SampledSoftmaxLayer, self).__init__(**kwargs) | ||
|
||
def build(self, input_shape): | ||
self.size = input_shape[0][0] | ||
self.zero_bias = self.add_weight(shape=[self.size], | ||
self.vocabulary_size = input_shape[0][0] | ||
self.zero_bias = self.add_weight(shape=[self.vocabulary_size], | ||
initializer=Zeros, | ||
dtype=tf.float32, | ||
trainable=False, | ||
name="bias") | ||
super(SampledSoftmaxLayer, self).build(input_shape) | ||
|
||
def call(self, inputs_with_label_idx, training=None, **kwargs): | ||
""" | ||
The first input should be the model as it were, and the second the | ||
target (i.e., a repeat of the training data) to compute the labels | ||
argument | ||
""" | ||
embeddings, inputs, label_idx = inputs_with_label_idx | ||
|
||
loss = tf.nn.sampled_softmax_loss(weights=embeddings, # self.item_embedding. | ||
biases=self.zero_bias, | ||
labels=label_idx, | ||
inputs=inputs, | ||
num_sampled=self.num_sampled, | ||
num_classes=self.size, # self.target_song_size | ||
) | ||
def call(self, inputs_with_item_idx, training=None, **kwargs): | ||
item_embeddings, user_vec, item_idx = inputs_with_item_idx | ||
if item_idx.dtype != tf.int64: | ||
item_idx = tf.cast(item_idx, tf.int64) | ||
user_vec /= self.temperature | ||
if self.sampler == "inbatch": | ||
item_vec = tf.gather(item_embeddings, tf.squeeze(item_idx, axis=1)) | ||
logits = tf.matmul(user_vec, item_vec, transpose_b=True) | ||
loss = inbatch_softmax_cross_entropy_with_logits(logits, self.item_count, item_idx) | ||
|
||
else: | ||
num_sampled = self.sampler_config['num_sampled'] | ||
if self.sampler == "frequency": | ||
sampled_values = tf.nn.fixed_unigram_candidate_sampler(item_idx, 1, num_sampled, True, | ||
self.vocabulary_size, | ||
distortion=self.sampler_config['distortion'], | ||
unigrams=np.maximum(self.item_count, 1).tolist(), | ||
seed=None, | ||
name=None) | ||
elif self.sampler == "adaptive": | ||
sampled_values = tf.nn.learned_unigram_candidate_sampler(item_idx, 1, num_sampled, True, | ||
self.vocabulary_size, seed=None, name=None) | ||
elif self.sampler == "uniform": | ||
try: | ||
sampled_values = tf.nn.uniform_candidate_sampler(item_idx, 1, num_sampled, True, | ||
self.vocabulary_size, seed=None, name=None) | ||
except AttributeError: | ||
sampled_values = tf.random.uniform_candidate_sampler(item_idx, 1, num_sampled, True, | ||
self.vocabulary_size, seed=None, name=None) | ||
else: | ||
raise ValueError(' `%s` sampler is not supported ' % self.sampler) | ||
|
||
loss = tf.nn.sampled_softmax_loss(weights=item_embeddings, | ||
biases=self.zero_bias, | ||
labels=item_idx, | ||
inputs=user_vec, | ||
num_sampled=num_sampled, | ||
num_classes=self.vocabulary_size, | ||
sampled_values=sampled_values | ||
) | ||
return tf.expand_dims(loss, axis=1) | ||
|
||
def compute_output_shape(self, input_shape): | ||
return (None, 1) | ||
|
||
def get_config(self, ): | ||
config = {'num_sampled': self.num_sampled} | ||
config = {'sampler_config': self.sampler_config, 'temperature': self.temperature} | ||
base_config = super(SampledSoftmaxLayer, self).get_config() | ||
return dict(list(base_config.items()) + list(config.items())) | ||
|
||
|
||
class InBatchSoftmaxLayer(Layer): | ||
def __init__(self, sampler_config, temperature=1.0, **kwargs): | ||
self.sampler_config = sampler_config | ||
self.temperature = temperature | ||
self.item_count = self.sampler_config['item_count'] | ||
|
||
super(InBatchSoftmaxLayer, self).__init__(**kwargs) | ||
|
||
def build(self, input_shape): | ||
super(InBatchSoftmaxLayer, self).build(input_shape) | ||
|
||
def call(self, inputs_with_item_idx, training=None, **kwargs): | ||
user_vec, item_vec, item_idx = inputs_with_item_idx | ||
if item_idx.dtype != tf.int64: | ||
item_idx = tf.cast(item_idx, tf.int64) | ||
user_vec /= self.temperature | ||
logits = tf.matmul(user_vec, item_vec, transpose_b=True) | ||
loss = inbatch_softmax_cross_entropy_with_logits(logits, self.item_count, item_idx) | ||
return tf.expand_dims(loss, axis=1) | ||
|
||
def compute_output_shape(self, input_shape): | ||
return (None, 1) | ||
|
||
def get_config(self, ): | ||
config = {'sampler_config': self.sampler_config, 'temperature': self.temperature} | ||
base_config = super(InBatchSoftmaxLayer, self).get_config() | ||
return dict(list(base_config.items()) + list(config.items())) | ||
|
||
|
||
class LabelAwareAttention(Layer): | ||
def __init__(self, k_max, pow_p=1, **kwargs): | ||
self.k_max = k_max | ||
|
@@ -128,38 +192,6 @@ def get_config(self, ): | |
return dict(list(base_config.items()) + list(config.items())) | ||
|
||
|
||
class Similarity(Layer): | ||
|
||
def __init__(self, gamma=1, axis=-1, type='cos', **kwargs): | ||
self.gamma = gamma | ||
self.axis = axis | ||
self.type = type | ||
super(Similarity, self).__init__(**kwargs) | ||
|
||
def build(self, input_shape): | ||
# Be sure to call this somewhere! | ||
super(Similarity, self).build(input_shape) | ||
|
||
def call(self, inputs, **kwargs): | ||
query, candidate = inputs | ||
if self.type == "cos": | ||
query_norm = tf.norm(query, axis=self.axis) | ||
candidate_norm = tf.norm(candidate, axis=self.axis) | ||
cosine_score = reduce_sum(tf.multiply(query, candidate), -1) | ||
if self.type == "cos": | ||
cosine_score = div(cosine_score, query_norm * candidate_norm + 1e-8) | ||
cosine_score = tf.clip_by_value(cosine_score, -1, 1.0) * self.gamma | ||
return cosine_score | ||
|
||
def compute_output_shape(self, input_shape): | ||
return (None, 1) | ||
|
||
def get_config(self, ): | ||
config = {'gamma': self.gamma, 'axis': self.axis, 'type': self.type} | ||
base_config = super(Similarity, self).get_config() | ||
return dict(list(base_config.items()) + list(config.items())) | ||
|
||
|
||
class CapsuleLayer(Layer): | ||
def __init__(self, input_units, out_units, max_len, k_max, iteration_times=3, | ||
init_std=1.0, **kwargs): | ||
|
@@ -245,6 +277,23 @@ def squash(inputs): | |
return vec_squashed | ||
|
||
|
||
def inbatch_softmax_cross_entropy_with_logits(logits, item_count, item_idx): | ||
Q = tf.gather(tf.constant(item_count / np.sum(item_count), 'float32'), | ||
tf.squeeze(item_idx, axis=1)) | ||
try: | ||
logQ = tf.reshape(tf.math.log(Q), (1, -1)) | ||
logits -= logQ # subtract_log_q | ||
labels = tf.linalg.diag(tf.ones_like(logits[0])) | ||
except AttributeError: | ||
logQ = tf.reshape(tf.log(Q), (1, -1)) | ||
logits -= logQ # subtract_log_q | ||
labels = tf.diag(tf.ones_like(logits[0])) | ||
|
||
loss = tf.nn.softmax_cross_entropy_with_logits( | ||
labels=labels, logits=logits) | ||
return loss | ||
|
||
|
||
class EmbeddingIndex(Layer): | ||
|
||
def __init__(self, index, **kwargs): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,10 @@ | ||
""" | ||
Author: | ||
Weichen Shen,[email protected] | ||
""" | ||
|
||
import tensorflow as tf | ||
from deepctr.layers.normalization import LayerNormalization | ||
from deepctr.layers.utils import softmax, reduce_mean | ||
|
@@ -109,7 +116,7 @@ def call(self, inputs, mask=None, training=None, **kwargs): | |
lower_tri = tf.ones([length, length]) | ||
try: | ||
lower_tri = tf.contrib.linalg.LinearOperatorTriL(lower_tri).to_dense() | ||
except: | ||
except AttributeError: | ||
lower_tri = tf.linalg.LinearOperatorLowerTriangular(lower_tri).to_dense() | ||
masks = tf.tile(tf.expand_dims(lower_tri, 0), [tf.shape(align)[0], 1, 1]) | ||
align = tf.where(tf.equal(masks, 0), paddings, align) | ||
|
@@ -199,8 +206,8 @@ def build(self, input_shape): | |
super(SelfAttention, self).build(input_shape) | ||
|
||
def call(self, inputs, mask=None, **kwargs): | ||
input, key_masks = inputs | ||
querys, keys, values = input, input, input | ||
_input, key_masks = inputs | ||
querys, keys, values = _input, _input, _input | ||
align = self.attention([querys, keys]) | ||
output = self.softmax_weight_sum([align, values, key_masks]) | ||
if self.use_layer_norm: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,10 @@ | ||
""" | ||
Author: | ||
Weichen Shen,[email protected] | ||
""" | ||
|
||
import tensorflow as tf | ||
from tensorflow.python.keras.layers import Layer | ||
|
||
|
Oops, something went wrong.