From 6cbc78d409f3e1d682e0661c0d150c3629ad0fdb Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 14 Nov 2023 08:17:54 +0100 Subject: [PATCH 01/31] Conditionally set C++17 for latest pytorch versions --- CMakeLists.txt | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8573ec30..4c4e83a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ # OpenMM PyTorch Plugin #---------------------------------------------------- -CMAKE_MINIMUM_REQUIRED(VERSION 3.5) +CMAKE_MINIMUM_REQUIRED(VERSION 3.7) # We need to know where OpenMM is installed so we can access the headers and libraries. SET(OPENMM_DIR "/usr/local/openmm" CACHE PATH "Where OpenMM is installed") @@ -14,8 +14,15 @@ SET(PYTORCH_DIR "" CACHE PATH "Where the PyTorch C++ API is installed") SET(CMAKE_PREFIX_PATH "${PYTORCH_DIR}") FIND_PACKAGE(Torch REQUIRED) -# Specify the C++ version we are building for. -SET (CMAKE_CXX_STANDARD 14) +# Specify the C++ version we are building for. Latest pytorch versions require C++17 +message(STATUS "Found Torch: ${Torch_VERSION}") +if(${Torch_VERSION} VERSION_GREATER_EQUAL "2.1.0") + set(CMAKE_CXX_STANDARD 17) + message(STATUS "Setting C++ standard to C++17") +else() + set(CMAKE_CXX_STANDARD 14) + message(STATUS "Setting C++ standard to C++14") +endif() # Set flags for linking on mac IF(APPLE) From d26f9454883b14d9c9a3442f8541864eb4270614 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 14 Nov 2023 08:53:45 +0100 Subject: [PATCH 02/31] Add correct std to setup.py too --- python/CMakeLists.txt | 1 + python/setup.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index ee0c19dd..65593c5f 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -23,6 +23,7 @@ add_custom_command( add_custom_target(PythonInstall DEPENDS "${WRAP_FILE}" "${CMAKE_CURRENT_SOURCE_DIR}/setup.py") set(NN_PLUGIN_HEADER_DIR "${CMAKE_SOURCE_DIR}/openmmapi/include") set(NN_PLUGIN_LIBRARY_DIR "${CMAKE_BINARY_DIR}") +set(EXTENSION_CXX_STANDARD ${CMAKE_CXX_STANDARD}) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py ${CMAKE_CURRENT_BINARY_DIR}/setup.py) add_custom_command(TARGET PythonInstall COMMAND "${PYTHON_EXECUTABLE}" -m pip install . diff --git a/python/setup.py b/python/setup.py index 4d03efe2..6748a57a 100644 --- a/python/setup.py +++ b/python/setup.py @@ -6,10 +6,11 @@ torch_include_dirs = '@TORCH_INCLUDE_DIRS@'.split(';') nn_plugin_header_dir = '@NN_PLUGIN_HEADER_DIR@' nn_plugin_library_dir = '@NN_PLUGIN_LIBRARY_DIR@' +cpp_std = '@EXTENSION_CXX_STANDARD@' torch_dir, _ = os.path.split('@TORCH_LIBRARY@') # setup extra compile and link arguments on Mac -extra_compile_args = ['-std=c++14'] +extra_compile_args = ['-std=c++' + cpp_std] extra_link_args = [] if platform.system() == 'Darwin': From 0f894404b8fcdc6463761fa01e1ad70ad302e13c Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 12 Dec 2023 12:47:09 +0100 Subject: [PATCH 03/31] Try to free up some space for the CI --- .github/workflows/CI.yml | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f2bd6a44..28385f1f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -77,6 +77,30 @@ jobs: NVCC_VERSION: ${{ matrix.nvcc-version }} PYTORCH_VERSION: ${{ matrix.pytorch-version }} + - name: Manage disk space + if: matrix.os == 'ubuntu' + run: | + sudo mkdir -p /opt/empty_dir || true + for d in \ + /opt/ghc \ + /opt/hostedtoolcache \ + /usr/lib/jvm \ + /usr/local/.ghcup \ + /usr/local/lib/android \ + /usr/local/share/powershell \ + /usr/share/dotnet \ + /usr/share/swift \ + ; do + sudo rsync --stats -a --delete /opt/empty_dir/ $d || true + done + sudo apt-get purge -y -f firefox \ + google-chrome-stable \ + microsoft-edge-stable + sudo apt-get autoremove -y >& /dev/null + sudo apt-get autoclean -y >& /dev/null + sudo docker image prune --all --force + df -h + - uses: conda-incubator/setup-miniconda@v2 name: "Install dependencies with Mamba" with: From a604b330f2b2d9914055d429e68c4468897aa261 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 12 Dec 2023 13:06:12 +0100 Subject: [PATCH 04/31] Clean space on CI machine before installing CUDA --- .github/workflows/CI.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 28385f1f..4135959e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -100,6 +100,29 @@ jobs: sudo apt-get autoclean -y >& /dev/null sudo docker image prune --all --force df -h + + - name: "Install CUDA Toolkit on Linux (if needed)" + uses: Jimver/cuda-toolkit@v0.2.10 + with: + cuda: ${{ matrix.cuda-version }} + linux-local-args: '["--toolkit", "--override"]' + if: startsWith(matrix.os, 'ubuntu') + + - name: "Install SDK on MacOS (if needed)" + run: source devtools/scripts/install_macos_sdk.sh + if: startsWith(matrix.os, 'macos') + + - name: "Update the conda enviroment file" + uses: cschleiden/replace-tokens@v1 + with: + tokenPrefix: '@' + tokenSuffix: '@' + files: devtools/conda-envs/build-${{ matrix.os }}.yml + env: + CUDATOOLKIT_VERSION: ${{ matrix.cuda-version }} + GCC_VERSION: ${{ matrix.gcc-version }} + NVCC_VERSION: ${{ matrix.nvcc-version }} + PYTORCH_VERSION: ${{ matrix.pytorch-version }} - uses: conda-incubator/setup-miniconda@v2 name: "Install dependencies with Mamba" From 4b6bd7364295a1496289ee974bcfb0692aad67d0 Mon Sep 17 00:00:00 2001 From: Raul Date: Mon, 29 Jan 2024 11:14:26 +0100 Subject: [PATCH 05/31] Update CMakeLists.txt --- python/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 65593c5f..56e51189 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -23,7 +23,7 @@ add_custom_command( add_custom_target(PythonInstall DEPENDS "${WRAP_FILE}" "${CMAKE_CURRENT_SOURCE_DIR}/setup.py") set(NN_PLUGIN_HEADER_DIR "${CMAKE_SOURCE_DIR}/openmmapi/include") set(NN_PLUGIN_LIBRARY_DIR "${CMAKE_BINARY_DIR}") -set(EXTENSION_CXX_STANDARD ${CMAKE_CXX_STANDARD}) +set(EXTENSION_CXX_STANDARD "${CMAKE_CXX_STANDARD}") configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py ${CMAKE_CURRENT_BINARY_DIR}/setup.py) add_custom_command(TARGET PythonInstall COMMAND "${PYTHON_EXECUTABLE}" -m pip install . From a347019ab7792b3e1e0123c1f6c61d07bd3c8242 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 09:03:38 +0200 Subject: [PATCH 06/31] Add energy parameter derivative to the API --- openmmapi/include/TorchForce.h | 23 +++++++++++++++++++---- openmmapi/src/TorchForce.cpp | 15 +++++++++++++++ python/openmmtorch.i | 2 ++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/openmmapi/include/TorchForce.h b/openmmapi/include/TorchForce.h index 4406eb20..c3b0ce8a 100644 --- a/openmmapi/include/TorchForce.h +++ b/openmmapi/include/TorchForce.h @@ -58,8 +58,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * @param file the path to the file containing the network * @param properties optional map of properties */ - TorchForce(const std::string& file, - const std::map& properties = {}); + TorchForce(const std::string& file, const std::map& properties = {}); /** * Create a TorchForce. The network is defined by a PyTorch ScriptModule * Note that this constructor makes a copy of the provided module. @@ -68,7 +67,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * @param module an instance of the torch module * @param properties optional map of properties */ - TorchForce(const torch::jit::Module &module, const std::map& properties = {}); + TorchForce(const torch::jit::Module& module, const std::map& properties = {}); /** * Get the path to the file containing the network. * If the TorchForce instance was constructed with a module, instead of a filename, @@ -78,7 +77,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { /** * Get the torch module currently in use. */ - const torch::jit::Module & getModule() const; + const torch::jit::Module& getModule() const; /** * Set whether this force makes use of periodic boundary conditions. If this is set * to true, the network must take a 3x3 tensor as its second input, which @@ -116,6 +115,19 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * @return the index of the parameter that was added */ int addGlobalParameter(const std::string& name, double defaultValue); + /** + * Add a new energy parameter derivative that the interaction may depend on. + * + * @param name the name of the parameter + */ + void addEnergyParameterDerivative(const std::string& name); + /** + * Get the name of an energy parameter derivative given its global parameter index. + * + * @param index the index of the parameter for which to get the name + * @return the parameter name + */ + const std::string& getEnergyParameterDerivativeName(int index) const; /** * Get the name of a global parameter. * @@ -156,13 +168,16 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * @return A map of property names to values. */ const std::map& getProperties() const; + protected: OpenMM::ForceImpl* createImpl() const; + private: class GlobalParameterInfo; std::string file; bool usePeriodic, outputsForces; std::vector globalParameters; + std::vector energyParameterDerivatives; torch::jit::Module module; std::map properties; std::string emptyProperty; diff --git a/openmmapi/src/TorchForce.cpp b/openmmapi/src/TorchForce.cpp index e8892048..f3c0c1ef 100644 --- a/openmmapi/src/TorchForce.cpp +++ b/openmmapi/src/TorchForce.cpp @@ -88,6 +88,21 @@ int TorchForce::addGlobalParameter(const string& name, double defaultValue) { return globalParameters.size() - 1; } +void TorchForce::addEnergyParameterDerivative(const string& name) { + for (int i = 0; i < globalParameters.size(); i++) { + if (globalParameters[i].name == name) { + energyParameterDerivatives.push_back(i); + return; + } + } +} + +const std::string& TorchForce::getEnergyParameterDerivativeName(int index) const { + if (index < 0 || index >= energyParameterDerivatives.size()) + throw OpenMM::OpenMMException("TorchForce::getEnergyParameterDerivativeName: index out of range."); + return globalParameters[energyParameterDerivatives[index]].name; +} + int TorchForce::getNumGlobalParameters() const { return globalParameters.size(); } diff --git a/python/openmmtorch.i b/python/openmmtorch.i index 05988e3d..77a72e43 100644 --- a/python/openmmtorch.i +++ b/python/openmmtorch.i @@ -74,6 +74,8 @@ public: void setGlobalParameterName(int index, const std::string& name); double getGlobalParameterDefaultValue(int index) const; void setGlobalParameterDefaultValue(int index, double defaultValue); + void addEnergyParameterDerivative(const std::string& name); + const std::string& getEnergyParameterDerivativeName(int index) const; void setProperty(const std::string& name, const std::string& value); const std::map& getProperties() const; From e12b0af15b8951fae0bedc14f19f43d94f88a47d Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 09:08:29 +0200 Subject: [PATCH 07/31] Empty commit to trigger CI From 11fb899d05f9a5db357d0063c1177af3c6326040 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 09:35:38 +0200 Subject: [PATCH 08/31] Change to int --- openmmapi/include/TorchForce.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openmmapi/include/TorchForce.h b/openmmapi/include/TorchForce.h index c3b0ce8a..79900d1e 100644 --- a/openmmapi/include/TorchForce.h +++ b/openmmapi/include/TorchForce.h @@ -177,7 +177,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { std::string file; bool usePeriodic, outputsForces; std::vector globalParameters; - std::vector energyParameterDerivatives; + std::vector energyParameterDerivatives; torch::jit::Module module; std::map properties; std::string emptyProperty; From cd35af11dc326962ceee7c534d72ea09c0c02f5d Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 09:38:38 +0200 Subject: [PATCH 09/31] Add test for energy derivatives --- python/tests/TestParameterDerivatives.py | 76 ++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 python/tests/TestParameterDerivatives.py diff --git a/python/tests/TestParameterDerivatives.py b/python/tests/TestParameterDerivatives.py new file mode 100644 index 00000000..c62bf7b7 --- /dev/null +++ b/python/tests/TestParameterDerivatives.py @@ -0,0 +1,76 @@ +import openmm as mm +import openmm.unit as unit +import openmmtorch as ot +import numpy as np +import pytest +import torch as pt +from torch import Tensor + + +class ForceWithParameters(pt.nn.Module): + + def __init__(self): + super(ForceWithParameters, self).__init__() + + def forward(self, positions: Tensor, parameter: Tensor) -> Tensor: + x2 = positions.pow(2).sum(dim=1) + u_harmonic = (parameter * x2).sum() + return u_harmonic + + +@pytest.mark.parametrize("use_cv_force", [True, False]) +@pytest.mark.parametrize("platform", ["Reference", "CPU", "CUDA", "OpenCL"]) +def testParameterEnergyDerivatives(use_cv_force, platform): + + if pt.cuda.device_count() < 1 and platform == "CUDA": + pytest.skip("A CUDA device is not available") + + # Create a random cloud of particles. + numParticles = 10 + system = mm.System() + positions = np.random.rand(numParticles, 3) + for _ in range(numParticles): + system.addParticle(1.0) + + # Create a force + pt_force = ForceWithParameters() + model = pt.jit.script(pt_force) + force = ot.TorchForce(model, {"useCUDAGraphs": "false"}) + # Add a parameter + parameter = 1.0 + force.addGlobalParameter("parameter", parameter) + # Enable energy derivatives for the parameter + force.setEnergyParameterDerivatives("parameter") + force.setOutputsForces(True) + if use_cv_force: + # Wrap TorchForce into CustomCVForce + cv_force = mm.CustomCVForce("force") + cv_force.addCollectiveVariable("force", force) + system.addForce(cv_force) + else: + system.addForce(force) + + # Compute the forces and energy. + integ = mm.VerletIntegrator(1.0) + platform = mm.Platform.getPlatformByName(platform) + context = mm.Context(system, integ, platform) + context.setPositions(positions) + state = context.getState( + getEnergy=True, getForces=True, getEnergyParameterDerivatives=True + ) + + # See if the energy and forces and the parameter derivative are correct. + # The network defines a potential of the form E(r) = parameter*|r|^2 + r2 = np.sum(positions * positions) + expectedEnergy = parameter * r2 + assert np.allclose( + expectedEnergy, + state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole), + ) + assert np.allclose(-2 * parameter * positions, state.getForces(asNumpy=True)) + assert np.allclose( + 2 * r2, + state.getEnergyParameterDerivatives()["parameter"].value_in_unit( + unit.kilojoules_per_mole + ), + ) From b69456e07f3a5ce150f0b2162e8f7dbc06e2934b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 09:51:52 +0200 Subject: [PATCH 10/31] Update env --- devtools/conda-envs/build-ubuntu-22.04.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/devtools/conda-envs/build-ubuntu-22.04.yml b/devtools/conda-envs/build-ubuntu-22.04.yml index ffd4d2f7..091b45e1 100644 --- a/devtools/conda-envs/build-ubuntu-22.04.yml +++ b/devtools/conda-envs/build-ubuntu-22.04.yml @@ -18,3 +18,6 @@ dependencies: - swig - sysroot_linux-64 2.17 - torchani + # Required by python<3.8 + # xref: https://github.com/conda-forge/linux-sysroot-feedstock/issues/52 + - libxcrypt From fc5a85fa15f2b9a346003761e14fd979b63103bd Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 09:51:52 +0200 Subject: [PATCH 11/31] Update env --- devtools/conda-envs/build-ubuntu-22.04.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/devtools/conda-envs/build-ubuntu-22.04.yml b/devtools/conda-envs/build-ubuntu-22.04.yml index ffd4d2f7..091b45e1 100644 --- a/devtools/conda-envs/build-ubuntu-22.04.yml +++ b/devtools/conda-envs/build-ubuntu-22.04.yml @@ -18,3 +18,6 @@ dependencies: - swig - sysroot_linux-64 2.17 - torchani + # Required by python<3.8 + # xref: https://github.com/conda-forge/linux-sysroot-feedstock/issues/52 + - libxcrypt From 033b2d483866e2bf4ea5782624460e6bef90a061 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 10:05:54 +0200 Subject: [PATCH 12/31] Typo --- python/tests/TestParameterDerivatives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/TestParameterDerivatives.py b/python/tests/TestParameterDerivatives.py index c62bf7b7..cb30300b 100644 --- a/python/tests/TestParameterDerivatives.py +++ b/python/tests/TestParameterDerivatives.py @@ -40,7 +40,7 @@ def testParameterEnergyDerivatives(use_cv_force, platform): parameter = 1.0 force.addGlobalParameter("parameter", parameter) # Enable energy derivatives for the parameter - force.setEnergyParameterDerivatives("parameter") + force.addEnergyParameterDerivatives("parameter") force.setOutputsForces(True) if use_cv_force: # Wrap TorchForce into CustomCVForce From ded655914fce7ca630a55756dca18eba7b5e7011 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 10:15:17 +0200 Subject: [PATCH 13/31] typo --- python/tests/TestParameterDerivatives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/TestParameterDerivatives.py b/python/tests/TestParameterDerivatives.py index cb30300b..3d8e339e 100644 --- a/python/tests/TestParameterDerivatives.py +++ b/python/tests/TestParameterDerivatives.py @@ -40,7 +40,7 @@ def testParameterEnergyDerivatives(use_cv_force, platform): parameter = 1.0 force.addGlobalParameter("parameter", parameter) # Enable energy derivatives for the parameter - force.addEnergyParameterDerivatives("parameter") + force.addEnergyParameterDerivative("parameter") force.setOutputsForces(True) if use_cv_force: # Wrap TorchForce into CustomCVForce From 55de9ee606c527a7673a3bfd7ecd56487c09f50e Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 11:12:03 +0200 Subject: [PATCH 14/31] Add getNumEnergyParameterDerivatives --- openmmapi/include/TorchForce.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/openmmapi/include/TorchForce.h b/openmmapi/include/TorchForce.h index 79900d1e..53b1b56f 100644 --- a/openmmapi/include/TorchForce.h +++ b/openmmapi/include/TorchForce.h @@ -128,6 +128,13 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * @return the parameter name */ const std::string& getEnergyParameterDerivativeName(int index) const; + /** + * Get the number of global parameters with respect to which the derivative of the energy + * should be computed. + */ + int getNumEnergyParameterDerivatives() const { + return energyParameterDerivatives.size(); + } /** * Get the name of a global parameter. * From bcc87597d485c1099aec8177b491fd19ced065c2 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 11:12:59 +0200 Subject: [PATCH 15/31] Implement Reference platform --- .../reference/src/ReferenceTorchKernels.cpp | 32 ++++++++++++++++--- .../reference/src/ReferenceTorchKernels.h | 3 +- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/platforms/reference/src/ReferenceTorchKernels.cpp b/platforms/reference/src/ReferenceTorchKernels.cpp index 5de8a2b8..9b329386 100644 --- a/platforms/reference/src/ReferenceTorchKernels.cpp +++ b/platforms/reference/src/ReferenceTorchKernels.cpp @@ -54,6 +54,11 @@ static Vec3* extractBoxVectors(ContextImpl& context) { return data->periodicBoxVectors; } +static map& extractEnergyParameterDerivatives(ContextImpl& context) { + ReferencePlatform::PlatformData* data = reinterpret_cast(context.getPlatformData()); + return *data->energyParameterDerivatives; +} + ReferenceCalcTorchForceKernel::~ReferenceCalcTorchForceKernel() { } @@ -63,6 +68,8 @@ void ReferenceCalcTorchForceKernel::initialize(const System& system, const Torch outputsForces = force.getOutputsForces(); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalNames.push_back(force.getGlobalParameterName(i)); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) + energyParameterDerivatives.push_back(force.getEnergyParameterDerivativeName(i)); } double ReferenceCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { @@ -76,15 +83,21 @@ double ReferenceCalcTorchForceKernel::execute(ContextImpl& context, bool include torch::Tensor boxTensor = torch::from_blob(box, {3, 3}, torch::kFloat64); inputs.push_back(boxTensor); } - for (const string& name : globalNames) - inputs.push_back(torch::tensor(context.getParameter(name))); + // Store parameter tensors that need derivatives + vector parameterTensors; + for (const string& name : globalNames) { + // Require grad if the parameter is in the list of energy parameter derivatives + bool requires_grad = std::find(energyParameterDerivatives.begin(), energyParameterDerivatives.end(), name) != energyParameterDerivatives.end(); + auto tensor = torch::tensor(context.getParameter(name), torch::TensorOptions().requires_grad(requires_grad)); + parameterTensors.emplace_back(tensor); + inputs.push_back(tensor); + } torch::Tensor energyTensor, forceTensor; if (outputsForces) { auto outputs = module.forward(inputs).toTuple(); energyTensor = outputs->elements()[0].toTensor(); forceTensor = outputs->elements()[1].toTensor(); - } - else + } else energyTensor = module.forward(inputs).toTensor(); if (includeForces) { if (!outputsForces) { @@ -97,7 +110,16 @@ double ReferenceCalcTorchForceKernel::execute(ContextImpl& context, bool include double forceSign = (outputsForces ? 1.0 : -1.0); for (int i = 0; i < numParticles; i++) for (int j = 0; j < 3; j++) - force[i][j] += forceSign*outputForces[3*i+j]; + force[i][j] += forceSign * outputForces[3 * i + j]; + } + // Store parameter energy derivatives + auto& derivs = extractEnergyParameterDerivatives(context); + for (int i = 0; i < energyParameterDerivatives.size(); i++) { + // Compute the derivative of the energy with respect to this parameter. + // The derivative is stored in the gradient of the parameter tensor. + double derivative = parameterTensors[i].grad().item(); + auto name = energyParameterDerivatives[i]; + derivs[name] = derivative; } return energyTensor.item(); } diff --git a/platforms/reference/src/ReferenceTorchKernels.h b/platforms/reference/src/ReferenceTorchKernels.h index f3abc1d2..b146699b 100644 --- a/platforms/reference/src/ReferenceTorchKernels.h +++ b/platforms/reference/src/ReferenceTorchKernels.h @@ -48,7 +48,7 @@ class ReferenceCalcTorchForceKernel : public CalcTorchForceKernel { ~ReferenceCalcTorchForceKernel(); /** * Initialize the kernel. - * + * * @param system the System this kernel will be applied to * @param force the TorchForce this kernel will be used for * @param module the PyTorch module to use for computing forces and energy @@ -67,6 +67,7 @@ class ReferenceCalcTorchForceKernel : public CalcTorchForceKernel { torch::jit::script::Module module; std::vector positions, boxVectors; std::vector globalNames; + std::vector energyParameterDerivatives; bool usePeriodic, outputsForces; }; From aba8a62062c03cb2597b6dc7ba6d6d78a64258f4 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 11:13:31 +0200 Subject: [PATCH 16/31] Update test --- python/tests/TestParameterDerivatives.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/tests/TestParameterDerivatives.py b/python/tests/TestParameterDerivatives.py index 3d8e339e..66c2521c 100644 --- a/python/tests/TestParameterDerivatives.py +++ b/python/tests/TestParameterDerivatives.py @@ -18,7 +18,7 @@ def forward(self, positions: Tensor, parameter: Tensor) -> Tensor: return u_harmonic -@pytest.mark.parametrize("use_cv_force", [True, False]) +@pytest.mark.parametrize("use_cv_force", [False, True]) @pytest.mark.parametrize("platform", ["Reference", "CPU", "CUDA", "OpenCL"]) def testParameterEnergyDerivatives(use_cv_force, platform): @@ -41,7 +41,7 @@ def testParameterEnergyDerivatives(use_cv_force, platform): force.addGlobalParameter("parameter", parameter) # Enable energy derivatives for the parameter force.addEnergyParameterDerivative("parameter") - force.setOutputsForces(True) + force.setOutputsForces(False) if use_cv_force: # Wrap TorchForce into CustomCVForce cv_force = mm.CustomCVForce("force") @@ -56,7 +56,7 @@ def testParameterEnergyDerivatives(use_cv_force, platform): context = mm.Context(system, integ, platform) context.setPositions(positions) state = context.getState( - getEnergy=True, getForces=True, getEnergyParameterDerivatives=True + getEnergy=True, getForces=True, getParameterDerivatives=True ) # See if the energy and forces and the parameter derivative are correct. @@ -69,8 +69,6 @@ def testParameterEnergyDerivatives(use_cv_force, platform): ) assert np.allclose(-2 * parameter * positions, state.getForces(asNumpy=True)) assert np.allclose( - 2 * r2, - state.getEnergyParameterDerivatives()["parameter"].value_in_unit( - unit.kilojoules_per_mole - ), + r2, + state.getEnergyParameterDerivatives()["parameter"], ) From f26e867fb3a2eaad9b32f1cc1e9c439e6bf5dd33 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 11:30:34 +0200 Subject: [PATCH 17/31] Update test adding two parameters --- python/tests/TestParameterDerivatives.py | 44 +++++++++++++++--------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/python/tests/TestParameterDerivatives.py b/python/tests/TestParameterDerivatives.py index 66c2521c..54de917a 100644 --- a/python/tests/TestParameterDerivatives.py +++ b/python/tests/TestParameterDerivatives.py @@ -12,9 +12,11 @@ class ForceWithParameters(pt.nn.Module): def __init__(self): super(ForceWithParameters, self).__init__() - def forward(self, positions: Tensor, parameter: Tensor) -> Tensor: + def forward( + self, positions: Tensor, parameter1: Tensor, parameter2: Tensor + ) -> Tensor: x2 = positions.pow(2).sum(dim=1) - u_harmonic = (parameter * x2).sum() + u_harmonic = ((parameter1 + parameter2**2) * x2).sum() return u_harmonic @@ -35,21 +37,23 @@ def testParameterEnergyDerivatives(use_cv_force, platform): # Create a force pt_force = ForceWithParameters() model = pt.jit.script(pt_force) - force = ot.TorchForce(model, {"useCUDAGraphs": "false"}) + tforce = ot.TorchForce(model, {"useCUDAGraphs": "false"}) # Add a parameter - parameter = 1.0 - force.addGlobalParameter("parameter", parameter) - # Enable energy derivatives for the parameter - force.addEnergyParameterDerivative("parameter") - force.setOutputsForces(False) + parameter1 = 1.0 + parameter2 = 1.0 + tforce.setOutputsForces(False) + tforce.addGlobalParameter("parameter1", parameter1) + tforce.addEnergyParameterDerivative("parameter1") + tforce.addGlobalParameter("parameter2", parameter2) + tforce.addEnergyParameterDerivative("parameter2") if use_cv_force: # Wrap TorchForce into CustomCVForce - cv_force = mm.CustomCVForce("force") - cv_force.addCollectiveVariable("force", force) - system.addForce(cv_force) + force = mm.CustomCVForce("force") + force.addCollectiveVariable("force", tforce) else: - system.addForce(force) - + force = tforce + # Enable energy derivatives for the parameter + system.addForce(force) # Compute the forces and energy. integ = mm.VerletIntegrator(1.0) platform = mm.Platform.getPlatformByName(platform) @@ -60,15 +64,21 @@ def testParameterEnergyDerivatives(use_cv_force, platform): ) # See if the energy and forces and the parameter derivative are correct. - # The network defines a potential of the form E(r) = parameter*|r|^2 + # The network defines a potential of the form E(r) = (parameter1 + parameter2**2)*|r|^2 r2 = np.sum(positions * positions) - expectedEnergy = parameter * r2 + expectedEnergy = (parameter1 + parameter2**2) * r2 assert np.allclose( expectedEnergy, state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole), ) - assert np.allclose(-2 * parameter * positions, state.getForces(asNumpy=True)) + assert np.allclose( + -2 * (parameter1 + parameter2**2) * positions, state.getForces(asNumpy=True) + ) assert np.allclose( r2, - state.getEnergyParameterDerivatives()["parameter"], + state.getEnergyParameterDerivatives()["parameter1"], + ) + assert np.allclose( + 2 * parameter2 * r2, + state.getEnergyParameterDerivatives()["parameter2"], ) From 00b8120c536e02de5601d48432fab8a70c05f338 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 14:06:25 +0200 Subject: [PATCH 18/31] Implement CUDA and OpenCL --- platforms/cuda/src/CudaTorchKernels.cpp | 32 +++++++++++++++++++-- platforms/cuda/src/CudaTorchKernels.h | 1 + platforms/opencl/src/OpenCLTorchKernels.cpp | 27 +++++++++++++++-- platforms/opencl/src/OpenCLTorchKernels.h | 3 +- 4 files changed, 56 insertions(+), 7 deletions(-) diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 60ea7794..d59e9167 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -51,6 +51,12 @@ using namespace std; throw OpenMMException(m.str()); \ } +static map& extractEnergyParameterDerivatives(CudaContext& context) { + //CudaPlatform::PlatformData* data = reinterpret_cast(context.getPlatformData()); + //return *data->energyParameterDerivatives; + context.getEnergyParamDerivWorkspace(); +} + CudaCalcTorchForceKernel::CudaCalcTorchForceKernel(string name, const Platform& platform, CudaContext& cu) : CalcTorchForceKernel(name, platform), hasInitializedKernel(false), cu(cu) { // Explicitly activate the primary context CHECK_RESULT(cuDevicePrimaryCtxRetain(&primaryContext, cu.getDevice()), "Failed to retain the primary context"); @@ -66,6 +72,8 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce outputsForces = force.getOutputsForces(); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalNames.push_back(force.getGlobalParameterName(i)); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) + energyParameterDerivatives.push_back(force.getEnergyParameterDerivativeName(i)); int numParticles = system.getNumParticles(); // Push the PyTorch context @@ -148,8 +156,16 @@ std::vector CudaCalcTorchForceKernel::prepareTorchInputs(Con vector inputs = {posTensor}; if (usePeriodic) inputs.push_back(boxTensor); - for (const string& name : globalNames) - inputs.push_back(torch::tensor(context.getParameter(name))); + for (const string& name : globalNames) { + // Require grad if the parameter is in the list of energy parameter derivatives + bool requires_grad = std::find(energyParameterDerivatives.begin(), energyParameterDerivatives.end(), name) != energyParameterDerivatives.end(); + auto options = torch::TensorOptions().requires_grad(requires_grad).device(posTensor.device()); + auto tensor = torch::tensor(context.getParameter(name), options); + // parameterTensors.emplace_back(tensor); + inputs.push_back(tensor); + } + // for (const string& name : globalNames) + // inputs.push_back(torch::tensor(context.getParameter(name))); return inputs; } @@ -210,7 +226,7 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce // Record graph if not already done bool is_graph_captured = false; if (graphs.find(includeForces) == graphs.end()) { - //CUDA graph capture must occur in a non-default stream + // CUDA graph capture must occur in a non-default stream const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex()); const c10::cuda::CUDAStreamGuard guard(stream); // Warmup the graph workload before capturing. This first @@ -249,6 +265,16 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce } // Get energy const double energy = energyTensor.item(); // This implicitly synchronizes the PyTorch context + // Store parameter energy derivatives + auto& derivs = extractEnergyParameterDerivatives(cu); + int firstParameterIndex = usePeriodic ? 2 : 1; // Skip the position and box tensors + for (int i = 0; i < energyParameterDerivatives.size(); i++) { + // Compute the derivative of the energy with respect to this parameter. + // The derivative is stored in the gradient of the parameter tensor. + double derivative = inputs[i + firstParameterIndex].toTensor().grad().item(); + auto name = energyParameterDerivatives[i]; + derivs[name] = derivative; + } // Pop to the PyTorch context CUcontext ctx; CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context"); diff --git a/platforms/cuda/src/CudaTorchKernels.h b/platforms/cuda/src/CudaTorchKernels.h index 13f2a9b6..e93e0508 100644 --- a/platforms/cuda/src/CudaTorchKernels.h +++ b/platforms/cuda/src/CudaTorchKernels.h @@ -72,6 +72,7 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel { torch::Tensor posTensor, boxTensor; torch::Tensor energyTensor, forceTensor; std::vector globalNames; + std::vector energyParameterDerivatives; bool usePeriodic, outputsForces; CUfunction copyInputsKernel, addForcesKernel; CUcontext primaryContext; diff --git a/platforms/opencl/src/OpenCLTorchKernels.cpp b/platforms/opencl/src/OpenCLTorchKernels.cpp index a232b1c6..a5d54a6c 100644 --- a/platforms/opencl/src/OpenCLTorchKernels.cpp +++ b/platforms/opencl/src/OpenCLTorchKernels.cpp @@ -38,6 +38,10 @@ using namespace TorchPlugin; using namespace OpenMM; using namespace std; +static map& extractEnergyParameterDerivatives(OpenCLContext& cl) { + return cl.getEnergyParamDerivWorkspace(); +} + OpenCLCalcTorchForceKernel::~OpenCLCalcTorchForceKernel() { } @@ -81,8 +85,16 @@ double OpenCLCalcTorchForceKernel::execute(ContextImpl& context, bool includeFor boxTensor = boxTensor.to(torch::kFloat32); inputs.push_back(boxTensor); } - for (const string& name : globalNames) - inputs.push_back(torch::tensor(context.getParameter(name))); + for (const string& name : globalNames) { + // Require grad if the parameter is in the list of energy parameter derivatives + bool requires_grad = std::find(energyParameterDerivatives.begin(), energyParameterDerivatives.end(), name) != energyParameterDerivatives.end(); + auto options = torch::TensorOptions().requires_grad(requires_grad).device(posTensor.device()); + auto tensor = torch::tensor(context.getParameter(name), options); + // parameterTensors.emplace_back(tensor); + inputs.push_back(tensor); + } + // for (const string& name : globalNames) + // inputs.push_back(torch::tensor(context.getParameter(name))); torch::Tensor energyTensor, forceTensor; if (outputsForces) { auto outputs = module.forward(inputs).toTuple(); @@ -115,6 +127,15 @@ double OpenCLCalcTorchForceKernel::execute(ContextImpl& context, bool includeFor addForcesKernel.setArg(4, outputsForces ? 1 : -1); cl.executeKernel(addForcesKernel, numParticles); } + // Store parameter energy derivatives + auto& derivs = extractEnergyParameterDerivatives(cl); + int firstParameterIndex = usePeriodic ? 2 : 1; // Skip the position and box tensors + for (int i = 0; i < energyParameterDerivatives.size(); i++) { + // Compute the derivative of the energy with respect to this parameter. + // The derivative is stored in the gradient of the parameter tensor. + double derivative = inputs[i + firstParameterIndex].toTensor().grad().item(); + auto name = energyParameterDerivatives[i]; + derivs[name] = derivative; + } return energyTensor.item(); } - diff --git a/platforms/opencl/src/OpenCLTorchKernels.h b/platforms/opencl/src/OpenCLTorchKernels.h index d5ccdee9..7a46d9af 100644 --- a/platforms/opencl/src/OpenCLTorchKernels.h +++ b/platforms/opencl/src/OpenCLTorchKernels.h @@ -49,7 +49,7 @@ class OpenCLCalcTorchForceKernel : public CalcTorchForceKernel { ~OpenCLCalcTorchForceKernel(); /** * Initialize the kernel. - * + * * @param system the System this kernel will be applied to * @param force the TorchForce this kernel will be used for * @param module the PyTorch module to use for computing forces and energy @@ -69,6 +69,7 @@ class OpenCLCalcTorchForceKernel : public CalcTorchForceKernel { OpenMM::OpenCLContext& cl; torch::jit::script::Module module; std::vector globalNames; + std::vector energyParameterDerivatives; bool usePeriodic, outputsForces; OpenMM::OpenCLArray networkForces; cl::Kernel addForcesKernel; From 277c90d29c28df567d2368229ebfb910b06ea010 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 14:20:18 +0200 Subject: [PATCH 19/31] Initialize OpenCL energyParameterDerivatives map --- platforms/opencl/src/OpenCLTorchKernels.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/platforms/opencl/src/OpenCLTorchKernels.cpp b/platforms/opencl/src/OpenCLTorchKernels.cpp index a5d54a6c..d78ee6ad 100644 --- a/platforms/opencl/src/OpenCLTorchKernels.cpp +++ b/platforms/opencl/src/OpenCLTorchKernels.cpp @@ -51,6 +51,9 @@ void OpenCLCalcTorchForceKernel::initialize(const System& system, const TorchFor outputsForces = force.getOutputsForces(); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalNames.push_back(force.getGlobalParameterName(i)); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) + energyParameterDerivatives.push_back(force.getEnergyParameterDerivativeName(i)); + int numParticles = system.getNumParticles(); // Inititalize OpenCL objects. From 9d6d34c809d8eb60c639a645bb23a6dbb2bb6a6c Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 14:39:45 +0200 Subject: [PATCH 20/31] Fix return --- platforms/cuda/src/CudaTorchKernels.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index d59e9167..56559f62 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -52,9 +52,7 @@ using namespace std; } static map& extractEnergyParameterDerivatives(CudaContext& context) { - //CudaPlatform::PlatformData* data = reinterpret_cast(context.getPlatformData()); - //return *data->energyParameterDerivatives; - context.getEnergyParamDerivWorkspace(); + return context.getEnergyParamDerivWorkspace(); } CudaCalcTorchForceKernel::CudaCalcTorchForceKernel(string name, const Platform& platform, CudaContext& cu) : CalcTorchForceKernel(name, platform), hasInitializedKernel(false), cu(cu) { From 4778eacb1b29248c8759442c04337cb162a007de Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 14:39:50 +0200 Subject: [PATCH 21/31] Remove commented code --- platforms/cuda/src/CudaTorchKernels.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 56559f62..f5ab9dfa 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -159,11 +159,8 @@ std::vector CudaCalcTorchForceKernel::prepareTorchInputs(Con bool requires_grad = std::find(energyParameterDerivatives.begin(), energyParameterDerivatives.end(), name) != energyParameterDerivatives.end(); auto options = torch::TensorOptions().requires_grad(requires_grad).device(posTensor.device()); auto tensor = torch::tensor(context.getParameter(name), options); - // parameterTensors.emplace_back(tensor); inputs.push_back(tensor); } - // for (const string& name : globalNames) - // inputs.push_back(torch::tensor(context.getParameter(name))); return inputs; } From 4031cc53e755d6bc351117486c06006b6ab16337 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 15:19:39 +0200 Subject: [PATCH 22/31] Register params to the context --- platforms/cuda/src/CudaTorchKernels.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index f5ab9dfa..d28ed170 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -70,8 +70,10 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce outputsForces = force.getOutputsForces(); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalNames.push_back(force.getGlobalParameterName(i)); - for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) - energyParameterDerivatives.push_back(force.getEnergyParameterDerivativeName(i)); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++){ + auto name = force.getEnergyParameterDerivativeName(i); + energyParameterDerivatives.push_back(name); + cu.addEnergyParameterDerivative(name); int numParticles = system.getNumParticles(); // Push the PyTorch context From 0acc338576a21e573742cf92f677e43346e9edf8 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 15:20:01 +0200 Subject: [PATCH 23/31] Make sure backwards is called with all the requested parameters as inputs --- platforms/cuda/src/CudaTorchKernels.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index d28ed170..3b2eb29e 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -204,7 +204,15 @@ static void executeGraph(bool outputsForces, bool includeForces, torch::jit::scr // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions // See https://github.com/openmm/openmm-torch/pull/120/ auto none = torch::Tensor(); - energyTensor.backward(none, false, false, posTensor); + std::vector inputs_with_grad; + for (auto& input : inputs) { + if (input.isTensor()) { + auto tensor = input.toTensor(); + if (tensor.requires_grad()) + inputs_with_grad.push_back(tensor); + } + } + energyTensor.backward(none, false, false, inputs_with_grad); // This is minus the forces, we change the sign later on forceTensor = posTensor.grad().clone(); // Zero the gradient to avoid accumulating it From 9c65e6b7fb91723e2e9548ff01f6e5308cd1b0e0 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 15:20:41 +0200 Subject: [PATCH 24/31] Small changes --- platforms/cuda/src/CudaTorchKernels.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 3b2eb29e..26812740 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -51,10 +51,6 @@ using namespace std; throw OpenMMException(m.str()); \ } -static map& extractEnergyParameterDerivatives(CudaContext& context) { - return context.getEnergyParamDerivWorkspace(); -} - CudaCalcTorchForceKernel::CudaCalcTorchForceKernel(string name, const Platform& platform, CudaContext& cu) : CalcTorchForceKernel(name, platform), hasInitializedKernel(false), cu(cu) { // Explicitly activate the primary context CHECK_RESULT(cuDevicePrimaryCtxRetain(&primaryContext, cu.getDevice()), "Failed to retain the primary context"); @@ -74,6 +70,7 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce auto name = force.getEnergyParameterDerivativeName(i); energyParameterDerivatives.push_back(name); cu.addEnergyParameterDerivative(name); + } int numParticles = system.getNumParticles(); // Push the PyTorch context @@ -161,7 +158,7 @@ std::vector CudaCalcTorchForceKernel::prepareTorchInputs(Con bool requires_grad = std::find(energyParameterDerivatives.begin(), energyParameterDerivatives.end(), name) != energyParameterDerivatives.end(); auto options = torch::TensorOptions().requires_grad(requires_grad).device(posTensor.device()); auto tensor = torch::tensor(context.getParameter(name), options); - inputs.push_back(tensor); + inputs.emplace_back(tensor); } return inputs; } @@ -271,12 +268,13 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce // Get energy const double energy = energyTensor.item(); // This implicitly synchronizes the PyTorch context // Store parameter energy derivatives - auto& derivs = extractEnergyParameterDerivatives(cu); + auto& derivs = cu.getEnergyParamDerivWorkspace(); int firstParameterIndex = usePeriodic ? 2 : 1; // Skip the position and box tensors for (int i = 0; i < energyParameterDerivatives.size(); i++) { // Compute the derivative of the energy with respect to this parameter. // The derivative is stored in the gradient of the parameter tensor. - double derivative = inputs[i + firstParameterIndex].toTensor().grad().item(); + auto parameter_tensor = inputs[i + firstParameterIndex].toTensor(); + double derivative = parameter_tensor.grad().item(); auto name = energyParameterDerivatives[i]; derivs[name] = derivative; } From 0dbc814df8d7b5b26297aad01b327fb9ae491506 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 15:20:53 +0200 Subject: [PATCH 25/31] Small changes --- platforms/opencl/src/OpenCLTorchKernels.cpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/platforms/opencl/src/OpenCLTorchKernels.cpp b/platforms/opencl/src/OpenCLTorchKernels.cpp index d78ee6ad..13200086 100644 --- a/platforms/opencl/src/OpenCLTorchKernels.cpp +++ b/platforms/opencl/src/OpenCLTorchKernels.cpp @@ -38,10 +38,6 @@ using namespace TorchPlugin; using namespace OpenMM; using namespace std; -static map& extractEnergyParameterDerivatives(OpenCLContext& cl) { - return cl.getEnergyParamDerivWorkspace(); -} - OpenCLCalcTorchForceKernel::~OpenCLCalcTorchForceKernel() { } @@ -96,8 +92,6 @@ double OpenCLCalcTorchForceKernel::execute(ContextImpl& context, bool includeFor // parameterTensors.emplace_back(tensor); inputs.push_back(tensor); } - // for (const string& name : globalNames) - // inputs.push_back(torch::tensor(context.getParameter(name))); torch::Tensor energyTensor, forceTensor; if (outputsForces) { auto outputs = module.forward(inputs).toTuple(); @@ -131,7 +125,7 @@ double OpenCLCalcTorchForceKernel::execute(ContextImpl& context, bool includeFor cl.executeKernel(addForcesKernel, numParticles); } // Store parameter energy derivatives - auto& derivs = extractEnergyParameterDerivatives(cl); + auto& derivs = cl.getEnergyParamDerivWorkspace(); int firstParameterIndex = usePeriodic ? 2 : 1; // Skip the position and box tensors for (int i = 0; i < energyParameterDerivatives.size(); i++) { // Compute the derivative of the energy with respect to this parameter. From 46990a96fccefe4011a2287c662e66a0c41cf8fc Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 16:09:33 +0200 Subject: [PATCH 26/31] Test more cases --- python/tests/TestParameterDerivatives.py | 43 +++++++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/python/tests/TestParameterDerivatives.py b/python/tests/TestParameterDerivatives.py index 54de917a..da649ced 100644 --- a/python/tests/TestParameterDerivatives.py +++ b/python/tests/TestParameterDerivatives.py @@ -5,12 +5,13 @@ import pytest import torch as pt from torch import Tensor +from typing import Tuple, List, Optional -class ForceWithParameters(pt.nn.Module): +class EnergyWithParameters(pt.nn.Module): def __init__(self): - super(ForceWithParameters, self).__init__() + super(EnergyWithParameters, self).__init__() def forward( self, positions: Tensor, parameter1: Tensor, parameter2: Tensor @@ -20,9 +21,38 @@ def forward( return u_harmonic +class EnergyForceWithParameters(pt.nn.Module): + + def __init__(self, use_backwards=False): + super(EnergyForceWithParameters, self).__init__() + self.use_backwards = use_backwards + + def forward( + self, positions: Tensor, parameter1: Tensor, parameter2: Tensor + ) -> Tuple[Tensor, Tensor]: + x2 = positions.pow(2).sum(dim=1) + u_harmonic = ((parameter1 + parameter2**2) * x2).sum() + # This way of computing the forces forcefully leaves out the parameter derivatives + if self.use_backwards: + grad_outputs: List[Optional[Tensor]] = [pt.ones_like(u_harmonic)] + dy = pt.autograd.grad( + [u_harmonic], + [positions], + grad_outputs=grad_outputs, + create_graph=False, + retain_graph=False, + )[0] + assert dy is not None + forces = -dy + else: + forces = -2 * (parameter1 + parameter2**2) * positions + return u_harmonic, forces + + @pytest.mark.parametrize("use_cv_force", [False, True]) @pytest.mark.parametrize("platform", ["Reference", "CPU", "CUDA", "OpenCL"]) -def testParameterEnergyDerivatives(use_cv_force, platform): +@pytest.mark.parametrize("return_forces", [False, True]) +def testParameterEnergyDerivatives(use_cv_force, platform, return_forces): if pt.cuda.device_count() < 1 and platform == "CUDA": pytest.skip("A CUDA device is not available") @@ -35,13 +65,16 @@ def testParameterEnergyDerivatives(use_cv_force, platform): system.addParticle(1.0) # Create a force - pt_force = ForceWithParameters() + if return_forces: + pt_force = EnergyForceWithParameters() + else: + pt_force = EnergyWithParameters() model = pt.jit.script(pt_force) tforce = ot.TorchForce(model, {"useCUDAGraphs": "false"}) # Add a parameter parameter1 = 1.0 parameter2 = 1.0 - tforce.setOutputsForces(False) + tforce.setOutputsForces(return_forces) tforce.addGlobalParameter("parameter1", parameter1) tforce.addEnergyParameterDerivative("parameter1") tforce.addGlobalParameter("parameter2", parameter2) From ce80b7d4abd62c4648857cc444a252e471b30e0c Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 16:10:42 +0200 Subject: [PATCH 27/31] Handle CUDA graphs --- platforms/cuda/src/CudaTorchKernels.cpp | 75 +++++++++++++++---------- platforms/cuda/src/CudaTorchKernels.h | 1 + 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 26812740..782632c1 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -66,10 +66,10 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce outputsForces = force.getOutputsForces(); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalNames.push_back(force.getGlobalParameterName(i)); - for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++){ + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) { auto name = force.getEnergyParameterDerivativeName(i); energyParameterDerivatives.push_back(name); - cu.addEnergyParameterDerivative(name); + cu.addEnergyParameterDerivative(name); } int numParticles = system.getNumParticles(); @@ -86,6 +86,7 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce boxTensor = torch::empty({3, 3}, options); energyTensor = torch::empty({0}, options); forceTensor = torch::empty({0}, options); + gradientTensors.resize(force.getNumEnergyParameterDerivatives(), torch::empty({0}, options)); // Pop the PyToch context CUcontext ctx; CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context"); @@ -189,31 +190,43 @@ void CudaCalcTorchForceKernel::addForces(torch::Tensor& forceTensor) { * implicit synchronizations) will result in a CUDA error. */ static void executeGraph(bool outputsForces, bool includeForces, torch::jit::script::Module& module, vector& inputs, torch::Tensor& posTensor, torch::Tensor& energyTensor, - torch::Tensor& forceTensor) { + torch::Tensor& forceTensor, std::vector& gradientTensors) { if (outputsForces) { auto outputs = module.forward(inputs).toTuple(); energyTensor = outputs->elements()[0].toTensor(); forceTensor = outputs->elements()[1].toTensor(); } else { energyTensor = module.forward(inputs).toTensor(); - // Compute force by backpropagating the PyTorch model - if (includeForces) { - // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions - // See https://github.com/openmm/openmm-torch/pull/120/ - auto none = torch::Tensor(); - std::vector inputs_with_grad; - for (auto& input : inputs) { - if (input.isTensor()) { - auto tensor = input.toTensor(); - if (tensor.requires_grad()) - inputs_with_grad.push_back(tensor); - } - } - energyTensor.backward(none, false, false, inputs_with_grad); - // This is minus the forces, we change the sign later on - forceTensor = posTensor.grad().clone(); - // Zero the gradient to avoid accumulating it - posTensor.grad().zero_(); + } + // Compute any gradients by backpropagating the PyTorch model + std::vector inputs_with_grad; + if (includeForces && !outputsForces) { + inputs_with_grad.push_back(posTensor); + } + for (int i = 1; i < inputs.size(); i++) { // Skip the positions + auto& input = inputs[i]; + if (input.isTensor()) { + auto tensor = input.toTensor(); + if (tensor.requires_grad()) + inputs_with_grad.emplace_back(tensor); + } + } + if (inputs_with_grad.size() > 0) { + // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions + // See https://github.com/openmm/openmm-torch/pull/120/ + auto none = torch::Tensor(); + energyTensor.backward(none, false, false, inputs_with_grad); + // Store the gradients for the energy parameters + bool isForceFirst = includeForces && !outputsForces; + for (int i = 0; i < inputs_with_grad.size(); i++) { + if (i == 0 && isForceFirst) { + // This is minus the forces, we change the sign later on + forceTensor = inputs_with_grad[i].grad().clone(); + inputs_with_grad[i].grad().zero_(); + } else { + gradientTensors[i - isForceFirst] = inputs_with_grad[i].grad().clone(); + inputs_with_grad[i].grad().zero_(); + } } } } @@ -222,15 +235,17 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce // Push to the PyTorch context CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context"); auto inputs = prepareTorchInputs(context); + // Store the executeGraph call in a lambda to allow for easy reuse + auto payload = [&]() { executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, gradientTensors); }; if (!useGraphs) { - executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); + payload(); } else { // Record graph if not already done bool is_graph_captured = false; if (graphs.find(includeForces) == graphs.end()) { // CUDA graph capture must occur in a non-default stream const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex()); - const c10::cuda::CUDAStreamGuard guard(stream); + const c10::cuda::CUDAStreamGuard guard(stream); // Warmup the graph workload before capturing. This first // run before capture sets up allocations so that no // allocations are needed after. Pytorch's allocator is @@ -238,14 +253,14 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce // record static pointers and shapes during capture. try { for (int i = 0; i < this->warmupSteps; i++) - executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); + payload(); } catch (std::exception& e) { throw OpenMMException(string("TorchForce Failed to warmup the model before graph construction. Torch reported the following error:\n") + e.what()); } graphs[includeForces].capture_begin(); try { - executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); + payload(); is_graph_captured = true; graphs[includeForces].capture_end(); } @@ -256,10 +271,10 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what()); } } - // Use the same stream as the OpenMM context, even if it is the default stream + // Use the same stream as the OpenMM context, even if it is the default stream const auto openmmStream = cu.getCurrentStream(); - const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex()); - const c10::cuda::CUDAStreamGuard guard(stream); + const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex()); + const c10::cuda::CUDAStreamGuard guard(stream); graphs[includeForces].replay(); } if (includeForces) { @@ -269,12 +284,10 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce const double energy = energyTensor.item(); // This implicitly synchronizes the PyTorch context // Store parameter energy derivatives auto& derivs = cu.getEnergyParamDerivWorkspace(); - int firstParameterIndex = usePeriodic ? 2 : 1; // Skip the position and box tensors for (int i = 0; i < energyParameterDerivatives.size(); i++) { // Compute the derivative of the energy with respect to this parameter. // The derivative is stored in the gradient of the parameter tensor. - auto parameter_tensor = inputs[i + firstParameterIndex].toTensor(); - double derivative = parameter_tensor.grad().item(); + double derivative = gradientTensors[i].item(); auto name = energyParameterDerivatives[i]; derivs[name] = derivative; } diff --git a/platforms/cuda/src/CudaTorchKernels.h b/platforms/cuda/src/CudaTorchKernels.h index e93e0508..ddafe371 100644 --- a/platforms/cuda/src/CudaTorchKernels.h +++ b/platforms/cuda/src/CudaTorchKernels.h @@ -71,6 +71,7 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel { torch::jit::script::Module module; torch::Tensor posTensor, boxTensor; torch::Tensor energyTensor, forceTensor; + std::vector gradientTensors; std::vector globalNames; std::vector energyParameterDerivatives; bool usePeriodic, outputsForces; From b0d51e6a61de01673d3c73c449fc9030bb792bbc Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 16:16:41 +0200 Subject: [PATCH 28/31] Test the case when a model already calls backwards to compute the forces --- python/tests/TestParameterDerivatives.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/tests/TestParameterDerivatives.py b/python/tests/TestParameterDerivatives.py index da649ced..ae274b1a 100644 --- a/python/tests/TestParameterDerivatives.py +++ b/python/tests/TestParameterDerivatives.py @@ -23,13 +23,14 @@ def forward( class EnergyForceWithParameters(pt.nn.Module): - def __init__(self, use_backwards=False): + def __init__(self, use_backwards=True): super(EnergyForceWithParameters, self).__init__() self.use_backwards = use_backwards def forward( self, positions: Tensor, parameter1: Tensor, parameter2: Tensor ) -> Tuple[Tensor, Tensor]: + positions.requires_grad_(True) x2 = positions.pow(2).sum(dim=1) u_harmonic = ((parameter1 + parameter2**2) * x2).sum() # This way of computing the forces forcefully leaves out the parameter derivatives @@ -40,7 +41,7 @@ def forward( [positions], grad_outputs=grad_outputs, create_graph=False, - retain_graph=False, + retain_graph=True, )[0] assert dy is not None forces = -dy @@ -51,8 +52,12 @@ def forward( @pytest.mark.parametrize("use_cv_force", [False, True]) @pytest.mark.parametrize("platform", ["Reference", "CPU", "CUDA", "OpenCL"]) -@pytest.mark.parametrize("return_forces", [False, True]) -def testParameterEnergyDerivatives(use_cv_force, platform, return_forces): +@pytest.mark.parametrize( + ("return_forces", "use_backwards"), [(False, False), (True, False), (True, True)] +) +def testParameterEnergyDerivatives( + use_cv_force, platform, return_forces, use_backwards +): if pt.cuda.device_count() < 1 and platform == "CUDA": pytest.skip("A CUDA device is not available") @@ -66,7 +71,7 @@ def testParameterEnergyDerivatives(use_cv_force, platform, return_forces): # Create a force if return_forces: - pt_force = EnergyForceWithParameters() + pt_force = EnergyForceWithParameters(use_backwards=use_backwards) else: pt_force = EnergyWithParameters() model = pt.jit.script(pt_force) From 0f88d7f793647dcfdbcce65152d0f873563f91c4 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 16:31:33 +0200 Subject: [PATCH 29/31] Small changes --- python/tests/TestParameterDerivatives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tests/TestParameterDerivatives.py b/python/tests/TestParameterDerivatives.py index ae274b1a..888acd99 100644 --- a/python/tests/TestParameterDerivatives.py +++ b/python/tests/TestParameterDerivatives.py @@ -41,6 +41,7 @@ def forward( [positions], grad_outputs=grad_outputs, create_graph=False, + # This must be true, otherwise pytorch will not allow to compute the gradients with respect to the parameters retain_graph=True, )[0] assert dy is not None From c81526c543f6f5c0d15900c94ca6deabc2eebb90 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 16:31:52 +0200 Subject: [PATCH 30/31] Handle some corner cases in OpenCL and Reference --- platforms/opencl/src/OpenCLTorchKernels.cpp | 27 ++++++++++++++++--- .../reference/src/ReferenceTorchKernels.cpp | 18 ++++++++++++- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/platforms/opencl/src/OpenCLTorchKernels.cpp b/platforms/opencl/src/OpenCLTorchKernels.cpp index 13200086..0c537953 100644 --- a/platforms/opencl/src/OpenCLTorchKernels.cpp +++ b/platforms/opencl/src/OpenCLTorchKernels.cpp @@ -47,8 +47,11 @@ void OpenCLCalcTorchForceKernel::initialize(const System& system, const TorchFor outputsForces = force.getOutputsForces(); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalNames.push_back(force.getGlobalParameterName(i)); - for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) - energyParameterDerivatives.push_back(force.getEnergyParameterDerivativeName(i)); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++){ + auto name = force.getEnergyParameterDerivativeName(i); + energyParameterDerivatives.push_back(name); + cl.addEnergyParameterDerivative(name); + } int numParticles = system.getNumParticles(); @@ -100,9 +103,27 @@ double OpenCLCalcTorchForceKernel::execute(ContextImpl& context, bool includeFor } else energyTensor = module.forward(inputs).toTensor(); + // Compute any gradients by backpropagating the PyTorch model + std::vector inputs_with_grad; + if (includeForces && !outputsForces) { + inputs_with_grad.push_back(posTensor); + } + for (int i = 1; i < inputs.size(); i++) { // Skip the positions + auto& input = inputs[i]; + if (input.isTensor()) { + auto tensor = input.toTensor(); + if (tensor.requires_grad()) + inputs_with_grad.emplace_back(tensor); + } + } + if (inputs_with_grad.size() > 0) { + // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions + // See https://github.com/openmm/openmm-torch/pull/120/ + auto none = torch::Tensor(); + energyTensor.backward(none, false, false, inputs_with_grad); + } if (includeForces) { if (!outputsForces) { - energyTensor.backward(); forceTensor = posTensor.grad(); } if (cl.getUseDoublePrecision()) { diff --git a/platforms/reference/src/ReferenceTorchKernels.cpp b/platforms/reference/src/ReferenceTorchKernels.cpp index 9b329386..f20cdd30 100644 --- a/platforms/reference/src/ReferenceTorchKernels.cpp +++ b/platforms/reference/src/ReferenceTorchKernels.cpp @@ -99,9 +99,25 @@ double ReferenceCalcTorchForceKernel::execute(ContextImpl& context, bool include forceTensor = outputs->elements()[1].toTensor(); } else energyTensor = module.forward(inputs).toTensor(); + // Compute any gradients by backpropagating the PyTorch model + std::vector inputs_with_grad; + if (includeForces && !outputsForces) { + inputs_with_grad.push_back(posTensor); + } + for (int i = 1; i < inputs.size(); i++) { // Skip the positions + auto& input = inputs[i]; + if (input.isTensor()) { + auto tensor = input.toTensor(); + if (tensor.requires_grad()) + inputs_with_grad.emplace_back(tensor); + } + } + if (inputs_with_grad.size() > 0) { + auto none = torch::Tensor(); + energyTensor.backward(none, false, false, inputs_with_grad); + } if (includeForces) { if (!outputsForces) { - energyTensor.backward(); forceTensor = posTensor.grad(); } if (!(forceTensor.dtype() == torch::kFloat64)) From 4c9b5e509d15ce52db7600f7538fbe66538c1c2a Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 15 May 2024 16:55:41 +0200 Subject: [PATCH 31/31] Small changes --- python/tests/TestParameterDerivatives.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tests/TestParameterDerivatives.py b/python/tests/TestParameterDerivatives.py index 888acd99..acb28588 100644 --- a/python/tests/TestParameterDerivatives.py +++ b/python/tests/TestParameterDerivatives.py @@ -82,8 +82,9 @@ def testParameterEnergyDerivatives( parameter2 = 1.0 tforce.setOutputsForces(return_forces) tforce.addGlobalParameter("parameter1", parameter1) - tforce.addEnergyParameterDerivative("parameter1") tforce.addGlobalParameter("parameter2", parameter2) + # Enable energy derivatives for the parameters + tforce.addEnergyParameterDerivative("parameter1") tforce.addEnergyParameterDerivative("parameter2") if use_cv_force: # Wrap TorchForce into CustomCVForce @@ -91,7 +92,6 @@ def testParameterEnergyDerivatives( force.addCollectiveVariable("force", tforce) else: force = tforce - # Enable energy derivatives for the parameter system.addForce(force) # Compute the forces and energy. integ = mm.VerletIntegrator(1.0)