Skip to content

Commit

Permalink
Merge pull request eljost#313 from bytedance/gpu4pyscf_support
Browse files Browse the repository at this point in the history
Add gpu4pyscf support, fix bug with pyscf density fitting
  • Loading branch information
eljost authored Nov 17, 2024
2 parents fd87777 + dfe82b0 commit f7f7c51
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions pysisyphus/calculators/PySCF.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import numpy as np
import pyscf
from pyscf import gto, grad, lib, hessian, tddft, qmmm
from pyscf import gto, lib, qmmm
from pyscf import __all__ # ensure all modules are accessible under the pyscf namespace

from pysisyphus.calculators.OverlapCalculator import OverlapCalculator
from pysisyphus.helpers import geom_loader
Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(
unrestricted=None,
grid_level=3,
pruning="nwchem",
use_gpu=False,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -79,6 +81,8 @@ def __init__(
self.chkfile = None
self.out_fn = "pyscf.out"

self.use_gpu = use_gpu

lib.num_threads(self.pal)

@staticmethod
Expand All @@ -98,7 +102,10 @@ def build_grid(self, mf):

def prepare_mf(self, mf):
# Method can be overriden in a subclass to modify the mf object.
return mf
if self.use_gpu:
return mf.to_gpu()
else:
return mf

def get_driver(self, step, mol=None, mf=None):
def _get_driver():
Expand Down Expand Up @@ -244,7 +251,7 @@ def run(self, mol, point_charges=None):
f"Using '{self.chkfile}' as initial guess for {step} calculation."
)
if self.auxbasis:
mf.density_fit(auxbasis=self.auxbasis)
mf = mf.density_fit(auxbasis=self.auxbasis)
self.log(f"Using density fitting with auxbasis {self.auxbasis}.")

if point_charges is not None:
Expand All @@ -270,7 +277,7 @@ def run(self, mol, point_charges=None):

# Keep mf and dump mol
# save_mol(mol, self.make_fn("mol.chk"))
self.mf = mf
self.mf = mf.reset() # release integrals and other temporary intermediates.
self.calc_counter += 1

return mf
Expand Down

0 comments on commit f7f7c51

Please sign in to comment.