From 46576d744da0a7661d7130c3c43058f7c9401381 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 18 Feb 2026 23:40:07 +0100 Subject: [PATCH 1/8] fix: Merge divergent Alembic migration branches Two migrations (0002_user_policies and f419b5f4acba) both descended from the initial schema, creating two Alembic heads. This empty merge migration reconciles them into a single linear chain so future migrations can be added. Co-Authored-By: Claude Opus 4.6 --- ...rge_user_policies_and_household_support.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 alembic/versions/20260218_merge_user_policies_and_household_support.py diff --git a/alembic/versions/20260218_merge_user_policies_and_household_support.py b/alembic/versions/20260218_merge_user_policies_and_household_support.py new file mode 100644 index 0000000..d880c84 --- /dev/null +++ b/alembic/versions/20260218_merge_user_policies_and_household_support.py @@ -0,0 +1,32 @@ +"""merge user_policies and household_support branches + +Revision ID: merge_001 +Revises: 0002_user_policies, a1b2c3d4e5f6 +Create Date: 2026-02-18 + +Merge the two migration branches that diverged from the initial schema: +- 0002_user_policies: added user_policies table + policy.tax_benefit_model_id +- f419b5f4acba → a1b2c3d4e5f6: added household support + regions table + +No schema changes — both branches modify independent tables. +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "merge_001" +down_revision: tuple[str, str] = ("0002_user_policies", "a1b2c3d4e5f6") +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """No schema changes — merge only.""" + pass + + +def downgrade() -> None: + """No schema changes — merge only.""" + pass From 8c0c5bcb17c0faf949a89167c77cac8a66ec3e25 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 18 Feb 2026 23:47:10 +0100 Subject: [PATCH 2/8] feat: Add filter_field/filter_value columns to simulations table The Simulation model already uses these fields for regional economy simulations (e.g., filtering a dataset to a specific state), but no Alembic migration created the columns. This aligns the DB schema with the existing Python model. Co-Authored-By: Claude Opus 4.6 --- .../20260218_add_simulation_filter_columns.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 alembic/versions/20260218_add_simulation_filter_columns.py diff --git a/alembic/versions/20260218_add_simulation_filter_columns.py b/alembic/versions/20260218_add_simulation_filter_columns.py new file mode 100644 index 0000000..ecc5b4c --- /dev/null +++ b/alembic/versions/20260218_add_simulation_filter_columns.py @@ -0,0 +1,41 @@ +"""add filter_field and filter_value to simulations + +Revision ID: add_sim_filters +Revises: merge_001 +Create Date: 2026-02-18 + +The Simulation model already has filter_field and filter_value fields +(used for regional economy simulations), but no migration added them +to the database. This brings the schema in line with the model. +""" + +from typing import Sequence, Union + +import sqlmodel.sql.sqltypes + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "add_sim_filters" +down_revision: Union[str, Sequence[str], None] = "merge_001" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add filter_field and filter_value columns to simulations table.""" + op.add_column( + "simulations", + sa.Column("filter_field", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + op.add_column( + "simulations", + sa.Column("filter_value", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + + +def downgrade() -> None: + """Remove filter_field and filter_value columns from simulations table.""" + op.drop_column("simulations", "filter_value") + op.drop_column("simulations", "filter_field") From ac82debaa69aab73f2d9f796a624d9a86c9f8b2d Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 19 Feb 2026 00:33:17 +0100 Subject: [PATCH 3/8] feat: Remove unused parent_report_id from Report model The self-referential parent_report_id FK on reports was never read or written by any code. Removing it keeps the schema clean before we build further on the Report model in later migration tasks. Co-Authored-By: Claude Opus 4.6 --- .../20260218_drop_parent_report_id.py | 43 +++++++++++++++++++ src/policyengine_api/models/report.py | 1 - 2 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 alembic/versions/20260218_drop_parent_report_id.py diff --git a/alembic/versions/20260218_drop_parent_report_id.py b/alembic/versions/20260218_drop_parent_report_id.py new file mode 100644 index 0000000..54c548c --- /dev/null +++ b/alembic/versions/20260218_drop_parent_report_id.py @@ -0,0 +1,43 @@ +"""drop parent_report_id from reports + +Revision ID: drop_parent_report +Revises: add_sim_filters +Create Date: 2026-02-18 + +Remove the unused self-referential parent_report_id foreign key from +the reports table. No code reads or writes this column. +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "drop_parent_report" +down_revision: Union[str, Sequence[str], None] = "add_sim_filters" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Drop parent_report_id column and its FK constraint.""" + op.drop_constraint( + "reports_parent_report_id_fkey", "reports", type_="foreignkey" + ) + op.drop_column("reports", "parent_report_id") + + +def downgrade() -> None: + """Re-add parent_report_id column and FK constraint.""" + op.add_column( + "reports", + sa.Column("parent_report_id", sa.Uuid(), nullable=True), + ) + op.create_foreign_key( + "reports_parent_report_id_fkey", + "reports", + "reports", + ["parent_report_id"], + ["id"], + ) diff --git a/src/policyengine_api/models/report.py b/src/policyengine_api/models/report.py index bc2cd40..ffc33b8 100644 --- a/src/policyengine_api/models/report.py +++ b/src/policyengine_api/models/report.py @@ -22,7 +22,6 @@ class ReportBase(SQLModel): report_type: str | None = None user_id: UUID | None = Field(default=None, foreign_key="users.id") markdown: str | None = Field(default=None, sa_column=Column(Text)) - parent_report_id: UUID | None = Field(default=None, foreign_key="reports.id") status: ReportStatus = ReportStatus.PENDING error_message: str | None = None baseline_simulation_id: UUID | None = Field( From 6c4af281f64d2667997437e5123838b0cbdd5e67 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 19 Feb 2026 00:55:54 +0100 Subject: [PATCH 4/8] feat: Add user_simulation_associations table and CRUD endpoints Enables users to save simulations to their list across sessions, following the same pattern as user_policies and user_household_associations. Includes model, autogenerated Alembic migration, API endpoints with ownership verification on update/delete, and 15 tests. Co-Authored-By: Claude Opus 4.6 --- ..._add_user_simulation_associations_table.py | 47 ++++ src/policyengine_api/api/__init__.py | 2 + .../api/user_simulation_associations.py | 146 +++++++++++ src/policyengine_api/models/__init__.py | 10 + .../models/user_simulation_association.py | 60 +++++ .../fixtures_user_simulation_associations.py | 79 ++++++ tests/test_user_simulation_associations.py | 244 ++++++++++++++++++ 7 files changed, 588 insertions(+) create mode 100644 alembic/versions/20260219_621977f3b1aa_add_user_simulation_associations_table.py create mode 100644 src/policyengine_api/api/user_simulation_associations.py create mode 100644 src/policyengine_api/models/user_simulation_association.py create mode 100644 test_fixtures/fixtures_user_simulation_associations.py create mode 100644 tests/test_user_simulation_associations.py diff --git a/alembic/versions/20260219_621977f3b1aa_add_user_simulation_associations_table.py b/alembic/versions/20260219_621977f3b1aa_add_user_simulation_associations_table.py new file mode 100644 index 0000000..c31a778 --- /dev/null +++ b/alembic/versions/20260219_621977f3b1aa_add_user_simulation_associations_table.py @@ -0,0 +1,47 @@ +"""add user_simulation_associations table + +Revision ID: 621977f3b1aa +Revises: drop_parent_report +Create Date: 2026-02-19 00:37:43.378088 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision: str = '621977f3b1aa' +down_revision: Union[str, Sequence[str], None] = 'drop_parent_report' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('user_simulation_associations', + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('simulation_id', sa.Uuid(), nullable=False), + sa.Column('country_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['simulation_id'], ['simulations.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_user_simulation_associations_simulation_id'), 'user_simulation_associations', ['simulation_id'], unique=False) + op.create_index(op.f('ix_user_simulation_associations_user_id'), 'user_simulation_associations', ['user_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_user_simulation_associations_user_id'), table_name='user_simulation_associations') + op.drop_index(op.f('ix_user_simulation_associations_simulation_id'), table_name='user_simulation_associations') + op.drop_table('user_simulation_associations') + # ### end Alembic commands ### diff --git a/src/policyengine_api/api/__init__.py b/src/policyengine_api/api/__init__.py index 7e94d29..a216e81 100644 --- a/src/policyengine_api/api/__init__.py +++ b/src/policyengine_api/api/__init__.py @@ -21,6 +21,7 @@ tax_benefit_models, user_household_associations, user_policies, + user_simulation_associations, variables, ) @@ -45,5 +46,6 @@ api_router.include_router(agent.router) api_router.include_router(user_household_associations.router) api_router.include_router(user_policies.router) +api_router.include_router(user_simulation_associations.router) __all__ = ["api_router"] diff --git a/src/policyengine_api/api/user_simulation_associations.py b/src/policyengine_api/api/user_simulation_associations.py new file mode 100644 index 0000000..2341d91 --- /dev/null +++ b/src/policyengine_api/api/user_simulation_associations.py @@ -0,0 +1,146 @@ +"""User-simulation association endpoints. + +Associates users with simulations they've run. This enables users to +maintain a list of their simulations across sessions without duplicating +the underlying simulation data. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save simulations without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.config.constants import CountryId +from policyengine_api.models import ( + Simulation, + UserSimulationAssociation, + UserSimulationAssociationCreate, + UserSimulationAssociationRead, + UserSimulationAssociationUpdate, +) +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/user-simulations", tags=["user-simulations"]) + + +@router.post("/", response_model=UserSimulationAssociationRead) +def create_user_simulation( + body: UserSimulationAssociationCreate, + session: Session = Depends(get_session), +): + """Create a new user-simulation association. + + Associates a user with a simulation, allowing them to save it to their list. + Duplicates are allowed - users can save the same simulation multiple times + with different labels. + """ + simulation = session.get(Simulation, body.simulation_id) + if not simulation: + raise HTTPException(status_code=404, detail="Simulation not found") + + record = UserSimulationAssociation.model_validate(body) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.get("/", response_model=list[UserSimulationAssociationRead]) +def list_user_simulations( + user_id: UUID = Query(..., description="User ID to filter by"), + country_id: CountryId | None = Query( + None, description="Filter by country ('us' or 'uk')" + ), + session: Session = Depends(get_session), +): + """List all simulation associations for a user. + + Returns all simulations saved by the specified user. Optionally filter by country. + """ + query = select(UserSimulationAssociation).where( + UserSimulationAssociation.user_id == user_id + ) + + if country_id: + query = query.where(UserSimulationAssociation.country_id == country_id) + + return session.exec(query).all() + + +@router.get("/{user_simulation_id}", response_model=UserSimulationAssociationRead) +def get_user_simulation( + user_simulation_id: UUID, + session: Session = Depends(get_session), +): + """Get a specific user-simulation association by ID.""" + record = session.get(UserSimulationAssociation, user_simulation_id) + if not record: + raise HTTPException( + status_code=404, detail="User-simulation association not found" + ) + return record + + +@router.patch("/{user_simulation_id}", response_model=UserSimulationAssociationRead) +def update_user_simulation( + user_simulation_id: UUID, + updates: UserSimulationAssociationUpdate, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Update a user-simulation association (e.g., rename label). + + Requires user_id to verify ownership - only the owner can update. + """ + record = session.exec( + select(UserSimulationAssociation).where( + UserSimulationAssociation.id == user_simulation_id, + UserSimulationAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException( + status_code=404, detail="User-simulation association not found" + ) + + update_data = updates.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(record, key, value) + + record.updated_at = datetime.now(timezone.utc) + + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.delete("/{user_simulation_id}", status_code=204) +def delete_user_simulation( + user_simulation_id: UUID, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Delete a user-simulation association. + + This only removes the association, not the underlying simulation. + Requires user_id to verify ownership - only the owner can delete. + """ + record = session.exec( + select(UserSimulationAssociation).where( + UserSimulationAssociation.id == user_simulation_id, + UserSimulationAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException( + status_code=404, detail="User-simulation association not found" + ) + + session.delete(record) + session.commit() diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index e7b386a..1809b49 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -61,6 +61,12 @@ UserHouseholdAssociationRead, UserHouseholdAssociationUpdate, ) +from .user_simulation_association import ( + UserSimulationAssociation, + UserSimulationAssociationCreate, + UserSimulationAssociationRead, + UserSimulationAssociationUpdate, +) from .user_policy import ( UserPolicy, UserPolicyCreate, @@ -142,6 +148,10 @@ "UserHouseholdAssociationRead", "UserHouseholdAssociationUpdate", "UserRead", + "UserSimulationAssociation", + "UserSimulationAssociationCreate", + "UserSimulationAssociationRead", + "UserSimulationAssociationUpdate", "UserPolicy", "UserPolicyCreate", "UserPolicyRead", diff --git a/src/policyengine_api/models/user_simulation_association.py b/src/policyengine_api/models/user_simulation_association.py new file mode 100644 index 0000000..9b07d19 --- /dev/null +++ b/src/policyengine_api/models/user_simulation_association.py @@ -0,0 +1,60 @@ +"""User-simulation association model. + +Associates users with simulations they've run. This enables users to +maintain a list of their simulations across sessions. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save simulations without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine_api.config.constants import CountryId + + +class UserSimulationAssociationBase(SQLModel): + """Base association fields.""" + + user_id: UUID = Field(index=True) + simulation_id: UUID = Field(foreign_key="simulations.id", index=True) + country_id: str + label: str | None = None + + +class UserSimulationAssociation(UserSimulationAssociationBase, table=True): + """User-simulation association database model.""" + + __tablename__ = "user_simulation_associations" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class UserSimulationAssociationCreate(SQLModel): + """Schema for creating user-simulation associations.""" + + user_id: UUID + simulation_id: UUID + country_id: CountryId + label: str | None = None + + +class UserSimulationAssociationRead(UserSimulationAssociationBase): + """Schema for reading user-simulation associations.""" + + id: UUID + created_at: datetime + updated_at: datetime + + +class UserSimulationAssociationUpdate(SQLModel): + """Schema for updating user-simulation associations.""" + + model_config = {"extra": "forbid"} + + label: str | None = None diff --git a/test_fixtures/fixtures_user_simulation_associations.py b/test_fixtures/fixtures_user_simulation_associations.py new file mode 100644 index 0000000..c2cbd74 --- /dev/null +++ b/test_fixtures/fixtures_user_simulation_associations.py @@ -0,0 +1,79 @@ +"""Fixtures and helpers for user-simulation association tests.""" + +from uuid import UUID + +from policyengine_api.models import ( + Dataset, + Simulation, + SimulationStatus, + TaxBenefitModel, + TaxBenefitModelVersion, + UserSimulationAssociation, +) + + +def create_tax_benefit_model( + session, + name: str = "policyengine-us", + description: str = "US model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + record = TaxBenefitModel(name=name, description=description) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_simulation(session, model: TaxBenefitModel | None = None) -> Simulation: + """Create and persist a Simulation with required dependencies.""" + if model is None: + model = create_tax_benefit_model(session) + + version = TaxBenefitModelVersion( + model_id=model.id, version="test", description="Test version" + ) + session.add(version) + session.commit() + session.refresh(version) + + dataset = Dataset( + name="test_dataset", + description="Test dataset", + filepath="test/path/dataset.h5", + year=2024, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + + simulation = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + return simulation + + +def create_user_simulation_association( + session, + user_id: UUID, + simulation: Simulation, + country_id: str = "us", + label: str | None = None, +) -> UserSimulationAssociation: + """Create and persist a UserSimulationAssociation record.""" + record = UserSimulationAssociation( + user_id=user_id, + simulation_id=simulation.id, + country_id=country_id, + label=label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/tests/test_user_simulation_associations.py b/tests/test_user_simulation_associations.py new file mode 100644 index 0000000..95f799e --- /dev/null +++ b/tests/test_user_simulation_associations.py @@ -0,0 +1,244 @@ +"""Tests for user-simulation association endpoints.""" + +from uuid import uuid4 + +from test_fixtures.fixtures_user_simulation_associations import ( + create_simulation, + create_user_simulation_association, +) + +# --------------------------------------------------------------------------- +# POST /user-simulations +# --------------------------------------------------------------------------- + + +def test_create_association(client, session): + """Create an association returns 200 with id and timestamps.""" + user_id = uuid4() + simulation = create_simulation(session) + payload = { + "user_id": str(user_id), + "simulation_id": str(simulation.id), + "country_id": "us", + "label": "My US simulation", + } + response = client.post("/user-simulations/", json=payload) + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["user_id"] == str(user_id) + assert data["simulation_id"] == str(simulation.id) + assert data["country_id"] == "us" + assert data["label"] == "My US simulation" + + +def test_create_association_allows_duplicates(client, session): + """Multiple associations to the same simulation are allowed.""" + user_id = uuid4() + simulation = create_simulation(session) + payload = { + "user_id": str(user_id), + "simulation_id": str(simulation.id), + "country_id": "us", + "label": "First label", + } + r1 = client.post("/user-simulations/", json=payload) + assert r1.status_code == 200 + + payload["label"] = "Second label" + r2 = client.post("/user-simulations/", json=payload) + assert r2.status_code == 200 + assert r1.json()["id"] != r2.json()["id"] + + +def test_create_association_simulation_not_found(client): + """Creating with a non-existent simulation returns 404.""" + payload = { + "user_id": str(uuid4()), + "simulation_id": str(uuid4()), + "country_id": "us", + } + response = client.post("/user-simulations/", json=payload) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_create_association_invalid_country(client, session): + """Creating with an invalid country_id returns 422.""" + simulation = create_simulation(session) + payload = { + "user_id": str(uuid4()), + "simulation_id": str(simulation.id), + "country_id": "invalid", + } + response = client.post("/user-simulations/", json=payload) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /user-simulations/?user_id=... +# --------------------------------------------------------------------------- + + +def test_list_by_user_empty(client): + """List associations for a user with none returns empty list.""" + response = client.get( + "/user-simulations/", params={"user_id": str(uuid4())} + ) + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_by_user(client, session): + """List all associations for a user.""" + user_id = uuid4() + sim1 = create_simulation(session) + sim2 = create_simulation(session) + create_user_simulation_association(session, user_id, sim1, label="First") + create_user_simulation_association(session, user_id, sim2, label="Second") + + response = client.get( + "/user-simulations/", params={"user_id": str(user_id)} + ) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_by_user_filter_country(client, session): + """Filter associations by country_id.""" + user_id = uuid4() + simulation = create_simulation(session) + create_user_simulation_association( + session, user_id, simulation, country_id="us" + ) + create_user_simulation_association( + session, user_id, simulation, country_id="uk" + ) + + response = client.get( + "/user-simulations/", + params={"user_id": str(user_id), "country_id": "uk"}, + ) + data = response.json() + assert len(data) == 1 + assert data[0]["country_id"] == "uk" + + +# --------------------------------------------------------------------------- +# GET /user-simulations/{id} +# --------------------------------------------------------------------------- + + +def test_get_by_id(client, session): + """Get a specific association by ID.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association( + session, user_id, simulation, label="Test" + ) + + response = client.get(f"/user-simulations/{assoc.id}") + assert response.status_code == 200 + assert response.json()["id"] == str(assoc.id) + assert response.json()["label"] == "Test" + + +def test_get_by_id_not_found(client): + """Get a non-existent association returns 404.""" + response = client.get(f"/user-simulations/{uuid4()}") + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# PATCH /user-simulations/{id}?user_id=... +# --------------------------------------------------------------------------- + + +def test_update_label(client, session): + """Update label via PATCH.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association( + session, user_id, simulation, label="Old" + ) + + response = client.patch( + f"/user-simulations/{assoc.id}", + json={"label": "New label"}, + params={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + assert response.json()["label"] == "New label" + + +def test_update_wrong_user(client, session): + """Update with wrong user_id returns 404.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association( + session, user_id, simulation, label="Mine" + ) + + response = client.patch( + f"/user-simulations/{assoc.id}", + json={"label": "Stolen"}, + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +def test_update_not_found(client): + """Update a non-existent association returns 404.""" + response = client.patch( + f"/user-simulations/{uuid4()}", + json={"label": "Something"}, + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# DELETE /user-simulations/{id}?user_id=... +# --------------------------------------------------------------------------- + + +def test_delete_association(client, session): + """Delete an association returns 204.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association(session, user_id, simulation) + + response = client.delete( + f"/user-simulations/{assoc.id}", + params={"user_id": str(user_id)}, + ) + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/user-simulations/{assoc.id}") + assert response.status_code == 404 + + +def test_delete_wrong_user(client, session): + """Delete with wrong user_id returns 404.""" + user_id = uuid4() + simulation = create_simulation(session) + assoc = create_user_simulation_association(session, user_id, simulation) + + response = client.delete( + f"/user-simulations/{assoc.id}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +def test_delete_not_found(client): + """Delete a non-existent association returns 404.""" + response = client.delete( + f"/user-simulations/{uuid4()}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 From 432f502a7092ebd13aabb3024cfa1bef19701ee3 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 19 Feb 2026 17:41:07 +0100 Subject: [PATCH 5/8] feat: Add user_report_associations table and CRUD endpoints Enables users to save reports to their list across sessions, following the same pattern as user_simulations and user_policies. Includes a last_run_at field for tracking when a report was last calculated. Model, autogenerated Alembic migration, API endpoints with ownership verification, and 16 tests. Co-Authored-By: Claude Opus 4.6 --- ...74dd_add_user_report_associations_table.py | 48 ++++ src/policyengine_api/api/__init__.py | 2 + .../api/user_report_associations.py | 146 ++++++++++ src/policyengine_api/models/__init__.py | 10 + .../models/user_report_association.py | 63 +++++ .../fixtures_user_report_associations.py | 101 +++++++ tests/test_user_report_associations.py | 250 ++++++++++++++++++ 7 files changed, 620 insertions(+) create mode 100644 alembic/versions/20260219_9daa015274dd_add_user_report_associations_table.py create mode 100644 src/policyengine_api/api/user_report_associations.py create mode 100644 src/policyengine_api/models/user_report_association.py create mode 100644 test_fixtures/fixtures_user_report_associations.py create mode 100644 tests/test_user_report_associations.py diff --git a/alembic/versions/20260219_9daa015274dd_add_user_report_associations_table.py b/alembic/versions/20260219_9daa015274dd_add_user_report_associations_table.py new file mode 100644 index 0000000..b9edb4f --- /dev/null +++ b/alembic/versions/20260219_9daa015274dd_add_user_report_associations_table.py @@ -0,0 +1,48 @@ +"""add user_report_associations table + +Revision ID: 9daa015274dd +Revises: 621977f3b1aa +Create Date: 2026-02-19 16:58:03.157551 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision: str = '9daa015274dd' +down_revision: Union[str, Sequence[str], None] = '621977f3b1aa' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('user_report_associations', + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('report_id', sa.Uuid(), nullable=False), + sa.Column('country_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('last_run_at', sa.DateTime(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_user_report_associations_report_id'), 'user_report_associations', ['report_id'], unique=False) + op.create_index(op.f('ix_user_report_associations_user_id'), 'user_report_associations', ['user_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_user_report_associations_user_id'), table_name='user_report_associations') + op.drop_index(op.f('ix_user_report_associations_report_id'), table_name='user_report_associations') + op.drop_table('user_report_associations') + # ### end Alembic commands ### diff --git a/src/policyengine_api/api/__init__.py b/src/policyengine_api/api/__init__.py index a216e81..dd8deac 100644 --- a/src/policyengine_api/api/__init__.py +++ b/src/policyengine_api/api/__init__.py @@ -21,6 +21,7 @@ tax_benefit_models, user_household_associations, user_policies, + user_report_associations, user_simulation_associations, variables, ) @@ -47,5 +48,6 @@ api_router.include_router(user_household_associations.router) api_router.include_router(user_policies.router) api_router.include_router(user_simulation_associations.router) +api_router.include_router(user_report_associations.router) __all__ = ["api_router"] diff --git a/src/policyengine_api/api/user_report_associations.py b/src/policyengine_api/api/user_report_associations.py new file mode 100644 index 0000000..ec1bd6b --- /dev/null +++ b/src/policyengine_api/api/user_report_associations.py @@ -0,0 +1,146 @@ +"""User-report association endpoints. + +Associates users with reports they've created. This enables users to +maintain a list of their reports across sessions without duplicating +the underlying report data. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save reports without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.config.constants import CountryId +from policyengine_api.models import ( + Report, + UserReportAssociation, + UserReportAssociationCreate, + UserReportAssociationRead, + UserReportAssociationUpdate, +) +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/user-reports", tags=["user-reports"]) + + +@router.post("/", response_model=UserReportAssociationRead) +def create_user_report( + body: UserReportAssociationCreate, + session: Session = Depends(get_session), +): + """Create a new user-report association. + + Associates a user with a report, allowing them to save it to their list. + Duplicates are allowed - users can save the same report multiple times + with different labels. + """ + report = session.get(Report, body.report_id) + if not report: + raise HTTPException(status_code=404, detail="Report not found") + + record = UserReportAssociation.model_validate(body) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.get("/", response_model=list[UserReportAssociationRead]) +def list_user_reports( + user_id: UUID = Query(..., description="User ID to filter by"), + country_id: CountryId | None = Query( + None, description="Filter by country ('us' or 'uk')" + ), + session: Session = Depends(get_session), +): + """List all report associations for a user. + + Returns all reports saved by the specified user. Optionally filter by country. + """ + query = select(UserReportAssociation).where( + UserReportAssociation.user_id == user_id + ) + + if country_id: + query = query.where(UserReportAssociation.country_id == country_id) + + return session.exec(query).all() + + +@router.get("/{user_report_id}", response_model=UserReportAssociationRead) +def get_user_report( + user_report_id: UUID, + session: Session = Depends(get_session), +): + """Get a specific user-report association by ID.""" + record = session.get(UserReportAssociation, user_report_id) + if not record: + raise HTTPException( + status_code=404, detail="User-report association not found" + ) + return record + + +@router.patch("/{user_report_id}", response_model=UserReportAssociationRead) +def update_user_report( + user_report_id: UUID, + updates: UserReportAssociationUpdate, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Update a user-report association (e.g., rename label or update last_run_at). + + Requires user_id to verify ownership - only the owner can update. + """ + record = session.exec( + select(UserReportAssociation).where( + UserReportAssociation.id == user_report_id, + UserReportAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException( + status_code=404, detail="User-report association not found" + ) + + update_data = updates.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(record, key, value) + + record.updated_at = datetime.now(timezone.utc) + + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.delete("/{user_report_id}", status_code=204) +def delete_user_report( + user_report_id: UUID, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Delete a user-report association. + + This only removes the association, not the underlying report. + Requires user_id to verify ownership - only the owner can delete. + """ + record = session.exec( + select(UserReportAssociation).where( + UserReportAssociation.id == user_report_id, + UserReportAssociation.user_id == user_id, + ) + ).first() + if not record: + raise HTTPException( + status_code=404, detail="User-report association not found" + ) + + session.delete(record) + session.commit() diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index 1809b49..e73b75a 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -67,6 +67,12 @@ UserSimulationAssociationRead, UserSimulationAssociationUpdate, ) +from .user_report_association import ( + UserReportAssociation, + UserReportAssociationCreate, + UserReportAssociationRead, + UserReportAssociationUpdate, +) from .user_policy import ( UserPolicy, UserPolicyCreate, @@ -152,6 +158,10 @@ "UserSimulationAssociationCreate", "UserSimulationAssociationRead", "UserSimulationAssociationUpdate", + "UserReportAssociation", + "UserReportAssociationCreate", + "UserReportAssociationRead", + "UserReportAssociationUpdate", "UserPolicy", "UserPolicyCreate", "UserPolicyRead", diff --git a/src/policyengine_api/models/user_report_association.py b/src/policyengine_api/models/user_report_association.py new file mode 100644 index 0000000..4f078cb --- /dev/null +++ b/src/policyengine_api/models/user_report_association.py @@ -0,0 +1,63 @@ +"""User-report association model. + +Associates users with reports they've created. This enables users to +maintain a list of their reports across sessions. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save reports without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine_api.config.constants import CountryId + + +class UserReportAssociationBase(SQLModel): + """Base association fields.""" + + user_id: UUID = Field(index=True) + report_id: UUID = Field(foreign_key="reports.id", index=True) + country_id: str + label: str | None = None + last_run_at: datetime | None = None + + +class UserReportAssociation(UserReportAssociationBase, table=True): + """User-report association database model.""" + + __tablename__ = "user_report_associations" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class UserReportAssociationCreate(SQLModel): + """Schema for creating user-report associations.""" + + user_id: UUID + report_id: UUID + country_id: CountryId + label: str | None = None + last_run_at: datetime | None = None + + +class UserReportAssociationRead(UserReportAssociationBase): + """Schema for reading user-report associations.""" + + id: UUID + created_at: datetime + updated_at: datetime + + +class UserReportAssociationUpdate(SQLModel): + """Schema for updating user-report associations.""" + + model_config = {"extra": "forbid"} + + label: str | None = None + last_run_at: datetime | None = None diff --git a/test_fixtures/fixtures_user_report_associations.py b/test_fixtures/fixtures_user_report_associations.py new file mode 100644 index 0000000..4ef07df --- /dev/null +++ b/test_fixtures/fixtures_user_report_associations.py @@ -0,0 +1,101 @@ +"""Fixtures and helpers for user-report association tests.""" + +from datetime import datetime +from uuid import UUID + +from policyengine_api.models import ( + Dataset, + Report, + ReportStatus, + Simulation, + SimulationStatus, + TaxBenefitModel, + TaxBenefitModelVersion, + UserReportAssociation, +) + + +def create_tax_benefit_model( + session, + name: str = "policyengine-us", + description: str = "US model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + record = TaxBenefitModel(name=name, description=description) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_report(session, model: TaxBenefitModel | None = None) -> Report: + """Create and persist a Report with required simulation dependencies.""" + if model is None: + model = create_tax_benefit_model(session) + + version = TaxBenefitModelVersion( + model_id=model.id, version="test", description="Test version" + ) + session.add(version) + session.commit() + session.refresh(version) + + dataset = Dataset( + name="test_dataset", + description="Test dataset", + filepath="test/path/dataset.h5", + year=2024, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + + baseline = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + reform = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + session.add(baseline) + session.add(reform) + session.commit() + session.refresh(baseline) + session.refresh(reform) + + report = Report( + label="Test report", + status=ReportStatus.COMPLETED, + baseline_simulation_id=baseline.id, + reform_simulation_id=reform.id, + ) + session.add(report) + session.commit() + session.refresh(report) + return report + + +def create_user_report_association( + session, + user_id: UUID, + report: Report, + country_id: str = "us", + label: str | None = None, + last_run_at: datetime | None = None, +) -> UserReportAssociation: + """Create and persist a UserReportAssociation record.""" + record = UserReportAssociation( + user_id=user_id, + report_id=report.id, + country_id=country_id, + label=label, + last_run_at=last_run_at, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/tests/test_user_report_associations.py b/tests/test_user_report_associations.py new file mode 100644 index 0000000..30821f7 --- /dev/null +++ b/tests/test_user_report_associations.py @@ -0,0 +1,250 @@ +"""Tests for user-report association endpoints.""" + +from datetime import datetime, timezone +from uuid import uuid4 + +from test_fixtures.fixtures_user_report_associations import ( + create_report, + create_user_report_association, +) + +# --------------------------------------------------------------------------- +# POST /user-reports +# --------------------------------------------------------------------------- + + +def test_create_association(client, session): + """Create an association returns 200 with id and timestamps.""" + user_id = uuid4() + report = create_report(session) + payload = { + "user_id": str(user_id), + "report_id": str(report.id), + "country_id": "us", + "label": "My US report", + } + response = client.post("/user-reports/", json=payload) + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["user_id"] == str(user_id) + assert data["report_id"] == str(report.id) + assert data["country_id"] == "us" + assert data["label"] == "My US report" + assert data["last_run_at"] is None + + +def test_create_association_with_last_run_at(client, session): + """Create an association with last_run_at set.""" + user_id = uuid4() + report = create_report(session) + now = datetime.now(timezone.utc).isoformat() + payload = { + "user_id": str(user_id), + "report_id": str(report.id), + "country_id": "us", + "last_run_at": now, + } + response = client.post("/user-reports/", json=payload) + assert response.status_code == 200 + assert response.json()["last_run_at"] is not None + + +def test_create_association_report_not_found(client): + """Creating with a non-existent report returns 404.""" + payload = { + "user_id": str(uuid4()), + "report_id": str(uuid4()), + "country_id": "us", + } + response = client.post("/user-reports/", json=payload) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_create_association_invalid_country(client, session): + """Creating with an invalid country_id returns 422.""" + report = create_report(session) + payload = { + "user_id": str(uuid4()), + "report_id": str(report.id), + "country_id": "invalid", + } + response = client.post("/user-reports/", json=payload) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /user-reports/?user_id=... +# --------------------------------------------------------------------------- + + +def test_list_by_user_empty(client): + """List associations for a user with none returns empty list.""" + response = client.get("/user-reports/", params={"user_id": str(uuid4())}) + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_by_user(client, session): + """List all associations for a user.""" + user_id = uuid4() + r1 = create_report(session) + r2 = create_report(session) + create_user_report_association(session, user_id, r1, label="First") + create_user_report_association(session, user_id, r2, label="Second") + + response = client.get("/user-reports/", params={"user_id": str(user_id)}) + assert response.status_code == 200 + assert len(response.json()) == 2 + + +def test_list_by_user_filter_country(client, session): + """Filter associations by country_id.""" + user_id = uuid4() + report = create_report(session) + create_user_report_association(session, user_id, report, country_id="us") + create_user_report_association(session, user_id, report, country_id="uk") + + response = client.get( + "/user-reports/", + params={"user_id": str(user_id), "country_id": "uk"}, + ) + data = response.json() + assert len(data) == 1 + assert data[0]["country_id"] == "uk" + + +# --------------------------------------------------------------------------- +# GET /user-reports/{id} +# --------------------------------------------------------------------------- + + +def test_get_by_id(client, session): + """Get a specific association by ID.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association( + session, user_id, report, label="Test" + ) + + response = client.get(f"/user-reports/{assoc.id}") + assert response.status_code == 200 + assert response.json()["id"] == str(assoc.id) + assert response.json()["label"] == "Test" + + +def test_get_by_id_not_found(client): + """Get a non-existent association returns 404.""" + response = client.get(f"/user-reports/{uuid4()}") + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# PATCH /user-reports/{id}?user_id=... +# --------------------------------------------------------------------------- + + +def test_update_label(client, session): + """Update label via PATCH.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association( + session, user_id, report, label="Old" + ) + + response = client.patch( + f"/user-reports/{assoc.id}", + json={"label": "New label"}, + params={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + assert response.json()["label"] == "New label" + + +def test_update_last_run_at(client, session): + """Update last_run_at via PATCH.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association(session, user_id, report) + + now = datetime.now(timezone.utc).isoformat() + response = client.patch( + f"/user-reports/{assoc.id}", + json={"last_run_at": now}, + params={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + assert response.json()["last_run_at"] is not None + + +def test_update_wrong_user(client, session): + """Update with wrong user_id returns 404.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association( + session, user_id, report, label="Mine" + ) + + response = client.patch( + f"/user-reports/{assoc.id}", + json={"label": "Stolen"}, + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +def test_update_not_found(client): + """Update a non-existent association returns 404.""" + response = client.patch( + f"/user-reports/{uuid4()}", + json={"label": "Something"}, + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# DELETE /user-reports/{id}?user_id=... +# --------------------------------------------------------------------------- + + +def test_delete_association(client, session): + """Delete an association returns 204.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association(session, user_id, report) + + response = client.delete( + f"/user-reports/{assoc.id}", + params={"user_id": str(user_id)}, + ) + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/user-reports/{assoc.id}") + assert response.status_code == 404 + + +def test_delete_wrong_user(client, session): + """Delete with wrong user_id returns 404.""" + user_id = uuid4() + report = create_report(session) + assoc = create_user_report_association(session, user_id, report) + + response = client.delete( + f"/user-reports/{assoc.id}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 + + +def test_delete_not_found(client): + """Delete a non-existent association returns 404.""" + response = client.delete( + f"/user-reports/{uuid4()}", + params={"user_id": str(uuid4())}, + ) + assert response.status_code == 404 From a997f847bf3cd197e528fab52765f6fe84a37045 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 19 Feb 2026 18:10:46 +0100 Subject: [PATCH 6/8] feat: Add standalone simulation endpoints (/simulations/household, /simulations/economy) Non-blocking POST endpoints that create simulation records and return immediately with pending status. GET endpoints for polling status and retrieving results. Region validation for economy sims, deterministic UUID deduplication, and 19 tests. Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/api/simulations.py | 350 +++++++++++++++++- .../fixtures_simulations_standalone.py | 164 ++++++++ tests/test_simulations_standalone.py | 312 ++++++++++++++++ 3 files changed, 814 insertions(+), 12 deletions(-) create mode 100644 test_fixtures/fixtures_simulations_standalone.py create mode 100644 tests/test_simulations_standalone.py diff --git a/src/policyengine_api/api/simulations.py b/src/policyengine_api/api/simulations.py index 633c57c..bf57c32 100644 --- a/src/policyengine_api/api/simulations.py +++ b/src/policyengine_api/api/simulations.py @@ -1,36 +1,362 @@ -"""Simulation status endpoints. +"""Simulation endpoints. -Simulations are economy-wide tax-benefit calculations running on population datasets. -They are created automatically when you call /analysis/economic-impact. Use these -endpoints to check simulation status (pending, running, completed, failed). +Simulations are individual tax-benefit calculations. Use these endpoints to: +- Create and run household simulations (single household, single policy) +- Create and run economy simulations (population dataset, single policy) +- Check simulation status and retrieve results + +For baseline-vs-reform comparisons, use the /analysis/ endpoints instead. """ -from typing import List +from typing import Any, List, Literal from uuid import UUID from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field from sqlmodel import Session, select -from policyengine_api.models import Simulation, SimulationRead +from policyengine_api.models import ( + Dataset, + Household, + Policy, + Region, + Simulation, + SimulationRead, + SimulationStatus, + SimulationType, + TaxBenefitModel, +) from policyengine_api.services.database import get_session +from .analysis import ( + RegionInfo, + _get_model_version, + _get_or_create_simulation, +) + router = APIRouter(prefix="/simulations", tags=["simulations"]) +# --------------------------------------------------------------------------- +# Request / Response schemas +# --------------------------------------------------------------------------- + + +class HouseholdSimulationRequest(BaseModel): + """Request body for creating a household simulation.""" + + household_id: UUID = Field(description="ID of the stored household") + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID. If None, runs under current law.", + ) + dynamic_id: UUID | None = Field( + default=None, + description="Optional behavioural response specification ID", + ) + + +class HouseholdSimulationResponse(BaseModel): + """Response for a household simulation.""" + + id: UUID + status: SimulationStatus + household_id: UUID | None = None + policy_id: UUID | None = None + household_result: dict[str, Any] | None = None + error_message: str | None = None + + +class EconomySimulationRequest(BaseModel): + """Request body for creating an economy simulation.""" + + tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( + description="Which country model to use" + ) + region: str | None = Field( + default=None, + description="Region code (e.g., 'state/ca', 'us'). Either region or dataset_id must be provided.", + ) + dataset_id: UUID | None = Field( + default=None, + description="Dataset ID. Either region or dataset_id must be provided.", + ) + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID. If None, runs under current law.", + ) + dynamic_id: UUID | None = Field( + default=None, + description="Optional behavioural response specification ID", + ) + + +class EconomySimulationResponse(BaseModel): + """Response for an economy simulation.""" + + id: UUID + status: SimulationStatus + dataset_id: UUID | None = None + policy_id: UUID | None = None + output_dataset_id: UUID | None = None + filter_field: str | None = None + filter_value: str | None = None + region: RegionInfo | None = None + error_message: str | None = None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _resolve_economy_dataset( + tax_benefit_model_name: str, + region_code: str | None, + dataset_id: UUID | None, + session: Session, +) -> tuple[Dataset, Region | None]: + """Resolve dataset from region code or dataset_id for economy simulations.""" + if region_code: + model_name = tax_benefit_model_name.replace("_", "-") + region = session.exec( + select(Region) + .join(TaxBenefitModel) + .where(Region.code == region_code) + .where(TaxBenefitModel.name == model_name) + ).first() + if not region: + raise HTTPException( + status_code=404, + detail=f"Region '{region_code}' not found for model {model_name}", + ) + dataset = session.get(Dataset, region.dataset_id) + if not dataset: + raise HTTPException( + status_code=404, + detail=f"Dataset for region '{region_code}' not found", + ) + return dataset, region + + elif dataset_id: + dataset = session.get(Dataset, dataset_id) + if not dataset: + raise HTTPException( + status_code=404, + detail=f"Dataset {dataset_id} not found", + ) + return dataset, None + + else: + raise HTTPException( + status_code=400, + detail="Either region or dataset_id must be provided", + ) + + +def _build_household_response(simulation: Simulation) -> HouseholdSimulationResponse: + """Build response from a household simulation.""" + return HouseholdSimulationResponse( + id=simulation.id, + status=simulation.status, + household_id=simulation.household_id, + policy_id=simulation.policy_id, + household_result=simulation.household_result, + error_message=simulation.error_message, + ) + + +def _build_economy_response( + simulation: Simulation, region: Region | None = None +) -> EconomySimulationResponse: + """Build response from an economy simulation.""" + region_info = None + if region: + region_info = RegionInfo( + code=region.code, + label=region.label, + region_type=region.region_type, + requires_filter=region.requires_filter, + filter_field=region.filter_field, + filter_value=region.filter_value, + ) + + return EconomySimulationResponse( + id=simulation.id, + status=simulation.status, + dataset_id=simulation.dataset_id, + policy_id=simulation.policy_id, + output_dataset_id=simulation.output_dataset_id, + filter_field=simulation.filter_field, + filter_value=simulation.filter_value, + region=region_info, + error_message=simulation.error_message, + ) + + +# --------------------------------------------------------------------------- +# List / generic get (existing endpoints) +# --------------------------------------------------------------------------- + + @router.get("/", response_model=List[SimulationRead]) def list_simulations(session: Session = Depends(get_session)): - """List all simulations. - - Simulations are created automatically via /analysis/economic-impact. - Check status to see if computation is pending, running, completed, or failed. - """ + """List all simulations.""" simulations = session.exec(select(Simulation)).all() return simulations +# --------------------------------------------------------------------------- +# Household simulation endpoints +# --------------------------------------------------------------------------- + + +@router.post("/household", response_model=HouseholdSimulationResponse) +def create_household_simulation( + request: HouseholdSimulationRequest, + session: Session = Depends(get_session), +): + """Create a household simulation job. + + Creates a Simulation record for the given household and policy. + Returns immediately with status "pending". + Poll GET /simulations/household/{id} until status is "completed". + """ + # Validate household exists + household = session.get(Household, request.household_id) + if not household: + raise HTTPException( + status_code=404, + detail=f"Household {request.household_id} not found", + ) + + # Validate policy exists (if provided) + if request.policy_id: + policy = session.get(Policy, request.policy_id) + if not policy: + raise HTTPException( + status_code=404, + detail=f"Policy {request.policy_id} not found", + ) + + # Get model version + model_version = _get_model_version(household.tax_benefit_model_name, session) + + # Get or create simulation (deterministic UUID) + simulation = _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version.id, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + session=session, + household_id=request.household_id, + ) + + return _build_household_response(simulation) + + +@router.get("/household/{simulation_id}", response_model=HouseholdSimulationResponse) +def get_household_simulation( + simulation_id: UUID, + session: Session = Depends(get_session), +): + """Get a household simulation's status and result.""" + simulation = session.get(Simulation, simulation_id) + if not simulation: + raise HTTPException(status_code=404, detail="Simulation not found") + if simulation.simulation_type != SimulationType.HOUSEHOLD: + raise HTTPException( + status_code=400, + detail="Simulation is not a household simulation", + ) + + return _build_household_response(simulation) + + +# --------------------------------------------------------------------------- +# Economy simulation endpoints +# --------------------------------------------------------------------------- + + +@router.post("/economy", response_model=EconomySimulationResponse) +def create_economy_simulation( + request: EconomySimulationRequest, + session: Session = Depends(get_session), +): + """Create a single economy simulation. + + Creates a Simulation record for the given dataset/region and policy. + Poll GET /simulations/economy/{id} until status is "completed". + + Note: standalone economy simulation computation will be connected + in future tasks. For full baseline-vs-reform economy analysis, + use POST /analysis/economic-impact instead. + """ + # Resolve dataset and region + dataset, region = _resolve_economy_dataset( + request.tax_benefit_model_name, + request.region, + request.dataset_id, + session, + ) + + # Validate policy exists (if provided) + if request.policy_id: + policy = session.get(Policy, request.policy_id) + if not policy: + raise HTTPException( + status_code=404, + detail=f"Policy {request.policy_id} not found", + ) + + # Extract filter parameters from region + filter_field = region.filter_field if region and region.requires_filter else None + filter_value = region.filter_value if region and region.requires_filter else None + + # Get model version + model_version = _get_model_version(request.tax_benefit_model_name, session) + + # Get or create simulation (deterministic UUID) + simulation = _get_or_create_simulation( + simulation_type=SimulationType.ECONOMY, + model_version_id=model_version.id, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + session=session, + dataset_id=dataset.id, + filter_field=filter_field, + filter_value=filter_value, + ) + + return _build_economy_response(simulation, region) + + +@router.get("/economy/{simulation_id}", response_model=EconomySimulationResponse) +def get_economy_simulation( + simulation_id: UUID, + session: Session = Depends(get_session), +): + """Get an economy simulation's status and result.""" + simulation = session.get(Simulation, simulation_id) + if not simulation: + raise HTTPException(status_code=404, detail="Simulation not found") + if simulation.simulation_type != SimulationType.ECONOMY: + raise HTTPException( + status_code=400, + detail="Simulation is not an economy simulation", + ) + + return _build_economy_response(simulation) + + +# --------------------------------------------------------------------------- +# Generic get (keep after specific routes to avoid path conflicts) +# --------------------------------------------------------------------------- + + @router.get("/{simulation_id}", response_model=SimulationRead) def get_simulation(simulation_id: UUID, session: Session = Depends(get_session)): - """Get a specific simulation.""" + """Get a specific simulation (any type).""" simulation = session.get(Simulation, simulation_id) if not simulation: raise HTTPException(status_code=404, detail="Simulation not found") diff --git a/test_fixtures/fixtures_simulations_standalone.py b/test_fixtures/fixtures_simulations_standalone.py new file mode 100644 index 0000000..314afe5 --- /dev/null +++ b/test_fixtures/fixtures_simulations_standalone.py @@ -0,0 +1,164 @@ +"""Fixtures and helpers for standalone simulation endpoint tests.""" + +from uuid import UUID + +from policyengine_api.models import ( + Dataset, + Household, + Policy, + Region, + Simulation, + SimulationStatus, + SimulationType, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +def create_us_model_and_version(session) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create a US tax-benefit model and version.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="test", description="Test version" + ) + session.add(version) + session.commit() + session.refresh(version) + + return model, version + + +def create_uk_model_and_version(session) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create a UK tax-benefit model and version.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="test", description="Test version" + ) + session.add(version) + session.commit() + session.refresh(version) + + return model, version + + +def create_household( + session, + tax_benefit_model_name: str = "policyengine_us", + year: int = 2024, + label: str = "Test household", +) -> Household: + """Create and persist a Household record.""" + household = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data={ + "people": [{"age": {"2024": 30}, "employment_income": {"2024": 50000}}], + "household": [{"state_code": {"2024": "CA"}}], + }, + ) + session.add(household) + session.commit() + session.refresh(household) + return household + + +def create_policy(session, model: TaxBenefitModel) -> Policy: + """Create and persist a Policy record.""" + policy = Policy( + name="Test reform", + description="A test reform policy", + tax_benefit_model_id=model.id, + ) + session.add(policy) + session.commit() + session.refresh(policy) + return policy + + +def create_dataset(session, model: TaxBenefitModel) -> Dataset: + """Create and persist a Dataset record.""" + dataset = Dataset( + name="test_dataset", + description="Test dataset", + filepath="test/path/dataset.h5", + year=2024, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + return dataset + + +def create_region( + session, + model: TaxBenefitModel, + dataset: Dataset, + code: str = "us", + label: str = "United States", + region_type: str = "country", + requires_filter: bool = False, + filter_field: str | None = None, + filter_value: str | None = None, +) -> Region: + """Create and persist a Region record.""" + region = Region( + code=code, + label=label, + region_type=region_type, + requires_filter=requires_filter, + filter_field=filter_field, + filter_value=filter_value, + dataset_id=dataset.id, + tax_benefit_model_id=model.id, + ) + session.add(region) + session.commit() + session.refresh(region) + return region + + +def create_economy_simulation( + session, + version: TaxBenefitModelVersion, + dataset: Dataset, +) -> Simulation: + """Create and persist an economy Simulation record.""" + simulation = Simulation( + simulation_type=SimulationType.ECONOMY, + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + return simulation + + +def create_household_simulation( + session, + version: TaxBenefitModelVersion, + household: Household, +) -> Simulation: + """Create and persist a household Simulation record.""" + simulation = Simulation( + simulation_type=SimulationType.HOUSEHOLD, + household_id=household.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + household_result={"person": [{"income_tax": {"2024": 5000}}]}, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + return simulation diff --git a/tests/test_simulations_standalone.py b/tests/test_simulations_standalone.py new file mode 100644 index 0000000..97d5743 --- /dev/null +++ b/tests/test_simulations_standalone.py @@ -0,0 +1,312 @@ +"""Tests for standalone simulation endpoints (/simulations/household, /simulations/economy).""" + +from uuid import uuid4 + +from test_fixtures.fixtures_simulations_standalone import ( + create_dataset, + create_economy_simulation, + create_household, + create_household_simulation, + create_policy, + create_region, + create_uk_model_and_version, + create_us_model_and_version, +) + + +# =========================================================================== +# POST /simulations/household +# =========================================================================== + + +def test_create_household_simulation(client, session): + """Create a household simulation returns 200 with pending status.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + + payload = {"household_id": str(household.id)} + response = client.post("/simulations/household", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["household_id"] == str(household.id) + assert data["household_result"] is None + assert data["policy_id"] is None + + +def test_create_household_simulation_with_policy(client, session): + """Create a household simulation with a reform policy.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + policy = create_policy(session, model) + + payload = { + "household_id": str(household.id), + "policy_id": str(policy.id), + } + response = client.post("/simulations/household", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["policy_id"] == str(policy.id) + + +def test_create_household_simulation_not_found(client, session): + """Creating with a non-existent household returns 404.""" + model, version = create_us_model_and_version(session) + payload = {"household_id": str(uuid4())} + response = client.post("/simulations/household", json=payload) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_create_household_simulation_policy_not_found(client, session): + """Creating with a non-existent policy returns 404.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + + payload = { + "household_id": str(household.id), + "policy_id": str(uuid4()), + } + response = client.post("/simulations/household", json=payload) + + assert response.status_code == 404 + assert "Policy" in response.json()["detail"] + + +def test_household_simulation_deduplication(client, session): + """Same inputs produce the same simulation (deterministic UUID).""" + model, version = create_us_model_and_version(session) + household = create_household(session) + + payload = {"household_id": str(household.id)} + response1 = client.post("/simulations/household", json=payload) + response2 = client.post("/simulations/household", json=payload) + + assert response1.status_code == 200 + assert response2.status_code == 200 + assert response1.json()["id"] == response2.json()["id"] + + +# =========================================================================== +# GET /simulations/household/{id} +# =========================================================================== + + +def test_get_household_simulation(client, session): + """Get a household simulation by ID.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + simulation = create_household_simulation(session, version, household) + + response = client.get(f"/simulations/household/{simulation.id}") + + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(simulation.id) + assert data["status"] == "completed" + assert data["household_result"] is not None + + +def test_get_household_simulation_not_found(client, session): + """Get a non-existent household simulation returns 404.""" + response = client.get(f"/simulations/household/{uuid4()}") + assert response.status_code == 404 + + +def test_get_household_simulation_wrong_type(client, session): + """Get an economy simulation via the household endpoint returns 400.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + economy_sim = create_economy_simulation(session, version, dataset) + + response = client.get(f"/simulations/household/{economy_sim.id}") + assert response.status_code == 400 + assert "not a household simulation" in response.json()["detail"] + + +# =========================================================================== +# POST /simulations/economy +# =========================================================================== + + +def test_create_economy_simulation_with_region(client, session): + """Create an economy simulation using a region code.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + region = create_region(session, model, dataset, code="us", label="United States") + + payload = { + "tax_benefit_model_name": "policyengine_us", + "region": "us", + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["dataset_id"] == str(dataset.id) + assert data["region"]["code"] == "us" + assert data["region"]["label"] == "United States" + + +def test_create_economy_simulation_with_dataset(client, session): + """Create an economy simulation using a dataset_id directly.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "dataset_id": str(dataset.id), + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["dataset_id"] == str(dataset.id) + assert data["region"] is None + + +def test_create_economy_simulation_with_region_filter(client, session): + """Create an economy simulation with a region that requires filtering.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + region = create_region( + session, + model, + dataset, + code="state/ca", + label="California", + region_type="state", + requires_filter=True, + filter_field="state_code", + filter_value="CA", + ) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "region": "state/ca", + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["filter_field"] == "state_code" + assert data["filter_value"] == "CA" + assert data["region"]["requires_filter"] is True + + +def test_create_economy_simulation_invalid_region(client, session): + """Creating with a non-existent region returns 404.""" + model, version = create_us_model_and_version(session) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "region": "nonexistent/region", + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_create_economy_simulation_no_region_or_dataset(client, session): + """Creating without region or dataset_id returns 400.""" + model, version = create_us_model_and_version(session) + + payload = {"tax_benefit_model_name": "policyengine_us"} + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 400 + assert "Either region or dataset_id" in response.json()["detail"] + + +def test_create_economy_simulation_policy_not_found(client, session): + """Creating with a non-existent policy returns 404.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "dataset_id": str(dataset.id), + "policy_id": str(uuid4()), + } + response = client.post("/simulations/economy", json=payload) + + assert response.status_code == 404 + assert "Policy" in response.json()["detail"] + + +def test_economy_simulation_deduplication(client, session): + """Same inputs produce the same simulation (deterministic UUID).""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + + payload = { + "tax_benefit_model_name": "policyengine_us", + "dataset_id": str(dataset.id), + } + response1 = client.post("/simulations/economy", json=payload) + response2 = client.post("/simulations/economy", json=payload) + + assert response1.status_code == 200 + assert response2.status_code == 200 + assert response1.json()["id"] == response2.json()["id"] + + +# =========================================================================== +# GET /simulations/economy/{id} +# =========================================================================== + + +def test_get_economy_simulation(client, session): + """Get an economy simulation by ID.""" + model, version = create_us_model_and_version(session) + dataset = create_dataset(session, model) + simulation = create_economy_simulation(session, version, dataset) + + response = client.get(f"/simulations/economy/{simulation.id}") + + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(simulation.id) + assert data["status"] == "completed" + + +def test_get_economy_simulation_not_found(client, session): + """Get a non-existent economy simulation returns 404.""" + response = client.get(f"/simulations/economy/{uuid4()}") + assert response.status_code == 404 + + +def test_get_economy_simulation_wrong_type(client, session): + """Get a household simulation via the economy endpoint returns 400.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + household_sim = create_household_simulation(session, version, household) + + response = client.get(f"/simulations/economy/{household_sim.id}") + assert response.status_code == 400 + assert "not an economy simulation" in response.json()["detail"] + + +# =========================================================================== +# Generic GET /simulations/{id} still works +# =========================================================================== + + +def test_get_simulation_generic(client, session): + """The generic GET /simulations/{id} endpoint still works for any type.""" + model, version = create_us_model_and_version(session) + household = create_household(session) + simulation = create_household_simulation(session, version, household) + + response = client.get(f"/simulations/{simulation.id}") + + assert response.status_code == 200 + assert response.json()["id"] == str(simulation.id) From 5e07daa5b8020573b5bec12ff1a5581c2b239d73 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 19 Feb 2026 18:53:10 +0100 Subject: [PATCH 7/8] feat: Add household_impact_uk and household_impact_us Modal functions These report-based Modal functions are called by _trigger_household_impact() when agent_use_modal=True. They load the Report and its Simulations from the database, run household calculations (reusing _calculate_household_uk/us from household.py), store results in simulation.household_result, and mark the report as completed. Follows the economy_comparison_uk/us error handling pattern with raw SQL fallback for failure marking. Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/modal_app.py | 288 ++++++++++++++++++++++++++++++ 1 file changed, 288 insertions(+) diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index 84f8e89..0feb0d7 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -1736,6 +1736,294 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: logfire.force_flush() +# --------------------------------------------------------------------------- +# Household impact (report-based) — called by _trigger_household_impact() +# --------------------------------------------------------------------------- + + +@app.function( + image=uk_image, + secrets=[db_secrets, logfire_secrets], + memory=4096, + cpu=4, + timeout=600, +) +def household_impact_uk(report_id: str, traceparent: str | None = None) -> None: + """Run UK household impact analysis for a report. + + Loads the Report and its Simulations from the database, runs household + calculations for each simulation, stores results, and marks the report + as completed. Called via Modal.spawn() from _trigger_household_impact(). + """ + import logfire + + configure_logfire("policyengine-modal-uk", traceparent) + + try: + with logfire.span("household_impact_uk", report_id=report_id): + from datetime import datetime, timezone + from uuid import UUID + + from sqlmodel import Session, create_engine + + database_url = get_database_url() + engine = create_engine(database_url) + + try: + from policyengine_api.api.household import _calculate_household_uk + from policyengine_api.api.household_analysis import ( + _ensure_list, + _extract_policy_data, + ) + from policyengine_api.models import ( + Household, + Report, + ReportStatus, + Simulation, + SimulationStatus, + ) + + with Session(engine) as session: + report = session.get(Report, UUID(report_id)) + if not report: + raise ValueError(f"Report {report_id} not found") + + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Run each simulation (baseline, then reform if present) + for sim_id in [ + report.baseline_simulation_id, + report.reform_simulation_id, + ]: + if not sim_id: + continue + + simulation = session.get(Simulation, sim_id) + if not simulation or simulation.status != SimulationStatus.PENDING: + continue + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError( + f"Household {simulation.household_id} not found" + ) + + # Convert policy to calculation format + policy_data = None + if simulation.policy_id: + from policyengine_api.models import Policy + + policy = session.get(Policy, simulation.policy_id) + policy_data = _extract_policy_data(policy) + + # Mark simulation as running + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + hh_data = household.household_data + with logfire.span( + "run_household_calculation", + simulation_id=str(sim_id), + ): + result = _calculate_household_uk( + people=hh_data.get("people", []), + benunit=_ensure_list(hh_data.get("benunit")), + household=_ensure_list(hh_data.get("household")), + year=household.year, + policy_data=policy_data, + ) + + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + + except Exception as e: + logfire.error( + "UK household impact failed", + report_id=report_id, + error=str(e), + ) + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE reports SET status = 'FAILED', " + "error_message = :error WHERE id = :report_id" + ), + {"report_id": report_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise + finally: + logfire.force_flush() + + +@app.function( + image=us_image, + secrets=[db_secrets, logfire_secrets], + memory=4096, + cpu=4, + timeout=600, +) +def household_impact_us(report_id: str, traceparent: str | None = None) -> None: + """Run US household impact analysis for a report. + + Loads the Report and its Simulations from the database, runs household + calculations for each simulation, stores results, and marks the report + as completed. Called via Modal.spawn() from _trigger_household_impact(). + """ + import logfire + + configure_logfire("policyengine-modal-us", traceparent) + + try: + with logfire.span("household_impact_us", report_id=report_id): + from datetime import datetime, timezone + from uuid import UUID + + from sqlmodel import Session, create_engine + + database_url = get_database_url() + engine = create_engine(database_url) + + try: + from policyengine_api.api.household import _calculate_household_us + from policyengine_api.api.household_analysis import ( + _ensure_list, + _extract_policy_data, + ) + from policyengine_api.models import ( + Household, + Report, + ReportStatus, + Simulation, + SimulationStatus, + ) + + with Session(engine) as session: + report = session.get(Report, UUID(report_id)) + if not report: + raise ValueError(f"Report {report_id} not found") + + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Run each simulation (baseline, then reform if present) + for sim_id in [ + report.baseline_simulation_id, + report.reform_simulation_id, + ]: + if not sim_id: + continue + + simulation = session.get(Simulation, sim_id) + if not simulation or simulation.status != SimulationStatus.PENDING: + continue + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError( + f"Household {simulation.household_id} not found" + ) + + # Convert policy to calculation format + policy_data = None + if simulation.policy_id: + from policyengine_api.models import Policy + + policy = session.get(Policy, simulation.policy_id) + policy_data = _extract_policy_data(policy) + + # Mark simulation as running + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + hh_data = household.household_data + with logfire.span( + "run_household_calculation", + simulation_id=str(sim_id), + ): + result = _calculate_household_us( + people=hh_data.get("people", []), + marital_unit=_ensure_list( + hh_data.get("marital_unit") + ), + family=_ensure_list(hh_data.get("family")), + spm_unit=_ensure_list(hh_data.get("spm_unit")), + tax_unit=_ensure_list(hh_data.get("tax_unit")), + household=_ensure_list(hh_data.get("household")), + year=household.year, + policy_data=policy_data, + ) + + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + + except Exception as e: + logfire.error( + "US household impact failed", + report_id=report_id, + error=str(e), + ) + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE reports SET status = 'FAILED', " + "error_message = :error WHERE id = :report_id" + ), + {"report_id": report_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise + finally: + logfire.force_flush() + + def _get_pe_policy_uk(policy_id, model_version, session): """Convert database Policy to policyengine Policy for UK.""" if policy_id is None: From 609b47ef24875c4f5c9ae04cbc192bad193ebc09 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 19 Feb 2026 22:19:15 +0100 Subject: [PATCH 8/8] feat: Add _run_local_economy_comparison_us local fallback Implements the US counterpart to _run_local_economy_comparison_uk so that US economy comparisons work when agent_use_modal=False (local dev, testing, CI). Uses us_latest model, PolicyEngineUSDataset, and US-specific programs (income_tax, employee_payroll_tax, snap, tanf, ssi, social_security). Updates _trigger_economy_comparison to call it instead of falling back to Modal. Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/api/analysis.py | 218 ++++++++++++++++++++++++++- 1 file changed, 213 insertions(+), 5 deletions(-) diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index 17539d7..ba739c7 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -591,6 +591,218 @@ def build_dynamic(dynamic_id): session.commit() +def _run_local_economy_comparison_us(job_id: str, session: Session) -> None: + """Run US economy comparison analysis locally.""" + from datetime import datetime, timezone + from uuid import UUID + + from policyengine.core import Simulation as PESimulation + from policyengine.core.dynamic import Dynamic as PEDynamic + from policyengine.core.policy import ParameterValue as PEParameterValue + from policyengine.core.policy import Policy as PEPolicy + from policyengine.outputs import DecileImpact as PEDecileImpact + from policyengine.tax_benefit_models.us import us_latest + from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset + from policyengine.tax_benefit_models.us.outputs import ( + ProgramStatistics as PEProgramStats, + ) + + from policyengine_api.models import Policy as DBPolicy + + # Load report and simulations + report = session.get(Report, UUID(job_id)) + if not report: + raise ValueError(f"Report {job_id} not found") + + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + reform_sim = session.get(Simulation, report.reform_simulation_id) + + if not baseline_sim or not reform_sim: + raise ValueError("Simulations not found") + + # Update status to running + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Get dataset + dataset = session.get(Dataset, baseline_sim.dataset_id) + if not dataset: + raise ValueError(f"Dataset {baseline_sim.dataset_id} not found") + + pe_model_version = us_latest + param_lookup = {p.name: p for p in pe_model_version.parameters} + + def build_policy(policy_id): + if not policy_id: + return None + db_policy = session.get(DBPolicy, policy_id) + if not db_policy: + return None + pe_param_values = [] + for pv in db_policy.parameter_values: + if not pv.parameter: + continue + pe_param = param_lookup.get(pv.parameter.name) + if not pe_param: + continue + pe_pv = PEParameterValue( + parameter=pe_param, + value=pv.value_json.get("value") + if isinstance(pv.value_json, dict) + else pv.value_json, + start_date=pv.start_date, + end_date=pv.end_date, + ) + pe_param_values.append(pe_pv) + return PEPolicy( + name=db_policy.name, + description=db_policy.description, + parameter_values=pe_param_values, + ) + + def build_dynamic(dynamic_id): + if not dynamic_id: + return None + from policyengine_api.models import Dynamic as DBDynamic + + db_dynamic = session.get(DBDynamic, dynamic_id) + if not db_dynamic: + return None + pe_param_values = [] + for pv in db_dynamic.parameter_values: + if not pv.parameter: + continue + pe_param = param_lookup.get(pv.parameter.name) + if not pe_param: + continue + pe_pv = PEParameterValue( + parameter=pe_param, + value=pv.value_json.get("value") + if isinstance(pv.value_json, dict) + else pv.value_json, + start_date=pv.start_date, + end_date=pv.end_date, + ) + pe_param_values.append(pe_pv) + return PEDynamic( + name=db_dynamic.name, + description=db_dynamic.description, + parameter_values=pe_param_values, + ) + + baseline_policy = build_policy(baseline_sim.policy_id) + reform_policy = build_policy(reform_sim.policy_id) + baseline_dynamic = build_dynamic(baseline_sim.dynamic_id) + reform_dynamic = build_dynamic(reform_sim.dynamic_id) + + # Download dataset + local_path = _download_dataset_local(dataset.filepath) + pe_dataset = PolicyEngineUSDataset( + name=dataset.name, + description=dataset.description or "", + filepath=local_path, + year=dataset.year, + ) + + # Run simulations (with optional regional filtering) + pe_baseline_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=baseline_policy, + dynamic=baseline_dynamic, + filter_field=baseline_sim.filter_field, + filter_value=baseline_sim.filter_value, + ) + pe_baseline_sim.ensure() + + pe_reform_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=reform_policy, + dynamic=reform_dynamic, + filter_field=reform_sim.filter_field, + filter_value=reform_sim.filter_value, + ) + pe_reform_sim.ensure() + + # Calculate decile impacts + for decile_num in range(1, 11): + di = PEDecileImpact( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + decile=decile_num, + ) + di.run() + decile_impact = DecileImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + income_variable=di.income_variable, + entity=di.entity, + decile=di.decile, + quantiles=di.quantiles, + baseline_mean=di.baseline_mean, + reform_mean=di.reform_mean, + absolute_change=di.absolute_change, + relative_change=di.relative_change, + count_better_off=di.count_better_off, + count_worse_off=di.count_worse_off, + count_no_change=di.count_no_change, + ) + session.add(decile_impact) + + # Calculate program statistics + PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) + programs = { + "income_tax": {"entity": "tax_unit", "is_tax": True}, + "employee_payroll_tax": {"entity": "person", "is_tax": True}, + "snap": {"entity": "spm_unit", "is_tax": False}, + "tanf": {"entity": "spm_unit", "is_tax": False}, + "ssi": {"entity": "spm_unit", "is_tax": False}, + "social_security": {"entity": "person", "is_tax": False}, + } + for prog_name, prog_info in programs.items(): + try: + ps = PEProgramStats( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + ) + ps.run() + program_stat = ProgramStatistics( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + baseline_total=ps.baseline_total, + reform_total=ps.reform_total, + change=ps.change, + baseline_count=ps.baseline_count, + reform_count=ps.reform_count, + winners=ps.winners, + losers=ps.losers, + ) + session.add(program_stat) + except KeyError: + pass # Variable not found in model + + # Mark completed + baseline_sim.status = SimulationStatus.COMPLETED + baseline_sim.completed_at = datetime.now(timezone.utc) + reform_sim.status = SimulationStatus.COMPLETED + reform_sim.completed_at = datetime.now(timezone.utc) + report.status = ReportStatus.COMPLETED + session.add(baseline_sim) + session.add(reform_sim) + session.add(report) + session.commit() + + def _trigger_economy_comparison( job_id: str, tax_benefit_model_name: str, session: Session | None = None ) -> None: @@ -604,11 +816,7 @@ def _trigger_economy_comparison( if tax_benefit_model_name == "policyengine_uk": _run_local_economy_comparison_uk(job_id, session) else: - # US not implemented for local yet - fall back to Modal - import modal - - fn = modal.Function.from_name("policyengine", "economy_comparison_us") - fn.spawn(job_id=job_id, traceparent=traceparent) + _run_local_economy_comparison_us(job_id, session) else: # Use Modal import modal