Skip to content

Commit d62e80b

Browse files
committed
feat: Add support for paramless function variants
1 parent 407adcc commit d62e80b

2 files changed

Lines changed: 23 additions & 2 deletions

File tree

src/fieldenum/_fieldenum.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def __call__(self, *args, **kwargs):
276276

277277

278278
class _FunctionVariant(Variant): # MARK: FunctionVariant
279-
__slots__ = ("_func", "_signature", "_match_args", "_self_included")
279+
__slots__ = ("_func", "_signature", "_match_args", "_self_included", "_paramless")
280280
name: str
281281

282282
def __init__(self, func: types.FunctionType) -> None:
@@ -285,6 +285,7 @@ def __init__(self, func: types.FunctionType) -> None:
285285
self._func = func
286286
signature = inspect.signature(func)
287287
parameters_raw = signature.parameters
288+
self._paramless = not parameters_raw or len(parameters_raw) == 1 and "self" in parameters_raw
288289
self._signature = signature
289290

290291
parameters_iter = iter(parameters_raw)
@@ -326,8 +327,13 @@ def attach(
326327
self._base = cls
327328
item = self
328329

330+
if self._paramless:
331+
meta = dict(metaclass=ParamlessSingletonMeta)
332+
else:
333+
meta = {}
334+
329335
# fmt: off
330-
class ConstructedVariant(cls):
336+
class ConstructedVariant(cls, **meta):
331337
if frozen and not typing.TYPE_CHECKING:
332338
__slots__ = tuple(f"__original_{name}" for name in item._slots_names)
333339
for name in item._slots_names:

tests/test_fieldenum.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def KwargsOnlyFuncVariant(*, a, b, c, d=None):
6060
def ParamlessFuncVariantWithBody(self):
6161
pass
6262

63+
@variant
64+
def ParamlessFuncVariantWithoutBody(self):
65+
pass
66+
6367

6468
def test_unreachable():
6569
with pytest.raises(UnreachableError, match="Unexpected type 'str' of 'hello'"):
@@ -212,6 +216,17 @@ def test_relationship():
212216
assert isinstance(Message.Pause(), Message)
213217

214218

219+
def test_fieldless_function():
220+
# with pytest.raises(TypeError, match="unhashable type:"):
221+
# {Message.ParamlessFuncVariantWithBody()}
222+
223+
# with pytest.raises(TypeError, match="unhashable type:"):
224+
# {Message.ParamlessFuncVariantWithoutBody()}
225+
226+
assert Message.ParamlessFuncVariantWithBody() is Message.ParamlessFuncVariantWithBody()
227+
assert Message.ParamlessFuncVariantWithoutBody() is Message.ParamlessFuncVariantWithoutBody()
228+
229+
215230
def test_instancing():
216231
type MyTypeAlias = int | str | bytes
217232

0 commit comments

Comments
 (0)