|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +"""deep_reloaderの初期型""" |
| 4 | + |
| 5 | +import _ast |
| 6 | +import ast |
| 7 | +import importlib |
| 8 | +import inspect |
| 9 | +import shutil |
| 10 | +import sys |
| 11 | +from pathlib import Path |
| 12 | +from types import ModuleType |
| 13 | +from typing import Any, Dict, List, Tuple, cast |
| 14 | + |
| 15 | +# ref. https://graphics.hatenablog.com/entry/2019/12/22/052819 |
| 16 | + |
| 17 | +__package_name = '' |
| 18 | + |
| 19 | + |
| 20 | +def module_reloader(module: ModuleType) -> None: |
| 21 | + """deep_reloaderの初期型 |
| 22 | +
|
| 23 | + Args: |
| 24 | + module: リロード対象のモジュール |
| 25 | + """ |
| 26 | + global __package_name |
| 27 | + |
| 28 | + # モジュール名からパッケージ名を自動推定 |
| 29 | + module_name = module.__name__ |
| 30 | + if '.' in module_name: |
| 31 | + # パッケージの一部の場合は、最上位パッケージ名を使用 |
| 32 | + __package_name = module_name.split('.')[0] |
| 33 | + else: |
| 34 | + # トップレベルモジュールの場合はモジュール名をそのまま使用 |
| 35 | + __package_name = module_name |
| 36 | + |
| 37 | + _delete_modules() |
| 38 | + |
| 39 | + from_import_symbols: List[Tuple[ModuleType, Dict[ModuleType, List[str]]]] = _get_symbols(module) |
| 40 | + |
| 41 | + parent: ModuleType |
| 42 | + children_symbols: Dict[ModuleType, List[str]] |
| 43 | + for parent, children_symbols in from_import_symbols: |
| 44 | + _reload(children_symbols) |
| 45 | + _overwrite_with_reloaded_symbols(parent, children_symbols) |
| 46 | + |
| 47 | + |
| 48 | +def _delete_modules() -> None: |
| 49 | + global __package_name |
| 50 | + |
| 51 | + # パッケージ名に基づいてsys.modulesからモジュールを削除 |
| 52 | + for module_name in list(sys.modules.keys()): |
| 53 | + if module_name.startswith(__package_name): |
| 54 | + del sys.modules[module_name] |
| 55 | + |
| 56 | + |
| 57 | +def _get_symbols(parent: ModuleType) -> List[Tuple[ModuleType, Dict[ModuleType, List[str]]]]: |
| 58 | + children_symbols: Dict[ModuleType, List[str]] = get_children_symbols(parent) |
| 59 | + result = [] |
| 60 | + for child_module in children_symbols.keys(): |
| 61 | + result.extend(_get_symbols(child_module)) |
| 62 | + result.append((parent, children_symbols)) |
| 63 | + return result |
| 64 | + |
| 65 | + |
| 66 | +def get_children_symbols(module: ModuleType): |
| 67 | + children_symbols: Dict[ModuleType, List[Any]] = {} |
| 68 | + |
| 69 | + try: |
| 70 | + source = inspect.getsource(module) |
| 71 | + except Exception: |
| 72 | + # ソースコードが取得できない場合(組み込みモジュールなど)はスキップ |
| 73 | + return children_symbols |
| 74 | + |
| 75 | + tree: _ast.Module = ast.parse(source) |
| 76 | + |
| 77 | + stmt: _ast.stmt |
| 78 | + for stmt in tree.body: |
| 79 | + # TODO: import xxx の場合のサポートも必要? |
| 80 | + # from xxx import でないならcontinue |
| 81 | + if stmt.__class__ != _ast.ImportFrom: |
| 82 | + continue |
| 83 | + |
| 84 | + imp_frm = cast(_ast.ImportFrom, stmt) |
| 85 | + |
| 86 | + # モジュール名を取得(相対インポートの場合の特別処理を含む) |
| 87 | + module_name = imp_frm.module |
| 88 | + |
| 89 | + # モジュールのフルネームを取得 |
| 90 | + if imp_frm.level == 0: |
| 91 | + # 絶対インポート: from module import something |
| 92 | + if module_name is None: |
| 93 | + continue |
| 94 | + module_full_name = f'{module_name}' |
| 95 | + elif imp_frm.level == 1: |
| 96 | + # 同階層相対インポート: from .module import something |
| 97 | + if module_name is None: |
| 98 | + # from . import something (現在のパッケージから直接インポート) |
| 99 | + module_full_name = module.__package__ |
| 100 | + else: |
| 101 | + # from .module import something |
| 102 | + module_full_name = f'{module.__package__}.{module_name}' |
| 103 | + elif imp_frm.level >= 2: |
| 104 | + # 上位階層相対インポート: from ..module import something |
| 105 | + package_names = module.__package__.split('.') |
| 106 | + package_names = package_names[: -(imp_frm.level - 1)] |
| 107 | + package_names = '.'.join(package_names) |
| 108 | + if module_name is None: |
| 109 | + # from .. import something (上位パッケージから直接インポート) |
| 110 | + module_full_name = package_names |
| 111 | + else: |
| 112 | + # from ..module import something |
| 113 | + module_full_name = f'{package_names}.{module_name}' |
| 114 | + else: |
| 115 | + raise Exception('module_reloaderにて例外が発生しました。ソースコードを確認してください') |
| 116 | + |
| 117 | + # リロード対象ではないならcontinue |
| 118 | + global __package_name |
| 119 | + if not module_full_name.startswith(__package_name): |
| 120 | + continue |
| 121 | + |
| 122 | + try: |
| 123 | + new_module: ModuleType = importlib.import_module(module_full_name) |
| 124 | + except Exception: |
| 125 | + # インポートに失敗した場合はスキップ |
| 126 | + continue |
| 127 | + |
| 128 | + # packageならスキップ(フリーズ防止のため重要) |
| 129 | + if _is_package(new_module): |
| 130 | + # NOTE: from xxx import yyy のyyyがモジュールのため、シンボルを上書きする必要はない。 |
| 131 | + continue |
| 132 | + |
| 133 | + symbol_names: List[str] = [x.name for x in imp_frm.names] |
| 134 | + |
| 135 | + # wildcard importの場合 |
| 136 | + if symbol_names[0] == '*': |
| 137 | + if '__all__' in new_module.__dict__: |
| 138 | + symbol_names = new_module.__dict__['__all__'] |
| 139 | + else: |
| 140 | + symbol_names = [x for x in new_module.__dict__ if not x.startswith('__')] |
| 141 | + |
| 142 | + children_symbols[new_module] = symbol_names |
| 143 | + |
| 144 | + return children_symbols |
| 145 | + |
| 146 | + |
| 147 | +def _is_package(module: ModuleType) -> bool: |
| 148 | + """モジュールがパッケージ(__init__.py)かどうかを判定""" |
| 149 | + file = module.__file__ |
| 150 | + return file is None or file.endswith('__init__.py') |
| 151 | + |
| 152 | + |
| 153 | +def _reload(children_symbols: Dict[ModuleType, List[str]]) -> None: |
| 154 | + for child_module in children_symbols.keys(): |
| 155 | + # 強力なリロード: sys.modulesから削除してから再インポート |
| 156 | + module_name = child_module.__name__ |
| 157 | + |
| 158 | + # .pycファイルを削除(キャッシュクリア) |
| 159 | + _clear_single_pycache(child_module) |
| 160 | + |
| 161 | + # sys.modulesから削除 |
| 162 | + if module_name in sys.modules: |
| 163 | + del sys.modules[module_name] |
| 164 | + |
| 165 | + # キャッシュをクリア |
| 166 | + importlib.invalidate_caches() |
| 167 | + |
| 168 | + # 再インポート |
| 169 | + try: |
| 170 | + reloaded_module = importlib.import_module(module_name) |
| 171 | + |
| 172 | + # 元のモジュールオブジェクトの辞書を更新 |
| 173 | + child_module.__dict__.clear() |
| 174 | + child_module.__dict__.update(reloaded_module.__dict__) |
| 175 | + |
| 176 | + except Exception: |
| 177 | + # フォールバック: 通常のリロード |
| 178 | + importlib.reload(child_module) |
| 179 | + |
| 180 | + |
| 181 | +def _clear_single_pycache(module: ModuleType) -> None: |
| 182 | + """ |
| 183 | + 1つのモジュールに対応する __pycache__ を削除 |
| 184 | + """ |
| 185 | + module_file = getattr(module, '__file__', None) |
| 186 | + if module_file is None: |
| 187 | + return |
| 188 | + |
| 189 | + module_dir = Path(module_file).parent |
| 190 | + pycache_dir = module_dir / '__pycache__' |
| 191 | + |
| 192 | + if pycache_dir.exists(): |
| 193 | + try: |
| 194 | + shutil.rmtree(pycache_dir) |
| 195 | + except Exception: |
| 196 | + pass # エラーは無視 |
| 197 | + |
| 198 | + |
| 199 | +def _overwrite_with_reloaded_symbols(parent: ModuleType, children_symbols: Dict[ModuleType, List[str]]) -> None: |
| 200 | + no_key = 'no key' |
| 201 | + |
| 202 | + for child_module, child_symbol_names in children_symbols.items(): |
| 203 | + for child_symbol_name in child_symbol_names: |
| 204 | + val = child_module.__dict__.get(child_symbol_name, no_key) |
| 205 | + if val == no_key: |
| 206 | + print(f'sys.modulesに{child_symbol_name}が存在しません') |
| 207 | + else: |
| 208 | + parent.__dict__[child_symbol_name] = val |
0 commit comments