Skip to content

Commit 4606aa7

Browse files
committed
fixes #809
1 parent b8c9ed7 commit 4606aa7

3 files changed

Lines changed: 131 additions & 41 deletions

File tree

fastcore/_modidx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
'fastcore.basics.exec_import': ('basics.html#exec_import', 'fastcore/basics.py'),
139139
'fastcore.basics.exec_local': ('basics.html#exec_local', 'fastcore/basics.py'),
140140
'fastcore.basics.exec_new': ('basics.html#exec_new', 'fastcore/basics.py'),
141+
'fastcore.basics.extend_enum': ('basics.html#extend_enum', 'fastcore/basics.py'),
141142
'fastcore.basics.fastuple': ('basics.html#fastuple', 'fastcore/basics.py'),
142143
'fastcore.basics.fastuple.__new__': ('basics.html#fastuple.__new__', 'fastcore/basics.py'),
143144
'fastcore.basics.fastuple._op': ('basics.html#fastuple._op', 'fastcore/basics.py'),

fastcore/basics.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
'last', 'only', 'nested_attr', 'nested_setdefault', 'nested_callable', 'nested_idx', 'set_nested_idx',
1818
'val2idx', 'uniqueify', 'loop_first_last', 'loop_first', 'loop_last', 'first_match', 'last_match', 'joins',
1919
'fastuple', 'bind', 'mapt', 'map_ex', 'compose', 'maps', 'partialler', 'instantiate', 'using_attr', 'negate',
20-
'spread', 'dspread', 'copy_func', 'patch_to', 'patch', 'compile_re', 'ImportEnum', 'StrEnum', 'str_enum',
21-
'ValEnum', 'Stateful', 'NotStr', 'PrettyString', 'even_mults', 'num_cpus', 'add_props', 'str2bool',
22-
'str2int', 'str2float', 'str2list', 'str2date', 'to_bool', 'to_int', 'to_float', 'to_list', 'to_date',
23-
'typed', 'exec_new', 'exec_import', 'sig_with_params', 'fdelegates', 'lt', 'gt', 'le', 'ge', 'eq', 'ne',
24-
'add', 'sub', 'mul', 'truediv', 'is_', 'is_not', 'mod']
20+
'spread', 'dspread', 'copy_func', 'patch_to', 'patch', 'extend_enum', 'compile_re', 'ImportEnum', 'StrEnum',
21+
'str_enum', 'ValEnum', 'Stateful', 'NotStr', 'PrettyString', 'even_mults', 'num_cpus', 'add_props',
22+
'str2bool', 'str2int', 'str2float', 'str2list', 'str2date', 'to_bool', 'to_int', 'to_float', 'to_list',
23+
'to_date', 'typed', 'exec_new', 'exec_import', 'sig_with_params', 'fdelegates', 'lt', 'gt', 'le', 'ge', 'eq',
24+
'ne', 'add', 'sub', 'mul', 'truediv', 'is_', 'is_not', 'mod']
2525

2626
# %% ../nbs/01_basics.ipynb #0e91ed82
2727
from .imports import *
@@ -1098,7 +1098,7 @@ def __init__(self, f): self.f = f
10981098
def __get__(self, _, f_cls): return MethodType(self.f, f_cls)
10991099

11001100
# %% ../nbs/01_basics.ipynb #3f2733ef
1101-
def patch_to(cls, as_prop=False, cls_method=False, set_prop=False, nm=None, glb=None):
1101+
def patch_to(cls, as_prop=False, cls_method=False, set_prop=False, static_method=False, nm=None, glb=None):
11021102
"Decorator: add `f` to `cls`"
11031103
if glb is None: glb = sys._getframe(1).f_globals
11041104
def _inner(f):
@@ -1111,6 +1111,7 @@ def _inner(f):
11111111
nf.__qualname__ = f"{c_.__name__}.{_nm}"
11121112
if hasattr(c_, _nm) and not hasattr(c_, onm): setattr(c_, onm, getattr(c_, _nm))
11131113
if cls_method: attr = _clsmethod(nf)
1114+
elif static_method: attr = staticmethod(nf)
11141115
elif set_prop: attr = getattr(c_, _nm).setter(nf)
11151116
elif as_prop: attr = property(nf)
11161117
else: attr = nf
@@ -1119,9 +1120,9 @@ def _inner(f):
11191120
return _inner
11201121

11211122
# %% ../nbs/01_basics.ipynb #8faf7b86
1122-
def patch(f=None, *, as_prop=False, cls_method=False, set_prop=False, nm=None):
1123+
def patch(f=None, *, as_prop=False, cls_method=False, static_method=False, set_prop=False, nm=None):
11231124
"Decorator: add `f` to the first parameter's class (based on f's type annotations)"
1124-
if f is None: return partial(patch, as_prop=as_prop, cls_method=cls_method, set_prop=set_prop, nm=nm)
1125+
if f is None: return partial(patch, as_prop=as_prop, cls_method=cls_method, static_method=static_method, set_prop=set_prop, nm=nm)
11251126
ann,glb,loc = get_annotations_ex(f)
11261127
if cls_method:
11271128
if 'cls' not in ann: raise TypeError(f"@patch with cls_method=True requires 'cls' to have a type annotation")
@@ -1130,7 +1131,25 @@ def patch(f=None, *, as_prop=False, cls_method=False, set_prop=False, nm=None):
11301131
if not ann: raise TypeError(f"@patch requires the first parameter of `{f.__name__}` to have a type annotation")
11311132
cls = next(iter(ann.values()))
11321133
cls = union2tuple(eval_type(cls, glb, loc))
1133-
return patch_to(cls, as_prop=as_prop, cls_method=cls_method, set_prop=set_prop, nm=nm, glb=sys._getframe(1).f_globals)(f)
1134+
glbs = sys._getframe(1).f_globals
1135+
return patch_to(cls, as_prop=as_prop, cls_method=cls_method, static_method=static_method, set_prop=set_prop, nm=nm, glb=glbs)(f)
1136+
1137+
# %% ../nbs/01_basics.ipynb #c8805a92
1138+
def extend_enum(
1139+
cls, # Enum class to modify
1140+
n, # Name of the new enum member
1141+
v # Value of the new enum member
1142+
):
1143+
"Add new member `n` with value `v` to enum class `cls` at runtime"
1144+
if n in cls._member_map_: return cls[n]
1145+
typ = cls._member_type_
1146+
res = object.__new__(cls) if typ is object else typ.__new__(cls, v)
1147+
res._name_,res._value_ = n,v
1148+
cls._member_names_.append(n)
1149+
cls._member_map_[n] = res
1150+
cls._value2member_map_[v] = res
1151+
type.__setattr__(cls, n, res)
1152+
return res
11341153

11351154
# %% ../nbs/01_basics.ipynb #d1732261
11361155
def compile_re(pat):

nbs/01_basics.ipynb

Lines changed: 102 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -747,9 +747,7 @@
747747
" nm, fld_names:VAR_POSITIONAL, sup:NoneType=None, doc:NoneType=None, funcs:NoneType=None, anno:NoneType=None,\n",
748748
" flds:VAR_KEYWORD\n",
749749
"):\n",
750-
"\"\"\"\n",
751-
"Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`\n",
752-
"\"\"\""
750+
"\"\"\"Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`\"\"\""
753751
]
754752
},
755753
"execution_count": null,
@@ -950,9 +948,7 @@
950948
"def ignore_exceptions(\n",
951949
" args:VAR_POSITIONAL, kwargs:VAR_KEYWORD\n",
952950
"):\n",
953-
"\"\"\"\n",
954-
"Context manager to ignore exceptions\n",
955-
"\"\"\""
951+
"\"\"\"Context manager to ignore exceptions\"\"\""
956952
]
957953
},
958954
"execution_count": null,
@@ -1123,9 +1119,7 @@
11231119
"def noop(\n",
11241120
" x:NoneType=None, args:VAR_POSITIONAL, kwargs:VAR_KEYWORD\n",
11251121
"):\n",
1126-
"\"\"\"\n",
1127-
"Do nothing\n",
1128-
"\"\"\""
1122+
"\"\"\"Do nothing\"\"\""
11291123
]
11301124
},
11311125
"execution_count": null,
@@ -1176,9 +1170,7 @@
11761170
"def noops(\n",
11771171
" x:NoneType=None, args:VAR_POSITIONAL, kwargs:VAR_KEYWORD\n",
11781172
"):\n",
1179-
"\"\"\"\n",
1180-
"Do nothing (method)\n",
1181-
"\"\"\""
1173+
"\"\"\"Do nothing (method)\"\"\""
11821174
]
11831175
},
11841176
"execution_count": null,
@@ -2028,7 +2020,7 @@
20282020
{
20292021
"data": {
20302022
"text/plain": [
2031-
"typing.Union[__main__._T2a, __main__._T2b]"
2023+
"__main__._T2a | __main__._T2b"
20322024
]
20332025
},
20342026
"execution_count": null,
@@ -3305,9 +3297,7 @@
33053297
"def GetAttr(\n",
33063298
" args:VAR_POSITIONAL, kwargs:VAR_KEYWORD\n",
33073299
"):\n",
3308-
"\"\"\"\n",
3309-
"Inherit from this to have all attr accesses in `self._xtra` passed down to `self.default`\n",
3310-
"\"\"\""
3300+
"\"\"\"Inherit from this to have all attr accesses in `self._xtra` passed down to `self.default`\"\"\""
33113301
]
33123302
},
33133303
"execution_count": null,
@@ -5280,9 +5270,7 @@
52805270
"def fastuple(\n",
52815271
" args:VAR_POSITIONAL, kwargs:VAR_KEYWORD\n",
52825272
"):\n",
5283-
"\"\"\"\n",
5284-
"A `tuple` with elementwise ops and more friendly __init__ behavior\n",
5285-
"\"\"\""
5273+
"\"\"\"A `tuple` with elementwise ops and more friendly __init__ behavior\"\"\""
52865274
]
52875275
},
52885276
"execution_count": null,
@@ -5374,9 +5362,7 @@
53745362
"def add(\n",
53755363
" args:VAR_POSITIONAL\n",
53765364
"):\n",
5377-
"\"\"\"\n",
5378-
"`+` is already defined in `tuple` for concat, so use `add` instead\n",
5379-
"\"\"\""
5365+
"\"\"\"`+` is already defined in `tuple` for concat, so use `add` instead\"\"\""
53805366
]
53815367
},
53825368
"execution_count": null,
@@ -5430,9 +5416,7 @@
54305416
"def mul(\n",
54315417
" args:VAR_POSITIONAL\n",
54325418
"):\n",
5433-
"\"\"\"\n",
5434-
"`*` is already defined in `tuple` for replicating, so use `mul` instead\n",
5435-
"\"\"\""
5419+
"\"\"\"`*` is already defined in `tuple` for replicating, so use `mul` instead\"\"\""
54365420
]
54375421
},
54385422
"execution_count": null,
@@ -5607,9 +5591,7 @@
56075591
"def bind(\n",
56085592
" func, pargs:VAR_POSITIONAL, pkwargs:VAR_KEYWORD\n",
56095593
"):\n",
5610-
"\"\"\"\n",
5611-
"Same as `partial`, except you can use `arg0` `arg1` etc param placeholders\n",
5612-
"\"\"\""
5594+
"\"\"\"Same as `partial`, except you can use `arg0` `arg1` etc param placeholders\"\"\""
56135595
]
56145596
},
56155597
"execution_count": null,
@@ -6508,7 +6490,7 @@
65086490
"outputs": [],
65096491
"source": [
65106492
"#| export\n",
6511-
"def patch_to(cls, as_prop=False, cls_method=False, set_prop=False, nm=None, glb=None):\n",
6493+
"def patch_to(cls, as_prop=False, cls_method=False, set_prop=False, static_method=False, nm=None, glb=None):\n",
65126494
" \"Decorator: add `f` to `cls`\"\n",
65136495
" if glb is None: glb = sys._getframe(1).f_globals\n",
65146496
" def _inner(f):\n",
@@ -6521,6 +6503,7 @@
65216503
" nf.__qualname__ = f\"{c_.__name__}.{_nm}\"\n",
65226504
" if hasattr(c_, _nm) and not hasattr(c_, onm): setattr(c_, onm, getattr(c_, _nm))\n",
65236505
" if cls_method: attr = _clsmethod(nf)\n",
6506+
" elif static_method: attr = staticmethod(nf)\n",
65246507
" elif set_prop: attr = getattr(c_, _nm).setter(nf)\n",
65256508
" elif as_prop: attr = property(nf)\n",
65266509
" else: attr = nf\n",
@@ -6715,9 +6698,9 @@
67156698
"outputs": [],
67166699
"source": [
67176700
"#| export\n",
6718-
"def patch(f=None, *, as_prop=False, cls_method=False, set_prop=False, nm=None):\n",
6701+
"def patch(f=None, *, as_prop=False, cls_method=False, static_method=False, set_prop=False, nm=None):\n",
67196702
" \"Decorator: add `f` to the first parameter's class (based on f's type annotations)\"\n",
6720-
" if f is None: return partial(patch, as_prop=as_prop, cls_method=cls_method, set_prop=set_prop, nm=nm)\n",
6703+
" if f is None: return partial(patch, as_prop=as_prop, cls_method=cls_method, static_method=static_method, set_prop=set_prop, nm=nm)\n",
67216704
" ann,glb,loc = get_annotations_ex(f)\n",
67226705
" if cls_method:\n",
67236706
" if 'cls' not in ann: raise TypeError(f\"@patch with cls_method=True requires 'cls' to have a type annotation\")\n",
@@ -6726,7 +6709,8 @@
67266709
" if not ann: raise TypeError(f\"@patch requires the first parameter of `{f.__name__}` to have a type annotation\")\n",
67276710
" cls = next(iter(ann.values()))\n",
67286711
" cls = union2tuple(eval_type(cls, glb, loc))\n",
6729-
" return patch_to(cls, as_prop=as_prop, cls_method=cls_method, set_prop=set_prop, nm=nm, glb=sys._getframe(1).f_globals)(f)"
6712+
" glbs = sys._getframe(1).f_globals\n",
6713+
" return patch_to(cls, as_prop=as_prop, cls_method=cls_method, static_method=static_method, set_prop=set_prop, nm=nm, glb=glbs)(f)"
67306714
]
67316715
},
67326716
{
@@ -6754,6 +6738,25 @@
67546738
"test_eq(t.func.__qualname__, '_T8.func')"
67556739
]
67566740
},
6741+
{
6742+
"cell_type": "code",
6743+
"execution_count": null,
6744+
"id": "5583f90f",
6745+
"metadata": {},
6746+
"outputs": [],
6747+
"source": [
6748+
"class MyMath: pass\n",
6749+
"\n",
6750+
"@patch_to(MyMath, static_method=True)\n",
6751+
"def add(a, b): return a + b\n",
6752+
"\n",
6753+
"@patch(static_method=True)\n",
6754+
"def mul(a:MyMath, b): return a * b\n",
6755+
"\n",
6756+
"test_eq(MyMath.add(2, 3), 5)\n",
6757+
"test_eq(MyMath.mul(2, 3), 6)"
6758+
]
6759+
},
67576760
{
67586761
"cell_type": "markdown",
67596762
"id": "d96ec34b",
@@ -6898,6 +6901,73 @@
68986901
"## Other Helpers"
68996902
]
69006903
},
6904+
{
6905+
"cell_type": "code",
6906+
"execution_count": null,
6907+
"id": "c8805a92",
6908+
"metadata": {},
6909+
"outputs": [],
6910+
"source": [
6911+
"#| export\n",
6912+
"def extend_enum(\n",
6913+
" cls, # Enum class to modify\n",
6914+
" n, # Name of the new enum member\n",
6915+
" v # Value of the new enum member\n",
6916+
"):\n",
6917+
" \"Add new member `n` with value `v` to enum class `cls` at runtime\"\n",
6918+
" if n in cls._member_map_: return cls[n]\n",
6919+
" typ = cls._member_type_\n",
6920+
" res = object.__new__(cls) if typ is object else typ.__new__(cls, v)\n",
6921+
" res._name_,res._value_ = n,v\n",
6922+
" cls._member_names_.append(n)\n",
6923+
" cls._member_map_[n] = res\n",
6924+
" cls._value2member_map_[v] = res\n",
6925+
" type.__setattr__(cls, n, res)\n",
6926+
" return res"
6927+
]
6928+
},
6929+
{
6930+
"cell_type": "markdown",
6931+
"id": "c8d11c83",
6932+
"metadata": {},
6933+
"source": [
6934+
"`extend_enum` mutates an existing enum class by constructing a new member, registering it in the enum’s internal lookup tables, and attaching it as a class attribute, so it behaves like a normal enum member created in the original class definition."
6935+
]
6936+
},
6937+
{
6938+
"cell_type": "code",
6939+
"execution_count": null,
6940+
"id": "c0ded5de",
6941+
"metadata": {},
6942+
"outputs": [],
6943+
"source": [
6944+
"from enum import Enum"
6945+
]
6946+
},
6947+
{
6948+
"cell_type": "code",
6949+
"execution_count": null,
6950+
"id": "190756da",
6951+
"metadata": {},
6952+
"outputs": [
6953+
{
6954+
"data": {
6955+
"text/plain": [
6956+
"(<Color.green: 3>, <Color.green: 3>, <Color.green: 3>)"
6957+
]
6958+
},
6959+
"execution_count": null,
6960+
"metadata": {},
6961+
"output_type": "execute_result"
6962+
}
6963+
],
6964+
"source": [
6965+
"class Color(Enum): red = 1; blue = 2\n",
6966+
"\n",
6967+
"extend_enum(Color, 'green', 3)\n",
6968+
"Color.green, Color['green'], Color(3)"
6969+
]
6970+
},
69016971
{
69026972
"cell_type": "code",
69036973
"execution_count": null,

0 commit comments

Comments
 (0)