diff --git a/gEconpy/model/statespace.py b/gEconpy/model/statespace.py index d6b75ef..6f75ee6 100644 --- a/gEconpy/model/statespace.py +++ b/gEconpy/model/statespace.py @@ -351,7 +351,6 @@ def build_statespace_graph( self, data: np.ndarray | pd.DataFrame | pt.TensorVariable, register_data: bool = True, - mode: str | None = None, missing_fill_value: float | None = None, cov_jitter: float | None = JITTER_DEFAULT, save_kalman_filter_outputs_in_idata: bool = False, @@ -359,7 +358,7 @@ def build_statespace_graph( add_bk_check: bool = False, add_solver_success_check: bool = False, add_steady_state_penalty: bool = True, - tol: float = 1e-8, + resid_penalty: float = 1.0, ) -> None: super().build_statespace_graph( data=data, @@ -367,7 +366,7 @@ def build_statespace_graph( missing_fill_value=missing_fill_value, cov_jitter=cov_jitter, save_kalman_filter_outputs_in_idata=save_kalman_filter_outputs_in_idata, - mode=mode, + mode=self._mode, ) pymc_model = pm.modelcontext(None) @@ -415,10 +414,6 @@ def build_statespace_graph( shock_idx = pt.arange(n_shocks) state_var_mask = pt.bitwise_and(tm1_idx, t_idx) - # PP = T[pt.abs(T) < tol].set(0.0) - # PP = T.copy() - # QQ = QQ[pt.abs(QQ) < tol].set(0.0) - QQ = R[:n_vars, :] P = T[state_var_mask, :][:, state_var_mask] Q = QQ[state_var_mask, :][:, shock_idx] @@ -438,14 +433,7 @@ def build_statespace_graph( # Add penalty terms to the likelihood to rule out invalid solutions pm.Potential( "solution_norm_penalty", - -(norm_deterministic + norm_stochastic), - # pt.switch( - # pt.bitwise_and( - # pt.lt(norm_deterministic, tol), pt.lt(norm_stochastic, tol) - # ), - # 0.0, - # -np.inf, - # ), + -resid_penalty * (norm_deterministic + norm_stochastic), ) if add_bk_check: @@ -456,11 +444,11 @@ def build_statespace_graph( if add_solver_success_check: policy_resid = pm.Deterministic("policy_resid", policy_resid) - pm.Potential("policy_resid_penalty", -policy_resid) + pm.Potential("policy_resid_penalty", -resid_penalty * policy_resid) if add_steady_state_penalty: ss_resid = pm.Deterministic("ss_resid", ss_resid) - pm.Potential("steady_state_resid_penalty", -ss_resid) + pm.Potential("steady_state_resid_penalty", -resid_penalty * ss_resid) def priors_to_preliz(self): priors = self.priors[0] diff --git a/gEconpy/solvers/cycle_reduction.py b/gEconpy/solvers/cycle_reduction.py index ce2b36f..39cbf2c 100644 --- a/gEconpy/solvers/cycle_reduction.py +++ b/gEconpy/solvers/cycle_reduction.py @@ -11,6 +11,7 @@ from gEconpy.solvers.shared import ( o1_policy_function_adjoints, pt_compute_selection_matrix, + stabilize, ) @@ -193,7 +194,10 @@ def cycle_step(A0, A1, A2, A1_hat, step_num, idx_0, idx_1): tmp = pt.dot( pt.vertical_stack(A0, A2), pt.linalg.solve( - A1, pt.horizontal_stack(A0, A2), assume_a="gen", check_finite=False + stabilize(A1), + pt.horizontal_stack(A0, A2), + assume_a="gen", + check_finite=False, ), ) @@ -228,7 +232,7 @@ def step(A0, A1, A2, A1_hat, norm, step_num, idx_0, idx_1, tol): ) A1_hat = A1_hat[-1] - T = -pt.linalg.solve(A1_hat, A, assume_a="gen", check_finite=False) + T = -pt.linalg.solve(stabilize(A1_hat), A, assume_a="gen", check_finite=False) return [T, n_steps[-1]] @@ -238,7 +242,7 @@ def scan_cycle_reduction( B: pt.TensorLike, C: pt.TensorLike, D: pt.TensorLike, - max_iter: int = 1000, + max_iter: int = 100, tol: float = 1e-7, mode: str | None = None, use_adjoint_gradients: bool = True, diff --git a/tests/test_perturbation.py b/tests/test_perturbation.py index faf125d..1820cd5 100644 --- a/tests/test_perturbation.py +++ b/tests/test_perturbation.py @@ -184,7 +184,7 @@ def test_cycle_reduction_gradients(op): for name, x in zip(list("ABCD"), [A, B, C, D]) ) - T, R = op(A_pt, B_pt, C_pt, D_pt) + T, R, *_ = op(A_pt, B_pt, C_pt, D_pt) T_grad = pt.grad(T.sum(), [A_pt, B_pt, C_pt]) f = pytensor.function(