|
6035 | 6035 | { |
6036 | 6036 | "cell_type": "code", |
6037 | 6037 | "execution_count": null, |
6038 | | - "id": "2256bd09", |
| 6038 | + "id": "97f2fee5", |
6039 | 6039 | "metadata": {}, |
6040 | 6040 | "outputs": [], |
6041 | 6041 | "source": [ |
6042 | 6042 | "#| export\n", |
6043 | | - "def patch_to(cls, as_prop=False, cls_method=False, set_prop=False):\n", |
| 6043 | + "def _strip_patch_name(nm):\n", |
| 6044 | + " \"Strip trailing `__` from `nm` if it doesn't start with `_`\"\n", |
| 6045 | + " return nm[:-2] if nm.endswith('__') and not nm.startswith('_') else nm\n", |
| 6046 | + "\n", |
| 6047 | + "def patch_to(cls, as_prop=False, cls_method=False, set_prop=False, nm=None):\n", |
6044 | 6048 | " \"Decorator: add `f` to `cls`\"\n", |
6045 | 6049 | " if not isinstance(cls, (tuple,list)): cls=(cls,)\n", |
6046 | 6050 | " def _inner(f):\n", |
6047 | 6051 | " for c_ in cls:\n", |
6048 | 6052 | " nf = copy_func(f)\n", |
6049 | | - " nm = f.__name__\n", |
| 6053 | + " fnm = nm or _strip_patch_name(f.__name__)\n", |
6050 | 6054 | " # `functools.update_wrapper` when passing patched function to `Pipeline`, so we do it manually\n", |
6051 | 6055 | " for o in functools.WRAPPER_ASSIGNMENTS: setattr(nf, o, getattr(f,o))\n", |
6052 | | - " nf.__qualname__ = f\"{c_.__name__}.{nm}\"\n", |
6053 | | - " if cls_method: setattr(c_, nm, _clsmethod(nf))\n", |
| 6056 | + " nf.__name__ = fnm\n", |
| 6057 | + " nf.__qualname__ = f\"{c_.__name__}.{fnm}\"\n", |
| 6058 | + " if cls_method: setattr(c_, fnm, _clsmethod(nf))\n", |
6054 | 6059 | " else:\n", |
6055 | | - " if set_prop: setattr(c_, nm, getattr(c_, nm).setter(nf))\n", |
6056 | | - " elif as_prop: setattr(c_, nm, property(nf))\n", |
| 6060 | + " if set_prop: setattr(c_, fnm, getattr(c_, fnm).setter(nf))\n", |
| 6061 | + " elif as_prop: setattr(c_, fnm, property(nf))\n", |
6057 | 6062 | " else:\n", |
6058 | | - " onm = '_orig_'+nm\n", |
6059 | | - " if hasattr(c_, nm) and not hasattr(c_, onm): setattr(c_, onm, getattr(c_, nm))\n", |
6060 | | - " setattr(c_, nm, nf)\n", |
| 6063 | + " onm = '_orig_'+fnm\n", |
| 6064 | + " if hasattr(c_, fnm) and not hasattr(c_, onm): setattr(c_, onm, getattr(c_, fnm))\n", |
| 6065 | + " setattr(c_, fnm, nf)\n", |
6061 | 6066 | " # Avoid clobbering existing functions\n", |
6062 | | - " return globals().get(nm, builtins.__dict__.get(nm, None))\n", |
| 6067 | + " return globals().get(fnm, builtins.__dict__.get(fnm, None))\n", |
6063 | 6068 | " return _inner" |
6064 | 6069 | ] |
6065 | 6070 | }, |
|
6215 | 6220 | "test_eq(t.func_mult(4), 8)" |
6216 | 6221 | ] |
6217 | 6222 | }, |
| 6223 | + { |
| 6224 | + "cell_type": "markdown", |
| 6225 | + "id": "521f1b1e", |
| 6226 | + "metadata": {}, |
| 6227 | + "source": [ |
| 6228 | + "You can also rename the function in the patched class:" |
| 6229 | + ] |
| 6230 | + }, |
6218 | 6231 | { |
6219 | 6232 | "cell_type": "code", |
6220 | 6233 | "execution_count": null, |
6221 | | - "id": "e4c74c53", |
| 6234 | + "id": "d50c188d", |
| 6235 | + "metadata": {}, |
| 6236 | + "outputs": [], |
| 6237 | + "source": [ |
| 6238 | + "class _T8(int): pass \n", |
| 6239 | + "\n", |
| 6240 | + "@patch_to(_T8, nm='add_value')\n", |
| 6241 | + "def func2(self, a): return self+a\n", |
| 6242 | + "\n", |
| 6243 | + "t = _T8(1)\n", |
| 6244 | + "test_eq(t.add_value(2), 3)\n", |
| 6245 | + "test_eq(_T8.add_value.__name__, 'add_value')\n", |
| 6246 | + "assert not hasattr(t, 'func2')" |
| 6247 | + ] |
| 6248 | + }, |
| 6249 | + { |
| 6250 | + "cell_type": "markdown", |
| 6251 | + "id": "81877f71", |
| 6252 | + "metadata": {}, |
| 6253 | + "source": [ |
| 6254 | + "A `__` suffix is stripped (unless there's also a `_` prefix):" |
| 6255 | + ] |
| 6256 | + }, |
| 6257 | + { |
| 6258 | + "cell_type": "code", |
| 6259 | + "execution_count": null, |
| 6260 | + "id": "cdd9dedb", |
| 6261 | + "metadata": {}, |
| 6262 | + "outputs": [], |
| 6263 | + "source": [ |
| 6264 | + "class _T9(int): pass \n", |
| 6265 | + "\n", |
| 6266 | + "@patch_to(_T9)\n", |
| 6267 | + "def func__(self, a): return self+a\n", |
| 6268 | + "\n", |
| 6269 | + "t = _T9(1)\n", |
| 6270 | + "test_eq(t.func(2), 3)\n", |
| 6271 | + "test_eq(_T9.func.__name__, 'func')\n", |
| 6272 | + "assert not hasattr(t, 'func__')" |
| 6273 | + ] |
| 6274 | + }, |
| 6275 | + { |
| 6276 | + "cell_type": "code", |
| 6277 | + "execution_count": null, |
| 6278 | + "id": "8faf7b86", |
6222 | 6279 | "metadata": {}, |
6223 | 6280 | "outputs": [], |
6224 | 6281 | "source": [ |
6225 | 6282 | "#| export\n", |
6226 | | - "def patch(f=None, *, as_prop=False, cls_method=False, set_prop=False):\n", |
| 6283 | + "def patch(f=None, *, as_prop=False, cls_method=False, set_prop=False, nm=None):\n", |
6227 | 6284 | " \"Decorator: add `f` to the first parameter's class (based on f's type annotations)\"\n", |
6228 | | - " if f is None: return partial(patch, as_prop=as_prop, cls_method=cls_method, set_prop=set_prop)\n", |
| 6285 | + " if f is None: return partial(patch, as_prop=as_prop, cls_method=cls_method, set_prop=set_prop, nm=nm)\n", |
6229 | 6286 | " ann,glb,loc = get_annotations_ex(f)\n", |
6230 | | - " cls = union2tuple(eval_type(ann.pop('cls') if cls_method else next(iter(ann.values())), glb, loc))\n", |
6231 | | - " return patch_to(cls, as_prop=as_prop, cls_method=cls_method, set_prop=set_prop)(f)" |
| 6287 | + " if cls_method:\n", |
| 6288 | + " if 'cls' not in ann: raise TypeError(f\"@patch with cls_method=True requires 'cls' to have a type annotation\")\n", |
| 6289 | + " cls = ann.pop('cls')\n", |
| 6290 | + " else:\n", |
| 6291 | + " if not ann: raise TypeError(f\"@patch requires the first parameter of `{f.__name__}` to have a type annotation\")\n", |
| 6292 | + " cls = next(iter(ann.values()))\n", |
| 6293 | + " cls = union2tuple(eval_type(cls, glb, loc))\n", |
| 6294 | + " return patch_to(cls, as_prop=as_prop, cls_method=cls_method, set_prop=set_prop, nm=nm)(f)" |
6232 | 6295 | ] |
6233 | 6296 | }, |
6234 | 6297 | { |
|
6348 | 6411 | { |
6349 | 6412 | "cell_type": "code", |
6350 | 6413 | "execution_count": null, |
6351 | | - "id": "f05267a8", |
| 6414 | + "id": "591af803", |
6352 | 6415 | "metadata": {}, |
6353 | 6416 | "outputs": [], |
6354 | 6417 | "source": [ |
6355 | | - "#| export\n", |
6356 | | - "def patch_property(f):\n", |
6357 | | - " \"Deprecated; use `patch(as_prop=True)` instead\"\n", |
6358 | | - " warnings.warn(\"`patch_property` is deprecated and will be removed; use `patch(as_prop=True)` instead\")\n", |
6359 | | - " cls = next(iter(f.__annotations__.values()))\n", |
6360 | | - " return patch_to(cls, as_prop=True)(f)" |
| 6418 | + "class _T8(int): pass \n", |
| 6419 | + "\n", |
| 6420 | + "@patch(nm='add_value')\n", |
| 6421 | + "def func2(self:_T8, a): return self+a\n", |
| 6422 | + "\n", |
| 6423 | + "t = _T8(1)\n", |
| 6424 | + "test_eq(t.add_value(2), 3)\n", |
| 6425 | + "test_eq(_T8.add_value.__name__, 'add_value')\n", |
| 6426 | + "assert not hasattr(t, 'func2')" |
| 6427 | + ] |
| 6428 | + }, |
| 6429 | + { |
| 6430 | + "cell_type": "code", |
| 6431 | + "execution_count": null, |
| 6432 | + "id": "fdbfff66", |
| 6433 | + "metadata": {}, |
| 6434 | + "outputs": [], |
| 6435 | + "source": [ |
| 6436 | + "class _T9(int): pass \n", |
| 6437 | + "\n", |
| 6438 | + "@patch\n", |
| 6439 | + "def func__(self:_T9, a): return self+a\n", |
| 6440 | + "\n", |
| 6441 | + "t = _T9(1)\n", |
| 6442 | + "test_eq(t.func(2), 3)\n", |
| 6443 | + "test_eq(_T9.func.__name__, 'func')\n", |
| 6444 | + "assert not hasattr(t, 'func__')" |
6361 | 6445 | ] |
6362 | 6446 | }, |
6363 | 6447 | { |
|
0 commit comments