Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Szkered committed May 4, 2024
1 parent a76bffa commit 7dfe34c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 27 deletions.
4 changes: 2 additions & 2 deletions d4ft/solver/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,13 @@ def cgto_direct(
where CGTO basis are used and the energy tensors are precomputed/incore."""
key = jax.random.PRNGKey(cfg.method_cfg.rng_seed)

pyscf_mol, H_factory, _, _ = build_mf_cgto(cfg)
pyscf_mol, H_factory, cgto, _ = build_mf_cgto(cfg)

H_transformed = hk.multi_transform(H_factory)
params = H_transformed.init(key)
H_hk = Hamiltonian(*H_transformed.apply)

logger, traj = sgd(cfg.solver_cfg, H_hk, params, key)
logger, traj = sgd(cfg.solver_cfg, H_hk, cgto, params, key)

min_e_step = logger.data_df.e_total.astype(float).idxmin()
logging.info(f"lowest total energy: \n {logger.data_df.iloc[min_e_step]}")
Expand Down
39 changes: 32 additions & 7 deletions d4ft/solver/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
# limitations under the License.
"""Solve DFT with gradient descent"""

from functools import partial
from typing import Tuple

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
from absl import logging

from d4ft.config import GDConfig
from d4ft.hamiltonian.nuclear import e_nuclear
from d4ft.integral.gto.cgto import CGTO
from d4ft.logger import RunLogger
from d4ft.optimize import get_optimizer
from d4ft.types import Hamiltonian, TrainingState, Trajectory, Transition
Expand All @@ -38,7 +42,8 @@ def scipy_opt(


def sgd(
solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params, key: jax.Array
solver_cfg: GDConfig, H: Hamiltonian, cgto: CGTO, params: hk.Params,
key: jax.Array
) -> Tuple[RunLogger, Trajectory]:

@jax.jit
Expand Down Expand Up @@ -104,14 +109,34 @@ def meta_step(state: TrainingState, meta_state: TrainingState):
center = state.params['~']['center']
logging.info(f"{center=}")

grads, _ = jax.grad(H.energy_fn, has_aux=True)(state.params, state.rng_key)
dE_dR = grads['~']['center']
logging.info(f"{dE_dR=}")
breakpoint()
# grads, _ = jax.grad(H.energy_fn, has_aux=True)(state.params, state.rng_key)
# dE_dR = grads['~']['center']
# logging.info(f"{dE_dR=}")
# breakpoint()

g1 = jax.grad(partial(e_nuclear, charge=cgto.charge))(cgto.atom_coords)
logging.info(f"{g1=}")

# g1 = jax.grad(partial(e_nuclear, charge=cgto.charge))(cgto.atom_coords)
# g2, _ = jax.jacfwd(H.energy_fn, has_aux=True)(state.params, state.rng_key)
# breakpoint()

def e_fn(center):
state.params['~']['center'] = center
return H.energy_fn(state.params, state.rng_key)[0]

cur_center = state.params['~']['center']
center = np.zeros_like(cur_center)
center[0, 0] = 1.
new_e = e_fn(center)
logging.info(f"{new_e=}")

tangent = np.zeros_like(cur_center)
tangent[0, 0] = 1.
primal_out, tangent_out = jax.jvp(
e_fn, primals=(cur_center,), tangents=(tangent,)
)
dE_dR = tangent_out
logging.info(f"{dE_dR=}")
breakpoint()

mo_coeff = H.mo_coeff_fn(state.params, state.rng_key, apply_spin_mask=False)
t = Transition(mo_coeff, energies, mo_grads)
Expand Down
30 changes: 14 additions & 16 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,24 @@ classifiers =
packages = find:
python_requires = >=3.7
install_requires =
ase>=3.22.1
ase==3.22.1
bs4==0.0.1
chex==0.1.8
distrax>=0.1.2
dm-haiku>=0.0.9
einops>=0.6.1
jax-xc>=0.0.7
jax>=0.3.25
jaxlib>=0.3.25
jaxtyping==0.2.15
matplotlib>=3.6.2
dm-haiku==0.0.12
einops==0.6.1
jax-xc>=0.0.8
jax==0.4.26
jaxlib==0.4.26
jaxtyping>=0.2.28
matplotlib==3.8.4
ml_collections==0.1.1
mpmath>=1.2.1
optax>=0.1.4
pandas>=1.5.2
mpmath==1.3.0
optax==0.2.2
pandas==2.2.2
pubchempy==1.0.4
pydantic==1.10.9
pyscf>=2.1.1
requests>=2.31.0
scipy>=1.9.0
pyscf==2.5.0
requests==2.31.0
scipy==1.13.0
shortuuid==1.0.11
tqdm==4.64.1

Expand Down
4 changes: 2 additions & 2 deletions third_party/pip_requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ ase==3.22.1
bs4==0.0.1
dm-haiku==0.0.12
einops==0.6.1
jax-xc==0.0.8
jax-xc>=0.0.8
# --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# jax[cuda12_local]==0.4.13
jax==0.4.26
jaxlib==0.4.26
jaxtyping==0.2.28
jaxtyping>=0.2.28
matplotlib==3.8.4
ml_collections==0.1.1
mpmath==1.3.0
Expand Down

0 comments on commit 7dfe34c

Please sign in to comment.