Skip to content

Commit

Permalink
Bug fix for jrc loss with scalar sample weight (#491)
Browse files Browse the repository at this point in the history
* fix bug of jrc loss when sample weight is a scalar
  • Loading branch information
yangxudong authored Nov 8, 2024
1 parent a669523 commit beafb3c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions easy_rec/python/loss/jrc_loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging

import numpy as np
import tensorflow as tf

if tf.__version__ >= '2.0':
Expand Down Expand Up @@ -66,8 +66,6 @@ def jrc_loss(labels,
pairwise_weights = tf.tile(weights, tf.stack([batch_size, 1]))
y_pos *= pairwise_weights
y_neg *= pairwise_weights
else:
assert sample_weights == 1.0, 'invalid sample_weight %d' % sample_weights

# Compute list-wise generative loss -log p(x|y, z)
if same_label_loss:
Expand Down Expand Up @@ -124,4 +122,6 @@ def jrc_loss(labels,
else:
raise ValueError('Unsupported loss weight strategy `%s` for jrc loss' %
loss_weight_strategy)
if np.isscalar(sample_weights):
return loss * sample_weights
return loss

0 comments on commit beafb3c

Please sign in to comment.