Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example Alanine dipeptide system explodes #115

Closed
JSLJ23 opened this issue Aug 15, 2023 · 10 comments
Closed

Example Alanine dipeptide system explodes #115

JSLJ23 opened this issue Aug 15, 2023 · 10 comments

Comments

@JSLJ23
Copy link

JSLJ23 commented Aug 15, 2023

Hi OpenMM-Torch Devs,
I was hoping to get some help on the properly running the example system in the example (but without NNPOps as I am unable to get that working on my system), but the peptide seems to break apart and drift away.
The output of the MD simulation has extremely high temperatures and the atoms essentially just fly apart.
I have attached a zip file of the PDB output trajectory on this issue.
Appreciate any help and advice on this.
Thank you.

Best regards,
Joshua

The code I am running:

import sys
from time import perf_counter

import openmmtools
import torch
from openmm import LangevinMiddleIntegrator
from openmm.app import Simulation, StateDataReporter
from openmm.unit import femtosecond, kelvin, kilojoule_per_mole, picosecond
from openmmtorch import TorchForce
from torchani.models import ANI2x

from openmm.app.pdbreporter import PDBReporter

# Get the system of alanine dipeptide
ala2 = openmmtools.testsystems.AlanineDipeptideVacuum(constraints=None)

# Remove MM forces
while ala2.system.getNumForces() > 0:
  ala2.system.removeForce(0)

# The system should not contain any additional force and constrains
assert ala2.system.getNumConstraints() == 0
assert ala2.system.getNumForces() == 0

# Get the list of atomic numbers
atomic_numbers = [atom.element.atomic_number for atom in ala2.topology.atoms()]

class NNP(torch.nn.Module):

  def __init__(self, atomic_numbers):
    super().__init__()
    # Store the atomic numbers
    self.atomic_numbers = torch.tensor(atomic_numbers, device=device).unsqueeze(0)
    # Create an ANI-2x model
    self.model = ANI2x(periodic_table_index=True).to(device)
    
  def forward(self, positions):
    # Prepare the positions
    positions = positions.unsqueeze(0).float() * 10 # nm --> Å
    # Run ANI-2x
    result = self.model((self.atomic_numbers, positions))
    # Get the potential energy
    energy = result.energies[0] * 2625.5 # Hartree --> kJ/mol
    return energy

# Create an instance of the model
nnp = NNP(atomic_numbers)

# Save the NNP to a file and load it with OpenMM-Torch
torch.jit.script(nnp).save("./ANI2x_model.pt")
torch_force = TorchForce("./ANI2x_model.pt")

# Add the NNP to the system
ala2.system.addForce(torch_force)
assert ala2.system.getNumForces() == 1

# Create an integrator with a time step of 1 fs
temperature = 298.15 * kelvin
frictionCoeff = 1 / picosecond
timeStep = 1 * femtosecond
integrator = LangevinMiddleIntegrator(temperature, frictionCoeff, timeStep)

# Create a simulation and set the initial positions and velocities
simulation = Simulation(ala2.topology, ala2.system, integrator)
simulation.context.setPositions(ala2.positions)
#simulation.context.setVelocitiesToTemperature(temperature) # This does not work (https://github.com/openmm/openmm-torch/issues/61)

# Configure a reporter to print to the console every 0.1 ps (100 steps)
reporter = StateDataReporter(file=sys.stdout, reportInterval=100, step=True, time=True, potentialEnergy=True, temperature=True)
simulation.reporters.append(reporter)
# Trajectory
trajectory = PDBReporter(file="ala_dp_traj.pdb", reportInterval=10)
simulation.reporters.append(trajectory)

# Run the simulation
simulation.minimizeEnergy(maxIterations=100)
simulation.step(1000)
#"Step","Time (ps)","Potential Energy (kJ/mole)","Temperature (K)"
100,0.10000000000000007,-1298358.3156858836,6594.891802480256
200,0.20000000000000015,-1297025.7032524203,16072.77916553958
300,0.3000000000000002,-1297570.2537676548,21468.322367087236
400,0.4000000000000003,-1297590.059707508,18354.094727838114
500,0.5000000000000003,-1297548.3815075015,15174.619776845926
600,0.6000000000000004,-1297802.9520432805,12882.61440885133
700,0.7000000000000005,-1298165.144314811,12254.39717469564
800,0.8000000000000006,-1297760.8873080467,9694.339541851743
900,0.9000000000000007,-1297937.7098039244,8065.8739320859795
1000,1.0000000000000007,-1298187.6265807603,110463.19667367659

ala_dp_traj.zip

@RaulPPelaez
Copy link
Contributor

@stefdoerr we could use your eagle eye here

@JSLJ23
Copy link
Author

JSLJ23 commented Aug 16, 2023

Linking a bit more information of this issue from here.
It seems that without the NNPOps TorchANISymmetryFunctions() or just using ANI2x natively, there seems to be a recurring issue with larger than normal force values and temperatures / kinetic energies.

@RaulPPelaez
Copy link
Contributor

RaulPPelaez commented Aug 16, 2023

Some context:
I ran the example notebook (which is almost verbatim what @JSLJ23 is using) and it produces reasonable temperatures.
The notebook is using pytorch==1.11, and thus an old version of nnpops (0.2) + openmm-torch (1.0) + torchani (2.2.2). I tried some other combinations. For that I create and env and install these:
mamba install openmm-torch nnpops torchani openmmtools pytorch=1.11
I change the pytorch version and test. I consider "reasonable" this output:

-1301523.8704208438
#"Step","Time (ps)","Potential Energy (kJ/mole)","Temperature (K)"
100,0.10000000000000007,-1301527.270483293,62.66180189563013
200,0.20000000000000015,-1301522.7641006862,97.65153999173894
300,0.3000000000000002,-1301515.40287374,118.82399211492178
400,0.4000000000000003,-1301512.0628259708,130.4099461070207
500,0.5000000000000003,-1301509.3890428697,137.24497926864706
600,0.6000000000000004,-1301508.4165233676,159.31845153306006
700,0.7000000000000005,-1301502.2501128023,150.63939303261853
800,0.8000000000000006,-1301505.847511657,174.9348822485912
900,0.9000000000000007,-1301503.4686377202,224.53581700945554
1000,1.0000000000000007,-1301498.117472202,168.15949337376438

while I consider this bogus:

#"Step","Time (ps)","Potential Energy (kJ/mole)","Temperature (K)"
100,0.10000000000000007,-1298613.4143821453,8083.915551056161
200,0.20000000000000015,-1297751.6999760126,10288.500040147252
300,0.3000000000000002,-1298380.3439685558,14269.652341144034
400,0.4000000000000003,-1298193.265926287,12031.456279865946
500,0.5000000000000003,-1298324.1451893304,10396.591828109476
600,0.6000000000000004,-1297342.4101866935,21304.54016036133
700,0.7000000000000005,-1297641.4798471783,81074.63725027355
800,0.8000000000000006,-1298022.843327031,67496.58902081563

I suspect NNPops is not the problem, so I am using ANI2x directly with the following script:

import openmmtools

# Get the system of alanine dipeptide
ala2 = openmmtools.testsystems.AlanineDipeptideVacuum(constraints=None)

# Remove MM forces
while ala2.system.getNumForces() > 0:
  ala2.system.removeForce(0)

# The system should not contain any additional force and constrains
assert ala2.system.getNumConstraints() == 0
assert ala2.system.getNumForces() == 0

# Get the list of atomic numbers
atomic_numbers = [atom.element.atomic_number for atom in ala2.topology.atoms()]
import torch as pt
from torchani.models import ANI2x
from NNPOps.BatchedNN import TorchANIBatchedNN
from NNPOps.EnergyShifter import TorchANIEnergyShifter, SpeciesEnergies
from NNPOps.SpeciesConverter import TorchANISpeciesConverter
from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions


class NNP(pt.nn.Module):

  def __init__(self, atomic_numbers):

    super().__init__()

    # Store the atomic numbers
    self.device="cuda"
    self.atomic_numbers = pt.tensor(atomic_numbers).unsqueeze(0).to(self.device)

    # Create an ANI-2x model
    self.model = ANI2x(periodic_table_index=True).to(self.device)

    # Accelerate the model
    #self.model = OptimizedTorchANI(self.model, self.atomic_numbers)
  def forward(self, positions):

    # Prepare the positions
    positions = positions.unsqueeze(0).float().to(self.device) * 10 # nm --> Å

    # Run ANI-2x
    result = self.model((self.atomic_numbers, positions))

    # Get the potential energy
    energy = result.energies[0] * 2625.5 # Hartree --> kJ/mol

    return energy

# Create an instance of the model
nnp = NNP(atomic_numbers)
# Comute the potential energy
pos = pt.tensor(ala2.positions.tolist())
energy_1 = float(nnp(pos))
print(energy_1)

# Check if the energy is correct
assert pt.isclose(pt.tensor(energy_1), pt.tensor(-1301523.8703817206))
from openmmtorch import TorchForce

# Save the NNP to a file and load it with OpenMM-Torch
pt.jit.script(nnp).save('model.pt')
force = TorchForce('model.pt')

# Add the NNP to the system
ala2.system.addForce(force)
assert ala2.system.getNumForces() == 1
import sys
from openmm import LangevinMiddleIntegrator
from openmm.app import Simulation, StateDataReporter
from openmm.unit import kelvin, picosecond, femtosecond

# Create an integrator with a time step of 1 fs
temperature = 298.15 * kelvin
frictionCoeff = 1 / picosecond
timeStep = 1 * femtosecond
integrator = LangevinMiddleIntegrator(temperature, frictionCoeff, timeStep)

# Create a simulation and set the initial positions and velocities
simulation = Simulation(ala2.topology, ala2.system, integrator)
simulation.context.setPositions(ala2.positions)
# simulation.context.setVelocitiesToTemperature(temperature) # This does not work (https://github.com/openmm/openmm-torch/issues/61)

# Configure a reporter to print to the console every 0.1 ps (100 steps)
reporter = StateDataReporter(file=sys.stdout, reportInterval=100, step=True, time=True, potentialEnergy=True, temperature=True)
simulation.reporters.append(reporter)
from openmm.unit import kilojoule_per_mole

# Comute the potential energy
state = simulation.context.getState(getEnergy=True)
energy_2 = state.getPotentialEnergy().value_in_unit(kilojoule_per_mole)
print(energy_2)

# Check if the energy is correct
assert pt.isclose(pt.tensor(energy_1), pt.tensor(energy_2))
# Run the simulations for 1 ps (1000 steps)
simulation.step(1000)

Pytorch 2.0 ( nnpops=0.6 openmm-torch=1.1 torchani=2.2.3)

Env:

Click me
# packages in environment at /shared/raul/mambaforge/envs/nnpops_bug:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
astunparse                1.6.3              pyhd8ed1ab_0    conda-forge
aws-c-auth                0.7.0                hbbaa140_3    conda-forge
aws-c-cal                 0.6.0                h93469e0_0    conda-forge
aws-c-common              0.8.23               hd590300_0    conda-forge
aws-c-compression         0.2.17               h862ab75_1    conda-forge
aws-c-event-stream        0.3.1                h9599702_1    conda-forge
aws-c-http                0.7.11               hbe98c3e_0    conda-forge
aws-c-io                  0.13.28              h3870b5a_0    conda-forge
aws-c-mqtt                0.9.0                h2e270ba_0    conda-forge
aws-c-s3                  0.3.13               heb0bb06_2    conda-forge
aws-c-sdkutils            0.1.12               h862ab75_0    conda-forge
aws-checksums             0.1.16               h862ab75_1    conda-forge
aws-crt-cpp               0.21.0               h87b6960_2    conda-forge
aws-sdk-cpp               1.10.57             h7062fed_18    conda-forge
blosc                     1.21.4               h0f2a231_0    conda-forge
brotli-python             1.0.9           py311ha362b79_9    conda-forge
bzip2                     1.0.8                h7f98852_4    conda-forge
c-ares                    1.19.1               hd590300_0    conda-forge
c-blosc2                  2.10.0               hb4ffafa_0    conda-forge
ca-certificates           2023.7.22            hbcca054_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
certifi                   2023.7.22          pyhd8ed1ab_0    conda-forge
cftime                    1.6.2           py311h4c7f6c3_1    conda-forge
charset-normalizer        3.2.0              pyhd8ed1ab_0    conda-forge
cuda-version              11.8                 h70ddcb2_2    conda-forge
cudatoolkit               11.8.0              h4ba93d1_12    conda-forge
cudnn                     8.8.0.121            h0800d71_1    conda-forge
filelock                  3.12.2             pyhd8ed1ab_0    conda-forge
gmp                       6.2.1                h58526e2_0    conda-forge
gmpy2                     2.1.2           py311h6a5fa03_1    conda-forge
h5py                      3.9.0           nompi_py311he78b9b8_101    conda-forge
hdf4                      4.2.15               h501b40f_6    conda-forge
hdf5                      1.14.1          nompi_h4f84152_100    conda-forge
icu                       72.1                 hcb278e6_0    conda-forge
idna                      3.4                pyhd8ed1ab_0    conda-forge
importlib-metadata        6.8.0              pyha770c72_0    conda-forge
importlib_metadata        6.8.0                hd8ed1ab_0    conda-forge
jax                       0.4.14             pyhd8ed1ab_1    conda-forge
jaxlib                    0.4.14          cuda112py311hf2474b9_201    conda-forge
jinja2                    3.1.2              pyhd8ed1ab_1    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
krb5                      1.21.2               h659d440_0    conda-forge
lark-parser               0.12.0             pyhd8ed1ab_0    conda-forge
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
libabseil                 20230125.3      cxx17_h59595ed_0    conda-forge
libaec                    1.0.6                hcb278e6_1    conda-forge
libblas                   3.9.0            16_linux64_mkl    conda-forge
libcblas                  3.9.0            16_linux64_mkl    conda-forge
libcurl                   8.2.1                hca28451_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 h516909a_1    conda-forge
libexpat                  2.5.0                hcb278e6_1    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 13.1.0               he5830b7_0    conda-forge
libgfortran-ng            13.1.0               h69a702a_0    conda-forge
libgfortran5              13.1.0               h15d22d2_0    conda-forge
libgrpc                   1.54.3               hb20ce57_0    conda-forge
libhwloc                  2.9.2           nocuda_h7313eea_1008    conda-forge
libiconv                  1.17                 h166bdaf_0    conda-forge
libjpeg-turbo             2.1.5.1              h0b41bf4_0    conda-forge
liblapack                 3.9.0            16_linux64_mkl    conda-forge
libllvm14                 14.0.6               hcd5def8_4    conda-forge
libmagma                  2.7.1                hc72dce7_3    conda-forge
libmagma_sparse           2.7.1                hc72dce7_4    conda-forge
libnetcdf                 4.9.2           nompi_h7e745eb_109    conda-forge
libnghttp2                1.52.0               h61bc06f_0    conda-forge
libnsl                    2.0.0                h7f98852_0    conda-forge
libprotobuf               3.21.12              h3eb15da_0    conda-forge
libsqlite                 3.42.0               h2797004_0    conda-forge
libssh2                   1.11.0               h0841786_0    conda-forge
libstdcxx-ng              13.1.0               hfd8a6a1_0    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libxml2                   2.11.5               h0d562d8_0    conda-forge
libzip                    1.9.2                hc929e4a_1    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
llvm-openmp               16.0.6               h4dfa4b3_0    conda-forge
llvmlite                  0.40.1          py311ha6695c7_0    conda-forge
lz4-c                     1.9.4                hcb278e6_0    conda-forge
lzo                       2.10              h516909a_1000    conda-forge
magma                     2.7.1                ha770c72_4    conda-forge
markupsafe                2.1.3           py311h459d7ec_0    conda-forge
mdtraj                    1.9.9           py311h90fe790_0    conda-forge
mkl                       2022.2.1         h84fe81f_16997    conda-forge
ml_dtypes                 0.2.0           py311h320fe9a_1    conda-forge
mpc                       1.3.1                hfe3b2da_0    conda-forge
mpfr                      4.2.0                hb012696_0    conda-forge
mpiplus                   v0.0.2             pyhd8ed1ab_0    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
nccl                      2.18.3.1             h12f7317_0    conda-forge
ncurses                   6.4                  hcb278e6_0    conda-forge
netcdf4                   1.6.4           nompi_py311h9a7c333_101    conda-forge
networkx                  3.1                pyhd8ed1ab_0    conda-forge
nnpops                    0.6             cuda112py311h86f5c52_0    conda-forge
nose                      1.3.7                   py_1006    conda-forge
numba                     0.57.1          py311h96b013e_0    conda-forge
numexpr                   2.8.4           mkl_py311hbaa3ca7_1    conda-forge
numpy                     1.24.4          py311h64a7726_0    conda-forge
ocl-icd                   2.3.1                h7f98852_0    conda-forge
ocl-icd-system            1.0.0                         1    conda-forge
openmm                    8.0.0           py311h59c6c42_1    conda-forge
openmm-torch              1.1             cuda112py311h20aef98_0    conda-forge
openmmtools               0.23.1             pyhd8ed1ab_0    conda-forge
openssl                   3.1.2                hd590300_0    conda-forge
opt_einsum                3.3.0              pyhd8ed1ab_1    conda-forge
packaging                 23.1               pyhd8ed1ab_0    conda-forge
pandas                    2.0.3           py311h320fe9a_1    conda-forge
pdbfixer                  1.9                pyh1a96a4e_0    conda-forge
pip                       23.2.1             pyhd8ed1ab_0    conda-forge
platformdirs              3.10.0             pyhd8ed1ab_0    conda-forge
pooch                     1.7.0              pyha770c72_3    conda-forge
py-cpuinfo                9.0.0              pyhd8ed1ab_0    conda-forge
pymbar                    4.0.2                h38be061_0    conda-forge
pymbar-core               4.0.2           py311h1f0f07a_0    conda-forge
pyparsing                 3.1.1              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
pytables                  3.8.0           py311h504fbfb_2    conda-forge
python                    3.11.4          hab00c5b_0_cpython    conda-forge
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python-tzdata             2023.3             pyhd8ed1ab_0    conda-forge
python_abi                3.11                    3_cp311    conda-forge
pytorch                   2.0.0           cuda112py311h13fee9e_200    conda-forge
pytorch-gpu               2.0.0           cuda112py311h9871d0b_200    conda-forge
pytz                      2023.3             pyhd8ed1ab_0    conda-forge
pyyaml                    6.0             py311hd4cff14_5    conda-forge
re2                       2023.03.02           h8c504da_0    conda-forge
readline                  8.2                  h8228510_1    conda-forge
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
rocm-smi                  5.6.0                h59595ed_1    conda-forge
s2n                       1.3.46               h06160fa_0    conda-forge
scipy                     1.11.1          py311h64a7726_0    conda-forge
setuptools                65.3.0             pyhd8ed1ab_1    conda-forge
setuptools-scm            6.3.2              pyhd8ed1ab_0    conda-forge
setuptools_scm            6.3.2                hd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.5.1                h9b69904_2    conda-forge
snappy                    1.1.10               h9fff704_0    conda-forge
sympy                     1.12            pypyh9d50eac_103    conda-forge
tbb                       2021.10.0            h00ab1b0_0    conda-forge
tk                        8.6.12               h27826a3_0    conda-forge
tomli                     2.0.1              pyhd8ed1ab_0    conda-forge
torchani                  2.2.3           cuda112py311he90cd52_2    conda-forge
typing-extensions         4.7.1                hd8ed1ab_0    conda-forge
typing_extensions         4.7.1              pyha770c72_0    conda-forge
tzdata                    2023c                h71feb2d_0    conda-forge
urllib3                   2.0.4              pyhd8ed1ab_0    conda-forge
wheel                     0.41.1             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
zipp                      3.16.2             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               hd590300_5    conda-forge
zlib-ng                   2.0.7                h0b41bf4_0    conda-forge
zstd                      1.5.2                hfc55251_7    conda-forge
  • CPU ok
  • GPU bad

Pytorch 1.13 (nnpops= 0.5, openmm-torch=1.0 torchani=2.2.2)

Click me
# packages in environment at /shared/raul/mambaforge/envs/nnpops_bug_pt13:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
astunparse                1.6.3              pyhd8ed1ab_0    conda-forge
aws-c-auth                0.7.0                hbbaa140_3    conda-forge
aws-c-cal                 0.6.0                h93469e0_0    conda-forge
aws-c-common              0.8.23               hd590300_0    conda-forge
aws-c-compression         0.2.17               h862ab75_1    conda-forge
aws-c-event-stream        0.3.1                h9599702_1    conda-forge
aws-c-http                0.7.11               hbe98c3e_0    conda-forge
aws-c-io                  0.13.28              h3870b5a_0    conda-forge
aws-c-mqtt                0.9.0                h2e270ba_0    conda-forge
aws-c-s3                  0.3.13               heb0bb06_2    conda-forge
aws-c-sdkutils            0.1.12               h862ab75_0    conda-forge
aws-checksums             0.1.16               h862ab75_1    conda-forge
aws-crt-cpp               0.21.0               h87b6960_2    conda-forge
aws-sdk-cpp               1.10.57             h7062fed_18    conda-forge
blosc                     1.21.4               h0f2a231_0    conda-forge
brotli-python             1.0.9           py310hd8f1fbe_9    conda-forge
bzip2                     1.0.8                h7f98852_4    conda-forge
c-ares                    1.19.1               hd590300_0    conda-forge
c-blosc2                  2.10.0               hb4ffafa_0    conda-forge
ca-certificates           2023.7.22            hbcca054_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
certifi                   2023.7.22          pyhd8ed1ab_0    conda-forge
cffi                      1.15.1          py310h255011f_3    conda-forge
cftime                    1.6.2           py310hde88566_1    conda-forge
charset-normalizer        3.2.0              pyhd8ed1ab_0    conda-forge
cuda-version              11.8                 h70ddcb2_2    conda-forge
cudatoolkit               11.8.0              h4ba93d1_12    conda-forge
cudnn                     8.8.0.121            h0800d71_1    conda-forge
h5py                      3.9.0           nompi_py310hcca72df_101    conda-forge
hdf4                      4.2.15               h501b40f_6    conda-forge
hdf5                      1.14.1          nompi_h4f84152_100    conda-forge
icu                       72.1                 hcb278e6_0    conda-forge
idna                      3.4                pyhd8ed1ab_0    conda-forge
importlib-metadata        6.8.0              pyha770c72_0    conda-forge
importlib_metadata        6.8.0                hd8ed1ab_0    conda-forge
jax                       0.4.14             pyhd8ed1ab_1    conda-forge
jaxlib                    0.4.14          cpu_py310h67d73b5_1    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
krb5                      1.21.2               h659d440_0    conda-forge
lark-parser               0.12.0             pyhd8ed1ab_0    conda-forge
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
libabseil                 20230125.3      cxx17_h59595ed_0    conda-forge
libaec                    1.0.6                hcb278e6_1    conda-forge
libblas                   3.9.0           17_linux64_openblas    conda-forge
libcblas                  3.9.0           17_linux64_openblas    conda-forge
libcurl                   8.2.1                hca28451_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 h516909a_1    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 13.1.0               he5830b7_0    conda-forge
libgfortran-ng            13.1.0               h69a702a_0    conda-forge
libgfortran5              13.1.0               h15d22d2_0    conda-forge
libgrpc                   1.54.3               hb20ce57_0    conda-forge
libhwloc                  2.9.2           nocuda_h7313eea_1008    conda-forge
libiconv                  1.17                 h166bdaf_0    conda-forge
libjpeg-turbo             2.1.5.1              h0b41bf4_0    conda-forge
liblapack                 3.9.0           17_linux64_openblas    conda-forge
libllvm14                 14.0.6               hcd5def8_4    conda-forge
libnetcdf                 4.9.2           nompi_h7e745eb_109    conda-forge
libnghttp2                1.52.0               h61bc06f_0    conda-forge
libnsl                    2.0.0                h7f98852_0    conda-forge
libopenblas               0.3.23          pthreads_h80387f5_0    conda-forge
libprotobuf               3.21.12              h3eb15da_0    conda-forge
libsqlite                 3.42.0               h2797004_0    conda-forge
libssh2                   1.11.0               h0841786_0    conda-forge
libstdcxx-ng              13.1.0               hfd8a6a1_0    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libxml2                   2.11.5               h0d562d8_0    conda-forge
libzip                    1.9.2                hc929e4a_1    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
llvm-openmp               16.0.6               h4dfa4b3_0    conda-forge
llvmlite                  0.40.1          py310h1b8f574_0    conda-forge
lz4-c                     1.9.4                hcb278e6_0    conda-forge
lzo                       2.10              h516909a_1000    conda-forge
magma                     2.6.2                hc72dce7_0    conda-forge
mdtraj                    1.9.9           py310h8e08b51_0    conda-forge
mkl                       2022.2.1         h84fe81f_16997    conda-forge
ml_dtypes                 0.2.0           py310h7cbd5c2_1    conda-forge
mpiplus                   v0.0.2             pyhd8ed1ab_0    conda-forge
nccl                      2.18.3.1             h12f7317_0    conda-forge
ncurses                   6.4                  hcb278e6_0    conda-forge
netcdf4                   1.6.4           nompi_py310h6f5dce6_101    conda-forge
ninja                     1.11.1               h924138e_0    conda-forge
nnpops                    0.5             cuda112py310hd4d1af5_0    conda-forge
nose                      1.3.7                   py_1006    conda-forge
numba                     0.57.1          py310h0f6aa51_0    conda-forge
numexpr                   2.7.3           py310hb5077e9_1    conda-forge
numpy                     1.24.4          py310ha4c1d20_0    conda-forge
ocl-icd                   2.3.1                h7f98852_0    conda-forge
ocl-icd-system            1.0.0                         1    conda-forge
openmm                    8.0.0           py310h5728c26_1    conda-forge
openmm-torch              1.0             cuda112py310hbd91edb_1    conda-forge
openmmtools               0.23.1             pyhd8ed1ab_0    conda-forge
openssl                   3.1.2                hd590300_0    conda-forge
opt_einsum                3.3.0              pyhd8ed1ab_1    conda-forge
packaging                 23.1               pyhd8ed1ab_0    conda-forge
pandas                    2.0.3           py310h7cbd5c2_1    conda-forge
pdbfixer                  1.9                pyh1a96a4e_0    conda-forge
pip                       23.2.1             pyhd8ed1ab_0    conda-forge
platformdirs              3.10.0             pyhd8ed1ab_0    conda-forge
pooch                     1.7.0              pyha770c72_3    conda-forge
py-cpuinfo                9.0.0              pyhd8ed1ab_0    conda-forge
pycparser                 2.21               pyhd8ed1ab_0    conda-forge
pymbar                    4.0.2                hff52083_0    conda-forge
pymbar-core               4.0.2           py310h278f3c1_0    conda-forge
pyparsing                 3.1.1              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
pytables                  3.8.0           py310ha028ce3_2    conda-forge
python                    3.10.12         hd12c33a_0_cpython    conda-forge
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python-tzdata             2023.3             pyhd8ed1ab_0    conda-forge
python_abi                3.10                    3_cp310    conda-forge
pytorch                   1.13.1          cuda112py310he33e0d6_200    conda-forge
pytz                      2023.3             pyhd8ed1ab_0    conda-forge
pyyaml                    6.0             py310h5764c6d_5    conda-forge
re2                       2023.03.02           h8c504da_0    conda-forge
readline                  8.2                  h8228510_1    conda-forge
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
rocm-smi                  5.6.0                h59595ed_1    conda-forge
s2n                       1.3.46               h06160fa_0    conda-forge
scipy                     1.11.1          py310ha4c1d20_0    conda-forge
setuptools                59.5.0          py310hff52083_0    conda-forge
setuptools-scm            6.3.2              pyhd8ed1ab_0    conda-forge
setuptools_scm            6.3.2                hd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.5.1                h9b69904_2    conda-forge
snappy                    1.1.10               h9fff704_0    conda-forge
tbb                       2021.10.0            h00ab1b0_0    conda-forge
tk                        8.6.12               h27826a3_0    conda-forge
tomli                     2.0.1              pyhd8ed1ab_0    conda-forge
torchani                  2.2.2           cuda112py310haf08e2f_7    conda-forge
typing-extensions         4.7.1                hd8ed1ab_0    conda-forge
typing_extensions         4.7.1              pyha770c72_0    conda-forge
tzdata                    2023c                h71feb2d_0    conda-forge
urllib3                   2.0.4              pyhd8ed1ab_0    conda-forge
wheel                     0.41.1             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
zipp                      3.16.2             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               hd590300_5    conda-forge
zlib-ng                   2.0.7                h0b41bf4_0    conda-forge
zstd                      1.5.2                hfc55251_7    conda-forge
  • CPU ok
  • GPU bad

Pytorch 1.11 (nnpops= 0.2, openmm-torch=1.0 torchani=2.2.2)

Env:

Click me
# packages in environment at /shared/raul/mambaforge/envs/nnpops_bug_pt11:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
absl-py                   1.4.0              pyhd8ed1ab_0    conda-forge
astunparse                1.6.3              pyhd8ed1ab_0    conda-forge
aws-c-auth                0.7.0                hbbaa140_3    conda-forge
aws-c-cal                 0.6.0                h93469e0_0    conda-forge
aws-c-common              0.8.23               hd590300_0    conda-forge
aws-c-compression         0.2.17               h862ab75_1    conda-forge
aws-c-event-stream        0.3.1                h9599702_1    conda-forge
aws-c-http                0.7.11               hbe98c3e_0    conda-forge
aws-c-io                  0.13.28              h3870b5a_0    conda-forge
aws-c-mqtt                0.9.0                h2e270ba_0    conda-forge
aws-c-s3                  0.3.13               heb0bb06_2    conda-forge
aws-c-sdkutils            0.1.12               h862ab75_0    conda-forge
aws-checksums             0.1.16               h862ab75_1    conda-forge
aws-crt-cpp               0.21.0               h87b6960_2    conda-forge
aws-sdk-cpp               1.10.57             h7062fed_18    conda-forge
blosc                     1.21.4               h0f2a231_0    conda-forge
brotli-python             1.0.9           py310hd8f1fbe_9    conda-forge
bzip2                     1.0.8                h7f98852_4    conda-forge
c-ares                    1.19.1               hd590300_0    conda-forge
c-blosc2                  2.10.0               hb4ffafa_0    conda-forge
ca-certificates           2023.7.22            hbcca054_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
certifi                   2023.7.22          pyhd8ed1ab_0    conda-forge
cffi                      1.15.1          py310h255011f_3    conda-forge
cftime                    1.6.2           py310hde88566_1    conda-forge
charset-normalizer        3.2.0              pyhd8ed1ab_0    conda-forge
cuda-version              11.8                 h70ddcb2_2    conda-forge
cudatoolkit               11.8.0              h4ba93d1_12    conda-forge
cudnn                     8.8.0.121            h0800d71_1    conda-forge
grpc-cpp                  1.47.1               hc2bec63_6    conda-forge
h5py                      3.9.0           nompi_py310hcca72df_101    conda-forge
hdf4                      4.2.15               h501b40f_6    conda-forge
hdf5                      1.14.1          nompi_h4f84152_100    conda-forge
icu                       72.1                 hcb278e6_0    conda-forge
idna                      3.4                pyhd8ed1ab_0    conda-forge
importlib-metadata        6.8.0              pyha770c72_0    conda-forge
importlib_metadata        6.8.0                hd8ed1ab_0    conda-forge
jax                       0.4.1              pyhd8ed1ab_0    conda-forge
jaxlib                    0.3.22          cuda112py310hfa36681_200    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
krb5                      1.21.2               h659d440_0    conda-forge
lark-parser               0.12.0             pyhd8ed1ab_0    conda-forge
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
libabseil                 20220623.0      cxx17_h05df665_6    conda-forge
libaec                    1.0.6                hcb278e6_1    conda-forge
libblas                   3.9.0            16_linux64_mkl    conda-forge
libcblas                  3.9.0            16_linux64_mkl    conda-forge
libcurl                   8.2.1                hca28451_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 h516909a_1    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 13.1.0               he5830b7_0    conda-forge
libgfortran-ng            13.1.0               h69a702a_0    conda-forge
libgfortran5              13.1.0               h15d22d2_0    conda-forge
libhwloc                  2.9.2           nocuda_h7313eea_1008    conda-forge
libiconv                  1.17                 h166bdaf_0    conda-forge
libjpeg-turbo             2.1.5.1              h0b41bf4_0    conda-forge
liblapack                 3.9.0            16_linux64_mkl    conda-forge
libllvm14                 14.0.6               hcd5def8_4    conda-forge
libnetcdf                 4.9.2           nompi_h7e745eb_109    conda-forge
libnghttp2                1.52.0               h61bc06f_0    conda-forge
libnsl                    2.0.0                h7f98852_0    conda-forge
libprotobuf               3.20.3               h3eb15da_0    conda-forge
libsqlite                 3.42.0               h2797004_0    conda-forge
libssh2                   1.11.0               h0841786_0    conda-forge
libstdcxx-ng              13.1.0               hfd8a6a1_0    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libxml2                   2.11.5               h0d562d8_0    conda-forge
libzip                    1.9.2                hc929e4a_1    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
llvm-openmp               16.0.6               h4dfa4b3_0    conda-forge
llvmlite                  0.40.1          py310h1b8f574_0    conda-forge
lz4-c                     1.9.4                hcb278e6_0    conda-forge
lzo                       2.10              h516909a_1000    conda-forge
magma                     2.5.4                hc72dce7_4    conda-forge
mdtraj                    1.9.9           py310h8e08b51_0    conda-forge
mkl                       2022.2.1         h84fe81f_16997    conda-forge
mpiplus                   v0.0.2             pyhd8ed1ab_0    conda-forge
nccl                      2.18.3.1             h12f7317_0    conda-forge
ncurses                   6.4                  hcb278e6_0    conda-forge
netcdf4                   1.6.4           nompi_py310h6f5dce6_101    conda-forge
ninja                     1.11.1               h924138e_0    conda-forge
nnpops                    0.2             cuda112py310h85a0d14_4    conda-forge
nose                      1.3.7                   py_1006    conda-forge
numba                     0.57.1          py310h0f6aa51_0    conda-forge
numexpr                   2.8.4           mkl_py310hab9d358_1    conda-forge
numpy                     1.24.4          py310ha4c1d20_0    conda-forge
ocl-icd                   2.3.1                h7f98852_0    conda-forge
ocl-icd-system            1.0.0                         1    conda-forge
openmm                    8.0.0           py310h5728c26_1    conda-forge
openmm-torch              1.0             cuda112py310hdb05021_1    conda-forge
openmmtools               0.23.1             pyhd8ed1ab_0    conda-forge
openssl                   3.1.2                hd590300_0    conda-forge
opt_einsum                3.3.0              pyhd8ed1ab_1    conda-forge
packaging                 23.1               pyhd8ed1ab_0    conda-forge
pandas                    2.0.3           py310h7cbd5c2_1    conda-forge
pdbfixer                  1.9                pyh1a96a4e_0    conda-forge
pip                       23.2.1             pyhd8ed1ab_0    conda-forge
platformdirs              3.10.0             pyhd8ed1ab_0    conda-forge
pooch                     1.7.0              pyha770c72_3    conda-forge
py-cpuinfo                9.0.0              pyhd8ed1ab_0    conda-forge
pycparser                 2.21               pyhd8ed1ab_0    conda-forge
pymbar                    4.0.2                hff52083_0    conda-forge
pymbar-core               4.0.2           py310h278f3c1_0    conda-forge
pyparsing                 3.1.1              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
pytables                  3.8.0           py310ha028ce3_2    conda-forge
python                    3.10.12         hd12c33a_0_cpython    conda-forge
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python-tzdata             2023.3             pyhd8ed1ab_0    conda-forge
python_abi                3.10                    3_cp310    conda-forge
pytorch                   1.11.0          cuda112py310h51fe464_202    conda-forge
pytz                      2023.3             pyhd8ed1ab_0    conda-forge
pyyaml                    6.0             py310h5764c6d_5    conda-forge
re2                       2022.06.01           h27087fc_1    conda-forge
readline                  8.2                  h8228510_1    conda-forge
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
rocm-smi                  5.6.0                h59595ed_1    conda-forge
s2n                       1.3.46               h06160fa_0    conda-forge
scipy                     1.11.1          py310ha4c1d20_0    conda-forge
setuptools                59.5.0          py310hff52083_0    conda-forge
setuptools-scm            6.3.2              pyhd8ed1ab_0    conda-forge
setuptools_scm            6.3.2                hd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.5.1                h9b69904_2    conda-forge
snappy                    1.1.10               h9fff704_0    conda-forge
tbb                       2021.10.0            h00ab1b0_0    conda-forge
tk                        8.6.12               h27826a3_0    conda-forge
tomli                     2.0.1              pyhd8ed1ab_0    conda-forge
torchani                  2.2.2           cuda112py310h73d5bcf_5    conda-forge
typing-extensions         4.7.1                hd8ed1ab_0    conda-forge
typing_extensions         4.7.1              pyha770c72_0    conda-forge
tzdata                    2023c                h71feb2d_0    conda-forge
urllib3                   2.0.4              pyhd8ed1ab_0    conda-forge
wheel                     0.41.1             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
zipp                      3.16.2             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               hd590300_5    conda-forge
zlib-ng                   2.0.7                h0b41bf4_0    conda-forge
zstd                      1.5.2                hfc55251_7    conda-forge
  • CPU ok
  • GPU ok

@RaulPPelaez
Copy link
Contributor

Ok, if I use OptimizedTorchANI from NNPops instead then everything works with all three pytorch versions:

class NNP(pt.nn.Module):

  def __init__(self, atomic_numbers):

    super().__init__()

    # Store the atomic numbers
    self.device="cuda"
    self.atomic_numbers = pt.tensor(atomic_numbers).unsqueeze(0).to(self.device)

    # Create an ANI-2x model
    self.model = ANI2x(periodic_table_index=True).to(self.device)

    # Accelerate the model
    self.model = OptimizedTorchANI(self.model, self.atomic_numbers).to(self.device)

With these results my conclusion is that this is actually a bug in torchani.

@RaulPPelaez
Copy link
Contributor

RaulPPelaez commented Aug 16, 2023

It looks like @sef43 already encountered and reported this:
aiqm/torchani#628
Indeed, the fix suggested there also works here, adding the following line fixes the issues originally presented:
torch._C._jit_set_nvfuser_enabled(False)
A similar research was conducted here with the same conclusion: openmm/openmm-ml#50

@JSLJ23
Copy link
Author

JSLJ23 commented Aug 17, 2023

Wow thank you so much for tracking this down across the other reported issues!
I think for now I can still opt to use it on Ada GPUs and not have to set torch._C._jit_set_nvfuser_enabled(False) yet. But on older hardware I'll bear in mind to set this.
Any idea why the NNPOps version does not have the error when it get's jit compiled?

@RaulPPelaez
Copy link
Contributor

I did not dig that deep into torchani as to understand their bug, but it seems they stumbled upon an obscure nvfuser bug. The NNPops impl is just different and thus does not trigger that particular behavior.
You were unlucky in the sense that the one component that is missing for your GPU in the NNPops conda release happens to be bogus in the torchani implementation.

@JSLJ23
Copy link
Author

JSLJ23 commented Aug 18, 2023

Ok, will close this issue then as the bug exist within torchANI itself and using the NNPOps version actually works perfectly fine. Thank you!

@sef43
Copy link

sef43 commented Aug 21, 2023

@RaulPPelaez I apologize for not seeing this sooner and you having to rediscover the bug.

Yes its a bug in PyTorch NVFuser. When I first investigated it only seemed to occur for a large number of atoms (e.g. > 1000). This is why it occurred in the mixed system example openmm/openmm-ml#50 but not when running the alanine dipeptide in vacuum example. It seems it occurs all the time now?. Apparently it has been fixed but the fix is not yet in a released version of PyTorch: pytorch/pytorch#84510

@sef43
Copy link

sef43 commented Aug 21, 2023

I think this issue should be reopened until the PyTorch fix is released or we document the issue and workaround in the openmm-torch README.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants