diff --git a/kfac_jax/_src/optimizer.py b/kfac_jax/_src/optimizer.py index ea35c29..0c99d8c 100644 --- a/kfac_jax/_src/optimizer.py +++ b/kfac_jax/_src/optimizer.py @@ -383,7 +383,7 @@ def __init__( self._value_func_has_rng = value_func_has_rng self._value_func: ValueFunc = convert_value_and_grad_to_value_func( value_and_grad_func, - has_aux=value_func_has_aux, + has_aux=value_func_has_aux or value_func_has_state, ) self._l2_reg = jnp.asarray(l2_reg) self._use_adaptive_learning_rate = use_adaptive_learning_rate