Skip to content

Commit 41a493d

Browse files
committed
Format YAML security hardening changes
1 parent 655784c commit 41a493d

5 files changed

Lines changed: 219 additions & 23 deletions

File tree

policyengine_core/parameters/config.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,23 @@ def date_constructor(_loader, node):
3333

3434

3535
def dict_no_duplicate_constructor(loader, node, deep=False):
36-
keys = [key.value for key, value in node.value]
37-
38-
if len(keys) != len(set(keys)):
39-
duplicate = next((key for key in keys if keys.count(key) > 1))
40-
raise yaml.parser.ParserError(
41-
"", node.start_mark, f"Found duplicate key '{duplicate}'"
42-
)
43-
44-
return loader.construct_mapping(node, deep)
36+
loader.flatten_mapping(node)
37+
pairs = loader.construct_pairs(node, deep=deep)
38+
mapping = {}
39+
40+
for key, value in pairs:
41+
try:
42+
if key in mapping:
43+
raise yaml.parser.ParserError(
44+
"", node.start_mark, f"Found duplicate key '{key}'"
45+
)
46+
except TypeError as exc:
47+
raise yaml.constructor.ConstructorError(
48+
"", node.start_mark, f"Found unhashable key '{key}'"
49+
) from exc
50+
mapping[key] = value
51+
52+
return mapping
4553

4654

4755
yaml.add_constructor(

policyengine_core/parameters/operations/homogenize_parameters.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from policyengine_core.parameters.parameter_node import ParameterNode
88
from policyengine_core.variables import Variable
99

10+
MAX_DYNAMIC_BREAKDOWN_VALUES = 10_000
11+
1012

1113
def homogenize_parameter_structures(
1214
root: ParameterNode, variables: Dict[str, Variable], default_value: Any = 0
@@ -44,6 +46,11 @@ def get_breakdown_variables(node: ParameterNode) -> List[str]:
4446
f"Invalid breakdown metadata for parameter {node.name}: {type(breakdown)}"
4547
)
4648
return None
49+
if len(breakdown) == 0:
50+
logging.warning(
51+
f"Invalid breakdown metadata for parameter {node.name}: empty list"
52+
)
53+
return None
4754
return breakdown
4855
else:
4956
return None
@@ -131,41 +138,72 @@ def evaluate_dynamic_breakdown(expression: str) -> List[Any]:
131138
parsed = ast.parse(expression, mode="eval")
132139
evaluated = evaluate_dynamic_breakdown_node(parsed.body)
133140
if isinstance(evaluated, range):
141+
validate_dynamic_breakdown_range_cardinality(evaluated, expression)
134142
return list(evaluated)
135143
if isinstance(evaluated, (list, tuple)):
144+
validate_dynamic_breakdown_cardinality(len(evaluated), expression)
136145
return list(evaluated)
137146
if isinstance(evaluated, set):
147+
validate_dynamic_breakdown_cardinality(len(evaluated), expression)
138148
return list(evaluated)
139149
raise ValueError(
140150
f"Invalid dynamic breakdown expression '{expression}'. "
141151
"Only literal collections and range() calls are allowed."
142152
)
143153

144154

155+
def validate_dynamic_breakdown_cardinality(count: int, expression: str) -> None:
156+
if count > MAX_DYNAMIC_BREAKDOWN_VALUES:
157+
raise ValueError(
158+
f"Dynamic breakdown expression '{expression}' produces {count} values, "
159+
f"which exceeds the maximum of {MAX_DYNAMIC_BREAKDOWN_VALUES}."
160+
)
161+
162+
163+
def validate_dynamic_breakdown_range_cardinality(
164+
values: range, expression: str
165+
) -> None:
166+
try:
167+
count = len(values)
168+
except OverflowError as exc:
169+
raise ValueError(
170+
f"Dynamic breakdown expression '{expression}' produces too many values."
171+
) from exc
172+
validate_dynamic_breakdown_cardinality(count, expression)
173+
174+
145175
def evaluate_dynamic_breakdown_node(node: ast.AST) -> Any:
146176
if isinstance(node, ast.Constant):
147177
return node.value
148178
if isinstance(node, ast.List):
179+
validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node))
149180
return [evaluate_dynamic_breakdown_node(element) for element in node.elts]
150181
if isinstance(node, ast.Tuple):
182+
validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node))
151183
return tuple(evaluate_dynamic_breakdown_node(element) for element in node.elts)
152184
if isinstance(node, ast.Set):
185+
validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node))
153186
return {evaluate_dynamic_breakdown_node(element) for element in node.elts}
154-
if isinstance(node, ast.UnaryOp) and isinstance(
155-
node.op, (ast.UAdd, ast.USub)
156-
):
187+
if isinstance(node, ast.UnaryOp) and isinstance(node.op, (ast.UAdd, ast.USub)):
157188
operand = evaluate_dynamic_breakdown_node(node.operand)
158189
return operand if isinstance(node.op, ast.UAdd) else -operand
159190
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
160191
if node.func.id == "range":
161192
args = [evaluate_dynamic_breakdown_node(arg) for arg in node.args]
162193
if node.keywords:
163194
raise ValueError("range() keyword arguments are not allowed")
164-
return range(*args)
195+
result = range(*args)
196+
validate_dynamic_breakdown_range_cardinality(result, ast.unparse(node))
197+
return result
165198
if node.func.id == "list":
166199
if len(node.args) != 1 or node.keywords:
167200
raise ValueError("list() must contain a single positional argument")
168-
return list(evaluate_dynamic_breakdown_node(node.args[0]))
201+
evaluated = evaluate_dynamic_breakdown_node(node.args[0])
202+
if isinstance(evaluated, (range, list, tuple, set)):
203+
return evaluated
204+
raise ValueError(
205+
"list() only supports range() and literal collection expressions"
206+
)
169207
raise ValueError(
170208
f"Unsupported dynamic breakdown expression: {ast.unparse(node) if hasattr(ast, 'unparse') else type(node).__name__}"
171209
)

policyengine_core/tools/test_runner.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,7 @@ def __init__(self, *, tax_benefit_system, options, **kwargs):
119119
def collect(self):
120120
try:
121121
tests = yaml.load(self.path.open(), Loader=Loader)
122-
except (
123-
yaml.scanner.ScannerError,
124-
yaml.parser.ParserError,
125-
yaml.constructor.ConstructorError,
126-
TypeError,
127-
):
122+
except (yaml.YAMLError, TypeError):
128123
message = os.linesep.join(
129124
[
130125
traceback.format_exc(),
@@ -137,6 +132,11 @@ def collect(self):
137132
tests: List[Dict] = [tests]
138133

139134
for test in tests:
135+
if not isinstance(test, dict):
136+
raise ValueError(
137+
f"'{self.path}' is not a valid YAML test file. "
138+
"Expected a mapping or a list of mappings."
139+
)
140140
if not self.should_ignore(test):
141141
yield YamlItem.from_parent(
142142
self,
@@ -148,11 +148,19 @@ def collect(self):
148148

149149
def should_ignore(self, test):
150150
name_filter = self.options.get("name_filter")
151+
keywords = test.get("keywords", [])
152+
if keywords is None:
153+
keywords = []
154+
if not isinstance(keywords, list):
155+
raise ValueError(
156+
f"'{self.path}' is not a valid YAML test file. "
157+
"'keywords' must be a list."
158+
)
151159
return (
152160
name_filter is not None
153161
and name_filter not in os.path.splitext(self.fspath.basename)[0]
154162
and name_filter not in test.get("name", "")
155-
and name_filter not in test.get("keywords", [])
163+
and name_filter not in keywords
156164
)
157165

158166

tests/core/test_parameter_security.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from policyengine_core.errors import ParameterParsingError
44
from policyengine_core.parameters import ParameterNode, homogenize_parameter_structures
55
from policyengine_core.parameters.helpers import _load_yaml_file
6+
from policyengine_core.parameters.operations.homogenize_parameters import (
7+
MAX_DYNAMIC_BREAKDOWN_VALUES,
8+
)
69

710

811
def test_parameter_yaml_loader_rejects_python_object_tags(tmp_path, monkeypatch):
@@ -32,8 +35,9 @@ def test_homogenize_parameter_structures_rejects_dynamic_breakdown_code(
3235

3336
monkeypatch.setattr(
3437
"builtins.eval",
35-
lambda expression, globals=None, locals=None: eval_calls.append(expression)
36-
or range(1, 4),
38+
lambda expression, globals=None, locals=None: (
39+
eval_calls.append(expression) or range(1, 4)
40+
),
3741
)
3842

3943
root = ParameterNode(
@@ -50,3 +54,58 @@ def test_homogenize_parameter_structures_rejects_dynamic_breakdown_code(
5054
homogenize_parameter_structures(root, {}, default_value=0)
5155

5256
assert eval_calls == []
57+
58+
59+
def test_homogenize_parameter_structures_rejects_oversized_dynamic_breakdown():
60+
root = ParameterNode(
61+
data={
62+
"value_by_category": {
63+
"metadata": {
64+
"breakdown": [f"list(range({MAX_DYNAMIC_BREAKDOWN_VALUES + 1}))"],
65+
},
66+
}
67+
}
68+
)
69+
70+
with pytest.raises(ValueError, match="exceeds the maximum"):
71+
homogenize_parameter_structures(root, {}, default_value=0)
72+
73+
74+
def test_homogenize_parameter_structures_rejects_overflowing_dynamic_breakdown():
75+
huge_stop = "1" + ("0" * 100)
76+
root = ParameterNode(
77+
data={
78+
"value_by_category": {
79+
"metadata": {
80+
"breakdown": [f"range(0, {huge_stop})"],
81+
},
82+
}
83+
}
84+
)
85+
86+
with pytest.raises(ValueError, match="too many values"):
87+
homogenize_parameter_structures(root, {}, default_value=0)
88+
89+
90+
def test_parameter_yaml_loader_rejects_implicit_duplicate_keys(tmp_path):
91+
yaml_path = tmp_path / "duplicate-bools.yaml"
92+
yaml_path.write_text("true: 1\nTrue: 2\n", encoding="utf-8")
93+
94+
with pytest.raises(ParameterParsingError, match="duplicate key"):
95+
_load_yaml_file(str(yaml_path))
96+
97+
98+
def test_homogenize_parameter_structures_ignores_empty_breakdown_lists():
99+
root = ParameterNode(
100+
data={
101+
"value_by_category": {
102+
"metadata": {
103+
"breakdown": [],
104+
},
105+
}
106+
}
107+
)
108+
109+
result = homogenize_parameter_structures(root, {}, default_value=0)
110+
111+
assert result is root

tests/core/tools/test_runner/test_yaml_runner.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,89 @@ def test_yaml_runner_rejects_python_object_tags(tmp_path, monkeypatch):
227227
assert calls == []
228228

229229

230+
def test_yaml_runner_wraps_composer_errors(tmp_path):
231+
yaml_path = tmp_path / "invalid-anchor.yaml"
232+
yaml_path.write_text("value: *missing_anchor\n", encoding="utf-8")
233+
234+
invalid_yaml_file = object.__new__(YamlFile)
235+
invalid_yaml_file.path = yaml_path
236+
invalid_yaml_file.options = {}
237+
invalid_yaml_file.tax_benefit_system = TaxBenefitSystem()
238+
239+
with pytest.raises(ValueError, match="not a valid YAML file"):
240+
list(invalid_yaml_file.collect())
241+
242+
243+
def test_yaml_runner_rejects_scalar_roots(tmp_path):
244+
yaml_path = tmp_path / "scalar.yaml"
245+
yaml_path.write_text("foo\n", encoding="utf-8")
246+
247+
scalar_yaml_file = object.__new__(YamlFile)
248+
scalar_yaml_file.path = yaml_path
249+
scalar_yaml_file.options = {}
250+
scalar_yaml_file.tax_benefit_system = TaxBenefitSystem()
251+
252+
with pytest.raises(ValueError, match="list of mappings"):
253+
list(scalar_yaml_file.collect())
254+
255+
256+
def test_yaml_runner_rejects_scalar_keywords(tmp_path):
257+
yaml_path = tmp_path / "invalid-keywords.yaml"
258+
yaml_path.write_text(
259+
"name: Example\nkeywords: 0\noutput: {}\n",
260+
encoding="utf-8",
261+
)
262+
263+
invalid_yaml_file = object.__new__(YamlFile)
264+
invalid_yaml_file.path = yaml_path
265+
invalid_yaml_file.options = {"name_filter": "missing"}
266+
invalid_yaml_file.tax_benefit_system = TaxBenefitSystem()
267+
268+
with pytest.raises(ValueError, match="'keywords' must be a list"):
269+
list(invalid_yaml_file.collect())
270+
271+
272+
def test_yaml_runner_allows_yaml_merge_anchors(tmp_path):
273+
yaml_path = tmp_path / "anchors.yaml"
274+
yaml_path.write_text(
275+
"""
276+
- name: define anchor
277+
input:
278+
persons: &persons
279+
Alicia:
280+
salary: 4000
281+
households:
282+
household:
283+
parents: [Alicia]
284+
output:
285+
salary: 4000
286+
287+
- name: merge anchor
288+
input:
289+
persons:
290+
<<: *persons
291+
households:
292+
household:
293+
parents: [Alicia]
294+
output:
295+
salary: 4000
296+
""".strip(),
297+
encoding="utf-8",
298+
)
299+
300+
yaml_file = object.__new__(YamlFile)
301+
yaml_file.config = None
302+
yaml_file.session = None
303+
yaml_file._nodeid = "anchors"
304+
yaml_file.path = yaml_path
305+
yaml_file.options = {}
306+
yaml_file.tax_benefit_system = TaxBenefitSystem()
307+
308+
collected = list(yaml_file.collect())
309+
310+
assert len(collected) == 2
311+
312+
230313
def clean_performance_files(paths: List[str]):
231314
for path in paths:
232315
if os.path.isfile(path):

0 commit comments

Comments
 (0)