Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions examples/run_oguk.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,27 @@
)


def run_steady_state(age_specific: str = "pooled"):
def run_steady_state(age_specific: str = "pooled", multi_sector: bool = False):
"""Run baseline and reform steady state, print results."""
print(f"Solving baseline steady state (age_specific='{age_specific}')...")
sector_label = "8-sector" if multi_sector else "1-sector"
print(
f"Solving baseline steady state (age_specific='{age_specific}', {sector_label})..."
)
t0 = time.time()
baseline = solve_steady_state(start_year=2026, age_specific=age_specific)
baseline = solve_steady_state(
start_year=2026, age_specific=age_specific, multi_sector=multi_sector
)
print(f" Done in {time.time() - t0:.1f}s")

print(f"Solving reform steady state (age_specific='{age_specific}')...")
print(
f"Solving reform steady state (age_specific='{age_specific}', {sector_label})..."
)
t0 = time.time()
reform = solve_steady_state(
start_year=2026, policy=REFORM, age_specific=age_specific
start_year=2026,
policy=REFORM,
age_specific=age_specific,
multi_sector=multi_sector,
)
print(f" Done in {time.time() - t0:.1f}s")

Expand Down Expand Up @@ -87,17 +97,21 @@ def run_steady_state(age_specific: str = "pooled"):
print("=" * 60)


def run_tpi():
def run_tpi(multi_sector: bool = False):
"""Run baseline and reform transition paths, print results."""
from dask.distributed import Client

print("Running baseline + reform transition paths...")
sector_label = "8-sector" if multi_sector else "1-sector"
print(f"Running baseline + reform transition paths ({sector_label})...")
print("(This solves SS + TPI for both scenarios — may take a while.)")
client = Client(n_workers=2, threads_per_worker=1, memory_limit="2GB")
t0 = time.time()
try:
base_tp, reform_tp = run_transition_path(
start_year=2026, policy=REFORM, client=client
start_year=2026,
policy=REFORM,
client=client,
multi_sector=multi_sector,
)
finally:
client.close()
Expand Down Expand Up @@ -137,12 +151,13 @@ def run_tpi():
def main():
mode = sys.argv[1] if len(sys.argv) > 1 else "ss"
age_specific = sys.argv[2] if len(sys.argv) > 2 else "pooled"
multi_sector = "multi-sector" in sys.argv
if mode == "ss":
run_steady_state(age_specific=age_specific)
run_steady_state(age_specific=age_specific, multi_sector=multi_sector)
elif mode == "tpi":
run_tpi()
run_tpi(multi_sector=multi_sector)
else:
print(f"Usage: {sys.argv[0]} [ss|tpi] [pooled|brackets|each]")
print(f"Usage: {sys.argv[0]} [ss|tpi] [pooled|brackets|each] [multi-sector]")
sys.exit(1)


Expand Down
53 changes: 35 additions & 18 deletions oguk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ def _build_specs(
max_iter: int = 250,
age_specific: str = "pooled",
param_overrides: dict | None = None,
multi_sector: bool = False,
):
"""Build a calibrated Specifications object (internal).

Expand All @@ -706,6 +707,8 @@ def _build_specs(
are applied last so they take precedence over defaults and
calibration outputs. Demographic / tax-function keys are not
permitted here — those are set by calibrate().
multi_sector: If True, use 8-sector industry calibration (M=8).
If False (default), use a single representative sector (M=1).
"""
from ogcore.parameters import Specifications

Expand Down Expand Up @@ -738,20 +741,21 @@ def _build_specs(
]:
defaults.pop(key, None)

# Strip single-industry defaults that will be replaced by industry_params
for key in [
"gamma",
"gamma_g",
"epsilon",
"Z",
"cit_rate",
"io_matrix",
"alpha_c",
"delta_tau_annual",
"inv_tax_credit",
"tau_c",
]:
defaults.pop(key, None)
if multi_sector:
# Strip single-industry defaults that will be replaced by industry_params
for key in [
"gamma",
"gamma_g",
"epsilon",
"Z",
"cit_rate",
"io_matrix",
"alpha_c",
"delta_tau_annual",
"inv_tax_credit",
"tau_c",
]:
defaults.pop(key, None)

p = Specifications(
baseline=baseline, output_base=output_base, baseline_dir=baseline_dir
Expand Down Expand Up @@ -789,10 +793,14 @@ def _build_specs(
}
)

# Apply 8-sector industry calibration
defaults.update(get_industry_params())
# Levenberg-Marquardt is more robust than the default 'hybr' for M>1
defaults["SS_root_method"] = "lm"
if multi_sector:
# Apply 8-sector industry calibration
defaults.update(get_industry_params())
# hybr (Powell hybrid) for heterogeneous CES; LM gets stuck at ~1e-5
defaults["SS_root_method"] = "hybr"
# Relax tolerances for heterogeneous CES
defaults["mindist_SS"] = 1e-4
defaults["RC_SS"] = 1e-4

if param_overrides:
defaults.update(param_overrides)
Expand Down Expand Up @@ -850,6 +858,7 @@ def solve_steady_state(
max_iter: int = 250,
age_specific: str = "pooled",
param_overrides: dict | None = None,
multi_sector: bool = False,
) -> SteadyStateResult:
"""Solve for steady state equilibrium.

Expand All @@ -863,6 +872,8 @@ def solve_steady_state(
"each" — separate function per individual age (80)
param_overrides: Optional dict of OG-Core parameter names to values
for structural shocks (e.g. ``{"g_y_annual": 0.011}``).
multi_sector: If True, use 8-sector industry calibration (M=8).
If False (default), use a single representative sector (M=1).

Returns:
SteadyStateResult with equilibrium values
Expand All @@ -878,6 +889,7 @@ def solve_steady_state(
max_iter=max_iter,
age_specific=age_specific,
param_overrides=param_overrides,
multi_sector=multi_sector,
)
ss = run_SS(p, client=None)
return _ss_dict_to_result(ss)
Expand All @@ -889,6 +901,7 @@ def run_transition_path(
client=None,
age_specific: str = "pooled",
param_overrides: dict | None = None,
multi_sector: bool = False,
) -> tuple[TransitionPathResult, TransitionPathResult | None]:
"""Run baseline (and optionally reform) transition path.

Expand All @@ -906,6 +919,8 @@ def run_transition_path(
"each" — separate function per individual age (80)
param_overrides: Optional dict of OG-Core parameter names to values
for structural shocks (e.g. ``{"Z": [[1.004]]}``).
multi_sector: If True, use 8-sector industry calibration (M=8).
If False (default), use a single representative sector (M=1).

Returns:
(baseline_tp, reform_tp) — reform_tp is None if no policy/overrides
Expand All @@ -930,6 +945,7 @@ def run_transition_path(
base_dir,
baseline=True,
age_specific=age_specific,
multi_sector=multi_sector,
)

# Solve SS first to auto-calibrate alpha_G.
Expand Down Expand Up @@ -963,6 +979,7 @@ def run_transition_path(
baseline=False,
age_specific=age_specific,
param_overrides=param_overrides,
multi_sector=multi_sector,
)

ss_reform = SS.run_SS(p_reform, client=client)
Expand Down
48 changes: 38 additions & 10 deletions oguk/industry_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,17 @@ def _sector_gva() -> np.ndarray:
]


def _sector_tfp() -> list:
def _sector_tfp(epsilon=None, gamma=None) -> list:
"""Solow-residual TFP by sector, normalised so the GVA-weighted mean = 1.

Computes Z_m = GVA_m / (K_m^gamma_m * L_m^(1 - gamma_m)) for each sector,
then rescales so that sum(gva_share_m * Z_m) = 1.0. This ensures the
aggregate economy matches the baseline calibration while allowing
cross-sector productivity differences.
When epsilon is all 1.0 (Cobb-Douglas), computes:
Z_m = GVA_m / (K_m^gamma_m * L_m^(1 - gamma_m))

When epsilon differs from 1.0 (CES), computes the CES residual:
Z_m = GVA_m / [gamma_m^(1/eps) * K_m^((eps-1)/eps)
+ (1-gamma_m)^(1/eps) * L_m^((eps-1)/eps)]^(eps/(eps-1))

Then rescales so that sum(gva_share_m * Z_m) = 1.0.

Sources:
GVA: _GVA_BY_SIC_SECTION (ONS Blue Book 2024)
Expand All @@ -273,9 +277,31 @@ def _sector_tfp() -> list:
gva = _sector_gva()
capital = np.array(_CAPITAL_STOCK, dtype=float)
labour = np.array(_WORKFORCE_JOBS, dtype=float)
gamma = np.array(_GAMMA, dtype=float)

z_raw = gva / (capital**gamma * labour ** (1 - gamma))
if gamma is None:
gamma = np.array(_GAMMA, dtype=float)
else:
gamma = np.array(gamma, dtype=float)
if epsilon is None:
epsilon = np.ones(M, dtype=float)
else:
epsilon = np.array(epsilon, dtype=float)

z_raw = np.empty(M, dtype=float)
for m in range(M):
eps = epsilon[m]
g = gamma[m]
K = capital[m]
L = labour[m]
if eps == 1.0:
# Cobb-Douglas
z_raw[m] = gva[m] / (K**g * L ** (1 - g))
else:
# CES: Y = Z * [gamma^(1/eps)*K^((eps-1)/eps) + (1-gamma)^(1/eps)*L^((eps-1)/eps)]^(eps/(eps-1))
ces_aggregate = (
g ** (1 / eps) * K ** ((eps - 1) / eps)
+ (1 - g) ** (1 / eps) * L ** ((eps - 1) / eps)
) ** (eps / (eps - 1))
z_raw[m] = gva[m] / ces_aggregate

# Normalise so GVA-weighted mean equals 1
gva_shares = gva / gva.sum()
Expand Down Expand Up @@ -309,13 +335,15 @@ def get_industry_params() -> dict:
# Shrink gamma 40% toward aggregate mean (0.35) for solver stability
gamma_shrunk = [0.35 + 0.6 * (g - 0.35) for g in _GAMMA]

epsilon = list(_EPSILON)

return {
"M": M,
"I": NUM_CONSUMPTION_GOODS,
"gamma": gamma_shrunk,
"gamma_g": [0.0] * M,
"epsilon": [1.0] * M, # Cobb-Douglas; heterogeneous epsilon breaks TPI
"Z": [_sector_tfp()],
"epsilon": epsilon, # calibrated CES elasticities by sector
"Z": [_sector_tfp(epsilon=epsilon, gamma=gamma_shrunk)],
"cit_rate": [[0.27] * M],
"io_matrix": _IO_MATRIX,
"alpha_c": alpha_c.tolist(),
Expand Down
Loading