diff --git a/d4ft/logger.py b/d4ft/logger.py index d667e7d..3b2a3bd 100644 --- a/d4ft/logger.py +++ b/d4ft/logger.py @@ -51,7 +51,7 @@ def log_step(self, metrics: NamedTuple, t: int, thresh: float) -> None: def get_segment_summary(self) -> pd.DataFrame: segment_df = self.data_df[self.last_t:] self.last_t = self.data_df.index[-1] - logging.info(f"Iter: {self.last_t}\n{segment_df.mean()}") + logging.info(f"Iter: {self.last_t}\n{segment_df.iloc[-1]}") return segment_df def log_summary(self) -> None: diff --git a/d4ft/solver/sgd.py b/d4ft/solver/sgd.py index f3681b5..b324694 100644 --- a/d4ft/solver/sgd.py +++ b/d4ft/solver/sgd.py @@ -29,7 +29,7 @@ def scipy_opt( solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params, - key: jax.random.KeyArray + key: jax.Array ) -> float: energy_fn_jit = jax.jit(lambda mo_coeff: H.energy_fn(mo_coeff, key)[0]) import jaxopt @@ -40,7 +40,7 @@ def scipy_opt( def sgd( solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params, - key: jax.random.KeyArray + key: jax.Array ) -> Tuple[RunLogger, Trajectory]: @jax.jit diff --git a/main.py b/main.py index 23c3ab3..320b6a5 100644 --- a/main.py +++ b/main.py @@ -16,12 +16,12 @@ import string from pathlib import Path from typing import Any +import jax import matplotlib.pyplot as plt import pandas as pd import shortuuid from absl import app, flags, logging -from jax.config import config from ml_collections.config_flags import config_flags from d4ft.config import D4FTConfig @@ -60,8 +60,8 @@ def get_rxn_energy(rxn: str, benchmark: str, df: pd.DataFrame) -> float: def main(_: Any) -> None: - config.update("jax_enable_x64", FLAGS.use_f64) - config.update("jax_debug_nans", FLAGS.debug_nans) + jax.config.update("jax_enable_x64", FLAGS.use_f64) + jax.config.update("jax_debug_nans", FLAGS.debug_nans) cfg: D4FTConfig = FLAGS.config print(cfg)