Skip to content

Commit 0ea65e4

Browse files
authored
Merge pull request #124 from PolicyEngine/feat/parameter-nodes
feat: Add parameter_nodes table for folder/category labels
2 parents bb13933 + 3e01c82 commit 0ea65e4

7 files changed

Lines changed: 198 additions & 11 deletions

File tree

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Add parameter_nodes table
2+
3+
Revision ID: 67608331ee8a
4+
Revises: add_modelled_policies
5+
Create Date: 2026-03-10 18:29:54.555074
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
import sqlalchemy as sa
12+
import sqlmodel.sql.sqltypes
13+
14+
from alembic import op
15+
16+
# revision identifiers, used by Alembic.
17+
revision: str = "67608331ee8a"
18+
down_revision: Union[str, Sequence[str], None] = "886921687770"
19+
branch_labels: Union[str, Sequence[str], None] = None
20+
depends_on: Union[str, Sequence[str], None] = None
21+
22+
23+
def upgrade() -> None:
24+
"""Upgrade schema."""
25+
op.create_table(
26+
"parameter_nodes",
27+
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
28+
sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
29+
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
30+
sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False),
31+
sa.Column("id", sa.Uuid(), nullable=False),
32+
sa.Column("created_at", sa.DateTime(), nullable=False),
33+
sa.ForeignKeyConstraint(
34+
["tax_benefit_model_version_id"],
35+
["tax_benefit_model_versions.id"],
36+
),
37+
sa.PrimaryKeyConstraint("id"),
38+
)
39+
40+
41+
def downgrade() -> None:
42+
"""Downgrade schema."""
43+
op.drop_table("parameter_nodes")

changelog.d/124.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add parameter_nodes table to store folder/category labels for parameter tree navigation

scripts/seed_models.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
"""Seed tax-benefit models with variables and parameters.
1+
"""Seed tax-benefit models with variables, parameters, and parameter nodes.
22
33
This script seeds TaxBenefitModel, TaxBenefitModelVersion, Variables,
4-
Parameters, and ParameterValues from policyengine.py.
4+
Parameters, ParameterValues, and ParameterNodes from policyengine.py.
55
66
Usage:
77
python scripts/seed_models.py # Seed UK and US models
@@ -55,10 +55,10 @@ def seed_model(
5555
variable_whitelist: set[str] | None = None,
5656
parameter_prefixes: set[str] | None = None,
5757
) -> TaxBenefitModelVersion:
58-
"""Seed a tax-benefit model with its variables and parameters.
58+
"""Seed a tax-benefit model with its variables, parameters, and parameter nodes.
5959
6060
Args:
61-
model_version: The policyengine.py model version object
61+
model_version: The policyengine.py model version object (with parameter_nodes)
6262
session: Database session
6363
skip_state_params: Skip US state-level parameters (gov.states.*)
6464
variable_whitelist: If provided, only seed variables whose name is in this set
@@ -338,6 +338,72 @@ def seed_model(
338338
+ (f" (skipped {skipped} invalid)" if skipped else "")
339339
)
340340

341+
# Add parameter nodes (folder/category structure)
342+
# Uses model_version.parameter_nodes exposed by policyengine.py
343+
parameter_nodes = model_version.parameter_nodes
344+
345+
# Filter by prefix if specified (same as parameters)
346+
if parameter_prefixes is not None:
347+
parameter_nodes = [
348+
n
349+
for n in parameter_nodes
350+
if any(n.name.startswith(prefix) for prefix in parameter_prefixes)
351+
]
352+
353+
# Deduplicate by name
354+
seen_node_names = set()
355+
nodes_to_add = []
356+
for node in parameter_nodes:
357+
if node.name not in seen_node_names:
358+
nodes_to_add.append(node)
359+
seen_node_names.add(node.name)
360+
361+
console.print(f" Found {len(nodes_to_add)} parameter nodes (folder structure)")
362+
363+
with logfire.span("add_parameter_nodes", count=len(nodes_to_add)):
364+
node_rows = []
365+
366+
with Progress(
367+
SpinnerColumn(),
368+
TextColumn("[progress.description]{task.description}"),
369+
console=console,
370+
) as progress:
371+
task = progress.add_task(
372+
f"Preparing {len(nodes_to_add)} parameter nodes",
373+
total=len(nodes_to_add),
374+
)
375+
for node in nodes_to_add:
376+
node_rows.append(
377+
{
378+
"id": uuid4(),
379+
"name": node.name,
380+
"label": node.label,
381+
"description": node.description or "",
382+
"tax_benefit_model_version_id": db_version.id,
383+
"created_at": datetime.now(timezone.utc),
384+
}
385+
)
386+
progress.advance(task)
387+
388+
console.print(f" Inserting {len(node_rows)} parameter nodes...")
389+
bulk_insert(
390+
session,
391+
"parameter_nodes",
392+
[
393+
"id",
394+
"name",
395+
"label",
396+
"description",
397+
"tax_benefit_model_version_id",
398+
"created_at",
399+
],
400+
node_rows,
401+
)
402+
403+
console.print(
404+
f" [green]✓[/green] Added {len(nodes_to_add)} parameter nodes"
405+
)
406+
341407
return db_version
342408

343409

src/policyengine_api/api/parameters.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId
1818
from policyengine_api.models import (
1919
Parameter,
20+
ParameterNode,
2021
ParameterRead,
2122
TaxBenefitModel,
2223
TaxBenefitModelVersion,
@@ -144,14 +145,27 @@ def get_parameter_children(
144145
prefix = f"{parent_path}." if parent_path else ""
145146

146147
# Fetch all parameters under this path
147-
query = (
148+
param_query = (
148149
select(Parameter)
149150
.join(TaxBenefitModelVersion)
150151
.join(TaxBenefitModel)
151152
.where(TaxBenefitModel.name == model_name)
152153
.where(Parameter.name.startswith(prefix))
153154
)
154-
descendants = session.exec(query).all()
155+
descendants = session.exec(param_query).all()
156+
157+
# Fetch all parameter nodes under this path for labels
158+
node_query = (
159+
select(ParameterNode)
160+
.join(TaxBenefitModelVersion)
161+
.join(TaxBenefitModel)
162+
.where(TaxBenefitModel.name == model_name)
163+
.where(ParameterNode.name.startswith(prefix))
164+
)
165+
nodes = session.exec(node_query).all()
166+
167+
# Build a map of node path -> label for quick lookup
168+
node_labels: dict[str, str | None] = {node.name: node.label for node in nodes}
155169

156170
# Group by direct child path
157171
children_map: dict[str, dict] = {}
@@ -187,12 +201,13 @@ def get_parameter_children(
187201
info = children_map[path]
188202
if info["descendant_count"] > 0:
189203
# Node: has children below it
204+
# Priority: 1) parameter_nodes label, 2) direct_param label, 3) path segment
190205
direct_param = info["direct_param"]
191-
label = (
192-
direct_param.label
193-
if direct_param and direct_param.label
194-
else path.rsplit(".", 1)[-1]
195-
)
206+
label = node_labels.get(path)
207+
if not label and direct_param and direct_param.label:
208+
label = direct_param.label
209+
if not label:
210+
label = path.rsplit(".", 1)[-1]
196211
children.append(
197212
ParameterChild(
198213
path=path,

src/policyengine_api/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
AggregateType,
5555
)
5656
from .parameter import Parameter, ParameterCreate, ParameterRead
57+
from .parameter_node import ParameterNode, ParameterNodeCreate, ParameterNodeRead
5758
from .parameter_value import (
5859
ParameterValue,
5960
ParameterValueCreate,
@@ -166,6 +167,9 @@
166167
"IntraDecileImpactRead",
167168
"Parameter",
168169
"ParameterCreate",
170+
"ParameterNode",
171+
"ParameterNodeCreate",
172+
"ParameterNodeRead",
169173
"ParameterRead",
170174
"ParameterValue",
171175
"ParameterValueCreate",
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from datetime import datetime, timezone
2+
from typing import TYPE_CHECKING
3+
from uuid import UUID, uuid4
4+
5+
from sqlmodel import Field, Relationship, SQLModel
6+
7+
if TYPE_CHECKING:
8+
from .tax_benefit_model_version import TaxBenefitModelVersion
9+
10+
11+
class ParameterNodeBase(SQLModel):
12+
"""Base parameter node fields.
13+
14+
Parameter nodes represent folder/category nodes in the parameter hierarchy
15+
(e.g., "gov", "gov.hmrc", "gov.hmrc.income_tax"). They provide structure
16+
and human-readable labels for navigating the parameter tree, but don't
17+
have values themselves.
18+
"""
19+
20+
name: str = Field(description="Full path of the node (e.g., 'gov.hmrc')")
21+
label: str | None = Field(
22+
default=None, description="Human-readable label (e.g., 'HMRC')"
23+
)
24+
description: str | None = Field(default=None, description="Node description")
25+
tax_benefit_model_version_id: UUID = Field(
26+
foreign_key="tax_benefit_model_versions.id"
27+
)
28+
29+
30+
class ParameterNode(ParameterNodeBase, table=True):
31+
"""Parameter node database model."""
32+
33+
__tablename__ = "parameter_nodes"
34+
35+
id: UUID = Field(default_factory=uuid4, primary_key=True)
36+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
37+
38+
# Relationships
39+
tax_benefit_model_version: "TaxBenefitModelVersion" = Relationship(
40+
back_populates="parameter_nodes"
41+
)
42+
43+
44+
class ParameterNodeCreate(ParameterNodeBase):
45+
"""Schema for creating parameter nodes."""
46+
47+
pass
48+
49+
50+
class ParameterNodeRead(ParameterNodeBase):
51+
"""Schema for reading parameter nodes."""
52+
53+
id: UUID
54+
created_at: datetime

src/policyengine_api/models/tax_benefit_model_version.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
if TYPE_CHECKING:
88
from .parameter import Parameter
9+
from .parameter_node import ParameterNode
910
from .tax_benefit_model import TaxBenefitModel
1011
from .variable import Variable
1112

@@ -34,6 +35,9 @@ class TaxBenefitModelVersion(TaxBenefitModelVersionBase, table=True):
3435
parameters: list["Parameter"] = Relationship(
3536
back_populates="tax_benefit_model_version"
3637
)
38+
parameter_nodes: list["ParameterNode"] = Relationship(
39+
back_populates="tax_benefit_model_version"
40+
)
3741

3842

3943
class TaxBenefitModelVersionCreate(TaxBenefitModelVersionBase):

0 commit comments

Comments
 (0)