diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f2bd6a44..4135959e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -77,6 +77,53 @@ jobs: NVCC_VERSION: ${{ matrix.nvcc-version }} PYTORCH_VERSION: ${{ matrix.pytorch-version }} + - name: Manage disk space + if: matrix.os == 'ubuntu' + run: | + sudo mkdir -p /opt/empty_dir || true + for d in \ + /opt/ghc \ + /opt/hostedtoolcache \ + /usr/lib/jvm \ + /usr/local/.ghcup \ + /usr/local/lib/android \ + /usr/local/share/powershell \ + /usr/share/dotnet \ + /usr/share/swift \ + ; do + sudo rsync --stats -a --delete /opt/empty_dir/ $d || true + done + sudo apt-get purge -y -f firefox \ + google-chrome-stable \ + microsoft-edge-stable + sudo apt-get autoremove -y >& /dev/null + sudo apt-get autoclean -y >& /dev/null + sudo docker image prune --all --force + df -h + + - name: "Install CUDA Toolkit on Linux (if needed)" + uses: Jimver/cuda-toolkit@v0.2.10 + with: + cuda: ${{ matrix.cuda-version }} + linux-local-args: '["--toolkit", "--override"]' + if: startsWith(matrix.os, 'ubuntu') + + - name: "Install SDK on MacOS (if needed)" + run: source devtools/scripts/install_macos_sdk.sh + if: startsWith(matrix.os, 'macos') + + - name: "Update the conda enviroment file" + uses: cschleiden/replace-tokens@v1 + with: + tokenPrefix: '@' + tokenSuffix: '@' + files: devtools/conda-envs/build-${{ matrix.os }}.yml + env: + CUDATOOLKIT_VERSION: ${{ matrix.cuda-version }} + GCC_VERSION: ${{ matrix.gcc-version }} + NVCC_VERSION: ${{ matrix.nvcc-version }} + PYTORCH_VERSION: ${{ matrix.pytorch-version }} + - uses: conda-incubator/setup-miniconda@v2 name: "Install dependencies with Mamba" with: diff --git a/CMakeLists.txt b/CMakeLists.txt index 8573ec30..4c4e83a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ # OpenMM PyTorch Plugin #---------------------------------------------------- -CMAKE_MINIMUM_REQUIRED(VERSION 3.5) +CMAKE_MINIMUM_REQUIRED(VERSION 3.7) # We need to know where OpenMM is installed so we can access the headers and libraries. SET(OPENMM_DIR "/usr/local/openmm" CACHE PATH "Where OpenMM is installed") @@ -14,8 +14,15 @@ SET(PYTORCH_DIR "" CACHE PATH "Where the PyTorch C++ API is installed") SET(CMAKE_PREFIX_PATH "${PYTORCH_DIR}") FIND_PACKAGE(Torch REQUIRED) -# Specify the C++ version we are building for. -SET (CMAKE_CXX_STANDARD 14) +# Specify the C++ version we are building for. Latest pytorch versions require C++17 +message(STATUS "Found Torch: ${Torch_VERSION}") +if(${Torch_VERSION} VERSION_GREATER_EQUAL "2.1.0") + set(CMAKE_CXX_STANDARD 17) + message(STATUS "Setting C++ standard to C++17") +else() + set(CMAKE_CXX_STANDARD 14) + message(STATUS "Setting C++ standard to C++14") +endif() # Set flags for linking on mac IF(APPLE) diff --git a/devtools/conda-envs/build-ubuntu-22.04.yml b/devtools/conda-envs/build-ubuntu-22.04.yml index ffd4d2f7..091b45e1 100644 --- a/devtools/conda-envs/build-ubuntu-22.04.yml +++ b/devtools/conda-envs/build-ubuntu-22.04.yml @@ -18,3 +18,6 @@ dependencies: - swig - sysroot_linux-64 2.17 - torchani + # Required by python<3.8 + # xref: https://github.com/conda-forge/linux-sysroot-feedstock/issues/52 + - libxcrypt diff --git a/openmmapi/include/TorchForce.h b/openmmapi/include/TorchForce.h index 4406eb20..53b1b56f 100644 --- a/openmmapi/include/TorchForce.h +++ b/openmmapi/include/TorchForce.h @@ -58,8 +58,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * @param file the path to the file containing the network * @param properties optional map of properties */ - TorchForce(const std::string& file, - const std::map& properties = {}); + TorchForce(const std::string& file, const std::map& properties = {}); /** * Create a TorchForce. The network is defined by a PyTorch ScriptModule * Note that this constructor makes a copy of the provided module. @@ -68,7 +67,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * @param module an instance of the torch module * @param properties optional map of properties */ - TorchForce(const torch::jit::Module &module, const std::map& properties = {}); + TorchForce(const torch::jit::Module& module, const std::map& properties = {}); /** * Get the path to the file containing the network. * If the TorchForce instance was constructed with a module, instead of a filename, @@ -78,7 +77,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { /** * Get the torch module currently in use. */ - const torch::jit::Module & getModule() const; + const torch::jit::Module& getModule() const; /** * Set whether this force makes use of periodic boundary conditions. If this is set * to true, the network must take a 3x3 tensor as its second input, which @@ -116,6 +115,26 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * @return the index of the parameter that was added */ int addGlobalParameter(const std::string& name, double defaultValue); + /** + * Add a new energy parameter derivative that the interaction may depend on. + * + * @param name the name of the parameter + */ + void addEnergyParameterDerivative(const std::string& name); + /** + * Get the name of an energy parameter derivative given its global parameter index. + * + * @param index the index of the parameter for which to get the name + * @return the parameter name + */ + const std::string& getEnergyParameterDerivativeName(int index) const; + /** + * Get the number of global parameters with respect to which the derivative of the energy + * should be computed. + */ + int getNumEnergyParameterDerivatives() const { + return energyParameterDerivatives.size(); + } /** * Get the name of a global parameter. * @@ -156,13 +175,16 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * @return A map of property names to values. */ const std::map& getProperties() const; + protected: OpenMM::ForceImpl* createImpl() const; + private: class GlobalParameterInfo; std::string file; bool usePeriodic, outputsForces; std::vector globalParameters; + std::vector energyParameterDerivatives; torch::jit::Module module; std::map properties; std::string emptyProperty; diff --git a/openmmapi/src/TorchForce.cpp b/openmmapi/src/TorchForce.cpp index e8892048..f3c0c1ef 100644 --- a/openmmapi/src/TorchForce.cpp +++ b/openmmapi/src/TorchForce.cpp @@ -88,6 +88,21 @@ int TorchForce::addGlobalParameter(const string& name, double defaultValue) { return globalParameters.size() - 1; } +void TorchForce::addEnergyParameterDerivative(const string& name) { + for (int i = 0; i < globalParameters.size(); i++) { + if (globalParameters[i].name == name) { + energyParameterDerivatives.push_back(i); + return; + } + } +} + +const std::string& TorchForce::getEnergyParameterDerivativeName(int index) const { + if (index < 0 || index >= energyParameterDerivatives.size()) + throw OpenMM::OpenMMException("TorchForce::getEnergyParameterDerivativeName: index out of range."); + return globalParameters[energyParameterDerivatives[index]].name; +} + int TorchForce::getNumGlobalParameters() const { return globalParameters.size(); } diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 60ea7794..782632c1 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -66,6 +66,11 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce outputsForces = force.getOutputsForces(); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalNames.push_back(force.getGlobalParameterName(i)); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) { + auto name = force.getEnergyParameterDerivativeName(i); + energyParameterDerivatives.push_back(name); + cu.addEnergyParameterDerivative(name); + } int numParticles = system.getNumParticles(); // Push the PyTorch context @@ -81,6 +86,7 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce boxTensor = torch::empty({3, 3}, options); energyTensor = torch::empty({0}, options); forceTensor = torch::empty({0}, options); + gradientTensors.resize(force.getNumEnergyParameterDerivatives(), torch::empty({0}, options)); // Pop the PyToch context CUcontext ctx; CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context"); @@ -148,8 +154,13 @@ std::vector CudaCalcTorchForceKernel::prepareTorchInputs(Con vector inputs = {posTensor}; if (usePeriodic) inputs.push_back(boxTensor); - for (const string& name : globalNames) - inputs.push_back(torch::tensor(context.getParameter(name))); + for (const string& name : globalNames) { + // Require grad if the parameter is in the list of energy parameter derivatives + bool requires_grad = std::find(energyParameterDerivatives.begin(), energyParameterDerivatives.end(), name) != energyParameterDerivatives.end(); + auto options = torch::TensorOptions().requires_grad(requires_grad).device(posTensor.device()); + auto tensor = torch::tensor(context.getParameter(name), options); + inputs.emplace_back(tensor); + } return inputs; } @@ -179,23 +190,43 @@ void CudaCalcTorchForceKernel::addForces(torch::Tensor& forceTensor) { * implicit synchronizations) will result in a CUDA error. */ static void executeGraph(bool outputsForces, bool includeForces, torch::jit::script::Module& module, vector& inputs, torch::Tensor& posTensor, torch::Tensor& energyTensor, - torch::Tensor& forceTensor) { + torch::Tensor& forceTensor, std::vector& gradientTensors) { if (outputsForces) { auto outputs = module.forward(inputs).toTuple(); energyTensor = outputs->elements()[0].toTensor(); forceTensor = outputs->elements()[1].toTensor(); } else { energyTensor = module.forward(inputs).toTensor(); - // Compute force by backpropagating the PyTorch model - if (includeForces) { - // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions - // See https://github.com/openmm/openmm-torch/pull/120/ - auto none = torch::Tensor(); - energyTensor.backward(none, false, false, posTensor); - // This is minus the forces, we change the sign later on - forceTensor = posTensor.grad().clone(); - // Zero the gradient to avoid accumulating it - posTensor.grad().zero_(); + } + // Compute any gradients by backpropagating the PyTorch model + std::vector inputs_with_grad; + if (includeForces && !outputsForces) { + inputs_with_grad.push_back(posTensor); + } + for (int i = 1; i < inputs.size(); i++) { // Skip the positions + auto& input = inputs[i]; + if (input.isTensor()) { + auto tensor = input.toTensor(); + if (tensor.requires_grad()) + inputs_with_grad.emplace_back(tensor); + } + } + if (inputs_with_grad.size() > 0) { + // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions + // See https://github.com/openmm/openmm-torch/pull/120/ + auto none = torch::Tensor(); + energyTensor.backward(none, false, false, inputs_with_grad); + // Store the gradients for the energy parameters + bool isForceFirst = includeForces && !outputsForces; + for (int i = 0; i < inputs_with_grad.size(); i++) { + if (i == 0 && isForceFirst) { + // This is minus the forces, we change the sign later on + forceTensor = inputs_with_grad[i].grad().clone(); + inputs_with_grad[i].grad().zero_(); + } else { + gradientTensors[i - isForceFirst] = inputs_with_grad[i].grad().clone(); + inputs_with_grad[i].grad().zero_(); + } } } } @@ -204,15 +235,17 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce // Push to the PyTorch context CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context"); auto inputs = prepareTorchInputs(context); + // Store the executeGraph call in a lambda to allow for easy reuse + auto payload = [&]() { executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, gradientTensors); }; if (!useGraphs) { - executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); + payload(); } else { // Record graph if not already done bool is_graph_captured = false; if (graphs.find(includeForces) == graphs.end()) { - //CUDA graph capture must occur in a non-default stream + // CUDA graph capture must occur in a non-default stream const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex()); - const c10::cuda::CUDAStreamGuard guard(stream); + const c10::cuda::CUDAStreamGuard guard(stream); // Warmup the graph workload before capturing. This first // run before capture sets up allocations so that no // allocations are needed after. Pytorch's allocator is @@ -220,14 +253,14 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce // record static pointers and shapes during capture. try { for (int i = 0; i < this->warmupSteps; i++) - executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); + payload(); } catch (std::exception& e) { throw OpenMMException(string("TorchForce Failed to warmup the model before graph construction. Torch reported the following error:\n") + e.what()); } graphs[includeForces].capture_begin(); try { - executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); + payload(); is_graph_captured = true; graphs[includeForces].capture_end(); } @@ -238,10 +271,10 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what()); } } - // Use the same stream as the OpenMM context, even if it is the default stream + // Use the same stream as the OpenMM context, even if it is the default stream const auto openmmStream = cu.getCurrentStream(); - const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex()); - const c10::cuda::CUDAStreamGuard guard(stream); + const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex()); + const c10::cuda::CUDAStreamGuard guard(stream); graphs[includeForces].replay(); } if (includeForces) { @@ -249,6 +282,15 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce } // Get energy const double energy = energyTensor.item(); // This implicitly synchronizes the PyTorch context + // Store parameter energy derivatives + auto& derivs = cu.getEnergyParamDerivWorkspace(); + for (int i = 0; i < energyParameterDerivatives.size(); i++) { + // Compute the derivative of the energy with respect to this parameter. + // The derivative is stored in the gradient of the parameter tensor. + double derivative = gradientTensors[i].item(); + auto name = energyParameterDerivatives[i]; + derivs[name] = derivative; + } // Pop to the PyTorch context CUcontext ctx; CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context"); diff --git a/platforms/cuda/src/CudaTorchKernels.h b/platforms/cuda/src/CudaTorchKernels.h index 13f2a9b6..ddafe371 100644 --- a/platforms/cuda/src/CudaTorchKernels.h +++ b/platforms/cuda/src/CudaTorchKernels.h @@ -71,7 +71,9 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel { torch::jit::script::Module module; torch::Tensor posTensor, boxTensor; torch::Tensor energyTensor, forceTensor; + std::vector gradientTensors; std::vector globalNames; + std::vector energyParameterDerivatives; bool usePeriodic, outputsForces; CUfunction copyInputsKernel, addForcesKernel; CUcontext primaryContext; diff --git a/platforms/opencl/src/OpenCLTorchKernels.cpp b/platforms/opencl/src/OpenCLTorchKernels.cpp index a232b1c6..0c537953 100644 --- a/platforms/opencl/src/OpenCLTorchKernels.cpp +++ b/platforms/opencl/src/OpenCLTorchKernels.cpp @@ -47,6 +47,12 @@ void OpenCLCalcTorchForceKernel::initialize(const System& system, const TorchFor outputsForces = force.getOutputsForces(); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalNames.push_back(force.getGlobalParameterName(i)); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++){ + auto name = force.getEnergyParameterDerivativeName(i); + energyParameterDerivatives.push_back(name); + cl.addEnergyParameterDerivative(name); + } + int numParticles = system.getNumParticles(); // Inititalize OpenCL objects. @@ -81,8 +87,14 @@ double OpenCLCalcTorchForceKernel::execute(ContextImpl& context, bool includeFor boxTensor = boxTensor.to(torch::kFloat32); inputs.push_back(boxTensor); } - for (const string& name : globalNames) - inputs.push_back(torch::tensor(context.getParameter(name))); + for (const string& name : globalNames) { + // Require grad if the parameter is in the list of energy parameter derivatives + bool requires_grad = std::find(energyParameterDerivatives.begin(), energyParameterDerivatives.end(), name) != energyParameterDerivatives.end(); + auto options = torch::TensorOptions().requires_grad(requires_grad).device(posTensor.device()); + auto tensor = torch::tensor(context.getParameter(name), options); + // parameterTensors.emplace_back(tensor); + inputs.push_back(tensor); + } torch::Tensor energyTensor, forceTensor; if (outputsForces) { auto outputs = module.forward(inputs).toTuple(); @@ -91,9 +103,27 @@ double OpenCLCalcTorchForceKernel::execute(ContextImpl& context, bool includeFor } else energyTensor = module.forward(inputs).toTensor(); + // Compute any gradients by backpropagating the PyTorch model + std::vector inputs_with_grad; + if (includeForces && !outputsForces) { + inputs_with_grad.push_back(posTensor); + } + for (int i = 1; i < inputs.size(); i++) { // Skip the positions + auto& input = inputs[i]; + if (input.isTensor()) { + auto tensor = input.toTensor(); + if (tensor.requires_grad()) + inputs_with_grad.emplace_back(tensor); + } + } + if (inputs_with_grad.size() > 0) { + // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions + // See https://github.com/openmm/openmm-torch/pull/120/ + auto none = torch::Tensor(); + energyTensor.backward(none, false, false, inputs_with_grad); + } if (includeForces) { if (!outputsForces) { - energyTensor.backward(); forceTensor = posTensor.grad(); } if (cl.getUseDoublePrecision()) { @@ -115,6 +145,15 @@ double OpenCLCalcTorchForceKernel::execute(ContextImpl& context, bool includeFor addForcesKernel.setArg(4, outputsForces ? 1 : -1); cl.executeKernel(addForcesKernel, numParticles); } + // Store parameter energy derivatives + auto& derivs = cl.getEnergyParamDerivWorkspace(); + int firstParameterIndex = usePeriodic ? 2 : 1; // Skip the position and box tensors + for (int i = 0; i < energyParameterDerivatives.size(); i++) { + // Compute the derivative of the energy with respect to this parameter. + // The derivative is stored in the gradient of the parameter tensor. + double derivative = inputs[i + firstParameterIndex].toTensor().grad().item(); + auto name = energyParameterDerivatives[i]; + derivs[name] = derivative; + } return energyTensor.item(); } - diff --git a/platforms/opencl/src/OpenCLTorchKernels.h b/platforms/opencl/src/OpenCLTorchKernels.h index d5ccdee9..7a46d9af 100644 --- a/platforms/opencl/src/OpenCLTorchKernels.h +++ b/platforms/opencl/src/OpenCLTorchKernels.h @@ -49,7 +49,7 @@ class OpenCLCalcTorchForceKernel : public CalcTorchForceKernel { ~OpenCLCalcTorchForceKernel(); /** * Initialize the kernel. - * + * * @param system the System this kernel will be applied to * @param force the TorchForce this kernel will be used for * @param module the PyTorch module to use for computing forces and energy @@ -69,6 +69,7 @@ class OpenCLCalcTorchForceKernel : public CalcTorchForceKernel { OpenMM::OpenCLContext& cl; torch::jit::script::Module module; std::vector globalNames; + std::vector energyParameterDerivatives; bool usePeriodic, outputsForces; OpenMM::OpenCLArray networkForces; cl::Kernel addForcesKernel; diff --git a/platforms/reference/src/ReferenceTorchKernels.cpp b/platforms/reference/src/ReferenceTorchKernels.cpp index 5de8a2b8..f20cdd30 100644 --- a/platforms/reference/src/ReferenceTorchKernels.cpp +++ b/platforms/reference/src/ReferenceTorchKernels.cpp @@ -54,6 +54,11 @@ static Vec3* extractBoxVectors(ContextImpl& context) { return data->periodicBoxVectors; } +static map& extractEnergyParameterDerivatives(ContextImpl& context) { + ReferencePlatform::PlatformData* data = reinterpret_cast(context.getPlatformData()); + return *data->energyParameterDerivatives; +} + ReferenceCalcTorchForceKernel::~ReferenceCalcTorchForceKernel() { } @@ -63,6 +68,8 @@ void ReferenceCalcTorchForceKernel::initialize(const System& system, const Torch outputsForces = force.getOutputsForces(); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalNames.push_back(force.getGlobalParameterName(i)); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) + energyParameterDerivatives.push_back(force.getEnergyParameterDerivativeName(i)); } double ReferenceCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { @@ -76,19 +83,41 @@ double ReferenceCalcTorchForceKernel::execute(ContextImpl& context, bool include torch::Tensor boxTensor = torch::from_blob(box, {3, 3}, torch::kFloat64); inputs.push_back(boxTensor); } - for (const string& name : globalNames) - inputs.push_back(torch::tensor(context.getParameter(name))); + // Store parameter tensors that need derivatives + vector parameterTensors; + for (const string& name : globalNames) { + // Require grad if the parameter is in the list of energy parameter derivatives + bool requires_grad = std::find(energyParameterDerivatives.begin(), energyParameterDerivatives.end(), name) != energyParameterDerivatives.end(); + auto tensor = torch::tensor(context.getParameter(name), torch::TensorOptions().requires_grad(requires_grad)); + parameterTensors.emplace_back(tensor); + inputs.push_back(tensor); + } torch::Tensor energyTensor, forceTensor; if (outputsForces) { auto outputs = module.forward(inputs).toTuple(); energyTensor = outputs->elements()[0].toTensor(); forceTensor = outputs->elements()[1].toTensor(); - } - else + } else energyTensor = module.forward(inputs).toTensor(); + // Compute any gradients by backpropagating the PyTorch model + std::vector inputs_with_grad; + if (includeForces && !outputsForces) { + inputs_with_grad.push_back(posTensor); + } + for (int i = 1; i < inputs.size(); i++) { // Skip the positions + auto& input = inputs[i]; + if (input.isTensor()) { + auto tensor = input.toTensor(); + if (tensor.requires_grad()) + inputs_with_grad.emplace_back(tensor); + } + } + if (inputs_with_grad.size() > 0) { + auto none = torch::Tensor(); + energyTensor.backward(none, false, false, inputs_with_grad); + } if (includeForces) { if (!outputsForces) { - energyTensor.backward(); forceTensor = posTensor.grad(); } if (!(forceTensor.dtype() == torch::kFloat64)) @@ -97,7 +126,16 @@ double ReferenceCalcTorchForceKernel::execute(ContextImpl& context, bool include double forceSign = (outputsForces ? 1.0 : -1.0); for (int i = 0; i < numParticles; i++) for (int j = 0; j < 3; j++) - force[i][j] += forceSign*outputForces[3*i+j]; + force[i][j] += forceSign * outputForces[3 * i + j]; + } + // Store parameter energy derivatives + auto& derivs = extractEnergyParameterDerivatives(context); + for (int i = 0; i < energyParameterDerivatives.size(); i++) { + // Compute the derivative of the energy with respect to this parameter. + // The derivative is stored in the gradient of the parameter tensor. + double derivative = parameterTensors[i].grad().item(); + auto name = energyParameterDerivatives[i]; + derivs[name] = derivative; } return energyTensor.item(); } diff --git a/platforms/reference/src/ReferenceTorchKernels.h b/platforms/reference/src/ReferenceTorchKernels.h index f3abc1d2..b146699b 100644 --- a/platforms/reference/src/ReferenceTorchKernels.h +++ b/platforms/reference/src/ReferenceTorchKernels.h @@ -48,7 +48,7 @@ class ReferenceCalcTorchForceKernel : public CalcTorchForceKernel { ~ReferenceCalcTorchForceKernel(); /** * Initialize the kernel. - * + * * @param system the System this kernel will be applied to * @param force the TorchForce this kernel will be used for * @param module the PyTorch module to use for computing forces and energy @@ -67,6 +67,7 @@ class ReferenceCalcTorchForceKernel : public CalcTorchForceKernel { torch::jit::script::Module module; std::vector positions, boxVectors; std::vector globalNames; + std::vector energyParameterDerivatives; bool usePeriodic, outputsForces; }; diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index ee0c19dd..56e51189 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -23,6 +23,7 @@ add_custom_command( add_custom_target(PythonInstall DEPENDS "${WRAP_FILE}" "${CMAKE_CURRENT_SOURCE_DIR}/setup.py") set(NN_PLUGIN_HEADER_DIR "${CMAKE_SOURCE_DIR}/openmmapi/include") set(NN_PLUGIN_LIBRARY_DIR "${CMAKE_BINARY_DIR}") +set(EXTENSION_CXX_STANDARD "${CMAKE_CXX_STANDARD}") configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py ${CMAKE_CURRENT_BINARY_DIR}/setup.py) add_custom_command(TARGET PythonInstall COMMAND "${PYTHON_EXECUTABLE}" -m pip install . diff --git a/python/openmmtorch.i b/python/openmmtorch.i index 05988e3d..77a72e43 100644 --- a/python/openmmtorch.i +++ b/python/openmmtorch.i @@ -74,6 +74,8 @@ public: void setGlobalParameterName(int index, const std::string& name); double getGlobalParameterDefaultValue(int index) const; void setGlobalParameterDefaultValue(int index, double defaultValue); + void addEnergyParameterDerivative(const std::string& name); + const std::string& getEnergyParameterDerivativeName(int index) const; void setProperty(const std::string& name, const std::string& value); const std::map& getProperties() const; diff --git a/python/setup.py b/python/setup.py index 4d03efe2..6748a57a 100644 --- a/python/setup.py +++ b/python/setup.py @@ -6,10 +6,11 @@ torch_include_dirs = '@TORCH_INCLUDE_DIRS@'.split(';') nn_plugin_header_dir = '@NN_PLUGIN_HEADER_DIR@' nn_plugin_library_dir = '@NN_PLUGIN_LIBRARY_DIR@' +cpp_std = '@EXTENSION_CXX_STANDARD@' torch_dir, _ = os.path.split('@TORCH_LIBRARY@') # setup extra compile and link arguments on Mac -extra_compile_args = ['-std=c++14'] +extra_compile_args = ['-std=c++' + cpp_std] extra_link_args = [] if platform.system() == 'Darwin': diff --git a/python/tests/TestParameterDerivatives.py b/python/tests/TestParameterDerivatives.py new file mode 100644 index 00000000..acb28588 --- /dev/null +++ b/python/tests/TestParameterDerivatives.py @@ -0,0 +1,123 @@ +import openmm as mm +import openmm.unit as unit +import openmmtorch as ot +import numpy as np +import pytest +import torch as pt +from torch import Tensor +from typing import Tuple, List, Optional + + +class EnergyWithParameters(pt.nn.Module): + + def __init__(self): + super(EnergyWithParameters, self).__init__() + + def forward( + self, positions: Tensor, parameter1: Tensor, parameter2: Tensor + ) -> Tensor: + x2 = positions.pow(2).sum(dim=1) + u_harmonic = ((parameter1 + parameter2**2) * x2).sum() + return u_harmonic + + +class EnergyForceWithParameters(pt.nn.Module): + + def __init__(self, use_backwards=True): + super(EnergyForceWithParameters, self).__init__() + self.use_backwards = use_backwards + + def forward( + self, positions: Tensor, parameter1: Tensor, parameter2: Tensor + ) -> Tuple[Tensor, Tensor]: + positions.requires_grad_(True) + x2 = positions.pow(2).sum(dim=1) + u_harmonic = ((parameter1 + parameter2**2) * x2).sum() + # This way of computing the forces forcefully leaves out the parameter derivatives + if self.use_backwards: + grad_outputs: List[Optional[Tensor]] = [pt.ones_like(u_harmonic)] + dy = pt.autograd.grad( + [u_harmonic], + [positions], + grad_outputs=grad_outputs, + create_graph=False, + # This must be true, otherwise pytorch will not allow to compute the gradients with respect to the parameters + retain_graph=True, + )[0] + assert dy is not None + forces = -dy + else: + forces = -2 * (parameter1 + parameter2**2) * positions + return u_harmonic, forces + + +@pytest.mark.parametrize("use_cv_force", [False, True]) +@pytest.mark.parametrize("platform", ["Reference", "CPU", "CUDA", "OpenCL"]) +@pytest.mark.parametrize( + ("return_forces", "use_backwards"), [(False, False), (True, False), (True, True)] +) +def testParameterEnergyDerivatives( + use_cv_force, platform, return_forces, use_backwards +): + + if pt.cuda.device_count() < 1 and platform == "CUDA": + pytest.skip("A CUDA device is not available") + + # Create a random cloud of particles. + numParticles = 10 + system = mm.System() + positions = np.random.rand(numParticles, 3) + for _ in range(numParticles): + system.addParticle(1.0) + + # Create a force + if return_forces: + pt_force = EnergyForceWithParameters(use_backwards=use_backwards) + else: + pt_force = EnergyWithParameters() + model = pt.jit.script(pt_force) + tforce = ot.TorchForce(model, {"useCUDAGraphs": "false"}) + # Add a parameter + parameter1 = 1.0 + parameter2 = 1.0 + tforce.setOutputsForces(return_forces) + tforce.addGlobalParameter("parameter1", parameter1) + tforce.addGlobalParameter("parameter2", parameter2) + # Enable energy derivatives for the parameters + tforce.addEnergyParameterDerivative("parameter1") + tforce.addEnergyParameterDerivative("parameter2") + if use_cv_force: + # Wrap TorchForce into CustomCVForce + force = mm.CustomCVForce("force") + force.addCollectiveVariable("force", tforce) + else: + force = tforce + system.addForce(force) + # Compute the forces and energy. + integ = mm.VerletIntegrator(1.0) + platform = mm.Platform.getPlatformByName(platform) + context = mm.Context(system, integ, platform) + context.setPositions(positions) + state = context.getState( + getEnergy=True, getForces=True, getParameterDerivatives=True + ) + + # See if the energy and forces and the parameter derivative are correct. + # The network defines a potential of the form E(r) = (parameter1 + parameter2**2)*|r|^2 + r2 = np.sum(positions * positions) + expectedEnergy = (parameter1 + parameter2**2) * r2 + assert np.allclose( + expectedEnergy, + state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole), + ) + assert np.allclose( + -2 * (parameter1 + parameter2**2) * positions, state.getForces(asNumpy=True) + ) + assert np.allclose( + r2, + state.getEnergyParameterDerivatives()["parameter1"], + ) + assert np.allclose( + 2 * parameter2 * r2, + state.getEnergyParameterDerivatives()["parameter2"], + )