|
9 | 9 | import logging |
10 | 10 | import math |
11 | 11 | import os |
| 12 | +import pcre |
12 | 13 | import platform |
13 | 14 | import shutil |
14 | 15 | import subprocess |
15 | 16 | import sys |
16 | 17 | import threading |
17 | 18 | import time |
| 19 | +import traceback |
18 | 20 | from contextlib import contextmanager |
19 | 21 | from pathlib import Path |
| 22 | +from torch.utils.cpp_extension import CUDA_HOME, _get_build_directory, _get_cuda_arch_flags, load |
20 | 23 | from typing import Callable, Optional, Sequence |
21 | 24 |
|
22 | | -import pcre |
23 | 25 | import torch |
24 | | -from torch.utils.cpp_extension import CUDA_HOME, _get_build_directory, _get_cuda_arch_flags, load |
25 | | - |
26 | 26 | from .env import env_flag |
27 | 27 | from .jit_compile_baselines import get_jit_compile_baseline_seconds |
28 | 28 | from .logger import setup_logger |
29 | 29 |
|
30 | | - |
31 | 30 | log = logging.getLogger(__name__) |
32 | 31 |
|
33 | 32 | # One process-local lock serializes every torch.ops JIT cache mutation and |
|
64 | 63 | _TORCH_OPS_BUILD_ROOT_ENV = "GPTQMODEL_TORCH_EXTENSIONS_DIR" |
65 | 64 |
|
66 | 65 |
|
| 66 | +def _log_cache_clear_callsite(*, reason: str, target_path: str | Path) -> None: |
| 67 | + stack_text = "".join(traceback.format_stack(limit=32)) |
| 68 | + log.warning( |
| 69 | + "[jit-cache-clear] reason=%s pid=%s path=%s\ncallstack:\n%s", |
| 70 | + reason, |
| 71 | + os.getpid(), |
| 72 | + target_path, |
| 73 | + stack_text, |
| 74 | + ) |
| 75 | + |
| 76 | + |
67 | 77 | def _nvcc_path() -> Optional[str]: |
68 | 78 | return shutil.which("nvcc") |
69 | 79 |
|
@@ -826,6 +836,10 @@ def clear_cache(self) -> None: |
826 | 836 | self._op_cache = {} |
827 | 837 | build_root = self.base_build_root() |
828 | 838 | if build_root.exists(): |
| 839 | + _log_cache_clear_callsite( |
| 840 | + reason=f"{self.display_name}.clear_cache", |
| 841 | + target_path=build_root, |
| 842 | + ) |
829 | 843 | shutil.rmtree(build_root, ignore_errors=True) |
830 | 844 |
|
831 | 845 | def last_error_message(self) -> str: |
@@ -880,6 +894,10 @@ def load(self) -> bool: |
880 | 894 |
|
881 | 895 | if force_rebuild and base_build_root.exists(): |
882 | 896 | setup_logger().info(f"{self.display_name}: clearing cached JIT extension at `{base_build_root}`.") |
| 897 | + _log_cache_clear_callsite( |
| 898 | + reason=f"{self.display_name}.force_rebuild", |
| 899 | + target_path=base_build_root, |
| 900 | + ) |
883 | 901 | shutil.rmtree(base_build_root, ignore_errors=True) |
884 | 902 |
|
885 | 903 | build_root.mkdir(parents=True, exist_ok=True) |
@@ -1023,6 +1041,10 @@ def safe_load_cpp_ext( |
1023 | 1041 | build_directory = build_directory or _get_build_directory(name, verbose=verbose) |
1024 | 1042 | if os.path.exists(build_directory): |
1025 | 1043 | try: |
| 1044 | + _log_cache_clear_callsite( |
| 1045 | + reason="safe_load_cpp_ext.first_init_cleanup", |
| 1046 | + target_path=build_directory, |
| 1047 | + ) |
1026 | 1048 | shutil.rmtree(build_directory) |
1027 | 1049 | if verbose: |
1028 | 1050 | log.debug(f"[safe_cpp_extension_load] Removed old build directory: {build_directory}") |
|
0 commit comments