diff --git a/d4ft/hamiltonian/nuclear.py b/d4ft/hamiltonian/nuclear.py index 4d5bb49..ce1a301 100644 --- a/d4ft/hamiltonian/nuclear.py +++ b/d4ft/hamiltonian/nuclear.py @@ -17,15 +17,11 @@ from jaxtyping import Array, Float, Int -def set_diag_zero(x: Array) -> Array: - """Set diagonal items to zero.""" - return x.at[jnp.diag_indices(x.shape[0])].set(0) - - def e_nuclear(center: Float[Array, "n_atoms 3"], charge: Int[Array, "n_atoms"]) -> Float[Array, ""]: """Potential energy between atomic nuclears.""" - dist_nuc = jnp.linalg.norm(center - center[:, None], axis=-1) + dist_diff = center - center[:, None] + dist_nuc = jnp.sqrt(jnp.sum(dist_diff**2, axis=-1) + 1e-20) + dist_nuc = jnp.where(dist_nuc <= 1e-9, 1e20, dist_nuc) charge_outer = jnp.outer(charge, charge) - charge_outer = set_diag_zero(charge_outer) - return 0.5 * jnp.sum(charge_outer / (dist_nuc + 1e-15)) + return 0.5 * jnp.sum(charge_outer / dist_nuc) diff --git a/d4ft/integral/gto/cgto.py b/d4ft/integral/gto/cgto.py index 40c9f7b..890c220 100644 --- a/d4ft/integral/gto/cgto.py +++ b/d4ft/integral/gto/cgto.py @@ -510,10 +510,6 @@ def from_mol(mol: Mol) -> CGTO: """Build CGTO from pyscf mol.""" return build_cgto_from_mol(mol) - @staticmethod - def from_cart(cgto_cart: CGTO) -> CGTO: - return build_cgto_sph_from_mol(cgto_cart) - def to_hk( self, optimizable_params: Sequence[Literal[ diff --git a/d4ft/integral/obara_saika/nuclear_attraction_integral.py b/d4ft/integral/obara_saika/nuclear_attraction_integral.py index 67dadc8..7a22baf 100644 --- a/d4ft/integral/obara_saika/nuclear_attraction_integral.py +++ b/d4ft/integral/obara_saika/nuclear_attraction_integral.py @@ -143,6 +143,7 @@ def horizontal(i, A_0, min_b): return jnp.einsum("a,am->m", w, A_0[min_b[i]:, :]) prefactor = 2 * (jnp.pi / zeta) * jnp.exp(-xi * jnp.dot(ab, ab)) # Eqn.A20 + # TODO(geo_opt): this gives nan A_0_0 = jax.vmap(boys.Boys, in_axes=(0, None))(jnp.arange(M, dtype=int), U) if use_horizontal: diff --git a/d4ft/solver/drivers.py b/d4ft/solver/drivers.py index 3cc4d3e..c2dc127 100644 --- a/d4ft/solver/drivers.py +++ b/d4ft/solver/drivers.py @@ -27,6 +27,7 @@ from d4ft.config import D4FTConfig from d4ft.hamiltonian.cgto_intors import get_cgto_fock_fn, get_cgto_intor from d4ft.hamiltonian.mf_cgto import mf_cgto +from d4ft.hamiltonian.nuclear import e_nuclear from d4ft.hamiltonian.ortho import qr_factor, sqrt_inv from d4ft.integral import obara_saika as obsa from d4ft.integral.gto.cgto import CGTO @@ -71,7 +72,16 @@ def build_mf_cgto(cfg: D4FTConfig): vxc_ab_fn = get_lda_vxc( grids_and_weights, cgto, polarized=not cfg.method_cfg.restricted ) - cgto_fock_fn = get_cgto_fock_fn(cgto, cgto_e_tensors, vxc_ab_fn) + if cfg.intor_cfg.incore: + cgto_fock_fn = get_cgto_fock_fn(cgto, cgto_e_tensors, vxc_ab_fn) + else: + cgto_fock_fn = None + + # # DEBUG + # g1 = jax.grad(partial(e_nuclear, charge=cgto.charge))(cgto.atom_coords) + # g2 = jax.jacfwd(partial(e_nuclear, charge=cgto.charge))(cgto.atom_coords) + # print(g1, g2) + # breakpoint() def H_factory(with_mo_coeff: bool = True) -> Tuple[Callable, Hamiltonian]: """Auto-grad scope""" diff --git a/d4ft/solver/sgd.py b/d4ft/solver/sgd.py index b324694..e451597 100644 --- a/d4ft/solver/sgd.py +++ b/d4ft/solver/sgd.py @@ -28,8 +28,7 @@ def scipy_opt( - solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params, - key: jax.Array + solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params, key: jax.Array ) -> float: energy_fn_jit = jax.jit(lambda mo_coeff: H.energy_fn(mo_coeff, key)[0]) import jaxopt @@ -39,8 +38,7 @@ def scipy_opt( def sgd( - solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params, - key: jax.Array + solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params, key: jax.Array ) -> Tuple[RunLogger, Trajectory]: @jax.jit @@ -103,6 +101,18 @@ def meta_step(state: TrainingState, meta_state: TrainingState): logger.log_step(energies, step, e_total_std) logger.get_segment_summary() + center = state.params['~']['center'] + logging.info(f"{center=}") + + grads, _ = jax.grad(H.energy_fn, has_aux=True)(state.params, state.rng_key) + dE_dR = grads['~']['center'] + logging.info(f"{dE_dR=}") + breakpoint() + + # g1 = jax.grad(partial(e_nuclear, charge=cgto.charge))(cgto.atom_coords) + # g2, _ = jax.jacfwd(H.energy_fn, has_aux=True)(state.params, state.rng_key) + # breakpoint() + mo_coeff = H.mo_coeff_fn(state.params, state.rng_key, apply_spin_mask=False) t = Transition(mo_coeff, energies, mo_grads) diff --git a/main.py b/main.py index 320b6a5..ce2384f 100644 --- a/main.py +++ b/main.py @@ -16,8 +16,8 @@ import string from pathlib import Path from typing import Any -import jax +import jax import matplotlib.pyplot as plt import pandas as pd import shortuuid @@ -62,6 +62,7 @@ def get_rxn_energy(rxn: str, benchmark: str, df: pd.DataFrame) -> float: def main(_: Any) -> None: jax.config.update("jax_enable_x64", FLAGS.use_f64) jax.config.update("jax_debug_nans", FLAGS.debug_nans) + # jax.config.update("jax_disable_jit", True) cfg: D4FTConfig = FLAGS.config print(cfg) diff --git a/third_party/pip_requirements/requirements-dev.txt b/third_party/pip_requirements/requirements-dev.txt index 9a8d525..d34f519 100644 --- a/third_party/pip_requirements/requirements-dev.txt +++ b/third_party/pip_requirements/requirements-dev.txt @@ -1,23 +1,23 @@ ase==3.22.1 bs4==0.0.1 -dm-haiku>=0.0.10 +dm-haiku==0.0.12 einops==0.6.1 -jax-xc>=0.0.7 +jax-xc==0.0.8 # --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # jax[cuda12_local]==0.4.13 -jax>=0.4.13 -jaxlib>=0.3.25 -jaxtyping==0.2.15 -matplotlib==3.7.2 +jax==0.4.26 +jaxlib==0.4.26 +jaxtyping==0.2.28 +matplotlib==3.8.4 ml_collections==0.1.1 mpmath==1.3.0 -optax==0.1.5 -pandas==2.0.3 +optax==0.2.2 +pandas==2.2.2 pubchempy==1.0.4 pydantic==1.10.9 -pyscf==2.3.0 +pyscf==2.5.0 requests==2.31.0 -scipy>=1.9.0 +scipy==1.13.0 setuptools==68.0.0 shortuuid==1.0.11 tqdm==4.64.1