Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 80 additions & 37 deletions src/ducktools/classbuilder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,12 @@ def __get__(self, instance, cls=None):
def get_init_generator(null=NOTHING, extra_code=None):
def cls_init_maker(cls, funcname="__init__"):
fields = get_fields(cls)
flags = get_flags(cls)

arglist = []
kw_only_arglist = []
assignments = []
globs = {}

kw_only_flag = flags.get("kw_only", False)

for k, v in fields.items():
if v.init:
if v.default is not null:
Expand All @@ -357,7 +354,7 @@ def cls_init_maker(cls, funcname="__init__"):
arg = f"{k}"
assignment = f"self.{k} = {k}"

if kw_only_flag or v.kw_only:
if v.kw_only:
kw_only_arglist.append(arg)
else:
arglist.append(arg)
Expand Down Expand Up @@ -548,31 +545,45 @@ def ge_generator(cls, funcname="__ge__"):
return get_order_generator(cls, funcname, operator=">=")


def replace_generator(cls, funcname="__replace__"):
# Generate the replace method for built classes
# unlike the dataclasses implementation this is generated
attribs = get_fields(cls)
def _get_replace_generator(private_type=False):
def cls_replace_generator(cls, funcname="__replace__"):
# Generate the replace method for built classes
# unlike the dataclasses implementation this is generated
attribs = get_fields(cls)

# This is essentially the as_dict generator for prefabs
# except based on attrib.init instead of .serialize
if private_type:
vals = ", ".join(
f"'{name}': self.{name}"
if name != "type"
else f"'{name}': self._{name}"
for name, attrib in attribs.items()
if attrib.init
)
else:
vals = ", ".join(
f"'{name}': self.{name}"
for name, attrib in attribs.items()
if attrib.init
)
init_dict = f"{{{vals}}}"

# This is essentially the as_dict generator for prefabs
# except based on attrib.init instead of .serialize
vals = ", ".join(
f"'{name}': self.{name}"
for name, attrib in attribs.items()
if attrib.init
)
init_dict = f"{{{vals}}}"
# fmt: off
code = (
f"def {funcname}(self, /, **changes):\n"
f" new_kwargs = {init_dict}\n"
f" new_kwargs |= changes\n"
f" return self.__class__(**new_kwargs)\n"
)
# fmt: on
globs = {}
return GeneratedCode(code, globs)

# fmt: off
code = (
f"def {funcname}(self, /, **changes):\n"
f" new_kwargs = {init_dict}\n"
f" new_kwargs |= changes\n"
f" return self.__class__(**new_kwargs)\n"
)
# fmt: on
globs = {}
return GeneratedCode(code, globs)
return cls_replace_generator

replace_generator = _get_replace_generator()
_field_replace_generator = _get_replace_generator(private_type=True)

def frozen_setattr_generator(cls, funcname="__setattr__"):
globs = {}
Expand Down Expand Up @@ -657,6 +668,10 @@ def hash_generator(cls, funcname="__hash__"):
)
)

# Special `__replace__` method for `Field` that will use the internal `_type`
# value instead of the resolved `type` property
_field_replace_maker = MethodMaker("__replace__", _field_replace_generator)


def add_methods(cls, methods, *, internals=None):
"""
Expand Down Expand Up @@ -730,13 +745,21 @@ def builder(cls=None, /, *, gatherer, methods, flags=None, fix_signature=True, f
flag_dict |= flags
internals["flags"] = flag_dict

kw_only = flag_dict.get("kw_only", False)
cls_gathered = cls.__dict__.get(GATHERED_DATA)

if cls_gathered:
cls_fields, modifications = cls_gathered
else:
cls_fields, modifications = gatherer(cls)

if kw_only:
# Update the class fields to make all Fields kw_only
cls_fields = {
k: v if v.kw_only else v.__replace__(kw_only=True)
for k, v in cls_fields.items()
}

for name, value in modifications.items():
if value is NOTHING:
delattr(cls, name)
Expand Down Expand Up @@ -1075,7 +1098,7 @@ def __init__(

def __init_subclass__(cls, frozen=False, ignore_annotations=False):
# Subclasses of Field can be created as if they are dataclasses
field_methods = {_field_init_maker, repr_maker, eq_maker, replace_maker}
field_methods = {_field_init_maker, repr_maker, eq_maker, _field_replace_maker}
if frozen or _UNDER_TESTING:
field_methods |= {frozen_setattr_maker, frozen_delattr_maker}

Expand Down Expand Up @@ -1110,8 +1133,9 @@ def from_field(cls, fld, /, **kwargs):
:param kwargs: Additional keyword arguments for subclasses
:return: new field subclass instance
"""
# type is special cased to get the internal value
inst_fields = {
k: getattr(fld, k)
k: getattr(fld, k) if k != "type" else getattr(fld, "_type")
for k in get_fields(type(fld))
}
argument_dict = {**inst_fields, **kwargs}
Expand Down Expand Up @@ -1152,19 +1176,21 @@ def _build_field():
"kw_only": "Make this a keyword only parameter in __init__",
}

# Fields here must be marked as kw_only to prevent the builder from trying
# to call the __replace__ method which doesn't exist yet
fields = {
"default": Field(default=NOTHING, doc=field_docs["default"]),
"default_factory": Field(default=NOTHING, doc=field_docs["default_factory"]),
"type": Field(default=NOTHING, doc=field_docs["type"]),
"doc": Field(default=None, doc=field_docs["doc"]),
"init": Field(default=True, doc=field_docs["init"]),
"repr": Field(default=True, doc=field_docs["repr"]),
"compare": Field(default=True, doc=field_docs["compare"]),
"kw_only": Field(default=False, doc=field_docs["kw_only"]),
"default": Field(default=NOTHING, doc=field_docs["default"], kw_only=True),
"default_factory": Field(default=NOTHING, doc=field_docs["default_factory"], kw_only=True),
"type": Field(default=NOTHING, doc=field_docs["type"], kw_only=True),
"doc": Field(default=None, doc=field_docs["doc"], kw_only=True),
"init": Field(default=True, doc=field_docs["init"], kw_only=True),
"repr": Field(default=True, doc=field_docs["repr"], kw_only=True),
"compare": Field(default=True, doc=field_docs["compare"], kw_only=True),
"kw_only": Field(default=False, doc=field_docs["kw_only"], kw_only=True),
}
modifications = {"__slots__": field_docs}

field_methods = {repr_maker, eq_maker, replace_maker}
field_methods = {repr_maker, eq_maker, _field_replace_maker}
if _UNDER_TESTING:
field_methods |= {frozen_setattr_maker, frozen_delattr_maker}

Expand Down Expand Up @@ -1448,6 +1474,23 @@ def check_argument_order(cls):
used_default = True


def replace(obj, /, **changes):
"""
Create a copy of a prefab instance with values provided to 'changes' replaced

:param obj: built class
:return: new built class instance with changes applied
"""
if not build_completed(type(obj)):
raise TypeError("replace() should be called on classbuilder class instances")
try:
replace_func = obj.__replace__
except AttributeError:
raise TypeError(f"{obj.__class__.__name__!r} does not support __replace__")

return replace_func(**changes)


# Class Decorators
def slotclass(cls=None, /, *, methods=default_methods, syntax_check=True):
"""
Expand Down
2 changes: 2 additions & 0 deletions src/ducktools/classbuilder/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ def unified_gatherer(cls_or_ns: type | _CopiableMappings) -> _gatherer_returntyp

def check_argument_order(cls: type) -> None: ...

def replace(obj: _T, /, **changes: typing.Any) -> _T: ...

@typing.overload
def slotclass(
cls: _TypeT,
Expand Down
26 changes: 6 additions & 20 deletions src/ducktools/classbuilder/prefab.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@
from .annotations import get_func_annotations, is_type, replace_generic_with_arg

# These aren't used but are re-exported for ease of use
from . import SlotFields as SlotFields, KW_ONLY as KW_ONLY
from . import (
KW_ONLY as KW_ONLY,
SlotFields as SlotFields,
replace as replace,
)

PREFAB_FIELDS = "PREFAB_FIELDS"
PREFAB_INIT_FUNC = "__prefab_init__"
Expand Down Expand Up @@ -123,7 +127,6 @@ def init_generator(cls, funcname="__init__"):
attributes = get_attributes(cls)
flags = get_flags(cls)

kw_only = flags.get("kw_only", False)
frozen = flags.get("frozen", False)
slotted = flags.get("slotted", False)

Expand Down Expand Up @@ -218,7 +221,7 @@ def init_generator(cls, funcname="__init__"):
globs[f"_{name}_factory"] = attrib.default_factory
else:
arg = name
if attrib.kw_only or kw_only:
if attrib.kw_only:
kw_only_arglist.append(arg)
else:
pos_arglist.append(arg)
Expand Down Expand Up @@ -1038,20 +1041,3 @@ def as_dict(o):
for name, attrib in flds.items()
if attrib.serialize
}


def replace(obj, /, **changes):
"""
Create a copy of a prefab instance with values provided to 'changes' replaced

:param obj: prefab instance
:return: new prefab instance
"""
if not is_prefab_instance(obj):
raise TypeError("replace() should be called on prefab instances")
try:
replace_func = obj.__replace__
except AttributeError:
raise TypeError(f"{obj.__class__.__name__!r} does not support __replace__")

return replace_func(**changes)
8 changes: 5 additions & 3 deletions src/ducktools/classbuilder/prefab.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ from . import (
_SignatureMaker
)

from . import SlotFields as SlotFields, KW_ONLY as KW_ONLY
from . import (
KW_ONLY as KW_ONLY,
SlotFields as SlotFields,
replace as replace,
)

# noinspection PyUnresolvedReferences
from . import _NothingType
Expand Down Expand Up @@ -262,5 +266,3 @@ def is_prefab(o: typing.Any) -> bool: ...
def is_prefab_instance(o: object) -> bool: ...

def as_dict(o) -> dict[str, typing.Any]: ...

def replace(obj: _T, /, **changes: typing.Any) -> _T: ...
Loading