Skip to content

Commit a498061

Browse files
feat: convert parameter schema from voluptuous to msgspec
1 parent c2bd792 commit a498061

3 files changed

Lines changed: 107 additions & 67 deletions

File tree

src/taskgraph/parameters.py

Lines changed: 82 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@
1010
from io import BytesIO
1111
from pprint import pformat
1212
from subprocess import CalledProcessError
13+
from typing import Optional, Union
1314
from unittest.mock import Mock
1415
from urllib.parse import urlparse
1516
from urllib.request import urlopen
1617

1718
import mozilla_repo_urls
18-
from voluptuous import ALLOW_EXTRA, Any, Optional, Required, Schema
19+
import msgspec
1920

2021
from taskgraph.util import json, yaml
2122
from taskgraph.util.readonlydict import ReadOnlyDict
22-
from taskgraph.util.schema import validate_schema
23+
from taskgraph.util.schema import Schema, validate_schema
2324
from taskgraph.util.taskcluster import find_task_id, get_artifact_url
2425
from taskgraph.util.vcs import get_repository
2526

@@ -28,44 +29,52 @@ class ParameterMismatch(Exception):
2829
"""Raised when a parameters.yml has extra or missing parameters."""
2930

3031

32+
class CodeReviewConfig(Schema, kw_only=True):
33+
phabricator_build_target: str
34+
35+
3136
#: Schema for base parameters.
3237
#: Please keep this list sorted and in sync with docs/reference/parameters.rst
33-
base_schema = Schema(
34-
{
35-
Required("base_repository"): str,
36-
Optional("base_ref"): str,
37-
Required("base_rev"): str,
38-
Required("build_date"): int,
39-
Required("build_number"): int,
40-
Required("do_not_optimize"): [str],
41-
Required("enable_always_target"): Any(bool, [str]),
42-
Required("existing_tasks"): {str: str},
43-
Required("files_changed"): [str],
44-
Required("filters"): [str],
45-
Required("head_ref"): str,
46-
Required("head_repository"): str,
47-
Required("head_rev"): str,
48-
Required("head_tag"): str,
49-
Required("level"): str,
50-
Required("moz_build_date"): str,
51-
Required("next_version"): Any(str, None),
52-
Required("optimize_strategies"): Any(str, None),
53-
Required("optimize_target_tasks"): bool,
54-
Required("owner"): str,
55-
Required("project"): str,
56-
Required("pushdate"): int,
57-
Required("pushlog_id"): str,
58-
Required("repository_type"): str,
59-
# target-kinds is not included, since it should never be
60-
# used at run-time
61-
Required("target_tasks_method"): str,
62-
Required("tasks_for"): str,
63-
Required("version"): Any(str, None),
64-
Optional("code-review"): {
65-
Required("phabricator-build-target"): str,
66-
},
67-
}
68-
)
38+
class BaseParametersSchema(
39+
Schema,
40+
kw_only=True,
41+
forbid_unknown_fields=False,
42+
rename={"code_review": "code-review"},
43+
):
44+
base_repository: str
45+
base_rev: str
46+
build_date: int
47+
build_number: int
48+
do_not_optimize: list[str]
49+
enable_always_target: Union[bool, list[str]]
50+
existing_tasks: dict[str, str]
51+
files_changed: list[str]
52+
filters: list[str]
53+
head_ref: str
54+
head_repository: str
55+
head_rev: str
56+
head_tag: str
57+
level: str
58+
moz_build_date: str
59+
next_version: Optional[str]
60+
optimize_strategies: Optional[str]
61+
optimize_target_tasks: bool
62+
owner: str
63+
project: str
64+
pushdate: int
65+
pushlog_id: str
66+
repository_type: str
67+
# target-kinds is not included, since it should never be
68+
# used at run-time
69+
target_tasks_method: str
70+
tasks_for: str
71+
version: Optional[str]
72+
base_ref: Optional[str] = None
73+
code_review: Optional[CodeReviewConfig] = None
74+
75+
76+
base_schema = BaseParametersSchema
77+
_parameter_extensions: list = []
6978

7079

7180
def get_contents(path):
@@ -143,19 +152,27 @@ def extend_parameters_schema(schema, defaults_fn=None):
143152
graph-configuration.
144153
145154
Args:
146-
schema (Schema): The voluptuous.Schema object used to describe extended
147-
parameters.
155+
schema: A msgspec ``Schema`` subclass describing extended parameters.
148156
defaults_fn (function): A function which takes no arguments and returns a
149157
dict mapping parameter name to default value in the
150158
event strict=False (optional).
151159
"""
152-
global base_schema
153160
global defaults_functions
154-
base_schema = base_schema.extend(schema)
161+
if not (isinstance(schema, type) and issubclass(schema, msgspec.Struct)):
162+
raise TypeError(
163+
"extend_parameters_schema requires a msgspec Schema subclass; "
164+
f"got {type(schema).__name__}"
165+
)
166+
_parameter_extensions.append(schema)
155167
if defaults_fn:
156168
defaults_functions.append(defaults_fn)
157169

158170

171+
def _schema_key_names(schema) -> set:
172+
"""Return the data-level field names declared by a parameters schema."""
173+
return {f.encode_name for f in msgspec.structs.fields(schema)}
174+
175+
159176
class Parameters(ReadOnlyDict):
160177
"""An immutable dictionary with nicer KeyError messages on failure"""
161178

@@ -214,11 +231,30 @@ def _fill_defaults(repo_root=None, **kwargs):
214231
return kwargs
215232

216233
def check(self):
217-
schema = (
218-
base_schema if self.strict else base_schema.extend({}, extra=ALLOW_EXTRA)
219-
)
234+
data = dict(self.copy())
220235
try:
221-
validate_schema(schema, self.copy(), "Invalid parameters:")
236+
# Validate core fields against just the subset of data owned by the
237+
# base schema. Extension keys are validated separately below, and a
238+
# strict-mode check rejects anything unknown to either.
239+
base_keys = _schema_key_names(base_schema)
240+
base_data = {k: v for k, v in data.items() if k in base_keys}
241+
validate_schema(base_schema, base_data, "Invalid parameters:")
242+
243+
# Validate each registered extension against the keys it declares.
244+
allowed = set(base_keys)
245+
for ext in _parameter_extensions:
246+
ext_keys = _schema_key_names(ext)
247+
allowed |= ext_keys
248+
ext_data = {k: data[k] for k in ext_keys if k in data}
249+
validate_schema(ext, ext_data, "Invalid parameters:")
250+
251+
# Strict mode: reject any data key not covered by base or extensions.
252+
if self.strict:
253+
unknown = sorted(set(data) - allowed)
254+
if unknown:
255+
raise Exception(
256+
"Invalid parameters:\nunknown keys: " + ", ".join(unknown)
257+
)
222258
except Exception as e:
223259
raise ParameterMismatch(str(e))
224260

taskcluster/self_taskgraph/custom_parameters.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
44

55
import os
6-
7-
from voluptuous import All, Any, Range, Required
6+
from typing import Optional
87

98
from taskgraph.parameters import extend_parameters_schema
9+
from taskgraph.util.schema import Schema
1010

1111

1212
def get_defaults(repo_root):
@@ -15,12 +15,18 @@ def get_defaults(repo_root):
1515
}
1616

1717

18-
extend_parameters_schema(
19-
{
20-
Required("pull_request_number"): Any(All(int, Range(min=1)), None),
21-
},
22-
defaults_fn=get_defaults,
23-
)
18+
class CustomParametersSchema(Schema, kw_only=True, rename=None):
19+
pull_request_number: Optional[int]
20+
21+
def __post_init__(self):
22+
super().__post_init__()
23+
if self.pull_request_number is not None and self.pull_request_number < 1:
24+
raise ValueError(
25+
f"pull_request_number must be >= 1, got {self.pull_request_number}"
26+
)
27+
28+
29+
extend_parameters_schema(CustomParametersSchema, defaults_fn=get_defaults)
2430

2531

2632
def decision_parameters(graph_config, parameters):

test/test_parameters.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import gzip
88
import os
99
from base64 import b64decode
10+
from typing import Optional
1011
from unittest import TestCase, mock
1112

1213
import mozilla_repo_urls
1314
import pytest
14-
from voluptuous import Optional, Required, Schema
1515

1616
import taskgraph # noqa: F401
1717
from taskgraph import parameters
@@ -21,6 +21,7 @@
2121
extend_parameters_schema,
2222
load_parameters_file,
2323
)
24+
from taskgraph.util.schema import Schema
2425

2526
from .mockedopen import MockedOpen
2627

@@ -274,20 +275,19 @@ def test_parameters_format_spec(spec, expected):
274275

275276

276277
def test_extend_parameters_schema(monkeypatch):
277-
monkeypatch.setattr(
278-
parameters,
279-
"base_schema",
280-
Schema(
281-
{
282-
Required("foo"): str,
283-
}
284-
),
285-
)
278+
class FooSchema(Schema, kw_only=True, rename=None):
279+
foo: str
280+
281+
class BarSchema(Schema, kw_only=True, rename=None):
282+
bar: Optional[bool] = None
283+
284+
monkeypatch.setattr(parameters, "base_schema", FooSchema)
286285
monkeypatch.setattr(
287286
parameters,
288287
"defaults_functions",
289288
list(parameters.defaults_functions),
290289
)
290+
monkeypatch.setattr(parameters, "_parameter_extensions", [])
291291

292292
with pytest.raises(ParameterMismatch):
293293
Parameters(strict=False).check()
@@ -296,9 +296,7 @@ def test_extend_parameters_schema(monkeypatch):
296296
Parameters(foo="1", bar=True).check()
297297

298298
extend_parameters_schema(
299-
{
300-
Optional("bar"): bool,
301-
},
299+
BarSchema,
302300
defaults_fn=lambda root: {"foo": "1", "bar": False},
303301
)
304302

0 commit comments

Comments
 (0)