Skip to content

Commit

Permalink
Allow users to mark equations to be excluded from the final system
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Nov 2, 2024
1 parent 7acb733 commit 71fe39d
Show file tree
Hide file tree
Showing 5 changed files with 757 additions and 556 deletions.
89 changes: 83 additions & 6 deletions gEconpy/model/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,37 @@ def _validate_initialization(self) -> bool:
if not control_found:
raise ControlVariableNotFoundException(self.name, control)

# Validate equation flags
# - the "is_calibrating" key can only occur in the calibration block
# - the "exclude" key can only occur in the constraints block
valid_flags = {
"is_calibrating": ["calibration"],
"exclude": ["constraints"],
}

for name, eq_block in zip(
["definitions", "objective", "constraints", "identities"],
[self.definitions, self.objective, self.constraints, self.identities],
):
if eq_block is not None:
for key, eq in eq_block.items():
if (
self.equation_flags[key].get("is_calibrating", False)
and name not in valid_flags["is_calibrating"]
):
raise ValueError(
f"Equation {eq} in {name} block of {self.name} has an invalid decorator: is_calibrating. "
f"This flag should only appear in the calibration block."
)
if (
self.equation_flags[key].get("exclude", False)
and name not in valid_flags["exclude"]
):
raise ValueError(
f"Equation {eq} in {name} block of {self.name} has an invalid decorator: exclude. "
f"This flag should only appear in the constraints block."
)

return True

def _validate_key(self, block_dict: dict, key: str) -> bool:
Expand Down Expand Up @@ -269,6 +300,46 @@ def _extract_lagrange_multipliers(

return result, multipliers

def _extract_decorators(
self, equations: list[list[str]], assumptions: dict
) -> tuple[list[list[str]], list[dict[str, bool]]]:
"""
Extract decorators from the equations in the block. Decorators are flags that indicate special properties of the
equation, such as whether it should be excluded from the final system of equations.
Parameters
----------
equations : list
A list of lists of strings, each list representing a model equation. Created by the
gEcon_parser.parsed_block_to_dict function.
assumptions : dict
Assumptions for the model.
Returns
-------
equations: list
List of lists of strings. All decorator strings have been removed.
flags: dict
A dictionary of flags for each equation, indexed by equation number.
"""

result, decorator_flags = [], []
for i, eq in enumerate(equations):
new_eq = []
flags = {}
for token in eq:
if token.startswith("@"):
decorator = token.removeprefix("@")
flags[decorator] = True
else:
new_eq.append(token)
result.append(new_eq)
decorator_flags.append(flags)

return result, decorator_flags

def _parse_variable_list(
self, block_dict: dict, key: str, assumptions: dict | None = None
) -> list[sp.Symbol] | None:
Expand Down Expand Up @@ -401,8 +472,10 @@ def _parse_equation_list(
equations, lagrange_multipliers = self._extract_lagrange_multipliers(
equations, assumptions
)
equations, decorators = self._extract_decorators(equations, assumptions)

parser_output = parse_equations.build_sympy_equations(equations, assumptions)

if len(parser_output) > 0:
equations, flags = list(zip(*parser_output))
else:
Expand All @@ -412,10 +485,13 @@ def _parse_equation_list(

equations = dict(zip(equation_numbers, equations))
flags = dict(zip(equation_numbers, flags))
decorator_flags = dict(zip(equation_numbers, decorators))

lagrange_multipliers = dict(zip(equation_numbers, lagrange_multipliers))
self.multipliers.update(lagrange_multipliers)
self.equation_flags.update(flags)
for k in equation_numbers:
self.equation_flags[k] = flags[k]
self.equation_flags[k].update(decorator_flags[k])

return equations

Expand Down Expand Up @@ -677,11 +753,12 @@ def solve_optimization(self, try_simplify: bool = True) -> None:
)

if self.constraints is not None:
_, constraints = unpack_keys_and_values(self.constraints)
for eq in constraints:
self.system_equations.append(
set_equality_equals_zero(eq.subs(sub_dict))
)
eq_idx, constraints = unpack_keys_and_values(self.constraints)
for idx, eq in zip(eq_idx, constraints):
if not self.equation_flags[idx].get("exclude", False):
self.system_equations.append(
set_equality_equals_zero(eq.subs(sub_dict))
)

if self.controls is None and self.objective is None:
return
Expand Down
75 changes: 75 additions & 0 deletions tests/Test GCNs/rbc_with_excluded.gcn
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
block HOUSEHOLD
{
definitions
{
u[] = C[] ^ (1 - sigma_C) / (1 - sigma_C) -
L[] ^ (1 + sigma_L) / (1 + sigma_L);
};
controls
{
K[], C[], L[], I[];
};
objective
{
U[] = u[] + beta * E[][U[1]];
};
constraints
{
@exclude
C[] + I[] = w[] * L[] + r[] * K[-1] : lambda[];

K[] = (1 - delta) * K[-1] + I[] : q[];
};

calibration
{
beta = 0.985;
delta = 0.025;
sigma_C = 2;
sigma_L = 1.5;
};
};


block FIRM
{
controls
{
K[-1], L[];
};

objective
{
TC[] = -(w[] * L[] + r[] * K[-1]);
};

constraints
{
Y[] = A[] * K[-1] ^ alpha * L[] ^ (1 - alpha) : P[];
};

identities
{
log(A[]) = rho_A * log(A[-1]) + epsilon_A[];
P[] = 1;
};

shocks
{
epsilon_A[];
};

calibration
{
alpha = 0.35;
rho_A = 0.95;
};
};

block EQUILIBRIUM
{
constraints
{
Y[] = C[] + I[];
};
};
54 changes: 54 additions & 0 deletions tests/test_block.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import re
import unittest

from pathlib import Path

import numpy as np
import pytest
import sympy as sp

from gEconpy.classes.time_aware_symbol import TimeAwareSymbol
Expand All @@ -15,6 +17,7 @@
)
from gEconpy.model.block import Block
from gEconpy.parser import constants, file_loaders, gEcon_parser
from gEconpy.parser.file_loaders import block_dict_to_equation_list
from gEconpy.utilities import set_equality_equals_zero, unpack_keys_and_values

ROOT = Path(__file__).parent.absolute()
Expand Down Expand Up @@ -258,6 +261,38 @@ def test_lagrange_multiplier_in_objective(self):
block.solve_optimization()


def test_invalid_decorator_raises():
test_file = """
block HOUSEHOLD
{
objective
{
@exclude
U[] = u[] + beta * E[][U[1]] : lambda[];
};
controls
{
u[];
};
};
"""

parser_output, prior_dict = gEcon_parser.preprocess_gcn(test_file)
block_dict, options, tryreduce, assumptions = (
gEcon_parser.split_gcn_into_dictionaries(parser_output)
)
block_dict = gEcon_parser.parsed_block_to_dict(block_dict["HOUSEHOLD"])
with pytest.raises(
ValueError,
match=re.escape(
"Equation Eq(U_t, beta*U_t+1 + u_t) in objective block of HOUSEHOLD "
"has an invalid decorator: exclude."
),
):
Block("HOUSEHOLD", block_dict)


class BlockTestCases(unittest.TestCase):
def setUp(self):
test_file = file_loaders.load_gcn(
Expand Down Expand Up @@ -588,5 +623,24 @@ def test_variable_list(self):
self.assertEqual({x.base_name for x in self.block.shocks}, {"epsilon"})


def test_block_with_exlcuded_equation():
test_file = file_loaders.load_gcn(
os.path.join(ROOT, "Test GCNs/rbc_with_excluded.gcn")
)

parser_output, prior_dict = gEcon_parser.preprocess_gcn(test_file)
block_dict, options, tryreduce, assumptions = (
gEcon_parser.split_gcn_into_dictionaries(parser_output)
)

block_dict = gEcon_parser.parsed_block_to_dict(block_dict["HOUSEHOLD"])

block = Block("HOUSEHOLD", block_dict)
block.solve_optimization()

# 6 equations are 4 controls, 1 objective, 1 constraint (excluding the excluded equation)
assert len(block.system_equations) == 6


if __name__ == "__main__":
unittest.main()
6 changes: 1 addition & 5 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,11 +583,7 @@ def test_model_gradient(backend, gcn_file):
@pytest.mark.parametrize("how", ["root", "minimize"], ids=["root", "minimize"])
@pytest.mark.parametrize(
"gcn_file",
[
"one_block_1_ss.gcn",
"open_rbc.gcn",
"full_nk.gcn",
],
["one_block_1_ss.gcn", "open_rbc.gcn", "full_nk.gcn", "rbc_with_excluded.gcn"],
)
@pytest.mark.parametrize(
"backend", ["numpy", "numba", "pytensor"], ids=["numpy", "numba", "pytensor"]
Expand Down
Loading

0 comments on commit 71fe39d

Please sign in to comment.