From 2a9c0753fe717c0ea9cb65ff36304838ca3c140e Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 21 Feb 2026 00:44:34 +0100 Subject: [PATCH 1/5] feat: Add economy analysis module registry Define ComputationModule dataclass and MODULE_REGISTRY with 10 modules (decile, program_statistics, poverty, inequality, budget_summary, intra_decile, congressional_district, constituency, local_authority, wealth_decile) plus helper functions for country filtering and validation. Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/api/module_registry.py | 128 +++++++++++++++++++ tests/test_module_registry.py | 134 ++++++++++++++++++++ 2 files changed, 262 insertions(+) create mode 100644 src/policyengine_api/api/module_registry.py create mode 100644 tests/test_module_registry.py diff --git a/src/policyengine_api/api/module_registry.py b/src/policyengine_api/api/module_registry.py new file mode 100644 index 0000000..db3d4d7 --- /dev/null +++ b/src/policyengine_api/api/module_registry.py @@ -0,0 +1,128 @@ +"""Economy analysis module registry. + +Defines the available computation modules for economy-wide analysis. +Each module maps to a named computation (e.g., "poverty", "decile") with +metadata about which countries support it and which response fields it +populates. + +Used by: +- GET /analysis/options — lists available modules +- POST /analysis/economy-custom — runs selected modules +""" + +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class ComputationModule: + """A named economy analysis computation module.""" + + name: str + label: str + description: str + countries: list[str] = field(default_factory=list) + response_fields: list[str] = field(default_factory=list) + + +MODULE_REGISTRY: dict[str, ComputationModule] = { + "decile": ComputationModule( + name="decile", + label="Income decile impacts", + description="Relative and average income change by income decile (1-10).", + countries=["uk", "us"], + response_fields=["decile_impacts"], + ), + "program_statistics": ComputationModule( + name="program_statistics", + label="Program statistics", + description="Per-program baseline/reform totals, changes, and winner/loser counts.", + countries=["uk", "us"], + response_fields=["program_statistics", "detailed_budget"], + ), + "poverty": ComputationModule( + name="poverty", + label="Poverty rates", + description="Poverty rates by type, overall and by demographic breakdowns (age, gender, race).", + countries=["uk", "us"], + response_fields=["poverty"], + ), + "inequality": ComputationModule( + name="inequality", + label="Inequality metrics", + description="Gini coefficient, top 10%/1% share, bottom 50% share.", + countries=["uk", "us"], + response_fields=["inequality"], + ), + "budget_summary": ComputationModule( + name="budget_summary", + label="Budget summary", + description="Aggregate tax revenue, benefit spending, net income, and household count.", + countries=["uk", "us"], + response_fields=["budget_summary"], + ), + "intra_decile": ComputationModule( + name="intra_decile", + label="Intra-decile breakdown", + description="Distribution of income change categories (5 bands) within each decile.", + countries=["uk", "us"], + response_fields=["intra_decile"], + ), + "congressional_district": ComputationModule( + name="congressional_district", + label="Congressional district impact", + description="Per-district average and relative household income change for US congressional districts.", + countries=["us"], + response_fields=["congressional_district_impact"], + ), + "constituency": ComputationModule( + name="constituency", + label="Parliamentary constituency impact", + description="Per-constituency average and relative household income change for UK parliamentary constituencies.", + countries=["uk"], + response_fields=["constituency_impact"], + ), + "local_authority": ComputationModule( + name="local_authority", + label="Local authority impact", + description="Per-local-authority average and relative household income change for UK local authorities.", + countries=["uk"], + response_fields=["local_authority_impact"], + ), + "wealth_decile": ComputationModule( + name="wealth_decile", + label="Wealth decile impacts", + description="Income change by wealth decile (1-10) and intra-wealth-decile breakdown.", + countries=["uk"], + response_fields=["wealth_decile", "intra_wealth_decile"], + ), +} + + +def get_modules_for_country(country: str) -> list[ComputationModule]: + """Return modules applicable to a country code ('uk' or 'us').""" + return [m for m in MODULE_REGISTRY.values() if country in m.countries] + + +def get_all_module_names() -> list[str]: + """Return all registered module names.""" + return list(MODULE_REGISTRY.keys()) + + +def validate_modules(names: list[str], country: str) -> list[str]: + """Validate module names against the registry for a given country. + + Returns the validated list. Raises ValueError for unknown or + inapplicable modules. + """ + available = {m.name for m in get_modules_for_country(country)} + errors = [] + for name in names: + if name not in MODULE_REGISTRY: + errors.append(f"Unknown module: {name!r}") + elif name not in available: + errors.append( + f"Module {name!r} is not available for country {country!r}" + ) + if errors: + raise ValueError("; ".join(errors)) + return names diff --git a/tests/test_module_registry.py b/tests/test_module_registry.py new file mode 100644 index 0000000..d8f18f2 --- /dev/null +++ b/tests/test_module_registry.py @@ -0,0 +1,134 @@ +"""Tests for the economy analysis module registry.""" + +import pytest + +from policyengine_api.api.module_registry import ( + MODULE_REGISTRY, + ComputationModule, + get_all_module_names, + get_modules_for_country, + validate_modules, +) + + +class TestModuleRegistry: + """Tests for MODULE_REGISTRY contents.""" + + def test_registry_is_not_empty(self): + assert len(MODULE_REGISTRY) > 0 + + def test_all_entries_are_computation_modules(self): + for name, module in MODULE_REGISTRY.items(): + assert isinstance(module, ComputationModule) + assert module.name == name + + def test_all_modules_have_countries(self): + for module in MODULE_REGISTRY.values(): + assert len(module.countries) > 0 + for country in module.countries: + assert country in ("uk", "us") + + def test_all_modules_have_response_fields(self): + for module in MODULE_REGISTRY.values(): + assert len(module.response_fields) > 0 + + def test_expected_modules_exist(self): + expected = [ + "decile", + "program_statistics", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + "congressional_district", + "constituency", + "local_authority", + "wealth_decile", + ] + for name in expected: + assert name in MODULE_REGISTRY, f"Missing module: {name}" + + +class TestCountryApplicability: + """Tests for country-specific module availability.""" + + def test_us_only_modules(self): + assert "us" in MODULE_REGISTRY["congressional_district"].countries + assert "uk" not in MODULE_REGISTRY["congressional_district"].countries + + def test_uk_only_modules(self): + for name in ("constituency", "local_authority", "wealth_decile"): + module = MODULE_REGISTRY[name] + assert "uk" in module.countries + assert "us" not in module.countries + + def test_shared_modules(self): + shared = ["decile", "program_statistics", "poverty", "inequality", + "budget_summary", "intra_decile"] + for name in shared: + module = MODULE_REGISTRY[name] + assert "uk" in module.countries + assert "us" in module.countries + + +class TestGetModulesForCountry: + """Tests for get_modules_for_country().""" + + def test_uk_includes_constituency(self): + uk_modules = get_modules_for_country("uk") + names = [m.name for m in uk_modules] + assert "constituency" in names + assert "local_authority" in names + assert "wealth_decile" in names + + def test_uk_excludes_congressional_district(self): + uk_modules = get_modules_for_country("uk") + names = [m.name for m in uk_modules] + assert "congressional_district" not in names + + def test_us_includes_congressional_district(self): + us_modules = get_modules_for_country("us") + names = [m.name for m in us_modules] + assert "congressional_district" in names + + def test_us_excludes_uk_only(self): + us_modules = get_modules_for_country("us") + names = [m.name for m in us_modules] + assert "constituency" not in names + assert "local_authority" not in names + assert "wealth_decile" not in names + + def test_unknown_country_returns_empty(self): + assert get_modules_for_country("fr") == [] + + +class TestGetAllModuleNames: + """Tests for get_all_module_names().""" + + def test_returns_all_names(self): + names = get_all_module_names() + assert set(names) == set(MODULE_REGISTRY.keys()) + + +class TestValidateModules: + """Tests for validate_modules().""" + + def test_valid_us_modules(self): + result = validate_modules(["decile", "poverty"], "us") + assert result == ["decile", "poverty"] + + def test_valid_uk_modules(self): + result = validate_modules(["constituency", "wealth_decile"], "uk") + assert result == ["constituency", "wealth_decile"] + + def test_unknown_module_raises(self): + with pytest.raises(ValueError, match="Unknown module"): + validate_modules(["nonexistent"], "us") + + def test_wrong_country_raises(self): + with pytest.raises(ValueError, match="not available for country"): + validate_modules(["congressional_district"], "uk") + + def test_multiple_errors_combined(self): + with pytest.raises(ValueError, match="Unknown module.*not available"): + validate_modules(["nonexistent", "constituency"], "us") From 2a9cb0609b8273f05cdef04f4a6e2e9ebc58cb04 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 23 Feb 2026 10:51:49 +0100 Subject: [PATCH 2/5] feat: Add GET /analysis/options endpoint Returns available economy analysis modules from the registry, with optional country query param to filter by UK/US applicability. Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/api/analysis.py | 40 +++++++++++++++++++ tests/test_analysis_options.py | 57 ++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 tests/test_analysis_options.py diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index 76fb30f..137879f 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -54,6 +54,7 @@ TaxBenefitModel, TaxBenefitModelVersion, ) +from policyengine_api.api.module_registry import MODULE_REGISTRY, get_modules_for_country from policyengine_api.services.database import get_session @@ -80,6 +81,45 @@ def _safe_float(value: float | None) -> float | None: router = APIRouter(prefix="/analysis", tags=["analysis"]) +# --------------------------------------------------------------------------- +# GET /analysis/options — list available computation modules +# --------------------------------------------------------------------------- + + +class ModuleOption(BaseModel): + """A single computation module available for economy analysis.""" + + name: str + label: str + description: str + response_fields: list[str] + + +@router.get("/options", response_model=list[ModuleOption]) +def list_analysis_options( + country: str | None = None, +) -> list[ModuleOption]: + """List available economy analysis modules. + + Args: + country: Optional country code ('uk' or 'us') to filter modules. + """ + if country: + modules = get_modules_for_country(country) + else: + modules = list(MODULE_REGISTRY.values()) + + return [ + ModuleOption( + name=m.name, + label=m.label, + description=m.description, + response_fields=list(m.response_fields), + ) + for m in modules + ] + + class EconomicImpactRequest(BaseModel): """Request body for economic impact analysis. diff --git a/tests/test_analysis_options.py b/tests/test_analysis_options.py new file mode 100644 index 0000000..0dee9d3 --- /dev/null +++ b/tests/test_analysis_options.py @@ -0,0 +1,57 @@ +"""Tests for GET /analysis/options endpoint.""" + +from policyengine_api.api.module_registry import MODULE_REGISTRY + + +class TestAnalysisOptions: + """Tests for the /analysis/options endpoint.""" + + def test_returns_all_modules(self, client): + response = client.get("/analysis/options") + assert response.status_code == 200 + data = response.json() + assert len(data) == len(MODULE_REGISTRY) + + def test_response_shape(self, client): + response = client.get("/analysis/options") + data = response.json() + for item in data: + assert "name" in item + assert "label" in item + assert "description" in item + assert "response_fields" in item + assert isinstance(item["response_fields"], list) + + def test_filter_by_uk(self, client): + response = client.get("/analysis/options?country=uk") + assert response.status_code == 200 + data = response.json() + names = [m["name"] for m in data] + assert "constituency" in names + assert "local_authority" in names + assert "wealth_decile" in names + assert "congressional_district" not in names + + def test_filter_by_us(self, client): + response = client.get("/analysis/options?country=us") + assert response.status_code == 200 + data = response.json() + names = [m["name"] for m in data] + assert "congressional_district" in names + assert "constituency" not in names + assert "local_authority" not in names + assert "wealth_decile" not in names + + def test_shared_modules_in_both_countries(self, client): + uk_resp = client.get("/analysis/options?country=uk") + us_resp = client.get("/analysis/options?country=us") + uk_names = {m["name"] for m in uk_resp.json()} + us_names = {m["name"] for m in us_resp.json()} + for shared in ["decile", "poverty", "inequality", "budget_summary", "intra_decile"]: + assert shared in uk_names + assert shared in us_names + + def test_unknown_country_returns_empty(self, client): + response = client.get("/analysis/options?country=fr") + assert response.status_code == 200 + assert response.json() == [] From 5806b8a1d5f0de333e8eb8c80f058e5405060301 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 23 Feb 2026 11:01:05 +0100 Subject: [PATCH 3/5] feat: Add POST /analysis/economy-custom endpoint Accepts same inputs as /analysis/economic-impact plus a modules list. Validates module names against the registry for the given country, triggers computation, and filters the response to only include fields for the requested modules. Includes GET polling endpoint with optional modules query param. Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/api/analysis.py | 193 ++++++++++++++++++++++++++- tests/test_economy_custom.py | 152 +++++++++++++++++++++ 2 files changed, 344 insertions(+), 1 deletion(-) create mode 100644 tests/test_economy_custom.py diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index 137879f..d7a68f9 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -54,7 +54,11 @@ TaxBenefitModel, TaxBenefitModelVersion, ) -from policyengine_api.api.module_registry import MODULE_REGISTRY, get_modules_for_country +from policyengine_api.api.module_registry import ( + MODULE_REGISTRY, + get_modules_for_country, + validate_modules, +) from policyengine_api.services.database import get_session @@ -1831,3 +1835,190 @@ def get_economic_impact_status( raise HTTPException(status_code=500, detail="Simulation data missing") return _build_response(report, baseline_sim, reform_sim, session) + + +# --------------------------------------------------------------------------- +# POST /analysis/economy-custom — run selected economy modules +# --------------------------------------------------------------------------- + +_MODEL_TO_COUNTRY = { + "policyengine_uk": "uk", + "policyengine_us": "us", +} + + +class EconomyCustomRequest(BaseModel): + """Request body for custom economy analysis with selected modules.""" + + tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( + description="Which country model to use" + ) + dataset_id: UUID | None = Field( + default=None, + description="Dataset ID. Either dataset_id or region must be provided.", + ) + region: str | None = Field( + default=None, + description="Region code (e.g., 'state/ca', 'us').", + ) + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID to compare against baseline (current law)", + ) + dynamic_id: UUID | None = Field( + default=None, description="Optional behavioural response specification ID" + ) + modules: list[str] = Field( + description="List of module names to compute (see GET /analysis/options)" + ) + + +def _build_filtered_response( + full_response: EconomicImpactResponse, + modules: list[str], +) -> EconomicImpactResponse: + """Return a copy of the response with only the fields for requested modules.""" + allowed_fields: set[str] = set() + for name in modules: + module = MODULE_REGISTRY.get(name) + if module: + allowed_fields.update(module.response_fields) + + # Fields that are always included regardless of modules + always_included = { + "report_id", + "status", + "baseline_simulation", + "reform_simulation", + "region", + "error_message", + } + + filtered = {} + for field_name in EconomicImpactResponse.model_fields: + value = getattr(full_response, field_name) + if field_name in always_included: + filtered[field_name] = value + elif field_name in allowed_fields: + filtered[field_name] = value + else: + filtered[field_name] = None + + return EconomicImpactResponse.model_construct(**filtered) + + +@router.post("/economy-custom", response_model=EconomicImpactResponse) +def economy_custom( + request: EconomyCustomRequest, + session: Session = Depends(get_session), +) -> EconomicImpactResponse: + """Run economy-wide analysis with only the selected modules. + + Same async pattern as /analysis/economic-impact but accepts a list of + module names. Only the requested modules' response fields are populated; + the rest are null. + + See GET /analysis/options for available module names. + """ + country = _MODEL_TO_COUNTRY[request.tax_benefit_model_name] + + try: + validate_modules(request.modules, country) + except ValueError as exc: + raise HTTPException(status_code=422, detail=str(exc)) + + # Reuse the same request model for dataset/region resolution + impact_request = EconomicImpactRequest( + tax_benefit_model_name=request.tax_benefit_model_name, + dataset_id=request.dataset_id, + region=request.region, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + ) + + dataset, region_obj = _resolve_dataset_and_region(impact_request, session) + + filter_field = ( + region_obj.filter_field if region_obj and region_obj.requires_filter else None + ) + filter_value = ( + region_obj.filter_value if region_obj and region_obj.requires_filter else None + ) + + model_version = _get_model_version(request.tax_benefit_model_name, session) + + baseline_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=request.dynamic_id, + session=session, + dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, + ) + + reform_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + model_version_id=model_version.id, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + session=session, + dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, + ) + + label = f"Custom analysis: {request.tax_benefit_model_name}" + if request.policy_id: + label += f" (policy {request.policy_id})" + + report = _get_or_create_report( + baseline_sim.id, reform_sim.id, label, "economy_comparison", session + ) + + if report.status == ReportStatus.PENDING: + with logfire.span("trigger_economy_comparison", job_id=str(report.id)): + _trigger_economy_comparison( + str(report.id), request.tax_benefit_model_name, session + ) + + full_response = _build_response( + report, baseline_sim, reform_sim, session, region_obj + ) + return _build_filtered_response(full_response, request.modules) + + +@router.get("/economy-custom/{report_id}", response_model=EconomicImpactResponse) +def get_economy_custom_status( + report_id: UUID, + modules: str | None = None, + session: Session = Depends(get_session), +) -> EconomicImpactResponse: + """Poll for results of custom economy analysis. + + Args: + report_id: The report ID returned by POST /analysis/economy-custom. + modules: Optional comma-separated module names to filter the response. + If omitted, all computed fields are returned. + """ + report = session.get(Report, report_id) + if not report: + raise HTTPException(status_code=404, detail=f"Report {report_id} not found") + + if not report.baseline_simulation_id or not report.reform_simulation_id: + raise HTTPException(status_code=500, detail="Report missing simulation IDs") + + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + reform_sim = session.get(Simulation, report.reform_simulation_id) + + if not baseline_sim or not reform_sim: + raise HTTPException(status_code=500, detail="Simulation data missing") + + full_response = _build_response(report, baseline_sim, reform_sim, session) + + if modules: + module_list = [m.strip() for m in modules.split(",")] + return _build_filtered_response(full_response, module_list) + + return full_response diff --git a/tests/test_economy_custom.py b/tests/test_economy_custom.py new file mode 100644 index 0000000..f992c53 --- /dev/null +++ b/tests/test_economy_custom.py @@ -0,0 +1,152 @@ +"""Tests for POST /analysis/economy-custom endpoint.""" + +import pytest +from unittest.mock import patch + +from policyengine_api.api.analysis import ( + EconomicImpactResponse, + SimulationInfo, + _build_filtered_response, +) +from policyengine_api.models import ReportStatus, SimulationStatus + + +# --------------------------------------------------------------------------- +# Unit tests for _build_filtered_response +# --------------------------------------------------------------------------- + + +def _make_stub_response(**overrides) -> EconomicImpactResponse: + """Build a minimal EconomicImpactResponse for testing.""" + from uuid import uuid4 + + defaults = dict( + report_id=uuid4(), + status=ReportStatus.COMPLETED, + baseline_simulation=SimulationInfo( + id=uuid4(), status=SimulationStatus.COMPLETED + ), + reform_simulation=SimulationInfo( + id=uuid4(), status=SimulationStatus.COMPLETED + ), + region=None, + error_message=None, + decile_impacts=[{"fake": "decile"}], + program_statistics=[{"fake": "program"}], + poverty=[{"fake": "poverty"}], + inequality=[{"fake": "inequality"}], + budget_summary=[{"fake": "budget"}], + intra_decile=[{"fake": "intra"}], + detailed_budget={"prog": {"baseline": 1.0}}, + congressional_district_impact=[{"fake": "district"}], + constituency_impact=[{"fake": "constituency"}], + local_authority_impact=[{"fake": "la"}], + wealth_decile=[{"fake": "wealth"}], + intra_wealth_decile=[{"fake": "intra_wealth"}], + ) + defaults.update(overrides) + return EconomicImpactResponse.model_construct(**defaults) + + +class TestBuildFilteredResponse: + """Tests for response filtering by module list.""" + + def test_single_module_keeps_only_its_fields(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["poverty"]) + assert filtered.poverty is not None + assert filtered.decile_impacts is None + assert filtered.program_statistics is None + assert filtered.inequality is None + + def test_multiple_modules(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["decile", "inequality"]) + assert filtered.decile_impacts is not None + assert filtered.inequality is not None + assert filtered.poverty is None + assert filtered.congressional_district_impact is None + + def test_program_statistics_includes_detailed_budget(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["program_statistics"]) + assert filtered.program_statistics is not None + assert filtered.detailed_budget is not None + assert filtered.decile_impacts is None + + def test_always_included_fields_preserved(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["poverty"]) + assert filtered.report_id == resp.report_id + assert filtered.status == resp.status + assert filtered.baseline_simulation is not None + assert filtered.reform_simulation is not None + + def test_empty_modules_nullifies_all_data_fields(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, []) + assert filtered.decile_impacts is None + assert filtered.poverty is None + assert filtered.inequality is None + assert filtered.report_id == resp.report_id + + +# --------------------------------------------------------------------------- +# Integration tests for the endpoint itself +# --------------------------------------------------------------------------- + + +class TestEconomyCustomEndpoint: + """Tests for POST /analysis/economy-custom validation.""" + + def test_unknown_module_returns_422(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": ["nonexistent_module"], + }, + ) + assert response.status_code == 422 + assert "Unknown module" in response.json()["detail"] + + def test_wrong_country_module_returns_422(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": ["constituency"], + }, + ) + assert response.status_code == 422 + assert "not available for country" in response.json()["detail"] + + def test_empty_modules_list_returns_422(self, client): + """An empty modules list should still be accepted (no validation error).""" + # This will fail on dataset resolution (404), not module validation + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": [], + }, + ) + # Empty list passes validation, so the error should be about + # dataset/region resolution, not about modules + assert response.status_code != 422 or "module" not in response.json().get( + "detail", "" + ).lower() + + +class TestEconomyCustomPolling: + """Tests for GET /analysis/economy-custom/{report_id}.""" + + def test_not_found(self, client): + from uuid import uuid4 + + fake_id = uuid4() + response = client.get(f"/analysis/economy-custom/{fake_id}") + assert response.status_code == 404 From 15366cad85e07aad00a6e23aa1c828aad7e945e9 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 23 Feb 2026 11:16:36 +0100 Subject: [PATCH 4/5] refactor: Extract composable computation modules from monolithic functions Move each module's computation logic (decile, poverty, inequality, budget_summary, program_statistics, intra_decile, constituency, local_authority, wealth_decile, congressional_district) into standalone functions in computation_modules.py with UK/US dispatch tables. The local economy comparison functions now call run_modules() with an optional modules list, enabling selective computation from the /analysis/economy-custom endpoint. Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/api/analysis.py | 726 +--------------- .../api/computation_modules.py | 778 ++++++++++++++++++ tests/test_computation_modules.py | 107 +++ 3 files changed, 929 insertions(+), 682 deletions(-) create mode 100644 src/policyengine_api/api/computation_modules.py create mode 100644 tests/test_computation_modules.py diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index d7a68f9..6d35a61 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -705,7 +705,9 @@ def _download_dataset_local(filepath: str) -> str: return str(cache_path) -def _run_local_economy_comparison_uk(job_id: str, session: Session) -> None: +def _run_local_economy_comparison_uk( + job_id: str, session: Session, modules: list[str] | None = None +) -> None: """Run UK economy comparison analysis locally.""" from datetime import datetime, timezone from uuid import UUID @@ -714,20 +716,8 @@ def _run_local_economy_comparison_uk(job_id: str, session: Session) -> None: from policyengine.core.dynamic import Dynamic as PEDynamic from policyengine.core.policy import ParameterValue as PEParameterValue from policyengine.core.policy import Policy as PEPolicy - from policyengine.outputs import DecileImpact as PEDecileImpact - from policyengine.outputs.aggregate import Aggregate as PEAggregate - from policyengine.outputs.aggregate import AggregateType as PEAggregateType - from policyengine.outputs.inequality import calculate_uk_inequality - from policyengine.outputs.poverty import ( - calculate_uk_poverty_by_age, - calculate_uk_poverty_by_gender, - calculate_uk_poverty_rates, - ) from policyengine.tax_benefit_models.uk import uk_latest from policyengine.tax_benefit_models.uk.datasets import PolicyEngineUKDataset - from policyengine.tax_benefit_models.uk.outputs import ( - ProgrammeStatistics as PEProgrammeStats, - ) from policyengine_api.models import Policy as DBPolicy @@ -848,383 +838,20 @@ def build_dynamic(dynamic_id): ) pe_reform_sim.ensure() - # Calculate decile impacts - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, - ) - di.run() - decile_impact = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable=di.income_variable, - entity=di.entity, - decile=di.decile, - quantiles=di.quantiles, - baseline_mean=di.baseline_mean, - reform_mean=di.reform_mean, - absolute_change=di.absolute_change, - relative_change=di.relative_change, - count_better_off=di.count_better_off, - count_worse_off=di.count_worse_off, - count_no_change=di.count_no_change, - ) - session.add(decile_impact) - - # Calculate program statistics - PEProgrammeStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) - programmes = { - "income_tax": {"entity": "person", "is_tax": True}, - "national_insurance": {"entity": "person", "is_tax": True}, - "vat": {"entity": "household", "is_tax": True}, - "council_tax": {"entity": "household", "is_tax": True}, - "universal_credit": {"entity": "person", "is_tax": False}, - "child_benefit": {"entity": "person", "is_tax": False}, - "pension_credit": {"entity": "person", "is_tax": False}, - "income_support": {"entity": "person", "is_tax": False}, - "working_tax_credit": {"entity": "person", "is_tax": False}, - "child_tax_credit": {"entity": "person", "is_tax": False}, - } - for prog_name, prog_info in programmes.items(): - try: - ps = PEProgrammeStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - programme_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - program_stat = ProgramStatistics( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, - ) - session.add(program_stat) - except KeyError: - pass # Variable not found in model - - # Calculate poverty rates for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - poverty_results = calculate_uk_poverty_rates(pe_sim) - for pov in poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by age group for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - age_poverty_results = calculate_uk_poverty_by_age(pe_sim) - for pov in age_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by gender for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - gender_poverty_results = calculate_uk_poverty_by_gender(pe_sim) - for pov in gender_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate inequality for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - ineq = calculate_uk_inequality(pe_sim) - ineq.run() - inequality_record = Inequality( - simulation_id=db_sim.id, - report_id=report.id, - income_variable=ineq.income_variable, - entity=ineq.entity, - gini=ineq.gini, - top_10_share=ineq.top_10_share, - top_1_share=ineq.top_1_share, - bottom_50_share=ineq.bottom_50_share, - ) - session.add(inequality_record) - - # Calculate budget summary aggregates - # UK budget variables — household-level aggregates for fiscal totals - uk_budget_variables = { - "household_tax": "household", - "household_benefits": "household", - "household_net_income": "household", - } - PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) - for var_name, entity in uk_budget_variables.items(): - baseline_agg = PEAggregate( - simulation=pe_baseline_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - baseline_agg.run() - reform_agg = PEAggregate( - simulation=pe_reform_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - reform_agg.run() - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name=var_name, - entity=entity, - baseline_total=float(baseline_agg.result), - reform_total=float(reform_agg.result), - change=float(reform_agg.result - baseline_agg.result), - ) - session.add(budget_record) - - # Household count: bypass Aggregate and compute directly from raw numpy - # values. Using Aggregate(SUM) on household_weight would compute - # sum(weight * weight) because MicroSeries.sum() applies weights - # automatically — it's unclear whether Aggregate can be used correctly - # for summing the weight column itself. - baseline_hh_count = float( - pe_baseline_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - reform_hh_count = float( - pe_reform_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name="household_count_total", - entity="household", - baseline_total=baseline_hh_count, - reform_total=reform_hh_count, - change=reform_hh_count - baseline_hh_count, - ) - session.add(budget_record) - - # Calculate intra-decile impact (5-category income change distribution) - from policyengine.outputs.intra_decile_impact import ( - compute_intra_decile_impacts as pe_compute_intra_decile, - ) - - intra_decile_results = pe_compute_intra_decile( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - entity="household", - ) - for r in intra_decile_results.outputs: - record = IntraDecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - decile=r.decile, - lose_more_than_5pct=r.lose_more_than_5pct, - lose_less_than_5pct=r.lose_less_than_5pct, - no_change=r.no_change, - gain_less_than_5pct=r.gain_less_than_5pct, - gain_more_than_5pct=r.gain_more_than_5pct, - ) - session.add(record) - - # Calculate constituency impact (UK only, requires weight matrix) - from policyengine.outputs.constituency_impact import ( - compute_uk_constituency_impacts, - ) - - try: - from policyengine_core.tools.google_cloud import download as gcs_download - - weight_matrix_path = gcs_download( - gcs_bucket="policyengine-uk-data-private", - gcs_key="parliamentary_constituency_weights.h5", - ) - constituency_csv_path = gcs_download( - gcs_bucket="policyengine-uk-data-private", - gcs_key="constituencies_2024.csv", - ) - constituency_impact = compute_uk_constituency_impacts( - pe_baseline_sim, - pe_reform_sim, - weight_matrix_path=weight_matrix_path, - constituency_csv_path=constituency_csv_path, - ) - if constituency_impact.constituency_results: - for cr in constituency_impact.constituency_results: - record = ConstituencyImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - constituency_code=cr["constituency_code"], - constituency_name=cr["constituency_name"], - x=cr["x"], - y=cr["y"], - average_household_income_change=cr[ - "average_household_income_change" - ], - relative_household_income_change=cr[ - "relative_household_income_change" - ], - population=cr["population"], - ) - session.add(record) - except Exception: - pass # Weight matrix not available, skip constituency impact + # Run computation modules + from policyengine_api.api.computation_modules import UK_MODULE_DISPATCH, run_modules - # Calculate local authority impact (UK only, requires weight matrix) - from policyengine.outputs.local_authority_impact import ( - compute_uk_local_authority_impacts, + run_modules( + dispatch=UK_MODULE_DISPATCH, + modules=modules, + pe_baseline_sim=pe_baseline_sim, + pe_reform_sim=pe_reform_sim, + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id, + report_id=report.id, + session=session, ) - try: - from policyengine_core.tools.google_cloud import download as gcs_download - - la_weight_matrix_path = gcs_download( - gcs_bucket="policyengine-uk-data-private", - gcs_key="local_authority_weights.h5", - ) - la_csv_path = gcs_download( - gcs_bucket="policyengine-uk-data-private", - gcs_key="local_authorities_2021.csv", - ) - la_impact = compute_uk_local_authority_impacts( - pe_baseline_sim, - pe_reform_sim, - weight_matrix_path=la_weight_matrix_path, - local_authority_csv_path=la_csv_path, - ) - if la_impact.local_authority_results: - for lr in la_impact.local_authority_results: - record = LocalAuthorityImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - local_authority_code=lr["local_authority_code"], - local_authority_name=lr["local_authority_name"], - x=lr["x"], - y=lr["y"], - average_household_income_change=lr[ - "average_household_income_change" - ], - relative_household_income_change=lr[ - "relative_household_income_change" - ], - population=lr["population"], - ) - session.add(record) - except Exception: - pass # Weight matrix not available, skip local authority impact - - # Calculate wealth decile impact (UK only) - try: - from policyengine.outputs.decile_impact import ( - DecileImpact as PEDecileImpact, - ) - - PEDecileImpact.model_rebuild(_types_namespace={"Simulation": PESimulation}) - for decile_num in range(1, 11): - wealth_di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - decile_variable="household_wealth_decile", - entity="household", - decile=decile_num, - ) - wealth_di.run() - record = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable="household_wealth_decile", - entity="household", - decile=decile_num, - quantiles=10, - baseline_mean=wealth_di.baseline_mean, - reform_mean=wealth_di.reform_mean, - absolute_change=wealth_di.absolute_change, - relative_change=wealth_di.relative_change, - ) - session.add(record) - - # Calculate intra-wealth-decile impact - intra_wealth_results = pe_compute_intra_decile( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - decile_variable="household_wealth_decile", - entity="household", - ) - for r in intra_wealth_results.outputs: - record = IntraDecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - decile_type="wealth", - decile=r.decile, - lose_more_than_5pct=r.lose_more_than_5pct, - lose_less_than_5pct=r.lose_less_than_5pct, - no_change=r.no_change, - gain_less_than_5pct=r.gain_less_than_5pct, - gain_more_than_5pct=r.gain_more_than_5pct, - ) - session.add(record) - except (KeyError, Exception): - pass # household_wealth_decile not available (US), skip - # Mark completed baseline_sim.status = SimulationStatus.COMPLETED baseline_sim.completed_at = datetime.now(timezone.utc) @@ -1237,7 +864,9 @@ def build_dynamic(dynamic_id): session.commit() -def _run_local_economy_comparison_us(job_id: str, session: Session) -> None: +def _run_local_economy_comparison_us( + job_id: str, session: Session, modules: list[str] | None = None +) -> None: """Run US economy comparison analysis locally.""" from datetime import datetime, timezone from uuid import UUID @@ -1246,21 +875,8 @@ def _run_local_economy_comparison_us(job_id: str, session: Session) -> None: from policyengine.core.dynamic import Dynamic as PEDynamic from policyengine.core.policy import ParameterValue as PEParameterValue from policyengine.core.policy import Policy as PEPolicy - from policyengine.outputs import DecileImpact as PEDecileImpact - from policyengine.outputs.aggregate import Aggregate as PEAggregate - from policyengine.outputs.aggregate import AggregateType as PEAggregateType - from policyengine.outputs.inequality import calculate_us_inequality - from policyengine.outputs.poverty import ( - calculate_us_poverty_by_age, - calculate_us_poverty_by_gender, - calculate_us_poverty_by_race, - calculate_us_poverty_rates, - ) from policyengine.tax_benefit_models.us import us_latest from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset - from policyengine.tax_benefit_models.us.outputs import ( - ProgramStatistics as PEProgramStats, - ) from policyengine_api.models import Policy as DBPolicy @@ -1381,284 +997,20 @@ def build_dynamic(dynamic_id): ) pe_reform_sim.ensure() - # Calculate decile impacts - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, - ) - di.run() - decile_impact = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable=di.income_variable, - entity=di.entity, - decile=di.decile, - quantiles=di.quantiles, - baseline_mean=di.baseline_mean, - reform_mean=di.reform_mean, - absolute_change=di.absolute_change, - relative_change=di.relative_change, - count_better_off=di.count_better_off, - count_worse_off=di.count_worse_off, - count_no_change=di.count_no_change, - ) - session.add(decile_impact) - - # Calculate program statistics - PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) - programs = { - "income_tax": {"entity": "tax_unit", "is_tax": True}, - "employee_payroll_tax": {"entity": "person", "is_tax": True}, - "snap": {"entity": "spm_unit", "is_tax": False}, - "tanf": {"entity": "spm_unit", "is_tax": False}, - "ssi": {"entity": "spm_unit", "is_tax": False}, - "social_security": {"entity": "person", "is_tax": False}, - } - for prog_name, prog_info in programs.items(): - try: - ps = PEProgramStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - program_stat = ProgramStatistics( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, - ) - session.add(program_stat) - except KeyError: - pass # Variable not found in model - - # Calculate poverty rates for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - poverty_results = calculate_us_poverty_rates(pe_sim) - for pov in poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by age group for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - age_poverty_results = calculate_us_poverty_by_age(pe_sim) - for pov in age_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by gender for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - gender_poverty_results = calculate_us_poverty_by_gender(pe_sim) - for pov in gender_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by race for baseline and reform (US only) - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - race_poverty_results = calculate_us_poverty_by_race(pe_sim) - for pov in race_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate inequality for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - ineq = calculate_us_inequality(pe_sim) - ineq.run() - inequality_record = Inequality( - simulation_id=db_sim.id, - report_id=report.id, - income_variable=ineq.income_variable, - entity=ineq.entity, - gini=ineq.gini, - top_10_share=ineq.top_10_share, - top_1_share=ineq.top_1_share, - bottom_50_share=ineq.bottom_50_share, - ) - session.add(inequality_record) - - # Calculate budget summary aggregates - # US budget variables — household-level plus state tax at tax_unit level - us_budget_variables = { - "household_tax": "household", - "household_benefits": "household", - "household_net_income": "household", - "household_state_income_tax": "tax_unit", - } - PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) - for var_name, entity in us_budget_variables.items(): - baseline_agg = PEAggregate( - simulation=pe_baseline_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - baseline_agg.run() - reform_agg = PEAggregate( - simulation=pe_reform_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - reform_agg.run() - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name=var_name, - entity=entity, - baseline_total=float(baseline_agg.result), - reform_total=float(reform_agg.result), - change=float(reform_agg.result - baseline_agg.result), - ) - session.add(budget_record) - - # Household count: bypass Aggregate and compute directly from raw numpy - # values. Using Aggregate(SUM) on household_weight would compute - # sum(weight * weight) because MicroSeries.sum() applies weights - # automatically — it's unclear whether Aggregate can be used correctly - # for summing the weight column itself. - baseline_hh_count = float( - pe_baseline_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - reform_hh_count = float( - pe_reform_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name="household_count_total", - entity="household", - baseline_total=baseline_hh_count, - reform_total=reform_hh_count, - change=reform_hh_count - baseline_hh_count, - ) - session.add(budget_record) - - # Calculate intra-decile impact (5-category income change distribution) - from policyengine.outputs.intra_decile_impact import ( - compute_intra_decile_impacts as pe_compute_intra_decile_us, - ) - - intra_decile_results_us = pe_compute_intra_decile_us( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - entity="household", - ) - for r in intra_decile_results_us.outputs: - record = IntraDecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - decile=r.decile, - lose_more_than_5pct=r.lose_more_than_5pct, - lose_less_than_5pct=r.lose_less_than_5pct, - no_change=r.no_change, - gain_less_than_5pct=r.gain_less_than_5pct, - gain_more_than_5pct=r.gain_more_than_5pct, - ) - session.add(record) + # Run computation modules + from policyengine_api.api.computation_modules import US_MODULE_DISPATCH, run_modules - # Calculate congressional district impact - from policyengine.outputs.congressional_district_impact import ( - compute_us_congressional_district_impacts, + run_modules( + dispatch=US_MODULE_DISPATCH, + modules=modules, + pe_baseline_sim=pe_baseline_sim, + pe_reform_sim=pe_reform_sim, + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id, + report_id=report.id, + session=session, ) - try: - district_impact = compute_us_congressional_district_impacts( - pe_baseline_sim, pe_reform_sim - ) - if district_impact.district_results: - for dr in district_impact.district_results: - record = CongressionalDistrictImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - district_geoid=dr["district_geoid"], - state_fips=dr["state_fips"], - district_number=dr["district_number"], - average_household_income_change=dr[ - "average_household_income_change" - ], - relative_household_income_change=dr[ - "relative_household_income_change" - ], - population=dr["population"], - ) - session.add(record) - except KeyError: - pass # congressional_district_geoid not in dataset - # Mark completed baseline_sim.status = SimulationStatus.COMPLETED baseline_sim.completed_at = datetime.now(timezone.utc) @@ -1672,9 +1024,16 @@ def build_dynamic(dynamic_id): def _trigger_economy_comparison( - job_id: str, tax_benefit_model_name: str, session: Session | None = None + job_id: str, + tax_benefit_model_name: str, + session: Session | None = None, + modules: list[str] | None = None, ) -> None: - """Trigger economy comparison analysis (local or Modal).""" + """Trigger economy comparison analysis (local or Modal). + + Args: + modules: Optional list of module names to run. If None, runs all. + """ from policyengine_api.config import settings traceparent = get_traceparent() @@ -1682,11 +1041,11 @@ def _trigger_economy_comparison( if not settings.agent_use_modal and session is not None: # Run locally if tax_benefit_model_name == "policyengine_uk": - _run_local_economy_comparison_uk(job_id, session) + _run_local_economy_comparison_uk(job_id, session, modules=modules) else: - _run_local_economy_comparison_us(job_id, session) + _run_local_economy_comparison_us(job_id, session, modules=modules) else: - # Use Modal + # Use Modal (modules param passed for future selective computation) import modal if tax_benefit_model_name == "policyengine_uk": @@ -1980,7 +1339,10 @@ def economy_custom( if report.status == ReportStatus.PENDING: with logfire.span("trigger_economy_comparison", job_id=str(report.id)): _trigger_economy_comparison( - str(report.id), request.tax_benefit_model_name, session + str(report.id), + request.tax_benefit_model_name, + session, + modules=request.modules, ) full_response = _build_response( diff --git a/src/policyengine_api/api/computation_modules.py b/src/policyengine_api/api/computation_modules.py new file mode 100644 index 0000000..0ab5b42 --- /dev/null +++ b/src/policyengine_api/api/computation_modules.py @@ -0,0 +1,778 @@ +"""Composable computation module functions for economy analysis. + +Each function computes a single module's results and writes DB records. +They share a common signature pattern: + (pe_baseline_sim, pe_reform_sim, baseline_sim_id, reform_sim_id, + report_id, session) -> None + +Used by _run_local_economy_comparison_uk/us to run modules selectively. +""" + +from __future__ import annotations + +from uuid import UUID + +from sqlmodel import Session + +from policyengine_api.models import ( + BudgetSummary, + CongressionalDistrictImpact, + ConstituencyImpact, + DecileImpact, + Inequality, + IntraDecileImpact, + LocalAuthorityImpact, + Poverty, + ProgramStatistics, +) + + +# --------------------------------------------------------------------------- +# Shared modules (UK + US) +# --------------------------------------------------------------------------- + + +def compute_decile_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute income decile impacts (1-10).""" + from policyengine.outputs import DecileImpact as PEDecileImpact + + for decile_num in range(1, 11): + di = PEDecileImpact( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + decile=decile_num, + ) + di.run() + record = DecileImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + income_variable=di.income_variable, + entity=di.entity, + decile=di.decile, + quantiles=di.quantiles, + baseline_mean=di.baseline_mean, + reform_mean=di.reform_mean, + absolute_change=di.absolute_change, + relative_change=di.relative_change, + count_better_off=di.count_better_off, + count_worse_off=di.count_worse_off, + count_no_change=di.count_no_change, + ) + session.add(record) + + +def compute_intra_decile_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute intra-decile income change distribution (5 bands).""" + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts as pe_compute_intra_decile, + ) + + results = pe_compute_intra_decile( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + entity="household", + ) + for r in results.outputs: + record = IntraDecileImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + decile=r.decile, + lose_more_than_5pct=r.lose_more_than_5pct, + lose_less_than_5pct=r.lose_less_than_5pct, + no_change=r.no_change, + gain_less_than_5pct=r.gain_less_than_5pct, + gain_more_than_5pct=r.gain_more_than_5pct, + ) + session.add(record) + + +# --------------------------------------------------------------------------- +# UK-specific modules +# --------------------------------------------------------------------------- + + +def compute_program_statistics_module_uk( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK programme statistics.""" + from policyengine.core import Simulation as PESimulation + from policyengine.tax_benefit_models.uk.outputs import ( + ProgrammeStatistics as PEProgrammeStats, + ) + + PEProgrammeStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) + programmes = { + "income_tax": {"entity": "person", "is_tax": True}, + "national_insurance": {"entity": "person", "is_tax": True}, + "vat": {"entity": "household", "is_tax": True}, + "council_tax": {"entity": "household", "is_tax": True}, + "universal_credit": {"entity": "person", "is_tax": False}, + "child_benefit": {"entity": "person", "is_tax": False}, + "pension_credit": {"entity": "person", "is_tax": False}, + "income_support": {"entity": "person", "is_tax": False}, + "working_tax_credit": {"entity": "person", "is_tax": False}, + "child_tax_credit": {"entity": "person", "is_tax": False}, + } + for prog_name, prog_info in programmes.items(): + try: + ps = PEProgrammeStats( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + programme_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + ) + ps.run() + record = ProgramStatistics( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + baseline_total=ps.baseline_total, + reform_total=ps.reform_total, + change=ps.change, + baseline_count=ps.baseline_count, + reform_count=ps.reform_count, + winners=ps.winners, + losers=ps.losers, + ) + session.add(record) + except KeyError: + pass + + +def compute_poverty_module_uk( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK poverty rates (overall, by age, by gender).""" + from policyengine.outputs.poverty import ( + calculate_uk_poverty_by_age, + calculate_uk_poverty_by_gender, + calculate_uk_poverty_rates, + ) + + sim_pairs = [ + (pe_baseline_sim, baseline_sim_id), + (pe_reform_sim, reform_sim_id), + ] + + for calculator in [ + calculate_uk_poverty_rates, + calculate_uk_poverty_by_age, + calculate_uk_poverty_by_gender, + ]: + for pe_sim, db_sim_id in sim_pairs: + results = calculator(pe_sim) + for pov in results.outputs: + record = Poverty( + simulation_id=db_sim_id, + report_id=report_id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(record) + + +def compute_inequality_module_uk( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK inequality metrics.""" + from policyengine.outputs.inequality import calculate_uk_inequality + + for pe_sim, db_sim_id in [ + (pe_baseline_sim, baseline_sim_id), + (pe_reform_sim, reform_sim_id), + ]: + ineq = calculate_uk_inequality(pe_sim) + ineq.run() + record = Inequality( + simulation_id=db_sim_id, + report_id=report_id, + income_variable=ineq.income_variable, + entity=ineq.entity, + gini=ineq.gini, + top_10_share=ineq.top_10_share, + top_1_share=ineq.top_1_share, + bottom_50_share=ineq.bottom_50_share, + ) + session.add(record) + + +def compute_budget_summary_module_uk( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK budget summary aggregates.""" + from policyengine.core import Simulation as PESimulation + from policyengine.outputs.aggregate import Aggregate as PEAggregate + from policyengine.outputs.aggregate import AggregateType as PEAggregateType + + PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) + + uk_budget_variables = { + "household_tax": "household", + "household_benefits": "household", + "household_net_income": "household", + } + for var_name, entity in uk_budget_variables.items(): + baseline_agg = PEAggregate( + simulation=pe_baseline_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + baseline_agg.run() + reform_agg = PEAggregate( + simulation=pe_reform_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + reform_agg.run() + record = BudgetSummary( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + variable_name=var_name, + entity=entity, + baseline_total=float(baseline_agg.result), + reform_total=float(reform_agg.result), + change=float(reform_agg.result - baseline_agg.result), + ) + session.add(record) + + # Household count: raw sum of weights (bypasses Aggregate weighting) + baseline_hh_count = float( + pe_baseline_sim.output_dataset.data.household[ + "household_weight" + ].values.sum() + ) + reform_hh_count = float( + pe_reform_sim.output_dataset.data.household[ + "household_weight" + ].values.sum() + ) + record = BudgetSummary( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + variable_name="household_count_total", + entity="household", + baseline_total=baseline_hh_count, + reform_total=reform_hh_count, + change=reform_hh_count - baseline_hh_count, + ) + session.add(record) + + +def compute_constituency_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK parliamentary constituency impact.""" + from policyengine.outputs.constituency_impact import ( + compute_uk_constituency_impacts, + ) + + try: + from policyengine_core.tools.google_cloud import download as gcs_download + + weight_matrix_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="parliamentary_constituency_weights.h5", + ) + constituency_csv_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="constituencies_2024.csv", + ) + impact = compute_uk_constituency_impacts( + pe_baseline_sim, + pe_reform_sim, + weight_matrix_path=weight_matrix_path, + constituency_csv_path=constituency_csv_path, + ) + if impact.constituency_results: + for cr in impact.constituency_results: + record = ConstituencyImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + constituency_code=cr["constituency_code"], + constituency_name=cr["constituency_name"], + x=cr["x"], + y=cr["y"], + average_household_income_change=cr[ + "average_household_income_change" + ], + relative_household_income_change=cr[ + "relative_household_income_change" + ], + population=cr["population"], + ) + session.add(record) + except Exception: + pass # Weight matrix not available + + +def compute_local_authority_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK local authority impact.""" + from policyengine.outputs.local_authority_impact import ( + compute_uk_local_authority_impacts, + ) + + try: + from policyengine_core.tools.google_cloud import download as gcs_download + + la_weight_matrix_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="local_authority_weights.h5", + ) + la_csv_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="local_authorities_2021.csv", + ) + impact = compute_uk_local_authority_impacts( + pe_baseline_sim, + pe_reform_sim, + weight_matrix_path=la_weight_matrix_path, + local_authority_csv_path=la_csv_path, + ) + if impact.local_authority_results: + for lr in impact.local_authority_results: + record = LocalAuthorityImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + local_authority_code=lr["local_authority_code"], + local_authority_name=lr["local_authority_name"], + x=lr["x"], + y=lr["y"], + average_household_income_change=lr[ + "average_household_income_change" + ], + relative_household_income_change=lr[ + "relative_household_income_change" + ], + population=lr["population"], + ) + session.add(record) + except Exception: + pass # Weight matrix not available + + +def compute_wealth_decile_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK wealth decile impact and intra-wealth-decile breakdown.""" + from policyengine.core import Simulation as PESimulation + from policyengine.outputs.decile_impact import DecileImpact as PEDecileImpact + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts as pe_compute_intra_decile, + ) + + try: + PEDecileImpact.model_rebuild(_types_namespace={"Simulation": PESimulation}) + for decile_num in range(1, 11): + wealth_di = PEDecileImpact( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + decile_variable="household_wealth_decile", + entity="household", + decile=decile_num, + ) + wealth_di.run() + record = DecileImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + income_variable="household_wealth_decile", + entity="household", + decile=decile_num, + quantiles=10, + baseline_mean=wealth_di.baseline_mean, + reform_mean=wealth_di.reform_mean, + absolute_change=wealth_di.absolute_change, + relative_change=wealth_di.relative_change, + ) + session.add(record) + + # Intra-wealth-decile + intra_wealth_results = pe_compute_intra_decile( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + decile_variable="household_wealth_decile", + entity="household", + ) + for r in intra_wealth_results.outputs: + record = IntraDecileImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + decile_type="wealth", + decile=r.decile, + lose_more_than_5pct=r.lose_more_than_5pct, + lose_less_than_5pct=r.lose_less_than_5pct, + no_change=r.no_change, + gain_less_than_5pct=r.gain_less_than_5pct, + gain_more_than_5pct=r.gain_more_than_5pct, + ) + session.add(record) + except (KeyError, Exception): + pass # household_wealth_decile not available + + +# --------------------------------------------------------------------------- +# US-specific modules +# --------------------------------------------------------------------------- + + +def compute_program_statistics_module_us( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute US program statistics.""" + from policyengine.core import Simulation as PESimulation + from policyengine.tax_benefit_models.us.outputs import ( + ProgramStatistics as PEProgramStats, + ) + + PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) + programs = { + "income_tax": {"entity": "tax_unit", "is_tax": True}, + "employee_payroll_tax": {"entity": "person", "is_tax": True}, + "snap": {"entity": "spm_unit", "is_tax": False}, + "tanf": {"entity": "spm_unit", "is_tax": False}, + "ssi": {"entity": "spm_unit", "is_tax": False}, + "social_security": {"entity": "person", "is_tax": False}, + } + for prog_name, prog_info in programs.items(): + try: + ps = PEProgramStats( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + ) + ps.run() + record = ProgramStatistics( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + baseline_total=ps.baseline_total, + reform_total=ps.reform_total, + change=ps.change, + baseline_count=ps.baseline_count, + reform_count=ps.reform_count, + winners=ps.winners, + losers=ps.losers, + ) + session.add(record) + except KeyError: + pass + + +def compute_poverty_module_us( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute US poverty rates (overall, by age, gender, race).""" + from policyengine.outputs.poverty import ( + calculate_us_poverty_by_age, + calculate_us_poverty_by_gender, + calculate_us_poverty_by_race, + calculate_us_poverty_rates, + ) + + sim_pairs = [ + (pe_baseline_sim, baseline_sim_id), + (pe_reform_sim, reform_sim_id), + ] + + for calculator in [ + calculate_us_poverty_rates, + calculate_us_poverty_by_age, + calculate_us_poverty_by_gender, + calculate_us_poverty_by_race, + ]: + for pe_sim, db_sim_id in sim_pairs: + results = calculator(pe_sim) + for pov in results.outputs: + record = Poverty( + simulation_id=db_sim_id, + report_id=report_id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(record) + + +def compute_inequality_module_us( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute US inequality metrics.""" + from policyengine.outputs.inequality import calculate_us_inequality + + for pe_sim, db_sim_id in [ + (pe_baseline_sim, baseline_sim_id), + (pe_reform_sim, reform_sim_id), + ]: + ineq = calculate_us_inequality(pe_sim) + ineq.run() + record = Inequality( + simulation_id=db_sim_id, + report_id=report_id, + income_variable=ineq.income_variable, + entity=ineq.entity, + gini=ineq.gini, + top_10_share=ineq.top_10_share, + top_1_share=ineq.top_1_share, + bottom_50_share=ineq.bottom_50_share, + ) + session.add(record) + + +def compute_budget_summary_module_us( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute US budget summary aggregates.""" + from policyengine.core import Simulation as PESimulation + from policyengine.outputs.aggregate import Aggregate as PEAggregate + from policyengine.outputs.aggregate import AggregateType as PEAggregateType + + PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) + + us_budget_variables = { + "household_tax": "household", + "household_benefits": "household", + "household_net_income": "household", + "household_state_income_tax": "tax_unit", + } + for var_name, entity in us_budget_variables.items(): + baseline_agg = PEAggregate( + simulation=pe_baseline_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + baseline_agg.run() + reform_agg = PEAggregate( + simulation=pe_reform_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + reform_agg.run() + record = BudgetSummary( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + variable_name=var_name, + entity=entity, + baseline_total=float(baseline_agg.result), + reform_total=float(reform_agg.result), + change=float(reform_agg.result - baseline_agg.result), + ) + session.add(record) + + # Household count: raw sum of weights + baseline_hh_count = float( + pe_baseline_sim.output_dataset.data.household[ + "household_weight" + ].values.sum() + ) + reform_hh_count = float( + pe_reform_sim.output_dataset.data.household[ + "household_weight" + ].values.sum() + ) + record = BudgetSummary( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + variable_name="household_count_total", + entity="household", + baseline_total=baseline_hh_count, + reform_total=reform_hh_count, + change=reform_hh_count - baseline_hh_count, + ) + session.add(record) + + +def compute_congressional_district_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute US congressional district impact.""" + from policyengine.outputs.congressional_district_impact import ( + compute_us_congressional_district_impacts, + ) + + try: + impact = compute_us_congressional_district_impacts( + pe_baseline_sim, pe_reform_sim + ) + if impact.district_results: + for dr in impact.district_results: + record = CongressionalDistrictImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + district_geoid=dr["district_geoid"], + state_fips=dr["state_fips"], + district_number=dr["district_number"], + average_household_income_change=dr[ + "average_household_income_change" + ], + relative_household_income_change=dr[ + "relative_household_income_change" + ], + population=dr["population"], + ) + session.add(record) + except KeyError: + pass # congressional_district_geoid not in dataset + + +# --------------------------------------------------------------------------- +# Dispatch tables: module name -> computation function +# --------------------------------------------------------------------------- + +# Type alias for module computation functions +ModuleFunction = type(compute_decile_module) + +UK_MODULE_DISPATCH: dict[str, ModuleFunction] = { + "decile": compute_decile_module, + "program_statistics": compute_program_statistics_module_uk, + "poverty": compute_poverty_module_uk, + "inequality": compute_inequality_module_uk, + "budget_summary": compute_budget_summary_module_uk, + "intra_decile": compute_intra_decile_module, + "constituency": compute_constituency_module, + "local_authority": compute_local_authority_module, + "wealth_decile": compute_wealth_decile_module, +} + +US_MODULE_DISPATCH: dict[str, ModuleFunction] = { + "decile": compute_decile_module, + "program_statistics": compute_program_statistics_module_us, + "poverty": compute_poverty_module_us, + "inequality": compute_inequality_module_us, + "budget_summary": compute_budget_summary_module_us, + "intra_decile": compute_intra_decile_module, + "congressional_district": compute_congressional_district_module, +} + + +def run_modules( + dispatch: dict[str, ModuleFunction], + modules: list[str] | None, + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Run the requested modules (or all if modules is None).""" + to_run = modules if modules is not None else list(dispatch.keys()) + for mod_name in to_run: + fn = dispatch.get(mod_name) + if fn: + fn( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id, + reform_sim_id, + report_id, + session, + ) diff --git a/tests/test_computation_modules.py b/tests/test_computation_modules.py new file mode 100644 index 0000000..70df5aa --- /dev/null +++ b/tests/test_computation_modules.py @@ -0,0 +1,107 @@ +"""Tests for the composable computation module dispatch system.""" + +from unittest.mock import MagicMock, patch + +from policyengine_api.api.computation_modules import ( + UK_MODULE_DISPATCH, + US_MODULE_DISPATCH, + run_modules, +) +from policyengine_api.api.module_registry import MODULE_REGISTRY + + +class TestDispatchTables: + """Tests for UK_MODULE_DISPATCH and US_MODULE_DISPATCH.""" + + def test_uk_dispatch_keys_match_registry(self): + """Every UK dispatch key should be a valid module in the registry.""" + for key in UK_MODULE_DISPATCH: + assert key in MODULE_REGISTRY, f"UK dispatch key {key!r} not in registry" + + def test_us_dispatch_keys_match_registry(self): + """Every US dispatch key should be a valid module in the registry.""" + for key in US_MODULE_DISPATCH: + assert key in MODULE_REGISTRY, f"US dispatch key {key!r} not in registry" + + def test_uk_dispatch_covers_uk_modules(self): + """UK dispatch should have an entry for every UK-applicable module.""" + uk_module_names = { + name + for name, mod in MODULE_REGISTRY.items() + if "uk" in mod.countries + } + assert set(UK_MODULE_DISPATCH.keys()) == uk_module_names + + def test_us_dispatch_covers_us_modules(self): + """US dispatch should have an entry for every US-applicable module.""" + us_module_names = { + name + for name, mod in MODULE_REGISTRY.items() + if "us" in mod.countries + } + assert set(US_MODULE_DISPATCH.keys()) == us_module_names + + def test_all_dispatch_values_are_callable(self): + for fn in UK_MODULE_DISPATCH.values(): + assert callable(fn) + for fn in US_MODULE_DISPATCH.values(): + assert callable(fn) + + +class TestRunModules: + """Tests for the run_modules dispatch helper.""" + + def _make_mock_dispatch(self, names): + """Create a dispatch dict with mock functions.""" + return {name: MagicMock(name=f"compute_{name}") for name in names} + + def test_runs_all_when_modules_is_none(self): + dispatch = self._make_mock_dispatch(["a", "b", "c"]) + session = MagicMock() + from uuid import uuid4 + + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], session) + + for fn in dispatch.values(): + fn.assert_called_once_with("bl", "rf", ids[0], ids[1], ids[2], session) + + def test_runs_only_requested_modules(self): + dispatch = self._make_mock_dispatch(["a", "b", "c"]) + session = MagicMock() + from uuid import uuid4 + + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, ["b"], "bl", "rf", ids[0], ids[1], ids[2], session) + + dispatch["a"].assert_not_called() + dispatch["b"].assert_called_once() + dispatch["c"].assert_not_called() + + def test_ignores_unknown_module_names(self): + dispatch = self._make_mock_dispatch(["a"]) + session = MagicMock() + from uuid import uuid4 + + ids = [uuid4() for _ in range(3)] + + # Should not raise + run_modules( + dispatch, ["a", "nonexistent"], "bl", "rf", ids[0], ids[1], ids[2], session + ) + + dispatch["a"].assert_called_once() + + def test_empty_modules_list_runs_nothing(self): + dispatch = self._make_mock_dispatch(["a", "b"]) + session = MagicMock() + from uuid import uuid4 + + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, [], "bl", "rf", ids[0], ids[1], ids[2], session) + + for fn in dispatch.values(): + fn.assert_not_called() From 21b93b1bd1ee2a76cfa00262ae97f4d424f7101d Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 23 Feb 2026 11:41:07 +0100 Subject: [PATCH 5/5] test: Expand Phase 4 unit tests from 43 to 119 Add comprehensive test coverage for module registry, analysis options, economy-custom endpoint, and computation module dispatch system. Includes lint/format fixes from ruff. Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/api/analysis.py | 22 +- .../api/computation_modules.py | 17 +- src/policyengine_api/api/module_registry.py | 4 +- tests/test_analysis_options.py | 76 ++++++- tests/test_computation_modules.py | 205 +++++++++++++++-- tests/test_economy_custom.py | 207 ++++++++++++++++-- tests/test_module_registry.py | 150 ++++++++++++- 7 files changed, 612 insertions(+), 69 deletions(-) diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index 6d35a61..ce77c78 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -25,6 +25,11 @@ from pydantic import BaseModel, Field from sqlmodel import Session, select +from policyengine_api.api.module_registry import ( + MODULE_REGISTRY, + get_modules_for_country, + validate_modules, +) from policyengine_api.models import ( BudgetSummary, BudgetSummaryRead, @@ -32,8 +37,6 @@ CongressionalDistrictImpactRead, ConstituencyImpact, ConstituencyImpactRead, - LocalAuthorityImpact, - LocalAuthorityImpactRead, Dataset, DecileImpact, DecileImpactRead, @@ -41,6 +44,8 @@ InequalityRead, IntraDecileImpact, IntraDecileImpactRead, + LocalAuthorityImpact, + LocalAuthorityImpactRead, Poverty, PovertyRead, ProgramStatistics, @@ -54,11 +59,6 @@ TaxBenefitModel, TaxBenefitModelVersion, ) -from policyengine_api.api.module_registry import ( - MODULE_REGISTRY, - get_modules_for_country, - validate_modules, -) from policyengine_api.services.database import get_session @@ -477,9 +477,7 @@ def _build_response( # Fetch intra-decile impact records for this report intra_rows = session.exec( - select(IntraDecileImpact).where( - IntraDecileImpact.report_id == report.id - ) + select(IntraDecileImpact).where(IntraDecileImpact.report_id == report.id) ).all() intra_decile_records = [ IntraDecileImpactRead( @@ -528,9 +526,7 @@ def _build_response( # Fetch constituency impact records for this report constituency_rows = session.exec( - select(ConstituencyImpact).where( - ConstituencyImpact.report_id == report.id - ) + select(ConstituencyImpact).where(ConstituencyImpact.report_id == report.id) ).all() if constituency_rows: constituency_impact_records = [ diff --git a/src/policyengine_api/api/computation_modules.py b/src/policyengine_api/api/computation_modules.py index 0ab5b42..3795ae9 100644 --- a/src/policyengine_api/api/computation_modules.py +++ b/src/policyengine_api/api/computation_modules.py @@ -26,7 +26,6 @@ ProgramStatistics, ) - # --------------------------------------------------------------------------- # Shared modules (UK + US) # --------------------------------------------------------------------------- @@ -285,14 +284,10 @@ def compute_budget_summary_module_uk( # Household count: raw sum of weights (bypasses Aggregate weighting) baseline_hh_count = float( - pe_baseline_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() + pe_baseline_sim.output_dataset.data.household["household_weight"].values.sum() ) reform_hh_count = float( - pe_reform_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() + pe_reform_sim.output_dataset.data.household["household_weight"].values.sum() ) record = BudgetSummary( baseline_simulation_id=baseline_sim_id, @@ -662,14 +657,10 @@ def compute_budget_summary_module_us( # Household count: raw sum of weights baseline_hh_count = float( - pe_baseline_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() + pe_baseline_sim.output_dataset.data.household["household_weight"].values.sum() ) reform_hh_count = float( - pe_reform_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() + pe_reform_sim.output_dataset.data.household["household_weight"].values.sum() ) record = BudgetSummary( baseline_simulation_id=baseline_sim_id, diff --git a/src/policyengine_api/api/module_registry.py b/src/policyengine_api/api/module_registry.py index db3d4d7..3cd5c6c 100644 --- a/src/policyengine_api/api/module_registry.py +++ b/src/policyengine_api/api/module_registry.py @@ -120,9 +120,7 @@ def validate_modules(names: list[str], country: str) -> list[str]: if name not in MODULE_REGISTRY: errors.append(f"Unknown module: {name!r}") elif name not in available: - errors.append( - f"Module {name!r} is not available for country {country!r}" - ) + errors.append(f"Module {name!r} is not available for country {country!r}") if errors: raise ValueError("; ".join(errors)) return names diff --git a/tests/test_analysis_options.py b/tests/test_analysis_options.py index 0dee9d3..8df673f 100644 --- a/tests/test_analysis_options.py +++ b/tests/test_analysis_options.py @@ -1,6 +1,6 @@ """Tests for GET /analysis/options endpoint.""" -from policyengine_api.api.module_registry import MODULE_REGISTRY +from policyengine_api.api.module_registry import MODULE_REGISTRY, get_modules_for_country class TestAnalysisOptions: @@ -22,6 +22,31 @@ def test_response_shape(self, client): assert "response_fields" in item assert isinstance(item["response_fields"], list) + def test_all_names_are_strings(self, client): + response = client.get("/analysis/options") + for item in response.json(): + assert isinstance(item["name"], str) + assert len(item["name"]) > 0 + + def test_all_labels_are_non_empty(self, client): + response = client.get("/analysis/options") + for item in response.json(): + assert isinstance(item["label"], str) + assert len(item["label"]) > 0 + + def test_all_descriptions_are_non_empty(self, client): + response = client.get("/analysis/options") + for item in response.json(): + assert isinstance(item["description"], str) + assert len(item["description"]) > 0 + + def test_all_response_fields_are_non_empty_lists(self, client): + response = client.get("/analysis/options") + for item in response.json(): + assert len(item["response_fields"]) > 0 + for field in item["response_fields"]: + assert isinstance(field, str) + def test_filter_by_uk(self, client): response = client.get("/analysis/options?country=uk") assert response.status_code == 200 @@ -42,12 +67,31 @@ def test_filter_by_us(self, client): assert "local_authority" not in names assert "wealth_decile" not in names + def test_uk_count_matches_registry(self, client): + response = client.get("/analysis/options?country=uk") + data = response.json() + expected = len(get_modules_for_country("uk")) + assert len(data) == expected + + def test_us_count_matches_registry(self, client): + response = client.get("/analysis/options?country=us") + data = response.json() + expected = len(get_modules_for_country("us")) + assert len(data) == expected + def test_shared_modules_in_both_countries(self, client): uk_resp = client.get("/analysis/options?country=uk") us_resp = client.get("/analysis/options?country=us") uk_names = {m["name"] for m in uk_resp.json()} us_names = {m["name"] for m in us_resp.json()} - for shared in ["decile", "poverty", "inequality", "budget_summary", "intra_decile"]: + for shared in [ + "decile", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + "program_statistics", + ]: assert shared in uk_names assert shared in us_names @@ -55,3 +99,31 @@ def test_unknown_country_returns_empty(self, client): response = client.get("/analysis/options?country=fr") assert response.status_code == 200 assert response.json() == [] + + def test_program_statistics_has_two_response_fields(self, client): + response = client.get("/analysis/options") + ps_module = next( + m for m in response.json() if m["name"] == "program_statistics" + ) + assert "program_statistics" in ps_module["response_fields"] + assert "detailed_budget" in ps_module["response_fields"] + + def test_wealth_decile_has_two_response_fields(self, client): + response = client.get("/analysis/options?country=uk") + wd_module = next(m for m in response.json() if m["name"] == "wealth_decile") + assert "wealth_decile" in wd_module["response_fields"] + assert "intra_wealth_decile" in wd_module["response_fields"] + + def test_no_country_param_returns_all(self, client): + all_resp = client.get("/analysis/options") + data = all_resp.json() + returned_names = {m["name"] for m in data} + assert returned_names == set(MODULE_REGISTRY.keys()) + + def test_response_matches_registry_data(self, client): + response = client.get("/analysis/options") + for item in response.json(): + registry_mod = MODULE_REGISTRY[item["name"]] + assert item["label"] == registry_mod.label + assert item["description"] == registry_mod.description + assert item["response_fields"] == list(registry_mod.response_fields) diff --git a/tests/test_computation_modules.py b/tests/test_computation_modules.py index 70df5aa..5329d39 100644 --- a/tests/test_computation_modules.py +++ b/tests/test_computation_modules.py @@ -1,7 +1,10 @@ """Tests for the composable computation module dispatch system.""" -from unittest.mock import MagicMock, patch +import inspect +from unittest.mock import MagicMock, call +from uuid import uuid4 +from policyengine_api.api import computation_modules as cm from policyengine_api.api.computation_modules import ( UK_MODULE_DISPATCH, US_MODULE_DISPATCH, @@ -26,18 +29,14 @@ def test_us_dispatch_keys_match_registry(self): def test_uk_dispatch_covers_uk_modules(self): """UK dispatch should have an entry for every UK-applicable module.""" uk_module_names = { - name - for name, mod in MODULE_REGISTRY.items() - if "uk" in mod.countries + name for name, mod in MODULE_REGISTRY.items() if "uk" in mod.countries } assert set(UK_MODULE_DISPATCH.keys()) == uk_module_names def test_us_dispatch_covers_us_modules(self): """US dispatch should have an entry for every US-applicable module.""" us_module_names = { - name - for name, mod in MODULE_REGISTRY.items() - if "us" in mod.countries + name for name, mod in MODULE_REGISTRY.items() if "us" in mod.countries } assert set(US_MODULE_DISPATCH.keys()) == us_module_names @@ -47,6 +46,133 @@ def test_all_dispatch_values_are_callable(self): for fn in US_MODULE_DISPATCH.values(): assert callable(fn) + def test_uk_dispatch_has_9_entries(self): + assert len(UK_MODULE_DISPATCH) == 9 + + def test_us_dispatch_has_7_entries(self): + assert len(US_MODULE_DISPATCH) == 7 + + +class TestSharedModuleFunctions: + """Tests that shared modules reference the same function objects.""" + + def test_decile_function_shared_between_uk_and_us(self): + assert UK_MODULE_DISPATCH["decile"] is US_MODULE_DISPATCH["decile"] + assert UK_MODULE_DISPATCH["decile"] is cm.compute_decile_module + + def test_intra_decile_function_shared_between_uk_and_us(self): + assert UK_MODULE_DISPATCH["intra_decile"] is US_MODULE_DISPATCH["intra_decile"] + assert UK_MODULE_DISPATCH["intra_decile"] is cm.compute_intra_decile_module + + +class TestCountrySpecificFunctions: + """Tests that UK/US specific modules use the correct country-specific functions.""" + + def test_uk_program_statistics(self): + assert ( + UK_MODULE_DISPATCH["program_statistics"] + is cm.compute_program_statistics_module_uk + ) + + def test_us_program_statistics(self): + assert ( + US_MODULE_DISPATCH["program_statistics"] + is cm.compute_program_statistics_module_us + ) + + def test_uk_poverty(self): + assert UK_MODULE_DISPATCH["poverty"] is cm.compute_poverty_module_uk + + def test_us_poverty(self): + assert US_MODULE_DISPATCH["poverty"] is cm.compute_poverty_module_us + + def test_uk_inequality(self): + assert UK_MODULE_DISPATCH["inequality"] is cm.compute_inequality_module_uk + + def test_us_inequality(self): + assert US_MODULE_DISPATCH["inequality"] is cm.compute_inequality_module_us + + def test_uk_budget_summary(self): + assert ( + UK_MODULE_DISPATCH["budget_summary"] + is cm.compute_budget_summary_module_uk + ) + + def test_us_budget_summary(self): + assert ( + US_MODULE_DISPATCH["budget_summary"] + is cm.compute_budget_summary_module_us + ) + + def test_constituency_is_uk_only(self): + assert UK_MODULE_DISPATCH["constituency"] is cm.compute_constituency_module + assert "constituency" not in US_MODULE_DISPATCH + + def test_local_authority_is_uk_only(self): + assert ( + UK_MODULE_DISPATCH["local_authority"] is cm.compute_local_authority_module + ) + assert "local_authority" not in US_MODULE_DISPATCH + + def test_wealth_decile_is_uk_only(self): + assert UK_MODULE_DISPATCH["wealth_decile"] is cm.compute_wealth_decile_module + assert "wealth_decile" not in US_MODULE_DISPATCH + + def test_congressional_district_is_us_only(self): + assert ( + US_MODULE_DISPATCH["congressional_district"] + is cm.compute_congressional_district_module + ) + assert "congressional_district" not in UK_MODULE_DISPATCH + + +class TestModuleFunctionSignatures: + """Tests that all module functions share the expected 6-param signature.""" + + _EXPECTED_PARAMS = [ + "pe_baseline_sim", + "pe_reform_sim", + "baseline_sim_id", + "reform_sim_id", + "report_id", + "session", + ] + + def _get_all_unique_functions(self): + """Collect all unique module functions from both dispatch tables.""" + seen = set() + fns = [] + for fn in list(UK_MODULE_DISPATCH.values()) + list( + US_MODULE_DISPATCH.values() + ): + if id(fn) not in seen: + seen.add(id(fn)) + fns.append(fn) + return fns + + def test_all_functions_have_6_parameters(self): + for fn in self._get_all_unique_functions(): + sig = inspect.signature(fn) + assert len(sig.parameters) == 6, ( + f"{fn.__name__} has {len(sig.parameters)} params, expected 6" + ) + + def test_all_functions_have_expected_param_names(self): + for fn in self._get_all_unique_functions(): + sig = inspect.signature(fn) + param_names = list(sig.parameters.keys()) + assert param_names == self._EXPECTED_PARAMS, ( + f"{fn.__name__} params {param_names} != {self._EXPECTED_PARAMS}" + ) + + def test_all_functions_return_none(self): + for fn in self._get_all_unique_functions(): + sig = inspect.signature(fn) + # `from __future__ import annotations` makes annotations strings + assert sig.return_annotation in (None, "None", inspect.Parameter.empty), ( + f"{fn.__name__} return annotation is {sig.return_annotation!r}, expected None" + ) + class TestRunModules: """Tests for the run_modules dispatch helper.""" @@ -58,8 +184,6 @@ def _make_mock_dispatch(self, names): def test_runs_all_when_modules_is_none(self): dispatch = self._make_mock_dispatch(["a", "b", "c"]) session = MagicMock() - from uuid import uuid4 - ids = [uuid4() for _ in range(3)] run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], session) @@ -70,8 +194,6 @@ def test_runs_all_when_modules_is_none(self): def test_runs_only_requested_modules(self): dispatch = self._make_mock_dispatch(["a", "b", "c"]) session = MagicMock() - from uuid import uuid4 - ids = [uuid4() for _ in range(3)] run_modules(dispatch, ["b"], "bl", "rf", ids[0], ids[1], ids[2], session) @@ -83,8 +205,6 @@ def test_runs_only_requested_modules(self): def test_ignores_unknown_module_names(self): dispatch = self._make_mock_dispatch(["a"]) session = MagicMock() - from uuid import uuid4 - ids = [uuid4() for _ in range(3)] # Should not raise @@ -97,11 +217,66 @@ def test_ignores_unknown_module_names(self): def test_empty_modules_list_runs_nothing(self): dispatch = self._make_mock_dispatch(["a", "b"]) session = MagicMock() - from uuid import uuid4 - ids = [uuid4() for _ in range(3)] run_modules(dispatch, [], "bl", "rf", ids[0], ids[1], ids[2], session) for fn in dispatch.values(): fn.assert_not_called() + + def test_preserves_call_order(self): + """Modules should be called in the order they appear in the modules list.""" + call_order = [] + + def make_tracker(name): + def fn(*args): + call_order.append(name) + + return fn + + dispatch = {name: make_tracker(name) for name in ["a", "b", "c"]} + ids = [uuid4() for _ in range(3)] + + run_modules( + dispatch, ["c", "a", "b"], "bl", "rf", ids[0], ids[1], ids[2], MagicMock() + ) + + assert call_order == ["c", "a", "b"] + + def test_none_modules_runs_all_in_dispatch_key_order(self): + """When modules is None, all dispatch entries run in dict-iteration order.""" + call_order = [] + + def make_tracker(name): + def fn(*args): + call_order.append(name) + + return fn + + dispatch = {name: make_tracker(name) for name in ["x", "y", "z"]} + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], MagicMock()) + + assert call_order == ["x", "y", "z"] + + def test_passes_all_args_correctly(self): + mock_fn = MagicMock() + dispatch = {"test_mod": mock_fn} + session = MagicMock() + bl, rf, b_id, r_id, rep_id = "baseline", "reform", uuid4(), uuid4(), uuid4() + + run_modules(dispatch, ["test_mod"], bl, rf, b_id, r_id, rep_id, session) + + mock_fn.assert_called_once_with(bl, rf, b_id, r_id, rep_id, session) + + def test_duplicate_module_name_runs_twice(self): + dispatch = self._make_mock_dispatch(["a"]) + session = MagicMock() + ids = [uuid4() for _ in range(3)] + + run_modules( + dispatch, ["a", "a"], "bl", "rf", ids[0], ids[1], ids[2], session + ) + + assert dispatch["a"].call_count == 2 diff --git a/tests/test_economy_custom.py b/tests/test_economy_custom.py index f992c53..f96002a 100644 --- a/tests/test_economy_custom.py +++ b/tests/test_economy_custom.py @@ -1,7 +1,6 @@ """Tests for POST /analysis/economy-custom endpoint.""" -import pytest -from unittest.mock import patch +from uuid import uuid4 from policyengine_api.api.analysis import ( EconomicImpactResponse, @@ -10,25 +9,20 @@ ) from policyengine_api.models import ReportStatus, SimulationStatus - # --------------------------------------------------------------------------- -# Unit tests for _build_filtered_response +# Helpers # --------------------------------------------------------------------------- def _make_stub_response(**overrides) -> EconomicImpactResponse: """Build a minimal EconomicImpactResponse for testing.""" - from uuid import uuid4 - defaults = dict( report_id=uuid4(), status=ReportStatus.COMPLETED, baseline_simulation=SimulationInfo( id=uuid4(), status=SimulationStatus.COMPLETED ), - reform_simulation=SimulationInfo( - id=uuid4(), status=SimulationStatus.COMPLETED - ), + reform_simulation=SimulationInfo(id=uuid4(), status=SimulationStatus.COMPLETED), region=None, error_message=None, decile_impacts=[{"fake": "decile"}], @@ -48,6 +42,37 @@ def _make_stub_response(**overrides) -> EconomicImpactResponse: return EconomicImpactResponse.model_construct(**defaults) +# All data fields that can be nullified by module filtering +_DATA_FIELDS = { + "decile_impacts", + "program_statistics", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + "detailed_budget", + "congressional_district_impact", + "constituency_impact", + "local_authority_impact", + "wealth_decile", + "intra_wealth_decile", +} + +_ALWAYS_INCLUDED = { + "report_id", + "status", + "baseline_simulation", + "reform_simulation", + "region", + "error_message", +} + + +# --------------------------------------------------------------------------- +# Unit tests for _build_filtered_response +# --------------------------------------------------------------------------- + + class TestBuildFilteredResponse: """Tests for response filtering by module list.""" @@ -82,14 +107,106 @@ def test_always_included_fields_preserved(self): assert filtered.baseline_simulation is not None assert filtered.reform_simulation is not None + def test_region_always_included(self): + from policyengine_api.api.analysis import RegionInfo + + region = RegionInfo( + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + resp = _make_stub_response(region=region) + filtered = _build_filtered_response(resp, ["decile"]) + assert filtered.region is not None + assert filtered.region.code == "uk" + + def test_error_message_always_included(self): + resp = _make_stub_response(error_message="something went wrong") + filtered = _build_filtered_response(resp, ["decile"]) + assert filtered.error_message == "something went wrong" + def test_empty_modules_nullifies_all_data_fields(self): resp = _make_stub_response() filtered = _build_filtered_response(resp, []) - assert filtered.decile_impacts is None - assert filtered.poverty is None - assert filtered.inequality is None + for field in _DATA_FIELDS: + assert getattr(filtered, field) is None, f"{field} should be None" assert filtered.report_id == resp.report_id + def test_empty_modules_preserves_always_included(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, []) + for field in _ALWAYS_INCLUDED: + original = getattr(resp, field) + assert getattr(filtered, field) == original, ( + f"{field} should be preserved" + ) + + def test_wealth_decile_keeps_both_fields(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["wealth_decile"]) + assert filtered.wealth_decile is not None + assert filtered.intra_wealth_decile is not None + assert filtered.decile_impacts is None + assert filtered.intra_decile is None + + def test_intra_decile_keeps_only_intra_decile(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["intra_decile"]) + assert filtered.intra_decile is not None + assert filtered.decile_impacts is None + assert filtered.intra_wealth_decile is None + + def test_congressional_district_keeps_only_district_impact(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["congressional_district"]) + assert filtered.congressional_district_impact is not None + assert filtered.constituency_impact is None + assert filtered.local_authority_impact is None + + def test_constituency_keeps_only_constituency_impact(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["constituency"]) + assert filtered.constituency_impact is not None + assert filtered.congressional_district_impact is None + assert filtered.local_authority_impact is None + + def test_local_authority_keeps_only_la_impact(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["local_authority"]) + assert filtered.local_authority_impact is not None + assert filtered.constituency_impact is None + assert filtered.congressional_district_impact is None + + def test_budget_summary_keeps_only_budget(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["budget_summary"]) + assert filtered.budget_summary is not None + assert filtered.decile_impacts is None + assert filtered.program_statistics is None + + def test_unknown_module_in_list_is_gracefully_ignored(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["poverty", "nonexistent_module"]) + assert filtered.poverty is not None + assert filtered.decile_impacts is None + + def test_all_modules_keeps_all_data_fields(self): + from policyengine_api.api.module_registry import MODULE_REGISTRY + + resp = _make_stub_response() + all_names = list(MODULE_REGISTRY.keys()) + filtered = _build_filtered_response(resp, all_names) + for field in _DATA_FIELDS: + assert getattr(filtered, field) is not None, ( + f"{field} should be preserved when all modules selected" + ) + + def test_returns_economic_impact_response_instance(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["decile"]) + assert isinstance(filtered, EconomicImpactResponse) + # --------------------------------------------------------------------------- # Integration tests for the endpoint itself @@ -123,9 +240,21 @@ def test_wrong_country_module_returns_422(self, client): assert response.status_code == 422 assert "not available for country" in response.json()["detail"] - def test_empty_modules_list_returns_422(self, client): - """An empty modules list should still be accepted (no validation error).""" - # This will fail on dataset resolution (404), not module validation + def test_multiple_errors_in_module_validation(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": ["nonexistent", "constituency"], + }, + ) + assert response.status_code == 422 + detail = response.json()["detail"] + assert "Unknown module" in detail + assert "not available for country" in detail + + def test_empty_modules_list_passes_validation(self, client): response = client.post( "/analysis/economy-custom", json={ @@ -134,19 +263,55 @@ def test_empty_modules_list_returns_422(self, client): "modules": [], }, ) - # Empty list passes validation, so the error should be about + # Empty list passes module validation, so the error should be about # dataset/region resolution, not about modules - assert response.status_code != 422 or "module" not in response.json().get( - "detail", "" - ).lower() + assert ( + response.status_code != 422 + or "module" not in response.json().get("detail", "").lower() + ) + + def test_valid_modules_but_missing_region_returns_404(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": ["decile", "poverty"], + }, + ) + # Passes validation but region "us" is not in the DB -> 404 + assert response.status_code == 404 + + def test_missing_modules_field_returns_422(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + }, + ) + assert response.status_code == 422 + + def test_invalid_model_name_returns_422(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "invalid_model", + "region": "us", + "modules": ["decile"], + }, + ) + assert response.status_code == 422 class TestEconomyCustomPolling: """Tests for GET /analysis/economy-custom/{report_id}.""" def test_not_found(self, client): - from uuid import uuid4 - fake_id = uuid4() response = client.get(f"/analysis/economy-custom/{fake_id}") assert response.status_code == 404 + + def test_invalid_uuid_returns_422(self, client): + response = client.get("/analysis/economy-custom/not-a-uuid") + assert response.status_code == 422 diff --git a/tests/test_module_registry.py b/tests/test_module_registry.py index d8f18f2..b3f8559 100644 --- a/tests/test_module_registry.py +++ b/tests/test_module_registry.py @@ -1,7 +1,10 @@ """Tests for the economy analysis module registry.""" +from dataclasses import FrozenInstanceError + import pytest +from policyengine_api.api.analysis import EconomicImpactResponse from policyengine_api.api.module_registry import ( MODULE_REGISTRY, ComputationModule, @@ -17,6 +20,9 @@ class TestModuleRegistry: def test_registry_is_not_empty(self): assert len(MODULE_REGISTRY) > 0 + def test_registry_has_exactly_10_modules(self): + assert len(MODULE_REGISTRY) == 10 + def test_all_entries_are_computation_modules(self): for name, module in MODULE_REGISTRY.items(): assert isinstance(module, ComputationModule) @@ -32,6 +38,16 @@ def test_all_modules_have_response_fields(self): for module in MODULE_REGISTRY.values(): assert len(module.response_fields) > 0 + def test_all_modules_have_non_empty_label(self): + for name, module in MODULE_REGISTRY.items(): + assert module.label, f"Module {name!r} has empty label" + assert len(module.label) > 0 + + def test_all_modules_have_non_empty_description(self): + for name, module in MODULE_REGISTRY.items(): + assert module.description, f"Module {name!r} has empty description" + assert len(module.description) > 0 + def test_expected_modules_exist(self): expected = [ "decile", @@ -48,6 +64,93 @@ def test_expected_modules_exist(self): for name in expected: assert name in MODULE_REGISTRY, f"Missing module: {name}" + def test_no_unexpected_modules(self): + expected = { + "decile", + "program_statistics", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + "congressional_district", + "constituency", + "local_authority", + "wealth_decile", + } + assert set(MODULE_REGISTRY.keys()) == expected + + +class TestComputationModuleFrozen: + """Tests that ComputationModule instances are immutable.""" + + def test_cannot_mutate_name(self): + module = MODULE_REGISTRY["decile"] + with pytest.raises(FrozenInstanceError): + module.name = "changed" + + def test_cannot_mutate_countries(self): + module = MODULE_REGISTRY["decile"] + with pytest.raises(FrozenInstanceError): + module.countries = ["fr"] + + def test_cannot_mutate_response_fields(self): + module = MODULE_REGISTRY["poverty"] + with pytest.raises(FrozenInstanceError): + module.response_fields = ["something_else"] + + +class TestResponseFieldsMapping: + """Tests that each module's response_fields reference valid EconomicImpactResponse fields.""" + + def test_all_response_fields_exist_on_response_model(self): + valid_fields = set(EconomicImpactResponse.model_fields.keys()) + for name, module in MODULE_REGISTRY.items(): + for field in module.response_fields: + assert field in valid_fields, ( + f"Module {name!r} references response field {field!r} " + f"which does not exist on EconomicImpactResponse" + ) + + def test_decile_response_fields(self): + assert MODULE_REGISTRY["decile"].response_fields == ["decile_impacts"] + + def test_program_statistics_includes_detailed_budget(self): + fields = MODULE_REGISTRY["program_statistics"].response_fields + assert "program_statistics" in fields + assert "detailed_budget" in fields + + def test_poverty_response_fields(self): + assert MODULE_REGISTRY["poverty"].response_fields == ["poverty"] + + def test_inequality_response_fields(self): + assert MODULE_REGISTRY["inequality"].response_fields == ["inequality"] + + def test_budget_summary_response_fields(self): + assert MODULE_REGISTRY["budget_summary"].response_fields == ["budget_summary"] + + def test_intra_decile_response_fields(self): + assert MODULE_REGISTRY["intra_decile"].response_fields == ["intra_decile"] + + def test_congressional_district_response_fields(self): + assert MODULE_REGISTRY["congressional_district"].response_fields == [ + "congressional_district_impact" + ] + + def test_constituency_response_fields(self): + assert MODULE_REGISTRY["constituency"].response_fields == [ + "constituency_impact" + ] + + def test_local_authority_response_fields(self): + assert MODULE_REGISTRY["local_authority"].response_fields == [ + "local_authority_impact" + ] + + def test_wealth_decile_includes_both_fields(self): + fields = MODULE_REGISTRY["wealth_decile"].response_fields + assert "wealth_decile" in fields + assert "intra_wealth_decile" in fields + class TestCountryApplicability: """Tests for country-specific module availability.""" @@ -63,8 +166,14 @@ def test_uk_only_modules(self): assert "us" not in module.countries def test_shared_modules(self): - shared = ["decile", "program_statistics", "poverty", "inequality", - "budget_summary", "intra_decile"] + shared = [ + "decile", + "program_statistics", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + ] for name in shared: module = MODULE_REGISTRY[name] assert "uk" in module.countries @@ -86,6 +195,10 @@ def test_uk_excludes_congressional_district(self): names = [m.name for m in uk_modules] assert "congressional_district" not in names + def test_uk_has_9_modules(self): + uk_modules = get_modules_for_country("uk") + assert len(uk_modules) == 9 + def test_us_includes_congressional_district(self): us_modules = get_modules_for_country("us") names = [m.name for m in us_modules] @@ -98,9 +211,17 @@ def test_us_excludes_uk_only(self): assert "local_authority" not in names assert "wealth_decile" not in names + def test_us_has_7_modules(self): + us_modules = get_modules_for_country("us") + assert len(us_modules) == 7 + def test_unknown_country_returns_empty(self): assert get_modules_for_country("fr") == [] + def test_returns_computation_module_instances(self): + for m in get_modules_for_country("uk"): + assert isinstance(m, ComputationModule) + class TestGetAllModuleNames: """Tests for get_all_module_names().""" @@ -109,6 +230,12 @@ def test_returns_all_names(self): names = get_all_module_names() assert set(names) == set(MODULE_REGISTRY.keys()) + def test_returns_list_of_strings(self): + names = get_all_module_names() + assert isinstance(names, list) + for name in names: + assert isinstance(name, str) + class TestValidateModules: """Tests for validate_modules().""" @@ -121,6 +248,20 @@ def test_valid_uk_modules(self): result = validate_modules(["constituency", "wealth_decile"], "uk") assert result == ["constituency", "wealth_decile"] + def test_empty_list_passes_validation(self): + result = validate_modules([], "us") + assert result == [] + + def test_all_us_modules_pass_validation(self): + us_names = [m.name for m in get_modules_for_country("us")] + result = validate_modules(us_names, "us") + assert result == us_names + + def test_all_uk_modules_pass_validation(self): + uk_names = [m.name for m in get_modules_for_country("uk")] + result = validate_modules(uk_names, "uk") + assert result == uk_names + def test_unknown_module_raises(self): with pytest.raises(ValueError, match="Unknown module"): validate_modules(["nonexistent"], "us") @@ -132,3 +273,8 @@ def test_wrong_country_raises(self): def test_multiple_errors_combined(self): with pytest.raises(ValueError, match="Unknown module.*not available"): validate_modules(["nonexistent", "constituency"], "us") + + def test_returns_original_list_on_success(self): + names = ["poverty", "decile", "inequality"] + result = validate_modules(names, "us") + assert result is names