Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OIM loss #90

Open
khadijakhaldi opened this issue Feb 6, 2020 · 1 comment
Open

OIM loss #90

khadijakhaldi opened this issue Feb 6, 2020 · 1 comment

Comments

@khadijakhaldi
Copy link

Why in the OIM loss we need to define the class OIM which extends Function class ? Normally pytorch does the backward for us, so we don't need to write the backward function.
Thank you.

@X-funbean
Copy link

I believe you are right. Also, I find that the 'momentum' seems useless, for I get the same output no matter what value I set.
I write a test code based on test/loss/test_oim.py. Notice I'm using the latest version of PyTorch, so I made several modifications.

from __future__ import absolute_import

import torch
import torch.nn.functional as F
from torch import nn, autograd
from torch.autograd import Variable


class OIM(autograd.Function):

    # def __init__(self, lut, momentum=0.5):
    #     super(OIM, self).__init__()
    #     self.lut = lut
    #     self.momentum = momentum

    @staticmethod
    def forward(ctx, inputs, targets, lut, momentum=0.5):
        ctx.save_for_backward(inputs, targets)
        ctx.lut = lut
        ctx.momentum = momentum

        outputs = inputs.mm(lut.t())
        return outputs

    @staticmethod
    def backward(ctx, grad_outputs):
        inputs, targets = ctx.saved_tensors
        grad_inputs = None
        if ctx.needs_input_grad[0]:
            print(ctx.needs_input_grad)
            grad_inputs = grad_outputs.mm(ctx.lut)

        for x, y in zip(inputs, targets):
            ctx.lut[y] = ctx.momentum * ctx.lut[y] + (1. - ctx.momentum) * x
            ctx.lut[y] /= ctx.lut[y].norm()
        return grad_inputs, None, None, None


class OIMLoss(nn.Module):
    def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5,
                 weight=None, reduction='mean'):
        super(OIMLoss, self).__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.momentum = momentum
        self.scalar = scalar
        self.weight = weight
        self.reduction = reduction

        self.register_buffer('lut', torch.zeros(num_classes, num_features))

        self.oim = OIM.apply

    def forward(self, inputs, targets):
        # inputs = oim(inputs, targets, self.lut, momentum=self.momentum)
        inputs = self.oim(inputs, targets, self.lut, self.momentum)
        inputs *= self.scalar
        loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction=self.reduction)
        return loss
        # return loss, inputs


class Test(nn.Module):
    def __init__(self, num_features, num_classes, scalar=1.0, weight=None, reduction='mean'):
        super(Test, self).__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.scalar = scalar
        self.weight = weight
        self.reduction = reduction

        self.register_buffer('lut', torch.zeros(num_classes, num_features))

    def forward(self, inputs, targets):
        inputs = inputs.mm(self.lut.t())
        inputs *= self.scalar
        loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction=self.reduction)
        return loss


if __name__ == '__main__':
    criterion = OIMLoss(3, 3, scalar=1.0, reduction='sum', momentum=0.9)
    # criterion_2 = OIMLoss(3, 3, scalar=1.0, reduction='sum', momentum=0)
    criterion_2 = Test(3, 3, scalar=1.0, reduction='sum')

    seed = 2018
    criterion.lut = torch.eye(3)
    criterion_2.lut = torch.eye(3)
    torch.manual_seed(seed)

    x = torch.randn(3, 3, requires_grad=True)
    y = torch.arange(0, 3)

    loss = criterion(x, y)
    loss.backward()
    probs = F.softmax(x, dim=-1)
    grads = probs.data - torch.eye(3)
    abs_diff = torch.abs(grads - x.grad.data)

    print(probs)
    print(grads)
    print(abs_diff)

    print('*' * 50)

    torch.manual_seed(seed)
    x = torch.randn(3, 3, requires_grad=True)
    y = torch.arange(0, 3)
    loss = criterion_2(x, y)
    loss.backward()
    probs = F.softmax(x, dim=-1)
    grads = probs.data - torch.eye(3)
    abs_diff = torch.abs(grads - x.grad.data)

    print(probs)
    print(grads)
    print(abs_diff)

and the output is

(True, False, False, False)
tensor([[0.6779, 0.2672, 0.0548],
        [0.1680, 0.2574, 0.5747],
        [0.1614, 0.3012, 0.5374]], grad_fn=<SoftmaxBackward>)
tensor([[-0.3221,  0.2672,  0.0548],
        [ 0.1680, -0.7426,  0.5747],
        [ 0.1614,  0.3012, -0.4626]])
tensor([[0.0000e+00, 0.0000e+00, 3.7253e-09],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.4901e-08, 0.0000e+00, 5.9605e-08]])
**************************************************
tensor([[0.6779, 0.2672, 0.0548],
        [0.1680, 0.2574, 0.5747],
        [0.1614, 0.3012, 0.5374]], grad_fn=<SoftmaxBackward>)
tensor([[-0.3221,  0.2672,  0.0548],
        [ 0.1680, -0.7426,  0.5747],
        [ 0.1614,  0.3012, -0.4626]])
tensor([[0.0000e+00, 0.0000e+00, 3.7253e-09],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.4901e-08, 0.0000e+00, 5.9605e-08]])

I think this version of OIM loss requires being perfected, because it seems that this code is widely used in person search. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants