Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support parameter derivatives #142

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6cbc78d
Conditionally set C++17 for latest pytorch versions
RaulPPelaez Nov 14, 2023
d26f945
Add correct std to setup.py too
RaulPPelaez Nov 14, 2023
0f89440
Try to free up some space for the CI
RaulPPelaez Dec 12, 2023
a604b33
Clean space on CI machine before installing CUDA
RaulPPelaez Dec 12, 2023
4b6bd73
Update CMakeLists.txt
RaulPPelaez Jan 29, 2024
a347019
Add energy parameter derivative to the API
RaulPPelaez May 15, 2024
e12b0af
Empty commit to trigger CI
RaulPPelaez May 15, 2024
11fb899
Change to int
RaulPPelaez May 15, 2024
cd35af1
Add test for energy derivatives
RaulPPelaez May 15, 2024
b69456e
Update env
RaulPPelaez May 15, 2024
fc5a85f
Update env
RaulPPelaez May 15, 2024
033b2d4
Typo
RaulPPelaez May 15, 2024
ded6559
typo
RaulPPelaez May 15, 2024
d1e0c22
Merge branch 'cpp17' into derivatives
RaulPPelaez May 15, 2024
55de9ee
Add getNumEnergyParameterDerivatives
RaulPPelaez May 15, 2024
bcc8759
Implement Reference platform
RaulPPelaez May 15, 2024
aba8a62
Update test
RaulPPelaez May 15, 2024
f26e867
Update test adding two parameters
RaulPPelaez May 15, 2024
00b8120
Implement CUDA and OpenCL
RaulPPelaez May 15, 2024
277c90d
Initialize OpenCL energyParameterDerivatives map
RaulPPelaez May 15, 2024
9d6d34c
Fix return
RaulPPelaez May 15, 2024
4778eac
Remove commented code
RaulPPelaez May 15, 2024
4031cc5
Register params to the context
RaulPPelaez May 15, 2024
0acc338
Make sure backwards is called with all the requested parameters as in…
RaulPPelaez May 15, 2024
9c65e6b
Small changes
RaulPPelaez May 15, 2024
0dbc814
Small changes
RaulPPelaez May 15, 2024
46990a9
Test more cases
RaulPPelaez May 15, 2024
ce80b7d
Handle CUDA graphs
RaulPPelaez May 15, 2024
b0d51e6
Test the case when a model already calls backwards to compute the forces
RaulPPelaez May 15, 2024
0f88d7f
Small changes
RaulPPelaez May 15, 2024
c81526c
Handle some corner cases in OpenCL and Reference
RaulPPelaez May 15, 2024
4c9b5e5
Small changes
RaulPPelaez May 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,53 @@ jobs:
NVCC_VERSION: ${{ matrix.nvcc-version }}
PYTORCH_VERSION: ${{ matrix.pytorch-version }}

- name: Manage disk space
if: matrix.os == 'ubuntu'
run: |
sudo mkdir -p /opt/empty_dir || true
for d in \
/opt/ghc \
/opt/hostedtoolcache \
/usr/lib/jvm \
/usr/local/.ghcup \
/usr/local/lib/android \
/usr/local/share/powershell \
/usr/share/dotnet \
/usr/share/swift \
; do
sudo rsync --stats -a --delete /opt/empty_dir/ $d || true
done
sudo apt-get purge -y -f firefox \
google-chrome-stable \
microsoft-edge-stable
sudo apt-get autoremove -y >& /dev/null
sudo apt-get autoclean -y >& /dev/null
sudo docker image prune --all --force
df -h

- name: "Install CUDA Toolkit on Linux (if needed)"
uses: Jimver/[email protected]
with:
cuda: ${{ matrix.cuda-version }}
linux-local-args: '["--toolkit", "--override"]'
if: startsWith(matrix.os, 'ubuntu')

- name: "Install SDK on MacOS (if needed)"
run: source devtools/scripts/install_macos_sdk.sh
if: startsWith(matrix.os, 'macos')

- name: "Update the conda enviroment file"
uses: cschleiden/replace-tokens@v1
with:
tokenPrefix: '@'
tokenSuffix: '@'
files: devtools/conda-envs/build-${{ matrix.os }}.yml
env:
CUDATOOLKIT_VERSION: ${{ matrix.cuda-version }}
GCC_VERSION: ${{ matrix.gcc-version }}
NVCC_VERSION: ${{ matrix.nvcc-version }}
PYTORCH_VERSION: ${{ matrix.pytorch-version }}

- uses: conda-incubator/setup-miniconda@v2
name: "Install dependencies with Mamba"
with:
Expand Down
13 changes: 10 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# OpenMM PyTorch Plugin
#----------------------------------------------------

CMAKE_MINIMUM_REQUIRED(VERSION 3.5)
CMAKE_MINIMUM_REQUIRED(VERSION 3.7)

# We need to know where OpenMM is installed so we can access the headers and libraries.
SET(OPENMM_DIR "/usr/local/openmm" CACHE PATH "Where OpenMM is installed")
Expand All @@ -14,8 +14,15 @@ SET(PYTORCH_DIR "" CACHE PATH "Where the PyTorch C++ API is installed")
SET(CMAKE_PREFIX_PATH "${PYTORCH_DIR}")
FIND_PACKAGE(Torch REQUIRED)

# Specify the C++ version we are building for.
SET (CMAKE_CXX_STANDARD 14)
# Specify the C++ version we are building for. Latest pytorch versions require C++17
message(STATUS "Found Torch: ${Torch_VERSION}")
if(${Torch_VERSION} VERSION_GREATER_EQUAL "2.1.0")
set(CMAKE_CXX_STANDARD 17)
message(STATUS "Setting C++ standard to C++17")
else()
set(CMAKE_CXX_STANDARD 14)
message(STATUS "Setting C++ standard to C++14")
endif()

# Set flags for linking on mac
IF(APPLE)
Expand Down
3 changes: 3 additions & 0 deletions devtools/conda-envs/build-ubuntu-22.04.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ dependencies:
- swig
- sysroot_linux-64 2.17
- torchani
# Required by python<3.8
# xref: https://github.com/conda-forge/linux-sysroot-feedstock/issues/52
- libxcrypt
30 changes: 26 additions & 4 deletions openmmapi/include/TorchForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* @param file the path to the file containing the network
* @param properties optional map of properties
*/
TorchForce(const std::string& file,
const std::map<std::string, std::string>& properties = {});
TorchForce(const std::string& file, const std::map<std::string, std::string>& properties = {});
/**
* Create a TorchForce. The network is defined by a PyTorch ScriptModule
* Note that this constructor makes a copy of the provided module.
Expand All @@ -68,7 +67,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* @param module an instance of the torch module
* @param properties optional map of properties
*/
TorchForce(const torch::jit::Module &module, const std::map<std::string, std::string>& properties = {});
TorchForce(const torch::jit::Module& module, const std::map<std::string, std::string>& properties = {});
/**
* Get the path to the file containing the network.
* If the TorchForce instance was constructed with a module, instead of a filename,
Expand All @@ -78,7 +77,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
/**
* Get the torch module currently in use.
*/
const torch::jit::Module & getModule() const;
const torch::jit::Module& getModule() const;
/**
* Set whether this force makes use of periodic boundary conditions. If this is set
* to true, the network must take a 3x3 tensor as its second input, which
Expand Down Expand Up @@ -116,6 +115,26 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* @return the index of the parameter that was added
*/
int addGlobalParameter(const std::string& name, double defaultValue);
/**
* Add a new energy parameter derivative that the interaction may depend on.
*
* @param name the name of the parameter
*/
void addEnergyParameterDerivative(const std::string& name);
/**
* Get the name of an energy parameter derivative given its global parameter index.
*
* @param index the index of the parameter for which to get the name
* @return the parameter name
*/
const std::string& getEnergyParameterDerivativeName(int index) const;
/**
* Get the number of global parameters with respect to which the derivative of the energy
* should be computed.
*/
int getNumEnergyParameterDerivatives() const {
return energyParameterDerivatives.size();
}
/**
* Get the name of a global parameter.
*
Expand Down Expand Up @@ -156,13 +175,16 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* @return A map of property names to values.
*/
const std::map<std::string, std::string>& getProperties() const;

protected:
OpenMM::ForceImpl* createImpl() const;

private:
class GlobalParameterInfo;
std::string file;
bool usePeriodic, outputsForces;
std::vector<GlobalParameterInfo> globalParameters;
std::vector<int> energyParameterDerivatives;
torch::jit::Module module;
std::map<std::string, std::string> properties;
std::string emptyProperty;
Expand Down
15 changes: 15 additions & 0 deletions openmmapi/src/TorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,21 @@ int TorchForce::addGlobalParameter(const string& name, double defaultValue) {
return globalParameters.size() - 1;
}

void TorchForce::addEnergyParameterDerivative(const string& name) {
for (int i = 0; i < globalParameters.size(); i++) {
if (globalParameters[i].name == name) {
energyParameterDerivatives.push_back(i);
return;
}
}
}

const std::string& TorchForce::getEnergyParameterDerivativeName(int index) const {
if (index < 0 || index >= energyParameterDerivatives.size())
throw OpenMM::OpenMMException("TorchForce::getEnergyParameterDerivativeName: index out of range.");
return globalParameters[energyParameterDerivatives[index]].name;
}

int TorchForce::getNumGlobalParameters() const {
return globalParameters.size();
}
Expand Down
84 changes: 63 additions & 21 deletions platforms/cuda/src/CudaTorchKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
outputsForces = force.getOutputsForces();
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalNames.push_back(force.getGlobalParameterName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
auto name = force.getEnergyParameterDerivativeName(i);
energyParameterDerivatives.push_back(name);
cu.addEnergyParameterDerivative(name);
}
int numParticles = system.getNumParticles();

// Push the PyTorch context
Expand All @@ -81,6 +86,7 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
boxTensor = torch::empty({3, 3}, options);
energyTensor = torch::empty({0}, options);
forceTensor = torch::empty({0}, options);
gradientTensors.resize(force.getNumEnergyParameterDerivatives(), torch::empty({0}, options));
// Pop the PyToch context
CUcontext ctx;
CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context");
Expand Down Expand Up @@ -148,8 +154,13 @@ std::vector<torch::jit::IValue> CudaCalcTorchForceKernel::prepareTorchInputs(Con
vector<torch::jit::IValue> inputs = {posTensor};
if (usePeriodic)
inputs.push_back(boxTensor);
for (const string& name : globalNames)
inputs.push_back(torch::tensor(context.getParameter(name)));
for (const string& name : globalNames) {
// Require grad if the parameter is in the list of energy parameter derivatives
bool requires_grad = std::find(energyParameterDerivatives.begin(), energyParameterDerivatives.end(), name) != energyParameterDerivatives.end();
auto options = torch::TensorOptions().requires_grad(requires_grad).device(posTensor.device());
auto tensor = torch::tensor(context.getParameter(name), options);
inputs.emplace_back(tensor);
}
return inputs;
}

Expand Down Expand Up @@ -179,23 +190,43 @@ void CudaCalcTorchForceKernel::addForces(torch::Tensor& forceTensor) {
* implicit synchronizations) will result in a CUDA error.
*/
static void executeGraph(bool outputsForces, bool includeForces, torch::jit::script::Module& module, vector<torch::jit::IValue>& inputs, torch::Tensor& posTensor, torch::Tensor& energyTensor,
torch::Tensor& forceTensor) {
torch::Tensor& forceTensor, std::vector<torch::Tensor>& gradientTensors) {
if (outputsForces) {
auto outputs = module.forward(inputs).toTuple();
energyTensor = outputs->elements()[0].toTensor();
forceTensor = outputs->elements()[1].toTensor();
} else {
energyTensor = module.forward(inputs).toTensor();
// Compute force by backpropagating the PyTorch model
if (includeForces) {
// CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions
// See https://github.com/openmm/openmm-torch/pull/120/
auto none = torch::Tensor();
energyTensor.backward(none, false, false, posTensor);
// This is minus the forces, we change the sign later on
forceTensor = posTensor.grad().clone();
// Zero the gradient to avoid accumulating it
posTensor.grad().zero_();
}
// Compute any gradients by backpropagating the PyTorch model
std::vector<torch::Tensor> inputs_with_grad;
if (includeForces && !outputsForces) {
inputs_with_grad.push_back(posTensor);
}
for (int i = 1; i < inputs.size(); i++) { // Skip the positions
auto& input = inputs[i];
if (input.isTensor()) {
auto tensor = input.toTensor();
if (tensor.requires_grad())
inputs_with_grad.emplace_back(tensor);
}
}
if (inputs_with_grad.size() > 0) {
// CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions
// See https://github.com/openmm/openmm-torch/pull/120/
auto none = torch::Tensor();
energyTensor.backward(none, false, false, inputs_with_grad);
// Store the gradients for the energy parameters
bool isForceFirst = includeForces && !outputsForces;
for (int i = 0; i < inputs_with_grad.size(); i++) {
if (i == 0 && isForceFirst) {
// This is minus the forces, we change the sign later on
forceTensor = inputs_with_grad[i].grad().clone();
inputs_with_grad[i].grad().zero_();
} else {
gradientTensors[i - isForceFirst] = inputs_with_grad[i].grad().clone();
inputs_with_grad[i].grad().zero_();
}
}
}
}
Expand All @@ -204,30 +235,32 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
// Push to the PyTorch context
CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context");
auto inputs = prepareTorchInputs(context);
// Store the executeGraph call in a lambda to allow for easy reuse
auto payload = [&]() { executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, gradientTensors); };
if (!useGraphs) {
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
payload();
} else {
// Record graph if not already done
bool is_graph_captured = false;
if (graphs.find(includeForces) == graphs.end()) {
//CUDA graph capture must occur in a non-default stream
// CUDA graph capture must occur in a non-default stream
const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
const c10::cuda::CUDAStreamGuard guard(stream);
// Warmup the graph workload before capturing. This first
// run before capture sets up allocations so that no
// allocations are needed after. Pytorch's allocator is
// stream capture-aware and, after warmup, will provide
// record static pointers and shapes during capture.
try {
for (int i = 0; i < this->warmupSteps; i++)
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
payload();
}
catch (std::exception& e) {
throw OpenMMException(string("TorchForce Failed to warmup the model before graph construction. Torch reported the following error:\n") + e.what());
}
graphs[includeForces].capture_begin();
try {
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
payload();
is_graph_captured = true;
graphs[includeForces].capture_end();
}
Expand All @@ -238,17 +271,26 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what());
}
}
// Use the same stream as the OpenMM context, even if it is the default stream
// Use the same stream as the OpenMM context, even if it is the default stream
const auto openmmStream = cu.getCurrentStream();
const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
graphs[includeForces].replay();
}
if (includeForces) {
addForces(forceTensor);
}
// Get energy
const double energy = energyTensor.item<double>(); // This implicitly synchronizes the PyTorch context
// Store parameter energy derivatives
auto& derivs = cu.getEnergyParamDerivWorkspace();
for (int i = 0; i < energyParameterDerivatives.size(); i++) {
// Compute the derivative of the energy with respect to this parameter.
// The derivative is stored in the gradient of the parameter tensor.
double derivative = gradientTensors[i].item<double>();
auto name = energyParameterDerivatives[i];
derivs[name] = derivative;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I be summing here instead of overwritting?

}
// Pop to the PyTorch context
CUcontext ctx;
CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context");
Expand Down
2 changes: 2 additions & 0 deletions platforms/cuda/src/CudaTorchKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel {
torch::jit::script::Module module;
torch::Tensor posTensor, boxTensor;
torch::Tensor energyTensor, forceTensor;
std::vector<torch::Tensor> gradientTensors;
std::vector<std::string> globalNames;
std::vector<std::string> energyParameterDerivatives;
bool usePeriodic, outputsForces;
CUfunction copyInputsKernel, addForcesKernel;
CUcontext primaryContext;
Expand Down
Loading
Loading