Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

line 244 in meta_neural_network_architectures.py #24

Open
RuohanW opened this issue Oct 7, 2019 · 4 comments
Open

line 244 in meta_neural_network_architectures.py #24

RuohanW opened this issue Oct 7, 2019 · 4 comments

Comments

@RuohanW
Copy link

RuohanW commented Oct 7, 2019

Thank you for releasing the code.

I notice that the function
def forward(self, input, num_step, params=None, training=False, backup_running_statistics=False)

has a training indicator. However, within the function (line 244):

output = F.batch_norm(input, running_mean, running_var, weight, bias,
                              training=True, momentum=momentum, eps=self.eps)

should the training be always set to true? Does this affect the reported results in the original paper, as batch norm per step appears to be an important trick for improving maml from the paper?

Many thanks.

@jfb54
Copy link

jfb54 commented Apr 10, 2020

This is the same as issue: #3

My understanding is that this will affect the results reported in the paper. The code as written will always use the batch statistics, not a running average accumulated per step.

@AntreasAntoniou
Copy link
Owner

AntreasAntoniou commented Apr 10, 2020 via email

@jfb54
Copy link

jfb54 commented Apr 10, 2020

Thanks for the quick response! The following is a short script that demonstrates my assertion. If you have tests that show otherwise, it would be great to see them.

import torch
import torch.nn.functional as F

N = 64  # batch size
C = 16  # number of channels
H = 32  # image height
W = 32  # image width
eps = 1e-05

input = 10 * torch.randn(N, C, H, W)  # create a random input

running_mean = torch.zeros(C)  # set the running mean for all channels to be 0
running_var = torch.ones(C)  # set the running var for all channels to be 1

# Call batch norm with training=False. Expect that the input is normalized with the running mean and running variance
output = F.batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05)

# Assert that the output is equal to the input
assert torch.allclose(input, output)

# Call batch norm with training=True. Expect that the input is normalized with batch statistics of the input.
output_bn = F.batch_norm(input, running_mean, running_var, weight=None, bias=None, training=True, momentum=0.1, eps=eps)

# Normalize the input manually
batch_mean = torch.mean(input, dim=(0, 2, 3), keepdim=True)
batch_var = torch.var(input, dim=(0, 2, 3), keepdim=True)
output_manual = (input - batch_mean) / torch.sqrt(batch_var + eps)

# Assert that output_bn equals output_manual
assert torch.allclose(output_bn, output_manual)

@AntreasAntoniou
Copy link
Owner

AntreasAntoniou commented Apr 10, 2020 via email

@suargi suargi mentioned this issue Nov 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants