Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Szkered committed Apr 30, 2024
1 parent e0b846d commit a76bffa
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 28 deletions.
12 changes: 4 additions & 8 deletions d4ft/hamiltonian/nuclear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@
from jaxtyping import Array, Float, Int


def set_diag_zero(x: Array) -> Array:
"""Set diagonal items to zero."""
return x.at[jnp.diag_indices(x.shape[0])].set(0)


def e_nuclear(center: Float[Array, "n_atoms 3"],
charge: Int[Array, "n_atoms"]) -> Float[Array, ""]:
"""Potential energy between atomic nuclears."""
dist_nuc = jnp.linalg.norm(center - center[:, None], axis=-1)
dist_diff = center - center[:, None]
dist_nuc = jnp.sqrt(jnp.sum(dist_diff**2, axis=-1) + 1e-20)
dist_nuc = jnp.where(dist_nuc <= 1e-9, 1e20, dist_nuc)
charge_outer = jnp.outer(charge, charge)
charge_outer = set_diag_zero(charge_outer)
return 0.5 * jnp.sum(charge_outer / (dist_nuc + 1e-15))
return 0.5 * jnp.sum(charge_outer / dist_nuc)
4 changes: 0 additions & 4 deletions d4ft/integral/gto/cgto.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,6 @@ def from_mol(mol: Mol) -> CGTO:
"""Build CGTO from pyscf mol."""
return build_cgto_from_mol(mol)

@staticmethod
def from_cart(cgto_cart: CGTO) -> CGTO:
return build_cgto_sph_from_mol(cgto_cart)

def to_hk(
self,
optimizable_params: Sequence[Literal[
Expand Down
1 change: 1 addition & 0 deletions d4ft/integral/obara_saika/nuclear_attraction_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def horizontal(i, A_0, min_b):
return jnp.einsum("a,am->m", w, A_0[min_b[i]:, :])

prefactor = 2 * (jnp.pi / zeta) * jnp.exp(-xi * jnp.dot(ab, ab)) # Eqn.A20
# TODO(geo_opt): this gives nan
A_0_0 = jax.vmap(boys.Boys, in_axes=(0, None))(jnp.arange(M, dtype=int), U)

if use_horizontal:
Expand Down
12 changes: 11 additions & 1 deletion d4ft/solver/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from d4ft.config import D4FTConfig
from d4ft.hamiltonian.cgto_intors import get_cgto_fock_fn, get_cgto_intor
from d4ft.hamiltonian.mf_cgto import mf_cgto
from d4ft.hamiltonian.nuclear import e_nuclear
from d4ft.hamiltonian.ortho import qr_factor, sqrt_inv
from d4ft.integral import obara_saika as obsa
from d4ft.integral.gto.cgto import CGTO
Expand Down Expand Up @@ -71,7 +72,16 @@ def build_mf_cgto(cfg: D4FTConfig):
vxc_ab_fn = get_lda_vxc(
grids_and_weights, cgto, polarized=not cfg.method_cfg.restricted
)
cgto_fock_fn = get_cgto_fock_fn(cgto, cgto_e_tensors, vxc_ab_fn)
if cfg.intor_cfg.incore:
cgto_fock_fn = get_cgto_fock_fn(cgto, cgto_e_tensors, vxc_ab_fn)
else:
cgto_fock_fn = None

# # DEBUG
# g1 = jax.grad(partial(e_nuclear, charge=cgto.charge))(cgto.atom_coords)
# g2 = jax.jacfwd(partial(e_nuclear, charge=cgto.charge))(cgto.atom_coords)
# print(g1, g2)
# breakpoint()

def H_factory(with_mo_coeff: bool = True) -> Tuple[Callable, Hamiltonian]:
"""Auto-grad scope"""
Expand Down
18 changes: 14 additions & 4 deletions d4ft/solver/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@


def scipy_opt(
solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params,
key: jax.Array
solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params, key: jax.Array
) -> float:
energy_fn_jit = jax.jit(lambda mo_coeff: H.energy_fn(mo_coeff, key)[0])
import jaxopt
Expand All @@ -39,8 +38,7 @@ def scipy_opt(


def sgd(
solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params,
key: jax.Array
solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params, key: jax.Array
) -> Tuple[RunLogger, Trajectory]:

@jax.jit
Expand Down Expand Up @@ -103,6 +101,18 @@ def meta_step(state: TrainingState, meta_state: TrainingState):
logger.log_step(energies, step, e_total_std)
logger.get_segment_summary()

center = state.params['~']['center']
logging.info(f"{center=}")

grads, _ = jax.grad(H.energy_fn, has_aux=True)(state.params, state.rng_key)
dE_dR = grads['~']['center']
logging.info(f"{dE_dR=}")
breakpoint()

# g1 = jax.grad(partial(e_nuclear, charge=cgto.charge))(cgto.atom_coords)
# g2, _ = jax.jacfwd(H.energy_fn, has_aux=True)(state.params, state.rng_key)
# breakpoint()

mo_coeff = H.mo_coeff_fn(state.params, state.rng_key, apply_spin_mask=False)
t = Transition(mo_coeff, energies, mo_grads)

Expand Down
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import string
from pathlib import Path
from typing import Any
import jax

import jax
import matplotlib.pyplot as plt
import pandas as pd
import shortuuid
Expand Down Expand Up @@ -62,6 +62,7 @@ def get_rxn_energy(rxn: str, benchmark: str, df: pd.DataFrame) -> float:
def main(_: Any) -> None:
jax.config.update("jax_enable_x64", FLAGS.use_f64)
jax.config.update("jax_debug_nans", FLAGS.debug_nans)
# jax.config.update("jax_disable_jit", True)

cfg: D4FTConfig = FLAGS.config
print(cfg)
Expand Down
20 changes: 10 additions & 10 deletions third_party/pip_requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
ase==3.22.1
bs4==0.0.1
dm-haiku>=0.0.10
dm-haiku==0.0.12
einops==0.6.1
jax-xc>=0.0.7
jax-xc==0.0.8
# --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# jax[cuda12_local]==0.4.13
jax>=0.4.13
jaxlib>=0.3.25
jaxtyping==0.2.15
matplotlib==3.7.2
jax==0.4.26
jaxlib==0.4.26
jaxtyping==0.2.28
matplotlib==3.8.4
ml_collections==0.1.1
mpmath==1.3.0
optax==0.1.5
pandas==2.0.3
optax==0.2.2
pandas==2.2.2
pubchempy==1.0.4
pydantic==1.10.9
pyscf==2.3.0
pyscf==2.5.0
requests==2.31.0
scipy>=1.9.0
scipy==1.13.0
setuptools==68.0.0
shortuuid==1.0.11
tqdm==4.64.1
Expand Down

0 comments on commit a76bffa

Please sign in to comment.