From 7841d3b17cd3a7676a9d6b221c25f323a91a6b0b Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 10 Feb 2026 21:01:26 +0100 Subject: [PATCH 01/10] feat: Add regions support for geographic analysis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add Region SQLModel with filtering fields (code, label, region_type, requires_filter, filter_field, filter_value, dataset_id, etc.) - Add Alembic migration for regions table - Add GET /regions/ endpoint with filters by model and region type - Add GET /regions/{region_id} and GET /regions/by-code/{code} endpoints - Add region parameter to analysis endpoint with dataset/region resolution 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../versions/20260210_add_regions_table.py | 63 ++++++++++ src/policyengine_api/api/__init__.py | 2 + src/policyengine_api/api/analysis.py | 112 ++++++++++++++++-- src/policyengine_api/api/regions.py | 102 ++++++++++++++++ src/policyengine_api/models/__init__.py | 4 + src/policyengine_api/models/region.py | 60 ++++++++++ 6 files changed, 330 insertions(+), 13 deletions(-) create mode 100644 alembic/versions/20260210_add_regions_table.py create mode 100644 src/policyengine_api/api/regions.py create mode 100644 src/policyengine_api/models/region.py diff --git a/alembic/versions/20260210_add_regions_table.py b/alembic/versions/20260210_add_regions_table.py new file mode 100644 index 0000000..effeab2 --- /dev/null +++ b/alembic/versions/20260210_add_regions_table.py @@ -0,0 +1,63 @@ +"""add regions table + +Revision ID: a1b2c3d4e5f6 +Revises: f419b5f4acba +Create Date: 2026-02-10 12:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision: str = "a1b2c3d4e5f6" +down_revision: Union[str, Sequence[str], None] = "f419b5f4acba" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create regions table.""" + op.create_table( + "regions", + sa.Column("code", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("region_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("requires_filter", sa.Boolean(), nullable=False, default=False), + sa.Column("filter_field", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("filter_value", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("parent_code", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("state_code", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("state_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("dataset_id", sa.Uuid(), nullable=False), + sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["dataset_id"], + ["datasets.id"], + ), + sa.ForeignKeyConstraint( + ["tax_benefit_model_id"], + ["tax_benefit_models.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # Create unique constraint on (code, tax_benefit_model_id) + op.create_index( + "ix_regions_code_model", + "regions", + ["code", "tax_benefit_model_id"], + unique=True, + ) + + +def downgrade() -> None: + """Drop regions table.""" + op.drop_index("ix_regions_code_model", table_name="regions") + op.drop_table("regions") diff --git a/src/policyengine_api/api/__init__.py b/src/policyengine_api/api/__init__.py index c3e0353..f135b14 100644 --- a/src/policyengine_api/api/__init__.py +++ b/src/policyengine_api/api/__init__.py @@ -15,6 +15,7 @@ parameter_values, parameters, policies, + regions, simulations, tax_benefit_model_versions, tax_benefit_models, @@ -26,6 +27,7 @@ api_router.include_router(datasets.router) api_router.include_router(policies.router) +api_router.include_router(regions.router) api_router.include_router(simulations.router) api_router.include_router(outputs.router) api_router.include_router(variables.router) diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index 10e6fc5..101669b 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -31,6 +31,7 @@ DecileImpactRead, ProgramStatistics, ProgramStatisticsRead, + Region, Report, ReportStatus, Simulation, @@ -68,19 +69,31 @@ def _safe_float(value: float | None) -> float | None: class EconomicImpactRequest(BaseModel): """Request body for economic impact analysis. - Example: + Example with dataset_id: { "tax_benefit_model_name": "policyengine_uk", "dataset_id": "uuid-from-datasets-endpoint", "policy_id": "uuid-of-reform-policy" } + + Example with region: + { + "tax_benefit_model_name": "policyengine_us", + "region": "state/ca", + "policy_id": "uuid-of-reform-policy" + } """ tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( description="Which country model to use" ) - dataset_id: UUID = Field( - description="Dataset ID from /datasets endpoint containing population microdata" + dataset_id: UUID | None = Field( + default=None, + description="Dataset ID from /datasets endpoint. Either dataset_id or region must be provided.", + ) + region: str | None = Field( + default=None, + description="Region code (e.g., 'state/ca', 'us'). Either dataset_id or region must be provided.", ) policy_id: UUID | None = Field( default=None, @@ -99,6 +112,17 @@ class SimulationInfo(BaseModel): error_message: str | None = None +class RegionInfo(BaseModel): + """Region information used in analysis.""" + + code: str + label: str + region_type: str + requires_filter: bool + filter_field: str | None = None + filter_value: str | None = None + + class EconomicImpactResponse(BaseModel): """Response from economic impact analysis.""" @@ -106,6 +130,7 @@ class EconomicImpactResponse(BaseModel): status: ReportStatus baseline_simulation: SimulationInfo reform_simulation: SimulationInfo + region: RegionInfo | None = None error_message: str | None = None decile_impacts: list[DecileImpactRead] | None = None program_statistics: list[ProgramStatisticsRead] | None = None @@ -235,6 +260,7 @@ def _build_response( baseline_sim: Simulation, reform_sim: Simulation, session: Session, + region: Region | None = None, ) -> EconomicImpactResponse: """Build response from report and simulations.""" decile_impacts = None @@ -292,6 +318,17 @@ def _build_response( for s in stats ] + region_info = None + if region: + region_info = RegionInfo( + code=region.code, + label=region.label, + region_type=region.region_type, + requires_filter=region.requires_filter, + filter_field=region.filter_field, + filter_value=region.filter_value, + ) + return EconomicImpactResponse( report_id=report.id, status=report.status, @@ -305,6 +342,7 @@ def _build_response( status=reform_sim.status, error_message=reform_sim.error_message, ), + region=region_info, error_message=report.error_message, decile_impacts=decile_impacts, program_statistics=program_statistics, @@ -571,6 +609,54 @@ def _trigger_economy_comparison( fn.spawn(job_id=job_id, traceparent=traceparent) +def _resolve_dataset_and_region( + request: EconomicImpactRequest, + session: Session, +) -> tuple[Dataset, Region | None]: + """Resolve dataset from request, optionally via region lookup. + + Returns: + Tuple of (dataset, region) where region is None if dataset_id was provided directly. + """ + if request.region: + # Look up region by code + model_name = request.tax_benefit_model_name.replace("_", "-") + region = session.exec( + select(Region) + .join(TaxBenefitModel) + .where(Region.code == request.region) + .where(TaxBenefitModel.name == model_name) + ).first() + + if not region: + raise HTTPException( + status_code=404, + detail=f"Region '{request.region}' not found for model {model_name}", + ) + + dataset = session.get(Dataset, region.dataset_id) + if not dataset: + raise HTTPException( + status_code=404, + detail=f"Dataset for region '{request.region}' not found", + ) + return dataset, region + + elif request.dataset_id: + dataset = session.get(Dataset, request.dataset_id) + if not dataset: + raise HTTPException( + status_code=404, detail=f"Dataset {request.dataset_id} not found" + ) + return dataset, None + + else: + raise HTTPException( + status_code=400, + detail="Either dataset_id or region must be provided", + ) + + @router.post("/economic-impact", response_model=EconomicImpactResponse) def economic_impact( request: EconomicImpactRequest, @@ -584,25 +670,25 @@ def economic_impact( Results include decile impacts (income changes by income group) and program statistics (budgetary effects of tax/benefit programs). + + You can specify the geographic scope either by: + - dataset_id: Direct dataset reference + - region: Region code (e.g., "state/ca", "us") which resolves to a dataset """ - # Validate dataset exists - dataset = session.get(Dataset, request.dataset_id) - if not dataset: - raise HTTPException( - status_code=404, detail=f"Dataset {request.dataset_id} not found" - ) + # Resolve dataset (and optionally region) + dataset, region = _resolve_dataset_and_region(request, session) # Get model version model_version = _get_model_version(request.tax_benefit_model_name, session) - # Get or create simulations + # Get or create simulations using the resolved dataset baseline_sim = _get_or_create_simulation( simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, policy_id=None, dynamic_id=request.dynamic_id, session=session, - dataset_id=request.dataset_id, + dataset_id=dataset.id, ) reform_sim = _get_or_create_simulation( @@ -611,7 +697,7 @@ def economic_impact( policy_id=request.policy_id, dynamic_id=request.dynamic_id, session=session, - dataset_id=request.dataset_id, + dataset_id=dataset.id, ) # Get or create report @@ -630,7 +716,7 @@ def economic_impact( str(report.id), request.tax_benefit_model_name, session ) - return _build_response(report, baseline_sim, reform_sim, session) + return _build_response(report, baseline_sim, reform_sim, session, region) @router.get("/economic-impact/{report_id}", response_model=EconomicImpactResponse) diff --git a/src/policyengine_api/api/regions.py b/src/policyengine_api/api/regions.py new file mode 100644 index 0000000..1d0a34e --- /dev/null +++ b/src/policyengine_api/api/regions.py @@ -0,0 +1,102 @@ +"""Region endpoints for geographic areas used in analysis. + +Regions represent geographic areas from countries down to states, +congressional districts, cities, etc. Each region has an associated +dataset for running simulations. +""" + +from typing import List +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.models import Region, RegionRead, TaxBenefitModel +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/regions", tags=["regions"]) + + +@router.get("/", response_model=List[RegionRead]) +def list_regions( + tax_benefit_model_id: UUID | None = Query( + None, description="Filter by tax-benefit model ID" + ), + tax_benefit_model_name: str | None = Query( + None, description="Filter by tax-benefit model name (e.g., 'policyengine-us')" + ), + region_type: str | None = Query( + None, + description="Filter by region type (e.g., 'state', 'congressional_district')", + ), + session: Session = Depends(get_session), +): + """List available regions. + + Returns regions that can be used with the /analysis/economic-impact endpoint. + Each region represents a geographic area with an associated dataset. + + Args: + tax_benefit_model_id: Filter by tax-benefit model UUID. + tax_benefit_model_name: Filter by model name (e.g., "policyengine-us"). + region_type: Filter by region type (e.g., "state", "congressional_district"). + """ + query = select(Region) + + if tax_benefit_model_id: + query = query.where(Region.tax_benefit_model_id == tax_benefit_model_id) + elif tax_benefit_model_name: + query = query.join(TaxBenefitModel).where( + TaxBenefitModel.name == tax_benefit_model_name + ) + + if region_type: + query = query.where(Region.region_type == region_type) + + regions = session.exec(query).all() + return regions + + +@router.get("/{region_id}", response_model=RegionRead) +def get_region(region_id: UUID, session: Session = Depends(get_session)): + """Get a specific region by ID.""" + region = session.get(Region, region_id) + if not region: + raise HTTPException(status_code=404, detail="Region not found") + return region + + +@router.get("/by-code/{region_code:path}", response_model=RegionRead) +def get_region_by_code( + region_code: str, + tax_benefit_model_id: UUID | None = Query( + None, + description="Tax-benefit model ID (required if multiple models have same region code)", + ), + tax_benefit_model_name: str | None = Query( + None, description="Tax-benefit model name (e.g., 'policyengine-us')" + ), + session: Session = Depends(get_session), +): + """Get a specific region by code. + + Region codes use a prefix format like "state/ca" or "constituency/Sheffield Central". + + Args: + region_code: The region code (e.g., "state/ca", "us"). + tax_benefit_model_id: Filter by tax-benefit model UUID. + tax_benefit_model_name: Filter by model name. + """ + query = select(Region).where(Region.code == region_code) + + if tax_benefit_model_id: + query = query.where(Region.tax_benefit_model_id == tax_benefit_model_id) + elif tax_benefit_model_name: + query = query.join(TaxBenefitModel).where( + TaxBenefitModel.name == tax_benefit_model_name + ) + + region = session.exec(query).first() + if not region: + raise HTTPException(status_code=404, detail="Region not found") + return region diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index c49b457..7361979 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -30,6 +30,7 @@ from .parameter_value import ParameterValue, ParameterValueCreate, ParameterValueRead from .policy import Policy, PolicyCreate, PolicyRead from .poverty import Poverty, PovertyCreate, PovertyRead +from .region import Region, RegionCreate, RegionRead from .program_statistics import ( ProgramStatistics, ProgramStatisticsCreate, @@ -107,6 +108,9 @@ "Poverty", "PovertyCreate", "PovertyRead", + "Region", + "RegionCreate", + "RegionRead", "ProgramStatistics", "ProgramStatisticsCreate", "ProgramStatisticsRead", diff --git a/src/policyengine_api/models/region.py b/src/policyengine_api/models/region.py new file mode 100644 index 0000000..7c87a00 --- /dev/null +++ b/src/policyengine_api/models/region.py @@ -0,0 +1,60 @@ +"""Region model for geographic areas used in analysis.""" + +from datetime import datetime, timezone +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 + +from sqlmodel import Field, Relationship, SQLModel + +if TYPE_CHECKING: + from .dataset import Dataset + from .tax_benefit_model import TaxBenefitModel + + +class RegionBase(SQLModel): + """Base region fields.""" + + code: str # e.g., "state/ca", "constituency/Sheffield Central" + label: str # e.g., "California", "Sheffield Central" + region_type: str # e.g., "state", "congressional_district", "constituency" + requires_filter: bool = False + filter_field: str | None = None # e.g., "state_code", "place_fips" + filter_value: str | None = None # e.g., "CA", "44000" + parent_code: str | None = None # e.g., "us", "state/ca" + state_code: str | None = None # For US regions + state_name: str | None = None # For US regions + dataset_id: UUID = Field(foreign_key="datasets.id") + tax_benefit_model_id: UUID = Field(foreign_key="tax_benefit_models.id") + + +class Region(RegionBase, table=True): + """Region database model. + + Regions represent geographic areas for analysis, from countries + down to states, congressional districts, cities, etc. + Each region has a dataset (either dedicated or filtered from parent). + """ + + __tablename__ = "regions" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + # Relationships + dataset: "Dataset" = Relationship() + tax_benefit_model: "TaxBenefitModel" = Relationship() + + +class RegionCreate(RegionBase): + """Schema for creating regions.""" + + pass + + +class RegionRead(RegionBase): + """Schema for reading regions.""" + + id: UUID + created_at: datetime + updated_at: datetime From 9cd42a2370d8dd6aab865304857e1979433f7129 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 10 Feb 2026 21:07:37 +0100 Subject: [PATCH 02/10] feat: Wire filter_field/filter_value through to policyengine.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add filter_field and filter_value to Simulation model - Include filter params in deterministic simulation ID generation - Pass filter params from region to simulation creation - Pass filter params to policyengine.py PESimulation when running 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/policyengine_api/api/analysis.py | 24 +++++++++++++++++++++-- src/policyengine_api/models/simulation.py | 10 ++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index 101669b..17539d7 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -170,10 +170,12 @@ def _get_deterministic_simulation_id( dynamic_id: UUID | None, dataset_id: UUID | None = None, household_id: UUID | None = None, + filter_field: str | None = None, + filter_value: str | None = None, ) -> UUID: """Generate a deterministic UUID from simulation parameters.""" if simulation_type == SimulationType.ECONOMY: - key = f"economy:{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}" + key = f"economy:{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}:{filter_field}:{filter_value}" else: key = f"household:{household_id}:{model_version_id}:{policy_id}:{dynamic_id}" return uuid5(SIMULATION_NAMESPACE, key) @@ -196,6 +198,8 @@ def _get_or_create_simulation( session: Session, dataset_id: UUID | None = None, household_id: UUID | None = None, + filter_field: str | None = None, + filter_value: str | None = None, ) -> Simulation: """Get existing simulation or create a new one.""" sim_id = _get_deterministic_simulation_id( @@ -205,6 +209,8 @@ def _get_or_create_simulation( dynamic_id, dataset_id=dataset_id, household_id=household_id, + filter_field=filter_field, + filter_value=filter_value, ) existing = session.get(Simulation, sim_id) @@ -220,6 +226,8 @@ def _get_or_create_simulation( policy_id=policy_id, dynamic_id=dynamic_id, status=SimulationStatus.PENDING, + filter_field=filter_field, + filter_value=filter_value, ) session.add(simulation) session.commit() @@ -487,12 +495,14 @@ def build_dynamic(dynamic_id): year=dataset.year, ) - # Run simulations + # Run simulations (with optional regional filtering) pe_baseline_sim = PESimulation( dataset=pe_dataset, tax_benefit_model_version=pe_model_version, policy=baseline_policy, dynamic=baseline_dynamic, + filter_field=baseline_sim.filter_field, + filter_value=baseline_sim.filter_value, ) pe_baseline_sim.ensure() @@ -501,6 +511,8 @@ def build_dynamic(dynamic_id): tax_benefit_model_version=pe_model_version, policy=reform_policy, dynamic=reform_dynamic, + filter_field=reform_sim.filter_field, + filter_value=reform_sim.filter_value, ) pe_reform_sim.ensure() @@ -678,6 +690,10 @@ def economic_impact( # Resolve dataset (and optionally region) dataset, region = _resolve_dataset_and_region(request, session) + # Extract filter parameters from region (if present) + filter_field = region.filter_field if region and region.requires_filter else None + filter_value = region.filter_value if region and region.requires_filter else None + # Get model version model_version = _get_model_version(request.tax_benefit_model_name, session) @@ -689,6 +705,8 @@ def economic_impact( dynamic_id=request.dynamic_id, session=session, dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, ) reform_sim = _get_or_create_simulation( @@ -698,6 +716,8 @@ def economic_impact( dynamic_id=request.dynamic_id, session=session, dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, ) # Get or create report diff --git a/src/policyengine_api/models/simulation.py b/src/policyengine_api/models/simulation.py index 985db3e..176f12e 100644 --- a/src/policyengine_api/models/simulation.py +++ b/src/policyengine_api/models/simulation.py @@ -46,6 +46,16 @@ class SimulationBase(SQLModel): status: SimulationStatus = SimulationStatus.PENDING error_message: str | None = None + # Regional filtering parameters (passed to policyengine.py) + filter_field: str | None = Field( + default=None, + description="Household-level variable to filter dataset by (e.g., 'place_fips', 'country')", + ) + filter_value: str | None = Field( + default=None, + description="Value to match when filtering (e.g., '44000', 'ENGLAND')", + ) + class Simulation(SimulationBase, table=True): """Simulation database model.""" From cea3410fe9faddfaea97cec3010fc2d2084e83e5 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 10 Feb 2026 23:41:08 +0100 Subject: [PATCH 03/10] feat: Add filter pass-through to Modal functions + region unit tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Wire filter_field/filter_value through Modal functions to policyengine.py: - simulate_economy_uk, simulate_economy_us - economy_comparison_uk, economy_comparison_us - Add fixtures_regions.py with factory functions for test data - Add 25 unit tests for region resolution and filtering: - test__given_region_with_filter__then_filter_params_included.py - test__given_region_without_filter__then_filter_params_none.py - test__given_dataset_id__then_region_is_none.py - test__given_same_params__then_deterministic_id.py - test__given_invalid_region__then_404_error.py - test__given_existing_simulation__then_reuses_existing.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/policyengine_api/modal_app.py | 27 +- test_fixtures/fixtures_regions.py | 260 ++++++++++++++++++ ...__given_dataset_id__then_region_is_none.py | 94 +++++++ ...isting_simulation__then_reuses_existing.py | 144 ++++++++++ ...t__given_invalid_region__then_404_error.py | 107 +++++++ ...ith_filter__then_filter_params_included.py | 170 ++++++++++++ ...without_filter__then_filter_params_none.py | 110 ++++++++ ...iven_same_params__then_deterministic_id.py | 171 ++++++++++++ 8 files changed, 1078 insertions(+), 5 deletions(-) create mode 100644 test_fixtures/fixtures_regions.py create mode 100644 tests/test__given_dataset_id__then_region_is_none.py create mode 100644 tests/test__given_existing_simulation__then_reuses_existing.py create mode 100644 tests/test__given_invalid_region__then_404_error.py create mode 100644 tests/test__given_region_with_filter__then_filter_params_included.py create mode 100644 tests/test__given_region_without_filter__then_filter_params_none.py create mode 100644 tests/test__given_same_params__then_deterministic_id.py diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index 1aa8119..332c349 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -841,6 +841,8 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N tax_benefit_model_version=pe_model_version, policy=policy, dynamic=dynamic, + filter_field=simulation.filter_field, + filter_value=simulation.filter_value, ) pe_sim.ensure() @@ -1007,6 +1009,8 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N tax_benefit_model_version=pe_model_version, policy=policy, dynamic=dynamic, + filter_field=simulation.filter_field, + filter_value=simulation.filter_value, ) pe_sim.ensure() @@ -1112,11 +1116,12 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: # Debug: log the key role import base64 import json + try: - payload = supabase_key.split('.')[1] - payload += '=' * (4 - len(payload) % 4) + payload = supabase_key.split(".")[1] + payload += "=" * (4 - len(payload) % 4) decoded = json.loads(base64.urlsafe_b64decode(payload)) - logfire.info("Supabase key info", role=decoded.get('role', 'unknown')) + logfire.info("Supabase key info", role=decoded.get("role", "unknown")) except Exception as e: logfire.warn("Could not decode key", error=str(e)) @@ -1213,6 +1218,8 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: tax_benefit_model_version=pe_model_version, policy=baseline_policy, dynamic=baseline_dynamic, + filter_field=baseline_sim.filter_field, + filter_value=baseline_sim.filter_value, ) pe_baseline_sim.ensure() @@ -1222,6 +1229,8 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: tax_benefit_model_version=pe_model_version, policy=reform_policy, dynamic=reform_dynamic, + filter_field=reform_sim.filter_field, + filter_value=reform_sim.filter_value, ) pe_reform_sim.ensure() @@ -1535,6 +1544,8 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: tax_benefit_model_version=pe_model_version, policy=baseline_policy, dynamic=baseline_dynamic, + filter_field=baseline_sim.filter_field, + filter_value=baseline_sim.filter_value, ) pe_baseline_sim.ensure() @@ -1544,6 +1555,8 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: tax_benefit_model_version=pe_model_version, policy=reform_policy, dynamic=reform_dynamic, + filter_field=reform_sim.filter_field, + filter_value=reform_sim.filter_value, ) pe_reform_sim.ensure() @@ -1930,7 +1943,9 @@ def compute_aggregate_uk(aggregate_id: str, traceparent: str | None = None) -> N pe_aggregate = PEAggregate( simulation=pe_sim, variable=aggregate.variable, - aggregate_type=PEAggregateType(aggregate.aggregate_type.value), + aggregate_type=PEAggregateType( + aggregate.aggregate_type.value + ), entity=aggregate.entity, ) pe_aggregate.run() @@ -2074,7 +2089,9 @@ def compute_aggregate_us(aggregate_id: str, traceparent: str | None = None) -> N pe_aggregate = PEAggregate( simulation=pe_sim, variable=aggregate.variable, - aggregate_type=PEAggregateType(aggregate.aggregate_type.value), + aggregate_type=PEAggregateType( + aggregate.aggregate_type.value + ), entity=aggregate.entity, ) pe_aggregate.run() diff --git a/test_fixtures/fixtures_regions.py b/test_fixtures/fixtures_regions.py new file mode 100644 index 0000000..e95e0d8 --- /dev/null +++ b/test_fixtures/fixtures_regions.py @@ -0,0 +1,260 @@ +"""Fixtures and helpers for region-related tests.""" + +from uuid import uuid4 + +import pytest + +from policyengine_api.models import ( + Dataset, + Region, + Simulation, + SimulationStatus, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- + +TEST_UUIDS = { + "DATASET": uuid4(), + "DATASET_UK": uuid4(), + "DATASET_US": uuid4(), + "MODEL_UK": uuid4(), + "MODEL_US": uuid4(), + "MODEL_VERSION_UK": uuid4(), + "MODEL_VERSION_US": uuid4(), + "REGION_UK": uuid4(), + "REGION_US_STATE": uuid4(), + "REGION_US_NATIONAL": uuid4(), + "POLICY": uuid4(), + "DYNAMIC": uuid4(), +} + +REGION_CODES = { + "UK_ENGLAND": "country/england", + "US_CALIFORNIA": "state/ca", + "US_NATIONAL": "us", + "UK_NATIONAL": "uk", +} + +FILTER_FIELDS = { + "UK_COUNTRY": "country", + "US_STATE": "state_code", + "US_FIPS": "place_fips", +} + +FILTER_VALUES = { + "ENGLAND": "ENGLAND", + "CALIFORNIA": "CA", + "CA_FIPS": "06000", +} + + +# ----------------------------------------------------------------------------- +# Factory Functions +# ----------------------------------------------------------------------------- + + +def create_tax_benefit_model( + session, name: str = "policyengine-uk", description: str = "UK model" +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel.""" + model = TaxBenefitModel(name=name, description=description) + session.add(model) + session.commit() + session.refresh(model) + return model + + +def create_tax_benefit_model_version( + session, model: TaxBenefitModel, version: str = "1.0.0" +) -> TaxBenefitModelVersion: + """Create and persist a TaxBenefitModelVersion.""" + model_version = TaxBenefitModelVersion( + model_id=model.id, + version=version, + description=f"Version {version}", + ) + session.add(model_version) + session.commit() + session.refresh(model_version) + return model_version + + +def create_dataset( + session, + model: TaxBenefitModel, + name: str = "test_dataset", + filepath: str = "test/path/dataset.h5", + year: int = 2024, +) -> Dataset: + """Create and persist a Dataset.""" + dataset = Dataset( + name=name, + description=f"Test dataset: {name}", + filepath=filepath, + year=year, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + return dataset + + +def create_region( + session, + model: TaxBenefitModel, + dataset: Dataset, + code: str, + label: str, + region_type: str, + requires_filter: bool = False, + filter_field: str | None = None, + filter_value: str | None = None, +) -> Region: + """Create and persist a Region.""" + region = Region( + code=code, + label=label, + region_type=region_type, + requires_filter=requires_filter, + filter_field=filter_field, + filter_value=filter_value, + dataset_id=dataset.id, + tax_benefit_model_id=model.id, + ) + session.add(region) + session.commit() + session.refresh(region) + return region + + +def create_simulation( + session, + dataset: Dataset, + model_version: TaxBenefitModelVersion, + filter_field: str | None = None, + filter_value: str | None = None, + status: SimulationStatus = SimulationStatus.PENDING, +) -> Simulation: + """Create and persist a Simulation with optional filter parameters.""" + simulation = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=model_version.id, + status=status, + filter_field=filter_field, + filter_value=filter_value, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + return simulation + + +# ----------------------------------------------------------------------------- +# Composite Fixtures +# ----------------------------------------------------------------------------- + + +@pytest.fixture +def uk_model_and_version(session): + """Create UK model with version.""" + model = create_tax_benefit_model( + session, name="policyengine-uk", description="UK model" + ) + version = create_tax_benefit_model_version(session, model) + return model, version + + +@pytest.fixture +def us_model_and_version(session): + """Create US model with version.""" + model = create_tax_benefit_model( + session, name="policyengine-us", description="US model" + ) + version = create_tax_benefit_model_version(session, model) + return model, version + + +@pytest.fixture +def uk_dataset(session, uk_model_and_version): + """Create a UK dataset.""" + model, _ = uk_model_and_version + return create_dataset( + session, model, name="uk_enhanced_frs", filepath="uk/enhanced_frs_2024.h5" + ) + + +@pytest.fixture +def us_dataset(session, us_model_and_version): + """Create a US dataset.""" + model, _ = us_model_and_version + return create_dataset(session, model, name="us_cps", filepath="us/cps_2024.h5") + + +@pytest.fixture +def uk_region_national(session, uk_model_and_version, uk_dataset): + """Create UK national region (no filter required).""" + model, _ = uk_model_and_version + return create_region( + session, + model=model, + dataset=uk_dataset, + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + + +@pytest.fixture +def uk_region_england(session, uk_model_and_version, uk_dataset): + """Create England region (filter required).""" + model, _ = uk_model_and_version + return create_region( + session, + model=model, + dataset=uk_dataset, + code="country/england", + label="England", + region_type="country", + requires_filter=True, + filter_field="country", + filter_value="ENGLAND", + ) + + +@pytest.fixture +def us_region_national(session, us_model_and_version, us_dataset): + """Create US national region (no filter required).""" + model, _ = us_model_and_version + return create_region( + session, + model=model, + dataset=us_dataset, + code="us", + label="United States", + region_type="national", + requires_filter=False, + ) + + +@pytest.fixture +def us_region_california(session, us_model_and_version, us_dataset): + """Create California state region (filter required).""" + model, _ = us_model_and_version + return create_region( + session, + model=model, + dataset=us_dataset, + code="state/ca", + label="California", + region_type="state", + requires_filter=True, + filter_field="state_code", + filter_value="CA", + ) diff --git a/tests/test__given_dataset_id__then_region_is_none.py b/tests/test__given_dataset_id__then_region_is_none.py new file mode 100644 index 0000000..ee3c1d5 --- /dev/null +++ b/tests/test__given_dataset_id__then_region_is_none.py @@ -0,0 +1,94 @@ +"""Tests for dataset resolution when dataset_id is provided directly. + +When a dataset_id is provided instead of a region code, +the resolved region should be None. +""" + +import pytest +from sqlmodel import Session + +from policyengine_api.api.analysis import ( + EconomicImpactRequest, + _resolve_dataset_and_region, +) +from test_fixtures.fixtures_regions import ( + create_dataset, + create_tax_benefit_model, +) + + +class TestResolveDatasetWithDatasetId: + """Tests for _resolve_dataset_and_region when dataset_id is provided.""" + + def test_given_dataset_id_then_region_is_none(self, session: Session): + """Given a dataset_id, then region is None in the response.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + dataset_id=dataset.id, + ) + + # When + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + # Then + assert resolved_region is None + + def test_given_dataset_id_then_dataset_is_returned(self, session: Session): + """Given a dataset_id, then the correct dataset is returned.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + dataset_id=dataset.id, + ) + + # When + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + # Then + assert resolved_dataset.id == dataset.id + assert resolved_dataset.name == "uk_enhanced_frs" + + def test_given_dataset_id_and_region_then_region_takes_precedence( + self, session: Session + ): + """Given both dataset_id and region, then region takes precedence.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset1 = create_dataset(session, model, name="dataset_from_id") + dataset2 = create_dataset(session, model, name="dataset_from_region") + from test_fixtures.fixtures_regions import create_region + + region = create_region( + session, + model=model, + dataset=dataset2, + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + dataset_id=dataset1.id, + region="uk", + ) + + # When + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + # Then + # Region code takes precedence, so we get dataset2 + assert resolved_dataset.id == dataset2.id + assert resolved_region is not None + assert resolved_region.code == "uk" diff --git a/tests/test__given_existing_simulation__then_reuses_existing.py b/tests/test__given_existing_simulation__then_reuses_existing.py new file mode 100644 index 0000000..77731f1 --- /dev/null +++ b/tests/test__given_existing_simulation__then_reuses_existing.py @@ -0,0 +1,144 @@ +"""Tests for simulation reuse with filter parameters. + +When a simulation with the same parameters already exists, +it should be reused instead of creating a new one. +""" + +import pytest +from sqlmodel import Session + +from policyengine_api.api.analysis import _get_or_create_simulation +from policyengine_api.models import SimulationStatus +from test_fixtures.fixtures_regions import ( + create_dataset, + create_simulation, + create_tax_benefit_model, + create_tax_benefit_model_version, +) + + +class TestSimulationReuse: + """Tests for simulation reuse behavior.""" + + def test_given_existing_simulation_with_filter_then_reuses(self, session: Session): + """Given an existing simulation with filter params, then it is reused.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + # Create initial simulation with filter params + first_sim = _get_or_create_simulation( + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + # When - request same simulation again + second_sim = _get_or_create_simulation( + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + # Then + assert first_sim.id == second_sim.id + + def test_given_different_filter_then_creates_new_simulation(self, session: Session): + """Given different filter params, then a new simulation is created.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + # Create simulation for England + england_sim = _get_or_create_simulation( + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + # When - request simulation for Scotland + scotland_sim = _get_or_create_simulation( + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="SCOTLAND", + ) + + # Then + assert england_sim.id != scotland_sim.id + assert england_sim.filter_value == "ENGLAND" + assert scotland_sim.filter_value == "SCOTLAND" + + def test_given_no_filter_vs_filter_then_creates_separate_simulations( + self, session: Session + ): + """Given national vs filtered, then separate simulations are created.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + # Create national (no filter) simulation + national_sim = _get_or_create_simulation( + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field=None, + filter_value=None, + ) + + # When - request filtered simulation + filtered_sim = _get_or_create_simulation( + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + # Then + assert national_sim.id != filtered_sim.id + assert national_sim.filter_field is None + assert filtered_sim.filter_field == "country" + + def test_given_new_simulation_then_status_is_pending(self, session: Session): + """Given a new simulation request, then status is PENDING.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + # When + simulation = _get_or_create_simulation( + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + # Then + assert simulation.status == SimulationStatus.PENDING diff --git a/tests/test__given_invalid_region__then_404_error.py b/tests/test__given_invalid_region__then_404_error.py new file mode 100644 index 0000000..562a5c9 --- /dev/null +++ b/tests/test__given_invalid_region__then_404_error.py @@ -0,0 +1,107 @@ +"""Tests for region resolution error cases. + +When an invalid region code is provided or required parameters are missing, +appropriate HTTP errors should be raised. +""" + +import pytest +from fastapi import HTTPException +from sqlmodel import Session + +from policyengine_api.api.analysis import ( + EconomicImpactRequest, + _resolve_dataset_and_region, +) +from test_fixtures.fixtures_regions import ( + create_dataset, + create_region, + create_tax_benefit_model, +) + + +class TestInvalidRegionCode: + """Tests for invalid region code handling.""" + + def test_given_nonexistent_region_code_then_raises_404(self, session: Session): + """Given a region code that doesn't exist, then raises 404.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + # Note: No region is created for this code + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="nonexistent/region", + ) + + # When/Then + with pytest.raises(HTTPException) as exc_info: + _resolve_dataset_and_region(request, session) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + def test_given_region_for_wrong_model_then_raises_404(self, session: Session): + """Given a region code for wrong model, then raises 404.""" + # Given + uk_model = create_tax_benefit_model(session, name="policyengine-uk") + uk_dataset = create_dataset(session, uk_model, name="uk_enhanced_frs") + create_region( + session, + model=uk_model, + dataset=uk_dataset, + code="uk", + label="United Kingdom", + region_type="national", + ) + # Request uses US model but UK region code + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_us", + region="uk", + ) + + # When/Then + with pytest.raises(HTTPException) as exc_info: + _resolve_dataset_and_region(request, session) + + assert exc_info.value.status_code == 404 + + +class TestMissingRequiredParams: + """Tests for missing required parameters.""" + + def test_given_neither_dataset_nor_region_then_raises_400(self, session: Session): + """Given neither dataset_id nor region, then raises 400.""" + # Given + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + # Neither dataset_id nor region provided + ) + + # When/Then + with pytest.raises(HTTPException) as exc_info: + _resolve_dataset_and_region(request, session) + + assert exc_info.value.status_code == 400 + assert "either dataset_id or region" in exc_info.value.detail.lower() + + +class TestNonexistentDataset: + """Tests for nonexistent dataset handling.""" + + def test_given_nonexistent_dataset_id_then_raises_404(self, session: Session): + """Given a dataset_id that doesn't exist, then raises 404.""" + # Given + from uuid import uuid4 + + nonexistent_id = uuid4() + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + dataset_id=nonexistent_id, + ) + + # When/Then + with pytest.raises(HTTPException) as exc_info: + _resolve_dataset_and_region(request, session) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() diff --git a/tests/test__given_region_with_filter__then_filter_params_included.py b/tests/test__given_region_with_filter__then_filter_params_included.py new file mode 100644 index 0000000..c84a372 --- /dev/null +++ b/tests/test__given_region_with_filter__then_filter_params_included.py @@ -0,0 +1,170 @@ +"""Tests for region resolution with filter parameters. + +When a region requires filtering (e.g., England from UK dataset, +California from US dataset), the filter_field and filter_value +should be extracted and passed through to simulations. +""" + +import pytest +from sqlmodel import Session + +from policyengine_api.api.analysis import ( + EconomicImpactRequest, + _get_or_create_simulation, + _resolve_dataset_and_region, +) +from test_fixtures.fixtures_regions import ( + create_dataset, + create_region, + create_tax_benefit_model, + create_tax_benefit_model_version, +) + + +class TestResolveDatasetAndRegionWithFilter: + """Tests for _resolve_dataset_and_region when region requires filtering.""" + + def test_given_region_requires_filter_then_returns_filter_field( + self, session: Session + ): + """Given a region that requires filtering, then filter_field is populated.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + region = create_region( + session, + model=model, + dataset=dataset, + code="country/england", + label="England", + region_type="country", + requires_filter=True, + filter_field="country", + filter_value="ENGLAND", + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="country/england", + ) + + # When + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + # Then + assert resolved_region is not None + assert resolved_region.filter_field == "country" + assert resolved_region.filter_value == "ENGLAND" + assert resolved_region.requires_filter is True + + def test_given_us_state_region_then_returns_state_filter(self, session: Session): + """Given a US state region, then returns state code filter.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-us") + dataset = create_dataset(session, model, name="us_cps") + region = create_region( + session, + model=model, + dataset=dataset, + code="state/ca", + label="California", + region_type="state", + requires_filter=True, + filter_field="state_code", + filter_value="CA", + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_us", + region="state/ca", + ) + + # When + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + # Then + assert resolved_region is not None + assert resolved_region.filter_field == "state_code" + assert resolved_region.filter_value == "CA" + + def test_given_region_with_filter_then_dataset_is_resolved(self, session: Session): + """Given a region code, then the associated dataset is returned.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + region = create_region( + session, + model=model, + dataset=dataset, + code="country/england", + label="England", + region_type="country", + requires_filter=True, + filter_field="country", + filter_value="ENGLAND", + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="country/england", + ) + + # When + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + # Then + assert resolved_dataset.id == dataset.id + assert resolved_dataset.name == "uk_enhanced_frs" + + +class TestSimulationCreationWithFilter: + """Tests for creating simulations with filter parameters.""" + + def test_given_filter_params_then_simulation_has_filter_fields( + self, session: Session + ): + """Given filter parameters, then created simulation has filter fields populated.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + # When + simulation = _get_or_create_simulation( + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + # Then + assert simulation.filter_field == "country" + assert simulation.filter_value == "ENGLAND" + + def test_given_no_filter_params_then_simulation_has_null_filter_fields( + self, session: Session + ): + """Given no filter parameters, then created simulation has null filter fields.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + # When + simulation = _get_or_create_simulation( + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + ) + + # Then + assert simulation.filter_field is None + assert simulation.filter_value is None diff --git a/tests/test__given_region_without_filter__then_filter_params_none.py b/tests/test__given_region_without_filter__then_filter_params_none.py new file mode 100644 index 0000000..e81d7a8 --- /dev/null +++ b/tests/test__given_region_without_filter__then_filter_params_none.py @@ -0,0 +1,110 @@ +"""Tests for region resolution without filter parameters. + +When a region does not require filtering (e.g., national UK or US), +the filter_field and filter_value should be None. +""" + +import pytest +from sqlmodel import Session + +from policyengine_api.api.analysis import ( + EconomicImpactRequest, + _resolve_dataset_and_region, +) +from test_fixtures.fixtures_regions import ( + create_dataset, + create_region, + create_tax_benefit_model, +) + + +class TestResolveDatasetAndRegionWithoutFilter: + """Tests for _resolve_dataset_and_region when region does not require filtering.""" + + def test_given_national_uk_region_then_filter_params_none(self, session: Session): + """Given UK national region, then filter_field and filter_value are None.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + region = create_region( + session, + model=model, + dataset=dataset, + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="uk", + ) + + # When + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + # Then + assert resolved_region is not None + assert resolved_region.requires_filter is False + assert resolved_region.filter_field is None + assert resolved_region.filter_value is None + + def test_given_national_us_region_then_filter_params_none(self, session: Session): + """Given US national region, then filter_field and filter_value are None.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-us") + dataset = create_dataset(session, model, name="us_cps") + region = create_region( + session, + model=model, + dataset=dataset, + code="us", + label="United States", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_us", + region="us", + ) + + # When + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + # Then + assert resolved_region is not None + assert resolved_region.requires_filter is False + assert resolved_region.filter_field is None + assert resolved_region.filter_value is None + + def test_given_national_region_then_dataset_still_resolved(self, session: Session): + """Given national region without filter, then dataset is still correctly resolved.""" + # Given + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + region = create_region( + session, + model=model, + dataset=dataset, + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="uk", + ) + + # When + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + # Then + assert resolved_dataset.id == dataset.id + assert resolved_dataset.name == "uk_enhanced_frs" diff --git a/tests/test__given_same_params__then_deterministic_id.py b/tests/test__given_same_params__then_deterministic_id.py new file mode 100644 index 0000000..d393f53 --- /dev/null +++ b/tests/test__given_same_params__then_deterministic_id.py @@ -0,0 +1,171 @@ +"""Tests for deterministic simulation ID generation. + +The simulation ID is generated deterministically from the simulation +parameters (dataset, model version, policy, dynamic, filter params). +This ensures that re-running the same simulation reuses existing results. +""" + +from uuid import uuid4 + +import pytest + +from policyengine_api.api.analysis import _get_deterministic_simulation_id + + +class TestDeterministicSimulationId: + """Tests for _get_deterministic_simulation_id function.""" + + def test_given_same_params_then_same_id_returned(self): + """Given identical parameters, then the same ID is returned.""" + # Given + dataset_id = uuid4() + model_version_id = uuid4() + policy_id = uuid4() + dynamic_id = uuid4() + filter_field = "country" + filter_value = "ENGLAND" + + # When + id1 = _get_deterministic_simulation_id( + dataset_id, + model_version_id, + policy_id, + dynamic_id, + filter_field, + filter_value, + ) + id2 = _get_deterministic_simulation_id( + dataset_id, + model_version_id, + policy_id, + dynamic_id, + filter_field, + filter_value, + ) + + # Then + assert id1 == id2 + + def test_given_different_filter_field_then_different_id(self): + """Given different filter_field, then a different ID is returned.""" + # Given + dataset_id = uuid4() + model_version_id = uuid4() + policy_id = None + dynamic_id = None + + # When + id1 = _get_deterministic_simulation_id( + dataset_id, + model_version_id, + policy_id, + dynamic_id, + filter_field="country", + filter_value="ENGLAND", + ) + id2 = _get_deterministic_simulation_id( + dataset_id, + model_version_id, + policy_id, + dynamic_id, + filter_field="state_code", + filter_value="ENGLAND", + ) + + # Then + assert id1 != id2 + + def test_given_different_filter_value_then_different_id(self): + """Given different filter_value, then a different ID is returned.""" + # Given + dataset_id = uuid4() + model_version_id = uuid4() + policy_id = None + dynamic_id = None + + # When + id1 = _get_deterministic_simulation_id( + dataset_id, + model_version_id, + policy_id, + dynamic_id, + filter_field="country", + filter_value="ENGLAND", + ) + id2 = _get_deterministic_simulation_id( + dataset_id, + model_version_id, + policy_id, + dynamic_id, + filter_field="country", + filter_value="SCOTLAND", + ) + + # Then + assert id1 != id2 + + def test_given_filter_none_vs_filter_set_then_different_id(self): + """Given None filter vs set filter, then different IDs are returned.""" + # Given + dataset_id = uuid4() + model_version_id = uuid4() + policy_id = None + dynamic_id = None + + # When + id_no_filter = _get_deterministic_simulation_id( + dataset_id, + model_version_id, + policy_id, + dynamic_id, + filter_field=None, + filter_value=None, + ) + id_with_filter = _get_deterministic_simulation_id( + dataset_id, + model_version_id, + policy_id, + dynamic_id, + filter_field="country", + filter_value="ENGLAND", + ) + + # Then + assert id_no_filter != id_with_filter + + def test_given_different_dataset_then_different_id(self): + """Given different dataset_id, then a different ID is returned.""" + # Given + model_version_id = uuid4() + policy_id = None + dynamic_id = None + filter_field = "country" + filter_value = "ENGLAND" + + # When + id1 = _get_deterministic_simulation_id( + uuid4(), model_version_id, policy_id, dynamic_id, filter_field, filter_value + ) + id2 = _get_deterministic_simulation_id( + uuid4(), model_version_id, policy_id, dynamic_id, filter_field, filter_value + ) + + # Then + assert id1 != id2 + + def test_given_null_optional_params_then_consistent_id(self): + """Given null optional parameters, then consistent ID is generated.""" + # Given + dataset_id = uuid4() + model_version_id = uuid4() + + # When + id1 = _get_deterministic_simulation_id( + dataset_id, model_version_id, None, None, None, None + ) + id2 = _get_deterministic_simulation_id( + dataset_id, model_version_id, None, None, None, None + ) + + # Then + assert id1 == id2 From 325db403a763474c6332d45f9d7788d38c188107 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 11 Feb 2026 23:45:13 +0100 Subject: [PATCH 04/10] feat: Add seed script for US and UK regions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add seed_regions.py to populate the regions table with geographic data from policyengine.py's region registries: - US: National + 51 states (DC included) - UK: National + 4 countries (England, Scotland, Wales, NI) Optional flags: - --include-places: Add US cities (333 places over 100K population) - --include-districts: Add US congressional districts (436) - --us-only / --uk-only: Seed only one country 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- scripts/seed_regions.py | 280 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 280 insertions(+) create mode 100644 scripts/seed_regions.py diff --git a/scripts/seed_regions.py b/scripts/seed_regions.py new file mode 100644 index 0000000..f4f9954 --- /dev/null +++ b/scripts/seed_regions.py @@ -0,0 +1,280 @@ +"""Seed regions for US and UK geographic analysis. + +This script populates the regions table with: +- US: National, 51 states (incl. DC), and optionally places/cities +- UK: National and 4 countries (England, Scotland, Wales, Northern Ireland) + +Regions are sourced from policyengine.py's region registries and linked +to the appropriate datasets in the database. + +Usage: + python scripts/seed_regions.py # Seed US and UK regions + python scripts/seed_regions.py --us-only # Seed only US regions + python scripts/seed_regions.py --uk-only # Seed only UK regions + python scripts/seed_regions.py --include-places # Include US places (cities) +""" + +import argparse +import sys +import time +from pathlib import Path + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn +from sqlmodel import Session, create_engine, select + +from policyengine_api.config.settings import settings +from policyengine_api.models import Dataset, Region, TaxBenefitModel + +console = Console() + + +def get_session() -> Session: + """Get database session.""" + engine = create_engine(settings.database_url) + return Session(engine) + + +def seed_us_regions( + session: Session, + include_places: bool = False, + include_districts: bool = False, +) -> tuple[int, int]: + """Seed US regions from policyengine.py registry. + + Args: + session: Database session + include_places: Include US places (cities over 100K population) + include_districts: Include congressional districts + + Returns: + Tuple of (created_count, skipped_count) + """ + from policyengine.countries.us.regions import us_region_registry + + # Get US model + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not us_model: + console.print("[red]Error: US model not found. Run seed.py first.[/red]") + return 0, 0 + + # Get US national dataset (CPS) + us_dataset = session.exec( + select(Dataset) + .where(Dataset.tax_benefit_model_id == us_model.id) + .where(Dataset.name.contains("cps")) # type: ignore + .order_by(Dataset.year.desc()) # type: ignore + ).first() + + if not us_dataset: + console.print("[red]Error: US dataset not found. Run seed.py first.[/red]") + return 0, 0 + + created = 0 + skipped = 0 + + # Filter regions based on options + regions_to_seed = [] + for region in us_region_registry.regions: + if region.region_type == "national": + regions_to_seed.append(region) + elif region.region_type == "state": + regions_to_seed.append(region) + elif region.region_type == "congressional_district" and include_districts: + regions_to_seed.append(region) + elif region.region_type == "place" and include_places: + regions_to_seed.append(region) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("US regions", total=len(regions_to_seed)) + + for pe_region in regions_to_seed: + progress.update(task, description=f"US: {pe_region.label}") + + # Check if region already exists + existing = session.exec( + select(Region).where(Region.code == pe_region.code) + ).first() + + if existing: + skipped += 1 + progress.advance(task) + continue + + # Create region record + db_region = Region( + code=pe_region.code, + label=pe_region.label, + region_type=pe_region.region_type, + requires_filter=pe_region.requires_filter, + filter_field=pe_region.filter_field, + filter_value=pe_region.filter_value, + parent_code=pe_region.parent_code, + state_code=pe_region.state_code, + state_name=pe_region.state_name, + dataset_id=us_dataset.id, # All US regions use the national dataset + tax_benefit_model_id=us_model.id, + ) + session.add(db_region) + created += 1 + progress.advance(task) + + session.commit() + + return created, skipped + + +def seed_uk_regions(session: Session) -> tuple[int, int]: + """Seed UK regions from policyengine.py registry. + + Args: + session: Database session + + Returns: + Tuple of (created_count, skipped_count) + """ + from policyengine.countries.uk.regions import uk_region_registry + + # Get UK model + uk_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") + ).first() + + if not uk_model: + console.print( + "[yellow]Warning: UK model not found. Skipping UK regions.[/yellow]" + ) + return 0, 0 + + # Get UK national dataset (FRS) + uk_dataset = session.exec( + select(Dataset) + .where(Dataset.tax_benefit_model_id == uk_model.id) + .where(Dataset.name.contains("frs")) # type: ignore + .order_by(Dataset.year.desc()) # type: ignore + ).first() + + if not uk_dataset: + console.print( + "[yellow]Warning: UK dataset not found. Skipping UK regions.[/yellow]" + ) + return 0, 0 + + created = 0 + skipped = 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("UK regions", total=len(uk_region_registry.regions)) + + for pe_region in uk_region_registry.regions: + progress.update(task, description=f"UK: {pe_region.label}") + + # Check if region already exists + existing = session.exec( + select(Region).where(Region.code == pe_region.code) + ).first() + + if existing: + skipped += 1 + progress.advance(task) + continue + + # Create region record + db_region = Region( + code=pe_region.code, + label=pe_region.label, + region_type=pe_region.region_type, + requires_filter=pe_region.requires_filter, + filter_field=pe_region.filter_field, + filter_value=pe_region.filter_value, + parent_code=pe_region.parent_code, + state_code=None, # UK regions don't have state_code + state_name=None, + dataset_id=uk_dataset.id, # All UK regions use the national dataset + tax_benefit_model_id=uk_model.id, + ) + session.add(db_region) + created += 1 + progress.advance(task) + + session.commit() + + return created, skipped + + +def main(): + parser = argparse.ArgumentParser(description="Seed US and UK regions") + parser.add_argument( + "--us-only", + action="store_true", + help="Only seed US regions", + ) + parser.add_argument( + "--uk-only", + action="store_true", + help="Only seed UK regions", + ) + parser.add_argument( + "--include-places", + action="store_true", + help="Include US places (cities over 100K population)", + ) + parser.add_argument( + "--include-districts", + action="store_true", + help="Include US congressional districts", + ) + args = parser.parse_args() + + console.print("[bold green]Seeding regions...[/bold green]\n") + + start = time.time() + total_created = 0 + total_skipped = 0 + + with get_session() as session: + # Seed US regions + if not args.uk_only: + console.print("[bold]US Regions[/bold]") + us_created, us_skipped = seed_us_regions( + session, + include_places=args.include_places, + include_districts=args.include_districts, + ) + total_created += us_created + total_skipped += us_skipped + console.print( + f"[green]✓[/green] US regions: {us_created} created, {us_skipped} skipped\n" + ) + + # Seed UK regions + if not args.us_only: + console.print("[bold]UK Regions[/bold]") + uk_created, uk_skipped = seed_uk_regions(session) + total_created += uk_created + total_skipped += uk_skipped + console.print( + f"[green]✓[/green] UK regions: {uk_created} created, {uk_skipped} skipped\n" + ) + + elapsed = time.time() - start + console.print(f"[bold]Total: {total_created} created, {total_skipped} skipped[/bold]") + console.print(f"[bold]Time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() From b0f7082217ad6b6efbb207b5e89022aa174a899e Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 12 Feb 2026 16:56:18 +0100 Subject: [PATCH 05/10] feat: Integrate regions seeding into main seed script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add --skip-regions, --include-places, and --include-districts CLI options to seed.py. Regions are now seeded as part of the standard database setup process, sourcing region definitions from policyengine.py's registries. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- scripts/seed.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/scripts/seed.py b/scripts/seed.py index 4274528..2568d26 100644 --- a/scripts/seed.py +++ b/scripts/seed.py @@ -631,6 +631,21 @@ def main(): action="store_true", help="Skip UK datasets (useful when HuggingFace token is not available)", ) + parser.add_argument( + "--skip-regions", + action="store_true", + help="Skip seeding regions", + ) + parser.add_argument( + "--include-places", + action="store_true", + help="Include US places (cities over 100K population) when seeding regions", + ) + parser.add_argument( + "--include-districts", + action="store_true", + help="Include US congressional districts when seeding regions", + ) args = parser.parse_args() with logfire.span("database_seeding"): @@ -652,6 +667,26 @@ def main(): # Seed example policies seed_example_policies(session) + # Seed regions + if not args.skip_regions: + from seed_regions import seed_us_regions, seed_uk_regions + + console.print("\n[bold]Seeding regions...[/bold]") + us_created, us_skipped = seed_us_regions( + session, + include_places=args.include_places, + include_districts=args.include_districts, + ) + console.print( + f"[green]✓[/green] US regions: {us_created} created, {us_skipped} skipped" + ) + + if not args.skip_uk_datasets: + uk_created, uk_skipped = seed_uk_regions(session) + console.print( + f"[green]✓[/green] UK regions: {uk_created} created, {uk_skipped} skipped" + ) + console.print("\n[bold green]✓ Database seeding complete![/bold green]") From 662b6cb2c9bddcdd05681ed65f94d7be008d8efc Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 12 Feb 2026 17:48:17 +0100 Subject: [PATCH 06/10] refactor: Change regions CLI to use skip flags instead of include flags MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Default behavior now seeds all US regions (national, states, districts, places). Use --skip-places and --skip-districts to exclude specific region types. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- scripts/seed.py | 12 ++++++------ scripts/seed_regions.py | 31 ++++++++++++++++--------------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/scripts/seed.py b/scripts/seed.py index 2568d26..072b7bb 100644 --- a/scripts/seed.py +++ b/scripts/seed.py @@ -637,14 +637,14 @@ def main(): help="Skip seeding regions", ) parser.add_argument( - "--include-places", + "--skip-places", action="store_true", - help="Include US places (cities over 100K population) when seeding regions", + help="Skip US places (cities over 100K population) when seeding regions", ) parser.add_argument( - "--include-districts", + "--skip-districts", action="store_true", - help="Include US congressional districts when seeding regions", + help="Skip US congressional districts when seeding regions", ) args = parser.parse_args() @@ -674,8 +674,8 @@ def main(): console.print("\n[bold]Seeding regions...[/bold]") us_created, us_skipped = seed_us_regions( session, - include_places=args.include_places, - include_districts=args.include_districts, + skip_places=args.skip_places, + skip_districts=args.skip_districts, ) console.print( f"[green]✓[/green] US regions: {us_created} created, {us_skipped} skipped" diff --git a/scripts/seed_regions.py b/scripts/seed_regions.py index f4f9954..c8cc9d8 100644 --- a/scripts/seed_regions.py +++ b/scripts/seed_regions.py @@ -1,17 +1,18 @@ """Seed regions for US and UK geographic analysis. This script populates the regions table with: -- US: National, 51 states (incl. DC), and optionally places/cities +- US: National, 51 states (incl. DC), 436 congressional districts, 333 places/cities - UK: National and 4 countries (England, Scotland, Wales, Northern Ireland) Regions are sourced from policyengine.py's region registries and linked to the appropriate datasets in the database. Usage: - python scripts/seed_regions.py # Seed US and UK regions + python scripts/seed_regions.py # Seed all US and UK regions python scripts/seed_regions.py --us-only # Seed only US regions python scripts/seed_regions.py --uk-only # Seed only UK regions - python scripts/seed_regions.py --include-places # Include US places (cities) + python scripts/seed_regions.py --skip-places # Exclude US places (cities) + python scripts/seed_regions.py --skip-districts # Exclude US congressional districts """ import argparse @@ -40,15 +41,15 @@ def get_session() -> Session: def seed_us_regions( session: Session, - include_places: bool = False, - include_districts: bool = False, + skip_places: bool = False, + skip_districts: bool = False, ) -> tuple[int, int]: """Seed US regions from policyengine.py registry. Args: session: Database session - include_places: Include US places (cities over 100K population) - include_districts: Include congressional districts + skip_places: Skip US places (cities over 100K population) + skip_districts: Skip congressional districts Returns: Tuple of (created_count, skipped_count) @@ -86,9 +87,9 @@ def seed_us_regions( regions_to_seed.append(region) elif region.region_type == "state": regions_to_seed.append(region) - elif region.region_type == "congressional_district" and include_districts: + elif region.region_type == "congressional_district" and not skip_districts: regions_to_seed.append(region) - elif region.region_type == "place" and include_places: + elif region.region_type == "place" and not skip_places: regions_to_seed.append(region) with Progress( @@ -229,14 +230,14 @@ def main(): help="Only seed UK regions", ) parser.add_argument( - "--include-places", + "--skip-places", action="store_true", - help="Include US places (cities over 100K population)", + help="Skip US places (cities over 100K population)", ) parser.add_argument( - "--include-districts", + "--skip-districts", action="store_true", - help="Include US congressional districts", + help="Skip US congressional districts", ) args = parser.parse_args() @@ -252,8 +253,8 @@ def main(): console.print("[bold]US Regions[/bold]") us_created, us_skipped = seed_us_regions( session, - include_places=args.include_places, - include_districts=args.include_districts, + skip_places=args.skip_places, + skip_districts=args.skip_districts, ) total_created += us_created total_skipped += us_skipped From 673de0935bed897ae07860ac375f097a38409adb Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 13 Feb 2026 00:40:03 +0100 Subject: [PATCH 07/10] refactor: Split seed.py into modular subscripts with presets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Break up monolithic seed.py into focused subscripts: - seed_utils.py: Shared utilities (get_session, bulk_insert, console) - seed_models.py: TaxBenefitModel, Version, Variables, Parameters, ParameterValues - seed_datasets.py: Dataset seeding and S3 upload - seed_policies.py: Example policy reforms - seed_regions.py: Geographic regions (updated to use seed_utils) Main seed.py is now an orchestrator with preset configurations: - full: Everything (default) - lite: Both countries, 2026 only, skip state params, core regions - minimal: Both countries, 2026 only, no policies/regions - uk-lite, uk-minimal: UK-only variants - us-lite, us-minimal: US-only variants Each subscript can also run standalone with its own CLI. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- scripts/seed.py | 880 ++++++--------------- scripts/seed_datasets.py | 227 ++++++ scripts/{seed_common.py => seed_models.py} | 211 +++-- scripts/seed_policies.py | 303 ++++--- scripts/seed_regions.py | 20 +- scripts/seed_uk_datasets.py | 113 --- scripts/seed_uk_model.py | 33 - scripts/seed_us_datasets.py | 108 --- scripts/seed_us_model.py | 33 - scripts/seed_utils.py | 72 ++ 10 files changed, 800 insertions(+), 1200 deletions(-) create mode 100644 scripts/seed_datasets.py rename scripts/{seed_common.py => seed_models.py} (68%) delete mode 100644 scripts/seed_uk_datasets.py delete mode 100644 scripts/seed_uk_model.py delete mode 100644 scripts/seed_us_datasets.py delete mode 100644 scripts/seed_us_model.py create mode 100644 scripts/seed_utils.py diff --git a/scripts/seed.py b/scripts/seed.py index 072b7bb..a1e1c35 100644 --- a/scripts/seed.py +++ b/scripts/seed.py @@ -1,693 +1,259 @@ -"""Seed database with UK and US models, variables, parameters, datasets.""" +"""Seed PolicyEngine database with models, datasets, policies, and regions. + +This is the main orchestrator script that calls individual seed scripts +based on the selected preset. + +Presets: + full - Everything (default) + lite - Both countries, 2026 datasets only, skip state params, core regions + minimal - Both countries, 2026 datasets only, skip state params, no policies/regions + uk-lite - UK only, 2026 datasets, skip state params + uk-minimal - UK only, 2026 datasets, skip state params, no policies/regions + us-lite - US only, 2026 datasets, skip state params, core regions only + us-minimal - US only, 2026 datasets, skip state params, no policies/regions + +Usage: + python scripts/seed.py # Full seed (default) + python scripts/seed.py --preset=lite # Lite mode for both countries + python scripts/seed.py --preset=us-lite # US only, lite mode + python scripts/seed.py --preset=minimal # Minimal seed (no policies/regions) +""" import argparse -import json -import logging -import math -import sys -import warnings -from datetime import datetime, timezone -from pathlib import Path -from uuid import uuid4 - -import logfire - -# Disable all SQLAlchemy and database logging BEFORE any imports -logging.basicConfig(level=logging.ERROR) -logging.getLogger("sqlalchemy").setLevel(logging.ERROR) -warnings.filterwarnings("ignore") - -# Add src to path -sys.path.insert(0, str(Path(__file__).parent.parent / "src")) - -from policyengine.tax_benefit_models.uk import uk_latest # noqa: E402 -from policyengine.tax_benefit_models.uk.datasets import ( # noqa: E402 - ensure_datasets as ensure_uk_datasets, -) -from policyengine.tax_benefit_models.us import us_latest # noqa: E402 -from policyengine.tax_benefit_models.us.datasets import ( # noqa: E402 - ensure_datasets as ensure_us_datasets, -) -from rich.console import Console # noqa: E402 -from rich.progress import Progress, SpinnerColumn, TextColumn # noqa: E402 -from sqlmodel import Session, create_engine, select # noqa: E402 - -from policyengine_api.config.settings import settings # noqa: E402 -from policyengine_api.models import ( # noqa: E402 - Dataset, - Parameter, - ParameterValue, - Policy, - TaxBenefitModel, - TaxBenefitModelVersion, -) -from policyengine_api.services.storage import ( # noqa: E402 - upload_dataset_for_seeding, -) - -# Configure logfire -if settings.logfire_token: - logfire.configure( - token=settings.logfire_token, - environment=settings.logfire_environment, - console=False, - ) - -console = Console() - - -def get_quiet_session(): - """Get database session with logging disabled.""" - engine = create_engine(settings.database_url, echo=False) - with Session(engine) as session: - yield session - - -def bulk_insert(session, table: str, columns: list[str], rows: list[dict]): - """Fast bulk insert using PostgreSQL COPY via StringIO.""" - if not rows: - return - - import io - - # Get raw psycopg2 connection - need to use the connection from session - # but not commit separately to avoid transaction issues - connection = session.connection() - raw_conn = connection.connection.dbapi_connection - cursor = raw_conn.cursor() - - # Build CSV-like data in memory - output = io.StringIO() - for row in rows: - values = [] - for col in columns: - val = row[col] - if val is None: - values.append("\\N") - elif isinstance(val, str): - # Escape special characters for COPY - val = ( - val.replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n") - ) - values.append(val) - else: - values.append(str(val)) - output.write("\t".join(values) + "\n") - - output.seek(0) - - # COPY is the fastest way to bulk load PostgreSQL - cursor.copy_from(output, table, columns=columns, null="\\N") - # Let SQLAlchemy handle the commit via session - session.commit() - - -def seed_model(model_version, session, lite: bool = False) -> TaxBenefitModelVersion: - """Seed a tax-benefit model with its variables and parameters.""" - - with logfire.span( - "seed_model", - model=model_version.model.id, - version=model_version.version, - ): - # Create or get the model - console.print(f"[bold blue]Seeding {model_version.model.id}...") - - existing_model = session.exec( - select(TaxBenefitModel).where( - TaxBenefitModel.name == model_version.model.id - ) - ).first() - - if existing_model: - db_model = existing_model - console.print(f" Using existing model: {db_model.id}") - else: - db_model = TaxBenefitModel( - name=model_version.model.id, - description=model_version.model.description, - ) - session.add(db_model) - session.commit() - session.refresh(db_model) - console.print(f" Created model: {db_model.id}") - - # Create model version - existing_version = session.exec( - select(TaxBenefitModelVersion).where( - TaxBenefitModelVersion.model_id == db_model.id, - TaxBenefitModelVersion.version == model_version.version, - ) - ).first() - - if existing_version: - console.print( - f" Model version {model_version.version} already exists, skipping" - ) - return existing_version - - db_version = TaxBenefitModelVersion( - model_id=db_model.id, - version=model_version.version, - description=f"Version {model_version.version}", - ) - session.add(db_version) - session.commit() - session.refresh(db_version) - console.print(f" Created version: {db_version.version}") - - # Add variables - with logfire.span("add_variables", count=len(model_version.variables)): - var_rows = [] - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task( - f"Preparing {len(model_version.variables)} variables", - total=len(model_version.variables), - ) - for var in model_version.variables: - var_rows.append( - { - "id": uuid4(), - "name": var.name, - "entity": var.entity, - "description": var.description or "", - "data_type": var.data_type.__name__ - if hasattr(var.data_type, "__name__") - else str(var.data_type), - "possible_values": None, - "tax_benefit_model_version_id": db_version.id, - "created_at": datetime.now(timezone.utc), - } - ) - progress.advance(task) - - console.print(f" Inserting {len(var_rows)} variables...") - bulk_insert( - session, - "variables", - [ - "id", - "name", - "entity", - "description", - "data_type", - "possible_values", - "tax_benefit_model_version_id", - "created_at", - ], - var_rows, +import time +from dataclasses import dataclass + +from seed_utils import console, get_session + +# Import seed functions from subscripts +from seed_datasets import seed_uk_datasets, seed_us_datasets +from seed_models import seed_uk_model, seed_us_model +from seed_policies import seed_uk_policy, seed_us_policy +from seed_regions import seed_uk_regions, seed_us_regions + + +@dataclass +class SeedConfig: + """Configuration for database seeding.""" + + # Countries + seed_uk: bool = True + seed_us: bool = True + + # Models + skip_state_params: bool = False + + # Datasets + dataset_year: int | None = None # None = all years + + # Policies + seed_policies: bool = True + + # Regions + seed_regions: bool = True + skip_places: bool = False + skip_districts: bool = False + + +# Preset configurations +PRESETS: dict[str, SeedConfig] = { + "full": SeedConfig( + seed_uk=True, + seed_us=True, + skip_state_params=False, + dataset_year=None, + seed_policies=True, + seed_regions=True, + skip_places=False, + skip_districts=False, + ), + "lite": SeedConfig( + seed_uk=True, + seed_us=True, + skip_state_params=True, + dataset_year=2026, + seed_policies=True, + seed_regions=True, + skip_places=True, + skip_districts=True, + ), + "minimal": SeedConfig( + seed_uk=True, + seed_us=True, + skip_state_params=True, + dataset_year=2026, + seed_policies=False, + seed_regions=False, + ), + "uk-lite": SeedConfig( + seed_uk=True, + seed_us=False, + skip_state_params=True, + dataset_year=2026, + seed_policies=True, + seed_regions=True, + ), + "uk-minimal": SeedConfig( + seed_uk=True, + seed_us=False, + skip_state_params=True, + dataset_year=2026, + seed_policies=False, + seed_regions=False, + ), + "us-lite": SeedConfig( + seed_uk=False, + seed_us=True, + skip_state_params=True, + dataset_year=2026, + seed_policies=True, + seed_regions=True, + skip_places=True, + skip_districts=True, + ), + "us-minimal": SeedConfig( + seed_uk=False, + seed_us=True, + skip_state_params=True, + dataset_year=2026, + seed_policies=False, + seed_regions=False, + ), +} + + +def run_seed(config: SeedConfig): + """Run database seeding with the given configuration.""" + start = time.time() + + with get_session() as session: + # Step 1: Seed models + console.print("[bold blue]Step 1: Seeding models...[/bold blue]\n") + + if config.seed_uk: + seed_uk_model(session, skip_state_params=config.skip_state_params) + + if config.seed_us: + seed_us_model(session, skip_state_params=config.skip_state_params) + + # Step 2: Seed datasets + console.print("[bold blue]Step 2: Seeding datasets...[/bold blue]\n") + + if config.seed_uk: + console.print("[bold]UK Datasets[/bold]") + uk_created, uk_skipped = seed_uk_datasets( + session, year=config.dataset_year ) - console.print( - f" [green]✓[/green] Added {len(model_version.variables)} variables" + f"[green]✓[/green] UK: {uk_created} created, {uk_skipped} skipped\n" ) - # Add parameters (only user-facing ones: those with labels) - # Deduplicate by name - keep first occurrence - # In lite mode, exclude US state parameters (gov.states.*) - seen_names = set() - parameters_to_add = [] - skipped_state_params = 0 - for p in model_version.parameters: - if p.label is None or p.name in seen_names: - continue - # In lite mode, skip state-level parameters for faster seeding - if lite and p.name.startswith("gov.states."): - skipped_state_params += 1 - continue - parameters_to_add.append(p) - seen_names.add(p.name) - - filter_msg = f" Filtered to {len(parameters_to_add)} user-facing parameters" - filter_msg += f" (from {len(model_version.parameters)} total, deduplicated by name)" - if lite and skipped_state_params > 0: - filter_msg += f", skipped {skipped_state_params} state params (lite mode)" - console.print(filter_msg) - - with logfire.span("add_parameters", count=len(parameters_to_add)): - # Build list of parameter dicts for bulk insert - param_rows = [] - param_names = [] # Track (pe_id, name, generated_uuid) - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task( - f"Preparing {len(parameters_to_add)} parameters", - total=len(parameters_to_add), - ) - for param in parameters_to_add: - param_uuid = uuid4() - param_rows.append( - { - "id": param_uuid, - "name": param.name, - "label": param.label if hasattr(param, "label") else None, - "description": param.description or "", - "data_type": param.data_type.__name__ - if hasattr(param.data_type, "__name__") - else str(param.data_type), - "unit": param.unit, - "tax_benefit_model_version_id": db_version.id, - "created_at": datetime.now(timezone.utc), - } - ) - param_names.append((param.id, param.name, param_uuid)) - progress.advance(task) - - console.print(f" Inserting {len(param_rows)} parameters...") - bulk_insert( - session, - "parameters", - [ - "id", - "name", - "label", - "description", - "data_type", - "unit", - "tax_benefit_model_version_id", - "created_at", - ], - param_rows, + if config.seed_us: + console.print("[bold]US Datasets[/bold]") + us_created, us_skipped = seed_us_datasets( + session, year=config.dataset_year ) - - # Build param_id_map from pre-generated UUIDs - param_id_map = {pe_id: db_uuid for pe_id, name, db_uuid in param_names} - console.print( - f" [green]✓[/green] Added {len(parameters_to_add)} parameters" + f"[green]✓[/green] US: {us_created} created, {us_skipped} skipped\n" ) - # Add parameter values - # Filter to only include values for parameters we added - parameter_values_to_add = [ - pv - for pv in model_version.parameter_values - if pv.parameter.id in param_id_map - ] - console.print(f" Found {len(parameter_values_to_add)} parameter values to add") - - with logfire.span("add_parameter_values", count=len(parameter_values_to_add)): - pv_rows = [] - skipped = 0 - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task( - f"Preparing {len(parameter_values_to_add)} parameter values", - total=len(parameter_values_to_add), - ) - for pv in parameter_values_to_add: - # Handle Infinity values - skip them as they can't be stored in JSON - if isinstance(pv.value, float) and ( - math.isinf(pv.value) or math.isnan(pv.value) - ): - skipped += 1 - progress.advance(task) - continue - - # Source data has dates swapped (start > end), fix ordering - # Only swap if both dates are set, otherwise keep original - if pv.start_date and pv.end_date: - start = pv.end_date # Swap: source end is our start - end = pv.start_date # Swap: source start is our end - else: - start = pv.start_date - end = pv.end_date - pv_rows.append( - { - "id": uuid4(), - "parameter_id": param_id_map[pv.parameter.id], - "value_json": json.dumps(pv.value), - "start_date": start, - "end_date": end, - "policy_id": None, - "dynamic_id": None, - "created_at": datetime.now(timezone.utc), - } - ) - progress.advance(task) - - console.print(f" Inserting {len(pv_rows)} parameter values...") - bulk_insert( - session, - "parameter_values", - [ - "id", - "parameter_id", - "value_json", - "start_date", - "end_date", - "policy_id", - "dynamic_id", - "created_at", - ], - pv_rows, - ) + # Step 3: Seed policies + if config.seed_policies: + console.print("[bold blue]Step 3: Seeding policies...[/bold blue]\n") - console.print( - f" [green]✓[/green] Added {len(pv_rows)} parameter values" - + (f" (skipped {skipped} invalid)" if skipped else "") - ) - - return db_version - - -def seed_datasets(session, lite: bool = False, skip_uk_datasets: bool = False): - """Seed datasets and upload to S3.""" - with logfire.span("seed_datasets"): - mode_str = " (lite mode - 2026 only)" if lite else "" - console.print(f"[bold blue]Seeding datasets{mode_str}...") + if config.seed_uk: + seed_uk_policy(session) - # Get UK and US models - uk_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") - ).first() - us_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") - ).first() - - if not uk_model or not us_model: - console.print( - "[red]Error: UK or US model not found. Run seed_model first.[/red]" - ) - return - - data_folder = str(Path(__file__).parent.parent / "data") - - # UK datasets - uk_created = 0 - uk_skipped = 0 - - if skip_uk_datasets: - console.print(" [yellow]Skipping UK datasets (--skip-uk-datasets)[/yellow]") - else: - console.print(" Creating UK datasets...") - uk_datasets = ensure_uk_datasets(data_folder=data_folder) - - # In lite mode, only upload FRS 2026 - if lite: - uk_datasets = { - k: v for k, v in uk_datasets.items() if v.year == 2026 and "frs" in k - } - console.print(f" Lite mode: filtered to {len(uk_datasets)} dataset(s)") - - with logfire.span("seed_uk_datasets", count=len(uk_datasets)): - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task("UK datasets", total=len(uk_datasets)) - for _, pe_dataset in uk_datasets.items(): - progress.update(task, description=f"UK: {pe_dataset.name}") - - # Check if dataset already exists - existing = session.exec( - select(Dataset).where(Dataset.name == pe_dataset.name) - ).first() - - if existing: - uk_skipped += 1 - progress.advance(task) - continue - - # Upload to S3 - object_name = upload_dataset_for_seeding(pe_dataset.filepath) - - # Create database record - db_dataset = Dataset( - name=pe_dataset.name, - description=pe_dataset.description, - filepath=object_name, - year=pe_dataset.year, - tax_benefit_model_id=uk_model.id, - ) - session.add(db_dataset) - session.commit() - uk_created += 1 - progress.advance(task) + if config.seed_us: + seed_us_policy(session) - console.print( - f" [green]✓[/green] UK: {uk_created} created, {uk_skipped} skipped" - ) + console.print() - # US datasets - console.print(" Creating US datasets...") - us_datasets = ensure_us_datasets(data_folder=data_folder) - - # In lite mode, only upload CPS 2026 - if lite: - us_datasets = { - k: v for k, v in us_datasets.items() if v.year == 2026 and "cps" in k - } - console.print(f" Lite mode: filtered to {len(us_datasets)} dataset(s)") - - us_created = 0 - us_skipped = 0 - - with logfire.span("seed_us_datasets", count=len(us_datasets)): - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task("US datasets", total=len(us_datasets)) - for _, pe_dataset in us_datasets.items(): - progress.update(task, description=f"US: {pe_dataset.name}") - - # Check if dataset already exists - existing = session.exec( - select(Dataset).where(Dataset.name == pe_dataset.name) - ).first() - - if existing: - us_skipped += 1 - progress.advance(task) - continue - - # Upload to S3 - object_name = upload_dataset_for_seeding(pe_dataset.filepath) - - # Create database record - db_dataset = Dataset( - name=pe_dataset.name, - description=pe_dataset.description, - filepath=object_name, - year=pe_dataset.year, - tax_benefit_model_id=us_model.id, - ) - session.add(db_dataset) - session.commit() - us_created += 1 - progress.advance(task) - - console.print( - f" [green]✓[/green] US: {us_created} created, {us_skipped} skipped" - ) - console.print( - f"[green]✓[/green] Seeded {uk_created + us_created} datasets total\n" - ) - - -def seed_example_policies(session): - """Seed example policy reforms for UK and US.""" - with logfire.span("seed_example_policies"): - console.print("[bold blue]Seeding example policies...") - - # Get model versions - uk_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") - ).first() - us_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") - ).first() - - if not uk_model or not us_model: - console.print( - "[red]Error: UK or US model not found. Run seed_model first.[/red]" - ) - return - - uk_version = session.exec( - select(TaxBenefitModelVersion) - .where(TaxBenefitModelVersion.model_id == uk_model.id) - .order_by(TaxBenefitModelVersion.created_at.desc()) - ).first() - - us_version = session.exec( - select(TaxBenefitModelVersion) - .where(TaxBenefitModelVersion.model_id == us_model.id) - .order_by(TaxBenefitModelVersion.created_at.desc()) - ).first() - - # UK example policy: raise basic rate to 22p - uk_policy_name = "UK basic rate 22p" - existing_uk_policy = session.exec( - select(Policy).where(Policy.name == uk_policy_name) - ).first() - - if existing_uk_policy: - console.print(f" Policy '{uk_policy_name}' already exists, skipping") - else: - # Find the basic rate parameter - uk_basic_rate_param = session.exec( - select(Parameter).where( - Parameter.name == "gov.hmrc.income_tax.rates.uk[0].rate", - Parameter.tax_benefit_model_version_id == uk_version.id, - ) - ).first() + # Step 4: Seed regions + if config.seed_regions: + console.print("[bold blue]Step 4: Seeding regions...[/bold blue]\n") - if uk_basic_rate_param: - uk_policy = Policy( - name=uk_policy_name, - description="Raise the UK income tax basic rate from 20p to 22p", - ) - session.add(uk_policy) - session.commit() - session.refresh(uk_policy) - - # Add parameter value (22% = 0.22) - uk_param_value = ParameterValue( - parameter_id=uk_basic_rate_param.id, - value_json={"value": 0.22}, - start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), - end_date=None, - policy_id=uk_policy.id, + if config.seed_us: + console.print("[bold]US Regions[/bold]") + us_created, us_skipped = seed_us_regions( + session, + skip_places=config.skip_places, + skip_districts=config.skip_districts, ) - session.add(uk_param_value) - session.commit() - console.print(f" [green]✓[/green] Created UK policy: {uk_policy_name}") - else: console.print( - " [yellow]Warning: UK basic rate parameter not found[/yellow]" + f"[green]✓[/green] US: {us_created} created, {us_skipped} skipped\n" ) - # US example policy: raise first bracket rate to 12% - us_policy_name = "US 12% lowest bracket" - existing_us_policy = session.exec( - select(Policy).where(Policy.name == us_policy_name) - ).first() - - if existing_us_policy: - console.print(f" Policy '{us_policy_name}' already exists, skipping") - else: - # Find the first bracket rate parameter - us_first_bracket_param = session.exec( - select(Parameter).where( - Parameter.name == "gov.irs.income.bracket.rates.1", - Parameter.tax_benefit_model_version_id == us_version.id, - ) - ).first() - - if us_first_bracket_param: - us_policy = Policy( - name=us_policy_name, - description="Raise US federal income tax lowest bracket to 12%", - ) - session.add(us_policy) - session.commit() - session.refresh(us_policy) - - # Add parameter value (12% = 0.12) - us_param_value = ParameterValue( - parameter_id=us_first_bracket_param.id, - value_json={"value": 0.12}, - start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), - end_date=None, - policy_id=us_policy.id, - ) - session.add(us_param_value) - session.commit() - console.print(f" [green]✓[/green] Created US policy: {us_policy_name}") - else: + if config.seed_uk: + console.print("[bold]UK Regions[/bold]") + uk_created, uk_skipped = seed_uk_regions(session) console.print( - " [yellow]Warning: US first bracket parameter not found[/yellow]" + f"[green]✓[/green] UK: {uk_created} created, {uk_skipped} skipped\n" ) - console.print("[green]✓[/green] Example policies seeded\n") + elapsed = time.time() - start + console.print(f"[bold green]✓ Database seeding complete![/bold green]") + console.print(f"[bold]Total time: {elapsed:.1f}s[/bold]") def main(): - """Main seed function.""" - parser = argparse.ArgumentParser(description="Seed PolicyEngine database") - parser.add_argument( - "--lite", - action="store_true", - help="Lite mode: skip US state parameters, only seed FRS 2026 and CPS 2026 datasets", - ) - parser.add_argument( - "--skip-uk-datasets", - action="store_true", - help="Skip UK datasets (useful when HuggingFace token is not available)", + parser = argparse.ArgumentParser( + description="Seed PolicyEngine database", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Presets: + full Everything (default) + lite Both countries, 2026 datasets only, skip state params, core regions + minimal Both countries, 2026 datasets only, skip state params, no policies/regions + uk-lite UK only, 2026 datasets, skip state params + uk-minimal UK only, 2026 datasets, skip state params, no policies/regions + us-lite US only, 2026 datasets, skip state params, core regions only + us-minimal US only, 2026 datasets, skip state params, no policies/regions +""", ) parser.add_argument( - "--skip-regions", - action="store_true", - help="Skip seeding regions", - ) - parser.add_argument( - "--skip-places", - action="store_true", - help="Skip US places (cities over 100K population) when seeding regions", - ) - parser.add_argument( - "--skip-districts", - action="store_true", - help="Skip US congressional districts when seeding regions", + "--preset", + choices=list(PRESETS.keys()), + default="full", + help="Seeding preset (default: full)", ) args = parser.parse_args() - with logfire.span("database_seeding"): - mode_str = " (lite mode)" if args.lite else "" - console.print(f"[bold green]PolicyEngine database seeding{mode_str}[/bold green]\n") + config = PRESETS[args.preset] - with next(get_quiet_session()) as session: - # Seed UK model - uk_version = seed_model(uk_latest, session, lite=args.lite) - console.print(f"[green]✓[/green] UK model seeded: {uk_version.id}\n") + # Build description of what we're doing + countries = [] + if config.seed_uk: + countries.append("UK") + if config.seed_us: + countries.append("US") + country_str = " + ".join(countries) - # Seed US model - us_version = seed_model(us_latest, session, lite=args.lite) - console.print(f"[green]✓[/green] US model seeded: {us_version.id}\n") + year_str = f", {config.dataset_year} only" if config.dataset_year else "" + state_str = ", skip state params" if config.skip_state_params else "" - # Seed datasets - seed_datasets(session, lite=args.lite, skip_uk_datasets=args.skip_uk_datasets) - - # Seed example policies - seed_example_policies(session) - - # Seed regions - if not args.skip_regions: - from seed_regions import seed_us_regions, seed_uk_regions - - console.print("\n[bold]Seeding regions...[/bold]") - us_created, us_skipped = seed_us_regions( - session, - skip_places=args.skip_places, - skip_districts=args.skip_districts, - ) - console.print( - f"[green]✓[/green] US regions: {us_created} created, {us_skipped} skipped" - ) - - if not args.skip_uk_datasets: - uk_created, uk_skipped = seed_uk_regions(session) - console.print( - f"[green]✓[/green] UK regions: {uk_created} created, {uk_skipped} skipped" - ) - - console.print("\n[bold green]✓ Database seeding complete![/bold green]") + console.print( + f"[bold green]PolicyEngine database seeding[/bold green] " + f"[dim](preset: {args.preset})[/dim]\n" + ) + console.print(f" Countries: {country_str}") + console.print(f" Datasets: {'all years' if not config.dataset_year else config.dataset_year}") + if config.skip_state_params: + console.print(" State params: skipped") + console.print(f" Policies: {'yes' if config.seed_policies else 'no'}") + if config.seed_regions: + region_details = [] + if config.skip_places: + region_details.append("no places") + if config.skip_districts: + region_details.append("no districts") + region_str = f"yes ({', '.join(region_details)})" if region_details else "yes (all)" + console.print(f" Regions: {region_str}") + else: + console.print(" Regions: no") + console.print() + + run_seed(config) if __name__ == "__main__": diff --git a/scripts/seed_datasets.py b/scripts/seed_datasets.py new file mode 100644 index 0000000..8a13130 --- /dev/null +++ b/scripts/seed_datasets.py @@ -0,0 +1,227 @@ +"""Seed datasets and upload to S3. + +This script downloads datasets from policyengine.py, uploads them to S3, +and creates database records. + +Usage: + python scripts/seed_datasets.py # Seed UK and US datasets + python scripts/seed_datasets.py --us-only # Seed only US datasets + python scripts/seed_datasets.py --uk-only # Seed only UK datasets + python scripts/seed_datasets.py --year=2026 # Seed only 2026 datasets +""" + +import argparse +from pathlib import Path + +import logfire +from rich.progress import Progress, SpinnerColumn, TextColumn +from sqlmodel import Session, select + +from seed_utils import console, get_session + +# Import after seed_utils sets up path +from policyengine_api.models import Dataset, TaxBenefitModel # noqa: E402 +from policyengine_api.services.storage import upload_dataset_for_seeding # noqa: E402 + + +def seed_uk_datasets(session: Session, year: int | None = None) -> tuple[int, int]: + """Seed UK datasets. + + Args: + session: Database session + year: If specified, only seed datasets for this year + + Returns: + Tuple of (created_count, skipped_count) + """ + from policyengine.tax_benefit_models.uk.datasets import ( + ensure_datasets as ensure_uk_datasets, + ) + + uk_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") + ).first() + + if not uk_model: + console.print("[red]Error: UK model not found. Run seed_models.py first.[/red]") + return 0, 0 + + data_folder = str(Path(__file__).parent.parent / "data") + uk_datasets = ensure_uk_datasets(data_folder=data_folder) + + # Filter by year if specified + if year: + uk_datasets = { + k: v for k, v in uk_datasets.items() if v.year == year and "frs" in k + } + console.print(f" Filtered to {len(uk_datasets)} dataset(s) for year {year}") + + created = 0 + skipped = 0 + + with logfire.span("seed_uk_datasets", count=len(uk_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("UK datasets", total=len(uk_datasets)) + for _, pe_dataset in uk_datasets.items(): + progress.update(task, description=f"UK: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=uk_model.id, + ) + session.add(db_dataset) + session.commit() + created += 1 + progress.advance(task) + + return created, skipped + + +def seed_us_datasets(session: Session, year: int | None = None) -> tuple[int, int]: + """Seed US datasets. + + Args: + session: Database session + year: If specified, only seed datasets for this year + + Returns: + Tuple of (created_count, skipped_count) + """ + from policyengine.tax_benefit_models.us.datasets import ( + ensure_datasets as ensure_us_datasets, + ) + + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not us_model: + console.print("[red]Error: US model not found. Run seed_models.py first.[/red]") + return 0, 0 + + data_folder = str(Path(__file__).parent.parent / "data") + us_datasets = ensure_us_datasets(data_folder=data_folder) + + # Filter by year if specified + if year: + us_datasets = { + k: v for k, v in us_datasets.items() if v.year == year and "cps" in k + } + console.print(f" Filtered to {len(us_datasets)} dataset(s) for year {year}") + + created = 0 + skipped = 0 + + with logfire.span("seed_us_datasets", count=len(us_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("US datasets", total=len(us_datasets)) + for _, pe_dataset in us_datasets.items(): + progress.update(task, description=f"US: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=us_model.id, + ) + session.add(db_dataset) + session.commit() + created += 1 + progress.advance(task) + + return created, skipped + + +def main(): + parser = argparse.ArgumentParser(description="Seed datasets") + parser.add_argument( + "--us-only", + action="store_true", + help="Only seed US datasets", + ) + parser.add_argument( + "--uk-only", + action="store_true", + help="Only seed UK datasets", + ) + parser.add_argument( + "--year", + type=int, + default=None, + help="Only seed datasets for this year (e.g., 2026)", + ) + args = parser.parse_args() + + year_str = f" (year {args.year})" if args.year else "" + console.print(f"[bold green]Seeding datasets{year_str}...[/bold green]\n") + + total_created = 0 + total_skipped = 0 + + with get_session() as session: + if not args.us_only: + console.print("[bold]UK Datasets[/bold]") + uk_created, uk_skipped = seed_uk_datasets(session, year=args.year) + total_created += uk_created + total_skipped += uk_skipped + console.print( + f"[green]✓[/green] UK: {uk_created} created, {uk_skipped} skipped\n" + ) + + if not args.uk_only: + console.print("[bold]US Datasets[/bold]") + us_created, us_skipped = seed_us_datasets(session, year=args.year) + total_created += us_created + total_skipped += us_skipped + console.print( + f"[green]✓[/green] US: {us_created} created, {us_skipped} skipped\n" + ) + + console.print( + f"[bold green]✓ Dataset seeding complete! " + f"{total_created} created, {total_skipped} skipped[/bold green]" + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_common.py b/scripts/seed_models.py similarity index 68% rename from scripts/seed_common.py rename to scripts/seed_models.py index 49797cb..d970df9 100644 --- a/scripts/seed_common.py +++ b/scripts/seed_models.py @@ -1,106 +1,59 @@ -"""Shared utilities for seed scripts.""" +"""Seed tax-benefit models with variables and parameters. -import io +This script seeds TaxBenefitModel, TaxBenefitModelVersion, Variables, +Parameters, and ParameterValues from policyengine.py. + +Usage: + python scripts/seed_models.py # Seed UK and US models + python scripts/seed_models.py --us-only # Seed only US model + python scripts/seed_models.py --uk-only # Seed only UK model + python scripts/seed_models.py --skip-state-params # Skip US state parameters +""" + +import argparse import json -import logging import math -import sys -import warnings from datetime import datetime, timezone -from pathlib import Path from uuid import uuid4 import logfire -from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn -from sqlmodel import Session, create_engine - -# Disable all SQLAlchemy and database logging BEFORE any imports -logging.basicConfig(level=logging.ERROR) -logging.getLogger("sqlalchemy").setLevel(logging.ERROR) -warnings.filterwarnings("ignore") - -# Add src to path -sys.path.insert(0, str(Path(__file__).parent.parent / "src")) - -from policyengine_api.config.settings import settings # noqa: E402 - -# Configure logfire -if settings.logfire_token: - logfire.configure( - token=settings.logfire_token, - environment=settings.logfire_environment, - console=False, - ) - -console = Console() - - -def get_session(): - """Get database session with logging disabled.""" - engine = create_engine(settings.database_url, echo=False) - return Session(engine) - - -def bulk_insert(session, table: str, columns: list[str], rows: list[dict]): - """Fast bulk insert using PostgreSQL COPY via StringIO.""" - if not rows: - return - - # Get raw psycopg2 connection - connection = session.connection() - raw_conn = connection.connection.dbapi_connection - cursor = raw_conn.cursor() - - # Build CSV-like data in memory - output = io.StringIO() - for row in rows: - values = [] - for col in columns: - val = row[col] - if val is None: - values.append("\\N") - elif isinstance(val, str): - # Escape special characters for COPY - val = ( - val.replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n") - ) - values.append(val) - else: - values.append(str(val)) - output.write("\t".join(values) + "\n") +from sqlmodel import Session, select - output.seek(0) +from seed_utils import bulk_insert, console, get_session - # COPY is the fastest way to bulk load PostgreSQL - cursor.copy_from(output, table, columns=columns, null="\\N") - session.commit() +# Import after seed_utils sets up path +from policyengine_api.models import ( # noqa: E402 + Parameter, + ParameterValue, + TaxBenefitModel, + TaxBenefitModelVersion, +) -def seed_model(model_version, session, lite: bool = False): +def seed_model( + model_version, + session: Session, + skip_state_params: bool = False, +) -> TaxBenefitModelVersion: """Seed a tax-benefit model with its variables and parameters. Args: - model_version: The policyengine package model version + model_version: The policyengine.py model version object session: Database session - lite: If True, skip state-level parameters + skip_state_params: Skip US state-level parameters (gov.states.*) - Returns the TaxBenefitModelVersion that was created or found. + Returns: + The created or existing TaxBenefitModelVersion """ - from policyengine_api.models import ( - TaxBenefitModel, - TaxBenefitModelVersion, - ) - from sqlmodel import select - with logfire.span( "seed_model", model=model_version.model.id, version=model_version.version, ): - # Create or get the model console.print(f"[bold blue]Seeding {model_version.model.id}...") + # Create or get the model existing_model = session.exec( select(TaxBenefitModel).where( TaxBenefitModel.name == model_version.model.id @@ -157,10 +110,6 @@ def seed_model(model_version, session, lite: bool = False): total=len(model_version.variables), ) for var in model_version.variables: - # default_value is pre-serialized by policyengine.py: - # - Enum values are converted to their name (e.g., "SINGLE") - # - datetime.date values are converted to ISO format - # - Primitives (bool, int, float, str) are kept as-is var_rows.append( { "id": uuid4(), @@ -171,7 +120,6 @@ def seed_model(model_version, session, lite: bool = False): if hasattr(var.data_type, "__name__") else str(var.data_type), "possible_values": None, - "default_value": json.dumps(var.default_value), "tax_benefit_model_version_id": db_version.id, "created_at": datetime.now(timezone.utc), } @@ -189,7 +137,6 @@ def seed_model(model_version, session, lite: bool = False): "description", "data_type", "possible_values", - "default_value", "tax_benefit_model_version_id", "created_at", ], @@ -200,43 +147,30 @@ def seed_model(model_version, session, lite: bool = False): f" [green]✓[/green] Added {len(model_version.variables)} variables" ) - # Add parameters - deduplicate by name (keep first occurrence) - # - # WHY DEDUPLICATION IS NEEDED: - # The policyengine package can provide multiple parameter entries with the same - # name. This happens because parameters can have multiple bracket entries or - # state-specific variants that share the same base name. We keep only the first - # occurrence to avoid database unique constraint violations and reduce redundancy. - # - # NOTE: We do NOT filter by label. Parameters without labels (bracket params, - # breakdown params) are still valid and needed for policy analysis. - # - # In lite mode, exclude US state parameters (gov.states.*) + # Add parameters (only user-facing ones: those with labels) + # Deduplicate by name - keep first occurrence seen_names = set() parameters_to_add = [] - skipped_state_params = 0 - skipped_duplicate = 0 - + skipped_state_params_count = 0 for p in model_version.parameters: - if p.name in seen_names: - skipped_duplicate += 1 + if p.label is None or p.name in seen_names: continue - # In lite mode, skip state-level parameters for faster seeding - if lite and p.name.startswith("gov.states."): - skipped_state_params += 1 + # Skip state-level parameters if requested + if skip_state_params and p.name.startswith("gov.states."): + skipped_state_params_count += 1 continue parameters_to_add.append(p) seen_names.add(p.name) - console.print(f" Parameter filtering:") - console.print(f" - Total from source: {len(model_version.parameters)}") - console.print(f" - Skipped (duplicate name): {skipped_duplicate}") - if lite and skipped_state_params > 0: - console.print(f" - Skipped (state params, lite mode): {skipped_state_params}") - console.print(f" - To add: {len(parameters_to_add)}") + filter_msg = f" Filtered to {len(parameters_to_add)} user-facing parameters" + filter_msg += ( + f" (from {len(model_version.parameters)} total, deduplicated by name)" + ) + if skip_state_params and skipped_state_params_count > 0: + filter_msg += f", skipped {skipped_state_params_count} state params" + console.print(filter_msg) with logfire.span("add_parameters", count=len(parameters_to_add)): - # Build list of parameter dicts for bulk insert param_rows = [] param_names = [] # Track (pe_id, name, generated_uuid) @@ -293,7 +227,6 @@ def seed_model(model_version, session, lite: bool = False): ) # Add parameter values - # Filter to only include values for parameters we added parameter_values_to_add = [ pv for pv in model_version.parameter_values @@ -324,7 +257,6 @@ def seed_model(model_version, session, lite: bool = False): continue # Source data has dates swapped (start > end), fix ordering - # Only swap if both dates are set, otherwise keep original if pv.start_date and pv.end_date: start = pv.end_date # Swap: source end is our start end = pv.start_date # Swap: source start is our end @@ -368,3 +300,56 @@ def seed_model(model_version, session, lite: bool = False): ) return db_version + + +def seed_uk_model(session: Session, skip_state_params: bool = False): + """Seed UK model.""" + from policyengine.tax_benefit_models.uk import uk_latest + + version = seed_model(uk_latest, session, skip_state_params=skip_state_params) + console.print(f"[green]✓[/green] UK model seeded: {version.id}\n") + return version + + +def seed_us_model(session: Session, skip_state_params: bool = False): + """Seed US model.""" + from policyengine.tax_benefit_models.us import us_latest + + version = seed_model(us_latest, session, skip_state_params=skip_state_params) + console.print(f"[green]✓[/green] US model seeded: {version.id}\n") + return version + + +def main(): + parser = argparse.ArgumentParser(description="Seed tax-benefit models") + parser.add_argument( + "--us-only", + action="store_true", + help="Only seed US model", + ) + parser.add_argument( + "--uk-only", + action="store_true", + help="Only seed UK model", + ) + parser.add_argument( + "--skip-state-params", + action="store_true", + help="Skip US state-level parameters (gov.states.*)", + ) + args = parser.parse_args() + + console.print("[bold green]Seeding tax-benefit models...[/bold green]\n") + + with get_session() as session: + if not args.us_only: + seed_uk_model(session, skip_state_params=args.skip_state_params) + + if not args.uk_only: + seed_us_model(session, skip_state_params=args.skip_state_params) + + console.print("[bold green]✓ Model seeding complete![/bold green]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_policies.py b/scripts/seed_policies.py index e57b964..c3212d6 100644 --- a/scripts/seed_policies.py +++ b/scripts/seed_policies.py @@ -1,142 +1,191 @@ -"""Seed example policy reforms for UK and US.""" +"""Seed example policy reforms. -import time +This script creates example policy reforms for UK and US models. + +Usage: + python scripts/seed_policies.py # Seed UK and US example policies + python scripts/seed_policies.py --us-only # Seed only US example policy + python scripts/seed_policies.py --uk-only # Seed only UK example policy +""" + +import argparse from datetime import datetime, timezone import logfire -from sqlmodel import select - -from seed_common import console, get_session +from sqlmodel import Session, select + +from seed_utils import console, get_session + +# Import after seed_utils sets up path +from policyengine_api.models import ( # noqa: E402 + Parameter, + ParameterValue, + Policy, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +def seed_uk_policy(session: Session) -> bool: + """Seed UK example policy: raise basic rate to 22p. + + Returns: + True if created, False if skipped + """ + uk_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") + ).first() + + if not uk_model: + console.print("[red]Error: UK model not found. Run seed_models.py first.[/red]") + return False + + uk_version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == uk_model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + + if not uk_version: + console.print( + "[red]Error: UK model version not found. Run seed_models.py first.[/red]" + ) + return False + + policy_name = "UK basic rate 22p" + existing = session.exec(select(Policy).where(Policy.name == policy_name)).first() + + if existing: + console.print(f" Policy '{policy_name}' already exists, skipping") + return False + + # Find the basic rate parameter + uk_basic_rate_param = session.exec( + select(Parameter).where( + Parameter.name == "gov.hmrc.income_tax.rates.uk[0].rate", + Parameter.tax_benefit_model_version_id == uk_version.id, + ) + ).first() + + if not uk_basic_rate_param: + console.print(" [yellow]Warning: UK basic rate parameter not found[/yellow]") + return False + + uk_policy = Policy( + name=policy_name, + description="Raise the UK income tax basic rate from 20p to 22p", + ) + session.add(uk_policy) + session.commit() + session.refresh(uk_policy) + + # Add parameter value (22% = 0.22) + uk_param_value = ParameterValue( + parameter_id=uk_basic_rate_param.id, + value_json={"value": 0.22}, + start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_date=None, + policy_id=uk_policy.id, + ) + session.add(uk_param_value) + session.commit() + console.print(f" [green]✓[/green] Created UK policy: {policy_name}") + return True + + +def seed_us_policy(session: Session) -> bool: + """Seed US example policy: raise first bracket to 12%. + + Returns: + True if created, False if skipped + """ + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not us_model: + console.print("[red]Error: US model not found. Run seed_models.py first.[/red]") + return False + + us_version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == us_model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + + if not us_version: + console.print( + "[red]Error: US model version not found. Run seed_models.py first.[/red]" + ) + return False + + policy_name = "US 12% lowest bracket" + existing = session.exec(select(Policy).where(Policy.name == policy_name)).first() + + if existing: + console.print(f" Policy '{policy_name}' already exists, skipping") + return False + + # Find the first bracket rate parameter + us_first_bracket_param = session.exec( + select(Parameter).where( + Parameter.name == "gov.irs.income.bracket.rates.1", + Parameter.tax_benefit_model_version_id == us_version.id, + ) + ).first() + + if not us_first_bracket_param: + console.print( + " [yellow]Warning: US first bracket parameter not found[/yellow]" + ) + return False + + us_policy = Policy( + name=policy_name, + description="Raise US federal income tax lowest bracket to 12%", + ) + session.add(us_policy) + session.commit() + session.refresh(us_policy) + + # Add parameter value (12% = 0.12) + us_param_value = ParameterValue( + parameter_id=us_first_bracket_param.id, + value_json={"value": 0.12}, + start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_date=None, + policy_id=us_policy.id, + ) + session.add(us_param_value) + session.commit() + console.print(f" [green]✓[/green] Created US policy: {policy_name}") + return True def main(): - from policyengine_api.models import ( - Parameter, - ParameterValue, - Policy, - TaxBenefitModel, - TaxBenefitModelVersion, + parser = argparse.ArgumentParser(description="Seed example policies") + parser.add_argument( + "--us-only", + action="store_true", + help="Only seed US example policy", + ) + parser.add_argument( + "--uk-only", + action="store_true", + help="Only seed UK example policy", ) + args = parser.parse_args() console.print("[bold green]Seeding example policies...[/bold green]\n") - start = time.time() with get_session() as session: - with logfire.span("seed_example_policies"): - # Get model versions - uk_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") - ).first() - us_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") - ).first() - - if not uk_model or not us_model: - console.print( - "[red]Error: UK or US model not found. Run seed_*_model.py first.[/red]" - ) - return - - uk_version = session.exec( - select(TaxBenefitModelVersion) - .where(TaxBenefitModelVersion.model_id == uk_model.id) - .order_by(TaxBenefitModelVersion.created_at.desc()) - ).first() - - us_version = session.exec( - select(TaxBenefitModelVersion) - .where(TaxBenefitModelVersion.model_id == us_model.id) - .order_by(TaxBenefitModelVersion.created_at.desc()) - ).first() - - # UK example policy: raise basic rate to 22p - uk_policy_name = "UK basic rate 22p" - existing_uk_policy = session.exec( - select(Policy).where(Policy.name == uk_policy_name) - ).first() - - if existing_uk_policy: - console.print(f" Policy '{uk_policy_name}' already exists, skipping") - else: - # Find the basic rate parameter - uk_basic_rate_param = session.exec( - select(Parameter).where( - Parameter.name == "gov.hmrc.income_tax.rates.uk[0].rate", - Parameter.tax_benefit_model_version_id == uk_version.id, - ) - ).first() - - if uk_basic_rate_param: - uk_policy = Policy( - name=uk_policy_name, - description="Raise the UK income tax basic rate from 20p to 22p", - ) - session.add(uk_policy) - session.commit() - session.refresh(uk_policy) - - # Add parameter value (22% = 0.22) - uk_param_value = ParameterValue( - parameter_id=uk_basic_rate_param.id, - value_json={"value": 0.22}, - start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), - end_date=None, - policy_id=uk_policy.id, - ) - session.add(uk_param_value) - session.commit() - console.print(f" [green]✓[/green] Created UK policy: {uk_policy_name}") - else: - console.print( - " [yellow]Warning: UK basic rate parameter not found[/yellow]" - ) - - # US example policy: raise first bracket rate to 12% - us_policy_name = "US 12% lowest bracket" - existing_us_policy = session.exec( - select(Policy).where(Policy.name == us_policy_name) - ).first() - - if existing_us_policy: - console.print(f" Policy '{us_policy_name}' already exists, skipping") - else: - # Find the first bracket rate parameter - us_first_bracket_param = session.exec( - select(Parameter).where( - Parameter.name == "gov.irs.income.bracket.rates.1", - Parameter.tax_benefit_model_version_id == us_version.id, - ) - ).first() - - if us_first_bracket_param: - us_policy = Policy( - name=us_policy_name, - description="Raise US federal income tax lowest bracket to 12%", - ) - session.add(us_policy) - session.commit() - session.refresh(us_policy) - - # Add parameter value (12% = 0.12) - us_param_value = ParameterValue( - parameter_id=us_first_bracket_param.id, - value_json={"value": 0.12}, - start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), - end_date=None, - policy_id=us_policy.id, - ) - session.add(us_param_value) - session.commit() - console.print(f" [green]✓[/green] Created US policy: {us_policy_name}") - else: - console.print( - " [yellow]Warning: US first bracket parameter not found[/yellow]" - ) - - console.print("[green]✓[/green] Example policies seeded") - - elapsed = time.time() - start - console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + if not args.us_only: + seed_uk_policy(session) + + if not args.uk_only: + seed_us_policy(session) + + console.print("\n[bold green]✓ Policy seeding complete![/bold green]") if __name__ == "__main__": diff --git a/scripts/seed_regions.py b/scripts/seed_regions.py index c8cc9d8..060fb2f 100644 --- a/scripts/seed_regions.py +++ b/scripts/seed_regions.py @@ -16,27 +16,15 @@ """ import argparse -import sys import time -from pathlib import Path -# Add src to path -sys.path.insert(0, str(Path(__file__).parent.parent / "src")) - -from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn -from sqlmodel import Session, create_engine, select - -from policyengine_api.config.settings import settings -from policyengine_api.models import Dataset, Region, TaxBenefitModel - -console = Console() +from sqlmodel import Session, select +from seed_utils import console, get_session -def get_session() -> Session: - """Get database session.""" - engine = create_engine(settings.database_url) - return Session(engine) +# Import after seed_utils sets up path +from policyengine_api.models import Dataset, Region, TaxBenefitModel # noqa: E402 def seed_us_regions( diff --git a/scripts/seed_uk_datasets.py b/scripts/seed_uk_datasets.py deleted file mode 100644 index 1754454..0000000 --- a/scripts/seed_uk_datasets.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Seed UK datasets (FRS) and upload to S3. - -NOTE: Requires HUGGING_FACE_TOKEN environment variable to be set, -as UK FRS datasets are hosted on a private HuggingFace repository. -""" - -import argparse -import time -from pathlib import Path - -import logfire -from rich.progress import Progress, SpinnerColumn, TextColumn -from sqlmodel import select - -from seed_common import console, get_session - - -def main(): - parser = argparse.ArgumentParser(description="Seed UK datasets") - parser.add_argument( - "--lite", - action="store_true", - help="Lite mode: only seed FRS 2026", - ) - args = parser.parse_args() - - # Import here to avoid slow import at module level - from policyengine.tax_benefit_models.uk.datasets import ( - ensure_datasets as ensure_uk_datasets, - ) - - from policyengine_api.models import Dataset, TaxBenefitModel - from policyengine_api.services.storage import upload_dataset_for_seeding - - console.print("[bold green]Seeding UK datasets...[/bold green]\n") - console.print("[yellow]Note: Requires HUGGING_FACE_TOKEN environment variable[/yellow]\n") - - start = time.time() - with get_session() as session: - # Get UK model - uk_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") - ).first() - - if not uk_model: - console.print("[red]Error: UK model not found. Run seed_uk_model.py first.[/red]") - return - - data_folder = str(Path(__file__).parent.parent / "data") - console.print(f" Data folder: {data_folder}") - - # Get datasets - console.print(" Loading UK datasets from policyengine package...") - ds_start = time.time() - uk_datasets = ensure_uk_datasets(data_folder=data_folder) - console.print(f" Loaded {len(uk_datasets)} datasets in {time.time() - ds_start:.1f}s") - - # In lite mode, only upload FRS 2026 - if args.lite: - uk_datasets = { - k: v for k, v in uk_datasets.items() if v.year == 2026 and "frs" in k - } - console.print(f" Lite mode: filtered to {len(uk_datasets)} dataset(s)") - - created = 0 - skipped = 0 - - with logfire.span("seed_uk_datasets", count=len(uk_datasets)): - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task("UK datasets", total=len(uk_datasets)) - for name, pe_dataset in uk_datasets.items(): - progress.update(task, description=f"UK: {pe_dataset.name}") - - # Check if dataset already exists - existing = session.exec( - select(Dataset).where(Dataset.name == pe_dataset.name) - ).first() - - if existing: - skipped += 1 - progress.advance(task) - continue - - # Upload to S3 - upload_start = time.time() - object_name = upload_dataset_for_seeding(pe_dataset.filepath) - console.print(f" Uploaded {pe_dataset.name} in {time.time() - upload_start:.1f}s") - - # Create database record - db_dataset = Dataset( - name=pe_dataset.name, - description=pe_dataset.description, - filepath=object_name, - year=pe_dataset.year, - tax_benefit_model_id=uk_model.id, - ) - session.add(db_dataset) - session.commit() - created += 1 - progress.advance(task) - - console.print(f"[green]✓[/green] UK datasets: {created} created, {skipped} skipped") - - elapsed = time.time() - start - console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") - - -if __name__ == "__main__": - main() diff --git a/scripts/seed_uk_model.py b/scripts/seed_uk_model.py deleted file mode 100644 index 07543bf..0000000 --- a/scripts/seed_uk_model.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Seed UK model (variables, parameters, parameter values).""" - -import argparse -import time - -from seed_common import console, get_session, seed_model - - -def main(): - parser = argparse.ArgumentParser(description="Seed UK model") - parser.add_argument( - "--lite", - action="store_true", - help="Lite mode: skip state parameters", - ) - args = parser.parse_args() - - # Import here to avoid slow import at module level - from policyengine.tax_benefit_models.uk import uk_latest - - console.print("[bold green]Seeding UK model...[/bold green]\n") - - start = time.time() - with get_session() as session: - version = seed_model(uk_latest, session, lite=args.lite) - console.print(f"[green]✓[/green] UK model seeded: {version.id}") - - elapsed = time.time() - start - console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") - - -if __name__ == "__main__": - main() diff --git a/scripts/seed_us_datasets.py b/scripts/seed_us_datasets.py deleted file mode 100644 index abf1995..0000000 --- a/scripts/seed_us_datasets.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Seed US datasets (CPS) and upload to S3.""" - -import argparse -import time -from pathlib import Path - -import logfire -from rich.progress import Progress, SpinnerColumn, TextColumn -from sqlmodel import select - -from seed_common import console, get_session - - -def main(): - parser = argparse.ArgumentParser(description="Seed US datasets") - parser.add_argument( - "--lite", - action="store_true", - help="Lite mode: only seed CPS 2026", - ) - args = parser.parse_args() - - # Import here to avoid slow import at module level - from policyengine.tax_benefit_models.us.datasets import ( - ensure_datasets as ensure_us_datasets, - ) - - from policyengine_api.models import Dataset, TaxBenefitModel - from policyengine_api.services.storage import upload_dataset_for_seeding - - console.print("[bold green]Seeding US datasets...[/bold green]\n") - - start = time.time() - with get_session() as session: - # Get US model - us_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") - ).first() - - if not us_model: - console.print("[red]Error: US model not found. Run seed_us_model.py first.[/red]") - return - - data_folder = str(Path(__file__).parent.parent / "data") - console.print(f" Data folder: {data_folder}") - - # Get datasets - console.print(" Loading US datasets from policyengine package...") - ds_start = time.time() - us_datasets = ensure_us_datasets(data_folder=data_folder) - console.print(f" Loaded {len(us_datasets)} datasets in {time.time() - ds_start:.1f}s") - - # In lite mode, only upload CPS 2026 - if args.lite: - us_datasets = { - k: v for k, v in us_datasets.items() if v.year == 2026 and "cps" in k - } - console.print(f" Lite mode: filtered to {len(us_datasets)} dataset(s)") - - created = 0 - skipped = 0 - - with logfire.span("seed_us_datasets", count=len(us_datasets)): - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task("US datasets", total=len(us_datasets)) - for name, pe_dataset in us_datasets.items(): - progress.update(task, description=f"US: {pe_dataset.name}") - - # Check if dataset already exists - existing = session.exec( - select(Dataset).where(Dataset.name == pe_dataset.name) - ).first() - - if existing: - skipped += 1 - progress.advance(task) - continue - - # Upload to S3 - upload_start = time.time() - object_name = upload_dataset_for_seeding(pe_dataset.filepath) - console.print(f" Uploaded {pe_dataset.name} in {time.time() - upload_start:.1f}s") - - # Create database record - db_dataset = Dataset( - name=pe_dataset.name, - description=pe_dataset.description, - filepath=object_name, - year=pe_dataset.year, - tax_benefit_model_id=us_model.id, - ) - session.add(db_dataset) - session.commit() - created += 1 - progress.advance(task) - - console.print(f"[green]✓[/green] US datasets: {created} created, {skipped} skipped") - - elapsed = time.time() - start - console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") - - -if __name__ == "__main__": - main() diff --git a/scripts/seed_us_model.py b/scripts/seed_us_model.py deleted file mode 100644 index ce8a829..0000000 --- a/scripts/seed_us_model.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Seed US model (variables, parameters, parameter values).""" - -import argparse -import time - -from seed_common import console, get_session, seed_model - - -def main(): - parser = argparse.ArgumentParser(description="Seed US model") - parser.add_argument( - "--lite", - action="store_true", - help="Lite mode: skip state parameters", - ) - args = parser.parse_args() - - # Import here to avoid slow import at module level - from policyengine.tax_benefit_models.us import us_latest - - console.print("[bold green]Seeding US model...[/bold green]\n") - - start = time.time() - with get_session() as session: - version = seed_model(us_latest, session, lite=args.lite) - console.print(f"[green]✓[/green] US model seeded: {version.id}") - - elapsed = time.time() - start - console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") - - -if __name__ == "__main__": - main() diff --git a/scripts/seed_utils.py b/scripts/seed_utils.py new file mode 100644 index 0000000..624379f --- /dev/null +++ b/scripts/seed_utils.py @@ -0,0 +1,72 @@ +"""Shared utilities for seed scripts.""" + +import io +import logging +import sys +import warnings +from pathlib import Path + +import logfire +from rich.console import Console +from sqlmodel import Session, create_engine + +# Disable all SQLAlchemy and database logging +logging.basicConfig(level=logging.ERROR) +logging.getLogger("sqlalchemy").setLevel(logging.ERROR) +warnings.filterwarnings("ignore") + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from policyengine_api.config.settings import settings # noqa: E402 + +# Configure logfire +if settings.logfire_token: + logfire.configure( + token=settings.logfire_token, + environment=settings.logfire_environment, + console=False, + ) + +console = Console() + + +def get_session() -> Session: + """Get database session with logging disabled.""" + engine = create_engine(settings.database_url, echo=False) + return Session(engine) + + +def bulk_insert(session: Session, table: str, columns: list[str], rows: list[dict]): + """Fast bulk insert using PostgreSQL COPY via StringIO.""" + if not rows: + return + + # Get raw psycopg2 connection + connection = session.connection() + raw_conn = connection.connection.dbapi_connection + cursor = raw_conn.cursor() + + # Build CSV-like data in memory + output = io.StringIO() + for row in rows: + values = [] + for col in columns: + val = row[col] + if val is None: + values.append("\\N") + elif isinstance(val, str): + # Escape special characters for COPY + val = ( + val.replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n") + ) + values.append(val) + else: + values.append(str(val)) + output.write("\t".join(values) + "\n") + + output.seek(0) + + # COPY is the fastest way to bulk load PostgreSQL + cursor.copy_from(output, table, columns=columns, null="\\N") + session.commit() From 048b7e274092a135c2d83264e3a03a29cceb8822 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 17 Feb 2026 00:18:44 +0100 Subject: [PATCH 08/10] fix: Update region tests to pass simulation_type argument Tests were written before _get_or_create_simulation and _get_deterministic_simulation_id gained the simulation_type parameter. Add SimulationType.ECONOMY and use keyword args for dataset_id/filter params to match the current function signatures. Co-Authored-By: Claude Opus 4.6 --- ...isting_simulation__then_reuses_existing.py | 9 ++- ...ith_filter__then_filter_params_included.py | 3 + ...iven_same_params__then_deterministic_id.py | 65 ++++++++++++++----- 3 files changed, 60 insertions(+), 17 deletions(-) diff --git a/tests/test__given_existing_simulation__then_reuses_existing.py b/tests/test__given_existing_simulation__then_reuses_existing.py index 77731f1..09f4df0 100644 --- a/tests/test__given_existing_simulation__then_reuses_existing.py +++ b/tests/test__given_existing_simulation__then_reuses_existing.py @@ -8,7 +8,7 @@ from sqlmodel import Session from policyengine_api.api.analysis import _get_or_create_simulation -from policyengine_api.models import SimulationStatus +from policyengine_api.models import SimulationStatus, SimulationType from test_fixtures.fixtures_regions import ( create_dataset, create_simulation, @@ -29,6 +29,7 @@ def test_given_existing_simulation_with_filter_then_reuses(self, session: Sessio # Create initial simulation with filter params first_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, dataset_id=dataset.id, model_version_id=model_version.id, policy_id=None, @@ -40,6 +41,7 @@ def test_given_existing_simulation_with_filter_then_reuses(self, session: Sessio # When - request same simulation again second_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, dataset_id=dataset.id, model_version_id=model_version.id, policy_id=None, @@ -61,6 +63,7 @@ def test_given_different_filter_then_creates_new_simulation(self, session: Sessi # Create simulation for England england_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, dataset_id=dataset.id, model_version_id=model_version.id, policy_id=None, @@ -72,6 +75,7 @@ def test_given_different_filter_then_creates_new_simulation(self, session: Sessi # When - request simulation for Scotland scotland_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, dataset_id=dataset.id, model_version_id=model_version.id, policy_id=None, @@ -97,6 +101,7 @@ def test_given_no_filter_vs_filter_then_creates_separate_simulations( # Create national (no filter) simulation national_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, dataset_id=dataset.id, model_version_id=model_version.id, policy_id=None, @@ -108,6 +113,7 @@ def test_given_no_filter_vs_filter_then_creates_separate_simulations( # When - request filtered simulation filtered_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, dataset_id=dataset.id, model_version_id=model_version.id, policy_id=None, @@ -131,6 +137,7 @@ def test_given_new_simulation_then_status_is_pending(self, session: Session): # When simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, dataset_id=dataset.id, model_version_id=model_version.id, policy_id=None, diff --git a/tests/test__given_region_with_filter__then_filter_params_included.py b/tests/test__given_region_with_filter__then_filter_params_included.py index c84a372..bc1d862 100644 --- a/tests/test__given_region_with_filter__then_filter_params_included.py +++ b/tests/test__given_region_with_filter__then_filter_params_included.py @@ -13,6 +13,7 @@ _get_or_create_simulation, _resolve_dataset_and_region, ) +from policyengine_api.models import SimulationType from test_fixtures.fixtures_regions import ( create_dataset, create_region, @@ -134,6 +135,7 @@ def test_given_filter_params_then_simulation_has_filter_fields( # When simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, dataset_id=dataset.id, model_version_id=model_version.id, policy_id=None, @@ -158,6 +160,7 @@ def test_given_no_filter_params_then_simulation_has_null_filter_fields( # When simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, dataset_id=dataset.id, model_version_id=model_version.id, policy_id=None, diff --git a/tests/test__given_same_params__then_deterministic_id.py b/tests/test__given_same_params__then_deterministic_id.py index d393f53..cbac11c 100644 --- a/tests/test__given_same_params__then_deterministic_id.py +++ b/tests/test__given_same_params__then_deterministic_id.py @@ -10,6 +10,7 @@ import pytest from policyengine_api.api.analysis import _get_deterministic_simulation_id +from policyengine_api.models import SimulationType class TestDeterministicSimulationId: @@ -27,20 +28,22 @@ def test_given_same_params_then_same_id_returned(self): # When id1 = _get_deterministic_simulation_id( - dataset_id, + SimulationType.ECONOMY, model_version_id, policy_id, dynamic_id, - filter_field, - filter_value, + dataset_id=dataset_id, + filter_field=filter_field, + filter_value=filter_value, ) id2 = _get_deterministic_simulation_id( - dataset_id, + SimulationType.ECONOMY, model_version_id, policy_id, dynamic_id, - filter_field, - filter_value, + dataset_id=dataset_id, + filter_field=filter_field, + filter_value=filter_value, ) # Then @@ -56,18 +59,20 @@ def test_given_different_filter_field_then_different_id(self): # When id1 = _get_deterministic_simulation_id( - dataset_id, + SimulationType.ECONOMY, model_version_id, policy_id, dynamic_id, + dataset_id=dataset_id, filter_field="country", filter_value="ENGLAND", ) id2 = _get_deterministic_simulation_id( - dataset_id, + SimulationType.ECONOMY, model_version_id, policy_id, dynamic_id, + dataset_id=dataset_id, filter_field="state_code", filter_value="ENGLAND", ) @@ -85,18 +90,20 @@ def test_given_different_filter_value_then_different_id(self): # When id1 = _get_deterministic_simulation_id( - dataset_id, + SimulationType.ECONOMY, model_version_id, policy_id, dynamic_id, + dataset_id=dataset_id, filter_field="country", filter_value="ENGLAND", ) id2 = _get_deterministic_simulation_id( - dataset_id, + SimulationType.ECONOMY, model_version_id, policy_id, dynamic_id, + dataset_id=dataset_id, filter_field="country", filter_value="SCOTLAND", ) @@ -114,18 +121,20 @@ def test_given_filter_none_vs_filter_set_then_different_id(self): # When id_no_filter = _get_deterministic_simulation_id( - dataset_id, + SimulationType.ECONOMY, model_version_id, policy_id, dynamic_id, + dataset_id=dataset_id, filter_field=None, filter_value=None, ) id_with_filter = _get_deterministic_simulation_id( - dataset_id, + SimulationType.ECONOMY, model_version_id, policy_id, dynamic_id, + dataset_id=dataset_id, filter_field="country", filter_value="ENGLAND", ) @@ -144,10 +153,22 @@ def test_given_different_dataset_then_different_id(self): # When id1 = _get_deterministic_simulation_id( - uuid4(), model_version_id, policy_id, dynamic_id, filter_field, filter_value + SimulationType.ECONOMY, + model_version_id, + policy_id, + dynamic_id, + dataset_id=uuid4(), + filter_field=filter_field, + filter_value=filter_value, ) id2 = _get_deterministic_simulation_id( - uuid4(), model_version_id, policy_id, dynamic_id, filter_field, filter_value + SimulationType.ECONOMY, + model_version_id, + policy_id, + dynamic_id, + dataset_id=uuid4(), + filter_field=filter_field, + filter_value=filter_value, ) # Then @@ -161,10 +182,22 @@ def test_given_null_optional_params_then_consistent_id(self): # When id1 = _get_deterministic_simulation_id( - dataset_id, model_version_id, None, None, None, None + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field=None, + filter_value=None, ) id2 = _get_deterministic_simulation_id( - dataset_id, model_version_id, None, None, None, None + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field=None, + filter_value=None, ) # Then From b42335f77e1af5c395a829447aa2af3c8f21537e Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 17 Feb 2026 01:00:38 +0100 Subject: [PATCH 09/10] refactor: Consolidate region test files into test_analysis.py Merge 6 separate test__given_* files into test_analysis.py organized by function tested: TestResolveDatasetAndRegion, TestGetDeterministicSimulationId, TestGetOrCreateSimulation. Fix pre-existing test_missing_dataset_id assertion (400 not 422). Move @pytest.mark.integration from file-level to class-level. Co-Authored-By: Claude Opus 4.6 --- ...__given_dataset_id__then_region_is_none.py | 94 --- ...isting_simulation__then_reuses_existing.py | 151 ---- ...t__given_invalid_region__then_404_error.py | 107 --- ...ith_filter__then_filter_params_included.py | 173 ----- ...without_filter__then_filter_params_none.py | 110 --- ...iven_same_params__then_deterministic_id.py | 204 ------ tests/test_analysis.py | 679 +++++++++++++++++- 7 files changed, 661 insertions(+), 857 deletions(-) delete mode 100644 tests/test__given_dataset_id__then_region_is_none.py delete mode 100644 tests/test__given_existing_simulation__then_reuses_existing.py delete mode 100644 tests/test__given_invalid_region__then_404_error.py delete mode 100644 tests/test__given_region_with_filter__then_filter_params_included.py delete mode 100644 tests/test__given_region_without_filter__then_filter_params_none.py delete mode 100644 tests/test__given_same_params__then_deterministic_id.py diff --git a/tests/test__given_dataset_id__then_region_is_none.py b/tests/test__given_dataset_id__then_region_is_none.py deleted file mode 100644 index ee3c1d5..0000000 --- a/tests/test__given_dataset_id__then_region_is_none.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Tests for dataset resolution when dataset_id is provided directly. - -When a dataset_id is provided instead of a region code, -the resolved region should be None. -""" - -import pytest -from sqlmodel import Session - -from policyengine_api.api.analysis import ( - EconomicImpactRequest, - _resolve_dataset_and_region, -) -from test_fixtures.fixtures_regions import ( - create_dataset, - create_tax_benefit_model, -) - - -class TestResolveDatasetWithDatasetId: - """Tests for _resolve_dataset_and_region when dataset_id is provided.""" - - def test_given_dataset_id_then_region_is_none(self, session: Session): - """Given a dataset_id, then region is None in the response.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - dataset = create_dataset(session, model, name="uk_enhanced_frs") - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", - dataset_id=dataset.id, - ) - - # When - resolved_dataset, resolved_region = _resolve_dataset_and_region( - request, session - ) - - # Then - assert resolved_region is None - - def test_given_dataset_id_then_dataset_is_returned(self, session: Session): - """Given a dataset_id, then the correct dataset is returned.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - dataset = create_dataset(session, model, name="uk_enhanced_frs") - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", - dataset_id=dataset.id, - ) - - # When - resolved_dataset, resolved_region = _resolve_dataset_and_region( - request, session - ) - - # Then - assert resolved_dataset.id == dataset.id - assert resolved_dataset.name == "uk_enhanced_frs" - - def test_given_dataset_id_and_region_then_region_takes_precedence( - self, session: Session - ): - """Given both dataset_id and region, then region takes precedence.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - dataset1 = create_dataset(session, model, name="dataset_from_id") - dataset2 = create_dataset(session, model, name="dataset_from_region") - from test_fixtures.fixtures_regions import create_region - - region = create_region( - session, - model=model, - dataset=dataset2, - code="uk", - label="United Kingdom", - region_type="national", - requires_filter=False, - ) - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", - dataset_id=dataset1.id, - region="uk", - ) - - # When - resolved_dataset, resolved_region = _resolve_dataset_and_region( - request, session - ) - - # Then - # Region code takes precedence, so we get dataset2 - assert resolved_dataset.id == dataset2.id - assert resolved_region is not None - assert resolved_region.code == "uk" diff --git a/tests/test__given_existing_simulation__then_reuses_existing.py b/tests/test__given_existing_simulation__then_reuses_existing.py deleted file mode 100644 index 09f4df0..0000000 --- a/tests/test__given_existing_simulation__then_reuses_existing.py +++ /dev/null @@ -1,151 +0,0 @@ -"""Tests for simulation reuse with filter parameters. - -When a simulation with the same parameters already exists, -it should be reused instead of creating a new one. -""" - -import pytest -from sqlmodel import Session - -from policyengine_api.api.analysis import _get_or_create_simulation -from policyengine_api.models import SimulationStatus, SimulationType -from test_fixtures.fixtures_regions import ( - create_dataset, - create_simulation, - create_tax_benefit_model, - create_tax_benefit_model_version, -) - - -class TestSimulationReuse: - """Tests for simulation reuse behavior.""" - - def test_given_existing_simulation_with_filter_then_reuses(self, session: Session): - """Given an existing simulation with filter params, then it is reused.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - model_version = create_tax_benefit_model_version(session, model) - dataset = create_dataset(session, model, name="uk_enhanced_frs") - - # Create initial simulation with filter params - first_sim = _get_or_create_simulation( - simulation_type=SimulationType.ECONOMY, - dataset_id=dataset.id, - model_version_id=model_version.id, - policy_id=None, - dynamic_id=None, - session=session, - filter_field="country", - filter_value="ENGLAND", - ) - - # When - request same simulation again - second_sim = _get_or_create_simulation( - simulation_type=SimulationType.ECONOMY, - dataset_id=dataset.id, - model_version_id=model_version.id, - policy_id=None, - dynamic_id=None, - session=session, - filter_field="country", - filter_value="ENGLAND", - ) - - # Then - assert first_sim.id == second_sim.id - - def test_given_different_filter_then_creates_new_simulation(self, session: Session): - """Given different filter params, then a new simulation is created.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - model_version = create_tax_benefit_model_version(session, model) - dataset = create_dataset(session, model, name="uk_enhanced_frs") - - # Create simulation for England - england_sim = _get_or_create_simulation( - simulation_type=SimulationType.ECONOMY, - dataset_id=dataset.id, - model_version_id=model_version.id, - policy_id=None, - dynamic_id=None, - session=session, - filter_field="country", - filter_value="ENGLAND", - ) - - # When - request simulation for Scotland - scotland_sim = _get_or_create_simulation( - simulation_type=SimulationType.ECONOMY, - dataset_id=dataset.id, - model_version_id=model_version.id, - policy_id=None, - dynamic_id=None, - session=session, - filter_field="country", - filter_value="SCOTLAND", - ) - - # Then - assert england_sim.id != scotland_sim.id - assert england_sim.filter_value == "ENGLAND" - assert scotland_sim.filter_value == "SCOTLAND" - - def test_given_no_filter_vs_filter_then_creates_separate_simulations( - self, session: Session - ): - """Given national vs filtered, then separate simulations are created.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - model_version = create_tax_benefit_model_version(session, model) - dataset = create_dataset(session, model, name="uk_enhanced_frs") - - # Create national (no filter) simulation - national_sim = _get_or_create_simulation( - simulation_type=SimulationType.ECONOMY, - dataset_id=dataset.id, - model_version_id=model_version.id, - policy_id=None, - dynamic_id=None, - session=session, - filter_field=None, - filter_value=None, - ) - - # When - request filtered simulation - filtered_sim = _get_or_create_simulation( - simulation_type=SimulationType.ECONOMY, - dataset_id=dataset.id, - model_version_id=model_version.id, - policy_id=None, - dynamic_id=None, - session=session, - filter_field="country", - filter_value="ENGLAND", - ) - - # Then - assert national_sim.id != filtered_sim.id - assert national_sim.filter_field is None - assert filtered_sim.filter_field == "country" - - def test_given_new_simulation_then_status_is_pending(self, session: Session): - """Given a new simulation request, then status is PENDING.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - model_version = create_tax_benefit_model_version(session, model) - dataset = create_dataset(session, model, name="uk_enhanced_frs") - - # When - simulation = _get_or_create_simulation( - simulation_type=SimulationType.ECONOMY, - dataset_id=dataset.id, - model_version_id=model_version.id, - policy_id=None, - dynamic_id=None, - session=session, - filter_field="country", - filter_value="ENGLAND", - ) - - # Then - assert simulation.status == SimulationStatus.PENDING diff --git a/tests/test__given_invalid_region__then_404_error.py b/tests/test__given_invalid_region__then_404_error.py deleted file mode 100644 index 562a5c9..0000000 --- a/tests/test__given_invalid_region__then_404_error.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Tests for region resolution error cases. - -When an invalid region code is provided or required parameters are missing, -appropriate HTTP errors should be raised. -""" - -import pytest -from fastapi import HTTPException -from sqlmodel import Session - -from policyengine_api.api.analysis import ( - EconomicImpactRequest, - _resolve_dataset_and_region, -) -from test_fixtures.fixtures_regions import ( - create_dataset, - create_region, - create_tax_benefit_model, -) - - -class TestInvalidRegionCode: - """Tests for invalid region code handling.""" - - def test_given_nonexistent_region_code_then_raises_404(self, session: Session): - """Given a region code that doesn't exist, then raises 404.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - dataset = create_dataset(session, model, name="uk_enhanced_frs") - # Note: No region is created for this code - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", - region="nonexistent/region", - ) - - # When/Then - with pytest.raises(HTTPException) as exc_info: - _resolve_dataset_and_region(request, session) - - assert exc_info.value.status_code == 404 - assert "not found" in exc_info.value.detail.lower() - - def test_given_region_for_wrong_model_then_raises_404(self, session: Session): - """Given a region code for wrong model, then raises 404.""" - # Given - uk_model = create_tax_benefit_model(session, name="policyengine-uk") - uk_dataset = create_dataset(session, uk_model, name="uk_enhanced_frs") - create_region( - session, - model=uk_model, - dataset=uk_dataset, - code="uk", - label="United Kingdom", - region_type="national", - ) - # Request uses US model but UK region code - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_us", - region="uk", - ) - - # When/Then - with pytest.raises(HTTPException) as exc_info: - _resolve_dataset_and_region(request, session) - - assert exc_info.value.status_code == 404 - - -class TestMissingRequiredParams: - """Tests for missing required parameters.""" - - def test_given_neither_dataset_nor_region_then_raises_400(self, session: Session): - """Given neither dataset_id nor region, then raises 400.""" - # Given - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", - # Neither dataset_id nor region provided - ) - - # When/Then - with pytest.raises(HTTPException) as exc_info: - _resolve_dataset_and_region(request, session) - - assert exc_info.value.status_code == 400 - assert "either dataset_id or region" in exc_info.value.detail.lower() - - -class TestNonexistentDataset: - """Tests for nonexistent dataset handling.""" - - def test_given_nonexistent_dataset_id_then_raises_404(self, session: Session): - """Given a dataset_id that doesn't exist, then raises 404.""" - # Given - from uuid import uuid4 - - nonexistent_id = uuid4() - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", - dataset_id=nonexistent_id, - ) - - # When/Then - with pytest.raises(HTTPException) as exc_info: - _resolve_dataset_and_region(request, session) - - assert exc_info.value.status_code == 404 - assert "not found" in exc_info.value.detail.lower() diff --git a/tests/test__given_region_with_filter__then_filter_params_included.py b/tests/test__given_region_with_filter__then_filter_params_included.py deleted file mode 100644 index bc1d862..0000000 --- a/tests/test__given_region_with_filter__then_filter_params_included.py +++ /dev/null @@ -1,173 +0,0 @@ -"""Tests for region resolution with filter parameters. - -When a region requires filtering (e.g., England from UK dataset, -California from US dataset), the filter_field and filter_value -should be extracted and passed through to simulations. -""" - -import pytest -from sqlmodel import Session - -from policyengine_api.api.analysis import ( - EconomicImpactRequest, - _get_or_create_simulation, - _resolve_dataset_and_region, -) -from policyengine_api.models import SimulationType -from test_fixtures.fixtures_regions import ( - create_dataset, - create_region, - create_tax_benefit_model, - create_tax_benefit_model_version, -) - - -class TestResolveDatasetAndRegionWithFilter: - """Tests for _resolve_dataset_and_region when region requires filtering.""" - - def test_given_region_requires_filter_then_returns_filter_field( - self, session: Session - ): - """Given a region that requires filtering, then filter_field is populated.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - dataset = create_dataset(session, model, name="uk_enhanced_frs") - region = create_region( - session, - model=model, - dataset=dataset, - code="country/england", - label="England", - region_type="country", - requires_filter=True, - filter_field="country", - filter_value="ENGLAND", - ) - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", - region="country/england", - ) - - # When - resolved_dataset, resolved_region = _resolve_dataset_and_region( - request, session - ) - - # Then - assert resolved_region is not None - assert resolved_region.filter_field == "country" - assert resolved_region.filter_value == "ENGLAND" - assert resolved_region.requires_filter is True - - def test_given_us_state_region_then_returns_state_filter(self, session: Session): - """Given a US state region, then returns state code filter.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-us") - dataset = create_dataset(session, model, name="us_cps") - region = create_region( - session, - model=model, - dataset=dataset, - code="state/ca", - label="California", - region_type="state", - requires_filter=True, - filter_field="state_code", - filter_value="CA", - ) - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_us", - region="state/ca", - ) - - # When - resolved_dataset, resolved_region = _resolve_dataset_and_region( - request, session - ) - - # Then - assert resolved_region is not None - assert resolved_region.filter_field == "state_code" - assert resolved_region.filter_value == "CA" - - def test_given_region_with_filter_then_dataset_is_resolved(self, session: Session): - """Given a region code, then the associated dataset is returned.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - dataset = create_dataset(session, model, name="uk_enhanced_frs") - region = create_region( - session, - model=model, - dataset=dataset, - code="country/england", - label="England", - region_type="country", - requires_filter=True, - filter_field="country", - filter_value="ENGLAND", - ) - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", - region="country/england", - ) - - # When - resolved_dataset, resolved_region = _resolve_dataset_and_region( - request, session - ) - - # Then - assert resolved_dataset.id == dataset.id - assert resolved_dataset.name == "uk_enhanced_frs" - - -class TestSimulationCreationWithFilter: - """Tests for creating simulations with filter parameters.""" - - def test_given_filter_params_then_simulation_has_filter_fields( - self, session: Session - ): - """Given filter parameters, then created simulation has filter fields populated.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - model_version = create_tax_benefit_model_version(session, model) - dataset = create_dataset(session, model, name="uk_enhanced_frs") - - # When - simulation = _get_or_create_simulation( - simulation_type=SimulationType.ECONOMY, - dataset_id=dataset.id, - model_version_id=model_version.id, - policy_id=None, - dynamic_id=None, - session=session, - filter_field="country", - filter_value="ENGLAND", - ) - - # Then - assert simulation.filter_field == "country" - assert simulation.filter_value == "ENGLAND" - - def test_given_no_filter_params_then_simulation_has_null_filter_fields( - self, session: Session - ): - """Given no filter parameters, then created simulation has null filter fields.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - model_version = create_tax_benefit_model_version(session, model) - dataset = create_dataset(session, model, name="uk_enhanced_frs") - - # When - simulation = _get_or_create_simulation( - simulation_type=SimulationType.ECONOMY, - dataset_id=dataset.id, - model_version_id=model_version.id, - policy_id=None, - dynamic_id=None, - session=session, - ) - - # Then - assert simulation.filter_field is None - assert simulation.filter_value is None diff --git a/tests/test__given_region_without_filter__then_filter_params_none.py b/tests/test__given_region_without_filter__then_filter_params_none.py deleted file mode 100644 index e81d7a8..0000000 --- a/tests/test__given_region_without_filter__then_filter_params_none.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Tests for region resolution without filter parameters. - -When a region does not require filtering (e.g., national UK or US), -the filter_field and filter_value should be None. -""" - -import pytest -from sqlmodel import Session - -from policyengine_api.api.analysis import ( - EconomicImpactRequest, - _resolve_dataset_and_region, -) -from test_fixtures.fixtures_regions import ( - create_dataset, - create_region, - create_tax_benefit_model, -) - - -class TestResolveDatasetAndRegionWithoutFilter: - """Tests for _resolve_dataset_and_region when region does not require filtering.""" - - def test_given_national_uk_region_then_filter_params_none(self, session: Session): - """Given UK national region, then filter_field and filter_value are None.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - dataset = create_dataset(session, model, name="uk_enhanced_frs") - region = create_region( - session, - model=model, - dataset=dataset, - code="uk", - label="United Kingdom", - region_type="national", - requires_filter=False, - ) - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", - region="uk", - ) - - # When - resolved_dataset, resolved_region = _resolve_dataset_and_region( - request, session - ) - - # Then - assert resolved_region is not None - assert resolved_region.requires_filter is False - assert resolved_region.filter_field is None - assert resolved_region.filter_value is None - - def test_given_national_us_region_then_filter_params_none(self, session: Session): - """Given US national region, then filter_field and filter_value are None.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-us") - dataset = create_dataset(session, model, name="us_cps") - region = create_region( - session, - model=model, - dataset=dataset, - code="us", - label="United States", - region_type="national", - requires_filter=False, - ) - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_us", - region="us", - ) - - # When - resolved_dataset, resolved_region = _resolve_dataset_and_region( - request, session - ) - - # Then - assert resolved_region is not None - assert resolved_region.requires_filter is False - assert resolved_region.filter_field is None - assert resolved_region.filter_value is None - - def test_given_national_region_then_dataset_still_resolved(self, session: Session): - """Given national region without filter, then dataset is still correctly resolved.""" - # Given - model = create_tax_benefit_model(session, name="policyengine-uk") - dataset = create_dataset(session, model, name="uk_enhanced_frs") - region = create_region( - session, - model=model, - dataset=dataset, - code="uk", - label="United Kingdom", - region_type="national", - requires_filter=False, - ) - request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", - region="uk", - ) - - # When - resolved_dataset, resolved_region = _resolve_dataset_and_region( - request, session - ) - - # Then - assert resolved_dataset.id == dataset.id - assert resolved_dataset.name == "uk_enhanced_frs" diff --git a/tests/test__given_same_params__then_deterministic_id.py b/tests/test__given_same_params__then_deterministic_id.py deleted file mode 100644 index cbac11c..0000000 --- a/tests/test__given_same_params__then_deterministic_id.py +++ /dev/null @@ -1,204 +0,0 @@ -"""Tests for deterministic simulation ID generation. - -The simulation ID is generated deterministically from the simulation -parameters (dataset, model version, policy, dynamic, filter params). -This ensures that re-running the same simulation reuses existing results. -""" - -from uuid import uuid4 - -import pytest - -from policyengine_api.api.analysis import _get_deterministic_simulation_id -from policyengine_api.models import SimulationType - - -class TestDeterministicSimulationId: - """Tests for _get_deterministic_simulation_id function.""" - - def test_given_same_params_then_same_id_returned(self): - """Given identical parameters, then the same ID is returned.""" - # Given - dataset_id = uuid4() - model_version_id = uuid4() - policy_id = uuid4() - dynamic_id = uuid4() - filter_field = "country" - filter_value = "ENGLAND" - - # When - id1 = _get_deterministic_simulation_id( - SimulationType.ECONOMY, - model_version_id, - policy_id, - dynamic_id, - dataset_id=dataset_id, - filter_field=filter_field, - filter_value=filter_value, - ) - id2 = _get_deterministic_simulation_id( - SimulationType.ECONOMY, - model_version_id, - policy_id, - dynamic_id, - dataset_id=dataset_id, - filter_field=filter_field, - filter_value=filter_value, - ) - - # Then - assert id1 == id2 - - def test_given_different_filter_field_then_different_id(self): - """Given different filter_field, then a different ID is returned.""" - # Given - dataset_id = uuid4() - model_version_id = uuid4() - policy_id = None - dynamic_id = None - - # When - id1 = _get_deterministic_simulation_id( - SimulationType.ECONOMY, - model_version_id, - policy_id, - dynamic_id, - dataset_id=dataset_id, - filter_field="country", - filter_value="ENGLAND", - ) - id2 = _get_deterministic_simulation_id( - SimulationType.ECONOMY, - model_version_id, - policy_id, - dynamic_id, - dataset_id=dataset_id, - filter_field="state_code", - filter_value="ENGLAND", - ) - - # Then - assert id1 != id2 - - def test_given_different_filter_value_then_different_id(self): - """Given different filter_value, then a different ID is returned.""" - # Given - dataset_id = uuid4() - model_version_id = uuid4() - policy_id = None - dynamic_id = None - - # When - id1 = _get_deterministic_simulation_id( - SimulationType.ECONOMY, - model_version_id, - policy_id, - dynamic_id, - dataset_id=dataset_id, - filter_field="country", - filter_value="ENGLAND", - ) - id2 = _get_deterministic_simulation_id( - SimulationType.ECONOMY, - model_version_id, - policy_id, - dynamic_id, - dataset_id=dataset_id, - filter_field="country", - filter_value="SCOTLAND", - ) - - # Then - assert id1 != id2 - - def test_given_filter_none_vs_filter_set_then_different_id(self): - """Given None filter vs set filter, then different IDs are returned.""" - # Given - dataset_id = uuid4() - model_version_id = uuid4() - policy_id = None - dynamic_id = None - - # When - id_no_filter = _get_deterministic_simulation_id( - SimulationType.ECONOMY, - model_version_id, - policy_id, - dynamic_id, - dataset_id=dataset_id, - filter_field=None, - filter_value=None, - ) - id_with_filter = _get_deterministic_simulation_id( - SimulationType.ECONOMY, - model_version_id, - policy_id, - dynamic_id, - dataset_id=dataset_id, - filter_field="country", - filter_value="ENGLAND", - ) - - # Then - assert id_no_filter != id_with_filter - - def test_given_different_dataset_then_different_id(self): - """Given different dataset_id, then a different ID is returned.""" - # Given - model_version_id = uuid4() - policy_id = None - dynamic_id = None - filter_field = "country" - filter_value = "ENGLAND" - - # When - id1 = _get_deterministic_simulation_id( - SimulationType.ECONOMY, - model_version_id, - policy_id, - dynamic_id, - dataset_id=uuid4(), - filter_field=filter_field, - filter_value=filter_value, - ) - id2 = _get_deterministic_simulation_id( - SimulationType.ECONOMY, - model_version_id, - policy_id, - dynamic_id, - dataset_id=uuid4(), - filter_field=filter_field, - filter_value=filter_value, - ) - - # Then - assert id1 != id2 - - def test_given_null_optional_params_then_consistent_id(self): - """Given null optional parameters, then consistent ID is generated.""" - # Given - dataset_id = uuid4() - model_version_id = uuid4() - - # When - id1 = _get_deterministic_simulation_id( - SimulationType.ECONOMY, - model_version_id, - None, - None, - dataset_id=dataset_id, - filter_field=None, - filter_value=None, - ) - id2 = _get_deterministic_simulation_id( - SimulationType.ECONOMY, - model_version_id, - None, - None, - dataset_id=dataset_id, - filter_field=None, - filter_value=None, - ) - - # Then - assert id1 == id2 diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 90dbe7c..ebcb7c2 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -1,26 +1,674 @@ -"""Tests for economic impact analysis endpoint. +"""Tests for economic impact analysis (analysis.py). -These tests require a running database with seeded data. -Run with: make integration-test +Unit tests for internal functions (_resolve_dataset_and_region, +_get_deterministic_simulation_id, _get_or_create_simulation) and +integration tests for the /analysis/economic-impact endpoint. """ -import pytest +from uuid import uuid4 -pytestmark = pytest.mark.integration +import pytest +from fastapi import HTTPException from fastapi.testclient import TestClient from sqlmodel import Session, select +from policyengine_api.api.analysis import ( + EconomicImpactRequest, + _get_deterministic_simulation_id, + _get_or_create_simulation, + _resolve_dataset_and_region, +) from policyengine_api.main import app -from policyengine_api.models import Dataset, Simulation, TaxBenefitModel +from policyengine_api.models import ( + Dataset, + Simulation, + SimulationStatus, + SimulationType, + TaxBenefitModel, +) +from test_fixtures.fixtures_regions import ( + create_dataset, + create_region, + create_tax_benefit_model, + create_tax_benefit_model_version, +) client = TestClient(app) +# --------------------------------------------------------------------------- +# _resolve_dataset_and_region +# --------------------------------------------------------------------------- + + +class TestResolveDatasetAndRegion: + """Tests for _resolve_dataset_and_region.""" + + # -- dataset_id path -- + + def test__given_dataset_id__then_region_is_none(self, session: Session): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + dataset_id=dataset.id, + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_region is None + + def test__given_dataset_id__then_dataset_is_returned(self, session: Session): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + dataset_id=dataset.id, + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_dataset.id == dataset.id + assert resolved_dataset.name == "uk_enhanced_frs" + + def test__given_dataset_id_and_region__then_region_takes_precedence( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset1 = create_dataset(session, model, name="dataset_from_id") + dataset2 = create_dataset(session, model, name="dataset_from_region") + create_region( + session, + model=model, + dataset=dataset2, + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + dataset_id=dataset1.id, + region="uk", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_dataset.id == dataset2.id + assert resolved_region is not None + assert resolved_region.code == "uk" + + # -- region with filter -- + + def test__given_region_requires_filter__then_returns_filter_fields( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + create_region( + session, + model=model, + dataset=dataset, + code="country/england", + label="England", + region_type="country", + requires_filter=True, + filter_field="country", + filter_value="ENGLAND", + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="country/england", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_region is not None + assert resolved_region.filter_field == "country" + assert resolved_region.filter_value == "ENGLAND" + assert resolved_region.requires_filter is True + + def test__given_us_state_region__then_returns_state_filter( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-us") + dataset = create_dataset(session, model, name="us_cps") + create_region( + session, + model=model, + dataset=dataset, + code="state/ca", + label="California", + region_type="state", + requires_filter=True, + filter_field="state_code", + filter_value="CA", + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_us", + region="state/ca", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_region is not None + assert resolved_region.filter_field == "state_code" + assert resolved_region.filter_value == "CA" + + def test__given_region_with_filter__then_dataset_is_resolved( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + create_region( + session, + model=model, + dataset=dataset, + code="country/england", + label="England", + region_type="country", + requires_filter=True, + filter_field="country", + filter_value="ENGLAND", + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="country/england", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_dataset.id == dataset.id + assert resolved_dataset.name == "uk_enhanced_frs" + + # -- region without filter -- + + def test__given_national_uk_region__then_filter_params_none( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + create_region( + session, + model=model, + dataset=dataset, + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="uk", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_region is not None + assert resolved_region.requires_filter is False + assert resolved_region.filter_field is None + assert resolved_region.filter_value is None + + def test__given_national_us_region__then_filter_params_none( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-us") + dataset = create_dataset(session, model, name="us_cps") + create_region( + session, + model=model, + dataset=dataset, + code="us", + label="United States", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_us", + region="us", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_region is not None + assert resolved_region.requires_filter is False + assert resolved_region.filter_field is None + assert resolved_region.filter_value is None + + def test__given_national_region__then_dataset_still_resolved( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + dataset = create_dataset(session, model, name="uk_enhanced_frs") + create_region( + session, + model=model, + dataset=dataset, + code="uk", + label="United Kingdom", + region_type="national", + requires_filter=False, + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="uk", + ) + + resolved_dataset, resolved_region = _resolve_dataset_and_region( + request, session + ) + + assert resolved_dataset.id == dataset.id + assert resolved_dataset.name == "uk_enhanced_frs" + + # -- error cases -- + + def test__given_nonexistent_region_code__then_raises_404(self, session: Session): + model = create_tax_benefit_model(session, name="policyengine-uk") + create_dataset(session, model, name="uk_enhanced_frs") + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + region="nonexistent/region", + ) + + with pytest.raises(HTTPException) as exc_info: + _resolve_dataset_and_region(request, session) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + def test__given_region_for_wrong_model__then_raises_404(self, session: Session): + uk_model = create_tax_benefit_model(session, name="policyengine-uk") + uk_dataset = create_dataset(session, uk_model, name="uk_enhanced_frs") + create_region( + session, + model=uk_model, + dataset=uk_dataset, + code="uk", + label="United Kingdom", + region_type="national", + ) + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_us", + region="uk", + ) + + with pytest.raises(HTTPException) as exc_info: + _resolve_dataset_and_region(request, session) + + assert exc_info.value.status_code == 404 + + def test__given_neither_dataset_nor_region__then_raises_400(self, session: Session): + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + ) + + with pytest.raises(HTTPException) as exc_info: + _resolve_dataset_and_region(request, session) + + assert exc_info.value.status_code == 400 + assert "either dataset_id or region" in exc_info.value.detail.lower() + + def test__given_nonexistent_dataset_id__then_raises_404(self, session: Session): + nonexistent_id = uuid4() + request = EconomicImpactRequest( + tax_benefit_model_name="policyengine_uk", + dataset_id=nonexistent_id, + ) + + with pytest.raises(HTTPException) as exc_info: + _resolve_dataset_and_region(request, session) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() + + +# --------------------------------------------------------------------------- +# _get_deterministic_simulation_id +# --------------------------------------------------------------------------- + + +class TestGetDeterministicSimulationId: + """Tests for _get_deterministic_simulation_id.""" + + def test__given_same_params__then_same_id_returned(self): + dataset_id = uuid4() + model_version_id = uuid4() + policy_id = uuid4() + dynamic_id = uuid4() + + id1 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + policy_id, + dynamic_id, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + ) + id2 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + policy_id, + dynamic_id, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + ) + + assert id1 == id2 + + def test__given_different_filter_field__then_different_id(self): + dataset_id = uuid4() + model_version_id = uuid4() + + id1 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + ) + id2 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="state_code", + filter_value="ENGLAND", + ) + + assert id1 != id2 + + def test__given_different_filter_value__then_different_id(self): + dataset_id = uuid4() + model_version_id = uuid4() + + id1 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + ) + id2 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="SCOTLAND", + ) + + assert id1 != id2 + + def test__given_filter_none_vs_filter_set__then_different_id(self): + dataset_id = uuid4() + model_version_id = uuid4() + + id_no_filter = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field=None, + filter_value=None, + ) + id_with_filter = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field="country", + filter_value="ENGLAND", + ) + + assert id_no_filter != id_with_filter + + def test__given_different_dataset__then_different_id(self): + model_version_id = uuid4() + + id1 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=uuid4(), + filter_field="country", + filter_value="ENGLAND", + ) + id2 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=uuid4(), + filter_field="country", + filter_value="ENGLAND", + ) + + assert id1 != id2 + + def test__given_null_optional_params__then_consistent_id(self): + dataset_id = uuid4() + model_version_id = uuid4() + + id1 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field=None, + filter_value=None, + ) + id2 = _get_deterministic_simulation_id( + SimulationType.ECONOMY, + model_version_id, + None, + None, + dataset_id=dataset_id, + filter_field=None, + filter_value=None, + ) + + assert id1 == id2 + + +# --------------------------------------------------------------------------- +# _get_or_create_simulation +# --------------------------------------------------------------------------- + + +class TestGetOrCreateSimulation: + """Tests for _get_or_create_simulation.""" + + def test__given_existing_simulation_with_filter__then_reuses( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + first_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + second_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + assert first_sim.id == second_sim.id + + def test__given_different_filter__then_creates_new_simulation( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + england_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + scotland_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="SCOTLAND", + ) + + assert england_sim.id != scotland_sim.id + assert england_sim.filter_value == "ENGLAND" + assert scotland_sim.filter_value == "SCOTLAND" + + def test__given_no_filter_vs_filter__then_creates_separate_simulations( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + national_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field=None, + filter_value=None, + ) + filtered_sim = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + assert national_sim.id != filtered_sim.id + assert national_sim.filter_field is None + assert filtered_sim.filter_field == "country" + + def test__given_new_simulation__then_status_is_pending(self, session: Session): + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + assert simulation.status == SimulationStatus.PENDING + + def test__given_filter_params__then_simulation_has_filter_fields( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + filter_field="country", + filter_value="ENGLAND", + ) + + assert simulation.filter_field == "country" + assert simulation.filter_value == "ENGLAND" + + def test__given_no_filter_params__then_simulation_has_null_filter_fields( + self, session: Session + ): + model = create_tax_benefit_model(session, name="policyengine-uk") + model_version = create_tax_benefit_model_version(session, model) + dataset = create_dataset(session, model, name="uk_enhanced_frs") + + simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=None, + session=session, + ) + + assert simulation.filter_field is None + assert simulation.filter_value is None + + +# --------------------------------------------------------------------------- +# HTTP endpoint validation (no database required) +# --------------------------------------------------------------------------- + + class TestEconomicImpactValidation: """Tests for request validation (no database required).""" def test_invalid_model_name(self): - """Test that invalid model name returns 422.""" response = client.post( "/analysis/economic-impact", json={ @@ -31,17 +679,15 @@ def test_invalid_model_name(self): assert response.status_code == 422 def test_missing_dataset_id(self): - """Test that missing dataset_id returns 422.""" response = client.post( "/analysis/economic-impact", json={ "tax_benefit_model_name": "policyengine_uk", }, ) - assert response.status_code == 422 + assert response.status_code == 400 def test_invalid_uuid(self): - """Test that invalid UUID returns 422.""" response = client.post( "/analysis/economic-impact", json={ @@ -56,7 +702,6 @@ class TestEconomicImpactNotFound: """Tests for 404 responses.""" def test_dataset_not_found(self): - """Test that non-existent dataset returns 404.""" response = client.post( "/analysis/economic-impact", json={ @@ -68,8 +713,11 @@ def test_dataset_not_found(self): assert "not found" in response.json()["detail"].lower() -# Integration tests that require a running database with seeded data -# These are marked with pytest.mark.integration and skipped by default +# --------------------------------------------------------------------------- +# Integration tests (require running database with seeded data) +# --------------------------------------------------------------------------- + + @pytest.mark.integration class TestEconomicImpactIntegration: """Integration tests for economic impact analysis. @@ -97,7 +745,6 @@ def uk_dataset_id(self, session: Session): return dataset.id def test_uk_economic_impact_baseline_only(self, uk_dataset_id): - """Test UK economic impact with no reform policy.""" response = client.post( "/analysis/economic-impact", json={ @@ -113,10 +760,8 @@ def test_uk_economic_impact_baseline_only(self, uk_dataset_id): assert "decile_impacts" in data assert "programme_statistics" in data - # Should have 10 deciles assert len(data["decile_impacts"]) == 10 - # Check decile structure for di in data["decile_impacts"]: assert "decile" in di assert "baseline_mean" in di @@ -124,7 +769,6 @@ def test_uk_economic_impact_baseline_only(self, uk_dataset_id): assert "absolute_change" in di def test_simulations_created(self, uk_dataset_id, session: Session): - """Test that simulations are created in the database.""" response = client.post( "/analysis/economic-impact", json={ @@ -135,7 +779,6 @@ def test_simulations_created(self, uk_dataset_id, session: Session): assert response.status_code == 200 data = response.json() - # Check simulations exist in database baseline_sim = session.get(Simulation, data["baseline_simulation_id"]) assert baseline_sim is not None assert baseline_sim.status == "completed" From 3bd874648cad080a5a0b1c025318340f52f70ff1 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 17 Feb 2026 01:06:40 +0100 Subject: [PATCH 10/10] fix: Mark TestEconomicImpactNotFound as integration test This test hits the real database (valid request passes validation), so it needs a running Supabase instance like the other integration tests. Co-Authored-By: Claude Opus 4.6 --- tests/test_analysis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_analysis.py b/tests/test_analysis.py index ebcb7c2..bfdd4a2 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -698,6 +698,7 @@ def test_invalid_uuid(self): assert response.status_code == 422 +@pytest.mark.integration class TestEconomicImpactNotFound: """Tests for 404 responses."""