diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index 76fb30f..ce77c78 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -25,6 +25,11 @@ from pydantic import BaseModel, Field from sqlmodel import Session, select +from policyengine_api.api.module_registry import ( + MODULE_REGISTRY, + get_modules_for_country, + validate_modules, +) from policyengine_api.models import ( BudgetSummary, BudgetSummaryRead, @@ -32,8 +37,6 @@ CongressionalDistrictImpactRead, ConstituencyImpact, ConstituencyImpactRead, - LocalAuthorityImpact, - LocalAuthorityImpactRead, Dataset, DecileImpact, DecileImpactRead, @@ -41,6 +44,8 @@ InequalityRead, IntraDecileImpact, IntraDecileImpactRead, + LocalAuthorityImpact, + LocalAuthorityImpactRead, Poverty, PovertyRead, ProgramStatistics, @@ -80,6 +85,45 @@ def _safe_float(value: float | None) -> float | None: router = APIRouter(prefix="/analysis", tags=["analysis"]) +# --------------------------------------------------------------------------- +# GET /analysis/options — list available computation modules +# --------------------------------------------------------------------------- + + +class ModuleOption(BaseModel): + """A single computation module available for economy analysis.""" + + name: str + label: str + description: str + response_fields: list[str] + + +@router.get("/options", response_model=list[ModuleOption]) +def list_analysis_options( + country: str | None = None, +) -> list[ModuleOption]: + """List available economy analysis modules. + + Args: + country: Optional country code ('uk' or 'us') to filter modules. + """ + if country: + modules = get_modules_for_country(country) + else: + modules = list(MODULE_REGISTRY.values()) + + return [ + ModuleOption( + name=m.name, + label=m.label, + description=m.description, + response_fields=list(m.response_fields), + ) + for m in modules + ] + + class EconomicImpactRequest(BaseModel): """Request body for economic impact analysis. @@ -433,9 +477,7 @@ def _build_response( # Fetch intra-decile impact records for this report intra_rows = session.exec( - select(IntraDecileImpact).where( - IntraDecileImpact.report_id == report.id - ) + select(IntraDecileImpact).where(IntraDecileImpact.report_id == report.id) ).all() intra_decile_records = [ IntraDecileImpactRead( @@ -484,9 +526,7 @@ def _build_response( # Fetch constituency impact records for this report constituency_rows = session.exec( - select(ConstituencyImpact).where( - ConstituencyImpact.report_id == report.id - ) + select(ConstituencyImpact).where(ConstituencyImpact.report_id == report.id) ).all() if constituency_rows: constituency_impact_records = [ @@ -661,7 +701,9 @@ def _download_dataset_local(filepath: str) -> str: return str(cache_path) -def _run_local_economy_comparison_uk(job_id: str, session: Session) -> None: +def _run_local_economy_comparison_uk( + job_id: str, session: Session, modules: list[str] | None = None +) -> None: """Run UK economy comparison analysis locally.""" from datetime import datetime, timezone from uuid import UUID @@ -670,20 +712,8 @@ def _run_local_economy_comparison_uk(job_id: str, session: Session) -> None: from policyengine.core.dynamic import Dynamic as PEDynamic from policyengine.core.policy import ParameterValue as PEParameterValue from policyengine.core.policy import Policy as PEPolicy - from policyengine.outputs import DecileImpact as PEDecileImpact - from policyengine.outputs.aggregate import Aggregate as PEAggregate - from policyengine.outputs.aggregate import AggregateType as PEAggregateType - from policyengine.outputs.inequality import calculate_uk_inequality - from policyengine.outputs.poverty import ( - calculate_uk_poverty_by_age, - calculate_uk_poverty_by_gender, - calculate_uk_poverty_rates, - ) from policyengine.tax_benefit_models.uk import uk_latest from policyengine.tax_benefit_models.uk.datasets import PolicyEngineUKDataset - from policyengine.tax_benefit_models.uk.outputs import ( - ProgrammeStatistics as PEProgrammeStats, - ) from policyengine_api.models import Policy as DBPolicy @@ -804,383 +834,20 @@ def build_dynamic(dynamic_id): ) pe_reform_sim.ensure() - # Calculate decile impacts - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, - ) - di.run() - decile_impact = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable=di.income_variable, - entity=di.entity, - decile=di.decile, - quantiles=di.quantiles, - baseline_mean=di.baseline_mean, - reform_mean=di.reform_mean, - absolute_change=di.absolute_change, - relative_change=di.relative_change, - count_better_off=di.count_better_off, - count_worse_off=di.count_worse_off, - count_no_change=di.count_no_change, - ) - session.add(decile_impact) - - # Calculate program statistics - PEProgrammeStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) - programmes = { - "income_tax": {"entity": "person", "is_tax": True}, - "national_insurance": {"entity": "person", "is_tax": True}, - "vat": {"entity": "household", "is_tax": True}, - "council_tax": {"entity": "household", "is_tax": True}, - "universal_credit": {"entity": "person", "is_tax": False}, - "child_benefit": {"entity": "person", "is_tax": False}, - "pension_credit": {"entity": "person", "is_tax": False}, - "income_support": {"entity": "person", "is_tax": False}, - "working_tax_credit": {"entity": "person", "is_tax": False}, - "child_tax_credit": {"entity": "person", "is_tax": False}, - } - for prog_name, prog_info in programmes.items(): - try: - ps = PEProgrammeStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - programme_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - program_stat = ProgramStatistics( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, - ) - session.add(program_stat) - except KeyError: - pass # Variable not found in model - - # Calculate poverty rates for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - poverty_results = calculate_uk_poverty_rates(pe_sim) - for pov in poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by age group for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - age_poverty_results = calculate_uk_poverty_by_age(pe_sim) - for pov in age_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by gender for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - gender_poverty_results = calculate_uk_poverty_by_gender(pe_sim) - for pov in gender_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate inequality for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - ineq = calculate_uk_inequality(pe_sim) - ineq.run() - inequality_record = Inequality( - simulation_id=db_sim.id, - report_id=report.id, - income_variable=ineq.income_variable, - entity=ineq.entity, - gini=ineq.gini, - top_10_share=ineq.top_10_share, - top_1_share=ineq.top_1_share, - bottom_50_share=ineq.bottom_50_share, - ) - session.add(inequality_record) - - # Calculate budget summary aggregates - # UK budget variables — household-level aggregates for fiscal totals - uk_budget_variables = { - "household_tax": "household", - "household_benefits": "household", - "household_net_income": "household", - } - PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) - for var_name, entity in uk_budget_variables.items(): - baseline_agg = PEAggregate( - simulation=pe_baseline_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - baseline_agg.run() - reform_agg = PEAggregate( - simulation=pe_reform_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - reform_agg.run() - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name=var_name, - entity=entity, - baseline_total=float(baseline_agg.result), - reform_total=float(reform_agg.result), - change=float(reform_agg.result - baseline_agg.result), - ) - session.add(budget_record) - - # Household count: bypass Aggregate and compute directly from raw numpy - # values. Using Aggregate(SUM) on household_weight would compute - # sum(weight * weight) because MicroSeries.sum() applies weights - # automatically — it's unclear whether Aggregate can be used correctly - # for summing the weight column itself. - baseline_hh_count = float( - pe_baseline_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - reform_hh_count = float( - pe_reform_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name="household_count_total", - entity="household", - baseline_total=baseline_hh_count, - reform_total=reform_hh_count, - change=reform_hh_count - baseline_hh_count, - ) - session.add(budget_record) - - # Calculate intra-decile impact (5-category income change distribution) - from policyengine.outputs.intra_decile_impact import ( - compute_intra_decile_impacts as pe_compute_intra_decile, - ) - - intra_decile_results = pe_compute_intra_decile( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - entity="household", - ) - for r in intra_decile_results.outputs: - record = IntraDecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - decile=r.decile, - lose_more_than_5pct=r.lose_more_than_5pct, - lose_less_than_5pct=r.lose_less_than_5pct, - no_change=r.no_change, - gain_less_than_5pct=r.gain_less_than_5pct, - gain_more_than_5pct=r.gain_more_than_5pct, - ) - session.add(record) - - # Calculate constituency impact (UK only, requires weight matrix) - from policyengine.outputs.constituency_impact import ( - compute_uk_constituency_impacts, - ) - - try: - from policyengine_core.tools.google_cloud import download as gcs_download - - weight_matrix_path = gcs_download( - gcs_bucket="policyengine-uk-data-private", - gcs_key="parliamentary_constituency_weights.h5", - ) - constituency_csv_path = gcs_download( - gcs_bucket="policyengine-uk-data-private", - gcs_key="constituencies_2024.csv", - ) - constituency_impact = compute_uk_constituency_impacts( - pe_baseline_sim, - pe_reform_sim, - weight_matrix_path=weight_matrix_path, - constituency_csv_path=constituency_csv_path, - ) - if constituency_impact.constituency_results: - for cr in constituency_impact.constituency_results: - record = ConstituencyImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - constituency_code=cr["constituency_code"], - constituency_name=cr["constituency_name"], - x=cr["x"], - y=cr["y"], - average_household_income_change=cr[ - "average_household_income_change" - ], - relative_household_income_change=cr[ - "relative_household_income_change" - ], - population=cr["population"], - ) - session.add(record) - except Exception: - pass # Weight matrix not available, skip constituency impact + # Run computation modules + from policyengine_api.api.computation_modules import UK_MODULE_DISPATCH, run_modules - # Calculate local authority impact (UK only, requires weight matrix) - from policyengine.outputs.local_authority_impact import ( - compute_uk_local_authority_impacts, + run_modules( + dispatch=UK_MODULE_DISPATCH, + modules=modules, + pe_baseline_sim=pe_baseline_sim, + pe_reform_sim=pe_reform_sim, + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id, + report_id=report.id, + session=session, ) - try: - from policyengine_core.tools.google_cloud import download as gcs_download - - la_weight_matrix_path = gcs_download( - gcs_bucket="policyengine-uk-data-private", - gcs_key="local_authority_weights.h5", - ) - la_csv_path = gcs_download( - gcs_bucket="policyengine-uk-data-private", - gcs_key="local_authorities_2021.csv", - ) - la_impact = compute_uk_local_authority_impacts( - pe_baseline_sim, - pe_reform_sim, - weight_matrix_path=la_weight_matrix_path, - local_authority_csv_path=la_csv_path, - ) - if la_impact.local_authority_results: - for lr in la_impact.local_authority_results: - record = LocalAuthorityImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - local_authority_code=lr["local_authority_code"], - local_authority_name=lr["local_authority_name"], - x=lr["x"], - y=lr["y"], - average_household_income_change=lr[ - "average_household_income_change" - ], - relative_household_income_change=lr[ - "relative_household_income_change" - ], - population=lr["population"], - ) - session.add(record) - except Exception: - pass # Weight matrix not available, skip local authority impact - - # Calculate wealth decile impact (UK only) - try: - from policyengine.outputs.decile_impact import ( - DecileImpact as PEDecileImpact, - ) - - PEDecileImpact.model_rebuild(_types_namespace={"Simulation": PESimulation}) - for decile_num in range(1, 11): - wealth_di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - decile_variable="household_wealth_decile", - entity="household", - decile=decile_num, - ) - wealth_di.run() - record = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable="household_wealth_decile", - entity="household", - decile=decile_num, - quantiles=10, - baseline_mean=wealth_di.baseline_mean, - reform_mean=wealth_di.reform_mean, - absolute_change=wealth_di.absolute_change, - relative_change=wealth_di.relative_change, - ) - session.add(record) - - # Calculate intra-wealth-decile impact - intra_wealth_results = pe_compute_intra_decile( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - decile_variable="household_wealth_decile", - entity="household", - ) - for r in intra_wealth_results.outputs: - record = IntraDecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - decile_type="wealth", - decile=r.decile, - lose_more_than_5pct=r.lose_more_than_5pct, - lose_less_than_5pct=r.lose_less_than_5pct, - no_change=r.no_change, - gain_less_than_5pct=r.gain_less_than_5pct, - gain_more_than_5pct=r.gain_more_than_5pct, - ) - session.add(record) - except (KeyError, Exception): - pass # household_wealth_decile not available (US), skip - # Mark completed baseline_sim.status = SimulationStatus.COMPLETED baseline_sim.completed_at = datetime.now(timezone.utc) @@ -1193,7 +860,9 @@ def build_dynamic(dynamic_id): session.commit() -def _run_local_economy_comparison_us(job_id: str, session: Session) -> None: +def _run_local_economy_comparison_us( + job_id: str, session: Session, modules: list[str] | None = None +) -> None: """Run US economy comparison analysis locally.""" from datetime import datetime, timezone from uuid import UUID @@ -1202,21 +871,8 @@ def _run_local_economy_comparison_us(job_id: str, session: Session) -> None: from policyengine.core.dynamic import Dynamic as PEDynamic from policyengine.core.policy import ParameterValue as PEParameterValue from policyengine.core.policy import Policy as PEPolicy - from policyengine.outputs import DecileImpact as PEDecileImpact - from policyengine.outputs.aggregate import Aggregate as PEAggregate - from policyengine.outputs.aggregate import AggregateType as PEAggregateType - from policyengine.outputs.inequality import calculate_us_inequality - from policyengine.outputs.poverty import ( - calculate_us_poverty_by_age, - calculate_us_poverty_by_gender, - calculate_us_poverty_by_race, - calculate_us_poverty_rates, - ) from policyengine.tax_benefit_models.us import us_latest from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset - from policyengine.tax_benefit_models.us.outputs import ( - ProgramStatistics as PEProgramStats, - ) from policyengine_api.models import Policy as DBPolicy @@ -1337,284 +993,20 @@ def build_dynamic(dynamic_id): ) pe_reform_sim.ensure() - # Calculate decile impacts - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, - ) - di.run() - decile_impact = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable=di.income_variable, - entity=di.entity, - decile=di.decile, - quantiles=di.quantiles, - baseline_mean=di.baseline_mean, - reform_mean=di.reform_mean, - absolute_change=di.absolute_change, - relative_change=di.relative_change, - count_better_off=di.count_better_off, - count_worse_off=di.count_worse_off, - count_no_change=di.count_no_change, - ) - session.add(decile_impact) - - # Calculate program statistics - PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) - programs = { - "income_tax": {"entity": "tax_unit", "is_tax": True}, - "employee_payroll_tax": {"entity": "person", "is_tax": True}, - "snap": {"entity": "spm_unit", "is_tax": False}, - "tanf": {"entity": "spm_unit", "is_tax": False}, - "ssi": {"entity": "spm_unit", "is_tax": False}, - "social_security": {"entity": "person", "is_tax": False}, - } - for prog_name, prog_info in programs.items(): - try: - ps = PEProgramStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - program_stat = ProgramStatistics( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, - ) - session.add(program_stat) - except KeyError: - pass # Variable not found in model - - # Calculate poverty rates for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - poverty_results = calculate_us_poverty_rates(pe_sim) - for pov in poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by age group for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - age_poverty_results = calculate_us_poverty_by_age(pe_sim) - for pov in age_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by gender for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - gender_poverty_results = calculate_us_poverty_by_gender(pe_sim) - for pov in gender_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by race for baseline and reform (US only) - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - race_poverty_results = calculate_us_poverty_by_race(pe_sim) - for pov in race_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate inequality for baseline and reform - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - ineq = calculate_us_inequality(pe_sim) - ineq.run() - inequality_record = Inequality( - simulation_id=db_sim.id, - report_id=report.id, - income_variable=ineq.income_variable, - entity=ineq.entity, - gini=ineq.gini, - top_10_share=ineq.top_10_share, - top_1_share=ineq.top_1_share, - bottom_50_share=ineq.bottom_50_share, - ) - session.add(inequality_record) - - # Calculate budget summary aggregates - # US budget variables — household-level plus state tax at tax_unit level - us_budget_variables = { - "household_tax": "household", - "household_benefits": "household", - "household_net_income": "household", - "household_state_income_tax": "tax_unit", - } - PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) - for var_name, entity in us_budget_variables.items(): - baseline_agg = PEAggregate( - simulation=pe_baseline_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - baseline_agg.run() - reform_agg = PEAggregate( - simulation=pe_reform_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - reform_agg.run() - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name=var_name, - entity=entity, - baseline_total=float(baseline_agg.result), - reform_total=float(reform_agg.result), - change=float(reform_agg.result - baseline_agg.result), - ) - session.add(budget_record) - - # Household count: bypass Aggregate and compute directly from raw numpy - # values. Using Aggregate(SUM) on household_weight would compute - # sum(weight * weight) because MicroSeries.sum() applies weights - # automatically — it's unclear whether Aggregate can be used correctly - # for summing the weight column itself. - baseline_hh_count = float( - pe_baseline_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - reform_hh_count = float( - pe_reform_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name="household_count_total", - entity="household", - baseline_total=baseline_hh_count, - reform_total=reform_hh_count, - change=reform_hh_count - baseline_hh_count, - ) - session.add(budget_record) + # Run computation modules + from policyengine_api.api.computation_modules import US_MODULE_DISPATCH, run_modules - # Calculate intra-decile impact (5-category income change distribution) - from policyengine.outputs.intra_decile_impact import ( - compute_intra_decile_impacts as pe_compute_intra_decile_us, - ) - - intra_decile_results_us = pe_compute_intra_decile_us( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - entity="household", - ) - for r in intra_decile_results_us.outputs: - record = IntraDecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - decile=r.decile, - lose_more_than_5pct=r.lose_more_than_5pct, - lose_less_than_5pct=r.lose_less_than_5pct, - no_change=r.no_change, - gain_less_than_5pct=r.gain_less_than_5pct, - gain_more_than_5pct=r.gain_more_than_5pct, - ) - session.add(record) - - # Calculate congressional district impact - from policyengine.outputs.congressional_district_impact import ( - compute_us_congressional_district_impacts, + run_modules( + dispatch=US_MODULE_DISPATCH, + modules=modules, + pe_baseline_sim=pe_baseline_sim, + pe_reform_sim=pe_reform_sim, + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id, + report_id=report.id, + session=session, ) - try: - district_impact = compute_us_congressional_district_impacts( - pe_baseline_sim, pe_reform_sim - ) - if district_impact.district_results: - for dr in district_impact.district_results: - record = CongressionalDistrictImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - district_geoid=dr["district_geoid"], - state_fips=dr["state_fips"], - district_number=dr["district_number"], - average_household_income_change=dr[ - "average_household_income_change" - ], - relative_household_income_change=dr[ - "relative_household_income_change" - ], - population=dr["population"], - ) - session.add(record) - except KeyError: - pass # congressional_district_geoid not in dataset - # Mark completed baseline_sim.status = SimulationStatus.COMPLETED baseline_sim.completed_at = datetime.now(timezone.utc) @@ -1628,9 +1020,16 @@ def build_dynamic(dynamic_id): def _trigger_economy_comparison( - job_id: str, tax_benefit_model_name: str, session: Session | None = None + job_id: str, + tax_benefit_model_name: str, + session: Session | None = None, + modules: list[str] | None = None, ) -> None: - """Trigger economy comparison analysis (local or Modal).""" + """Trigger economy comparison analysis (local or Modal). + + Args: + modules: Optional list of module names to run. If None, runs all. + """ from policyengine_api.config import settings traceparent = get_traceparent() @@ -1638,11 +1037,11 @@ def _trigger_economy_comparison( if not settings.agent_use_modal and session is not None: # Run locally if tax_benefit_model_name == "policyengine_uk": - _run_local_economy_comparison_uk(job_id, session) + _run_local_economy_comparison_uk(job_id, session, modules=modules) else: - _run_local_economy_comparison_us(job_id, session) + _run_local_economy_comparison_us(job_id, session, modules=modules) else: - # Use Modal + # Use Modal (modules param passed for future selective computation) import modal if tax_benefit_model_name == "policyengine_uk": @@ -1791,3 +1190,193 @@ def get_economic_impact_status( raise HTTPException(status_code=500, detail="Simulation data missing") return _build_response(report, baseline_sim, reform_sim, session) + + +# --------------------------------------------------------------------------- +# POST /analysis/economy-custom — run selected economy modules +# --------------------------------------------------------------------------- + +_MODEL_TO_COUNTRY = { + "policyengine_uk": "uk", + "policyengine_us": "us", +} + + +class EconomyCustomRequest(BaseModel): + """Request body for custom economy analysis with selected modules.""" + + tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( + description="Which country model to use" + ) + dataset_id: UUID | None = Field( + default=None, + description="Dataset ID. Either dataset_id or region must be provided.", + ) + region: str | None = Field( + default=None, + description="Region code (e.g., 'state/ca', 'us').", + ) + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID to compare against baseline (current law)", + ) + dynamic_id: UUID | None = Field( + default=None, description="Optional behavioural response specification ID" + ) + modules: list[str] = Field( + description="List of module names to compute (see GET /analysis/options)" + ) + + +def _build_filtered_response( + full_response: EconomicImpactResponse, + modules: list[str], +) -> EconomicImpactResponse: + """Return a copy of the response with only the fields for requested modules.""" + allowed_fields: set[str] = set() + for name in modules: + module = MODULE_REGISTRY.get(name) + if module: + allowed_fields.update(module.response_fields) + + # Fields that are always included regardless of modules + always_included = { + "report_id", + "status", + "baseline_simulation", + "reform_simulation", + "region", + "error_message", + } + + filtered = {} + for field_name in EconomicImpactResponse.model_fields: + value = getattr(full_response, field_name) + if field_name in always_included: + filtered[field_name] = value + elif field_name in allowed_fields: + filtered[field_name] = value + else: + filtered[field_name] = None + + return EconomicImpactResponse.model_construct(**filtered) + + +@router.post("/economy-custom", response_model=EconomicImpactResponse) +def economy_custom( + request: EconomyCustomRequest, + session: Session = Depends(get_session), +) -> EconomicImpactResponse: + """Run economy-wide analysis with only the selected modules. + + Same async pattern as /analysis/economic-impact but accepts a list of + module names. Only the requested modules' response fields are populated; + the rest are null. + + See GET /analysis/options for available module names. + """ + country = _MODEL_TO_COUNTRY[request.tax_benefit_model_name] + + try: + validate_modules(request.modules, country) + except ValueError as exc: + raise HTTPException(status_code=422, detail=str(exc)) + + # Reuse the same request model for dataset/region resolution + impact_request = EconomicImpactRequest( + tax_benefit_model_name=request.tax_benefit_model_name, + dataset_id=request.dataset_id, + region=request.region, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + ) + + dataset, region_obj = _resolve_dataset_and_region(impact_request, session) + + filter_field = ( + region_obj.filter_field if region_obj and region_obj.requires_filter else None + ) + filter_value = ( + region_obj.filter_value if region_obj and region_obj.requires_filter else None + ) + + model_version = _get_model_version(request.tax_benefit_model_name, session) + + baseline_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=request.dynamic_id, + session=session, + dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, + ) + + reform_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + model_version_id=model_version.id, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + session=session, + dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, + ) + + label = f"Custom analysis: {request.tax_benefit_model_name}" + if request.policy_id: + label += f" (policy {request.policy_id})" + + report = _get_or_create_report( + baseline_sim.id, reform_sim.id, label, "economy_comparison", session + ) + + if report.status == ReportStatus.PENDING: + with logfire.span("trigger_economy_comparison", job_id=str(report.id)): + _trigger_economy_comparison( + str(report.id), + request.tax_benefit_model_name, + session, + modules=request.modules, + ) + + full_response = _build_response( + report, baseline_sim, reform_sim, session, region_obj + ) + return _build_filtered_response(full_response, request.modules) + + +@router.get("/economy-custom/{report_id}", response_model=EconomicImpactResponse) +def get_economy_custom_status( + report_id: UUID, + modules: str | None = None, + session: Session = Depends(get_session), +) -> EconomicImpactResponse: + """Poll for results of custom economy analysis. + + Args: + report_id: The report ID returned by POST /analysis/economy-custom. + modules: Optional comma-separated module names to filter the response. + If omitted, all computed fields are returned. + """ + report = session.get(Report, report_id) + if not report: + raise HTTPException(status_code=404, detail=f"Report {report_id} not found") + + if not report.baseline_simulation_id or not report.reform_simulation_id: + raise HTTPException(status_code=500, detail="Report missing simulation IDs") + + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + reform_sim = session.get(Simulation, report.reform_simulation_id) + + if not baseline_sim or not reform_sim: + raise HTTPException(status_code=500, detail="Simulation data missing") + + full_response = _build_response(report, baseline_sim, reform_sim, session) + + if modules: + module_list = [m.strip() for m in modules.split(",")] + return _build_filtered_response(full_response, module_list) + + return full_response diff --git a/src/policyengine_api/api/computation_modules.py b/src/policyengine_api/api/computation_modules.py new file mode 100644 index 0000000..3795ae9 --- /dev/null +++ b/src/policyengine_api/api/computation_modules.py @@ -0,0 +1,769 @@ +"""Composable computation module functions for economy analysis. + +Each function computes a single module's results and writes DB records. +They share a common signature pattern: + (pe_baseline_sim, pe_reform_sim, baseline_sim_id, reform_sim_id, + report_id, session) -> None + +Used by _run_local_economy_comparison_uk/us to run modules selectively. +""" + +from __future__ import annotations + +from uuid import UUID + +from sqlmodel import Session + +from policyengine_api.models import ( + BudgetSummary, + CongressionalDistrictImpact, + ConstituencyImpact, + DecileImpact, + Inequality, + IntraDecileImpact, + LocalAuthorityImpact, + Poverty, + ProgramStatistics, +) + +# --------------------------------------------------------------------------- +# Shared modules (UK + US) +# --------------------------------------------------------------------------- + + +def compute_decile_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute income decile impacts (1-10).""" + from policyengine.outputs import DecileImpact as PEDecileImpact + + for decile_num in range(1, 11): + di = PEDecileImpact( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + decile=decile_num, + ) + di.run() + record = DecileImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + income_variable=di.income_variable, + entity=di.entity, + decile=di.decile, + quantiles=di.quantiles, + baseline_mean=di.baseline_mean, + reform_mean=di.reform_mean, + absolute_change=di.absolute_change, + relative_change=di.relative_change, + count_better_off=di.count_better_off, + count_worse_off=di.count_worse_off, + count_no_change=di.count_no_change, + ) + session.add(record) + + +def compute_intra_decile_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute intra-decile income change distribution (5 bands).""" + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts as pe_compute_intra_decile, + ) + + results = pe_compute_intra_decile( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + entity="household", + ) + for r in results.outputs: + record = IntraDecileImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + decile=r.decile, + lose_more_than_5pct=r.lose_more_than_5pct, + lose_less_than_5pct=r.lose_less_than_5pct, + no_change=r.no_change, + gain_less_than_5pct=r.gain_less_than_5pct, + gain_more_than_5pct=r.gain_more_than_5pct, + ) + session.add(record) + + +# --------------------------------------------------------------------------- +# UK-specific modules +# --------------------------------------------------------------------------- + + +def compute_program_statistics_module_uk( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK programme statistics.""" + from policyengine.core import Simulation as PESimulation + from policyengine.tax_benefit_models.uk.outputs import ( + ProgrammeStatistics as PEProgrammeStats, + ) + + PEProgrammeStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) + programmes = { + "income_tax": {"entity": "person", "is_tax": True}, + "national_insurance": {"entity": "person", "is_tax": True}, + "vat": {"entity": "household", "is_tax": True}, + "council_tax": {"entity": "household", "is_tax": True}, + "universal_credit": {"entity": "person", "is_tax": False}, + "child_benefit": {"entity": "person", "is_tax": False}, + "pension_credit": {"entity": "person", "is_tax": False}, + "income_support": {"entity": "person", "is_tax": False}, + "working_tax_credit": {"entity": "person", "is_tax": False}, + "child_tax_credit": {"entity": "person", "is_tax": False}, + } + for prog_name, prog_info in programmes.items(): + try: + ps = PEProgrammeStats( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + programme_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + ) + ps.run() + record = ProgramStatistics( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + baseline_total=ps.baseline_total, + reform_total=ps.reform_total, + change=ps.change, + baseline_count=ps.baseline_count, + reform_count=ps.reform_count, + winners=ps.winners, + losers=ps.losers, + ) + session.add(record) + except KeyError: + pass + + +def compute_poverty_module_uk( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK poverty rates (overall, by age, by gender).""" + from policyengine.outputs.poverty import ( + calculate_uk_poverty_by_age, + calculate_uk_poverty_by_gender, + calculate_uk_poverty_rates, + ) + + sim_pairs = [ + (pe_baseline_sim, baseline_sim_id), + (pe_reform_sim, reform_sim_id), + ] + + for calculator in [ + calculate_uk_poverty_rates, + calculate_uk_poverty_by_age, + calculate_uk_poverty_by_gender, + ]: + for pe_sim, db_sim_id in sim_pairs: + results = calculator(pe_sim) + for pov in results.outputs: + record = Poverty( + simulation_id=db_sim_id, + report_id=report_id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(record) + + +def compute_inequality_module_uk( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK inequality metrics.""" + from policyengine.outputs.inequality import calculate_uk_inequality + + for pe_sim, db_sim_id in [ + (pe_baseline_sim, baseline_sim_id), + (pe_reform_sim, reform_sim_id), + ]: + ineq = calculate_uk_inequality(pe_sim) + ineq.run() + record = Inequality( + simulation_id=db_sim_id, + report_id=report_id, + income_variable=ineq.income_variable, + entity=ineq.entity, + gini=ineq.gini, + top_10_share=ineq.top_10_share, + top_1_share=ineq.top_1_share, + bottom_50_share=ineq.bottom_50_share, + ) + session.add(record) + + +def compute_budget_summary_module_uk( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK budget summary aggregates.""" + from policyengine.core import Simulation as PESimulation + from policyengine.outputs.aggregate import Aggregate as PEAggregate + from policyengine.outputs.aggregate import AggregateType as PEAggregateType + + PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) + + uk_budget_variables = { + "household_tax": "household", + "household_benefits": "household", + "household_net_income": "household", + } + for var_name, entity in uk_budget_variables.items(): + baseline_agg = PEAggregate( + simulation=pe_baseline_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + baseline_agg.run() + reform_agg = PEAggregate( + simulation=pe_reform_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + reform_agg.run() + record = BudgetSummary( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + variable_name=var_name, + entity=entity, + baseline_total=float(baseline_agg.result), + reform_total=float(reform_agg.result), + change=float(reform_agg.result - baseline_agg.result), + ) + session.add(record) + + # Household count: raw sum of weights (bypasses Aggregate weighting) + baseline_hh_count = float( + pe_baseline_sim.output_dataset.data.household["household_weight"].values.sum() + ) + reform_hh_count = float( + pe_reform_sim.output_dataset.data.household["household_weight"].values.sum() + ) + record = BudgetSummary( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + variable_name="household_count_total", + entity="household", + baseline_total=baseline_hh_count, + reform_total=reform_hh_count, + change=reform_hh_count - baseline_hh_count, + ) + session.add(record) + + +def compute_constituency_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK parliamentary constituency impact.""" + from policyengine.outputs.constituency_impact import ( + compute_uk_constituency_impacts, + ) + + try: + from policyengine_core.tools.google_cloud import download as gcs_download + + weight_matrix_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="parliamentary_constituency_weights.h5", + ) + constituency_csv_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="constituencies_2024.csv", + ) + impact = compute_uk_constituency_impacts( + pe_baseline_sim, + pe_reform_sim, + weight_matrix_path=weight_matrix_path, + constituency_csv_path=constituency_csv_path, + ) + if impact.constituency_results: + for cr in impact.constituency_results: + record = ConstituencyImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + constituency_code=cr["constituency_code"], + constituency_name=cr["constituency_name"], + x=cr["x"], + y=cr["y"], + average_household_income_change=cr[ + "average_household_income_change" + ], + relative_household_income_change=cr[ + "relative_household_income_change" + ], + population=cr["population"], + ) + session.add(record) + except Exception: + pass # Weight matrix not available + + +def compute_local_authority_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK local authority impact.""" + from policyengine.outputs.local_authority_impact import ( + compute_uk_local_authority_impacts, + ) + + try: + from policyengine_core.tools.google_cloud import download as gcs_download + + la_weight_matrix_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="local_authority_weights.h5", + ) + la_csv_path = gcs_download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="local_authorities_2021.csv", + ) + impact = compute_uk_local_authority_impacts( + pe_baseline_sim, + pe_reform_sim, + weight_matrix_path=la_weight_matrix_path, + local_authority_csv_path=la_csv_path, + ) + if impact.local_authority_results: + for lr in impact.local_authority_results: + record = LocalAuthorityImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + local_authority_code=lr["local_authority_code"], + local_authority_name=lr["local_authority_name"], + x=lr["x"], + y=lr["y"], + average_household_income_change=lr[ + "average_household_income_change" + ], + relative_household_income_change=lr[ + "relative_household_income_change" + ], + population=lr["population"], + ) + session.add(record) + except Exception: + pass # Weight matrix not available + + +def compute_wealth_decile_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute UK wealth decile impact and intra-wealth-decile breakdown.""" + from policyengine.core import Simulation as PESimulation + from policyengine.outputs.decile_impact import DecileImpact as PEDecileImpact + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts as pe_compute_intra_decile, + ) + + try: + PEDecileImpact.model_rebuild(_types_namespace={"Simulation": PESimulation}) + for decile_num in range(1, 11): + wealth_di = PEDecileImpact( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + decile_variable="household_wealth_decile", + entity="household", + decile=decile_num, + ) + wealth_di.run() + record = DecileImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + income_variable="household_wealth_decile", + entity="household", + decile=decile_num, + quantiles=10, + baseline_mean=wealth_di.baseline_mean, + reform_mean=wealth_di.reform_mean, + absolute_change=wealth_di.absolute_change, + relative_change=wealth_di.relative_change, + ) + session.add(record) + + # Intra-wealth-decile + intra_wealth_results = pe_compute_intra_decile( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable="household_net_income", + decile_variable="household_wealth_decile", + entity="household", + ) + for r in intra_wealth_results.outputs: + record = IntraDecileImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + decile_type="wealth", + decile=r.decile, + lose_more_than_5pct=r.lose_more_than_5pct, + lose_less_than_5pct=r.lose_less_than_5pct, + no_change=r.no_change, + gain_less_than_5pct=r.gain_less_than_5pct, + gain_more_than_5pct=r.gain_more_than_5pct, + ) + session.add(record) + except (KeyError, Exception): + pass # household_wealth_decile not available + + +# --------------------------------------------------------------------------- +# US-specific modules +# --------------------------------------------------------------------------- + + +def compute_program_statistics_module_us( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute US program statistics.""" + from policyengine.core import Simulation as PESimulation + from policyengine.tax_benefit_models.us.outputs import ( + ProgramStatistics as PEProgramStats, + ) + + PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) + programs = { + "income_tax": {"entity": "tax_unit", "is_tax": True}, + "employee_payroll_tax": {"entity": "person", "is_tax": True}, + "snap": {"entity": "spm_unit", "is_tax": False}, + "tanf": {"entity": "spm_unit", "is_tax": False}, + "ssi": {"entity": "spm_unit", "is_tax": False}, + "social_security": {"entity": "person", "is_tax": False}, + } + for prog_name, prog_info in programs.items(): + try: + ps = PEProgramStats( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + ) + ps.run() + record = ProgramStatistics( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + baseline_total=ps.baseline_total, + reform_total=ps.reform_total, + change=ps.change, + baseline_count=ps.baseline_count, + reform_count=ps.reform_count, + winners=ps.winners, + losers=ps.losers, + ) + session.add(record) + except KeyError: + pass + + +def compute_poverty_module_us( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute US poverty rates (overall, by age, gender, race).""" + from policyengine.outputs.poverty import ( + calculate_us_poverty_by_age, + calculate_us_poverty_by_gender, + calculate_us_poverty_by_race, + calculate_us_poverty_rates, + ) + + sim_pairs = [ + (pe_baseline_sim, baseline_sim_id), + (pe_reform_sim, reform_sim_id), + ] + + for calculator in [ + calculate_us_poverty_rates, + calculate_us_poverty_by_age, + calculate_us_poverty_by_gender, + calculate_us_poverty_by_race, + ]: + for pe_sim, db_sim_id in sim_pairs: + results = calculator(pe_sim) + for pov in results.outputs: + record = Poverty( + simulation_id=db_sim_id, + report_id=report_id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(record) + + +def compute_inequality_module_us( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute US inequality metrics.""" + from policyengine.outputs.inequality import calculate_us_inequality + + for pe_sim, db_sim_id in [ + (pe_baseline_sim, baseline_sim_id), + (pe_reform_sim, reform_sim_id), + ]: + ineq = calculate_us_inequality(pe_sim) + ineq.run() + record = Inequality( + simulation_id=db_sim_id, + report_id=report_id, + income_variable=ineq.income_variable, + entity=ineq.entity, + gini=ineq.gini, + top_10_share=ineq.top_10_share, + top_1_share=ineq.top_1_share, + bottom_50_share=ineq.bottom_50_share, + ) + session.add(record) + + +def compute_budget_summary_module_us( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute US budget summary aggregates.""" + from policyengine.core import Simulation as PESimulation + from policyengine.outputs.aggregate import Aggregate as PEAggregate + from policyengine.outputs.aggregate import AggregateType as PEAggregateType + + PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) + + us_budget_variables = { + "household_tax": "household", + "household_benefits": "household", + "household_net_income": "household", + "household_state_income_tax": "tax_unit", + } + for var_name, entity in us_budget_variables.items(): + baseline_agg = PEAggregate( + simulation=pe_baseline_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + baseline_agg.run() + reform_agg = PEAggregate( + simulation=pe_reform_sim, + variable=var_name, + aggregate_type=PEAggregateType.SUM, + entity=entity, + ) + reform_agg.run() + record = BudgetSummary( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + variable_name=var_name, + entity=entity, + baseline_total=float(baseline_agg.result), + reform_total=float(reform_agg.result), + change=float(reform_agg.result - baseline_agg.result), + ) + session.add(record) + + # Household count: raw sum of weights + baseline_hh_count = float( + pe_baseline_sim.output_dataset.data.household["household_weight"].values.sum() + ) + reform_hh_count = float( + pe_reform_sim.output_dataset.data.household["household_weight"].values.sum() + ) + record = BudgetSummary( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + variable_name="household_count_total", + entity="household", + baseline_total=baseline_hh_count, + reform_total=reform_hh_count, + change=reform_hh_count - baseline_hh_count, + ) + session.add(record) + + +def compute_congressional_district_module( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Compute US congressional district impact.""" + from policyengine.outputs.congressional_district_impact import ( + compute_us_congressional_district_impacts, + ) + + try: + impact = compute_us_congressional_district_impacts( + pe_baseline_sim, pe_reform_sim + ) + if impact.district_results: + for dr in impact.district_results: + record = CongressionalDistrictImpact( + baseline_simulation_id=baseline_sim_id, + reform_simulation_id=reform_sim_id, + report_id=report_id, + district_geoid=dr["district_geoid"], + state_fips=dr["state_fips"], + district_number=dr["district_number"], + average_household_income_change=dr[ + "average_household_income_change" + ], + relative_household_income_change=dr[ + "relative_household_income_change" + ], + population=dr["population"], + ) + session.add(record) + except KeyError: + pass # congressional_district_geoid not in dataset + + +# --------------------------------------------------------------------------- +# Dispatch tables: module name -> computation function +# --------------------------------------------------------------------------- + +# Type alias for module computation functions +ModuleFunction = type(compute_decile_module) + +UK_MODULE_DISPATCH: dict[str, ModuleFunction] = { + "decile": compute_decile_module, + "program_statistics": compute_program_statistics_module_uk, + "poverty": compute_poverty_module_uk, + "inequality": compute_inequality_module_uk, + "budget_summary": compute_budget_summary_module_uk, + "intra_decile": compute_intra_decile_module, + "constituency": compute_constituency_module, + "local_authority": compute_local_authority_module, + "wealth_decile": compute_wealth_decile_module, +} + +US_MODULE_DISPATCH: dict[str, ModuleFunction] = { + "decile": compute_decile_module, + "program_statistics": compute_program_statistics_module_us, + "poverty": compute_poverty_module_us, + "inequality": compute_inequality_module_us, + "budget_summary": compute_budget_summary_module_us, + "intra_decile": compute_intra_decile_module, + "congressional_district": compute_congressional_district_module, +} + + +def run_modules( + dispatch: dict[str, ModuleFunction], + modules: list[str] | None, + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id: UUID, + reform_sim_id: UUID, + report_id: UUID, + session: Session, +) -> None: + """Run the requested modules (or all if modules is None).""" + to_run = modules if modules is not None else list(dispatch.keys()) + for mod_name in to_run: + fn = dispatch.get(mod_name) + if fn: + fn( + pe_baseline_sim, + pe_reform_sim, + baseline_sim_id, + reform_sim_id, + report_id, + session, + ) diff --git a/src/policyengine_api/api/module_registry.py b/src/policyengine_api/api/module_registry.py new file mode 100644 index 0000000..3cd5c6c --- /dev/null +++ b/src/policyengine_api/api/module_registry.py @@ -0,0 +1,126 @@ +"""Economy analysis module registry. + +Defines the available computation modules for economy-wide analysis. +Each module maps to a named computation (e.g., "poverty", "decile") with +metadata about which countries support it and which response fields it +populates. + +Used by: +- GET /analysis/options — lists available modules +- POST /analysis/economy-custom — runs selected modules +""" + +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class ComputationModule: + """A named economy analysis computation module.""" + + name: str + label: str + description: str + countries: list[str] = field(default_factory=list) + response_fields: list[str] = field(default_factory=list) + + +MODULE_REGISTRY: dict[str, ComputationModule] = { + "decile": ComputationModule( + name="decile", + label="Income decile impacts", + description="Relative and average income change by income decile (1-10).", + countries=["uk", "us"], + response_fields=["decile_impacts"], + ), + "program_statistics": ComputationModule( + name="program_statistics", + label="Program statistics", + description="Per-program baseline/reform totals, changes, and winner/loser counts.", + countries=["uk", "us"], + response_fields=["program_statistics", "detailed_budget"], + ), + "poverty": ComputationModule( + name="poverty", + label="Poverty rates", + description="Poverty rates by type, overall and by demographic breakdowns (age, gender, race).", + countries=["uk", "us"], + response_fields=["poverty"], + ), + "inequality": ComputationModule( + name="inequality", + label="Inequality metrics", + description="Gini coefficient, top 10%/1% share, bottom 50% share.", + countries=["uk", "us"], + response_fields=["inequality"], + ), + "budget_summary": ComputationModule( + name="budget_summary", + label="Budget summary", + description="Aggregate tax revenue, benefit spending, net income, and household count.", + countries=["uk", "us"], + response_fields=["budget_summary"], + ), + "intra_decile": ComputationModule( + name="intra_decile", + label="Intra-decile breakdown", + description="Distribution of income change categories (5 bands) within each decile.", + countries=["uk", "us"], + response_fields=["intra_decile"], + ), + "congressional_district": ComputationModule( + name="congressional_district", + label="Congressional district impact", + description="Per-district average and relative household income change for US congressional districts.", + countries=["us"], + response_fields=["congressional_district_impact"], + ), + "constituency": ComputationModule( + name="constituency", + label="Parliamentary constituency impact", + description="Per-constituency average and relative household income change for UK parliamentary constituencies.", + countries=["uk"], + response_fields=["constituency_impact"], + ), + "local_authority": ComputationModule( + name="local_authority", + label="Local authority impact", + description="Per-local-authority average and relative household income change for UK local authorities.", + countries=["uk"], + response_fields=["local_authority_impact"], + ), + "wealth_decile": ComputationModule( + name="wealth_decile", + label="Wealth decile impacts", + description="Income change by wealth decile (1-10) and intra-wealth-decile breakdown.", + countries=["uk"], + response_fields=["wealth_decile", "intra_wealth_decile"], + ), +} + + +def get_modules_for_country(country: str) -> list[ComputationModule]: + """Return modules applicable to a country code ('uk' or 'us').""" + return [m for m in MODULE_REGISTRY.values() if country in m.countries] + + +def get_all_module_names() -> list[str]: + """Return all registered module names.""" + return list(MODULE_REGISTRY.keys()) + + +def validate_modules(names: list[str], country: str) -> list[str]: + """Validate module names against the registry for a given country. + + Returns the validated list. Raises ValueError for unknown or + inapplicable modules. + """ + available = {m.name for m in get_modules_for_country(country)} + errors = [] + for name in names: + if name not in MODULE_REGISTRY: + errors.append(f"Unknown module: {name!r}") + elif name not in available: + errors.append(f"Module {name!r} is not available for country {country!r}") + if errors: + raise ValueError("; ".join(errors)) + return names diff --git a/tests/test_analysis_options.py b/tests/test_analysis_options.py new file mode 100644 index 0000000..8df673f --- /dev/null +++ b/tests/test_analysis_options.py @@ -0,0 +1,129 @@ +"""Tests for GET /analysis/options endpoint.""" + +from policyengine_api.api.module_registry import MODULE_REGISTRY, get_modules_for_country + + +class TestAnalysisOptions: + """Tests for the /analysis/options endpoint.""" + + def test_returns_all_modules(self, client): + response = client.get("/analysis/options") + assert response.status_code == 200 + data = response.json() + assert len(data) == len(MODULE_REGISTRY) + + def test_response_shape(self, client): + response = client.get("/analysis/options") + data = response.json() + for item in data: + assert "name" in item + assert "label" in item + assert "description" in item + assert "response_fields" in item + assert isinstance(item["response_fields"], list) + + def test_all_names_are_strings(self, client): + response = client.get("/analysis/options") + for item in response.json(): + assert isinstance(item["name"], str) + assert len(item["name"]) > 0 + + def test_all_labels_are_non_empty(self, client): + response = client.get("/analysis/options") + for item in response.json(): + assert isinstance(item["label"], str) + assert len(item["label"]) > 0 + + def test_all_descriptions_are_non_empty(self, client): + response = client.get("/analysis/options") + for item in response.json(): + assert isinstance(item["description"], str) + assert len(item["description"]) > 0 + + def test_all_response_fields_are_non_empty_lists(self, client): + response = client.get("/analysis/options") + for item in response.json(): + assert len(item["response_fields"]) > 0 + for field in item["response_fields"]: + assert isinstance(field, str) + + def test_filter_by_uk(self, client): + response = client.get("/analysis/options?country=uk") + assert response.status_code == 200 + data = response.json() + names = [m["name"] for m in data] + assert "constituency" in names + assert "local_authority" in names + assert "wealth_decile" in names + assert "congressional_district" not in names + + def test_filter_by_us(self, client): + response = client.get("/analysis/options?country=us") + assert response.status_code == 200 + data = response.json() + names = [m["name"] for m in data] + assert "congressional_district" in names + assert "constituency" not in names + assert "local_authority" not in names + assert "wealth_decile" not in names + + def test_uk_count_matches_registry(self, client): + response = client.get("/analysis/options?country=uk") + data = response.json() + expected = len(get_modules_for_country("uk")) + assert len(data) == expected + + def test_us_count_matches_registry(self, client): + response = client.get("/analysis/options?country=us") + data = response.json() + expected = len(get_modules_for_country("us")) + assert len(data) == expected + + def test_shared_modules_in_both_countries(self, client): + uk_resp = client.get("/analysis/options?country=uk") + us_resp = client.get("/analysis/options?country=us") + uk_names = {m["name"] for m in uk_resp.json()} + us_names = {m["name"] for m in us_resp.json()} + for shared in [ + "decile", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + "program_statistics", + ]: + assert shared in uk_names + assert shared in us_names + + def test_unknown_country_returns_empty(self, client): + response = client.get("/analysis/options?country=fr") + assert response.status_code == 200 + assert response.json() == [] + + def test_program_statistics_has_two_response_fields(self, client): + response = client.get("/analysis/options") + ps_module = next( + m for m in response.json() if m["name"] == "program_statistics" + ) + assert "program_statistics" in ps_module["response_fields"] + assert "detailed_budget" in ps_module["response_fields"] + + def test_wealth_decile_has_two_response_fields(self, client): + response = client.get("/analysis/options?country=uk") + wd_module = next(m for m in response.json() if m["name"] == "wealth_decile") + assert "wealth_decile" in wd_module["response_fields"] + assert "intra_wealth_decile" in wd_module["response_fields"] + + def test_no_country_param_returns_all(self, client): + all_resp = client.get("/analysis/options") + data = all_resp.json() + returned_names = {m["name"] for m in data} + assert returned_names == set(MODULE_REGISTRY.keys()) + + def test_response_matches_registry_data(self, client): + response = client.get("/analysis/options") + for item in response.json(): + registry_mod = MODULE_REGISTRY[item["name"]] + assert item["label"] == registry_mod.label + assert item["description"] == registry_mod.description + assert item["response_fields"] == list(registry_mod.response_fields) diff --git a/tests/test_computation_modules.py b/tests/test_computation_modules.py new file mode 100644 index 0000000..5329d39 --- /dev/null +++ b/tests/test_computation_modules.py @@ -0,0 +1,282 @@ +"""Tests for the composable computation module dispatch system.""" + +import inspect +from unittest.mock import MagicMock, call +from uuid import uuid4 + +from policyengine_api.api import computation_modules as cm +from policyengine_api.api.computation_modules import ( + UK_MODULE_DISPATCH, + US_MODULE_DISPATCH, + run_modules, +) +from policyengine_api.api.module_registry import MODULE_REGISTRY + + +class TestDispatchTables: + """Tests for UK_MODULE_DISPATCH and US_MODULE_DISPATCH.""" + + def test_uk_dispatch_keys_match_registry(self): + """Every UK dispatch key should be a valid module in the registry.""" + for key in UK_MODULE_DISPATCH: + assert key in MODULE_REGISTRY, f"UK dispatch key {key!r} not in registry" + + def test_us_dispatch_keys_match_registry(self): + """Every US dispatch key should be a valid module in the registry.""" + for key in US_MODULE_DISPATCH: + assert key in MODULE_REGISTRY, f"US dispatch key {key!r} not in registry" + + def test_uk_dispatch_covers_uk_modules(self): + """UK dispatch should have an entry for every UK-applicable module.""" + uk_module_names = { + name for name, mod in MODULE_REGISTRY.items() if "uk" in mod.countries + } + assert set(UK_MODULE_DISPATCH.keys()) == uk_module_names + + def test_us_dispatch_covers_us_modules(self): + """US dispatch should have an entry for every US-applicable module.""" + us_module_names = { + name for name, mod in MODULE_REGISTRY.items() if "us" in mod.countries + } + assert set(US_MODULE_DISPATCH.keys()) == us_module_names + + def test_all_dispatch_values_are_callable(self): + for fn in UK_MODULE_DISPATCH.values(): + assert callable(fn) + for fn in US_MODULE_DISPATCH.values(): + assert callable(fn) + + def test_uk_dispatch_has_9_entries(self): + assert len(UK_MODULE_DISPATCH) == 9 + + def test_us_dispatch_has_7_entries(self): + assert len(US_MODULE_DISPATCH) == 7 + + +class TestSharedModuleFunctions: + """Tests that shared modules reference the same function objects.""" + + def test_decile_function_shared_between_uk_and_us(self): + assert UK_MODULE_DISPATCH["decile"] is US_MODULE_DISPATCH["decile"] + assert UK_MODULE_DISPATCH["decile"] is cm.compute_decile_module + + def test_intra_decile_function_shared_between_uk_and_us(self): + assert UK_MODULE_DISPATCH["intra_decile"] is US_MODULE_DISPATCH["intra_decile"] + assert UK_MODULE_DISPATCH["intra_decile"] is cm.compute_intra_decile_module + + +class TestCountrySpecificFunctions: + """Tests that UK/US specific modules use the correct country-specific functions.""" + + def test_uk_program_statistics(self): + assert ( + UK_MODULE_DISPATCH["program_statistics"] + is cm.compute_program_statistics_module_uk + ) + + def test_us_program_statistics(self): + assert ( + US_MODULE_DISPATCH["program_statistics"] + is cm.compute_program_statistics_module_us + ) + + def test_uk_poverty(self): + assert UK_MODULE_DISPATCH["poverty"] is cm.compute_poverty_module_uk + + def test_us_poverty(self): + assert US_MODULE_DISPATCH["poverty"] is cm.compute_poverty_module_us + + def test_uk_inequality(self): + assert UK_MODULE_DISPATCH["inequality"] is cm.compute_inequality_module_uk + + def test_us_inequality(self): + assert US_MODULE_DISPATCH["inequality"] is cm.compute_inequality_module_us + + def test_uk_budget_summary(self): + assert ( + UK_MODULE_DISPATCH["budget_summary"] + is cm.compute_budget_summary_module_uk + ) + + def test_us_budget_summary(self): + assert ( + US_MODULE_DISPATCH["budget_summary"] + is cm.compute_budget_summary_module_us + ) + + def test_constituency_is_uk_only(self): + assert UK_MODULE_DISPATCH["constituency"] is cm.compute_constituency_module + assert "constituency" not in US_MODULE_DISPATCH + + def test_local_authority_is_uk_only(self): + assert ( + UK_MODULE_DISPATCH["local_authority"] is cm.compute_local_authority_module + ) + assert "local_authority" not in US_MODULE_DISPATCH + + def test_wealth_decile_is_uk_only(self): + assert UK_MODULE_DISPATCH["wealth_decile"] is cm.compute_wealth_decile_module + assert "wealth_decile" not in US_MODULE_DISPATCH + + def test_congressional_district_is_us_only(self): + assert ( + US_MODULE_DISPATCH["congressional_district"] + is cm.compute_congressional_district_module + ) + assert "congressional_district" not in UK_MODULE_DISPATCH + + +class TestModuleFunctionSignatures: + """Tests that all module functions share the expected 6-param signature.""" + + _EXPECTED_PARAMS = [ + "pe_baseline_sim", + "pe_reform_sim", + "baseline_sim_id", + "reform_sim_id", + "report_id", + "session", + ] + + def _get_all_unique_functions(self): + """Collect all unique module functions from both dispatch tables.""" + seen = set() + fns = [] + for fn in list(UK_MODULE_DISPATCH.values()) + list( + US_MODULE_DISPATCH.values() + ): + if id(fn) not in seen: + seen.add(id(fn)) + fns.append(fn) + return fns + + def test_all_functions_have_6_parameters(self): + for fn in self._get_all_unique_functions(): + sig = inspect.signature(fn) + assert len(sig.parameters) == 6, ( + f"{fn.__name__} has {len(sig.parameters)} params, expected 6" + ) + + def test_all_functions_have_expected_param_names(self): + for fn in self._get_all_unique_functions(): + sig = inspect.signature(fn) + param_names = list(sig.parameters.keys()) + assert param_names == self._EXPECTED_PARAMS, ( + f"{fn.__name__} params {param_names} != {self._EXPECTED_PARAMS}" + ) + + def test_all_functions_return_none(self): + for fn in self._get_all_unique_functions(): + sig = inspect.signature(fn) + # `from __future__ import annotations` makes annotations strings + assert sig.return_annotation in (None, "None", inspect.Parameter.empty), ( + f"{fn.__name__} return annotation is {sig.return_annotation!r}, expected None" + ) + + +class TestRunModules: + """Tests for the run_modules dispatch helper.""" + + def _make_mock_dispatch(self, names): + """Create a dispatch dict with mock functions.""" + return {name: MagicMock(name=f"compute_{name}") for name in names} + + def test_runs_all_when_modules_is_none(self): + dispatch = self._make_mock_dispatch(["a", "b", "c"]) + session = MagicMock() + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], session) + + for fn in dispatch.values(): + fn.assert_called_once_with("bl", "rf", ids[0], ids[1], ids[2], session) + + def test_runs_only_requested_modules(self): + dispatch = self._make_mock_dispatch(["a", "b", "c"]) + session = MagicMock() + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, ["b"], "bl", "rf", ids[0], ids[1], ids[2], session) + + dispatch["a"].assert_not_called() + dispatch["b"].assert_called_once() + dispatch["c"].assert_not_called() + + def test_ignores_unknown_module_names(self): + dispatch = self._make_mock_dispatch(["a"]) + session = MagicMock() + ids = [uuid4() for _ in range(3)] + + # Should not raise + run_modules( + dispatch, ["a", "nonexistent"], "bl", "rf", ids[0], ids[1], ids[2], session + ) + + dispatch["a"].assert_called_once() + + def test_empty_modules_list_runs_nothing(self): + dispatch = self._make_mock_dispatch(["a", "b"]) + session = MagicMock() + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, [], "bl", "rf", ids[0], ids[1], ids[2], session) + + for fn in dispatch.values(): + fn.assert_not_called() + + def test_preserves_call_order(self): + """Modules should be called in the order they appear in the modules list.""" + call_order = [] + + def make_tracker(name): + def fn(*args): + call_order.append(name) + + return fn + + dispatch = {name: make_tracker(name) for name in ["a", "b", "c"]} + ids = [uuid4() for _ in range(3)] + + run_modules( + dispatch, ["c", "a", "b"], "bl", "rf", ids[0], ids[1], ids[2], MagicMock() + ) + + assert call_order == ["c", "a", "b"] + + def test_none_modules_runs_all_in_dispatch_key_order(self): + """When modules is None, all dispatch entries run in dict-iteration order.""" + call_order = [] + + def make_tracker(name): + def fn(*args): + call_order.append(name) + + return fn + + dispatch = {name: make_tracker(name) for name in ["x", "y", "z"]} + ids = [uuid4() for _ in range(3)] + + run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], MagicMock()) + + assert call_order == ["x", "y", "z"] + + def test_passes_all_args_correctly(self): + mock_fn = MagicMock() + dispatch = {"test_mod": mock_fn} + session = MagicMock() + bl, rf, b_id, r_id, rep_id = "baseline", "reform", uuid4(), uuid4(), uuid4() + + run_modules(dispatch, ["test_mod"], bl, rf, b_id, r_id, rep_id, session) + + mock_fn.assert_called_once_with(bl, rf, b_id, r_id, rep_id, session) + + def test_duplicate_module_name_runs_twice(self): + dispatch = self._make_mock_dispatch(["a"]) + session = MagicMock() + ids = [uuid4() for _ in range(3)] + + run_modules( + dispatch, ["a", "a"], "bl", "rf", ids[0], ids[1], ids[2], session + ) + + assert dispatch["a"].call_count == 2 diff --git a/tests/test_economy_custom.py b/tests/test_economy_custom.py new file mode 100644 index 0000000..f96002a --- /dev/null +++ b/tests/test_economy_custom.py @@ -0,0 +1,317 @@ +"""Tests for POST /analysis/economy-custom endpoint.""" + +from uuid import uuid4 + +from policyengine_api.api.analysis import ( + EconomicImpactResponse, + SimulationInfo, + _build_filtered_response, +) +from policyengine_api.models import ReportStatus, SimulationStatus + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_stub_response(**overrides) -> EconomicImpactResponse: + """Build a minimal EconomicImpactResponse for testing.""" + defaults = dict( + report_id=uuid4(), + status=ReportStatus.COMPLETED, + baseline_simulation=SimulationInfo( + id=uuid4(), status=SimulationStatus.COMPLETED + ), + reform_simulation=SimulationInfo(id=uuid4(), status=SimulationStatus.COMPLETED), + region=None, + error_message=None, + decile_impacts=[{"fake": "decile"}], + program_statistics=[{"fake": "program"}], + poverty=[{"fake": "poverty"}], + inequality=[{"fake": "inequality"}], + budget_summary=[{"fake": "budget"}], + intra_decile=[{"fake": "intra"}], + detailed_budget={"prog": {"baseline": 1.0}}, + congressional_district_impact=[{"fake": "district"}], + constituency_impact=[{"fake": "constituency"}], + local_authority_impact=[{"fake": "la"}], + wealth_decile=[{"fake": "wealth"}], + intra_wealth_decile=[{"fake": "intra_wealth"}], + ) + defaults.update(overrides) + return EconomicImpactResponse.model_construct(**defaults) + + +# All data fields that can be nullified by module filtering +_DATA_FIELDS = { + "decile_impacts", + "program_statistics", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + "detailed_budget", + "congressional_district_impact", + "constituency_impact", + "local_authority_impact", + "wealth_decile", + "intra_wealth_decile", +} + +_ALWAYS_INCLUDED = { + "report_id", + "status", + "baseline_simulation", + "reform_simulation", + "region", + "error_message", +} + + +# --------------------------------------------------------------------------- +# Unit tests for _build_filtered_response +# --------------------------------------------------------------------------- + + +class TestBuildFilteredResponse: + """Tests for response filtering by module list.""" + + def test_single_module_keeps_only_its_fields(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["poverty"]) + assert filtered.poverty is not None + assert filtered.decile_impacts is None + assert filtered.program_statistics is None + assert filtered.inequality is None + + def test_multiple_modules(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["decile", "inequality"]) + assert filtered.decile_impacts is not None + assert filtered.inequality is not None + assert filtered.poverty is None + assert filtered.congressional_district_impact is None + + def test_program_statistics_includes_detailed_budget(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["program_statistics"]) + assert filtered.program_statistics is not None + assert filtered.detailed_budget is not None + assert filtered.decile_impacts is None + + def test_always_included_fields_preserved(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["poverty"]) + assert filtered.report_id == resp.report_id + assert filtered.status == resp.status + assert filtered.baseline_simulation is not None + assert filtered.reform_simulation is not None + + def test_region_always_included(self): + from policyengine_api.api.analysis import RegionInfo + + region = RegionInfo( + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + resp = _make_stub_response(region=region) + filtered = _build_filtered_response(resp, ["decile"]) + assert filtered.region is not None + assert filtered.region.code == "uk" + + def test_error_message_always_included(self): + resp = _make_stub_response(error_message="something went wrong") + filtered = _build_filtered_response(resp, ["decile"]) + assert filtered.error_message == "something went wrong" + + def test_empty_modules_nullifies_all_data_fields(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, []) + for field in _DATA_FIELDS: + assert getattr(filtered, field) is None, f"{field} should be None" + assert filtered.report_id == resp.report_id + + def test_empty_modules_preserves_always_included(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, []) + for field in _ALWAYS_INCLUDED: + original = getattr(resp, field) + assert getattr(filtered, field) == original, ( + f"{field} should be preserved" + ) + + def test_wealth_decile_keeps_both_fields(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["wealth_decile"]) + assert filtered.wealth_decile is not None + assert filtered.intra_wealth_decile is not None + assert filtered.decile_impacts is None + assert filtered.intra_decile is None + + def test_intra_decile_keeps_only_intra_decile(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["intra_decile"]) + assert filtered.intra_decile is not None + assert filtered.decile_impacts is None + assert filtered.intra_wealth_decile is None + + def test_congressional_district_keeps_only_district_impact(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["congressional_district"]) + assert filtered.congressional_district_impact is not None + assert filtered.constituency_impact is None + assert filtered.local_authority_impact is None + + def test_constituency_keeps_only_constituency_impact(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["constituency"]) + assert filtered.constituency_impact is not None + assert filtered.congressional_district_impact is None + assert filtered.local_authority_impact is None + + def test_local_authority_keeps_only_la_impact(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["local_authority"]) + assert filtered.local_authority_impact is not None + assert filtered.constituency_impact is None + assert filtered.congressional_district_impact is None + + def test_budget_summary_keeps_only_budget(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["budget_summary"]) + assert filtered.budget_summary is not None + assert filtered.decile_impacts is None + assert filtered.program_statistics is None + + def test_unknown_module_in_list_is_gracefully_ignored(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["poverty", "nonexistent_module"]) + assert filtered.poverty is not None + assert filtered.decile_impacts is None + + def test_all_modules_keeps_all_data_fields(self): + from policyengine_api.api.module_registry import MODULE_REGISTRY + + resp = _make_stub_response() + all_names = list(MODULE_REGISTRY.keys()) + filtered = _build_filtered_response(resp, all_names) + for field in _DATA_FIELDS: + assert getattr(filtered, field) is not None, ( + f"{field} should be preserved when all modules selected" + ) + + def test_returns_economic_impact_response_instance(self): + resp = _make_stub_response() + filtered = _build_filtered_response(resp, ["decile"]) + assert isinstance(filtered, EconomicImpactResponse) + + +# --------------------------------------------------------------------------- +# Integration tests for the endpoint itself +# --------------------------------------------------------------------------- + + +class TestEconomyCustomEndpoint: + """Tests for POST /analysis/economy-custom validation.""" + + def test_unknown_module_returns_422(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": ["nonexistent_module"], + }, + ) + assert response.status_code == 422 + assert "Unknown module" in response.json()["detail"] + + def test_wrong_country_module_returns_422(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": ["constituency"], + }, + ) + assert response.status_code == 422 + assert "not available for country" in response.json()["detail"] + + def test_multiple_errors_in_module_validation(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": ["nonexistent", "constituency"], + }, + ) + assert response.status_code == 422 + detail = response.json()["detail"] + assert "Unknown module" in detail + assert "not available for country" in detail + + def test_empty_modules_list_passes_validation(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": [], + }, + ) + # Empty list passes module validation, so the error should be about + # dataset/region resolution, not about modules + assert ( + response.status_code != 422 + or "module" not in response.json().get("detail", "").lower() + ) + + def test_valid_modules_but_missing_region_returns_404(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + "modules": ["decile", "poverty"], + }, + ) + # Passes validation but region "us" is not in the DB -> 404 + assert response.status_code == 404 + + def test_missing_modules_field_returns_422(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "policyengine_us", + "region": "us", + }, + ) + assert response.status_code == 422 + + def test_invalid_model_name_returns_422(self, client): + response = client.post( + "/analysis/economy-custom", + json={ + "tax_benefit_model_name": "invalid_model", + "region": "us", + "modules": ["decile"], + }, + ) + assert response.status_code == 422 + + +class TestEconomyCustomPolling: + """Tests for GET /analysis/economy-custom/{report_id}.""" + + def test_not_found(self, client): + fake_id = uuid4() + response = client.get(f"/analysis/economy-custom/{fake_id}") + assert response.status_code == 404 + + def test_invalid_uuid_returns_422(self, client): + response = client.get("/analysis/economy-custom/not-a-uuid") + assert response.status_code == 422 diff --git a/tests/test_module_registry.py b/tests/test_module_registry.py new file mode 100644 index 0000000..b3f8559 --- /dev/null +++ b/tests/test_module_registry.py @@ -0,0 +1,280 @@ +"""Tests for the economy analysis module registry.""" + +from dataclasses import FrozenInstanceError + +import pytest + +from policyengine_api.api.analysis import EconomicImpactResponse +from policyengine_api.api.module_registry import ( + MODULE_REGISTRY, + ComputationModule, + get_all_module_names, + get_modules_for_country, + validate_modules, +) + + +class TestModuleRegistry: + """Tests for MODULE_REGISTRY contents.""" + + def test_registry_is_not_empty(self): + assert len(MODULE_REGISTRY) > 0 + + def test_registry_has_exactly_10_modules(self): + assert len(MODULE_REGISTRY) == 10 + + def test_all_entries_are_computation_modules(self): + for name, module in MODULE_REGISTRY.items(): + assert isinstance(module, ComputationModule) + assert module.name == name + + def test_all_modules_have_countries(self): + for module in MODULE_REGISTRY.values(): + assert len(module.countries) > 0 + for country in module.countries: + assert country in ("uk", "us") + + def test_all_modules_have_response_fields(self): + for module in MODULE_REGISTRY.values(): + assert len(module.response_fields) > 0 + + def test_all_modules_have_non_empty_label(self): + for name, module in MODULE_REGISTRY.items(): + assert module.label, f"Module {name!r} has empty label" + assert len(module.label) > 0 + + def test_all_modules_have_non_empty_description(self): + for name, module in MODULE_REGISTRY.items(): + assert module.description, f"Module {name!r} has empty description" + assert len(module.description) > 0 + + def test_expected_modules_exist(self): + expected = [ + "decile", + "program_statistics", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + "congressional_district", + "constituency", + "local_authority", + "wealth_decile", + ] + for name in expected: + assert name in MODULE_REGISTRY, f"Missing module: {name}" + + def test_no_unexpected_modules(self): + expected = { + "decile", + "program_statistics", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + "congressional_district", + "constituency", + "local_authority", + "wealth_decile", + } + assert set(MODULE_REGISTRY.keys()) == expected + + +class TestComputationModuleFrozen: + """Tests that ComputationModule instances are immutable.""" + + def test_cannot_mutate_name(self): + module = MODULE_REGISTRY["decile"] + with pytest.raises(FrozenInstanceError): + module.name = "changed" + + def test_cannot_mutate_countries(self): + module = MODULE_REGISTRY["decile"] + with pytest.raises(FrozenInstanceError): + module.countries = ["fr"] + + def test_cannot_mutate_response_fields(self): + module = MODULE_REGISTRY["poverty"] + with pytest.raises(FrozenInstanceError): + module.response_fields = ["something_else"] + + +class TestResponseFieldsMapping: + """Tests that each module's response_fields reference valid EconomicImpactResponse fields.""" + + def test_all_response_fields_exist_on_response_model(self): + valid_fields = set(EconomicImpactResponse.model_fields.keys()) + for name, module in MODULE_REGISTRY.items(): + for field in module.response_fields: + assert field in valid_fields, ( + f"Module {name!r} references response field {field!r} " + f"which does not exist on EconomicImpactResponse" + ) + + def test_decile_response_fields(self): + assert MODULE_REGISTRY["decile"].response_fields == ["decile_impacts"] + + def test_program_statistics_includes_detailed_budget(self): + fields = MODULE_REGISTRY["program_statistics"].response_fields + assert "program_statistics" in fields + assert "detailed_budget" in fields + + def test_poverty_response_fields(self): + assert MODULE_REGISTRY["poverty"].response_fields == ["poverty"] + + def test_inequality_response_fields(self): + assert MODULE_REGISTRY["inequality"].response_fields == ["inequality"] + + def test_budget_summary_response_fields(self): + assert MODULE_REGISTRY["budget_summary"].response_fields == ["budget_summary"] + + def test_intra_decile_response_fields(self): + assert MODULE_REGISTRY["intra_decile"].response_fields == ["intra_decile"] + + def test_congressional_district_response_fields(self): + assert MODULE_REGISTRY["congressional_district"].response_fields == [ + "congressional_district_impact" + ] + + def test_constituency_response_fields(self): + assert MODULE_REGISTRY["constituency"].response_fields == [ + "constituency_impact" + ] + + def test_local_authority_response_fields(self): + assert MODULE_REGISTRY["local_authority"].response_fields == [ + "local_authority_impact" + ] + + def test_wealth_decile_includes_both_fields(self): + fields = MODULE_REGISTRY["wealth_decile"].response_fields + assert "wealth_decile" in fields + assert "intra_wealth_decile" in fields + + +class TestCountryApplicability: + """Tests for country-specific module availability.""" + + def test_us_only_modules(self): + assert "us" in MODULE_REGISTRY["congressional_district"].countries + assert "uk" not in MODULE_REGISTRY["congressional_district"].countries + + def test_uk_only_modules(self): + for name in ("constituency", "local_authority", "wealth_decile"): + module = MODULE_REGISTRY[name] + assert "uk" in module.countries + assert "us" not in module.countries + + def test_shared_modules(self): + shared = [ + "decile", + "program_statistics", + "poverty", + "inequality", + "budget_summary", + "intra_decile", + ] + for name in shared: + module = MODULE_REGISTRY[name] + assert "uk" in module.countries + assert "us" in module.countries + + +class TestGetModulesForCountry: + """Tests for get_modules_for_country().""" + + def test_uk_includes_constituency(self): + uk_modules = get_modules_for_country("uk") + names = [m.name for m in uk_modules] + assert "constituency" in names + assert "local_authority" in names + assert "wealth_decile" in names + + def test_uk_excludes_congressional_district(self): + uk_modules = get_modules_for_country("uk") + names = [m.name for m in uk_modules] + assert "congressional_district" not in names + + def test_uk_has_9_modules(self): + uk_modules = get_modules_for_country("uk") + assert len(uk_modules) == 9 + + def test_us_includes_congressional_district(self): + us_modules = get_modules_for_country("us") + names = [m.name for m in us_modules] + assert "congressional_district" in names + + def test_us_excludes_uk_only(self): + us_modules = get_modules_for_country("us") + names = [m.name for m in us_modules] + assert "constituency" not in names + assert "local_authority" not in names + assert "wealth_decile" not in names + + def test_us_has_7_modules(self): + us_modules = get_modules_for_country("us") + assert len(us_modules) == 7 + + def test_unknown_country_returns_empty(self): + assert get_modules_for_country("fr") == [] + + def test_returns_computation_module_instances(self): + for m in get_modules_for_country("uk"): + assert isinstance(m, ComputationModule) + + +class TestGetAllModuleNames: + """Tests for get_all_module_names().""" + + def test_returns_all_names(self): + names = get_all_module_names() + assert set(names) == set(MODULE_REGISTRY.keys()) + + def test_returns_list_of_strings(self): + names = get_all_module_names() + assert isinstance(names, list) + for name in names: + assert isinstance(name, str) + + +class TestValidateModules: + """Tests for validate_modules().""" + + def test_valid_us_modules(self): + result = validate_modules(["decile", "poverty"], "us") + assert result == ["decile", "poverty"] + + def test_valid_uk_modules(self): + result = validate_modules(["constituency", "wealth_decile"], "uk") + assert result == ["constituency", "wealth_decile"] + + def test_empty_list_passes_validation(self): + result = validate_modules([], "us") + assert result == [] + + def test_all_us_modules_pass_validation(self): + us_names = [m.name for m in get_modules_for_country("us")] + result = validate_modules(us_names, "us") + assert result == us_names + + def test_all_uk_modules_pass_validation(self): + uk_names = [m.name for m in get_modules_for_country("uk")] + result = validate_modules(uk_names, "uk") + assert result == uk_names + + def test_unknown_module_raises(self): + with pytest.raises(ValueError, match="Unknown module"): + validate_modules(["nonexistent"], "us") + + def test_wrong_country_raises(self): + with pytest.raises(ValueError, match="not available for country"): + validate_modules(["congressional_district"], "uk") + + def test_multiple_errors_combined(self): + with pytest.raises(ValueError, match="Unknown module.*not available"): + validate_modules(["nonexistent", "constituency"], "us") + + def test_returns_original_list_on_success(self): + names = ["poverty", "decile", "inequality"] + result = validate_modules(names, "us") + assert result is names