Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION]] about EquivariantLayerNormV2 #13

Open
kzhoa opened this issue Aug 20, 2023 · 1 comment
Open

[QUESTION]] about EquivariantLayerNormV2 #13

kzhoa opened this issue Aug 20, 2023 · 1 comment

Comments

@kzhoa
Copy link

kzhoa commented Aug 20, 2023

Hi, thanks for your wonderful work.
I encountered a question when reading class EquivariantLayerNormV2 in /nets/layer_norm.py .
On computing the field mean with
field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, mul, 1]] ,
Should dim here be actually -1 ?
Since we also compute field_norm withdim==-1 in next few lines.

Related codes:

for mul, ir in self.irreps:  # mul is the multiplicity (number of copies) of some irrep type (ir)
            d = ir.dim
            field = node_input.narrow(1, ix, mul*d)
            ix += mul * d

            # [batch * sample, mul, repr]
            field = field.reshape(-1, mul, d)

            # For scalars first compute and subtract the mean
            if ir.l == 0 and ir.p == 1:
                # TODO:  here the dim should be -1?
                field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, mul, 1]]
                field = field - field_mean
                
            # Then compute the rescaling factor (norm of each feature vector)
            # Rescaling of the norms themselves based on the option "normalization"
            if self.normalization == 'norm':
                field_norm = field.pow(2).sum(-1)  # [batch * sample, mul]
            elif self.normalization == 'component':
                field_norm = field.pow(2).mean(-1)  # [batch * sample, mul]
@yilunliao
Copy link
Member

Hi @w55100

The original code is correct.

  1. The mean is computed over all channels (dimension mul in the code).

  2. Besides, we compute field_mean only for scalars (i.e., degree = 0, parity = even), so d will be 1.
    In this case, computing the mean over d is the same as not computing the mean.

Best

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants