Skip to content

Commit

Permalink
add DSSM and YoutubeDNN
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyaoGeng committed Apr 5, 2022
1 parent 09b583a commit 917f4f4
Show file tree
Hide file tree
Showing 8 changed files with 366 additions and 11 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,6 @@ In [example](https://github.com/ZiyaoGeng/Recommender-System-with-TF2.0/tree/rec
</table>






### Ranking

<table style="text-align:center;margin:auto">
Expand Down Expand Up @@ -149,7 +145,7 @@ In [example](https://github.com/ZiyaoGeng/Recommender-System-with-TF2.0/tree/rec

## Discussion

1. If you have any suggestions or questions about the project, you can leave a comment on `Issue` or email `[email protected]`.
1. If you have any suggestions or questions about the project, you can leave a comment on `Issue`.
2. wechat:

<div align=center><img src="https://cdn.jsdelivr.net/gh/BlackSpaceGZY/cdn/img/weixin.jpg" width="20%"/></div>
Expand Down
5 changes: 3 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,14 @@ for file in split_file_list[:-1]:
</tr>
<tr><td>BPR</td><td>0.5768</td><td>0.2392</td><td>0.3016</td><td>0.3708</td><td>0.2108</td><td>0.2485</td><td>0.7728</td><td>0.4220</td><td>0.5054</td></tr>
<tr><td>NCF</td><td>0.5711</td><td>0.2112</td><td>0.2950</td><td>0.5448</td><td>0.2831</td><td>0.3451</td><td>0.7768</td><td>0.4273</td><td>0.5103</td></tr>
<tr><td>DSSM</td><td>0.5410</td><td>0.2016</td><td>0.2807</td><td>-</td><td>-</td><td>-</td><td>-</td><td>-</td><td>-</td></tr>
<tr><td>YoutubeDNN</td><td>0.6358</td><td>0.3042</td><td>0.3825</td><td>-</td><td>-</td><td>-</td><td>-</td><td>-</td><td>-</td></tr>
<tr><td>SASRec</td><td>0.8103</td><td>0.4812</td><td>0.5605</td><td>0.5230</td><td>0.2781</td><td>0.3355</td><td>0.8606</td><td>0.5669</td><td>0.6374</td></tr>
</table>





### Ranking

<table style="text-align:center;margin:auto">
Expand Down Expand Up @@ -289,4 +290,4 @@ for file in split_file_list[:-1]:

## 讨论

对于项目有任何建议或问题,可以在`Issue`留言,或者发邮件至`[email protected]`
对于项目有任何建议或问题,可以在`Issue`留言。
88 changes: 88 additions & 0 deletions example/m_dssm_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Created on Apr 1, 2022
train DSSM demo
@author: Ziyao Geng([email protected])
"""
import os
from absl import flags, app
from time import time
from tensorflow.keras.optimizers import Adam

from reclearn.models.matching import DSSM
from reclearn.data.datasets import movielens as ml
from reclearn.evaluator import eval_pos_neg

FLAGS = flags.FLAGS

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

# Setting training parameters
flags.DEFINE_string("file_path", "data/ml-1m/ratings.dat", "file path.")
flags.DEFINE_string("train_path", "data/ml-1m/ml_train.txt", "train path.")
flags.DEFINE_string("val_path", "data/ml-1m/ml_val.txt", "val path.")
flags.DEFINE_string("test_path", "data/ml-1m/ml_test.txt", "test path.")
flags.DEFINE_string("meta_path", "data/ml-1m/ml_meta.txt", "meta path.")
flags.DEFINE_integer("embed_dim", 64, "The size of embedding dimension.")
flags.DEFINE_float("embed_reg", 0.0, "The value of embedding regularization.")
flags.DEFINE_list("user_mlp", [128], "A list of user MLP hidden units.")
flags.DEFINE_list("item_mlp", [128], "A list of item MLP hidden units")
flags.DEFINE_string("activation", "relu", "Activation Name.")
flags.DEFINE_float("dnn_dropout", 0., "Float between 0 and 1. Dropout of user and item MLP layer.")
flags.DEFINE_boolean("use_l2norm", False, "Whether user embedding, item embedding should be normalized or not.")
flags.DEFINE_string("loss_name", "binary_cross_entropy_loss", "Loss Name.")
flags.DEFINE_float("gamma", 0.5, "If hinge_loss is selected as the loss function, you can specify the margin.")
flags.DEFINE_float("learning_rate", 0.001, "Learning rate.")
flags.DEFINE_integer("neg_num", 2, "The number of negative sample for each positive sample.")
flags.DEFINE_integer("epochs", 10, "train steps.")
flags.DEFINE_integer("batch_size", 512, "Batch Size.")
flags.DEFINE_integer("test_neg_num", 100, "The number of test negative samples.")
flags.DEFINE_integer("k", 10, "recall k items at test stage.")


def main(argv):
# TODO: 1. Split Data
if FLAGS.train_path == "None":
train_path, val_path, test_path, meta_path = ml.split_data(file_path=file_path)
else:
train_path, val_path, test_path, meta_path = FLAGS.train_path, FLAGS.val_path, FLAGS.test_path, FLAGS.meta_path
with open(meta_path) as f:
max_user_num, max_item_num = [int(x) for x in f.readline().strip('\n').split('\t')]
# TODO: 2. Load Data
train_data = ml.load_data(train_path, FLAGS.neg_num, max_item_num)
val_data = ml.load_data(val_path, FLAGS.neg_num, max_item_num)
test_data = ml.load_data(test_path, FLAGS.test_neg_num, max_item_num)
# TODO: 3. Set Model Hyper Parameters.
model_params = {
'user_num': max_user_num + 1,
'item_num': max_item_num + 1,
'embed_dim': FLAGS.embed_dim,
'user_mlp': FLAGS.user_mlp,
'item_mlp': FLAGS.item_mlp,
'activation': FLAGS.activation,
'dnn_dropout': FLAGS.dnn_dropout,
'use_l2norm': FLAGS.use_l2norm,
'loss_name': FLAGS.loss_name,
'gamma': FLAGS.gamma,
'embed_reg': FLAGS.embed_reg
}
# TODO: 4. Build Model
model = DSSM(**model_params)
model.compile(optimizer=Adam(learning_rate=FLAGS.learning_rate))
# TODO: 5. Fit Model
for epoch in range(1, FLAGS.epochs + 1):
t1 = time()
model.fit(
x=train_data,
epochs=1,
validation_data=val_data,
batch_size=FLAGS.batch_size
)
t2 = time()
eval_dict = eval_pos_neg(model, test_data, ['hr', 'mrr', 'ndcg'], FLAGS.k, FLAGS.batch_size)
print('Iteration %d Fit [%.1f s], Evaluate [%.1f s]: HR = %.4f, MRR = %.4f, NDCG = %.4f'
% (epoch, t2 - t1, time() - t2, eval_dict['hr'], eval_dict['mrr'], eval_dict['ndcg']))


if __name__ == '__main__':
app.run(main)
84 changes: 84 additions & 0 deletions example/m_youtubednn_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
Created on Apr 5, 2022
train YoutubeDNN demo
@author: Ziyao Geng([email protected])
"""
import os
from absl import flags, app
from time import time
from tensorflow.keras.optimizers import Adam

from reclearn.models.matching import YoutubeDNN
from reclearn.data.datasets import movielens as ml
from reclearn.evaluator import eval_pos_neg

FLAGS = flags.FLAGS

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

# Setting training parameters
flags.DEFINE_string("file_path", "data/ml-1m/ratings.dat", "file path.")
flags.DEFINE_string("train_path", "data/ml-1m/ml_seq_train.txt", "train path.")
flags.DEFINE_string("val_path", "data/ml-1m/ml_seq_val.txt", "val path.")
flags.DEFINE_string("test_path", "data/ml-1m/ml_seq_test.txt", "test path.")
flags.DEFINE_string("meta_path", "data/ml-1m/ml_seq_meta.txt", "meta path.")
flags.DEFINE_integer("embed_dim", 64, "The size of embedding dimension.")
flags.DEFINE_float("embed_reg", 0.0, "The value of embedding regularization.")
flags.DEFINE_list("user_mlp", [128, 256, 64], "A list of user MLP hidden units.")
flags.DEFINE_string("activation", "relu", "Activation Name.")
flags.DEFINE_float("dnn_dropout", 0., "Float between 0 and 1. Dropout of user and item MLP layer.")
flags.DEFINE_boolean("use_l2norm", False, "Whether user embedding, item embedding should be normalized or not.")
flags.DEFINE_float("learning_rate", 0.001, "Learning rate.")
flags.DEFINE_integer("neg_num", 2, "The number of negative sample for each positive sample.")
flags.DEFINE_integer("seq_len", 200, "The length of user's behavior sequence.")
flags.DEFINE_integer("epochs", 10, "train steps.")
flags.DEFINE_integer("batch_size", 512, "Batch Size.")
flags.DEFINE_integer("test_neg_num", 100, "The number of test negative samples.")
flags.DEFINE_integer("k", 10, "recall k items at test stage.")


def main(argv):
# TODO: 1. Split Data
if FLAGS.train_path == "None":
train_path, val_path, test_path, meta_path = ml.split_seq_data(file_path=FLAGS.file_path)
else:
train_path, val_path, test_path, meta_path = FLAGS.train_path, FLAGS.val_path, FLAGS.test_path, FLAGS.meta_path
with open(meta_path) as f:
_, max_item_num = [int(x) for x in f.readline().strip('\n').split('\t')]
# TODO: 2. Load Sequence Data
train_data = ml.load_seq_data(train_path, "train", FLAGS.seq_len, 0, max_item_num)
val_data = ml.load_seq_data(val_path, "val", FLAGS.seq_len, FLAGS.neg_num, max_item_num)
test_data = ml.load_seq_data(test_path, "test", FLAGS.seq_len, FLAGS.test_neg_num, max_item_num)
# TODO: 3. Set Model Hyper Parameters.
model_params = {
'item_num': max_item_num + 1,
'embed_dim': FLAGS.embed_dim,
'user_mlp': FLAGS.user_mlp,
'activation': FLAGS.activation,
'dnn_dropout': FLAGS.dnn_dropout,
'neg_num': FLAGS.neg_num,
'batch_size': FLAGS.batch_size,
'use_l2norm': FLAGS.use_l2norm,
'embed_reg': FLAGS.embed_reg
}
# TODO: 4. Build Model
model = YoutubeDNN(**model_params)
model.compile(optimizer=Adam(learning_rate=FLAGS.learning_rate))
# TODO: 5. Fit Model
for epoch in range(1, FLAGS.epochs + 1):
t1 = time()
model.fit(
x=train_data,
epochs=1,
validation_data=val_data,
batch_size=FLAGS.batch_size
)
t2 = time()
eval_dict = eval_pos_neg(model, test_data, ['hr', 'mrr', 'ndcg'], FLAGS.k, FLAGS.batch_size)
print('Iteration %d Fit [%.1f s], Evaluate [%.1f s]: HR = %.4f, MRR = %.4f, NDCG = %.4f'
% (epoch, t2 - t1, time() - t2, eval_dict['hr'], eval_dict['mrr'], eval_dict['ndcg']))


if __name__ == '__main__':
app.run(main)
6 changes: 3 additions & 3 deletions reclearn/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_loss(pos_scores, neg_scores, loss_name, gamma=None):
elif loss_name == 'hinge_loss':
loss = hinge_loss(pos_scores, neg_scores, gamma)
else:
loss = binary_entropy_loss(pos_scores, neg_scores)
loss = binary_cross_entropy_loss(pos_scores, neg_scores)
return loss


Expand All @@ -48,8 +48,8 @@ def hinge_loss(pos_scores, neg_scores, gamma=0.5):
return loss


def binary_entropy_loss(pos_scores, neg_scores):
"""binary entropy loss.
def binary_cross_entropy_loss(pos_scores, neg_scores):
"""binary cross entropy loss.
Args:
:param pos_scores: A tensor with shape of [batch_size, neg_num].
:param neg_scores: A tensor with shape of [batch_size, neg_num].
Expand Down
4 changes: 3 additions & 1 deletion reclearn/models/matching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from reclearn.models.matching.poprec import PopRec
from reclearn.models.matching.bpr import BPR
from reclearn.models.matching.ncf import NCF
from reclearn.models.matching.dssm import DSSM
from reclearn.models.matching.youtubednn import YoutubeDNN
from reclearn.models.matching.gru4rec import GRU4Rec
from reclearn.models.matching.sasrec import SASRec
from reclearn.models.matching.attrec import AttRec
from reclearn.models.matching.caser import Caser
from reclearn.models.matching.fissa import FISSA


__all__ = ['PopRec', 'BPR', 'NCF', 'GRU4Rec', 'SASRec', 'AttRec', 'Caser', 'FISSA']
__all__ = ['PopRec', 'BPR', 'NCF', 'DSSM', 'YoutubeDNN', 'GRU4Rec', 'SASRec', 'AttRec', 'Caser', 'FISSA']
90 changes: 90 additions & 0 deletions reclearn/models/matching/dssm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Created on Mar 31, 2022
Reference: "Learning Deep Structured Semantic Models for Web Search using Clickthrough Data", CIKM, 2013
@author: Ziyao Geng([email protected])
"""
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Embedding, Input
from tensorflow.keras.regularizers import l2
from reclearn.layers import MLP
from reclearn.models.losses import get_loss


class DSSM(Model):
def __init__(self, user_num, item_num, embed_dim, user_mlp, item_mlp, activation='relu',
dnn_dropout=0., use_l2norm=False, loss_name="binary_cross_entropy_loss",
gamma=0.5, embed_reg=0., seed=None):
"""DSSM: The two-tower matching model commonly used in industry.
Args:
:param user_num: An integer type. The largest user index + 1.
:param item_num: An integer type. The largest item index + 1.
:param embed_dim: An integer type. Embedding dimension of user vector and item vector.
:param user_mlp: A list of user MLP hidden units such as [128, 64, 32].
:param item_mlp: A list of item MLP hidden units such as [128, 64, 32] and
the last unit must be equal to the user's.
:param activation: A string. Activation function name of user and item MLP layer.
:param dnn_dropout: Float between 0 and 1. Dropout of user and item MLP layer.
:param use_l2norm: A boolean. Whether user embedding, item embedding should be normalized or not.
:param loss_name: A string. You can specify the current point-loss function 'binary_cross_entropy_loss' or
pair-loss function as 'bpr_loss'、'hinge_loss'.
:param gamma: A scalar. If hinge_loss is selected as the loss function, you can specify the margin.
:param embed_reg: A float type. The regularizer of embedding.
:param seed: A Python integer to use as random seed.
:return:
"""
super(DSSM, self).__init__()
if user_mlp[-1] != item_mlp[-1]:
raise ValueError("The last value of user_mlp must be equal to item_mlp's.")
# user embedding
self.user_embedding_table = Embedding(input_dim=user_num,
input_length=1,
output_dim=embed_dim,
embeddings_initializer='random_normal',
embeddings_regularizer=l2(embed_reg))
# item embedding
self.item_embedding_table = Embedding(input_dim=item_num,
input_length=1,
output_dim=embed_dim,
embeddings_initializer='random_normal',
embeddings_regularizer=l2(embed_reg))
# user_mlp_layer
self.user_mlp_layer = MLP(user_mlp, activation, dnn_dropout)
# item_mlp_layer
self.item_mlp_layer = MLP(item_mlp, activation, dnn_dropout)
self.use_l2norm = use_l2norm
self.loss_name = loss_name
self.gamma = gamma
# seed
tf.random.set_seed(seed)

def call(self, inputs):
# user info
user_info = self.user_embedding_table(inputs['user']) # (None, embed_dim)
# item info
pos_info = self.item_embedding_table(inputs['pos_item']) # (None, embed_dim)
neg_info = self.item_embedding_table(inputs['neg_item']) # (None, neg_num, embed_dim)
# mlp
user_info = self.user_mlp_layer(user_info)
pos_info = self.item_mlp_layer(pos_info)
neg_info = self.item_mlp_layer(neg_info)
# norm
if self.use_l2norm:
user_info = tf.math.l2_normalize(user_info, axis=-1)
pos_info = tf.math.l2_normalize(pos_info, axis=-1)
neg_info = tf.math.l2_normalize(neg_info, axis=-1)
# calculate similar scores.
pos_scores = tf.reduce_sum(tf.multiply(user_info, pos_info), axis=-1, keepdims=True) # (None, 1)
neg_scores = tf.reduce_sum(tf.multiply(tf.expand_dims(user_info, axis=1), neg_info), axis=-1) # (None, neg_num)
# add loss
self.add_loss(get_loss(pos_scores, neg_scores, self.loss_name, self.gamma))
logits = tf.concat([pos_scores, neg_scores], axis=-1)
return logits

def summary(self):
inputs = {
'user': Input(shape=(), dtype=tf.int32),
'pos_item': Input(shape=(), dtype=tf.int32),
'neg_item': Input(shape=(1,), dtype=tf.int32) # suppose neg_num=1
}
Model(inputs=inputs, outputs=self.call(inputs)).summary()
Loading

0 comments on commit 917f4f4

Please sign in to comment.