diff --git a/src/policyengine_api/api/parameters.py b/src/policyengine_api/api/parameters.py index db029e5..72b64ef 100644 --- a/src/policyengine_api/api/parameters.py +++ b/src/policyengine_api/api/parameters.py @@ -5,12 +5,16 @@ Parameter names are used when creating policy reforms. """ -from typing import List +from __future__ import annotations + +from typing import List, Literal from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel from sqlmodel import Session, select +from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId from policyengine_api.models import ( Parameter, ParameterRead, @@ -67,6 +71,151 @@ def list_parameters( return parameters +class ParameterByNameRequest(BaseModel): + """Request body for looking up parameters by name.""" + + names: list[str] + country_id: CountryId + + +@router.post("/by-name", response_model=List[ParameterRead]) +def get_parameters_by_name( + request: ParameterByNameRequest, + session: Session = Depends(get_session), +): + """Look up parameters by their exact names. + + Given a list of parameter paths (e.g. "gov.hmrc.income_tax.rates.uk[0].rate"), + returns the full metadata for each matching parameter. Names that don't match + any parameter are silently omitted from the response. + + Use this to fetch metadata for a known set of parameters (e.g. all parameters + referenced in a user's saved policy) without loading the entire parameter catalog. + """ + if not request.names: + return [] + + model_name = COUNTRY_MODEL_NAMES[request.country_id] + + query = ( + select(Parameter) + .join(TaxBenefitModelVersion) + .join(TaxBenefitModel) + .where(TaxBenefitModel.name == model_name) + .where(Parameter.name.in_(request.names)) + .order_by(Parameter.name) + ) + + return session.exec(query).all() + + +class ParameterChild(BaseModel): + """A single child in the parameter tree.""" + + path: str + label: str + type: Literal["node", "parameter"] + child_count: int | None = None + parameter: ParameterRead | None = None + + +class ParameterChildrenResponse(BaseModel): + """Response for the parameter children endpoint.""" + + parent_path: str + children: list[ParameterChild] + + +@router.get("/children", response_model=ParameterChildrenResponse) +def get_parameter_children( + country_id: CountryId = Query(description='Country ID ("us" or "uk")'), + parent_path: str = Query( + default="", description="Parent parameter path (e.g. 'gov' or 'gov.hmrc')" + ), + session: Session = Depends(get_session), +) -> ParameterChildrenResponse: + """Get direct children of a parameter path for tree navigation. + + Returns both intermediate nodes (folders with child_count) and leaf + parameters (with full metadata). Use this to lazily load the parameter + tree one level at a time. + """ + model_name = COUNTRY_MODEL_NAMES[country_id] + prefix = f"{parent_path}." if parent_path else "" + + # Fetch all parameters under this path + query = ( + select(Parameter) + .join(TaxBenefitModelVersion) + .join(TaxBenefitModel) + .where(TaxBenefitModel.name == model_name) + .where(Parameter.name.startswith(prefix)) + ) + descendants = session.exec(query).all() + + # Group by direct child path + children_map: dict[str, dict] = {} + prefix_len = len(prefix) + + for param in descendants: + remainder = param.name[prefix_len:] + dot_pos = remainder.find(".") + + if dot_pos == -1: + # Direct child (leaf at this level) + child_path = param.name + if child_path not in children_map: + children_map[child_path] = { + "direct_param": None, + "descendant_count": 0, + } + children_map[child_path]["direct_param"] = param + else: + # Deeper descendant — extract direct child segment + segment = remainder[:dot_pos] + child_path = prefix + segment + if child_path not in children_map: + children_map[child_path] = { + "direct_param": None, + "descendant_count": 0, + } + children_map[child_path]["descendant_count"] += 1 + + # Build response + children = [] + for path in sorted(children_map): + info = children_map[path] + if info["descendant_count"] > 0: + # Node: has children below it + direct_param = info["direct_param"] + label = ( + direct_param.label + if direct_param and direct_param.label + else path.rsplit(".", 1)[-1] + ) + children.append( + ParameterChild( + path=path, + label=label, + type="node", + child_count=info["descendant_count"], + ) + ) + elif info["direct_param"]: + # Leaf parameter + param = info["direct_param"] + children.append( + ParameterChild( + path=path, + label=param.label or path.rsplit(".", 1)[-1], + type="parameter", + parameter=ParameterRead.model_validate(param), + ) + ) + + return ParameterChildrenResponse(parent_path=parent_path, children=children) + + @router.get("/{parameter_id}", response_model=ParameterRead) def get_parameter(parameter_id: UUID, session: Session = Depends(get_session)): """Get a specific parameter.""" diff --git a/src/policyengine_api/api/tax_benefit_models.py b/src/policyengine_api/api/tax_benefit_models.py index 5dda3a1..b4d4921 100644 --- a/src/policyengine_api/api/tax_benefit_models.py +++ b/src/policyengine_api/api/tax_benefit_models.py @@ -9,9 +9,16 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel from sqlmodel import Session, select -from policyengine_api.models import TaxBenefitModel, TaxBenefitModelRead +from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId +from policyengine_api.models import ( + TaxBenefitModel, + TaxBenefitModelRead, + TaxBenefitModelVersion, + TaxBenefitModelVersionRead, +) from policyengine_api.services.database import get_session router = APIRouter(prefix="/tax-benefit-models", tags=["tax-benefit-models"]) @@ -28,6 +35,55 @@ def list_tax_benefit_models(session: Session = Depends(get_session)): return models +class ModelByCountryResponse(BaseModel): + """Response for the model-by-country endpoint.""" + + model: TaxBenefitModelRead + latest_version: TaxBenefitModelVersionRead + + +@router.get( + "/by-country/{country_id}", + response_model=ModelByCountryResponse, +) +def get_model_by_country( + country_id: CountryId, + session: Session = Depends(get_session), +): + """Get a tax-benefit model and its latest version by country ID. + + Returns the model metadata and the most recently created version in a + single response. Use this on page load to check the current model version + for cache invalidation. + """ + model_name = COUNTRY_MODEL_NAMES[country_id] + + model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == model_name) + ).first() + if not model: + raise HTTPException( + status_code=404, + detail=f"No model found for country '{country_id}'", + ) + + latest_version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + if not latest_version: + raise HTTPException( + status_code=404, + detail=f"No versions found for model '{model_name}'", + ) + + return ModelByCountryResponse( + model=TaxBenefitModelRead.model_validate(model), + latest_version=TaxBenefitModelVersionRead.model_validate(latest_version), + ) + + @router.get("/{model_id}", response_model=TaxBenefitModelRead) def get_tax_benefit_model(model_id: UUID, session: Session = Depends(get_session)): """Get a specific tax-benefit model.""" diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index 3c24f3d..04aa512 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -9,8 +9,10 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel from sqlmodel import Session, select +from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId from policyengine_api.models import ( TaxBenefitModel, TaxBenefitModelVersion, @@ -67,6 +69,44 @@ def list_variables( return variables +class VariableByNameRequest(BaseModel): + """Request body for looking up variables by name.""" + + names: list[str] + country_id: CountryId + + +@router.post("/by-name", response_model=List[VariableRead]) +def get_variables_by_name( + request: VariableByNameRequest, + session: Session = Depends(get_session), +): + """Look up variables by their exact names. + + Given a list of variable names (e.g. "employment_income", "income_tax"), + returns the full metadata for each matching variable. Names that don't + match any variable are silently omitted from the response. + + Use this to fetch metadata for a known set of variables (e.g. variables + used in a household builder or report output) without loading the entire + variable catalog. + """ + if not request.names: + return [] + + model_name = COUNTRY_MODEL_NAMES[request.country_id] + query = ( + select(Variable) + .join(TaxBenefitModelVersion) + .join(TaxBenefitModel) + .where(TaxBenefitModel.name == model_name) + .where(Variable.name.in_(request.names)) + .order_by(Variable.name) + ) + + return session.exec(query).all() + + @router.get("/{variable_id}", response_model=VariableRead) def get_variable(variable_id: UUID, session: Session = Depends(get_session)): """Get a specific variable.""" diff --git a/src/policyengine_api/config/constants.py b/src/policyengine_api/config/constants.py index 527ba25..a39d827 100644 --- a/src/policyengine_api/config/constants.py +++ b/src/policyengine_api/config/constants.py @@ -4,3 +4,9 @@ # Countries supported by the API CountryId = Literal["us", "uk"] + +# Mapping from country ID to tax-benefit model name in the database +COUNTRY_MODEL_NAMES: dict[str, str] = { + "uk": "policyengine-uk", + "us": "policyengine-us", +} diff --git a/tests/test_models_by_country.py b/tests/test_models_by_country.py new file mode 100644 index 0000000..f71be93 --- /dev/null +++ b/tests/test_models_by_country.py @@ -0,0 +1,149 @@ +"""Tests for GET /tax-benefit-models/by-country/{country_id} endpoint.""" + +from datetime import datetime, timezone, timedelta + +import pytest + +from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _create_model_and_version(session, name, description, version_str, **version_kw): + """Create a model and a single version, return (model, version).""" + model = TaxBenefitModel(name=name, description=description) + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, + version=version_str, + description=f"{name} {version_str}", + **version_kw, + ) + session.add(version) + session.commit() + session.refresh(version) + return model, version + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestModelByCountry: + """Tests for the by-country lookup.""" + + def test_uk_returns_model_and_version(self, client, session): + """country_id=uk returns the UK model and its latest version.""" + _create_model_and_version(session, "policyengine-uk", "UK model", "2.51.0") + + response = client.get("/tax-benefit-models/by-country/uk") + + assert response.status_code == 200 + data = response.json() + assert data["model"]["name"] == "policyengine-uk" + assert data["latest_version"]["version"] == "2.51.0" + + def test_us_returns_model_and_version(self, client, session): + """country_id=us returns the US model and its latest version.""" + _create_model_and_version(session, "policyengine-us", "US model", "1.20.0") + + response = client.get("/tax-benefit-models/by-country/us") + + assert response.status_code == 200 + data = response.json() + assert data["model"]["name"] == "policyengine-us" + assert data["latest_version"]["version"] == "1.20.0" + + def test_multiple_versions_returns_latest(self, client, session): + """When multiple versions exist, returns the most recently created.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK") + session.add(model) + session.commit() + session.refresh(model) + + old = TaxBenefitModelVersion( + model_id=model.id, + version="2.50.0", + description="Old", + created_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + new = TaxBenefitModelVersion( + model_id=model.id, + version="2.51.0", + description="New", + created_at=datetime(2026, 2, 1, tzinfo=timezone.utc), + ) + session.add(old) + session.add(new) + session.commit() + + response = client.get("/tax-benefit-models/by-country/uk") + + assert response.status_code == 200 + assert response.json()["latest_version"]["version"] == "2.51.0" + + def test_no_model_returns_404(self, client): + """When the model doesn't exist in the DB, returns 404.""" + response = client.get("/tax-benefit-models/by-country/uk") + + assert response.status_code == 404 + assert "No model found" in response.json()["detail"] + + def test_model_without_versions_returns_404(self, client, session): + """When the model exists but has no versions, returns 404.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK") + session.add(model) + session.commit() + + response = client.get("/tax-benefit-models/by-country/uk") + + assert response.status_code == 404 + assert "No versions found" in response.json()["detail"] + + def test_invalid_country_id_returns_422(self, client): + """An invalid country_id is rejected by Literal validation.""" + response = client.get("/tax-benefit-models/by-country/fr") + + assert response.status_code == 422 + + def test_response_shape(self, client, session): + """Response contains the expected fields for both model and version.""" + _create_model_and_version(session, "policyengine-uk", "UK model", "2.51.0") + + response = client.get("/tax-benefit-models/by-country/uk") + data = response.json() + + # Model fields + model = data["model"] + assert "id" in model + assert "name" in model + assert "description" in model + assert "created_at" in model + + # Version fields + version = data["latest_version"] + assert "id" in version + assert "version" in version + assert "model_id" in version + assert "description" in version + assert "created_at" in version + + def test_country_isolation(self, client, session): + """UK endpoint doesn't return US model data and vice versa.""" + _create_model_and_version(session, "policyengine-uk", "UK", "2.51.0") + _create_model_and_version(session, "policyengine-us", "US", "1.20.0") + + uk_resp = client.get("/tax-benefit-models/by-country/uk") + us_resp = client.get("/tax-benefit-models/by-country/us") + + assert uk_resp.json()["model"]["name"] == "policyengine-uk" + assert uk_resp.json()["latest_version"]["version"] == "2.51.0" + assert us_resp.json()["model"]["name"] == "policyengine-us" + assert us_resp.json()["latest_version"]["version"] == "1.20.0" diff --git a/tests/test_parameters_by_name.py b/tests/test_parameters_by_name.py new file mode 100644 index 0000000..b9cace6 --- /dev/null +++ b/tests/test_parameters_by_name.py @@ -0,0 +1,241 @@ +"""Tests for POST /parameters/by-name endpoint.""" + +import pytest + +from policyengine_api.models import ( + Parameter, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def us_version(session): + """Create a policyengine-us model and version.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="US v1" + ) + session.add(version) + session.commit() + session.refresh(version) + return version + + +def create_parameter(session, model_version, name: str, label: str) -> Parameter: + """Create and persist a Parameter.""" + param = Parameter( + name=name, + label=label, + tax_benefit_model_version_id=model_version.id, + ) + session.add(param) + session.commit() + session.refresh(param) + return param + + +class TestParametersByName: + """Tests for looking up parameters by their exact names.""" + + def test_returns_matching_parameters(self, client, session, us_version): + """Given known parameter names, returns their full metadata.""" + create_parameter(session, us_version, "gov.tax.rate", "Tax rate") + create_parameter(session, us_version, "gov.tax.threshold", "Threshold") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.tax.rate", "gov.tax.threshold"], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + returned_names = {p["name"] for p in data} + assert returned_names == {"gov.tax.rate", "gov.tax.threshold"} + + def test_returns_empty_list_for_empty_names(self, client): + """Given an empty names list, returns an empty list.""" + response = client.post( + "/parameters/by-name", + json={ + "names": [], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + assert response.json() == [] + + def test_returns_empty_list_for_unknown_names(self, client, session, us_version): + """Given names that don't match any parameter, returns an empty list.""" + create_parameter(session, us_version, "gov.exists", "Exists") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.does_not_exist", "gov.also_missing"], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + assert response.json() == [] + + def test_returns_only_matching_when_mix_of_known_and_unknown( + self, client, session, us_version + ): + """Given a mix of known and unknown names, returns only the known ones.""" + create_parameter(session, us_version, "gov.real", "Real param") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.real", "gov.fake"], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.real" + + def test_filters_by_country(self, client, session): + """Parameters from a different country are excluded.""" + # Create two models + model_uk = TaxBenefitModel(name="policyengine-uk", description="UK") + model_us = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model_uk) + session.add(model_us) + session.commit() + session.refresh(model_uk) + session.refresh(model_us) + + ver_uk = TaxBenefitModelVersion( + model_id=model_uk.id, version="1.0", description="UK v1" + ) + ver_us = TaxBenefitModelVersion( + model_id=model_us.id, version="1.0", description="US v1" + ) + session.add(ver_uk) + session.add(ver_us) + session.commit() + session.refresh(ver_uk) + session.refresh(ver_us) + + # Same parameter name in both models + create_parameter(session, ver_uk, "gov.shared_name", "UK version") + create_parameter(session, ver_us, "gov.shared_name", "US version") + + # Request only UK + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.shared_name"], + "country_id": "uk", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["label"] == "UK version" + + def test_response_shape_matches_parameter_read( + self, client, session, us_version + ): + """Returned objects have the same shape as ParameterRead.""" + create_parameter(session, us_version, "gov.shape_test", "Shape test") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.shape_test"], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + param = data[0] + assert "id" in param + assert "name" in param + assert "label" in param + assert "created_at" in param + assert "tax_benefit_model_version_id" in param + + def test_results_ordered_by_name(self, client, session, us_version): + """Returned parameters are sorted alphabetically by name.""" + create_parameter(session, us_version, "gov.zzz", "Last") + create_parameter(session, us_version, "gov.aaa", "First") + create_parameter(session, us_version, "gov.mmm", "Middle") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.zzz", "gov.aaa", "gov.mmm"], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + names = [p["name"] for p in response.json()] + assert names == ["gov.aaa", "gov.mmm", "gov.zzz"] + + def test_missing_country_id_returns_422(self, client): + """Request without country_id is rejected.""" + response = client.post( + "/parameters/by-name", + json={"names": ["gov.something"]}, + ) + + assert response.status_code == 422 + + def test_invalid_country_id_returns_422(self, client): + """Request with invalid country_id is rejected.""" + response = client.post( + "/parameters/by-name", + json={"names": ["gov.something"], "country_id": "invalid"}, + ) + + assert response.status_code == 422 + + def test_missing_names_field_returns_422(self, client): + """Request without names field is rejected.""" + response = client.post( + "/parameters/by-name", + json={"country_id": "us"}, + ) + + assert response.status_code == 422 + + def test_single_name_lookup(self, client, session, us_version): + """Looking up a single parameter name works.""" + create_parameter(session, us_version, "gov.single", "Single param") + + response = client.post( + "/parameters/by-name", + json={ + "names": ["gov.single"], + "country_id": "us", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "gov.single" diff --git a/tests/test_parameters_children.py b/tests/test_parameters_children.py new file mode 100644 index 0000000..3db4f01 --- /dev/null +++ b/tests/test_parameters_children.py @@ -0,0 +1,404 @@ +"""Tests for GET /parameters/children endpoint.""" + +import pytest + +from policyengine_api.models import ( + Parameter, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def uk_version(session): + """Create a policyengine-uk model and version.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="UK v1" + ) + session.add(version) + session.commit() + session.refresh(version) + return version + + +@pytest.fixture +def us_version(session): + """Create a policyengine-us model and version.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="US v1" + ) + session.add(version) + session.commit() + session.refresh(version) + return version + + +def _add_params(session, version, names_and_labels): + """Bulk-add parameters. names_and_labels is [(name, label), ...].""" + for name, label in names_and_labels: + session.add( + Parameter( + name=name, + label=label, + tax_benefit_model_version_id=version.id, + ) + ) + session.commit() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestParameterChildrenBasic: + """Basic tree structure tests.""" + + def test_returns_nodes_for_intermediate_paths(self, client, session, uk_version): + """Parameters at gov.hmrc.x and gov.dwp.x produce nodes for hmrc and dwp.""" + _add_params( + session, + uk_version, + [ + ("gov.hmrc.income_tax.rate", "Basic rate"), + ("gov.hmrc.income_tax.threshold", "Threshold"), + ("gov.dwp.uc.amount", "UC amount"), + ], + ) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["parent_path"] == "gov" + children = data["children"] + assert len(children) == 2 + paths = [c["path"] for c in children] + assert paths == ["gov.dwp", "gov.hmrc"] + for child in children: + assert child["type"] == "node" + assert child["child_count"] > 0 + + def test_returns_leaf_parameters(self, client, session, uk_version): + """Direct child parameters are returned with type='parameter'.""" + _add_params( + session, + uk_version, + [ + ("gov.benefit_uprating_cpi", "Benefit uprating CPI"), + ("gov.hmrc.income_tax.rate", "Basic rate"), + ], + ) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + assert response.status_code == 200 + children = response.json()["children"] + assert len(children) == 2 + + leaf = next(c for c in children if c["type"] == "parameter") + assert leaf["path"] == "gov.benefit_uprating_cpi" + assert leaf["label"] == "Benefit uprating CPI" + assert leaf["parameter"] is not None + assert leaf["parameter"]["name"] == "gov.benefit_uprating_cpi" + + node = next(c for c in children if c["type"] == "node") + assert node["path"] == "gov.hmrc" + + def test_mixed_nodes_and_leaves(self, client, session, uk_version): + """Both nodes and leaf parameters can appear at the same level.""" + _add_params( + session, + uk_version, + [ + ("gov.hmrc.tax.rate", "Rate"), + ("gov.flat_rate", "Flat rate"), + ("gov.threshold", "Threshold"), + ], + ) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + children = response.json()["children"] + types = {c["path"]: c["type"] for c in children} + assert types["gov.hmrc"] == "node" + assert types["gov.flat_rate"] == "parameter" + assert types["gov.threshold"] == "parameter" + + +class TestChildCount: + """Tests for child_count accuracy.""" + + def test_child_count_reflects_total_descendants(self, client, session, uk_version): + """child_count counts all leaf parameters under the node.""" + _add_params( + session, + uk_version, + [ + ("gov.hmrc.income_tax.rate", "Rate"), + ("gov.hmrc.income_tax.threshold", "Threshold"), + ("gov.hmrc.ni.rate", "NI rate"), + ], + ) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + children = response.json()["children"] + hmrc = children[0] + assert hmrc["path"] == "gov.hmrc" + assert hmrc["child_count"] == 3 + + def test_nested_child_count(self, client, session, uk_version): + """Querying a deeper level gives accurate child counts.""" + _add_params( + session, + uk_version, + [ + ("gov.hmrc.income_tax.rate", "Rate"), + ("gov.hmrc.income_tax.threshold", "Threshold"), + ("gov.hmrc.ni.rate", "NI rate"), + ], + ) + + response = client.get( + "/parameters/children", + params={"country_id": "uk", "parent_path": "gov.hmrc"}, + ) + + children = response.json()["children"] + assert len(children) == 2 + income_tax = next(c for c in children if c["path"] == "gov.hmrc.income_tax") + ni = next(c for c in children if c["path"] == "gov.hmrc.ni") + assert income_tax["child_count"] == 2 + assert ni["child_count"] == 1 + + def test_leaf_has_no_child_count(self, client, session, uk_version): + """Leaf parameters have child_count=None.""" + _add_params(session, uk_version, [("gov.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + children = response.json()["children"] + assert len(children) == 1 + assert children[0]["child_count"] is None + + +class TestCountryFiltering: + """Tests for country_id filtering.""" + + def test_uk_country_id(self, client, session, uk_version): + """country_id=uk returns UK parameters.""" + _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + assert response.status_code == 200 + assert len(response.json()["children"]) == 1 + + def test_us_country_id(self, client, session, us_version): + """country_id=us returns US parameters.""" + _add_params(session, us_version, [("gov.irs.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "us", "parent_path": "gov"} + ) + + assert response.status_code == 200 + assert len(response.json()["children"]) == 1 + + def test_country_isolation(self, client, session, uk_version, us_version): + """Parameters from a different country are excluded.""" + _add_params(session, uk_version, [("gov.hmrc.rate", "UK rate")]) + _add_params(session, us_version, [("gov.irs.rate", "US rate")]) + + uk_response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + us_response = client.get( + "/parameters/children", params={"country_id": "us", "parent_path": "gov"} + ) + + uk_paths = [c["path"] for c in uk_response.json()["children"]] + us_paths = [c["path"] for c in us_response.json()["children"]] + assert uk_paths == ["gov.hmrc"] + assert us_paths == ["gov.irs"] + + def test_invalid_country_id_returns_422(self, client): + """An invalid country_id is rejected by Literal validation.""" + response = client.get( + "/parameters/children", + params={"country_id": "fr", "parent_path": "gov"}, + ) + + assert response.status_code == 422 + + +class TestEdgeCases: + """Tests for edge cases and special inputs.""" + + def test_empty_parent_path(self, client, session, uk_version): + """Empty parent_path returns top-level children.""" + _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": ""} + ) + + assert response.status_code == 200 + children = response.json()["children"] + assert len(children) == 1 + assert children[0]["path"] == "gov" + assert children[0]["type"] == "node" + + def test_nonexistent_parent_returns_empty(self, client, session, uk_version): + """A parent path with no descendants returns empty children list.""" + _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + + response = client.get( + "/parameters/children", + params={"country_id": "uk", "parent_path": "gov.dwp"}, + ) + + assert response.status_code == 200 + assert response.json()["children"] == [] + + def test_children_sorted_by_path(self, client, session, uk_version): + """Children are returned sorted alphabetically by path.""" + _add_params( + session, + uk_version, + [ + ("gov.zzz.param", "Z param"), + ("gov.aaa.param", "A param"), + ("gov.mmm.param", "M param"), + ], + ) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + paths = [c["path"] for c in response.json()["children"]] + assert paths == ["gov.aaa", "gov.mmm", "gov.zzz"] + + def test_node_label_from_path_segment(self, client, session, uk_version): + """Node labels default to the last path segment when no parameter exists.""" + _add_params(session, uk_version, [("gov.hmrc.income_tax.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + children = response.json()["children"] + assert children[0]["label"] == "hmrc" + + def test_missing_country_id_returns_422(self, client): + """Request without country_id returns 422.""" + response = client.get( + "/parameters/children", params={"parent_path": "gov"} + ) + + assert response.status_code == 422 + + def test_default_parent_path_is_empty(self, client, session, uk_version): + """Omitting parent_path defaults to empty string (root level).""" + _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk"} + ) + + assert response.status_code == 200 + assert response.json()["parent_path"] == "" + assert len(response.json()["children"]) == 1 + + def test_leaf_parameter_includes_full_metadata(self, client, session, uk_version): + """Leaf parameters include the full ParameterRead shape.""" + _add_params(session, uk_version, [("gov.rate", "The rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + param = response.json()["children"][0]["parameter"] + assert param["name"] == "gov.rate" + assert param["label"] == "The rate" + assert "id" in param + assert "created_at" in param + assert "tax_benefit_model_version_id" in param + + def test_node_has_no_parameter_field(self, client, session, uk_version): + """Nodes do not include the parameter field.""" + _add_params(session, uk_version, [("gov.hmrc.rate", "Rate")]) + + response = client.get( + "/parameters/children", params={"country_id": "uk", "parent_path": "gov"} + ) + + node = response.json()["children"][0] + assert node["type"] == "node" + assert node["parameter"] is None + + def test_deep_nesting(self, client, session, uk_version): + """Works correctly with deeply nested parameter paths.""" + _add_params( + session, + uk_version, + [("gov.hmrc.income_tax.rates.uk[0].rate", "Basic rate")], + ) + + # Each level should show the correct child + for parent, expected_child in [ + ("gov", "gov.hmrc"), + ("gov.hmrc", "gov.hmrc.income_tax"), + ("gov.hmrc.income_tax", "gov.hmrc.income_tax.rates"), + ("gov.hmrc.income_tax.rates", "gov.hmrc.income_tax.rates.uk[0]"), + ]: + resp = client.get( + "/parameters/children", + params={"country_id": "uk", "parent_path": parent}, + ) + children = resp.json()["children"] + assert len(children) == 1 + assert children[0]["path"] == expected_child + assert children[0]["type"] == "node" + + # Final level should be a leaf + resp = client.get( + "/parameters/children", + params={ + "country_id": "uk", + "parent_path": "gov.hmrc.income_tax.rates.uk[0]", + }, + ) + children = resp.json()["children"] + assert len(children) == 1 + assert children[0]["type"] == "parameter" + assert children[0]["path"] == "gov.hmrc.income_tax.rates.uk[0].rate" diff --git a/tests/test_variables_by_name.py b/tests/test_variables_by_name.py new file mode 100644 index 0000000..2c408e6 --- /dev/null +++ b/tests/test_variables_by_name.py @@ -0,0 +1,229 @@ +"""Tests for POST /variables/by-name endpoint.""" + +import pytest + +from policyengine_api.models import ( + TaxBenefitModel, + TaxBenefitModelVersion, + Variable, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def uk_version(session): + """Create a policyengine-uk model and version.""" + model = TaxBenefitModel(name="policyengine-uk", description="UK model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="UK v1" + ) + session.add(version) + session.commit() + session.refresh(version) + return version + + +@pytest.fixture +def us_version(session): + """Create a policyengine-us model and version.""" + model = TaxBenefitModel(name="policyengine-us", description="US model") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="US v1" + ) + session.add(version) + session.commit() + session.refresh(version) + return version + + +def _add_var(session, version, name, entity="person", description=None): + """Create and persist a Variable.""" + var = Variable( + name=name, + entity=entity, + description=description, + tax_benefit_model_version_id=version.id, + ) + session.add(var) + session.commit() + session.refresh(var) + return var + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestVariablesByName: + """Tests for looking up variables by their exact names.""" + + def test_returns_matching_variables(self, client, session, uk_version): + """Given known variable names, returns their full metadata.""" + _add_var(session, uk_version, "employment_income") + _add_var(session, uk_version, "income_tax") + + response = client.post( + "/variables/by-name", + json={"names": ["employment_income", "income_tax"], "country_id": "uk"}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + returned_names = {v["name"] for v in data} + assert returned_names == {"employment_income", "income_tax"} + + def test_returns_empty_list_for_empty_names(self, client): + """Given an empty names list, returns an empty list.""" + response = client.post( + "/variables/by-name", + json={"names": [], "country_id": "uk"}, + ) + + assert response.status_code == 200 + assert response.json() == [] + + def test_returns_empty_list_for_unknown_names(self, client, session, uk_version): + """Given names that don't match any variable, returns an empty list.""" + _add_var(session, uk_version, "employment_income") + + response = client.post( + "/variables/by-name", + json={"names": ["nonexistent_var", "also_missing"], "country_id": "uk"}, + ) + + assert response.status_code == 200 + assert response.json() == [] + + def test_returns_only_matching_when_mix_of_known_and_unknown( + self, client, session, uk_version + ): + """Given a mix of known and unknown names, returns only the known ones.""" + _add_var(session, uk_version, "income_tax") + + response = client.post( + "/variables/by-name", + json={"names": ["income_tax", "fake_var"], "country_id": "uk"}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "income_tax" + + def test_single_name_lookup(self, client, session, uk_version): + """Looking up a single variable name works.""" + _add_var(session, uk_version, "age") + + response = client.post( + "/variables/by-name", + json={"names": ["age"], "country_id": "uk"}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["name"] == "age" + + def test_results_ordered_by_name(self, client, session, uk_version): + """Returned variables are sorted alphabetically by name.""" + _add_var(session, uk_version, "zzz_var") + _add_var(session, uk_version, "aaa_var") + _add_var(session, uk_version, "mmm_var") + + response = client.post( + "/variables/by-name", + json={ + "names": ["zzz_var", "aaa_var", "mmm_var"], + "country_id": "uk", + }, + ) + + assert response.status_code == 200 + names = [v["name"] for v in response.json()] + assert names == ["aaa_var", "mmm_var", "zzz_var"] + + def test_response_shape_matches_variable_read(self, client, session, uk_version): + """Returned objects have the same shape as VariableRead.""" + _add_var(session, uk_version, "income_tax", entity="person", description="Tax") + + response = client.post( + "/variables/by-name", + json={"names": ["income_tax"], "country_id": "uk"}, + ) + + assert response.status_code == 200 + var = response.json()[0] + assert "id" in var + assert "name" in var + assert "entity" in var + assert "description" in var + assert "created_at" in var + assert "tax_benefit_model_version_id" in var + + +class TestVariablesByNameCountryFiltering: + """Tests for country_id filtering.""" + + def test_country_isolation(self, client, session, uk_version, us_version): + """Variables from a different country are excluded.""" + _add_var(session, uk_version, "council_tax") + _add_var(session, us_version, "state_income_tax") + + uk_response = client.post( + "/variables/by-name", + json={"names": ["council_tax", "state_income_tax"], "country_id": "uk"}, + ) + us_response = client.post( + "/variables/by-name", + json={"names": ["council_tax", "state_income_tax"], "country_id": "us"}, + ) + + assert len(uk_response.json()) == 1 + assert uk_response.json()[0]["name"] == "council_tax" + assert len(us_response.json()) == 1 + assert us_response.json()[0]["name"] == "state_income_tax" + + def test_invalid_country_id_returns_422(self, client): + """An invalid country_id is rejected.""" + response = client.post( + "/variables/by-name", + json={"names": ["income_tax"], "country_id": "fr"}, + ) + + assert response.status_code == 422 + + +class TestVariablesByNameValidation: + """Tests for request validation.""" + + def test_missing_country_id_returns_422(self, client): + """Request without country_id is rejected.""" + response = client.post( + "/variables/by-name", + json={"names": ["income_tax"]}, + ) + + assert response.status_code == 422 + + def test_missing_names_field_returns_422(self, client): + """Request without names field is rejected.""" + response = client.post( + "/variables/by-name", + json={"country_id": "uk"}, + ) + + assert response.status_code == 422