Skip to content

Commit b3e0a05

Browse files
feat!: convert parameter schema from voluptuous to msgspec
1 parent a63271c commit b3e0a05

3 files changed

Lines changed: 99 additions & 66 deletions

File tree

src/taskgraph/parameters.py

Lines changed: 77 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@
1010
from io import BytesIO
1111
from pprint import pformat
1212
from subprocess import CalledProcessError
13+
from typing import Optional, Union
1314
from unittest.mock import Mock
1415
from urllib.parse import urlparse
1516
from urllib.request import urlopen
1617

1718
import mozilla_repo_urls
18-
from voluptuous import ALLOW_EXTRA, Any, Optional, Required, Schema
19+
import msgspec
1920

2021
from taskgraph.util import json, yaml
2122
from taskgraph.util.readonlydict import ReadOnlyDict
22-
from taskgraph.util.schema import validate_schema
23+
from taskgraph.util.schema import Schema, validate_schema
2324
from taskgraph.util.taskcluster import find_task_id, get_artifact_url
2425
from taskgraph.util.vcs import get_repository
2526

@@ -30,43 +31,50 @@ class ParameterMismatch(Exception):
3031

3132
#: Schema for base parameters.
3233
#: Please keep this list sorted and in sync with docs/reference/parameters.rst
33-
base_schema = Schema(
34+
base_schema = Schema.from_dict(
3435
{
35-
Required("base_repository"): str,
36-
Optional("base_ref"): str,
37-
Required("base_rev"): str,
38-
Required("build_date"): int,
39-
Required("build_number"): int,
40-
Required("do_not_optimize"): [str],
41-
Required("enable_always_target"): Any(bool, [str]),
42-
Required("existing_tasks"): {str: str},
43-
Required("files_changed"): [str],
44-
Required("filters"): [str],
45-
Required("head_ref"): str,
46-
Required("head_repository"): str,
47-
Required("head_rev"): str,
48-
Required("head_tag"): str,
49-
Required("level"): str,
50-
Required("moz_build_date"): str,
51-
Required("next_version"): Any(str, None),
52-
Required("optimize_strategies"): Any(str, None),
53-
Required("optimize_target_tasks"): bool,
54-
Required("owner"): str,
55-
Required("project"): str,
56-
Required("pushdate"): int,
57-
Required("pushlog_id"): str,
58-
Required("repository_type"): str,
36+
"base_repository": str,
37+
"base_ref": Optional[str],
38+
"base_rev": str,
39+
"build_date": int,
40+
"build_number": int,
41+
"do_not_optimize": list[str],
42+
"enable_always_target": Union[bool, list[str]],
43+
"existing_tasks": dict[str, str],
44+
"files_changed": list[str],
45+
"filters": list[str],
46+
"head_ref": str,
47+
"head_repository": str,
48+
"head_rev": str,
49+
"head_tag": str,
50+
"level": str,
51+
"moz_build_date": str,
52+
"next_version": Optional[str],
53+
"optimize_strategies": Optional[str],
54+
"optimize_target_tasks": bool,
55+
"owner": str,
56+
"project": str,
57+
"pushdate": int,
58+
"pushlog_id": str,
59+
"repository_type": str,
5960
# target-kinds is not included, since it should never be
6061
# used at run-time
61-
Required("target_tasks_method"): str,
62-
Required("tasks_for"): str,
63-
Required("version"): Any(str, None),
64-
Optional("code-review"): {
65-
Required("phabricator-build-target"): str,
66-
},
67-
}
62+
"target_tasks_method": str,
63+
"tasks_for": str,
64+
"version": Optional[str],
65+
"code-review": Schema.from_dict(
66+
{"phabricator-build-target": str},
67+
name="CodeReviewConfig",
68+
optional=True,
69+
),
70+
},
71+
name="BaseParametersSchema",
72+
forbid_unknown_fields=False,
73+
kw_only=True,
6874
)
6975

76+
_parameter_extensions: list = []
77+
7078

7179
def get_contents(path):
7280
with open(path) as fh:
@@ -143,19 +151,27 @@ def extend_parameters_schema(schema, defaults_fn=None):
143151
graph-configuration.
144152
145153
Args:
146-
schema (Schema): The voluptuous.Schema object used to describe extended
147-
parameters.
154+
schema: A msgspec ``Schema`` subclass describing extended parameters.
148155
defaults_fn (function): A function which takes no arguments and returns a
149156
dict mapping parameter name to default value in the
150157
event strict=False (optional).
151158
"""
152-
global base_schema
153159
global defaults_functions
154-
base_schema = base_schema.extend(schema)
160+
if not (isinstance(schema, type) and issubclass(schema, msgspec.Struct)):
161+
raise TypeError(
162+
"extend_parameters_schema requires a msgspec Schema subclass; "
163+
f"got {type(schema).__name__}"
164+
)
165+
_parameter_extensions.append(schema)
155166
if defaults_fn:
156167
defaults_functions.append(defaults_fn)
157168

158169

170+
def _schema_key_names(schema) -> set:
171+
"""Return the data-level field names declared by a parameters schema."""
172+
return {f.encode_name for f in msgspec.structs.fields(schema)}
173+
174+
159175
class Parameters(ReadOnlyDict):
160176
"""An immutable dictionary with nicer KeyError messages on failure"""
161177

@@ -214,11 +230,30 @@ def _fill_defaults(repo_root=None, **kwargs):
214230
return kwargs
215231

216232
def check(self):
217-
schema = (
218-
base_schema if self.strict else base_schema.extend({}, extra=ALLOW_EXTRA)
219-
)
233+
data = dict(self.copy())
220234
try:
221-
validate_schema(schema, self.copy(), "Invalid parameters:")
235+
# Validate core fields against just the subset of data owned by the
236+
# base schema. Extension keys are validated separately below, and a
237+
# strict-mode check rejects anything unknown to either.
238+
base_keys = _schema_key_names(base_schema)
239+
base_data = {k: v for k, v in data.items() if k in base_keys}
240+
validate_schema(base_schema, base_data, "Invalid parameters:")
241+
242+
# Validate each registered extension against the keys it declares.
243+
allowed = set(base_keys)
244+
for ext in _parameter_extensions:
245+
ext_keys = _schema_key_names(ext)
246+
allowed |= ext_keys
247+
ext_data = {k: data[k] for k in ext_keys if k in data}
248+
validate_schema(ext, ext_data, "Invalid parameters:")
249+
250+
# Strict mode: reject any data key not covered by base or extensions.
251+
if self.strict:
252+
unknown = sorted(set(data) - allowed)
253+
if unknown:
254+
raise Exception(
255+
"Invalid parameters:\nunknown keys: " + ", ".join(unknown)
256+
)
222257
except Exception as e:
223258
raise ParameterMismatch(str(e))
224259

taskcluster/self_taskgraph/custom_parameters.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
44

55
import os
6+
from typing import Annotated, Optional
67

7-
from voluptuous import All, Any, Range, Required
8+
import msgspec
89

910
from taskgraph.parameters import extend_parameters_schema
11+
from taskgraph.util.schema import Schema
1012

1113

1214
def get_defaults(repo_root):
@@ -15,14 +17,15 @@ def get_defaults(repo_root):
1517
}
1618

1719

18-
extend_parameters_schema(
19-
{
20-
Required("pull_request_number"): Any(All(int, Range(min=1)), None),
21-
},
22-
defaults_fn=get_defaults,
20+
CustomParametersSchema = Schema.from_dict(
21+
{"pull_request_number": Optional[Annotated[int, msgspec.Meta(ge=1)]]},
22+
name="CustomParametersSchema",
2323
)
2424

2525

26+
extend_parameters_schema(CustomParametersSchema, defaults_fn=get_defaults)
27+
28+
2629
def decision_parameters(graph_config, parameters):
2730
if parameters["tasks_for"] == "github-release":
2831
parameters["target_tasks_method"] = "release"

test/test_parameters.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import gzip
88
import os
99
from base64 import b64decode
10+
from typing import Optional
1011
from unittest import TestCase, mock
1112

1213
import mozilla_repo_urls
1314
import pytest
14-
from voluptuous import Optional, Required, Schema
1515

1616
import taskgraph # noqa: F401
1717
from taskgraph import parameters
@@ -21,6 +21,7 @@
2121
extend_parameters_schema,
2222
load_parameters_file,
2323
)
24+
from taskgraph.util.schema import Schema
2425

2526
from .mockedopen import MockedOpen
2627

@@ -274,20 +275,16 @@ def test_parameters_format_spec(spec, expected):
274275

275276

276277
def test_extend_parameters_schema(monkeypatch):
277-
monkeypatch.setattr(
278-
parameters,
279-
"base_schema",
280-
Schema(
281-
{
282-
Required("foo"): str,
283-
}
284-
),
285-
)
286-
monkeypatch.setattr(
287-
parameters,
288-
"defaults_functions",
289-
list(parameters.defaults_functions),
290-
)
278+
FooSchema = Schema.from_dict({"foo": str}, name="FooSchema")
279+
BarSchema = Schema.from_dict({"bar": Optional[bool]}, name="BarSchema")
280+
281+
monkeypatch.setattr(parameters, "base_schema", FooSchema)
282+
# Replace defaults_functions with an empty list so _fill_defaults doesn't
283+
# shell out to git via the built-in _get_defaults (which fails on Windows
284+
# CI when safe.directory isn't honored). The third assertion below adds
285+
# back its own defaults_fn via extend_parameters_schema.
286+
monkeypatch.setattr(parameters, "defaults_functions", [])
287+
monkeypatch.setattr(parameters, "_parameter_extensions", [])
291288

292289
with pytest.raises(ParameterMismatch):
293290
Parameters(strict=False).check()
@@ -296,9 +293,7 @@ def test_extend_parameters_schema(monkeypatch):
296293
Parameters(foo="1", bar=True).check()
297294

298295
extend_parameters_schema(
299-
{
300-
Optional("bar"): bool,
301-
},
296+
BarSchema,
302297
defaults_fn=lambda root: {"foo": "1", "bar": False},
303298
)
304299

0 commit comments

Comments
 (0)