Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 151 additions & 2 deletions src/policyengine_api/api/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
Parameter names are used when creating policy reforms.
"""

from typing import List
from __future__ import annotations

from typing import List, Literal
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlmodel import Session, select

from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId
from policyengine_api.models import (
Parameter,
ParameterRead,
Expand Down Expand Up @@ -67,6 +71,151 @@ def list_parameters(
return parameters


class ParameterByNameRequest(BaseModel):
"""Request body for looking up parameters by name."""

names: list[str]
country_id: CountryId


@router.post("/by-name", response_model=List[ParameterRead])
def get_parameters_by_name(
request: ParameterByNameRequest,
session: Session = Depends(get_session),
):
"""Look up parameters by their exact names.

Given a list of parameter paths (e.g. "gov.hmrc.income_tax.rates.uk[0].rate"),
returns the full metadata for each matching parameter. Names that don't match
any parameter are silently omitted from the response.

Use this to fetch metadata for a known set of parameters (e.g. all parameters
referenced in a user's saved policy) without loading the entire parameter catalog.
"""
if not request.names:
return []

model_name = COUNTRY_MODEL_NAMES[request.country_id]

query = (
select(Parameter)
.join(TaxBenefitModelVersion)
.join(TaxBenefitModel)
.where(TaxBenefitModel.name == model_name)
.where(Parameter.name.in_(request.names))
.order_by(Parameter.name)
)

return session.exec(query).all()


class ParameterChild(BaseModel):
"""A single child in the parameter tree."""

path: str
label: str
type: Literal["node", "parameter"]
child_count: int | None = None
parameter: ParameterRead | None = None


class ParameterChildrenResponse(BaseModel):
"""Response for the parameter children endpoint."""

parent_path: str
children: list[ParameterChild]


@router.get("/children", response_model=ParameterChildrenResponse)
def get_parameter_children(
country_id: CountryId = Query(description='Country ID ("us" or "uk")'),
parent_path: str = Query(
default="", description="Parent parameter path (e.g. 'gov' or 'gov.hmrc')"
),
session: Session = Depends(get_session),
) -> ParameterChildrenResponse:
"""Get direct children of a parameter path for tree navigation.

Returns both intermediate nodes (folders with child_count) and leaf
parameters (with full metadata). Use this to lazily load the parameter
tree one level at a time.
"""
model_name = COUNTRY_MODEL_NAMES[country_id]
prefix = f"{parent_path}." if parent_path else ""

# Fetch all parameters under this path
query = (
select(Parameter)
.join(TaxBenefitModelVersion)
.join(TaxBenefitModel)
.where(TaxBenefitModel.name == model_name)
.where(Parameter.name.startswith(prefix))
)
descendants = session.exec(query).all()

# Group by direct child path
children_map: dict[str, dict] = {}
prefix_len = len(prefix)

for param in descendants:
remainder = param.name[prefix_len:]
dot_pos = remainder.find(".")

if dot_pos == -1:
# Direct child (leaf at this level)
child_path = param.name
if child_path not in children_map:
children_map[child_path] = {
"direct_param": None,
"descendant_count": 0,
}
children_map[child_path]["direct_param"] = param
else:
# Deeper descendant — extract direct child segment
segment = remainder[:dot_pos]
child_path = prefix + segment
if child_path not in children_map:
children_map[child_path] = {
"direct_param": None,
"descendant_count": 0,
}
children_map[child_path]["descendant_count"] += 1

# Build response
children = []
for path in sorted(children_map):
info = children_map[path]
if info["descendant_count"] > 0:
# Node: has children below it
direct_param = info["direct_param"]
label = (
direct_param.label
if direct_param and direct_param.label
else path.rsplit(".", 1)[-1]
)
children.append(
ParameterChild(
path=path,
label=label,
type="node",
child_count=info["descendant_count"],
)
)
elif info["direct_param"]:
# Leaf parameter
param = info["direct_param"]
children.append(
ParameterChild(
path=path,
label=param.label or path.rsplit(".", 1)[-1],
type="parameter",
parameter=ParameterRead.model_validate(param),
)
)

return ParameterChildrenResponse(parent_path=parent_path, children=children)


@router.get("/{parameter_id}", response_model=ParameterRead)
def get_parameter(parameter_id: UUID, session: Session = Depends(get_session)):
"""Get a specific parameter."""
Expand Down
58 changes: 57 additions & 1 deletion src/policyengine_api/api/tax_benefit_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlmodel import Session, select

from policyengine_api.models import TaxBenefitModel, TaxBenefitModelRead
from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId
from policyengine_api.models import (
TaxBenefitModel,
TaxBenefitModelRead,
TaxBenefitModelVersion,
TaxBenefitModelVersionRead,
)
from policyengine_api.services.database import get_session

router = APIRouter(prefix="/tax-benefit-models", tags=["tax-benefit-models"])
Expand All @@ -28,6 +35,55 @@ def list_tax_benefit_models(session: Session = Depends(get_session)):
return models


class ModelByCountryResponse(BaseModel):
"""Response for the model-by-country endpoint."""

model: TaxBenefitModelRead
latest_version: TaxBenefitModelVersionRead


@router.get(
"/by-country/{country_id}",
response_model=ModelByCountryResponse,
)
def get_model_by_country(
country_id: CountryId,
session: Session = Depends(get_session),
):
"""Get a tax-benefit model and its latest version by country ID.

Returns the model metadata and the most recently created version in a
single response. Use this on page load to check the current model version
for cache invalidation.
"""
model_name = COUNTRY_MODEL_NAMES[country_id]

model = session.exec(
select(TaxBenefitModel).where(TaxBenefitModel.name == model_name)
).first()
if not model:
raise HTTPException(
status_code=404,
detail=f"No model found for country '{country_id}'",
)

latest_version = session.exec(
select(TaxBenefitModelVersion)
.where(TaxBenefitModelVersion.model_id == model.id)
.order_by(TaxBenefitModelVersion.created_at.desc())
).first()
if not latest_version:
raise HTTPException(
status_code=404,
detail=f"No versions found for model '{model_name}'",
)

return ModelByCountryResponse(
model=TaxBenefitModelRead.model_validate(model),
latest_version=TaxBenefitModelVersionRead.model_validate(latest_version),
)


@router.get("/{model_id}", response_model=TaxBenefitModelRead)
def get_tax_benefit_model(model_id: UUID, session: Session = Depends(get_session)):
"""Get a specific tax-benefit model."""
Expand Down
40 changes: 40 additions & 0 deletions src/policyengine_api/api/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlmodel import Session, select

from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId
from policyengine_api.models import (
TaxBenefitModel,
TaxBenefitModelVersion,
Expand Down Expand Up @@ -67,6 +69,44 @@ def list_variables(
return variables


class VariableByNameRequest(BaseModel):
"""Request body for looking up variables by name."""

names: list[str]
country_id: CountryId


@router.post("/by-name", response_model=List[VariableRead])
def get_variables_by_name(
request: VariableByNameRequest,
session: Session = Depends(get_session),
):
"""Look up variables by their exact names.

Given a list of variable names (e.g. "employment_income", "income_tax"),
returns the full metadata for each matching variable. Names that don't
match any variable are silently omitted from the response.

Use this to fetch metadata for a known set of variables (e.g. variables
used in a household builder or report output) without loading the entire
variable catalog.
"""
if not request.names:
return []

model_name = COUNTRY_MODEL_NAMES[request.country_id]
query = (
select(Variable)
.join(TaxBenefitModelVersion)
.join(TaxBenefitModel)
.where(TaxBenefitModel.name == model_name)
.where(Variable.name.in_(request.names))
.order_by(Variable.name)
)

return session.exec(query).all()


@router.get("/{variable_id}", response_model=VariableRead)
def get_variable(variable_id: UUID, session: Session = Depends(get_session)):
"""Get a specific variable."""
Expand Down
6 changes: 6 additions & 0 deletions src/policyengine_api/config/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@

# Countries supported by the API
CountryId = Literal["us", "uk"]

# Mapping from country ID to tax-benefit model name in the database
COUNTRY_MODEL_NAMES: dict[str, str] = {
"uk": "policyengine-uk",
"us": "policyengine-us",
}
Loading