You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, thanks for your wonderful work.
I encountered a question when reading class EquivariantLayerNormV2 in /nets/layer_norm.py .
On computing the field mean with field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, mul, 1]] ,
Should dim here be actually -1 ?
Since we also compute field_norm withdim==-1 in next few lines.
Related codes:
formul, irinself.irreps: # mul is the multiplicity (number of copies) of some irrep type (ir)d=ir.dimfield=node_input.narrow(1, ix, mul*d)
ix+=mul*d# [batch * sample, mul, repr]field=field.reshape(-1, mul, d)
# For scalars first compute and subtract the meanifir.l==0andir.p==1:
# TODO: here the dim should be -1?field_mean=torch.mean(field, dim=1, keepdim=True) # [batch, mul, 1]]field=field-field_mean# Then compute the rescaling factor (norm of each feature vector)# Rescaling of the norms themselves based on the option "normalization"ifself.normalization=='norm':
field_norm=field.pow(2).sum(-1) # [batch * sample, mul]elifself.normalization=='component':
field_norm=field.pow(2).mean(-1) # [batch * sample, mul]
The text was updated successfully, but these errors were encountered:
The mean is computed over all channels (dimension mul in the code).
Besides, we compute field_mean only for scalars (i.e., degree = 0, parity = even), so d will be 1.
In this case, computing the mean over d is the same as not computing the mean.
Hi, thanks for your wonderful work.
I encountered a question when reading
class EquivariantLayerNormV2
in/nets/layer_norm.py
.On computing the field mean with
field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, mul, 1]]
,Should dim here be actually
-1
?Since we also compute
field_norm
withdim==-1
in next few lines.Related codes:
The text was updated successfully, but these errors were encountered: