From a2da7fd638d731bdc2386d9dd0b8de9cdd61be0d Mon Sep 17 00:00:00 2001 From: David C Ellis Date: Mon, 27 Apr 2026 10:49:58 +0100 Subject: [PATCH] Make kw_only on a class only change fields on that class --- src/ducktools/classbuilder/__init__.py | 117 +++++++++++----- src/ducktools/classbuilder/__init__.pyi | 2 + src/ducktools/classbuilder/prefab.py | 26 +--- src/ducktools/classbuilder/prefab.pyi | 8 +- tests/prefab/test_kw_only.py | 178 +++++++++++++++--------- tests/py314_tests/test_field_type.py | 16 +++ tests/test_core.py | 75 +++++----- 7 files changed, 258 insertions(+), 164 deletions(-) create mode 100644 tests/py314_tests/test_field_type.py diff --git a/src/ducktools/classbuilder/__init__.py b/src/ducktools/classbuilder/__init__.py index 9527ce8..3a2b6c2 100644 --- a/src/ducktools/classbuilder/__init__.py +++ b/src/ducktools/classbuilder/__init__.py @@ -334,15 +334,12 @@ def __get__(self, instance, cls=None): def get_init_generator(null=NOTHING, extra_code=None): def cls_init_maker(cls, funcname="__init__"): fields = get_fields(cls) - flags = get_flags(cls) arglist = [] kw_only_arglist = [] assignments = [] globs = {} - kw_only_flag = flags.get("kw_only", False) - for k, v in fields.items(): if v.init: if v.default is not null: @@ -357,7 +354,7 @@ def cls_init_maker(cls, funcname="__init__"): arg = f"{k}" assignment = f"self.{k} = {k}" - if kw_only_flag or v.kw_only: + if v.kw_only: kw_only_arglist.append(arg) else: arglist.append(arg) @@ -548,31 +545,45 @@ def ge_generator(cls, funcname="__ge__"): return get_order_generator(cls, funcname, operator=">=") -def replace_generator(cls, funcname="__replace__"): - # Generate the replace method for built classes - # unlike the dataclasses implementation this is generated - attribs = get_fields(cls) +def _get_replace_generator(private_type=False): + def cls_replace_generator(cls, funcname="__replace__"): + # Generate the replace method for built classes + # unlike the dataclasses implementation this is generated + attribs = get_fields(cls) + + # This is essentially the as_dict generator for prefabs + # except based on attrib.init instead of .serialize + if private_type: + vals = ", ".join( + f"'{name}': self.{name}" + if name != "type" + else f"'{name}': self._{name}" + for name, attrib in attribs.items() + if attrib.init + ) + else: + vals = ", ".join( + f"'{name}': self.{name}" + for name, attrib in attribs.items() + if attrib.init + ) + init_dict = f"{{{vals}}}" - # This is essentially the as_dict generator for prefabs - # except based on attrib.init instead of .serialize - vals = ", ".join( - f"'{name}': self.{name}" - for name, attrib in attribs.items() - if attrib.init - ) - init_dict = f"{{{vals}}}" + # fmt: off + code = ( + f"def {funcname}(self, /, **changes):\n" + f" new_kwargs = {init_dict}\n" + f" new_kwargs |= changes\n" + f" return self.__class__(**new_kwargs)\n" + ) + # fmt: on + globs = {} + return GeneratedCode(code, globs) - # fmt: off - code = ( - f"def {funcname}(self, /, **changes):\n" - f" new_kwargs = {init_dict}\n" - f" new_kwargs |= changes\n" - f" return self.__class__(**new_kwargs)\n" - ) - # fmt: on - globs = {} - return GeneratedCode(code, globs) + return cls_replace_generator +replace_generator = _get_replace_generator() +_field_replace_generator = _get_replace_generator(private_type=True) def frozen_setattr_generator(cls, funcname="__setattr__"): globs = {} @@ -657,6 +668,10 @@ def hash_generator(cls, funcname="__hash__"): ) ) +# Special `__replace__` method for `Field` that will use the internal `_type` +# value instead of the resolved `type` property +_field_replace_maker = MethodMaker("__replace__", _field_replace_generator) + def add_methods(cls, methods, *, internals=None): """ @@ -730,6 +745,7 @@ def builder(cls=None, /, *, gatherer, methods, flags=None, fix_signature=True, f flag_dict |= flags internals["flags"] = flag_dict + kw_only = flag_dict.get("kw_only", False) cls_gathered = cls.__dict__.get(GATHERED_DATA) if cls_gathered: @@ -737,6 +753,13 @@ def builder(cls=None, /, *, gatherer, methods, flags=None, fix_signature=True, f else: cls_fields, modifications = gatherer(cls) + if kw_only: + # Update the class fields to make all Fields kw_only + cls_fields = { + k: v if v.kw_only else v.__replace__(kw_only=True) + for k, v in cls_fields.items() + } + for name, value in modifications.items(): if value is NOTHING: delattr(cls, name) @@ -1075,7 +1098,7 @@ def __init__( def __init_subclass__(cls, frozen=False, ignore_annotations=False): # Subclasses of Field can be created as if they are dataclasses - field_methods = {_field_init_maker, repr_maker, eq_maker, replace_maker} + field_methods = {_field_init_maker, repr_maker, eq_maker, _field_replace_maker} if frozen or _UNDER_TESTING: field_methods |= {frozen_setattr_maker, frozen_delattr_maker} @@ -1110,8 +1133,9 @@ def from_field(cls, fld, /, **kwargs): :param kwargs: Additional keyword arguments for subclasses :return: new field subclass instance """ + # type is special cased to get the internal value inst_fields = { - k: getattr(fld, k) + k: getattr(fld, k) if k != "type" else getattr(fld, "_type") for k in get_fields(type(fld)) } argument_dict = {**inst_fields, **kwargs} @@ -1152,19 +1176,21 @@ def _build_field(): "kw_only": "Make this a keyword only parameter in __init__", } + # Fields here must be marked as kw_only to prevent the builder from trying + # to call the __replace__ method which doesn't exist yet fields = { - "default": Field(default=NOTHING, doc=field_docs["default"]), - "default_factory": Field(default=NOTHING, doc=field_docs["default_factory"]), - "type": Field(default=NOTHING, doc=field_docs["type"]), - "doc": Field(default=None, doc=field_docs["doc"]), - "init": Field(default=True, doc=field_docs["init"]), - "repr": Field(default=True, doc=field_docs["repr"]), - "compare": Field(default=True, doc=field_docs["compare"]), - "kw_only": Field(default=False, doc=field_docs["kw_only"]), + "default": Field(default=NOTHING, doc=field_docs["default"], kw_only=True), + "default_factory": Field(default=NOTHING, doc=field_docs["default_factory"], kw_only=True), + "type": Field(default=NOTHING, doc=field_docs["type"], kw_only=True), + "doc": Field(default=None, doc=field_docs["doc"], kw_only=True), + "init": Field(default=True, doc=field_docs["init"], kw_only=True), + "repr": Field(default=True, doc=field_docs["repr"], kw_only=True), + "compare": Field(default=True, doc=field_docs["compare"], kw_only=True), + "kw_only": Field(default=False, doc=field_docs["kw_only"], kw_only=True), } modifications = {"__slots__": field_docs} - field_methods = {repr_maker, eq_maker, replace_maker} + field_methods = {repr_maker, eq_maker, _field_replace_maker} if _UNDER_TESTING: field_methods |= {frozen_setattr_maker, frozen_delattr_maker} @@ -1448,6 +1474,23 @@ def check_argument_order(cls): used_default = True +def replace(obj, /, **changes): + """ + Create a copy of a prefab instance with values provided to 'changes' replaced + + :param obj: built class + :return: new built class instance with changes applied + """ + if not build_completed(type(obj)): + raise TypeError("replace() should be called on classbuilder class instances") + try: + replace_func = obj.__replace__ + except AttributeError: + raise TypeError(f"{obj.__class__.__name__!r} does not support __replace__") + + return replace_func(**changes) + + # Class Decorators def slotclass(cls=None, /, *, methods=default_methods, syntax_check=True): """ diff --git a/src/ducktools/classbuilder/__init__.pyi b/src/ducktools/classbuilder/__init__.pyi index 462e967..4b147cd 100644 --- a/src/ducktools/classbuilder/__init__.pyi +++ b/src/ducktools/classbuilder/__init__.pyi @@ -309,6 +309,8 @@ def unified_gatherer(cls_or_ns: type | _CopiableMappings) -> _gatherer_returntyp def check_argument_order(cls: type) -> None: ... +def replace(obj: _T, /, **changes: typing.Any) -> _T: ... + @typing.overload def slotclass( cls: _TypeT, diff --git a/src/ducktools/classbuilder/prefab.py b/src/ducktools/classbuilder/prefab.py index 9bd508a..b6694d1 100644 --- a/src/ducktools/classbuilder/prefab.py +++ b/src/ducktools/classbuilder/prefab.py @@ -67,7 +67,11 @@ from .annotations import get_func_annotations, is_type, replace_generic_with_arg # These aren't used but are re-exported for ease of use -from . import SlotFields as SlotFields, KW_ONLY as KW_ONLY +from . import ( + KW_ONLY as KW_ONLY, + SlotFields as SlotFields, + replace as replace, +) PREFAB_FIELDS = "PREFAB_FIELDS" PREFAB_INIT_FUNC = "__prefab_init__" @@ -123,7 +127,6 @@ def init_generator(cls, funcname="__init__"): attributes = get_attributes(cls) flags = get_flags(cls) - kw_only = flags.get("kw_only", False) frozen = flags.get("frozen", False) slotted = flags.get("slotted", False) @@ -218,7 +221,7 @@ def init_generator(cls, funcname="__init__"): globs[f"_{name}_factory"] = attrib.default_factory else: arg = name - if attrib.kw_only or kw_only: + if attrib.kw_only: kw_only_arglist.append(arg) else: pos_arglist.append(arg) @@ -1038,20 +1041,3 @@ def as_dict(o): for name, attrib in flds.items() if attrib.serialize } - - -def replace(obj, /, **changes): - """ - Create a copy of a prefab instance with values provided to 'changes' replaced - - :param obj: prefab instance - :return: new prefab instance - """ - if not is_prefab_instance(obj): - raise TypeError("replace() should be called on prefab instances") - try: - replace_func = obj.__replace__ - except AttributeError: - raise TypeError(f"{obj.__class__.__name__!r} does not support __replace__") - - return replace_func(**changes) diff --git a/src/ducktools/classbuilder/prefab.pyi b/src/ducktools/classbuilder/prefab.pyi index 596d437..f8677d5 100644 --- a/src/ducktools/classbuilder/prefab.pyi +++ b/src/ducktools/classbuilder/prefab.pyi @@ -17,7 +17,11 @@ from . import ( _SignatureMaker ) -from . import SlotFields as SlotFields, KW_ONLY as KW_ONLY +from . import ( + KW_ONLY as KW_ONLY, + SlotFields as SlotFields, + replace as replace, +) # noinspection PyUnresolvedReferences from . import _NothingType @@ -262,5 +266,3 @@ def is_prefab(o: typing.Any) -> bool: ... def is_prefab_instance(o: object) -> bool: ... def as_dict(o) -> dict[str, typing.Any]: ... - -def replace(obj: _T, /, **changes: typing.Any) -> _T: ... diff --git a/tests/prefab/test_kw_only.py b/tests/prefab/test_kw_only.py index bb93067..6915c7d 100644 --- a/tests/prefab/test_kw_only.py +++ b/tests/prefab/test_kw_only.py @@ -1,59 +1,15 @@ import pytest from ducktools.classbuilder.annotations import get_ns_annotations -from ducktools.classbuilder.prefab import attribute, prefab, KW_ONLY - - -# Test Classes -@prefab -class KWBasic: - x = attribute(kw_only=True) - y = attribute(kw_only=True) - - -@prefab -class KWOrdering: - x = attribute(default=2, kw_only=True) - y = attribute() - - -@prefab -class KWBase: - x = attribute(default=2, kw_only=True) - - -@prefab -class KWChild(KWBase): - y = attribute() - - -@prefab(kw_only=True) -class KWPrefabArgument: - x = attribute() - y = attribute() - - -@prefab(kw_only=True) -class KWPrefabArgumentOverrides: - x = attribute() - y = attribute(kw_only=False) - - -@prefab -class KWFlagNoDefaults: - x: int - _: KW_ONLY # type: ignore - y: int - - -@prefab -class KWFlagXDefault: - x: int = 1 - _: KW_ONLY # type: ignore - y: int # type: ignore +from ducktools.classbuilder.prefab import KW_ONLY, Prefab, attribute, prefab def test_kw_only_basic(): + @prefab + class KWBasic: + x = attribute(kw_only=True) + y = attribute(kw_only=True) + # Check the typeerror is raised for # trying to use positional arguments with pytest.raises(TypeError): @@ -64,15 +20,30 @@ def test_kw_only_basic(): def test_kw_only_ordering(): + # Test the kw_only argument is not also positional + @prefab + class KWOrdering: + x = attribute(default=2, kw_only=True) + y = attribute() + with pytest.raises(TypeError): x = KWOrdering(1, 2) x = KWOrdering(1) assert (x.x, x.y) == (2, 1) - assert repr(x) == "KWOrdering(x=2, y=1)" + assert repr(x).endswith("KWOrdering(x=2, y=1)") + + + +def test_on_attribute(): + @prefab + class KWBase: + x = attribute(default=2, kw_only=True) + @prefab + class KWChild(KWBase): + y = attribute() -def test_kw_only_inheritance(): with pytest.raises(TypeError): x = KWChild(1, 2) @@ -80,30 +51,97 @@ def test_kw_only_inheritance(): y = KWChild(1) assert (x.x, x.y) == (2, 1) assert x == y - assert repr(x) == "KWChild(x=2, y=1)" + assert repr(x).endswith("KWChild(x=2, y=1)") -def test_kw_only_prefab_argument(): - with pytest.raises(TypeError): - x = KWPrefabArgument(1, 2) +class TestKWOnlyClassArg: + def test_kw_only_prefab_argument(self): + @prefab(kw_only=True) + class KWPrefabArgument: + x = attribute() + y = attribute() - x = KWPrefabArgument(x=1, y=2) + with pytest.raises(TypeError): + x = KWPrefabArgument(1, 2) - assert (x.x, x.y) == (1, 2) - assert repr(x) == "KWPrefabArgument(x=1, y=2)" + x = KWPrefabArgument(x=1, y=2) + assert (x.x, x.y) == (1, 2) + assert repr(x).endswith("KWPrefabArgument(x=1, y=2)") -def test_kw_only_prefab_argument_overrides(): - with pytest.raises(TypeError): - x = KWPrefabArgumentOverrides(1, 2) + def test_kw_only_prefab_argument_overrides(self): + @prefab(kw_only=True) + class KWPrefabArgumentOverrides: + x = attribute() + y = attribute(kw_only=False) - x = KWPrefabArgumentOverrides(x=1, y=2) + with pytest.raises(TypeError): + x = KWPrefabArgumentOverrides(1, 2) - assert (x.x, x.y) == (1, 2) - assert repr(x) == "KWPrefabArgumentOverrides(x=1, y=2)" + x = KWPrefabArgumentOverrides(x=1, y=2) + + assert (x.x, x.y) == (1, 2) + assert repr(x).endswith("KWPrefabArgumentOverrides(x=1, y=2)") + + def test_only_applies_to_new_fields(self): + @prefab + class Base: + name: str = "Dent" + + @prefab(kw_only=True) + class Sub(Base): + answer: int = 42 + + with pytest.raises(TypeError): + _ = Sub("Zaphod", 24) + + ex = Sub("Zaphod", answer=54) + assert ex.name == "Zaphod" + assert ex.answer == 54 + + def test_ignored_by_new_subclass(self): + @prefab(kw_only=True) + class Base: + name: str = "Dent" + + @prefab + class Sub(Base): + answer: int = 42 + + with pytest.raises(TypeError): + _ = Sub("Zaphod", 24) + + ex = Sub(54, name="Zaphod") + + assert ex.name == "Zaphod" + assert ex.answer == 54 + + def test_inherited_in_class_form(self): + # The base class version should inherit kw_only + class Base(Prefab, kw_only=True): + name: str = "Dent" + + class Sub(Base): + answer: int = 42 + + with pytest.raises(TypeError): + _ = Sub("Zaphod", 24) + + with pytest.raises(TypeError): + _ = Sub(24, name="Zaphod") + + ex = Sub(name="Zaphod", answer=54) + assert ex.name == "Zaphod" + assert ex.answer == 54 def test_kw_flag_no_defaults(): + @prefab + class KWFlagNoDefaults: + x: int + _: KW_ONLY # type: ignore + y: int + annotations = get_ns_annotations(KWFlagNoDefaults.__dict__) assert "_" in annotations @@ -116,10 +154,16 @@ def test_kw_flag_no_defaults(): assert not hasattr(x, "_") assert (x.x, x.y) == (1, 2) - assert repr(x) == "KWFlagNoDefaults(x=1, y=2)" + assert repr(x).endswith("KWFlagNoDefaults(x=1, y=2)") def test_kw_flat_defaults(): + @prefab + class KWFlagXDefault: + x: int = 1 + _: KW_ONLY # type: ignore + y: int # type: ignore + with pytest.raises(TypeError): x = KWFlagXDefault(1, 2) @@ -128,4 +172,4 @@ def test_kw_flat_defaults(): assert (x.x, x.y) == (1, 2) assert x == y - assert repr(x) == "KWFlagXDefault(x=1, y=2)" + assert repr(x).endswith("KWFlagXDefault(x=1, y=2)") diff --git a/tests/py314_tests/test_field_type.py b/tests/py314_tests/test_field_type.py new file mode 100644 index 0000000..5ff9892 --- /dev/null +++ b/tests/py314_tests/test_field_type.py @@ -0,0 +1,16 @@ +from reannotate import DeferredAnnotation +from ducktools.classbuilder import Field, replace + + +def test_replace_preserves_type(): + # Test that the Field __replace__ method preserves the internal _type + + f = Field(type=DeferredAnnotation(str)) + + assert f.type is str + assert f._type == DeferredAnnotation(str) + + new_f = replace(f) + + assert new_f.type is str + assert new_f._type == DeferredAnnotation(str) diff --git a/tests/test_core.py b/tests/test_core.py index 42917e5..278c497 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -438,56 +438,57 @@ class DictClass: assert inst.__dict__ == {"c": 42} -def test_fieldclass(): - class NewField(Field): - serialize: bool = True +class TestSubclassingField: + def test_fieldclass(self): + class NewField(Field): + serialize: bool = True - f = NewField() + f = NewField() - assert f.default is NOTHING - assert f.default_factory is NOTHING - assert f.type is NOTHING - assert f.doc is None - assert f.serialize is True + assert f.default is NOTHING + assert f.default_factory is NOTHING + assert f.type is NOTHING + assert f.doc is None + assert f.serialize is True - f2 = NewField(default=1, serialize=False) + f2 = NewField(default=1, serialize=False) - assert f2.default == 1 - assert f2.serialize is False + assert f2.default == 1 + assert f2.serialize is False - with pytest.raises(TypeError): - # All arguments are keyword only in fieldclasses - NewField(42) + with pytest.raises(TypeError): + # All arguments are keyword only in fieldclasses + NewField(42) -def test_fieldclass_frozen(): - class NewField(Field, frozen=True): - serialize: bool = True + def test_fieldclass_frozen(self): + class NewField(Field, frozen=True): + serialize: bool = True - f = NewField() + f = NewField() - attr_changes = { - "default": False, - "default_factory": list, - "type": bool, - "doc": "This should fail", - "serialize": False, - } + attr_changes = { + "default": False, + "default_factory": list, + "type": bool, + "doc": "This should fail", + "serialize": False, + } - for k, v in attr_changes.items(): - with pytest.raises(TypeError): - setattr(f, k, v) + for k, v in attr_changes.items(): + with pytest.raises(TypeError): + setattr(f, k, v) - for k in attr_changes: - with pytest.raises(TypeError): - delattr(f, k) + for k in attr_changes: + with pytest.raises(TypeError): + delattr(f, k) - # Even slotted fields raise TypeError as setattr happens first - with pytest.raises(TypeError): - setattr(f, "new_attribute", False) + # Even slotted fields raise TypeError as setattr happens first + with pytest.raises(TypeError): + setattr(f, "new_attribute", False) - with pytest.raises(TypeError): - delattr(f, "new_attribute") + with pytest.raises(TypeError): + delattr(f, "new_attribute") @graalpy_fails