-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OIM loss #90
Comments
I believe you are right. Also, I find that the 'momentum' seems useless, for I get the same output no matter what value I set. from __future__ import absolute_import
import torch
import torch.nn.functional as F
from torch import nn, autograd
from torch.autograd import Variable
class OIM(autograd.Function):
# def __init__(self, lut, momentum=0.5):
# super(OIM, self).__init__()
# self.lut = lut
# self.momentum = momentum
@staticmethod
def forward(ctx, inputs, targets, lut, momentum=0.5):
ctx.save_for_backward(inputs, targets)
ctx.lut = lut
ctx.momentum = momentum
outputs = inputs.mm(lut.t())
return outputs
@staticmethod
def backward(ctx, grad_outputs):
inputs, targets = ctx.saved_tensors
grad_inputs = None
if ctx.needs_input_grad[0]:
print(ctx.needs_input_grad)
grad_inputs = grad_outputs.mm(ctx.lut)
for x, y in zip(inputs, targets):
ctx.lut[y] = ctx.momentum * ctx.lut[y] + (1. - ctx.momentum) * x
ctx.lut[y] /= ctx.lut[y].norm()
return grad_inputs, None, None, None
class OIMLoss(nn.Module):
def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5,
weight=None, reduction='mean'):
super(OIMLoss, self).__init__()
self.num_features = num_features
self.num_classes = num_classes
self.momentum = momentum
self.scalar = scalar
self.weight = weight
self.reduction = reduction
self.register_buffer('lut', torch.zeros(num_classes, num_features))
self.oim = OIM.apply
def forward(self, inputs, targets):
# inputs = oim(inputs, targets, self.lut, momentum=self.momentum)
inputs = self.oim(inputs, targets, self.lut, self.momentum)
inputs *= self.scalar
loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction=self.reduction)
return loss
# return loss, inputs
class Test(nn.Module):
def __init__(self, num_features, num_classes, scalar=1.0, weight=None, reduction='mean'):
super(Test, self).__init__()
self.num_features = num_features
self.num_classes = num_classes
self.scalar = scalar
self.weight = weight
self.reduction = reduction
self.register_buffer('lut', torch.zeros(num_classes, num_features))
def forward(self, inputs, targets):
inputs = inputs.mm(self.lut.t())
inputs *= self.scalar
loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction=self.reduction)
return loss
if __name__ == '__main__':
criterion = OIMLoss(3, 3, scalar=1.0, reduction='sum', momentum=0.9)
# criterion_2 = OIMLoss(3, 3, scalar=1.0, reduction='sum', momentum=0)
criterion_2 = Test(3, 3, scalar=1.0, reduction='sum')
seed = 2018
criterion.lut = torch.eye(3)
criterion_2.lut = torch.eye(3)
torch.manual_seed(seed)
x = torch.randn(3, 3, requires_grad=True)
y = torch.arange(0, 3)
loss = criterion(x, y)
loss.backward()
probs = F.softmax(x, dim=-1)
grads = probs.data - torch.eye(3)
abs_diff = torch.abs(grads - x.grad.data)
print(probs)
print(grads)
print(abs_diff)
print('*' * 50)
torch.manual_seed(seed)
x = torch.randn(3, 3, requires_grad=True)
y = torch.arange(0, 3)
loss = criterion_2(x, y)
loss.backward()
probs = F.softmax(x, dim=-1)
grads = probs.data - torch.eye(3)
abs_diff = torch.abs(grads - x.grad.data)
print(probs)
print(grads)
print(abs_diff) and the output is (True, False, False, False)
tensor([[0.6779, 0.2672, 0.0548],
[0.1680, 0.2574, 0.5747],
[0.1614, 0.3012, 0.5374]], grad_fn=<SoftmaxBackward>)
tensor([[-0.3221, 0.2672, 0.0548],
[ 0.1680, -0.7426, 0.5747],
[ 0.1614, 0.3012, -0.4626]])
tensor([[0.0000e+00, 0.0000e+00, 3.7253e-09],
[0.0000e+00, 0.0000e+00, 0.0000e+00],
[1.4901e-08, 0.0000e+00, 5.9605e-08]])
**************************************************
tensor([[0.6779, 0.2672, 0.0548],
[0.1680, 0.2574, 0.5747],
[0.1614, 0.3012, 0.5374]], grad_fn=<SoftmaxBackward>)
tensor([[-0.3221, 0.2672, 0.0548],
[ 0.1680, -0.7426, 0.5747],
[ 0.1614, 0.3012, -0.4626]])
tensor([[0.0000e+00, 0.0000e+00, 3.7253e-09],
[0.0000e+00, 0.0000e+00, 0.0000e+00],
[1.4901e-08, 0.0000e+00, 5.9605e-08]]) I think this version of OIM loss requires being perfected, because it seems that this code is widely used in person search. Thanks! |
Why in the OIM loss we need to define the class OIM which extends Function class ? Normally pytorch does the backward for us, so we don't need to write the backward function.
Thank you.
The text was updated successfully, but these errors were encountered: