Skip to content

Commit a2da7fd

Browse files
committed
Make kw_only on a class only change fields on that class
1 parent c57adeb commit a2da7fd

7 files changed

Lines changed: 258 additions & 164 deletions

File tree

src/ducktools/classbuilder/__init__.py

Lines changed: 80 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -334,15 +334,12 @@ def __get__(self, instance, cls=None):
334334
def get_init_generator(null=NOTHING, extra_code=None):
335335
def cls_init_maker(cls, funcname="__init__"):
336336
fields = get_fields(cls)
337-
flags = get_flags(cls)
338337

339338
arglist = []
340339
kw_only_arglist = []
341340
assignments = []
342341
globs = {}
343342

344-
kw_only_flag = flags.get("kw_only", False)
345-
346343
for k, v in fields.items():
347344
if v.init:
348345
if v.default is not null:
@@ -357,7 +354,7 @@ def cls_init_maker(cls, funcname="__init__"):
357354
arg = f"{k}"
358355
assignment = f"self.{k} = {k}"
359356

360-
if kw_only_flag or v.kw_only:
357+
if v.kw_only:
361358
kw_only_arglist.append(arg)
362359
else:
363360
arglist.append(arg)
@@ -548,31 +545,45 @@ def ge_generator(cls, funcname="__ge__"):
548545
return get_order_generator(cls, funcname, operator=">=")
549546

550547

551-
def replace_generator(cls, funcname="__replace__"):
552-
# Generate the replace method for built classes
553-
# unlike the dataclasses implementation this is generated
554-
attribs = get_fields(cls)
548+
def _get_replace_generator(private_type=False):
549+
def cls_replace_generator(cls, funcname="__replace__"):
550+
# Generate the replace method for built classes
551+
# unlike the dataclasses implementation this is generated
552+
attribs = get_fields(cls)
553+
554+
# This is essentially the as_dict generator for prefabs
555+
# except based on attrib.init instead of .serialize
556+
if private_type:
557+
vals = ", ".join(
558+
f"'{name}': self.{name}"
559+
if name != "type"
560+
else f"'{name}': self._{name}"
561+
for name, attrib in attribs.items()
562+
if attrib.init
563+
)
564+
else:
565+
vals = ", ".join(
566+
f"'{name}': self.{name}"
567+
for name, attrib in attribs.items()
568+
if attrib.init
569+
)
570+
init_dict = f"{{{vals}}}"
555571

556-
# This is essentially the as_dict generator for prefabs
557-
# except based on attrib.init instead of .serialize
558-
vals = ", ".join(
559-
f"'{name}': self.{name}"
560-
for name, attrib in attribs.items()
561-
if attrib.init
562-
)
563-
init_dict = f"{{{vals}}}"
572+
# fmt: off
573+
code = (
574+
f"def {funcname}(self, /, **changes):\n"
575+
f" new_kwargs = {init_dict}\n"
576+
f" new_kwargs |= changes\n"
577+
f" return self.__class__(**new_kwargs)\n"
578+
)
579+
# fmt: on
580+
globs = {}
581+
return GeneratedCode(code, globs)
564582

565-
# fmt: off
566-
code = (
567-
f"def {funcname}(self, /, **changes):\n"
568-
f" new_kwargs = {init_dict}\n"
569-
f" new_kwargs |= changes\n"
570-
f" return self.__class__(**new_kwargs)\n"
571-
)
572-
# fmt: on
573-
globs = {}
574-
return GeneratedCode(code, globs)
583+
return cls_replace_generator
575584

585+
replace_generator = _get_replace_generator()
586+
_field_replace_generator = _get_replace_generator(private_type=True)
576587

577588
def frozen_setattr_generator(cls, funcname="__setattr__"):
578589
globs = {}
@@ -657,6 +668,10 @@ def hash_generator(cls, funcname="__hash__"):
657668
)
658669
)
659670

671+
# Special `__replace__` method for `Field` that will use the internal `_type`
672+
# value instead of the resolved `type` property
673+
_field_replace_maker = MethodMaker("__replace__", _field_replace_generator)
674+
660675

661676
def add_methods(cls, methods, *, internals=None):
662677
"""
@@ -730,13 +745,21 @@ def builder(cls=None, /, *, gatherer, methods, flags=None, fix_signature=True, f
730745
flag_dict |= flags
731746
internals["flags"] = flag_dict
732747

748+
kw_only = flag_dict.get("kw_only", False)
733749
cls_gathered = cls.__dict__.get(GATHERED_DATA)
734750

735751
if cls_gathered:
736752
cls_fields, modifications = cls_gathered
737753
else:
738754
cls_fields, modifications = gatherer(cls)
739755

756+
if kw_only:
757+
# Update the class fields to make all Fields kw_only
758+
cls_fields = {
759+
k: v if v.kw_only else v.__replace__(kw_only=True)
760+
for k, v in cls_fields.items()
761+
}
762+
740763
for name, value in modifications.items():
741764
if value is NOTHING:
742765
delattr(cls, name)
@@ -1075,7 +1098,7 @@ def __init__(
10751098

10761099
def __init_subclass__(cls, frozen=False, ignore_annotations=False):
10771100
# Subclasses of Field can be created as if they are dataclasses
1078-
field_methods = {_field_init_maker, repr_maker, eq_maker, replace_maker}
1101+
field_methods = {_field_init_maker, repr_maker, eq_maker, _field_replace_maker}
10791102
if frozen or _UNDER_TESTING:
10801103
field_methods |= {frozen_setattr_maker, frozen_delattr_maker}
10811104

@@ -1110,8 +1133,9 @@ def from_field(cls, fld, /, **kwargs):
11101133
:param kwargs: Additional keyword arguments for subclasses
11111134
:return: new field subclass instance
11121135
"""
1136+
# type is special cased to get the internal value
11131137
inst_fields = {
1114-
k: getattr(fld, k)
1138+
k: getattr(fld, k) if k != "type" else getattr(fld, "_type")
11151139
for k in get_fields(type(fld))
11161140
}
11171141
argument_dict = {**inst_fields, **kwargs}
@@ -1152,19 +1176,21 @@ def _build_field():
11521176
"kw_only": "Make this a keyword only parameter in __init__",
11531177
}
11541178

1179+
# Fields here must be marked as kw_only to prevent the builder from trying
1180+
# to call the __replace__ method which doesn't exist yet
11551181
fields = {
1156-
"default": Field(default=NOTHING, doc=field_docs["default"]),
1157-
"default_factory": Field(default=NOTHING, doc=field_docs["default_factory"]),
1158-
"type": Field(default=NOTHING, doc=field_docs["type"]),
1159-
"doc": Field(default=None, doc=field_docs["doc"]),
1160-
"init": Field(default=True, doc=field_docs["init"]),
1161-
"repr": Field(default=True, doc=field_docs["repr"]),
1162-
"compare": Field(default=True, doc=field_docs["compare"]),
1163-
"kw_only": Field(default=False, doc=field_docs["kw_only"]),
1182+
"default": Field(default=NOTHING, doc=field_docs["default"], kw_only=True),
1183+
"default_factory": Field(default=NOTHING, doc=field_docs["default_factory"], kw_only=True),
1184+
"type": Field(default=NOTHING, doc=field_docs["type"], kw_only=True),
1185+
"doc": Field(default=None, doc=field_docs["doc"], kw_only=True),
1186+
"init": Field(default=True, doc=field_docs["init"], kw_only=True),
1187+
"repr": Field(default=True, doc=field_docs["repr"], kw_only=True),
1188+
"compare": Field(default=True, doc=field_docs["compare"], kw_only=True),
1189+
"kw_only": Field(default=False, doc=field_docs["kw_only"], kw_only=True),
11641190
}
11651191
modifications = {"__slots__": field_docs}
11661192

1167-
field_methods = {repr_maker, eq_maker, replace_maker}
1193+
field_methods = {repr_maker, eq_maker, _field_replace_maker}
11681194
if _UNDER_TESTING:
11691195
field_methods |= {frozen_setattr_maker, frozen_delattr_maker}
11701196

@@ -1448,6 +1474,23 @@ def check_argument_order(cls):
14481474
used_default = True
14491475

14501476

1477+
def replace(obj, /, **changes):
1478+
"""
1479+
Create a copy of a prefab instance with values provided to 'changes' replaced
1480+
1481+
:param obj: built class
1482+
:return: new built class instance with changes applied
1483+
"""
1484+
if not build_completed(type(obj)):
1485+
raise TypeError("replace() should be called on classbuilder class instances")
1486+
try:
1487+
replace_func = obj.__replace__
1488+
except AttributeError:
1489+
raise TypeError(f"{obj.__class__.__name__!r} does not support __replace__")
1490+
1491+
return replace_func(**changes)
1492+
1493+
14511494
# Class Decorators
14521495
def slotclass(cls=None, /, *, methods=default_methods, syntax_check=True):
14531496
"""

src/ducktools/classbuilder/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ def unified_gatherer(cls_or_ns: type | _CopiableMappings) -> _gatherer_returntyp
309309

310310
def check_argument_order(cls: type) -> None: ...
311311

312+
def replace(obj: _T, /, **changes: typing.Any) -> _T: ...
313+
312314
@typing.overload
313315
def slotclass(
314316
cls: _TypeT,

src/ducktools/classbuilder/prefab.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@
6767
from .annotations import get_func_annotations, is_type, replace_generic_with_arg
6868

6969
# These aren't used but are re-exported for ease of use
70-
from . import SlotFields as SlotFields, KW_ONLY as KW_ONLY
70+
from . import (
71+
KW_ONLY as KW_ONLY,
72+
SlotFields as SlotFields,
73+
replace as replace,
74+
)
7175

7276
PREFAB_FIELDS = "PREFAB_FIELDS"
7377
PREFAB_INIT_FUNC = "__prefab_init__"
@@ -123,7 +127,6 @@ def init_generator(cls, funcname="__init__"):
123127
attributes = get_attributes(cls)
124128
flags = get_flags(cls)
125129

126-
kw_only = flags.get("kw_only", False)
127130
frozen = flags.get("frozen", False)
128131
slotted = flags.get("slotted", False)
129132

@@ -218,7 +221,7 @@ def init_generator(cls, funcname="__init__"):
218221
globs[f"_{name}_factory"] = attrib.default_factory
219222
else:
220223
arg = name
221-
if attrib.kw_only or kw_only:
224+
if attrib.kw_only:
222225
kw_only_arglist.append(arg)
223226
else:
224227
pos_arglist.append(arg)
@@ -1038,20 +1041,3 @@ def as_dict(o):
10381041
for name, attrib in flds.items()
10391042
if attrib.serialize
10401043
}
1041-
1042-
1043-
def replace(obj, /, **changes):
1044-
"""
1045-
Create a copy of a prefab instance with values provided to 'changes' replaced
1046-
1047-
:param obj: prefab instance
1048-
:return: new prefab instance
1049-
"""
1050-
if not is_prefab_instance(obj):
1051-
raise TypeError("replace() should be called on prefab instances")
1052-
try:
1053-
replace_func = obj.__replace__
1054-
except AttributeError:
1055-
raise TypeError(f"{obj.__class__.__name__!r} does not support __replace__")
1056-
1057-
return replace_func(**changes)

src/ducktools/classbuilder/prefab.pyi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ from . import (
1717
_SignatureMaker
1818
)
1919

20-
from . import SlotFields as SlotFields, KW_ONLY as KW_ONLY
20+
from . import (
21+
KW_ONLY as KW_ONLY,
22+
SlotFields as SlotFields,
23+
replace as replace,
24+
)
2125

2226
# noinspection PyUnresolvedReferences
2327
from . import _NothingType
@@ -262,5 +266,3 @@ def is_prefab(o: typing.Any) -> bool: ...
262266
def is_prefab_instance(o: object) -> bool: ...
263267

264268
def as_dict(o) -> dict[str, typing.Any]: ...
265-
266-
def replace(obj: _T, /, **changes: typing.Any) -> _T: ...

0 commit comments

Comments
 (0)