Skip to content

Commit

Permalink
remove torch_runstats dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
cw-tan committed Oct 6, 2024
1 parent cd4c90c commit 5d50474
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 8 deletions.
3 changes: 1 addition & 2 deletions examples/lj/lj.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

import torch

from torch_runstats.scatter import scatter

from nequip.utils import scatter
from nequip.data import AtomicDataDict
from nequip.nn import GraphModuleMixin, SequentialGraphNetwork, AtomwiseReduce

Expand Down
3 changes: 1 addition & 2 deletions nequip/nn/_atomwise.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import torch
import torch.nn.functional
from torch_runstats.scatter import scatter

from e3nn.o3 import Linear

from nequip.data import AtomicDataDict
from nequip.utils import dtype_from_name, format_type_vals
from nequip.utils import scatter, dtype_from_name, format_type_vals
from nequip.utils.versions import _TORCH_IS_GE_1_13
from ._graph_mixin import GraphModuleMixin
from ._rescale import RescaleOutput
Expand Down
3 changes: 1 addition & 2 deletions nequip/nn/_interaction_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@

import torch

from torch_runstats.scatter import scatter

from e3nn import o3
from e3nn.nn import FullyConnectedNet
from e3nn.o3 import TensorProduct, Linear, FullyConnectedTensorProduct

from nequip.utils import scatter
from nequip.data import AtomicDataDict
from nequip.nn.nonlinearities import ShiftedSoftPlus
from ._graph_mixin import GraphModuleMixin
Expand Down
2 changes: 1 addition & 1 deletion nequip/nn/pair_potential.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Union, Optional, List

import torch
from torch_runstats.scatter import scatter

from e3nn.util.jit import compile_mode

import ase.data

from nequip.utils import scatter
from nequip.data import AtomicDataDict
from nequip.nn import GraphModuleMixin, RescaleOutput

Expand Down
3 changes: 3 additions & 0 deletions nequip/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from .misc import dtype_to_name, dtype_from_name, torch_default_dtype, format_type_vals
from .file_utils import download_url, extract_zip
from .logger import RankedLogger
from .scatter import scatter


__all__ = [
instantiate_from_cls_name,
Expand All @@ -36,4 +38,5 @@
download_url,
extract_zip,
RankedLogger,
scatter,
]
43 changes: 43 additions & 0 deletions nequip/utils/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Migrated from https://github.com/mir-group/pytorch_runstats
"""

import torch
from typing import Optional


def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src


@torch.jit.script
def scatter(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
reduce: str = "sum",
) -> torch.Tensor:
assert reduce == "sum" # for now, TODO
index = _broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
'contextvars;python_version<"3.7"', # backport of contextvars for savenload
"typing_extensions;python_version<'3.8'", # backport of Final
"importlib_metadata;python_version<'3.10'", # backport of importlib
"torch-runstats>=0.2.0",
"torch-ema>=0.3.0",
"hydra-core",
"lightning",
Expand Down

0 comments on commit 5d50474

Please sign in to comment.