Skip to content

Commit

Permalink
Fix interoperability with CustomCVForce (#80)
Browse files Browse the repository at this point in the history
* Add a test with CustomCVForce

* Test all the platforms

* Add an iteroperability test for TorchANI and NNPOps

* Add a missing dependencies

* Skip for MacOS

* Move imports

* Fix import

* Retain the primary context

* Switch properly the contexts

* Set the oldest CUDA to 11.0

* Fix nvcc version

* Enable an extra check

* Clean up a temporary file

* Add more checks

* Add comments

* Remove a sync and clean up

* Move the primary context activation
  • Loading branch information
Raimondas Galvelis authored Jul 8, 2022
1 parent 661b004 commit 5dc7279
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 19 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ jobs:
matrix:
include:
# Oldest supported versions
- name: Linux (CUDA 10.2, Python 3.7, PyTorch 1.11)
# NOTE: renable CUDA 10.2 when it supported by NNPOps (https://github.com/conda-forge/nnpops-feedstock/pull/8)
- name: Linux (CUDA 11.0, Python 3.7, PyTorch 1.11)
os: ubuntu-18.04
cuda-version: "10.2.89"
cuda-version: "11.0.3"
gcc-version: "8.5.*"
nvcc-version: "10.2"
nvcc-version: "11.0"
python-version: "3.7"
pytorch-version: "1.11.*"

Expand Down
4 changes: 3 additions & 1 deletion devtools/conda-envs/build-ubuntu-18.04.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dependencies:
- cudatoolkit @CUDATOOLKIT_VERSION@
- gxx_linux-64 @GCC_VERSION@
- make
- nnpops
- nvcc_linux-64 @NVCC_VERSION@
- ocl-icd
- openmm >=7.7
Expand All @@ -15,4 +16,5 @@ dependencies:
- python
- pytorch-gpu @PYTORCH_VERSION@
- swig
- sysroot_linux-64 2.17
- sysroot_linux-64 2.17
- torchani
43 changes: 36 additions & 7 deletions platforms/cuda/src/CudaTorchKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ if (result != CUDA_SUCCESS) { \
throw OpenMMException(m.str());\
}

CudaCalcTorchForceKernel::CudaCalcTorchForceKernel(string name, const Platform& platform, CudaContext& cu) :
CalcTorchForceKernel(name, platform), hasInitializedKernel(false), cu(cu) {
// Explicitly activate the primary context
CHECK_RESULT(cuDevicePrimaryCtxRetain(&primaryContext, cu.getDevice()), "Failed to retain the primary context");
}

CudaCalcTorchForceKernel::~CudaCalcTorchForceKernel() {
cuDevicePrimaryCtxRelease(cu.getDevice());
}

void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce& force, torch::jit::script::Module& module) {
Expand All @@ -60,6 +67,11 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
globalNames.push_back(force.getGlobalParameterName(i));
int numParticles = system.getNumParticles();

// Push the PyTorch context
// NOTE: Pytorch is always using the primary context.
// It makes the primary context current, if it is not a case.
CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context");

// Initialize CUDA objects for PyTorch
const torch::Device device(torch::kCUDA, cu.getDeviceIndex()); // This implicitly initialize PyTorch
module.to(device);
Expand All @@ -69,8 +81,13 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
posTensor = torch::empty({numParticles, 3}, options.requires_grad(!outputsForces));
boxTensor = torch::empty({3, 3}, options);

// Pop the PyToch context
CUcontext ctx;
CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context");
assert(primaryContext == ctx); // Check that PyTorch haven't messed up the context stack

// Initialize CUDA objects for OpenMM-Torch
ContextSelector selector(cu);
ContextSelector selector(cu); // Switch to the OpenMM context
map<string, string> defines;
CUmodule program = cu.createModule(CudaTorchKernelSources::torchForce, defines);
copyInputsKernel = cu.getKernel(program, "copyInputs");
Expand All @@ -80,6 +97,9 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
int numParticles = cu.getNumAtoms();

// Push to the PyTorch context
CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context");

// Get pointers to the atomic positions and simulation box
void* posData;
void* boxData;
Expand All @@ -94,11 +114,11 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce

// Copy the atomic positions and simulation box to PyTorch tensors
{
ContextSelector selector(cu);
ContextSelector selector(cu); // Switch to the OpenMM context
void* inputArgs[] = {&posData, &boxData, &cu.getPosq().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(),
&numParticles, cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer()};
cu.executeKernel(copyInputsKernel, inputArgs, numParticles);
CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context"); // Synchronize before switching to the PyTorch context
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context
}

// Prepare the input of the PyTorch model
Expand Down Expand Up @@ -138,21 +158,30 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
forceTensor = forceTensor.to(torch::kFloat32);
forceData = forceTensor.data_ptr<float>();
}
CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context"); // Synchronize before switching to the OpenMM context
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the OpenMM context

// Add the computed forces to the total atomic forces
{
ContextSelector selector(cu);
ContextSelector selector(cu); // Switch to the OpenMM context
int paddedNumAtoms = cu.getPaddedNumAtoms();
int forceSign = (outputsForces ? 1 : -1);
void* forceArgs[] = {&forceData, &cu.getForce().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), &numParticles, &paddedNumAtoms, &forceSign};
cu.executeKernel(addForcesKernel, forceArgs, numParticles);
CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context"); // Synchronize before switching to the PyTorch context
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context
}

// Reset the forces
if (!outputsForces)
posTensor.grad().zero_();
}
return energyTensor.item<double>(); // This implicitly synchronize the PyTorch context

// Get energy
const double energy = energyTensor.item<double>(); // This implicitly synchronizes the PyTorch context

// Pop to the PyTorch context
CUcontext ctx;
CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context");
assert(primaryContext == ctx); // Check that the correct context was popped

return energy;
}
6 changes: 2 additions & 4 deletions platforms/cuda/src/CudaTorchKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

#include "TorchKernels.h"
#include "openmm/cuda/CudaContext.h"
#include "openmm/cuda/CudaArray.h"

namespace TorchPlugin {

Expand All @@ -43,9 +42,7 @@ namespace TorchPlugin {
*/
class CudaCalcTorchForceKernel : public CalcTorchForceKernel {
public:
CudaCalcTorchForceKernel(std::string name, const OpenMM::Platform& platform, OpenMM::CudaContext& cu) :
CalcTorchForceKernel(name, platform), hasInitializedKernel(false), cu(cu) {
}
CudaCalcTorchForceKernel(std::string name, const OpenMM::Platform& platform, OpenMM::CudaContext& cu);
~CudaCalcTorchForceKernel();
/**
* Initialize the kernel.
Expand All @@ -72,6 +69,7 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel {
std::vector<std::string> globalNames;
bool usePeriodic, outputsForces;
CUfunction copyInputsKernel, addForcesKernel;
CUcontext primaryContext;
};

} // namespace TorchPlugin
Expand Down
71 changes: 71 additions & 0 deletions python/tests/TestInteroperability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import openmm as mm
import openmm.unit as unit
import openmmtorch as ot
import platform
import pytest
from tempfile import NamedTemporaryFile
import torch as pt


@pytest.mark.skipif(platform.system() == 'Darwin', reason='There is no NNPOps package for MacOS')
@pytest.mark.parametrize('use_cv_force', [True, False])
@pytest.mark.parametrize('platform', ['Reference', 'CPU', 'CUDA', 'OpenCL'])
def testTorchANI(use_cv_force, platform):

if pt.cuda.device_count() < 1 and platform == 'CUDA':
pytest.skip('A CUDA device is not available')

import NNPOps # There is no NNPOps package for MacOS
import torchani

class Model(pt.nn.Module):

def __init__(self):
super().__init__()
self.register_buffer('atomic_numbers', pt.tensor([[1, 1]]))
self.model = torchani.models.ANI2x(periodic_table_index=True)
self.model = NNPOps.OptimizedTorchANI(self.model, self.atomic_numbers)

def forward(self, positions):
positions = positions.float().unsqueeze(0) * 10 # nm --> Ang
return self.model((self.atomic_numbers, positions)).energies[0] * 2625.5 # Hartree --> kJ/mol

# Create a system
system = mm.System()
for _ in range(2):
system.addParticle(1.0)
positions = pt.tensor([[-5, 0.0, 0.0], [5, 0.0, 0.0]], requires_grad=True)

with NamedTemporaryFile() as model_file:

# Save the model
pt.jit.script(Model()).save(model_file.name)

# Compute reference energy and forces
model = pt.jit.load(model_file)
ref_energy = model(positions)
ref_energy.backward()
ref_forces = positions.grad

# Create a force
force = ot.TorchForce(model_file.name)
if use_cv_force:
# Wrap TorchForce into CustomCVForce
cv_force = mm.CustomCVForce('force')
cv_force.addCollectiveVariable('force', force)
system.addForce(cv_force)
else:
system.addForce(force)

# Compute energy and forces
integ = mm.VerletIntegrator(1.0)
platform = mm.Platform.getPlatformByName(platform)
context = mm.Context(system, integ, platform)
context.setPositions(positions.detach().numpy())
state = context.getState(getEnergy=True, getForces=True)
energy = state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
forces = state.getForces(asNumpy=True).value_in_unit(unit.kilojoules_per_mole/unit.nanometers)

# Check energy and forces
assert pt.allclose(ref_energy, pt.tensor(energy, dtype=ref_energy.dtype))
assert pt.allclose(ref_forces, pt.tensor(forces, dtype=ref_forces.dtype))
20 changes: 16 additions & 4 deletions python/tests/TestTorchForce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,37 @@
@pytest.mark.parametrize('model_file, output_forces,',
[('../../tests/central.pt', False),
('../../tests/forces.pt', True)])
def testForce(model_file, output_forces):
@pytest.mark.parametrize('use_cv_force', [True, False])
@pytest.mark.parametrize('platform', ['Reference', 'CPU', 'CUDA', 'OpenCL'])
def testForce(model_file, output_forces, use_cv_force, platform):

if pt.cuda.device_count() < 1 and platform == 'CUDA':
pytest.skip('A CUDA device is not available')

# Create a random cloud of particles.
numParticles = 10
system = mm.System()
positions = np.random.rand(numParticles, 3)
for i in range(numParticles):
for _ in range(numParticles):
system.addParticle(1.0)

# Create a force
force = ot.TorchForce(model_file)
assert not force.getOutputsForces() # Check the default
force.setOutputsForces(output_forces)
assert force.getOutputsForces() == output_forces
system.addForce(force)
if use_cv_force:
# Wrap TorchForce into CustomCVForce
cv_force = mm.CustomCVForce('force')
cv_force.addCollectiveVariable('force', force)
system.addForce(cv_force)
else:
system.addForce(force)

# Compute the forces and energy.
integ = mm.VerletIntegrator(1.0)
context = mm.Context(system, integ, mm.Platform.getPlatformByName('Reference'))
platform = mm.Platform.getPlatformByName(platform)
context = mm.Context(system, integ, platform)
context.setPositions(positions)
state = context.getState(getEnergy=True, getForces=True)

Expand Down

0 comments on commit 5dc7279

Please sign in to comment.