Skip to content

Commit e37abc9

Browse files
committed
feat: support defining schemas as dicts
This provides an alternative way to define schemas, which IMO looks a bit nicer than the class based approach. One big benefit of this method is that you can use dashes in the keys, so it's possible to use both underscores and dashes as identifiers and it's clear to the user which is which.
1 parent 195eef8 commit e37abc9

2 files changed

Lines changed: 132 additions & 0 deletions

File tree

src/taskgraph/util/schema.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# License, v. 2.0. If a copy of the MPL was not distributed with this
33
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
44

5+
import inspect
56
import pprint
67
import re
78
import threading
@@ -318,6 +319,11 @@ def __getitem__(self, item):
318319
return self.schema[item] # type: ignore
319320

320321

322+
def _caller_module_name(depth=1):
323+
frame = inspect.stack()[depth + 1].frame
324+
return frame.f_globals.get("__name__", "schema")
325+
326+
321327
class Schema(
322328
msgspec.Struct,
323329
kw_only=True,
@@ -345,6 +351,11 @@ class MySchema(Schema, forbid_unknown_fields=False, kw_only=True):
345351
foo: str
346352
"""
347353

354+
def __init_subclass__(cls, exclusive=None, **kwargs):
355+
super().__init_subclass__(**kwargs)
356+
if exclusive is not None:
357+
cls.exclusive = exclusive
358+
348359
def __post_init__(self):
349360
if taskgraph.fast:
350361
return
@@ -370,6 +381,76 @@ def __post_init__(self):
370381

371382
keyed_by.validate(obj)
372383

384+
# Validate mutually exclusive field groups.
385+
for group in getattr(self, "exclusive", []):
386+
set_fields = [f for f in group if getattr(self, f) is not None]
387+
if len(set_fields) > 1:
388+
raise ValueError(
389+
f"{' and '.join(repr(f) for f in set_fields)} are mutually exclusive"
390+
)
391+
392+
@classmethod
393+
def from_dict(
394+
cls,
395+
fields_dict: dict[str, Any],
396+
name: Optional[str] = None,
397+
optional: bool = False,
398+
**kwargs,
399+
) -> Union[type[msgspec.Struct], type[Optional[msgspec.Struct]]]:
400+
"""Create a Schema subclass dynamically from a dict of field definitions.
401+
402+
Each key is a field name and each value is either a type annotation or a
403+
``(type, default)`` tuple. Fields typed as ``Optional[...]`` automatically
404+
receive a default of ``None`` when no explicit default is provided.
405+
406+
Usage::
407+
408+
Schema.from_dict("MySchema", {
409+
"required_field": str,
410+
"optional_field": Optional[int], # default None inferred
411+
"explicit_default": (list[str], []), # explicit default
412+
})
413+
414+
Keyword arguments are forwarded to ``msgspec.defstruct`` (e.g.
415+
``forbid_unknown_fields=False``).
416+
"""
417+
# Don't use `rename=kebab` by default as we can define kebab case
418+
# properly in dicts.
419+
kwargs.setdefault("rename", None)
420+
421+
# Ensure name and module are set correctly for error messages.
422+
caller_module = _caller_module_name()
423+
kwargs.setdefault("module", caller_module)
424+
name = name or caller_module.rsplit(".", 1)[-1]
425+
426+
fields = []
427+
for field_name, field_spec in fields_dict.items():
428+
python_name = field_name.replace("-", "_")
429+
430+
if isinstance(field_spec, tuple):
431+
typ, default = field_spec
432+
else:
433+
typ = field_spec
434+
if get_origin(typ) is Union and type(None) in get_args(typ):
435+
default = None
436+
else:
437+
default = msgspec.NODEFAULT
438+
439+
if field_name != python_name:
440+
# Use msgspec.field to preserve the kebab-case encoded name.
441+
# Explicit field names take priority over the struct-level rename.
442+
fields.append(
443+
(python_name, typ, msgspec.field(name=field_name, default=default))
444+
)
445+
else:
446+
fields.append((python_name, typ, default))
447+
448+
exclusive = kwargs.pop("exclusive", None)
449+
result = msgspec.defstruct(name or "Schema", fields, bases=(cls,), **kwargs)
450+
if exclusive:
451+
result.exclusive = exclusive # type: ignore[attr-defined]
452+
return Optional[result] if optional else result # type: ignore[valid-type]
453+
373454
@classmethod
374455
def validate(cls, data):
375456
"""Validate data against this schema."""

test/test_util_schema.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
44

55
import unittest
6+
from typing import Optional
67

78
import msgspec
89
import pytest
@@ -348,3 +349,53 @@ class TestSchema(Schema):
348349

349350
with pytest.raises(msgspec.ValidationError):
350351
TestSchema.validate({"field": {"by-foo": {"a": "b"}}})
352+
353+
354+
@pytest.mark.parametrize(
355+
"fields_dict, data, attr, expected",
356+
[
357+
({"name": str}, {"name": "foo"}, "name", "foo"),
358+
({"count": Optional[int]}, {}, "count", None),
359+
({"tags": (list, [])}, {}, "tags", []),
360+
({"my-field": str}, {"my-field": "bar"}, "my_field", "bar"),
361+
],
362+
)
363+
def test_from_dict_valid(fields_dict, data, attr, expected):
364+
S = Schema.from_dict(fields_dict)
365+
result = msgspec.convert(data, S)
366+
assert getattr(result, attr) == expected
367+
368+
369+
@pytest.mark.parametrize(
370+
"fields_dict, data",
371+
[
372+
({"name": str}, {}),
373+
({"my-field": str}, {"my_field": "bar"}),
374+
],
375+
)
376+
def test_from_dict_invalid(fields_dict, data):
377+
S = Schema.from_dict(fields_dict)
378+
with pytest.raises(msgspec.ValidationError):
379+
msgspec.convert(data, S)
380+
381+
382+
@pytest.mark.parametrize(
383+
"data, raises",
384+
[
385+
({"a": "x", "b": "y"}, True),
386+
({"a": "x"}, False),
387+
({}, False),
388+
],
389+
)
390+
def test_exclusive(data, raises):
391+
S = Schema.from_dict(
392+
{"a": Optional[str], "b": Optional[str]},
393+
exclusive=[["a", "b"]],
394+
)
395+
if raises:
396+
with pytest.raises(
397+
(ValueError, msgspec.ValidationError), match="mutually exclusive"
398+
):
399+
msgspec.convert(data, S)
400+
else:
401+
msgspec.convert(data, S)

0 commit comments

Comments
 (0)