diff --git a/benchmarks/bench_mahler_yum.py b/benchmarks/bench_mahler_yum.py index e31c2e2a..fa8e3f6b 100644 --- a/benchmarks/bench_mahler_yum.py +++ b/benchmarks/bench_mahler_yum.py @@ -20,22 +20,13 @@ def _build(self): create_inputs, ) - start_params_without_beta = { - k: v for k, v in START_PARAMS.items() if k != "beta" - } - self.model = MAHLER_YUM_MODEL - common_params, initial_states, _discount_factor_type = create_inputs( + common_params, initial_states = create_inputs( seed=0, n_simulation_subjects=_N_SUBJECTS, - **start_params_without_beta, + **START_PARAMS, ) - self.model_params = { - "alive": { - "discount_factor": START_PARAMS["beta"]["mean"], - **common_params, - }, - } + self.model_params = {"alive": common_params} self.initial_conditions = { **initial_states, "regime": jnp.full( diff --git a/docs/examples/mahler_yum_2024.md b/docs/examples/mahler_yum_2024.md index 3d6ae63d..5d816efe 100644 --- a/docs/examples/mahler_yum_2024.md +++ b/docs/examples/mahler_yum_2024.md @@ -27,35 +27,26 @@ from lcm_examples.mahler_yum_2024 import ( create_inputs, ) -# Build inputs (params, initial states, discount factor types) -start_params_without_beta = {k: v for k, v in START_PARAMS.items() if k != "beta"} -common_params, initial_states, discount_factor_types = create_inputs( +n_subjects = 1_000 + +# Build inputs: per-subject initial states include `discount_type` +# (small/large), and `params["discount_factor"]["discount_factor_by_type"]` +# carries the two-element beta array that the `discount_factor` DAG +# function indexes with the state. +common_params, initial_states = create_inputs( seed=7235, - n_simulation_subjects=1_000, - **start_params_without_beta, + n_simulation_subjects=n_subjects, + **START_PARAMS, ) -beta_mean = START_PARAMS["beta"]["mean"] -beta_std = START_PARAMS["beta"]["std"] - -# Select initial states with high discount factor type -selected_ids_high = jnp.flatnonzero(discount_factor_types) -initial_states_high = { - state: values[selected_ids_high] for state, values in initial_states.items() -} - -# Solve and simulate for high discount factor type +# One solve, one simulate — both discount types are handled inside the +# regime via the `discount_type` state. result = MAHLER_YUM_MODEL.simulate( - params={ - "alive": { - "discount_factor": beta_mean + beta_std, - **common_params, - }, - }, + params={"alive": common_params}, initial_conditions={ - **initial_states_high, + **initial_states, "regime": jnp.full( - selected_ids_high.shape[0], + n_subjects, MAHLER_YUM_MODEL.regime_names_to_ids["alive"], ), }, diff --git a/src/lcm/model_processing.py b/src/lcm/model_processing.py index e745f9b1..b544a679 100644 --- a/src/lcm/model_processing.py +++ b/src/lcm/model_processing.py @@ -24,6 +24,10 @@ ) from lcm.params.sequence_leaf import SequenceLeaf from lcm.regime import Regime +from lcm.regime_building.h_dag import ( + get_h_accepted_params, + get_h_dag_target_names, +) from lcm.regime_building.processing import ( InternalRegime, process_regimes, @@ -205,6 +209,7 @@ def _validate_all_variables_used(regimes: Mapping[str, Regime]) -> list[str]: Each state or action must appear in at least one of: - The concurrent valuation (utility or constraints) - A transition function + - An H-DAG target function (a regime function whose output H consumes) Args: regimes: Mapping of regime names to regimes to validate. @@ -219,6 +224,11 @@ def _validate_all_variables_used(regimes: Mapping[str, Regime]) -> list[str]: variable_names = set(regime.states) | set(regime.actions) user_functions = dict(regime.get_all_functions(phase="solve")) + h_accepted_params = get_h_accepted_params(user_functions) + h_dag_target_names = get_h_dag_target_names( + functions=user_functions, h_accepted_params=h_accepted_params + ) + targets = [ "utility", *list(regime.constraints), @@ -228,6 +238,7 @@ def _validate_all_variables_used(regimes: Mapping[str, Regime]) -> list[str]: if name.startswith("next_") and not getattr(user_functions[name], "_is_auto_identity", False) ), + *h_dag_target_names, ] reachable = get_ancestors( user_functions, targets=targets, include_targets=False diff --git a/src/lcm/regime_building/Q_and_F.py b/src/lcm/regime_building/Q_and_F.py index a9bd594a..492d8a52 100644 --- a/src/lcm/regime_building/Q_and_F.py +++ b/src/lcm/regime_building/Q_and_F.py @@ -6,6 +6,7 @@ from dags import concatenate_functions, with_signature from jax import Array +from lcm.regime_building.h_dag import get_h_dag_target_names from lcm.regime_building.next_state import ( get_next_state_function_for_solution, get_next_stochastic_weights_function, @@ -135,6 +136,13 @@ def get_Q_and_F( exclude=frozenset(), ) + # Resolve H arguments that are regime-function outputs (e.g. a + # `discount_factor` DAG function that indexes a per-type Series by a + # state). `None` when H only needs state/action/user-param values. + _h_dag_func = _get_h_dag_func( + functions=functions, h_accepted_params=_H_accepted_params + ) + @with_signature( args=arg_names_of_Q_and_F, return_annotation="tuple[FloatND, BoolND]" ) @@ -203,6 +211,8 @@ def Q_and_F( H_kwargs = { k: v for k, v in states_actions_params.items() if k in _H_accepted_params } + if _h_dag_func is not None: + H_kwargs |= _h_dag_func(**states_actions_params) Q_arr = _H_func(utility=U_arr, E_next_V=E_next_V, **H_kwargs) # Handle cases when there is only one state. @@ -301,6 +311,9 @@ def get_compute_intermediates( _H_accepted_params = frozenset( get_union_of_args([_H_func]) - {"utility", "E_next_V"} ) + _h_dag_func = _get_h_dag_func( + functions=functions, h_accepted_params=_H_accepted_params + ) arg_names_of_compute_intermediates = _get_arg_names_of_Q_and_F( [ @@ -357,6 +370,8 @@ def compute_intermediates( H_kwargs = { k: v for k, v in states_actions_params.items() if k in _H_accepted_params } + if _h_dag_func is not None: + H_kwargs |= _h_dag_func(**states_actions_params) Q_arr = _H_func(utility=U_arr, E_next_V=E_next_V, **H_kwargs) return U_arr, F_arr, E_next_V, Q_arr, active_regime_probs @@ -536,6 +551,54 @@ def _outer(**kwargs: Float1D) -> FloatND: ) +def _get_h_dag_func( + *, + functions: FunctionsMapping, + h_accepted_params: frozenset[str], +) -> Callable[..., dict[str, Any]] | None: + """Compile a DAG that resolves H arguments computed by regime functions. + + `H` may name any argument supported by regime functions: states, + actions, flat params, or outputs of other user-provided functions. + Names in H's signature are resolved at runtime from, in order: + + 1. `states_actions_params` (states, actions, and flat params — the + same scalar pool every regime function draws from), and + 2. DAG-output functions, compiled here. + + This helper handles only (2): for every name in H's signature that + is also a user-provided function, compile a DAG target so its + output can be merged into `H_kwargs` alongside the values supplied + by (1). If no such names exist, return `None`. + + Args: + functions: Regime functions (user and generated). + h_accepted_params: Names H accepts beyond `utility` / `E_next_V`. + + Returns: + A callable mapping `states_actions_params` kwargs to a dict of + the resolved DAG outputs, or `None` if H needs no DAG outputs. + + """ + dag_targets = tuple( + sorted( + get_h_dag_target_names( + functions=functions, h_accepted_params=h_accepted_params + ) + ) + ) + + if not dag_targets: + return None + + return concatenate_functions( + functions={k: v for k, v in functions.items() if k != "H"}, + targets=list(dag_targets), + return_type="dict", + enforce_signature=False, + ) + + def _get_U_and_F( *, functions: FunctionsMapping, diff --git a/src/lcm/regime_building/h_dag.py b/src/lcm/regime_building/h_dag.py new file mode 100644 index 00000000..2a0025d3 --- /dev/null +++ b/src/lcm/regime_building/h_dag.py @@ -0,0 +1,72 @@ +"""H's DAG-target bookkeeping, shared between runtime and validation. + +The default Bellman aggregator `H(utility, E_next_V, discount_factor)` — +and any user-supplied H — may declare parameters that are not +states/actions/user-params but are outputs of regime functions +registered under the same name (e.g. a `discount_factor` DAG function +that indexes a per-type Series by a `discount_type` state). + +This module exposes: + +- `get_h_accepted_params`: H's signature minus `utility` / `E_next_V`. +- `get_h_dag_target_names`: those H parameters that are *also* regime + functions. Q_and_F compiles these into a runtime DAG; + `_validate_all_variables_used` uses them as reachability targets so + states consumed only via H's DAG dependencies count as "used". +""" + +from collections.abc import Callable, Mapping +from typing import Any + +from lcm.utils.functools import get_union_of_args + + +def get_h_accepted_params( + functions: Mapping[str, Callable[..., Any]], +) -> frozenset[str]: + """H's signature parameters, minus `utility` and `E_next_V`. + + Empty when the regime has no `H` (terminal regimes). + + Args: + functions: Mapping of regime function names to callables (user + and generated). + + Returns: + Frozenset of parameter names H accepts beyond `utility` / `E_next_V`. + + """ + h_func = functions.get("H") + if h_func is None: + return frozenset() + return frozenset(get_union_of_args([h_func]) - {"utility", "E_next_V"}) + + +def get_h_dag_target_names( + *, + functions: Mapping[str, Callable[..., Any]], + h_accepted_params: frozenset[str], +) -> frozenset[str]: + """Names of regime functions whose outputs H consumes via the DAG. + + These are H's signature parameters that are also regime functions, + minus `H`, `utility`, `feasibility` (H cannot consume its own + output; `utility` is wired directly from `U_and_F`; `feasibility` + is never a legitimate H input). + + Args: + functions: Mapping of regime function names to callables (user + and generated). + h_accepted_params: Names H accepts beyond `utility` / `E_next_V` + (typically the output of `get_h_accepted_params`). + + Returns: + Frozenset of regime function names whose outputs are routed + into H at runtime. + + """ + return frozenset(h_accepted_params) & set(functions) - { + "H", + "utility", + "feasibility", + } diff --git a/src/lcm_examples/mahler_yum_2024/__init__.py b/src/lcm_examples/mahler_yum_2024/__init__.py index 34397498..ca2eea6d 100644 --- a/src/lcm_examples/mahler_yum_2024/__init__.py +++ b/src/lcm_examples/mahler_yum_2024/__init__.py @@ -9,6 +9,7 @@ DEAD_REGIME, MAHLER_YUM_MODEL, START_PARAMS, + DiscountType, Education, Effort, Health, @@ -26,6 +27,7 @@ "DEAD_REGIME", "MAHLER_YUM_MODEL", "START_PARAMS", + "DiscountType", "Education", "Effort", "Health", diff --git a/src/lcm_examples/mahler_yum_2024/_model.py b/src/lcm_examples/mahler_yum_2024/_model.py index 481cf90a..f3698ae4 100644 --- a/src/lcm_examples/mahler_yum_2024/_model.py +++ b/src/lcm_examples/mahler_yum_2024/_model.py @@ -35,7 +35,6 @@ FloatND, Int1D, Period, - RegimeName, ) from lcm.utils.dispatchers import productmap @@ -111,12 +110,30 @@ class ProductivityShock: val4: int +@categorical(ordered=True) +class DiscountType: + small: int + large: int + + @categorical(ordered=False) class RegimeId: alive: int dead: int +def discount_factor( + discount_type: DiscreteState, + discount_factor_by_type: FloatND, +) -> FloatND: + """Per-period discount factor indexed by `discount_type`. + + Wired as a DAG function on `ALIVE_REGIME.functions`; pylcm's default + Bellman aggregator picks the scalar up as a DAG-output H input. + """ + return discount_factor_by_type[discount_type] + + def utility( scaled_adjustment_cost: FloatND, fcost: FloatND, @@ -309,6 +326,7 @@ def dead_is_active(age: int, initial_age: float) -> bool: "education": DiscreteGrid(Education), "productivity": DiscreteGrid(ProductivityType), "health_type": DiscreteGrid(HealthType), + "discount_type": DiscreteGrid(DiscountType, batch_size=1), }, state_transitions={ "wealth": next_wealth, @@ -317,6 +335,7 @@ def dead_is_active(age: int, initial_age: float) -> bool: "education": None, "productivity": None, "health_type": None, + "discount_type": None, }, actions={ "labor_supply": DiscreteGrid(LaborSupply), @@ -336,6 +355,11 @@ def dead_is_active(age: int, initial_age: float) -> bool: "taxed_income": taxed_income, "pension": pension, "scaled_productivity_shock": scaled_productivity_shock, + # Heterogeneous β: the scalar is produced by indexing + # `discount_factor_by_type` by the `discount_type` state, and + # pylcm's default Bellman aggregator picks it up as a DAG-output + # argument (see `_default_H`). + "discount_factor": discount_factor, }, constraints={ "retirement_constraint": retirement_constraint, @@ -343,10 +367,20 @@ def dead_is_active(age: int, initial_age: float) -> bool: }, ) + +def dead_utility(discount_type: DiscreteState) -> FloatND: # noqa: ARG001 + """Dead-regime utility: always zero. `discount_type` is in the + signature so pylcm's usage check accepts the state declaration.""" + return jnp.asarray(0.0) + + DEAD_REGIME = Regime( transition=None, active=partial(dead_is_active, initial_age=ages.values[0]), - functions={"utility": lambda: 0.0}, + states={ + "discount_type": DiscreteGrid(DiscountType, batch_size=1), + }, + functions={"utility": dead_utility}, ) MAHLER_YUM_MODEL = Model( @@ -582,12 +616,13 @@ def create_inputs( xi: dict[str, dict[str, list[float]]], income_process: dict[str, dict[str, float] | float], chi: list[float], + beta: dict[str, float], psi: float, bb: float, conp: float, penre: float, sigma: int, -) -> tuple[dict[RegimeName, Any], dict[RegimeName, Any], Int1D]: +) -> tuple[dict[str, Any], dict[str, Any]]: # Create variable grids from supplied parameters income_grid = create_income_grid(income_process) # ty: ignore[invalid-argument-type] chimax_grid = create_chimaxgrid(chi) @@ -598,6 +633,10 @@ def create_inputs( regime_transition = create_regime_transition_grid() + discount_factor_by_type = jnp.array( + [beta["mean"] - beta["std"], beta["mean"] + beta["std"]] + ) + params = { "disutil": {"phigrid": phi_grid}, "fcost": {"psi": psi, "xigrid": xi_grid}, @@ -608,6 +647,7 @@ def create_inputs( "scaled_productivity_shock": {"sigx": jnp.sqrt(income_process["sigx"])}, # ty: ignore[invalid-argument-type] "next_health": {"probs_array": tr2yp_grid}, "next_regime": {"probs_array": regime_transition}, + "discount_factor": {"discount_factor_by_type": discount_factor_by_type}, } # Create initial states for the simulation @@ -646,7 +686,6 @@ def create_inputs( initial_productivity = prod[types] initial_effort = jnp.searchsorted(eff_grid, init_distr_2b2t2h[:, 2][types]) initial_adjustment_cost = random.uniform(new_keys[1], (n_simulation_subjects,)) - discount_factor_type = discount[types] prod_dist = jax.lax.fori_loop( 0, 200, @@ -666,5 +705,6 @@ def create_inputs( "adjustment_cost": initial_adjustment_cost, "education": initial_education, "productivity": initial_productivity, + "discount_type": discount[types], } - return params, initial_states, discount_factor_type + return params, initial_states diff --git a/tests/data/regression_tests/f64/mahler_yum_simulation.pkl b/tests/data/regression_tests/f64/mahler_yum_simulation.pkl index b4109af2..c72a5baa 100644 Binary files a/tests/data/regression_tests/f64/mahler_yum_simulation.pkl and b/tests/data/regression_tests/f64/mahler_yum_simulation.pkl differ diff --git a/tests/data/regression_tests/generate_benchmark_data.py b/tests/data/regression_tests/generate_benchmark_data.py index 546a5ca3..748ddd51 100644 --- a/tests/data/regression_tests/generate_benchmark_data.py +++ b/tests/data/regression_tests/generate_benchmark_data.py @@ -93,19 +93,13 @@ def _generate_mortality(data_dir: Path) -> None: def _generate_mahler_yum(data_dir: Path) -> None: n_subjects = 4 - start_params_without_beta = {k: v for k, v in START_PARAMS.items() if k != "beta"} - common_params, initial_states, _discount_factor_type = create_inputs( + common_params, initial_states = create_inputs( seed=0, n_simulation_subjects=n_subjects, - **start_params_without_beta, # ty: ignore[invalid-argument-type] + **START_PARAMS, # ty: ignore[invalid-argument-type] ) model = MAHLER_YUM_MODEL - params = { - "alive": { - "discount_factor": START_PARAMS["beta"]["mean"], # ty: ignore[invalid-argument-type, not-subscriptable] - **common_params, - }, - } + params = {"alive": common_params} initial_conditions = { **initial_states, "regime": jnp.full( @@ -116,7 +110,7 @@ def _generate_mahler_yum(data_dir: Path) -> None: } result = model.simulate( - params=params, # ty: ignore[invalid-argument-type] + params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, seed=12345, diff --git a/tests/solution/test_custom_aggregator.py b/tests/solution/test_custom_aggregator.py index f595956f..6dcd227c 100644 --- a/tests/solution/test_custom_aggregator.py +++ b/tests/solution/test_custom_aggregator.py @@ -1,5 +1,7 @@ """Test that a custom aggregation function H can be used in a model.""" +from collections.abc import Callable + import jax.numpy as jnp from numpy.testing import assert_array_equal @@ -9,6 +11,7 @@ ContinuousAction, ContinuousState, DiscreteAction, + DiscreteState, FloatND, ScalarInt, ) @@ -77,36 +80,89 @@ def ces_H( FINAL_AGE_ALIVE = START_AGE + N_PERIODS - 2 # = 2 -def _make_model(custom_H=None): - """Create a simple model, optionally with a custom H.""" - functions = { +@categorical(ordered=False) +class PrefType: + type_0: int + type_1: int + type_2: int + + +def discount_factor_from_type( + pref_type: DiscreteState, + discount_factor_by_type: FloatND, +) -> FloatND: + """Index a per-type discount factor Series by the pref_type state. + + Wiring this as `functions["discount_factor"]` exercises pylcm's + ability to resolve an H argument from a DAG function output when + the name is not in `states_actions_params`. + """ + return discount_factor_by_type[pref_type] + + +def _make_model(custom_H=None, *, with_pref_type: bool = False): + """Create a simple model, optionally with a custom H and pref_type state. + + When `with_pref_type=True`, the working-life regime gains a + `pref_type` discrete state (`batch_size=1`, three categories) and + wires `discount_factor` as a DAG function that indexes + `discount_factor_by_type` by the state. This exercises the + "DAG output feeds H" path in pylcm's Q_and_F — and relies on + `_validate_all_variables_used` treating H-DAG targets as reachable + so `pref_type` counts as used without any workaround in `utility`. + """ + functions: dict[str, Callable] = { "utility": utility, "labor_income": labor_income, "is_working": is_working, } if custom_H is not None: functions["H"] = custom_H + if with_pref_type: + functions["discount_factor"] = discount_factor_from_type + + working_life_states: dict = { + "wealth": LinSpacedGrid(start=0.5, stop=10, n_points=30), + } + working_life_state_transitions: dict = { + "wealth": next_wealth, + } + dead_states: dict = {} + if with_pref_type: + working_life_states["pref_type"] = DiscreteGrid(PrefType, batch_size=1) + working_life_state_transitions["pref_type"] = None + dead_states["pref_type"] = DiscreteGrid(PrefType, batch_size=1) working_life_regime = Regime( actions={ "labor_supply": DiscreteGrid(LaborSupply), "consumption": LinSpacedGrid(start=0.5, stop=10, n_points=50), }, - states={ - "wealth": LinSpacedGrid(start=0.5, stop=10, n_points=30), - }, - state_transitions={ - "wealth": next_wealth, - }, + states=working_life_states, + state_transitions=working_life_state_transitions, constraints={"borrowing_constraint": borrowing_constraint}, transition=next_regime, functions=functions, active=lambda age: age <= FINAL_AGE_ALIVE, ) + # Terminal regime: when pref_type is declared as a state across + # regimes, dead_utility must reference it so pylcm's state-usage + # check accepts the declaration (terminal regimes have no H, so + # the H-DAG reachability fix does not apply here). + if with_pref_type: + + def dead_utility(pref_type: DiscreteState) -> FloatND: # noqa: ARG001 + return jnp.asarray(0.0) + else: + + def dead_utility() -> float: + return 0.0 + dead_regime = Regime( transition=None, - functions={"utility": lambda: 0.0}, + functions={"utility": dead_utility}, + states=dead_states, active=lambda age: age > FINAL_AGE_ALIVE, ) @@ -237,3 +293,286 @@ def test_terminal_regime_value_unchanged_by_H(): V_default[last_period]["dead"], V_ces[last_period]["dead"], ) + + +# --------------------------------------------------------------------------- +# DAG-output feeds H: `discount_factor` computed by a DAG function that +# indexes a per-type Series by the `pref_type` state. +# --------------------------------------------------------------------------- + + +def test_model_constructs_when_state_reachable_only_via_h_dag(): + """State reached only via H's DAG deps must pass the usage check. + + `pref_type` is used by `discount_factor_from_type`, whose output + feeds the default H. `utility` / `feasibility` / transitions do + not reference `pref_type`. Pre-fix, this failed with + "states defined but never used"; post-fix, the state-usage walk + treats H-DAG targets as reachable. + """ + _make_model(with_pref_type=True) + + +def test_dag_output_feeds_default_h_monotone_in_discount_factor(): + """Higher per-type discount factor ⇒ higher value function. + + The working-life regime uses the default H (which expects a scalar + `discount_factor`). That scalar is produced by a DAG function that + indexes `discount_factor_by_type` by the `pref_type` state. This + only works if pylcm's Q_and_F resolves H arguments from DAG + function outputs when they are not in `states_actions_params`. + """ + model = _make_model(with_pref_type=True) + + params = { + "discount_factor_by_type": jnp.array([0.70, 0.85, 0.99]), + "working_life": { + "utility": {"disutility_of_work": 0.5}, + "next_regime": {"final_age_alive": FINAL_AGE_ALIVE}, + }, + } + V = model.solve(params=params) + + # Pick a non-terminal period; slice each pref_type. + non_terminal_periods = [p for p in V if p < max(V.keys())] + assert non_terminal_periods + for period in non_terminal_periods: + # Shape is (..., n_pref_type). Compare averages across the + # other axes so the comparison is robust to the grid layout. + v = V[period]["working_life"] + # pref_type is the innermost batched state ⇒ last axis. + v_type_0 = jnp.mean(v[..., 0]) + v_type_1 = jnp.mean(v[..., 1]) + v_type_2 = jnp.mean(v[..., 2]) + assert v_type_0 < v_type_1 < v_type_2, ( + f"Expected V monotone in discount factor at period {period}; " + f"got {v_type_0:.4f} < {v_type_1:.4f} < {v_type_2:.4f}" + ) + + +# H's permissive kwarg contract: H may name any argument supported by +# regime functions — states, actions, flat params, or DAG-output +# functions. The following tests lock that contract in. + + +def wealth_H( + utility: float, + E_next_V: float, + discount_factor: float, + wealth: float, + wealth_weight: float, +) -> float: + return utility + discount_factor * E_next_V + wealth_weight * wealth + + +def test_h_consumes_continuous_state(): + """Solve when H names a continuous state; exact lift at the last period. + + Regression guard against a refactor that narrows `_H_accepted_params` + to reject state names. At the last period where `working_life` is + active, `E_next_V = 0` (dead utility is zero), so adding + `wealth_weight * wealth` to `Q` shifts `V` by exactly that term — + independent of the argmax. + """ + model = _make_model(custom_H=wealth_H) + common = { + "utility": {"disutility_of_work": 0.5}, + "next_regime": {"final_age_alive": FINAL_AGE_ALIVE}, + } + V_zero = model.solve( + params={ + "working_life": { + "H": {"discount_factor": 0.95, "wealth_weight": 0.0}, + **common, + }, + "dead": {}, + }, + ) + V_pos = model.solve( + params={ + "working_life": { + "H": {"discount_factor": 0.95, "wealth_weight": 0.1}, + **common, + }, + "dead": {}, + }, + ) + lift_at_terminal = ( + V_pos[FINAL_AGE_ALIVE]["working_life"] - V_zero[FINAL_AGE_ALIVE]["working_life"] + ) + expected = 0.1 * jnp.linspace(0.5, 10.0, 30) + assert bool(jnp.allclose(lift_at_terminal, expected, atol=1e-5)) + + +def consumption_H( + utility: float, + E_next_V: float, + discount_factor: float, + consumption: float, + action_weight: float, +) -> float: + return utility + discount_factor * E_next_V + action_weight * consumption + + +def test_h_consumes_continuous_action(): + """H may name a continuous action; non-zero weight shifts V. + + Regression guard: when `H` names `consumption`, the scalar at the + current action-gridpoint is bound at Q evaluation (before argmax). + A positive `action_weight` therefore shifts `V` relative to the + `action_weight=0` baseline. + """ + model = _make_model(custom_H=consumption_H) + common = { + "utility": {"disutility_of_work": 0.5}, + "next_regime": {"final_age_alive": FINAL_AGE_ALIVE}, + } + V_zero = model.solve( + params={ + "working_life": { + "H": {"discount_factor": 0.95, "action_weight": 0.0}, + **common, + }, + "dead": {}, + }, + ) + V_pos = model.solve( + params={ + "working_life": { + "H": {"discount_factor": 0.95, "action_weight": 0.1}, + **common, + }, + "dead": {}, + }, + ) + non_terminal = [p for p in V_zero if p <= FINAL_AGE_ALIVE] + assert non_terminal + diffs_exist = any( + not jnp.allclose(V_zero[p]["working_life"], V_pos[p]["working_life"]) + for p in non_terminal + ) + assert diffs_exist, "action_weight>0 must shift V at some working-life period" + + +def labor_supply_H( + utility: float, + E_next_V: float, + discount_factor: float, + labor_supply: DiscreteAction, + bonus: float, +) -> FloatND: + return ( + utility + discount_factor * E_next_V + bonus * labor_supply.astype(jnp.float32) + ) + + +def test_h_consumes_discrete_action(): + """H may name a discrete action; solve compiles and V shapes match baseline. + + Regression guard: discrete action scalars reach `H` via + `states_actions_params` the same way continuous ones do. + """ + model = _make_model(custom_H=labor_supply_H) + V = model.solve( + params={ + "working_life": { + "H": {"discount_factor": 0.95, "bonus": 0.1}, + "utility": {"disutility_of_work": 0.5}, + "next_regime": {"final_age_alive": FINAL_AGE_ALIVE}, + }, + "dead": {}, + }, + ) + baseline = _make_model().solve( + params={ + "discount_factor": 0.95, + "working_life": { + "utility": {"disutility_of_work": 0.5}, + "next_regime": {"final_age_alive": FINAL_AGE_ALIVE}, + }, + }, + ) + for period in V: + for regime in V[period]: + assert V[period][regime].shape == baseline[period][regime].shape + + +def pref_type_direct_H( + utility: float, + E_next_V: float, + discount_factor: float, + pref_type: DiscreteState, +) -> FloatND: + return utility + discount_factor * E_next_V + 0.1 * pref_type.astype(jnp.float32) + + +def test_h_consumes_discrete_state(): + """H may name a discrete state directly, without a DAG function of that name. + + Regression guard: `pref_type` reaches `H` as a scalar per + state-action gridpoint — the same path utility uses. + `discount_factor` here is still a DAG output + (`discount_factor_from_type`), proving state-direct and + DAG-output paths can coexist in one `H`. + """ + model = _make_model(custom_H=pref_type_direct_H, with_pref_type=True) + V = model.solve( + params={ + "discount_factor_by_type": jnp.array([0.70, 0.85, 0.99]), + "working_life": { + "utility": {"disutility_of_work": 0.5}, + "next_regime": {"final_age_alive": FINAL_AGE_ALIVE}, + }, + }, + ) + non_terminal = [p for p in V if p <= FINAL_AGE_ALIVE] + assert non_terminal + for period in non_terminal: + v = V[period]["working_life"] + assert 3 in v.shape, f"Period {period}: pref_type axis missing ({v.shape})" + assert bool(jnp.all(jnp.isfinite(v))) + + +def mixed_H( + utility: float, + E_next_V: float, + discount_factor: float, + ies: float, + wealth: float, + consumption: float, + pref_type: DiscreteState, +) -> FloatND: + rho = 1 - ies + u_eff = utility + 1e-3 * wealth + v_eff = E_next_V + 1e-3 * consumption + combined = ((1 - discount_factor) * u_eff**rho + discount_factor * v_eff**rho) ** ( + 1 / rho + ) + return combined + 1e-3 * pref_type.astype(jnp.float32) + + +def test_h_consumes_flat_param_state_action_and_dag_output(): + """H may simultaneously name a flat param, a state, an action, and a DAG output. + + Regression guard: every kwarg-resolution path fires at once — + `states_actions_params` supplies wealth/consumption/pref_type, + flat params supply `ies`, the DAG supplies `discount_factor`. + """ + model = _make_model(custom_H=mixed_H, with_pref_type=True) + V = model.solve( + params={ + "discount_factor_by_type": jnp.array([0.70, 0.85, 0.99]), + "working_life": { + "H": {"ies": 0.5}, + "utility": {"disutility_of_work": 0.5}, + "next_regime": {"final_age_alive": FINAL_AGE_ALIVE}, + }, + }, + ) + for period in V: + if "working_life" in V[period]: + v = V[period]["working_life"] + assert bool(jnp.all(jnp.isfinite(v))), ( + f"Non-finite working_life V at period {period}" + ) + assert 3 in v.shape diff --git a/tests/test_regression_test.py b/tests/test_regression_test.py index 408780b0..a6a47168 100644 --- a/tests/test_regression_test.py +++ b/tests/test_regression_test.py @@ -167,19 +167,13 @@ def test_regression_mahler_yum(): expected = pd.read_pickle(_PRECISION_DIR / "mahler_yum_simulation.pkl") n_subjects = 4 - start_params_without_beta = {k: v for k, v in START_PARAMS.items() if k != "beta"} - common_params, initial_states, _discount_factor_type = create_inputs( + common_params, initial_states = create_inputs( seed=0, n_simulation_subjects=n_subjects, - **start_params_without_beta, # ty: ignore[invalid-argument-type] + **START_PARAMS, # ty: ignore[invalid-argument-type] ) model = MAHLER_YUM_MODEL - params = { - "alive": { - "discount_factor": START_PARAMS["beta"]["mean"], # ty: ignore[invalid-argument-type, not-subscriptable] - **common_params, - }, - } + params = {"alive": common_params} initial_conditions = { **initial_states, "regime": jnp.full( @@ -190,7 +184,7 @@ def test_regression_mahler_yum(): } got = model.simulate( - params=params, # ty: ignore[invalid-argument-type] + params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, seed=12345,