diff --git a/.gitignore b/.gitignore index 2f8c3e6..bb80186 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,6 @@ docs/.env.local data/ *.h5 *.db + +# macOS +.DS_Store diff --git a/alembic/versions/20260205_0002_add_user_policies.py b/alembic/versions/20260205_0002_add_user_policies.py new file mode 100644 index 0000000..5f061de --- /dev/null +++ b/alembic/versions/20260205_0002_add_user_policies.py @@ -0,0 +1,94 @@ +"""add_user_policies + +Revision ID: 0002_user_policies +Revises: 36f9d434e95b +Create Date: 2026-02-05 + +This migration adds: +1. tax_benefit_model_id foreign key to policies table +2. user_policies table for user-policy associations + +Note: user_id in user_policies is NOT a foreign key to users table. +It's a client-generated UUID stored in localStorage, allowing anonymous +users to save policies without authentication. +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0002_user_policies" +down_revision: Union[str, Sequence[str], None] = "36f9d434e95b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add user_policies table and policy.tax_benefit_model_id.""" + # Add tax_benefit_model_id to policies table + op.add_column( + "policies", sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False) + ) + op.create_index( + op.f("ix_policies_tax_benefit_model_id"), + "policies", + ["tax_benefit_model_id"], + unique=False, + ) + op.create_foreign_key( + "fk_policies_tax_benefit_model_id", + "policies", + "tax_benefit_models", + ["tax_benefit_model_id"], + ["id"], + ) + + # Create user_policies table + # Note: user_id is NOT a foreign key - it's a client-generated UUID from localStorage + op.create_table( + "user_policies", + sa.Column("user_id", sa.Uuid(), nullable=False), + sa.Column("policy_id", sa.Uuid(), nullable=False), + sa.Column("country_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_user_policies_policy_id"), + "user_policies", + ["policy_id"], + unique=False, + ) + op.create_index( + op.f("ix_user_policies_user_id"), "user_policies", ["user_id"], unique=False + ) + op.create_index( + op.f("ix_user_policies_country_id"), + "user_policies", + ["country_id"], + unique=False, + ) + + +def downgrade() -> None: + """Remove user_policies table and policy.tax_benefit_model_id.""" + # Drop user_policies table + op.drop_index(op.f("ix_user_policies_country_id"), table_name="user_policies") + op.drop_index(op.f("ix_user_policies_user_id"), table_name="user_policies") + op.drop_index(op.f("ix_user_policies_policy_id"), table_name="user_policies") + op.drop_table("user_policies") + + # Remove tax_benefit_model_id from policies + op.drop_constraint( + "fk_policies_tax_benefit_model_id", "policies", type_="foreignkey" + ) + op.drop_index(op.f("ix_policies_tax_benefit_model_id"), table_name="policies") + op.drop_column("policies", "tax_benefit_model_id") diff --git a/src/policyengine_api/agent_sandbox.py b/src/policyengine_api/agent_sandbox.py index 9d0436c..bcf1ab0 100644 --- a/src/policyengine_api/agent_sandbox.py +++ b/src/policyengine_api/agent_sandbox.py @@ -235,8 +235,7 @@ def openapi_to_claude_tools(spec: dict) -> list[dict]: prop = schema_to_json_schema(spec, param_schema) prop["description"] = ( - param.get("description", "") - + f" (in: {param_in})" + param.get("description", "") + f" (in: {param_in})" ) properties[param_name] = prop @@ -268,16 +267,18 @@ def openapi_to_claude_tools(spec: dict) -> list[dict]: if required: input_schema["required"] = list(set(required)) - tools.append({ - "name": tool_name, - "description": full_desc[:1024], # Claude has limits - "input_schema": input_schema, - "_meta": { - "path": path, - "method": method, - "parameters": operation.get("parameters", []), - }, - }) + tools.append( + { + "name": tool_name, + "description": full_desc[:1024], # Claude has limits + "input_schema": input_schema, + "_meta": { + "path": path, + "method": method, + "parameters": operation.get("parameters", []), + }, + } + ) return tools @@ -347,7 +348,9 @@ def execute_api_tool( url, params=query_params, json=body_data, headers=headers, timeout=60 ) elif method == "delete": - resp = requests.delete(url, params=query_params, headers=headers, timeout=60) + resp = requests.delete( + url, params=query_params, headers=headers, timeout=60 + ) else: return f"Unsupported method: {method}" @@ -415,9 +418,7 @@ def log(msg: str) -> None: tool_lookup = {t["name"]: t for t in tools} # Strip _meta from tools before sending to Claude (it doesn't need it) - claude_tools = [ - {k: v for k, v in t.items() if k != "_meta"} for t in tools - ] + claude_tools = [{k: v for k, v in t.items() if k != "_meta"} for t in tools] # Add the sleep tool claude_tools.append(SLEEP_TOOL) @@ -477,11 +478,13 @@ def log(msg: str) -> None: log(f"[TOOL_RESULT] {result[:300]}") - tool_results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": result, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": result, + } + ) messages.append({"role": "assistant", "content": assistant_content}) diff --git a/src/policyengine_api/api/__init__.py b/src/policyengine_api/api/__init__.py index f135b14..7e94d29 100644 --- a/src/policyengine_api/api/__init__.py +++ b/src/policyengine_api/api/__init__.py @@ -20,6 +20,7 @@ tax_benefit_model_versions, tax_benefit_models, user_household_associations, + user_policies, variables, ) @@ -43,5 +44,6 @@ api_router.include_router(analysis.router) api_router.include_router(agent.router) api_router.include_router(user_household_associations.router) +api_router.include_router(user_policies.router) __all__ = ["api_router"] diff --git a/src/policyengine_api/api/agent.py b/src/policyengine_api/api/agent.py index 7b7d108..6c26e80 100644 --- a/src/policyengine_api/api/agent.py +++ b/src/policyengine_api/api/agent.py @@ -24,6 +24,7 @@ def get_traceparent() -> str | None: TraceContextTextMapPropagator().inject(carrier) return carrier.get("traceparent") + router = APIRouter(prefix="/agent", tags=["agent"]) @@ -93,7 +94,9 @@ def _run_local_agent( from policyengine_api.agent_sandbox import _run_agent_impl try: - history_dicts = [{"role": m.role, "content": m.content} for m in (history or [])] + history_dicts = [ + {"role": m.role, "content": m.content} for m in (history or []) + ] result = _run_agent_impl(question, api_base_url, call_id, history_dicts) _calls[call_id]["status"] = result.get("status", "completed") _calls[call_id]["result"] = result @@ -136,9 +139,15 @@ async def run_agent(request: RunRequest) -> RunResponse: traceparent = get_traceparent() run_fn = modal.Function.from_name("policyengine-sandbox", "run_agent") - history_dicts = [{"role": m.role, "content": m.content} for m in request.history] + history_dicts = [ + {"role": m.role, "content": m.content} for m in request.history + ] call = run_fn.spawn( - request.question, api_base_url, call_id, history_dicts, traceparent=traceparent + request.question, + api_base_url, + call_id, + history_dicts, + traceparent=traceparent, ) _calls[call_id] = { @@ -166,7 +175,12 @@ async def run_agent(request: RunRequest) -> RunResponse: # Run in background using asyncio loop = asyncio.get_event_loop() loop.run_in_executor( - None, _run_local_agent, call_id, request.question, api_base_url, request.history + None, + _run_local_agent, + call_id, + request.question, + api_base_url, + request.history, ) return RunResponse(call_id=call_id, status="running") diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index 0e89b5e..eba986d 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -300,11 +300,13 @@ def _calculate_household_uk( from pathlib import Path import pandas as pd - from policyengine.core import Simulation from microdf import MicroDataFrame + from policyengine.core import Simulation from policyengine.tax_benefit_models.uk import uk_latest - from policyengine.tax_benefit_models.uk.datasets import PolicyEngineUKDataset - from policyengine.tax_benefit_models.uk.datasets import UKYearData + from policyengine.tax_benefit_models.uk.datasets import ( + PolicyEngineUKDataset, + UKYearData, + ) n_people = len(people) n_benunits = max(1, len(benunit)) @@ -466,7 +468,14 @@ def _run_local_household_us( try: result = _calculate_household_us( - people, marital_unit, family, spm_unit, tax_unit, household, year, policy_data + people, + marital_unit, + family, + spm_unit, + tax_unit, + household, + year, + policy_data, ) # Update job with result @@ -512,11 +521,13 @@ def _calculate_household_us( from pathlib import Path import pandas as pd - from policyengine.core import Simulation from microdf import MicroDataFrame + from policyengine.core import Simulation from policyengine.tax_benefit_models.us import us_latest - from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset - from policyengine.tax_benefit_models.us.datasets import USYearData + from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + USYearData, + ) n_people = len(people) n_households = max(1, len(household)) @@ -672,7 +683,9 @@ def safe_convert(value): except (ValueError, TypeError): return str(value) - def extract_entity_outputs(entity_name: str, entity_data, n_rows: int) -> list[dict]: + def extract_entity_outputs( + entity_name: str, entity_data, n_rows: int + ) -> list[dict]: outputs = [] for i in range(n_rows): row_dict = {} diff --git a/src/policyengine_api/api/policies.py b/src/policyengine_api/api/policies.py index d0e2ca5..ad5397d 100644 --- a/src/policyengine_api/api/policies.py +++ b/src/policyengine_api/api/policies.py @@ -31,7 +31,7 @@ from typing import List from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session, select from policyengine_api.models import ( @@ -40,6 +40,7 @@ Policy, PolicyCreate, PolicyRead, + TaxBenefitModel, ) from policyengine_api.services.database import get_session @@ -67,8 +68,17 @@ def create_policy(policy: PolicyCreate, session: Session = Depends(get_session)) ] } """ + # Validate tax_benefit_model exists + tax_model = session.get(TaxBenefitModel, policy.tax_benefit_model_id) + if not tax_model: + raise HTTPException(status_code=404, detail="Tax benefit model not found") + # Create the policy - db_policy = Policy(name=policy.name, description=policy.description) + db_policy = Policy( + name=policy.name, + description=policy.description, + tax_benefit_model_id=policy.tax_benefit_model_id, + ) session.add(db_policy) session.flush() # Get the policy ID before adding parameter values @@ -112,10 +122,17 @@ def create_policy(policy: PolicyCreate, session: Session = Depends(get_session)) @router.get("/", response_model=List[PolicyRead]) -def list_policies(session: Session = Depends(get_session)): - """List all policies.""" - policies = session.exec(select(Policy)).all() - return policies +def list_policies( + tax_benefit_model_id: UUID | None = Query( + None, description="Filter by tax benefit model" + ), + session: Session = Depends(get_session), +): + """List all policies, optionally filtered by tax benefit model.""" + query = select(Policy) + if tax_benefit_model_id: + query = query.where(Policy.tax_benefit_model_id == tax_benefit_model_id) + return session.exec(query).all() @router.get("/{policy_id}", response_model=PolicyRead) diff --git a/src/policyengine_api/api/user_policies.py b/src/policyengine_api/api/user_policies.py new file mode 100644 index 0000000..3cc2c08 --- /dev/null +++ b/src/policyengine_api/api/user_policies.py @@ -0,0 +1,147 @@ +"""User-policy association endpoints. + +Associates users with policies they've saved/created. This enables users to +maintain a list of their policies across sessions without duplicating the +underlying policy data. + +Note: user_id is a client-generated UUID (via crypto.randomUUID()) stored in +the browser's localStorage. It is NOT validated against a users table, allowing +anonymous users to save policies without authentication. +""" + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.config.constants import CountryId +from policyengine_api.models import ( + Policy, + UserPolicy, + UserPolicyCreate, + UserPolicyRead, + UserPolicyUpdate, +) +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/user-policies", tags=["user-policies"]) + + +@router.post("/", response_model=UserPolicyRead) +def create_user_policy( + user_policy: UserPolicyCreate, + session: Session = Depends(get_session), +): + """Create a new user-policy association. + + Associates a user with a policy, allowing them to save it to their list. + Duplicates are allowed - users can save the same policy multiple times + with different labels (matching FE localStorage behavior). + + Note: user_id is not validated - it's a client-generated UUID from localStorage. + Note: country_id is validated via Pydantic Literal type to "us" or "uk". + """ + # Validate policy exists + policy = session.get(Policy, user_policy.policy_id) + if not policy: + raise HTTPException(status_code=404, detail="Policy not found") + + # Create the association (duplicates allowed) + db_user_policy = UserPolicy.model_validate(user_policy) + session.add(db_user_policy) + session.commit() + session.refresh(db_user_policy) + return db_user_policy + + +@router.get("/", response_model=list[UserPolicyRead]) +def list_user_policies( + user_id: UUID = Query(..., description="User ID to filter by"), + country_id: CountryId | None = Query( + None, description="Filter by country ('us' or 'uk')" + ), + session: Session = Depends(get_session), +): + """List all policy associations for a user. + + Returns all policies saved by the specified user. Optionally filter by country. + Country ID is validated via Pydantic Literal type. + """ + query = select(UserPolicy).where(UserPolicy.user_id == user_id) + + if country_id: + query = query.where(UserPolicy.country_id == country_id) + + user_policies = session.exec(query).all() + return user_policies + + +@router.get("/{user_policy_id}", response_model=UserPolicyRead) +def get_user_policy( + user_policy_id: UUID, + session: Session = Depends(get_session), +): + """Get a specific user-policy association by ID.""" + user_policy = session.get(UserPolicy, user_policy_id) + if not user_policy: + raise HTTPException(status_code=404, detail="User-policy association not found") + return user_policy + + +@router.patch("/{user_policy_id}", response_model=UserPolicyRead) +def update_user_policy( + user_policy_id: UUID, + updates: UserPolicyUpdate, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Update a user-policy association (e.g., rename label). + + Requires user_id to verify ownership - only the owner can update. + """ + user_policy = session.exec( + select(UserPolicy).where( + UserPolicy.id == user_policy_id, + UserPolicy.user_id == user_id, + ) + ).first() + if not user_policy: + raise HTTPException(status_code=404, detail="User-policy association not found") + + # Apply updates + update_data = updates.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(user_policy, key, value) + + # Update timestamp + user_policy.updated_at = datetime.now(timezone.utc) + + session.add(user_policy) + session.commit() + session.refresh(user_policy) + return user_policy + + +@router.delete("/{user_policy_id}", status_code=204) +def delete_user_policy( + user_policy_id: UUID, + user_id: UUID = Query(..., description="User ID for ownership verification"), + session: Session = Depends(get_session), +): + """Delete a user-policy association. + + This only removes the association, not the underlying policy. + Requires user_id to verify ownership - only the owner can delete. + """ + user_policy = session.exec( + select(UserPolicy).where( + UserPolicy.id == user_policy_id, + UserPolicy.user_id == user_id, + ) + ).first() + if not user_policy: + raise HTTPException(status_code=404, detail="User-policy association not found") + + session.delete(user_policy) + session.commit() diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index d660b1b..3c24f3d 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -56,9 +56,9 @@ def list_variables( # Case-insensitive search using ILIKE # Note: Variables don't have a label field, only name and description search_pattern = f"%{search}%" - search_filter = Variable.name.ilike(search_pattern) | Variable.description.ilike( + search_filter = Variable.name.ilike( search_pattern - ) + ) | Variable.description.ilike(search_pattern) query = query.where(search_filter) variables = session.exec( diff --git a/src/policyengine_api/config/constants.py b/src/policyengine_api/config/constants.py new file mode 100644 index 0000000..527ba25 --- /dev/null +++ b/src/policyengine_api/config/constants.py @@ -0,0 +1,6 @@ +"""Shared constants for the PolicyEngine API.""" + +from typing import Literal + +# Countries supported by the API +CountryId = Literal["us", "uk"] diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index 332c349..84f8e89 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -242,13 +242,13 @@ def simulate_household_uk( engine = create_engine(database_url) try: - from policyengine.core import Simulation from microdf import MicroDataFrame + from policyengine.core import Simulation from policyengine.tax_benefit_models.uk import uk_latest from policyengine.tax_benefit_models.uk.datasets import ( PolicyEngineUKDataset, + UKYearData, ) - from policyengine.tax_benefit_models.uk.datasets import UKYearData n_people = len(people) n_benunits = max(1, len(benunit)) @@ -487,13 +487,13 @@ def simulate_household_us( engine = create_engine(database_url) try: - from policyengine.core import Simulation from microdf import MicroDataFrame + from policyengine.core import Simulation from policyengine.tax_benefit_models.us import us_latest from policyengine.tax_benefit_models.us.datasets import ( PolicyEngineUSDataset, + USYearData, ) - from policyengine.tax_benefit_models.us.datasets import USYearData n_people = len(people) n_households = max(1, len(household)) diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index 7361979..e7b386a 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -61,6 +61,12 @@ UserHouseholdAssociationRead, UserHouseholdAssociationUpdate, ) +from .user_policy import ( + UserPolicy, + UserPolicyCreate, + UserPolicyRead, + UserPolicyUpdate, +) from .variable import Variable, VariableCreate, VariableRead __all__ = [ @@ -136,6 +142,10 @@ "UserHouseholdAssociationRead", "UserHouseholdAssociationUpdate", "UserRead", + "UserPolicy", + "UserPolicyCreate", + "UserPolicyRead", + "UserPolicyUpdate", "Variable", "VariableCreate", "VariableRead", diff --git a/src/policyengine_api/models/policy.py b/src/policyengine_api/models/policy.py index 570320b..69eeecf 100644 --- a/src/policyengine_api/models/policy.py +++ b/src/policyengine_api/models/policy.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from .parameter_value import ParameterValue + from .tax_benefit_model import TaxBenefitModel class PolicyBase(SQLModel): @@ -13,6 +14,7 @@ class PolicyBase(SQLModel): name: str description: str | None = None + tax_benefit_model_id: UUID = Field(foreign_key="tax_benefit_models.id", index=True) class Policy(PolicyBase, table=True): @@ -26,6 +28,7 @@ class Policy(PolicyBase, table=True): # Relationships parameter_values: list["ParameterValue"] = Relationship(back_populates="policy") + tax_benefit_model: "TaxBenefitModel" = Relationship() class PolicyCreate(PolicyBase): diff --git a/src/policyengine_api/models/user_policy.py b/src/policyengine_api/models/user_policy.py new file mode 100644 index 0000000..a9a86b6 --- /dev/null +++ b/src/policyengine_api/models/user_policy.py @@ -0,0 +1,64 @@ +from datetime import datetime, timezone +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 + +from sqlmodel import Field, Relationship, SQLModel + +from policyengine_api.config.constants import CountryId + +if TYPE_CHECKING: + from .policy import Policy + + +class UserPolicyBase(SQLModel): + """Base user-policy association fields.""" + + # user_id is a client-generated UUID stored in localStorage, not a foreign key. + # This allows anonymous users to save policies without requiring authentication. + # The UUID is generated once per browser via crypto.randomUUID() and persisted + # in localStorage for stable identity across sessions. + user_id: UUID = Field(index=True) + policy_id: UUID = Field(foreign_key="policies.id", index=True) + country_id: str # Stored as string in DB, validated via Pydantic in Create schema + label: str | None = None + + +class UserPolicy(UserPolicyBase, table=True): + """User-policy association database model.""" + + __tablename__ = "user_policies" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + # Relationships + policy: "Policy" = Relationship() + + +class UserPolicyCreate(SQLModel): + """Schema for creating user-policy associations. + + Uses CountryId Literal type for validation of country_id. + """ + + user_id: UUID + policy_id: UUID + country_id: CountryId # Validated to "us" or "uk" + label: str | None = None + + +class UserPolicyRead(UserPolicyBase): + """Schema for reading user-policy associations.""" + + id: UUID + created_at: datetime + updated_at: datetime + + +class UserPolicyUpdate(SQLModel): + """Schema for updating user-policy associations.""" + + model_config = {"extra": "forbid"} + + label: str | None = None diff --git a/test_fixtures/fixtures_household_analysis.py b/test_fixtures/fixtures_household_analysis.py index 573930a..5cf5cd1 100644 --- a/test_fixtures/fixtures_household_analysis.py +++ b/test_fixtures/fixtures_household_analysis.py @@ -272,13 +272,13 @@ def create_parameter( def create_policy( session: Session, - model_version_id: UUID, + model_id: UUID, name: str = "Test Policy", description: str = "A test policy", ) -> Policy: """Create and persist a Policy record.""" policy = Policy( - tax_benefit_model_version_id=model_version_id, + tax_benefit_model_id=model_id, name=name, description=description, ) @@ -290,13 +290,13 @@ def create_policy( def create_policy_with_parameter_value( session: Session, - model_version_id: UUID, + model_id: UUID, parameter_id: UUID, value: float, name: str = "Test Policy", ) -> Policy: """Create a Policy with an associated ParameterValue.""" - policy = create_policy(session, model_version_id, name=name) + policy = create_policy(session, model_id, name=name) param_value = ParameterValue( policy_id=policy.id, diff --git a/test_fixtures/fixtures_parameters.py b/test_fixtures/fixtures_parameters.py index ff69b0e..0df134c 100644 --- a/test_fixtures/fixtures_parameters.py +++ b/test_fixtures/fixtures_parameters.py @@ -54,9 +54,15 @@ def create_parameter(session, model_version, name: str, label: str) -> Parameter return param -def create_policy(session, name: str, description: str = "A test policy") -> Policy: +def create_policy( + session, name: str, model_version, description: str = "A test policy" +) -> Policy: """Create and persist a Policy.""" - policy = Policy(name=name, description=description) + policy = Policy( + name=name, + description=description, + tax_benefit_model_id=model_version.model_id, + ) session.add(policy) session.commit() session.refresh(policy) diff --git a/test_fixtures/fixtures_user_policies.py b/test_fixtures/fixtures_user_policies.py new file mode 100644 index 0000000..1572ca7 --- /dev/null +++ b/test_fixtures/fixtures_user_policies.py @@ -0,0 +1,70 @@ +"""Fixtures and helpers for user-policy association tests.""" + +from uuid import UUID + +from policyengine_api.models import Policy, TaxBenefitModel, UserPolicy + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- + +US_COUNTRY_ID = "us" +UK_COUNTRY_ID = "uk" + +DEFAULT_POLICY_NAME = "Test policy" +DEFAULT_POLICY_DESCRIPTION = "A test policy" + +# ----------------------------------------------------------------------------- +# Factory functions +# ----------------------------------------------------------------------------- + + +def create_tax_benefit_model( + session, + name: str = "policyengine-us", + description: str = "US model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + record = TaxBenefitModel(name=name, description=description) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_policy( + session, + tax_benefit_model: TaxBenefitModel, + name: str = DEFAULT_POLICY_NAME, + description: str = DEFAULT_POLICY_DESCRIPTION, +) -> Policy: + """Create and persist a Policy record.""" + record = Policy( + name=name, + description=description, + tax_benefit_model_id=tax_benefit_model.id, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_user_policy( + session, + user_id: UUID, + policy: Policy, + country_id: str = US_COUNTRY_ID, + label: str | None = None, +) -> UserPolicy: + """Create and persist a UserPolicy association record.""" + record = UserPolicy( + user_id=user_id, + policy_id=policy.id, + country_id=country_id, + label=label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/tests/conftest.py b/tests/conftest.py index 8be9b3f..77c29ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,5 @@ """Pytest fixtures for tests.""" -from uuid import uuid4 - import pytest from fastapi.testclient import TestClient from fastapi_cache import FastAPICache @@ -48,6 +46,26 @@ def get_session_override(): app.dependency_overrides.clear() +@pytest.fixture(name="tax_benefit_model") +def tax_benefit_model_fixture(session: Session): + """Create a TaxBenefitModel for tests.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + return model + + +@pytest.fixture(name="uk_tax_benefit_model") +def uk_tax_benefit_model_fixture(session: Session): + """Create a UK TaxBenefitModel for tests.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK model") + session.add(model) + session.commit() + session.refresh(model) + return model + + @pytest.fixture(name="simulation_id") def simulation_fixture(session: Session): """Create a test simulation with required dependencies.""" diff --git a/tests/test_agent.py b/tests/test_agent.py index 2c591f5..55bb2c2 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -9,6 +9,7 @@ import json from unittest.mock import AsyncMock, MagicMock, patch + from fastapi.testclient import TestClient from policyengine_api.main import app diff --git a/tests/test_agent_policy_questions.py b/tests/test_agent_policy_questions.py index 1550f89..289d73c 100644 --- a/tests/test_agent_policy_questions.py +++ b/tests/test_agent_policy_questions.py @@ -11,10 +11,10 @@ pytestmark = pytest.mark.integration -from policyengine_api.agent_sandbox import _run_agent_impl - import os +from policyengine_api.agent_sandbox import _run_agent_impl + # Use local API by default, override with POLICYENGINE_API_URL env var API_BASE = os.environ.get("POLICYENGINE_API_URL", "http://localhost:8000") @@ -218,4 +218,6 @@ def test_turn_efficiency(self, question, max_expected_turns): print(f"Result: {result['result'][:300]}") if result["turns"] > max_expected_turns: - print(f"WARNING: Took {result['turns']} turns, expected <= {max_expected_turns}") + print( + f"WARNING: Took {result['turns']} turns, expected <= {max_expected_turns}" + ) diff --git a/tests/test_analysis_household_impact.py b/tests/test_analysis_household_impact.py index 23465c7..3633e68 100644 --- a/tests/test_analysis_household_impact.py +++ b/tests/test_analysis_household_impact.py @@ -341,9 +341,9 @@ def test_single_run_creates_one_simulation(self, client, session): def test_comparison_creates_two_simulations(self, client, session): """Comparison (with policy_id) creates two simulations.""" - _, version = setup_uk_model_and_version(session) + model, version = setup_uk_model_and_version(session) household = create_household_for_analysis(session) - policy = create_policy(session, version.id) + policy = create_policy(session, model.id) response = client.post( "/analysis/household-impact", @@ -381,9 +381,9 @@ def test_simulation_type_is_household(self, client, session): def test_report_links_simulations(self, client, session): """Report correctly links baseline and reform simulations.""" - _, version = setup_uk_model_and_version(session) + model, version = setup_uk_model_and_version(session) household = create_household_for_analysis(session) - policy = create_policy(session, version.id) + policy = create_policy(session, model.id) response = client.post( "/analysis/household-impact", @@ -435,10 +435,10 @@ def test_same_request_returns_same_simulation(self, client, session): def test_different_policy_creates_different_simulation(self, client, session): """Different policy creates different simulation.""" - _, version = setup_uk_model_and_version(session) + model, version = setup_uk_model_and_version(session) household = create_household_for_analysis(session) - policy1 = create_policy(session, version.id, name="Policy 1") - policy2 = create_policy(session, version.id, name="Policy 2") + policy1 = create_policy(session, model.id, name="Policy 1") + policy2 = create_policy(session, model.id, name="Policy 2") # Request with policy1 response1 = client.post( diff --git a/tests/test_integration.py b/tests/test_integration.py index e044cab..e055423 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -9,6 +9,7 @@ pytestmark = pytest.mark.integration from datetime import datetime, timezone + from rich.console import Console from sqlmodel import Session, create_engine, select diff --git a/tests/test_parameters.py b/tests/test_parameters.py index f95016b..50bb213 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -107,7 +107,7 @@ def test__given_policy_id_filter__then_returns_only_matching_values( """GET /parameter-values?policy_id=X returns only values for that policy.""" # Given param = create_parameter(session, model_version, "test.param", "Test Param") - policy = create_policy(session, "Test Policy") + policy = create_policy(session, "Test Policy", model_version) create_parameter_value(session, param.id, 100, policy_id=None) # baseline create_parameter_value(session, param.id, 150, policy_id=policy.id) # reform @@ -135,7 +135,7 @@ def test__given_both_parameter_and_policy_filters__then_returns_matching_interse param2 = create_parameter( session, model_version, "test.both.param2", "Test Both Param 2" ) - policy = create_policy(session, "Test Both Policy") + policy = create_policy(session, "Test Both Policy", model_version) create_parameter_value(session, param1.id, 100, policy_id=None) # baseline create_parameter_value(session, param1.id, 150, policy_id=policy.id) # target diff --git a/tests/test_policies.py b/tests/test_policies.py index f48730b..b4ac25f 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -12,25 +12,46 @@ def test_list_policies_empty(client): assert response.json() == [] -def test_create_policy(client): +def test_create_policy(client, tax_benefit_model): """Create a new policy.""" response = client.post( "/policies", json={ "name": "Test policy", "description": "A test policy", + "tax_benefit_model_id": str(tax_benefit_model.id), }, ) assert response.status_code == 200 data = response.json() assert data["name"] == "Test policy" assert data["description"] == "A test policy" + assert data["tax_benefit_model_id"] == str(tax_benefit_model.id) assert "id" in data -def test_list_policies_with_data(client, session): +def test_create_policy_invalid_tax_benefit_model(client): + """Create policy with non-existent tax_benefit_model returns 404.""" + fake_id = uuid4() + response = client.post( + "/policies", + json={ + "name": "Test policy", + "description": "A test policy", + "tax_benefit_model_id": str(fake_id), + }, + ) + assert response.status_code == 404 + assert response.json()["detail"] == "Tax benefit model not found" + + +def test_list_policies_with_data(client, session, tax_benefit_model): """List policies returns all policies.""" - policy = Policy(name="test-policy", description="Test") + policy = Policy( + name="test-policy", + description="Test", + tax_benefit_model_id=tax_benefit_model.id, + ) session.add(policy) session.commit() @@ -41,9 +62,39 @@ def test_list_policies_with_data(client, session): assert data[0]["name"] == "test-policy" -def test_get_policy(client, session): +def test_list_policies_filter_by_tax_benefit_model( + client, session, tax_benefit_model, uk_tax_benefit_model +): + """List policies with tax_benefit_model_id filter.""" + policy1 = Policy( + name="US policy", + description="US", + tax_benefit_model_id=tax_benefit_model.id, + ) + policy2 = Policy( + name="UK policy", + description="UK", + tax_benefit_model_id=uk_tax_benefit_model.id, + ) + session.add(policy1) + session.add(policy2) + session.commit() + + # Filter by US model + response = client.get(f"/policies?tax_benefit_model_id={tax_benefit_model.id}") + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "US policy" + + +def test_get_policy(client, session, tax_benefit_model): """Get a specific policy by ID.""" - policy = Policy(name="test-policy", description="Test") + policy = Policy( + name="test-policy", + description="Test", + tax_benefit_model_id=tax_benefit_model.id, + ) session.add(policy) session.commit() session.refresh(policy) diff --git a/tests/test_user_policies.py b/tests/test_user_policies.py new file mode 100644 index 0000000..3b8d06e --- /dev/null +++ b/tests/test_user_policies.py @@ -0,0 +1,262 @@ +"""Tests for user-policy association endpoints. + +Note: user_id is a client-generated UUID (not validated against users table), +so tests use uuid4() directly rather than creating User records. +""" + +from uuid import uuid4 + +from test_fixtures.fixtures_user_policies import ( + UK_COUNTRY_ID, + US_COUNTRY_ID, + create_policy, + create_user_policy, +) + + +def test_list_user_policies_empty(client): + """List user policies returns empty list when user has no associations.""" + user_id = uuid4() + response = client.get(f"/user-policies?user_id={user_id}") + assert response.status_code == 200 + assert response.json() == [] + + +def test_create_user_policy(client, session, tax_benefit_model): + """Create a new user-policy association.""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + + response = client.post( + "/user-policies", + json={ + "user_id": str(user_id), + "policy_id": str(policy.id), + "country_id": US_COUNTRY_ID, + "label": "My test policy", + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["user_id"] == str(user_id) + assert data["policy_id"] == str(policy.id) + assert data["country_id"] == US_COUNTRY_ID + assert data["label"] == "My test policy" + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + + +def test_create_user_policy_without_label(client, session, tax_benefit_model): + """Create a user-policy association without a label.""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + + response = client.post( + "/user-policies", + json={ + "user_id": str(user_id), + "policy_id": str(policy.id), + "country_id": US_COUNTRY_ID, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["label"] is None + assert data["country_id"] == US_COUNTRY_ID + + +def test_create_user_policy_policy_not_found(client): + """Create user-policy association with non-existent policy returns 404.""" + user_id = uuid4() + fake_policy_id = uuid4() + + response = client.post( + "/user-policies", + json={ + "user_id": str(user_id), + "policy_id": str(fake_policy_id), + "country_id": US_COUNTRY_ID, + }, + ) + assert response.status_code == 404 + assert response.json()["detail"] == "Policy not found" + + +def test_create_user_policy_duplicate_allowed(client, session, tax_benefit_model): + """Creating duplicate user-policy association is allowed (matches FE localStorage behavior).""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy(session, user_id, policy, country_id=US_COUNTRY_ID) + + # Create duplicate - should succeed with a new ID + response = client.post( + "/user-policies", + json={ + "user_id": str(user_id), + "policy_id": str(policy.id), + "country_id": US_COUNTRY_ID, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] != str(user_policy.id) # New association created + assert data["user_id"] == str(user_id) + assert data["policy_id"] == str(policy.id) + + +def test_list_user_policies_with_data( + client, session, tax_benefit_model, uk_tax_benefit_model +): + """List user policies returns all associations for a user.""" + user_id = uuid4() + policy1 = create_policy(session, tax_benefit_model, name="Policy 1", description="First policy") + policy2 = create_policy(session, uk_tax_benefit_model, name="Policy 2", description="Second policy") + create_user_policy(session, user_id, policy1, country_id=US_COUNTRY_ID, label="US policy") + create_user_policy(session, user_id, policy2, country_id=UK_COUNTRY_ID, label="UK policy") + + response = client.get(f"/user-policies?user_id={user_id}") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_user_policies_filter_by_country( + client, session, tax_benefit_model, uk_tax_benefit_model +): + """List user policies filtered by country_id.""" + user_id = uuid4() + policy1 = create_policy(session, tax_benefit_model, name="Policy 1", description="First policy") + policy2 = create_policy(session, uk_tax_benefit_model, name="Policy 2", description="Second policy") + create_user_policy(session, user_id, policy1, country_id=US_COUNTRY_ID) + create_user_policy(session, user_id, policy2, country_id=UK_COUNTRY_ID) + + response = client.get( + f"/user-policies?user_id={user_id}&country_id={US_COUNTRY_ID}" + ) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["policy_id"] == str(policy1.id) + assert data[0]["country_id"] == US_COUNTRY_ID + + +def test_get_user_policy(client, session, tax_benefit_model): + """Get a specific user-policy association by ID.""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy(session, user_id, policy, country_id=US_COUNTRY_ID, label="My policy") + + response = client.get(f"/user-policies/{user_policy.id}") + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(user_policy.id) + assert data["label"] == "My policy" + assert data["country_id"] == US_COUNTRY_ID + + +def test_get_user_policy_not_found(client): + """Get a non-existent user-policy association returns 404.""" + fake_id = uuid4() + response = client.get(f"/user-policies/{fake_id}") + assert response.status_code == 404 + assert response.json()["detail"] == "User-policy association not found" + + +def test_update_user_policy(client, session, tax_benefit_model): + """Update a user-policy association label.""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy(session, user_id, policy, country_id=US_COUNTRY_ID, label="Old label") + + response = client.patch( + f"/user-policies/{user_policy.id}?user_id={user_id}", + json={"label": "New label"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["label"] == "New label" + assert data["country_id"] == US_COUNTRY_ID + + +def test_update_user_policy_not_found(client): + """Update a non-existent user-policy association returns 404.""" + fake_id = uuid4() + fake_user_id = uuid4() + response = client.patch( + f"/user-policies/{fake_id}?user_id={fake_user_id}", + json={"label": "New label"}, + ) + assert response.status_code == 404 + assert response.json()["detail"] == "User-policy association not found" + + +def test_update_user_policy_wrong_user(client, session, tax_benefit_model): + """Update with wrong user_id returns 404 (ownership check).""" + user_id = uuid4() + wrong_user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy(session, user_id, policy, country_id=US_COUNTRY_ID, label="Original label") + + # Try to update with wrong user_id + response = client.patch( + f"/user-policies/{user_policy.id}?user_id={wrong_user_id}", + json={"label": "Hacked label"}, + ) + assert response.status_code == 404 + + # Verify original label unchanged + response = client.get(f"/user-policies/{user_policy.id}") + assert response.json()["label"] == "Original label" + + +def test_update_user_policy_rejects_extra_fields(client, session, tax_benefit_model): + """Update with extra fields returns 422 (extra='forbid').""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy(session, user_id, policy, country_id=US_COUNTRY_ID, label="Original") + + response = client.patch( + f"/user-policies/{user_policy.id}?user_id={user_id}", + json={"label": "New", "user_id": str(uuid4())}, + ) + assert response.status_code == 422 + + +def test_delete_user_policy(client, session, tax_benefit_model): + """Delete a user-policy association.""" + user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy(session, user_id, policy, country_id=US_COUNTRY_ID) + + response = client.delete(f"/user-policies/{user_policy.id}?user_id={user_id}") + assert response.status_code == 204 + + # Verify it's deleted + response = client.get(f"/user-policies/{user_policy.id}") + assert response.status_code == 404 + + +def test_delete_user_policy_not_found(client): + """Delete a non-existent user-policy association returns 404.""" + fake_id = uuid4() + fake_user_id = uuid4() + response = client.delete(f"/user-policies/{fake_id}?user_id={fake_user_id}") + assert response.status_code == 404 + assert response.json()["detail"] == "User-policy association not found" + + +def test_delete_user_policy_wrong_user(client, session, tax_benefit_model): + """Delete with wrong user_id returns 404 (ownership check).""" + user_id = uuid4() + wrong_user_id = uuid4() + policy = create_policy(session, tax_benefit_model) + user_policy = create_user_policy(session, user_id, policy, country_id=US_COUNTRY_ID) + + # Try to delete with wrong user_id + response = client.delete(f"/user-policies/{user_policy.id}?user_id={wrong_user_id}") + assert response.status_code == 404 + + # Verify it still exists + response = client.get(f"/user-policies/{user_policy.id}") + assert response.status_code == 200 diff --git a/uv.lock b/uv.lock index 466caf4..f66f5d0 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.13" [[package]]