diff --git a/lib/loss/clipped_weighted_huber_loss.py b/lib/loss/clipped_weighted_huber_loss.py index 7e2a87b..e85875a 100644 --- a/lib/loss/clipped_weighted_huber_loss.py +++ b/lib/loss/clipped_weighted_huber_loss.py @@ -34,7 +34,7 @@ def forward(self, inputs): xp.square(abs_diff, out=abs_diff) y = (y - abs_diff) * 0.5 - return y.mean(), + return xp.array(y.mean(), dtype=y.dtype), def backward(self, inputs, grad_outputs): xp = cuda.get_array_module(*inputs)