diff --git a/alembic/versions/20260218_add_simulation_filter_columns.py b/alembic/versions/20260218_add_simulation_filter_columns.py new file mode 100644 index 0000000..ecc5b4c --- /dev/null +++ b/alembic/versions/20260218_add_simulation_filter_columns.py @@ -0,0 +1,41 @@ +"""add filter_field and filter_value to simulations + +Revision ID: add_sim_filters +Revises: merge_001 +Create Date: 2026-02-18 + +The Simulation model already has filter_field and filter_value fields +(used for regional economy simulations), but no migration added them +to the database. This brings the schema in line with the model. +""" + +from typing import Sequence, Union + +import sqlmodel.sql.sqltypes + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "add_sim_filters" +down_revision: Union[str, Sequence[str], None] = "merge_001" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add filter_field and filter_value columns to simulations table.""" + op.add_column( + "simulations", + sa.Column("filter_field", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + op.add_column( + "simulations", + sa.Column("filter_value", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + + +def downgrade() -> None: + """Remove filter_field and filter_value columns from simulations table.""" + op.drop_column("simulations", "filter_value") + op.drop_column("simulations", "filter_field") diff --git a/alembic/versions/20260218_drop_parent_report_id.py b/alembic/versions/20260218_drop_parent_report_id.py new file mode 100644 index 0000000..54c548c --- /dev/null +++ b/alembic/versions/20260218_drop_parent_report_id.py @@ -0,0 +1,43 @@ +"""drop parent_report_id from reports + +Revision ID: drop_parent_report +Revises: add_sim_filters +Create Date: 2026-02-18 + +Remove the unused self-referential parent_report_id foreign key from +the reports table. No code reads or writes this column. +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "drop_parent_report" +down_revision: Union[str, Sequence[str], None] = "add_sim_filters" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Drop parent_report_id column and its FK constraint.""" + op.drop_constraint( + "reports_parent_report_id_fkey", "reports", type_="foreignkey" + ) + op.drop_column("reports", "parent_report_id") + + +def downgrade() -> None: + """Re-add parent_report_id column and FK constraint.""" + op.add_column( + "reports", + sa.Column("parent_report_id", sa.Uuid(), nullable=True), + ) + op.create_foreign_key( + "reports_parent_report_id_fkey", + "reports", + "reports", + ["parent_report_id"], + ["id"], + ) diff --git a/alembic/versions/20260218_merge_user_policies_and_household_support.py b/alembic/versions/20260218_merge_user_policies_and_household_support.py new file mode 100644 index 0000000..d880c84 --- /dev/null +++ b/alembic/versions/20260218_merge_user_policies_and_household_support.py @@ -0,0 +1,32 @@ +"""merge user_policies and household_support branches + +Revision ID: merge_001 +Revises: 0002_user_policies, a1b2c3d4e5f6 +Create Date: 2026-02-18 + +Merge the two migration branches that diverged from the initial schema: +- 0002_user_policies: added user_policies table + policy.tax_benefit_model_id +- f419b5f4acba → a1b2c3d4e5f6: added household support + regions table + +No schema changes — both branches modify independent tables. +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "merge_001" +down_revision: tuple[str, str] = ("0002_user_policies", "a1b2c3d4e5f6") +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """No schema changes — merge only.""" + pass + + +def downgrade() -> None: + """No schema changes — merge only.""" + pass diff --git a/alembic/versions/20260219_621977f3b1aa_add_user_simulation_associations_table.py b/alembic/versions/20260219_621977f3b1aa_add_user_simulation_associations_table.py new file mode 100644 index 0000000..c31a778 --- /dev/null +++ b/alembic/versions/20260219_621977f3b1aa_add_user_simulation_associations_table.py @@ -0,0 +1,47 @@ +"""add user_simulation_associations table + +Revision ID: 621977f3b1aa +Revises: drop_parent_report +Create Date: 2026-02-19 00:37:43.378088 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision: str = '621977f3b1aa' +down_revision: Union[str, Sequence[str], None] = 'drop_parent_report' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('user_simulation_associations', + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('simulation_id', sa.Uuid(), nullable=False), + sa.Column('country_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['simulation_id'], ['simulations.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_user_simulation_associations_simulation_id'), 'user_simulation_associations', ['simulation_id'], unique=False) + op.create_index(op.f('ix_user_simulation_associations_user_id'), 'user_simulation_associations', ['user_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_user_simulation_associations_user_id'), table_name='user_simulation_associations') + op.drop_index(op.f('ix_user_simulation_associations_simulation_id'), table_name='user_simulation_associations') + op.drop_table('user_simulation_associations') + # ### end Alembic commands ### diff --git a/alembic/versions/20260219_9daa015274dd_add_user_report_associations_table.py b/alembic/versions/20260219_9daa015274dd_add_user_report_associations_table.py new file mode 100644 index 0000000..b9edb4f --- /dev/null +++ b/alembic/versions/20260219_9daa015274dd_add_user_report_associations_table.py @@ -0,0 +1,48 @@ +"""add user_report_associations table + +Revision ID: 9daa015274dd +Revises: 621977f3b1aa +Create Date: 2026-02-19 16:58:03.157551 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision: str = '9daa015274dd' +down_revision: Union[str, Sequence[str], None] = '621977f3b1aa' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('user_report_associations', + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('report_id', sa.Uuid(), nullable=False), + sa.Column('country_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('last_run_at', sa.DateTime(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_user_report_associations_report_id'), 'user_report_associations', ['report_id'], unique=False) + op.create_index(op.f('ix_user_report_associations_user_id'), 'user_report_associations', ['user_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_user_report_associations_user_id'), table_name='user_report_associations') + op.drop_index(op.f('ix_user_report_associations_report_id'), table_name='user_report_associations') + op.drop_table('user_report_associations') + # ### end Alembic commands ### diff --git a/src/policyengine_api/api/__init__.py b/src/policyengine_api/api/__init__.py index 7e94d29..dd8deac 100644 --- a/src/policyengine_api/api/__init__.py +++ b/src/policyengine_api/api/__init__.py @@ -21,6 +21,8 @@ tax_benefit_models, user_household_associations, user_policies, + user_report_associations, + user_simulation_associations, variables, ) @@ -45,5 +47,7 @@ api_router.include_router(agent.router) api_router.include_router(user_household_associations.router) api_router.include_router(user_policies.router) +api_router.include_router(user_simulation_associations.router) +api_router.include_router(user_report_associations.router) __all__ = ["api_router"] diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index 17539d7..ba739c7 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -591,6 +591,218 @@ def build_dynamic(dynamic_id): session.commit() +def _run_local_economy_comparison_us(job_id: str, session: Session) -> None: + """Run US economy comparison analysis locally.""" + from datetime import datetime, timezone + from uuid import UUID + + from policyengine.core import Simulation as PESimulation + from policyengine.core.dynamic import Dynamic as PEDynamic + from policyengine.core.policy import ParameterValue as PEParameterValue + from policyengine.core.policy import Policy as PEPolicy + from policyengine.outputs import DecileImpact as PEDecileImpact + from policyengine.tax_benefit_models.us import us_latest + from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset + from policyengine.tax_benefit_models.us.outputs import ( + ProgramStatistics as PEProgramStats, + ) + + from policyengine_api.models import Policy as DBPolicy + + # Load report and simulations + report = session.get(Report, UUID(job_id)) + if not report: + raise ValueError(f"Report {job_id} not found") + + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + reform_sim = session.get(Simulation, report.reform_simulation_id) + + if not baseline_sim or not reform_sim: + raise ValueError("Simulations not found") + + # Update status to running + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Get dataset + dataset = session.get(Dataset, baseline_sim.dataset_id) + if not dataset: + raise ValueError(f"Dataset {baseline_sim.dataset_id} not found") + + pe_model_version = us_latest + param_lookup = {p.name: p for p in pe_model_version.parameters} + + def build_policy(policy_id): + if not policy_id: + return None + db_policy = session.get(DBPolicy, policy_id) + if not db_policy: + return None + pe_param_values = [] + for pv in db_policy.parameter_values: + if not pv.parameter: + continue + pe_param = param_lookup.get(pv.parameter.name) + if not pe_param: + continue + pe_pv = PEParameterValue( + parameter=pe_param, + value=pv.value_json.get("value") + if isinstance(pv.value_json, dict) + else pv.value_json, + start_date=pv.start_date, + end_date=pv.end_date, + ) + pe_param_values.append(pe_pv) + return PEPolicy( + name=db_policy.name, + description=db_policy.description, + parameter_values=pe_param_values, + ) + + def build_dynamic(dynamic_id): + if not dynamic_id: + return None + from policyengine_api.models import Dynamic as DBDynamic + + db_dynamic = session.get(DBDynamic, dynamic_id) + if not db_dynamic: + return None + pe_param_values = [] + for pv in db_dynamic.parameter_values: + if not pv.parameter: + continue + pe_param = param_lookup.get(pv.parameter.name) + if not pe_param: + continue + pe_pv = PEParameterValue( + parameter=pe_param, + value=pv.value_json.get("value") + if isinstance(pv.value_json, dict) + else pv.value_json, + start_date=pv.start_date, + end_date=pv.end_date, + ) + pe_param_values.append(pe_pv) + return PEDynamic( + name=db_dynamic.name, + description=db_dynamic.description, + parameter_values=pe_param_values, + ) + + baseline_policy = build_policy(baseline_sim.policy_id) + reform_policy = build_policy(reform_sim.policy_id) + baseline_dynamic = build_dynamic(baseline_sim.dynamic_id) + reform_dynamic = build_dynamic(reform_sim.dynamic_id) + + # Download dataset + local_path = _download_dataset_local(dataset.filepath) + pe_dataset = PolicyEngineUSDataset( + name=dataset.name, + description=dataset.description or "", + filepath=local_path, + year=dataset.year, + ) + + # Run simulations (with optional regional filtering) + pe_baseline_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=baseline_policy, + dynamic=baseline_dynamic, + filter_field=baseline_sim.filter_field, + filter_value=baseline_sim.filter_value, + ) + pe_baseline_sim.ensure() + + pe_reform_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=reform_policy, + dynamic=reform_dynamic, + filter_field=reform_sim.filter_field, + filter_value=reform_sim.filter_value, + ) + pe_reform_sim.ensure() + + # Calculate decile impacts + for decile_num in range(1, 11): + di = PEDecileImpact( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + decile=decile_num, + ) + di.run() + decile_impact = DecileImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + income_variable=di.income_variable, + entity=di.entity, + decile=di.decile, + quantiles=di.quantiles, + baseline_mean=di.baseline_mean, + reform_mean=di.reform_mean, + absolute_change=di.absolute_change, + relative_change=di.relative_change, + count_better_off=di.count_better_off, + count_worse_off=di.count_worse_off, + count_no_change=di.count_no_change, + ) + session.add(decile_impact) + + # Calculate program statistics + PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) + programs = { + "income_tax": {"entity": "tax_unit", "is_tax": True}, + "employee_payroll_tax": {"entity": "person", "is_tax": True}, + "snap": {"entity": "spm_unit", "is_tax": False}, + "tanf": {"entity": "spm_unit", "is_tax": False}, + "ssi": {"entity": "spm_unit", "is_tax": False}, + "social_security": {"entity": "person", "is_tax": False}, + } + for prog_name, prog_info in programs.items(): + try: + ps = PEProgramStats( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + ) + ps.run() + program_stat = ProgramStatistics( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + baseline_total=ps.baseline_total, + reform_total=ps.reform_total, + change=ps.change, + baseline_count=ps.baseline_count, + reform_count=ps.reform_count, + winners=ps.winners, + losers=ps.losers, + ) + session.add(program_stat) + except KeyError: + pass # Variable not found in model + + # Mark completed + baseline_sim.status = SimulationStatus.COMPLETED + baseline_sim.completed_at = datetime.now(timezone.utc) + reform_sim.status = SimulationStatus.COMPLETED + reform_sim.completed_at = datetime.now(timezone.utc) + report.status = ReportStatus.COMPLETED + session.add(baseline_sim) + session.add(reform_sim) + session.add(report) + session.commit() + + def _trigger_economy_comparison( job_id: str, tax_benefit_model_name: str, session: Session | None = None ) -> None: @@ -604,11 +816,7 @@ def _trigger_economy_comparison( if tax_benefit_model_name == "policyengine_uk": _run_local_economy_comparison_uk(job_id, session) else: - # US not implemented for local yet - fall back to Modal - import modal - - fn = modal.Function.from_name("policyengine", "economy_comparison_us") - fn.spawn(job_id=job_id, traceparent=traceparent) + _run_local_economy_comparison_us(job_id, session) else: # Use Modal import modal diff --git a/src/policyengine_api/api/simulations.py b/src/policyengine_api/api/simulations.py index 633c57c..bf57c32 100644 --- a/src/policyengine_api/api/simulations.py +++ b/src/policyengine_api/api/simulations.py @@ -1,36 +1,362 @@ -"""Simulation status endpoints. +"""Simulation endpoints. -Simulations are economy-wide tax-benefit calculations running on population datasets. -They are created automatically when you call /analysis/economic-impact. Use these -endpoints to check simulation status (pending, running, completed, failed). +Simulations are individual tax-benefit calculations. Use these endpoints to: +- Create and run household simulations (single household, single policy) +- Create and run economy simulations (population dataset, single policy) +- Check simulation status and retrieve results + +For baseline-vs-reform comparisons, use the /analysis/ endpoints instead. """ -from typing import List +from typing import Any, List, Literal from uuid import UUID from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field from sqlmodel import Session, select -from policyengine_api.models import Simulation, SimulationRead +from policyengine_api.models import ( + Dataset, + Household, + Policy, + Region, + Simulation, + SimulationRead, + SimulationStatus, + SimulationType, + TaxBenefitModel, +) from policyengine_api.services.database import get_session +from .analysis import ( + RegionInfo, + _get_model_version, + _get_or_create_simulation, +) + router = APIRouter(prefix="/simulations", tags=["simulations"]) +# --------------------------------------------------------------------------- +# Request / Response schemas +# --------------------------------------------------------------------------- + + +class HouseholdSimulationRequest(BaseModel): + """Request body for creating a household simulation.""" + + household_id: UUID = Field(description="ID of the stored household") + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID. If None, runs under current law.", + ) + dynamic_id: UUID | None = Field( + default=None, + description="Optional behavioural response specification ID", + ) + + +class HouseholdSimulationResponse(BaseModel): + """Response for a household simulation.""" + + id: UUID + status: SimulationStatus + household_id: UUID | None = None + policy_id: UUID | None = None + household_result: dict[str, Any] | None = None + error_message: str | None = None + + +class EconomySimulationRequest(BaseModel): + """Request body for creating an economy simulation.""" + + tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( + description="Which country model to use" + ) + region: str | None = Field( + default=None, + description="Region code (e.g., 'state/ca', 'us'). Either region or dataset_id must be provided.", + ) + dataset_id: UUID | None = Field( + default=None, + description="Dataset ID. Either region or dataset_id must be provided.", + ) + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID. If None, runs under current law.", + ) + dynamic_id: UUID | None = Field( + default=None, + description="Optional behavioural response specification ID", + ) + + +class EconomySimulationResponse(BaseModel): + """Response for an economy simulation.""" + + id: UUID + status: SimulationStatus + dataset_id: UUID | None = None + policy_id: UUID | None = None + output_dataset_id: UUID | None = None + filter_field: str | None = None + filter_value: str | None = None + region: RegionInfo | None = None + error_message: str | None = None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _resolve_economy_dataset( + tax_benefit_model_name: str, + region_code: str | None, + dataset_id: UUID | None, + session: Session, +) -> tuple[Dataset, Region | None]: + """Resolve dataset from region code or dataset_id for economy simulations.""" + if region_code: + model_name = tax_benefit_model_name.replace("_", "-") + region = session.exec( + select(Region) + .join(TaxBenefitModel) + .where(Region.code == region_code) + .where(TaxBenefitModel.name == model_name) + ).first() + if not region: + raise HTTPException( + status_code=404, + detail=f"Region '{region_code}' not found for model {model_name}", + ) + dataset = session.get(Dataset, region.dataset_id) + if not dataset: + raise HTTPException( + status_code=404, + detail=f"Dataset for region '{region_code}' not found", + ) + return dataset, region + + elif dataset_id: + dataset = session.get(Dataset, dataset_id) + if not dataset: + raise HTTPException( + status_code=404, + detail=f"Dataset {dataset_id} not found", + ) + return dataset, None + + else: + raise HTTPException( + status_code=400, + detail="Either region or dataset_id must be provided", + ) + + +def _build_household_response(simulation: Simulation) -> HouseholdSimulationResponse: + """Build response from a household simulation.""" + return HouseholdSimulationResponse( + id=simulation.id, + status=simulation.status, + household_id=simulation.household_id, + policy_id=simulation.policy_id, + household_result=simulation.household_result, + error_message=simulation.error_message, + ) + + +def _build_economy_response( + simulation: Simulation, region: Region | None = None +) -> EconomySimulationResponse: + """Build response from an economy simulation.""" + region_info = None + if region: + region_info = RegionInfo( + code=region.code, + label=region.label, + region_type=region.region_type, + requires_filter=region.requires_filter, + filter_field=region.filter_field, + filter_value=region.filter_value, + ) + + return EconomySimulationResponse( + id=simulation.id, + status=simulation.status, + dataset_id=simulation.dataset_id, + policy_id=simulation.policy_id, + output_dataset_id=simulation.output_dataset_id, + filter_field=simulation.filter_field, + filter_value=simulation.filter_value, + region=region_info, + error_message=simulation.error_message, + ) + + +# --------------------------------------------------------------------------- +# List / generic get (existing endpoints) +# --------------------------------------------------------------------------- + + @router.get("/", response_model=List[SimulationRead]) def list_simulations(session: Session = Depends(get_session)): - """List all simulations. - - Simulations are created automatically via /analysis/economic-impact. - Check status to see if computation is pending, running, completed, or failed. - """ + """List all simulations.""" simulations = session.exec(select(Simulation)).all() return simulations +# --------------------------------------------------------------------------- +# Household simulation endpoints +# --------------------------------------------------------------------------- + + +@router.post("/household", response_model=HouseholdSimulationResponse) +def create_household_simulation( + request: HouseholdSimulationRequest, + session: Session = Depends(get_session), +): + """Create a household simulation job. + + Creates a Simulation record for the given household and policy. + Returns immediately with status "pending". + Poll GET /simulations/household/{id} until status is "completed". + """ + # Validate household exists + household = session.get(Household, request.household_id) + if not household: + raise HTTPException( + status_code=404, + detail=f"Household {request.household_id} not found", + ) + + # Validate policy exists (if provided) + if request.policy_id: + policy = session.get(Policy, request.policy_id) + if not policy: + raise HTTPException( + status_code=404, + detail=f"Policy {request.policy_id} not found", + ) + + # Get model version + model_version = _get_model_version(household.tax_benefit_model_name, session) + + # Get or create simulation (deterministic UUID) + simulation = _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version.id, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + session=session, + household_id=request.household_id, + ) + + return _build_household_response(simulation) + + +@router.get("/household/{simulation_id}", response_model=HouseholdSimulationResponse) +def get_household_simulation( + simulation_id: UUID, + session: Session = Depends(get_session), +): + """Get a household simulation's status and result.""" + simulation = session.get(Simulation, simulation_id) + if not simulation: + raise HTTPException(status_code=404, detail="Simulation not found") + if simulation.simulation_type != SimulationType.HOUSEHOLD: + raise HTTPException( + status_code=400, + detail="Simulation is not a household simulation", + ) + + return _build_household_response(simulation) + + +# --------------------------------------------------------------------------- +# Economy simulation endpoints +# --------------------------------------------------------------------------- + + +@router.post("/economy", response_model=EconomySimulationResponse) +def create_economy_simulation( + request: EconomySimulationRequest, + session: Session = Depends(get_session), +): + """Create a single economy simulation. + + Creates a Simulation record for the given dataset/region and policy. + Poll GET /simulations/economy/{id} until status is "completed". + + Note: standalone economy simulation computation will be connected + in future tasks. For full baseline-vs-reform economy analysis, + use POST /analysis/economic-impact instead. + """ + # Resolve dataset and region + dataset, region = _resolve_economy_dataset( + request.tax_benefit_model_name, + request.region, + request.dataset_id, + session, + ) + + # Validate policy exists (if provided) + if request.policy_id: + policy = session.get(Policy, request.policy_id) + if not policy: + raise HTTPException( + status_code=404, + detail=f"Policy {request.policy_id} not found", + ) + + # Extract filter parameters from region + filter_field = region.filter_field if region and region.requires_filter else None + filter_value = region.filter_value if region and region.requires_filter else None + + # Get model version + model_version = _get_model_version(request.tax_benefit_model_name, session) + + # Get or create simulation (deterministic UUID) + simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + model_version_id=model_version.id, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + session=session, + dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, + ) + + return _build_economy_response(simulation, region) + + +@router.get("/economy/{simulation_id}", response_model=EconomySimulationResponse) +def get_economy_simulation( + simulation_id: UUID, + session: Session = Depends(get_session), +): + """Get an economy simulation's status and result.""" + simulation = session.get(Simulation, simulation_id) + if not simulation: + raise HTTPException(status_code=404, detail="Simulation not found") + if simulation.simulation_type != SimulationType.ECONOMY: + raise HTTPException( + status_code=400, + detail="Simulation is not an economy simulation", + ) + + return _build_economy_response(simulation) + + +# --------------------------------------------------------------------------- +# Generic get (keep after specific routes to avoid path conflicts) +# --------------------------------------------------------------------------- + + @router.get("/{simulation_id}", response_model=SimulationRead) def get_simulation(simulation_id: UUID, session: Session = Depends(get_session)): - """Get a specific simulation.""" + """Get a specific simulation (any type).""" simulation = session.get(Simulation, simulation_id) if not simulation: raise HTTPException(status_code=404, detail="Simulation not found") diff --git a/src/policyengine_api/api/user_report_associations.py b/src/policyengine_api/api/user_report_associations.py new file mode 100644 index 0000000..ec1bd6b --- /dev/null +++ b/src/policyengine_api/api/user_report_associations.py @@ -0,0 +1,146 @@ +"""User-report association endpoints. + +Associates users with reports they've created. This enables users to +maintain a list of their reports across sessions without duplicating +the underlying report data. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save reports without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.config.constants import CountryId +from policyengine_api.models import ( + Report, + UserReportAssociation, + UserReportAssociationCreate, + UserReportAssociationRead, + UserReportAssociationUpdate, +) +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/user-reports", tags=["user-reports"]) + + +@router.post("/", response_model=UserReportAssociationRead) +def create_user_report( + body: UserReportAssociationCreate, + session: Session = Depends(get_session), +): + """Create a new user-report association. + + Associates a user with a report, allowing them to save it to their list. + Duplicates are allowed - users can save the same report multiple times + with different labels. + """ + report = session.get(Report, body.report_id) + if not report: + raise HTTPException(status_code=404, detail="Report not found") + + record = UserReportAssociation.model_validate(body) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.get("/", response_model=list[UserReportAssociationRead]) +def list_user_reports( + user_id: UUID = Query(..., description="User ID to filter by"), + country_id: CountryId | None = Query( + None, description="Filter by country ('us' or 'uk')" + ), + session: Session = Depends(get_session), +): + """List all report associations for a user. + + Returns all reports saved by the specified user. Optionally filter by country. + """ + query = select(UserReportAssociation).where( + UserReportAssociation.user_id == user_id + ) + + if country_id: + query = query.where(UserReportAssociation.country_id == country_id) + + return session.exec(query).all() + + +@router.get("/{user_report_id}", response_model=UserReportAssociationRead) +def get_user_report( + user_report_id: UUID, + session: Session = Depends(get_session), +): + """Get a specific user-report association by ID.""" + record = session.get(UserReportAssociation, user_report_id) + if not record: + raise HTTPException( + status_code=404, detail="User-report association not found" + ) + return record + + +@router.patch("/{user_report_id}", response_model=UserReportAssociationRead) +def update_user_report( + user_report_id: UUID, + updates: UserReportAssociationUpdate, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Update a user-report association (e.g., rename label or update last_run_at). + + Requires user_id to verify ownership - only the owner can update. + """ + record = session.exec( + select(UserReportAssociation).where( + UserReportAssociation.id == user_report_id, + UserReportAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException( + status_code=404, detail="User-report association not found" + ) + + update_data = updates.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(record, key, value) + + record.updated_at = datetime.now(timezone.utc) + + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.delete("/{user_report_id}", status_code=204) +def delete_user_report( + user_report_id: UUID, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Delete a user-report association. + + This only removes the association, not the underlying report. + Requires user_id to verify ownership - only the owner can delete. + """ + record = session.exec( + select(UserReportAssociation).where( + UserReportAssociation.id == user_report_id, + UserReportAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException( + status_code=404, detail="User-report association not found" + ) + + session.delete(record) + session.commit() diff --git a/src/policyengine_api/api/user_simulation_associations.py b/src/policyengine_api/api/user_simulation_associations.py new file mode 100644 index 0000000..2341d91 --- /dev/null +++ b/src/policyengine_api/api/user_simulation_associations.py @@ -0,0 +1,146 @@ +"""User-simulation association endpoints. + +Associates users with simulations they've run. This enables users to +maintain a list of their simulations across sessions without duplicating +the underlying simulation data. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save simulations without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.config.constants import CountryId +from policyengine_api.models import ( + Simulation, + UserSimulationAssociation, + UserSimulationAssociationCreate, + UserSimulationAssociationRead, + UserSimulationAssociationUpdate, +) +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/user-simulations", tags=["user-simulations"]) + + +@router.post("/", response_model=UserSimulationAssociationRead) +def create_user_simulation( + body: UserSimulationAssociationCreate, + session: Session = Depends(get_session), +): + """Create a new user-simulation association. + + Associates a user with a simulation, allowing them to save it to their list. + Duplicates are allowed - users can save the same simulation multiple times + with different labels. + """ + simulation = session.get(Simulation, body.simulation_id) + if not simulation: + raise HTTPException(status_code=404, detail="Simulation not found") + + record = UserSimulationAssociation.model_validate(body) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.get("/", response_model=list[UserSimulationAssociationRead]) +def list_user_simulations( + user_id: UUID = Query(..., description="User ID to filter by"), + country_id: CountryId | None = Query( + None, description="Filter by country ('us' or 'uk')" + ), + session: Session = Depends(get_session), +): + """List all simulation associations for a user. + + Returns all simulations saved by the specified user. Optionally filter by country. + """ + query = select(UserSimulationAssociation).where( + UserSimulationAssociation.user_id == user_id + ) + + if country_id: + query = query.where(UserSimulationAssociation.country_id == country_id) + + return session.exec(query).all() + + +@router.get("/{user_simulation_id}", response_model=UserSimulationAssociationRead) +def get_user_simulation( + user_simulation_id: UUID, + session: Session = Depends(get_session), +): + """Get a specific user-simulation association by ID.""" + record = session.get(UserSimulationAssociation, user_simulation_id) + if not record: + raise HTTPException( + status_code=404, detail="User-simulation association not found" + ) + return record + + +@router.patch("/{user_simulation_id}", response_model=UserSimulationAssociationRead) +def update_user_simulation( + user_simulation_id: UUID, + updates: UserSimulationAssociationUpdate, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Update a user-simulation association (e.g., rename label). + + Requires user_id to verify ownership - only the owner can update. + """ + record = session.exec( + select(UserSimulationAssociation).where( + UserSimulationAssociation.id == user_simulation_id, + UserSimulationAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException( + status_code=404, detail="User-simulation association not found" + ) + + update_data = updates.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(record, key, value) + + record.updated_at = datetime.now(timezone.utc) + + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.delete("/{user_simulation_id}", status_code=204) +def delete_user_simulation( + user_simulation_id: UUID, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Delete a user-simulation association. + + This only removes the association, not the underlying simulation. + Requires user_id to verify ownership - only the owner can delete. + """ + record = session.exec( + select(UserSimulationAssociation).where( + UserSimulationAssociation.id == user_simulation_id, + UserSimulationAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException( + status_code=404, detail="User-simulation association not found" + ) + + session.delete(record) + session.commit() diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index 84f8e89..0feb0d7 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -1736,6 +1736,294 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: logfire.force_flush() +# --------------------------------------------------------------------------- +# Household impact (report-based) — called by _trigger_household_impact() +# --------------------------------------------------------------------------- + + +@app.function( + image=uk_image, + secrets=[db_secrets, logfire_secrets], + memory=4096, + cpu=4, + timeout=600, +) +def household_impact_uk(report_id: str, traceparent: str | None = None) -> None: + """Run UK household impact analysis for a report. + + Loads the Report and its Simulations from the database, runs household + calculations for each simulation, stores results, and marks the report + as completed. Called via Modal.spawn() from _trigger_household_impact(). + """ + import logfire + + configure_logfire("policyengine-modal-uk", traceparent) + + try: + with logfire.span("household_impact_uk", report_id=report_id): + from datetime import datetime, timezone + from uuid import UUID + + from sqlmodel import Session, create_engine + + database_url = get_database_url() + engine = create_engine(database_url) + + try: + from policyengine_api.api.household import _calculate_household_uk + from policyengine_api.api.household_analysis import ( + _ensure_list, + _extract_policy_data, + ) + from policyengine_api.models import ( + Household, + Report, + ReportStatus, + Simulation, + SimulationStatus, + ) + + with Session(engine) as session: + report = session.get(Report, UUID(report_id)) + if not report: + raise ValueError(f"Report {report_id} not found") + + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Run each simulation (baseline, then reform if present) + for sim_id in [ + report.baseline_simulation_id, + report.reform_simulation_id, + ]: + if not sim_id: + continue + + simulation = session.get(Simulation, sim_id) + if not simulation or simulation.status != SimulationStatus.PENDING: + continue + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError( + f"Household {simulation.household_id} not found" + ) + + # Convert policy to calculation format + policy_data = None + if simulation.policy_id: + from policyengine_api.models import Policy + + policy = session.get(Policy, simulation.policy_id) + policy_data = _extract_policy_data(policy) + + # Mark simulation as running + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + hh_data = household.household_data + with logfire.span( + "run_household_calculation", + simulation_id=str(sim_id), + ): + result = _calculate_household_uk( + people=hh_data.get("people", []), + benunit=_ensure_list(hh_data.get("benunit")), + household=_ensure_list(hh_data.get("household")), + year=household.year, + policy_data=policy_data, + ) + + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + + except Exception as e: + logfire.error( + "UK household impact failed", + report_id=report_id, + error=str(e), + ) + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE reports SET status = 'FAILED', " + "error_message = :error WHERE id = :report_id" + ), + {"report_id": report_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise + finally: + logfire.force_flush() + + +@app.function( + image=us_image, + secrets=[db_secrets, logfire_secrets], + memory=4096, + cpu=4, + timeout=600, +) +def household_impact_us(report_id: str, traceparent: str | None = None) -> None: + """Run US household impact analysis for a report. + + Loads the Report and its Simulations from the database, runs household + calculations for each simulation, stores results, and marks the report + as completed. Called via Modal.spawn() from _trigger_household_impact(). + """ + import logfire + + configure_logfire("policyengine-modal-us", traceparent) + + try: + with logfire.span("household_impact_us", report_id=report_id): + from datetime import datetime, timezone + from uuid import UUID + + from sqlmodel import Session, create_engine + + database_url = get_database_url() + engine = create_engine(database_url) + + try: + from policyengine_api.api.household import _calculate_household_us + from policyengine_api.api.household_analysis import ( + _ensure_list, + _extract_policy_data, + ) + from policyengine_api.models import ( + Household, + Report, + ReportStatus, + Simulation, + SimulationStatus, + ) + + with Session(engine) as session: + report = session.get(Report, UUID(report_id)) + if not report: + raise ValueError(f"Report {report_id} not found") + + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Run each simulation (baseline, then reform if present) + for sim_id in [ + report.baseline_simulation_id, + report.reform_simulation_id, + ]: + if not sim_id: + continue + + simulation = session.get(Simulation, sim_id) + if not simulation or simulation.status != SimulationStatus.PENDING: + continue + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError( + f"Household {simulation.household_id} not found" + ) + + # Convert policy to calculation format + policy_data = None + if simulation.policy_id: + from policyengine_api.models import Policy + + policy = session.get(Policy, simulation.policy_id) + policy_data = _extract_policy_data(policy) + + # Mark simulation as running + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + hh_data = household.household_data + with logfire.span( + "run_household_calculation", + simulation_id=str(sim_id), + ): + result = _calculate_household_us( + people=hh_data.get("people", []), + marital_unit=_ensure_list( + hh_data.get("marital_unit") + ), + family=_ensure_list(hh_data.get("family")), + spm_unit=_ensure_list(hh_data.get("spm_unit")), + tax_unit=_ensure_list(hh_data.get("tax_unit")), + household=_ensure_list(hh_data.get("household")), + year=household.year, + policy_data=policy_data, + ) + + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + + except Exception as e: + logfire.error( + "US household impact failed", + report_id=report_id, + error=str(e), + ) + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE reports SET status = 'FAILED', " + "error_message = :error WHERE id = :report_id" + ), + {"report_id": report_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise + finally: + logfire.force_flush() + + def _get_pe_policy_uk(policy_id, model_version, session): """Convert database Policy to policyengine Policy for UK.""" if policy_id is None: diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index e7b386a..e73b75a 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -61,6 +61,18 @@ UserHouseholdAssociationRead, UserHouseholdAssociationUpdate, ) +from .user_simulation_association import ( + UserSimulationAssociation, + UserSimulationAssociationCreate, + UserSimulationAssociationRead, + UserSimulationAssociationUpdate, +) +from .user_report_association import ( + UserReportAssociation, + UserReportAssociationCreate, + UserReportAssociationRead, + UserReportAssociationUpdate, +) from .user_policy import ( UserPolicy, UserPolicyCreate, @@ -142,6 +154,14 @@ "UserHouseholdAssociationRead", "UserHouseholdAssociationUpdate", "UserRead", + "UserSimulationAssociation", + "UserSimulationAssociationCreate", + "UserSimulationAssociationRead", + "UserSimulationAssociationUpdate", + "UserReportAssociation", + "UserReportAssociationCreate", + "UserReportAssociationRead", + "UserReportAssociationUpdate", "UserPolicy", "UserPolicyCreate", "UserPolicyRead", diff --git a/src/policyengine_api/models/report.py b/src/policyengine_api/models/report.py index bc2cd40..ffc33b8 100644 --- a/src/policyengine_api/models/report.py +++ b/src/policyengine_api/models/report.py @@ -22,7 +22,6 @@ class ReportBase(SQLModel): report_type: str | None = None user_id: UUID | None = Field(default=None, foreign_key="users.id") markdown: str | None = Field(default=None, sa_column=Column(Text)) - parent_report_id: UUID | None = Field(default=None, foreign_key="reports.id") status: ReportStatus = ReportStatus.PENDING error_message: str | None = None baseline_simulation_id: UUID | None = Field( diff --git a/src/policyengine_api/models/user_report_association.py b/src/policyengine_api/models/user_report_association.py new file mode 100644 index 0000000..4f078cb --- /dev/null +++ b/src/policyengine_api/models/user_report_association.py @@ -0,0 +1,63 @@ +"""User-report association model. + +Associates users with reports they've created. This enables users to +maintain a list of their reports across sessions. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save reports without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine_api.config.constants import CountryId + + +class UserReportAssociationBase(SQLModel): + """Base association fields.""" + + user_id: UUID = Field(index=True) + report_id: UUID = Field(foreign_key="reports.id", index=True) + country_id: str + label: str | None = None + last_run_at: datetime | None = None + + +class UserReportAssociation(UserReportAssociationBase, table=True): + """User-report association database model.""" + + __tablename__ = "user_report_associations" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class UserReportAssociationCreate(SQLModel): + """Schema for creating user-report associations.""" + + user_id: UUID + report_id: UUID + country_id: CountryId + label: str | None = None + last_run_at: datetime | None = None + + +class UserReportAssociationRead(UserReportAssociationBase): + """Schema for reading user-report associations.""" + + id: UUID + created_at: datetime + updated_at: datetime + + +class UserReportAssociationUpdate(SQLModel): + """Schema for updating user-report associations.""" + + model_config = {"extra": "forbid"} + + label: str | None = None + last_run_at: datetime | None = None diff --git a/src/policyengine_api/models/user_simulation_association.py b/src/policyengine_api/models/user_simulation_association.py new file mode 100644 index 0000000..9b07d19 --- /dev/null +++ b/src/policyengine_api/models/user_simulation_association.py @@ -0,0 +1,60 @@ +"""User-simulation association model. + +Associates users with simulations they've run. This enables users to +maintain a list of their simulations across sessions. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save simulations without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine_api.config.constants import CountryId + + +class UserSimulationAssociationBase(SQLModel): + """Base association fields.""" + + user_id: UUID = Field(index=True) + simulation_id: UUID = Field(foreign_key="simulations.id", index=True) + country_id: str + label: str | None = None + + +class UserSimulationAssociation(UserSimulationAssociationBase, table=True): + """User-simulation association database model.""" + + __tablename__ = "user_simulation_associations" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class UserSimulationAssociationCreate(SQLModel): + """Schema for creating user-simulation associations.""" + + user_id: UUID + simulation_id: UUID + country_id: CountryId + label: str | None = None + + +class UserSimulationAssociationRead(UserSimulationAssociationBase): + """Schema for reading user-simulation associations.""" + + id: UUID + created_at: datetime + updated_at: datetime + + +class UserSimulationAssociationUpdate(SQLModel): + """Schema for updating user-simulation associations.""" + + model_config = {"extra": "forbid"} + + label: str | None = None diff --git a/test_fixtures/fixtures_simulations_standalone.py b/test_fixtures/fixtures_simulations_standalone.py new file mode 100644 index 0000000..314afe5 --- /dev/null +++ b/test_fixtures/fixtures_simulations_standalone.py @@ -0,0 +1,164 @@ +"""Fixtures and helpers for standalone simulation endpoint tests.""" + +from uuid import UUID + +from policyengine_api.models import ( + Dataset, + Household, + Policy, + Region, + Simulation, + SimulationStatus, + SimulationType, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +def create_us_model_and_version(session) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create a US tax-benefit model and version.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="test", description="Test version" + ) + session.add(version) + session.commit() + session.refresh(version) + + return model, version + + +def create_uk_model_and_version(session) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create a UK tax-benefit model and version.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="test", description="Test version" + ) + session.add(version) + session.commit() + session.refresh(version) + + return model, version + + +def create_household( + session, + tax_benefit_model_name: str = "policyengine_us", + year: int = 2024, + label: str = "Test household", +) -> Household: + """Create and persist a Household record.""" + household = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data={ + "people": [{"age": {"2024": 30}, "employment_income": {"2024": 50000}}], + "household": [{"state_code": {"2024": "CA"}}], + }, + ) + session.add(household) + session.commit() + session.refresh(household) + return household + + +def create_policy(session, model: TaxBenefitModel) -> Policy: + """Create and persist a Policy record.""" + policy = Policy( + name="Test reform", + description="A test reform policy", + tax_benefit_model_id=model.id, + ) + session.add(policy) + session.commit() + session.refresh(policy) + return policy + + +def create_dataset(session, model: TaxBenefitModel) -> Dataset: + """Create and persist a Dataset record.""" + dataset = Dataset( + name="test_dataset", + description="Test dataset", + filepath="test/path/dataset.h5", + year=2024, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + return dataset + + +def create_region( + session, + model: TaxBenefitModel, + dataset: Dataset, + code: str = "us", + label: str = "United States", + region_type: str = "country", + requires_filter: bool = False, + filter_field: str | None = None, + filter_value: str | None = None, +) -> Region: + """Create and persist a Region record.""" + region = Region( + code=code, + label=label, + region_type=region_type, + requires_filter=requires_filter, + filter_field=filter_field, + filter_value=filter_value, + dataset_id=dataset.id, + tax_benefit_model_id=model.id, + ) + session.add(region) + session.commit() + session.refresh(region) + return region + + +def create_economy_simulation( + session, + version: TaxBenefitModelVersion, + dataset: Dataset, +) -> Simulation: + """Create and persist an economy Simulation record.""" + simulation = Simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + return simulation + + +def create_household_simulation( + session, + version: TaxBenefitModelVersion, + household: Household, +) -> Simulation: + """Create and persist a household Simulation record.""" + simulation = Simulation( + simulation_type=SimulationType.HOUSEHOLD, + household_id=household.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + household_result={"person": [{"income_tax": {"2024": 5000}}]}, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + return simulation diff --git a/test_fixtures/fixtures_user_report_associations.py b/test_fixtures/fixtures_user_report_associations.py new file mode 100644 index 0000000..4ef07df --- /dev/null +++ b/test_fixtures/fixtures_user_report_associations.py @@ -0,0 +1,101 @@ +"""Fixtures and helpers for user-report association tests.""" + +from datetime import datetime +from uuid import UUID + +from policyengine_api.models import ( + Dataset, + Report, + ReportStatus, + Simulation, + SimulationStatus, + TaxBenefitModel, + TaxBenefitModelVersion, + UserReportAssociation, +) + + +def create_tax_benefit_model( + session, + name: str = "policyengine-us", + description: str = "US model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + record = TaxBenefitModel(name=name, description=description) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_report(session, model: TaxBenefitModel | None = None) -> Report: + """Create and persist a Report with required simulation dependencies.""" + if model is None: + model = create_tax_benefit_model(session) + + version = TaxBenefitModelVersion( + model_id=model.id, version="test", description="Test version" + ) + session.add(version) + session.commit() + session.refresh(version) + + dataset = Dataset( + name="test_dataset", + description="Test dataset", + filepath="test/path/dataset.h5", + year=2024, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + + baseline = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + reform = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + session.add(baseline) + session.add(reform) + session.commit() + session.refresh(baseline) + session.refresh(reform) + + report = Report( + label="Test report", + status=ReportStatus.COMPLETED, + baseline_simulation_id=baseline.id, + reform_simulation_id=reform.id, + ) + session.add(report) + session.commit() + session.refresh(report) + return report + + +def create_user_report_association( + session, + user_id: UUID, + report: Report, + country_id: str = "us", + label: str | None = None, + last_run_at: datetime | None = None, +) -> UserReportAssociation: + """Create and persist a UserReportAssociation record.""" + record = UserReportAssociation( + user_id=user_id, + report_id=report.id, + country_id=country_id, + label=label, + last_run_at=last_run_at, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/test_fixtures/fixtures_user_simulation_associations.py b/test_fixtures/fixtures_user_simulation_associations.py new file mode 100644 index 0000000..c2cbd74 --- /dev/null +++ b/test_fixtures/fixtures_user_simulation_associations.py @@ -0,0 +1,79 @@ +"""Fixtures and helpers for user-simulation association tests.""" + +from uuid import UUID + +from policyengine_api.models import ( + Dataset, + Simulation, + SimulationStatus, + TaxBenefitModel, + TaxBenefitModelVersion, + UserSimulationAssociation, +) + + +def create_tax_benefit_model( + session, + name: str = "policyengine-us", + description: str = "US model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + record = TaxBenefitModel(name=name, description=description) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_simulation(session, model: TaxBenefitModel | None = None) -> Simulation: + """Create and persist a Simulation with required dependencies.""" + if model is None: + model = create_tax_benefit_model(session) + + version = TaxBenefitModelVersion( + model_id=model.id, version="test", description="Test version" + ) + session.add(version) + session.commit() + session.refresh(version) + + dataset = Dataset( + name="test_dataset", + description="Test dataset", + filepath="test/path/dataset.h5", + year=2024, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + + simulation = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + return simulation + + +def create_user_simulation_association( + session, + user_id: UUID, + simulation: Simulation, + country_id: str = "us", + label: str | None = None, +) -> UserSimulationAssociation: + """Create and persist a UserSimulationAssociation record.""" + record = UserSimulationAssociation( + user_id=user_id, + simulation_id=simulation.id, + country_id=country_id, + label=label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/tests/test_simulations_standalone.py b/tests/test_simulations_standalone.py new file mode 100644 index 0000000..97d5743 --- /dev/null +++ b/tests/test_simulations_standalone.py @@ -0,0 +1,312 @@ +"""Tests for standalone simulation endpoints (/simulations/household, /simulations/economy).""" + +from uuid import uuid4 + +from test_fixtures.fixtures_simulations_standalone import ( + create_dataset, + create_economy_simulation, + create_household, + create_household_simulation, + create_policy, + create_region, + create_uk_model_and_version, + create_us_model_and_version, +) + + +# =========================================================================== +# POST /simulations/household +# =========================================================================== + + +def test_create_household_simulation(client, session): + """Create a household simulation returns 200 with pending status.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + + payload = {"household_id": str(household.id)} + response = client.post("/simulations/household", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["household_id"] == str(household.id) + assert data["household_result"] is None + assert data["policy_id"] is None + + +def test_create_household_simulation_with_policy(client, session): + """Create a household simulation with a reform policy.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + policy = create_policy(session, model) + + payload = { + "household_id": str(household.id), + "policy_id": str(policy.id), + } + response = client.post("/simulations/household", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["policy_id"] == str(policy.id) + + +def test_create_household_simulation_not_found(client, session): + """Creating with a non-existent household returns 404.""" + model, version = create_us_model_and_version(session) + payload = {"household_id": str(uuid4())} + response = client.post("/simulations/household", json=payload) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_create_household_simulation_policy_not_found(client, session): + """Creating with a non-existent policy returns 404.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + + payload = { + "household_id": str(household.id), + "policy_id": str(uuid4()), + } + response = client.post("/simulations/household", json=payload) + + assert response.status_code == 404 + assert "Policy" in response.json()["detail"] + + +def test_household_simulation_deduplication(client, session): + """Same inputs produce the same simulation (deterministic UUID).""" + model, version = create_us_model_and_version(session) + household = create_household(session) + + payload = {"household_id": str(household.id)} + response1 = client.post("/simulations/household", json=payload) + response2 = client.post("/simulations/household", json=payload) + + assert response1.status_code == 200 + assert response2.status_code == 200 + assert response1.json()["id"] == response2.json()["id"] + + +# =========================================================================== +# GET /simulations/household/{id} +# =========================================================================== + + +def test_get_household_simulation(client, session): + """Get a household simulation by ID.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + simulation = create_household_simulation(session, version, household) + + response = client.get(f"/simulations/household/{simulation.id}") + + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(simulation.id) + assert data["status"] == "completed" + assert data["household_result"] is not None + + +def test_get_household_simulation_not_found(client, session): + """Get a non-existent household simulation returns 404.""" + response = client.get(f"/simulations/household/{uuid4()}") + assert response.status_code == 404 + + +def test_get_household_simulation_wrong_type(client, session): + """Get an economy simulation via the household endpoint returns 400.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + economy_sim = create_economy_simulation(session, version, dataset) + + response = client.get(f"/simulations/household/{economy_sim.id}") + assert response.status_code == 400 + assert "not a household simulation" in response.json()["detail"] + + +# =========================================================================== +# POST /simulations/economy +# =========================================================================== + + +def test_create_economy_simulation_with_region(client, session): + """Create an economy simulation using a region code.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + region = create_region(session, model, dataset, code="us", label="United States") + + payload = { + "tax_benefit_model_name": "policyengine_us", + "region": "us", + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["dataset_id"] == str(dataset.id) + assert data["region"]["code"] == "us" + assert data["region"]["label"] == "United States" + + +def test_create_economy_simulation_with_dataset(client, session): + """Create an economy simulation using a dataset_id directly.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "dataset_id": str(dataset.id), + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["dataset_id"] == str(dataset.id) + assert data["region"] is None + + +def test_create_economy_simulation_with_region_filter(client, session): + """Create an economy simulation with a region that requires filtering.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + region = create_region( + session, + model, + dataset, + code="state/ca", + label="California", + region_type="state", + requires_filter=True, + filter_field="state_code", + filter_value="CA", + ) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "region": "state/ca", + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["filter_field"] == "state_code" + assert data["filter_value"] == "CA" + assert data["region"]["requires_filter"] is True + + +def test_create_economy_simulation_invalid_region(client, session): + """Creating with a non-existent region returns 404.""" + model, version = create_us_model_and_version(session) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "region": "nonexistent/region", + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_create_economy_simulation_no_region_or_dataset(client, session): + """Creating without region or dataset_id returns 400.""" + model, version = create_us_model_and_version(session) + + payload = {"tax_benefit_model_name": "policyengine_us"} + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 400 + assert "Either region or dataset_id" in response.json()["detail"] + + +def test_create_economy_simulation_policy_not_found(client, session): + """Creating with a non-existent policy returns 404.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "dataset_id": str(dataset.id), + "policy_id": str(uuid4()), + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 404 + assert "Policy" in response.json()["detail"] + + +def test_economy_simulation_deduplication(client, session): + """Same inputs produce the same simulation (deterministic UUID).""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "dataset_id": str(dataset.id), + } + response1 = client.post("/simulations/economy", json=payload) + response2 = client.post("/simulations/economy", json=payload) + + assert response1.status_code == 200 + assert response2.status_code == 200 + assert response1.json()["id"] == response2.json()["id"] + + +# =========================================================================== +# GET /simulations/economy/{id} +# =========================================================================== + + +def test_get_economy_simulation(client, session): + """Get an economy simulation by ID.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + simulation = create_economy_simulation(session, version, dataset) + + response = client.get(f"/simulations/economy/{simulation.id}") + + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(simulation.id) + assert data["status"] == "completed" + + +def test_get_economy_simulation_not_found(client, session): + """Get a non-existent economy simulation returns 404.""" + response = client.get(f"/simulations/economy/{uuid4()}") + assert response.status_code == 404 + + +def test_get_economy_simulation_wrong_type(client, session): + """Get a household simulation via the economy endpoint returns 400.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + household_sim = create_household_simulation(session, version, household) + + response = client.get(f"/simulations/economy/{household_sim.id}") + assert response.status_code == 400 + assert "not an economy simulation" in response.json()["detail"] + + +# =========================================================================== +# Generic GET /simulations/{id} still works +# =========================================================================== + + +def test_get_simulation_generic(client, session): + """The generic GET /simulations/{id} endpoint still works for any type.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + simulation = create_household_simulation(session, version, household) + + response = client.get(f"/simulations/{simulation.id}") + + assert response.status_code == 200 + assert response.json()["id"] == str(simulation.id) diff --git a/tests/test_user_report_associations.py b/tests/test_user_report_associations.py new file mode 100644 index 0000000..30821f7 --- /dev/null +++ b/tests/test_user_report_associations.py @@ -0,0 +1,250 @@ +"""Tests for user-report association endpoints.""" + +from datetime import datetime, timezone +from uuid import uuid4 + +from test_fixtures.fixtures_user_report_associations import ( + create_report, + create_user_report_association, +) + +# --------------------------------------------------------------------------- +# POST /user-reports +# --------------------------------------------------------------------------- + + +def test_create_association(client, session): + """Create an association returns 200 with id and timestamps.""" + user_id = uuid4() + report = create_report(session) + payload = { + "user_id": str(user_id), + "report_id": str(report.id), + "country_id": "us", + "label": "My US report", + } + response = client.post("/user-reports/", json=payload) + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["user_id"] == str(user_id) + assert data["report_id"] == str(report.id) + assert data["country_id"] == "us" + assert data["label"] == "My US report" + assert data["last_run_at"] is None + + +def test_create_association_with_last_run_at(client, session): + """Create an association with last_run_at set.""" + user_id = uuid4() + report = create_report(session) + now = datetime.now(timezone.utc).isoformat() + payload = { + "user_id": str(user_id), + "report_id": str(report.id), + "country_id": "us", + "last_run_at": now, + } + response = client.post("/user-reports/", json=payload) + assert response.status_code == 200 + assert response.json()["last_run_at"] is not None + + +def test_create_association_report_not_found(client): + """Creating with a non-existent report returns 404.""" + payload = { + "user_id": str(uuid4()), + "report_id": str(uuid4()), + "country_id": "us", + } + response = client.post("/user-reports/", json=payload) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_create_association_invalid_country(client, session): + """Creating with an invalid country_id returns 422.""" + report = create_report(session) + payload = { + "user_id": str(uuid4()), + "report_id": str(report.id), + "country_id": "invalid", + } + response = client.post("/user-reports/", json=payload) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /user-reports/?user_id=... +# --------------------------------------------------------------------------- + + +def test_list_by_user_empty(client): + """List associations for a user with none returns empty list.""" + response = client.get("/user-reports/", params={"user_id": str(uuid4())}) + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_by_user(client, session): + """List all associations for a user.""" + user_id = uuid4() + r1 = create_report(session) + r2 = create_report(session) + create_user_report_association(session, user_id, r1, label="First") + create_user_report_association(session, user_id, r2, label="Second") + + response = client.get("/user-reports/", params={"user_id": str(user_id)}) + assert response.status_code == 200 + assert len(response.json()) == 2 + + +def test_list_by_user_filter_country(client, session): + """Filter associations by country_id.""" + user_id = uuid4() + report = create_report(session) + create_user_report_association(session, user_id, report, country_id="us") + create_user_report_association(session, user_id, report, country_id="uk") + + response = client.get( + "/user-reports/", + params={"user_id": str(user_id), "country_id": "uk"}, + ) + data = response.json() + assert len(data) == 1 + assert data[0]["country_id"] == "uk" + + +# --------------------------------------------------------------------------- +# GET /user-reports/{id} +# --------------------------------------------------------------------------- + + +def test_get_by_id(client, session): + """Get a specific association by ID.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association( + session, user_id, report, label="Test" + ) + + response = client.get(f"/user-reports/{assoc.id}") + assert response.status_code == 200 + assert response.json()["id"] == str(assoc.id) + assert response.json()["label"] == "Test" + + +def test_get_by_id_not_found(client): + """Get a non-existent association returns 404.""" + response = client.get(f"/user-reports/{uuid4()}") + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# PATCH /user-reports/{id}?user_id=... +# --------------------------------------------------------------------------- + + +def test_update_label(client, session): + """Update label via PATCH.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association( + session, user_id, report, label="Old" + ) + + response = client.patch( + f"/user-reports/{assoc.id}", + json={"label": "New label"}, + params={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + assert response.json()["label"] == "New label" + + +def test_update_last_run_at(client, session): + """Update last_run_at via PATCH.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association(session, user_id, report) + + now = datetime.now(timezone.utc).isoformat() + response = client.patch( + f"/user-reports/{assoc.id}", + json={"last_run_at": now}, + params={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + assert response.json()["last_run_at"] is not None + + +def test_update_wrong_user(client, session): + """Update with wrong user_id returns 404.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association( + session, user_id, report, label="Mine" + ) + + response = client.patch( + f"/user-reports/{assoc.id}", + json={"label": "Stolen"}, + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +def test_update_not_found(client): + """Update a non-existent association returns 404.""" + response = client.patch( + f"/user-reports/{uuid4()}", + json={"label": "Something"}, + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# DELETE /user-reports/{id}?user_id=... +# --------------------------------------------------------------------------- + + +def test_delete_association(client, session): + """Delete an association returns 204.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association(session, user_id, report) + + response = client.delete( + f"/user-reports/{assoc.id}", + params={"user_id": str(user_id)}, + ) + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/user-reports/{assoc.id}") + assert response.status_code == 404 + + +def test_delete_wrong_user(client, session): + """Delete with wrong user_id returns 404.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association(session, user_id, report) + + response = client.delete( + f"/user-reports/{assoc.id}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +def test_delete_not_found(client): + """Delete a non-existent association returns 404.""" + response = client.delete( + f"/user-reports/{uuid4()}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 diff --git a/tests/test_user_simulation_associations.py b/tests/test_user_simulation_associations.py new file mode 100644 index 0000000..95f799e --- /dev/null +++ b/tests/test_user_simulation_associations.py @@ -0,0 +1,244 @@ +"""Tests for user-simulation association endpoints.""" + +from uuid import uuid4 + +from test_fixtures.fixtures_user_simulation_associations import ( + create_simulation, + create_user_simulation_association, +) + +# --------------------------------------------------------------------------- +# POST /user-simulations +# --------------------------------------------------------------------------- + + +def test_create_association(client, session): + """Create an association returns 200 with id and timestamps.""" + user_id = uuid4() + simulation = create_simulation(session) + payload = { + "user_id": str(user_id), + "simulation_id": str(simulation.id), + "country_id": "us", + "label": "My US simulation", + } + response = client.post("/user-simulations/", json=payload) + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["user_id"] == str(user_id) + assert data["simulation_id"] == str(simulation.id) + assert data["country_id"] == "us" + assert data["label"] == "My US simulation" + + +def test_create_association_allows_duplicates(client, session): + """Multiple associations to the same simulation are allowed.""" + user_id = uuid4() + simulation = create_simulation(session) + payload = { + "user_id": str(user_id), + "simulation_id": str(simulation.id), + "country_id": "us", + "label": "First label", + } + r1 = client.post("/user-simulations/", json=payload) + assert r1.status_code == 200 + + payload["label"] = "Second label" + r2 = client.post("/user-simulations/", json=payload) + assert r2.status_code == 200 + assert r1.json()["id"] != r2.json()["id"] + + +def test_create_association_simulation_not_found(client): + """Creating with a non-existent simulation returns 404.""" + payload = { + "user_id": str(uuid4()), + "simulation_id": str(uuid4()), + "country_id": "us", + } + response = client.post("/user-simulations/", json=payload) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_create_association_invalid_country(client, session): + """Creating with an invalid country_id returns 422.""" + simulation = create_simulation(session) + payload = { + "user_id": str(uuid4()), + "simulation_id": str(simulation.id), + "country_id": "invalid", + } + response = client.post("/user-simulations/", json=payload) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /user-simulations/?user_id=... +# --------------------------------------------------------------------------- + + +def test_list_by_user_empty(client): + """List associations for a user with none returns empty list.""" + response = client.get( + "/user-simulations/", params={"user_id": str(uuid4())} + ) + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_by_user(client, session): + """List all associations for a user.""" + user_id = uuid4() + sim1 = create_simulation(session) + sim2 = create_simulation(session) + create_user_simulation_association(session, user_id, sim1, label="First") + create_user_simulation_association(session, user_id, sim2, label="Second") + + response = client.get( + "/user-simulations/", params={"user_id": str(user_id)} + ) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_by_user_filter_country(client, session): + """Filter associations by country_id.""" + user_id = uuid4() + simulation = create_simulation(session) + create_user_simulation_association( + session, user_id, simulation, country_id="us" + ) + create_user_simulation_association( + session, user_id, simulation, country_id="uk" + ) + + response = client.get( + "/user-simulations/", + params={"user_id": str(user_id), "country_id": "uk"}, + ) + data = response.json() + assert len(data) == 1 + assert data[0]["country_id"] == "uk" + + +# --------------------------------------------------------------------------- +# GET /user-simulations/{id} +# --------------------------------------------------------------------------- + + +def test_get_by_id(client, session): + """Get a specific association by ID.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association( + session, user_id, simulation, label="Test" + ) + + response = client.get(f"/user-simulations/{assoc.id}") + assert response.status_code == 200 + assert response.json()["id"] == str(assoc.id) + assert response.json()["label"] == "Test" + + +def test_get_by_id_not_found(client): + """Get a non-existent association returns 404.""" + response = client.get(f"/user-simulations/{uuid4()}") + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# PATCH /user-simulations/{id}?user_id=... +# --------------------------------------------------------------------------- + + +def test_update_label(client, session): + """Update label via PATCH.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association( + session, user_id, simulation, label="Old" + ) + + response = client.patch( + f"/user-simulations/{assoc.id}", + json={"label": "New label"}, + params={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + assert response.json()["label"] == "New label" + + +def test_update_wrong_user(client, session): + """Update with wrong user_id returns 404.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association( + session, user_id, simulation, label="Mine" + ) + + response = client.patch( + f"/user-simulations/{assoc.id}", + json={"label": "Stolen"}, + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +def test_update_not_found(client): + """Update a non-existent association returns 404.""" + response = client.patch( + f"/user-simulations/{uuid4()}", + json={"label": "Something"}, + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# DELETE /user-simulations/{id}?user_id=... +# --------------------------------------------------------------------------- + + +def test_delete_association(client, session): + """Delete an association returns 204.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association(session, user_id, simulation) + + response = client.delete( + f"/user-simulations/{assoc.id}", + params={"user_id": str(user_id)}, + ) + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/user-simulations/{assoc.id}") + assert response.status_code == 404 + + +def test_delete_wrong_user(client, session): + """Delete with wrong user_id returns 404.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association(session, user_id, simulation) + + response = client.delete( + f"/user-simulations/{assoc.id}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +def test_delete_not_found(client): + """Delete a non-existent association returns 404.""" + response = client.delete( + f"/user-simulations/{uuid4()}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404