Skip to content

Commit ab9e61b

Browse files
log stack trace for marlin jit error (#2887)
1 parent 28e970e commit ab9e61b

1 file changed

Lines changed: 26 additions & 4 deletions

File tree

gptqmodel/utils/cpp.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,24 @@
99
import logging
1010
import math
1111
import os
12+
import pcre
1213
import platform
1314
import shutil
1415
import subprocess
1516
import sys
1617
import threading
1718
import time
19+
import traceback
1820
from contextlib import contextmanager
1921
from pathlib import Path
22+
from torch.utils.cpp_extension import CUDA_HOME, _get_build_directory, _get_cuda_arch_flags, load
2023
from typing import Callable, Optional, Sequence
2124

22-
import pcre
2325
import torch
24-
from torch.utils.cpp_extension import CUDA_HOME, _get_build_directory, _get_cuda_arch_flags, load
25-
2626
from .env import env_flag
2727
from .jit_compile_baselines import get_jit_compile_baseline_seconds
2828
from .logger import setup_logger
2929

30-
3130
log = logging.getLogger(__name__)
3231

3332
# One process-local lock serializes every torch.ops JIT cache mutation and
@@ -64,6 +63,17 @@
6463
_TORCH_OPS_BUILD_ROOT_ENV = "GPTQMODEL_TORCH_EXTENSIONS_DIR"
6564

6665

66+
def _log_cache_clear_callsite(*, reason: str, target_path: str | Path) -> None:
67+
stack_text = "".join(traceback.format_stack(limit=32))
68+
log.warning(
69+
"[jit-cache-clear] reason=%s pid=%s path=%s\ncallstack:\n%s",
70+
reason,
71+
os.getpid(),
72+
target_path,
73+
stack_text,
74+
)
75+
76+
6777
def _nvcc_path() -> Optional[str]:
6878
return shutil.which("nvcc")
6979

@@ -826,6 +836,10 @@ def clear_cache(self) -> None:
826836
self._op_cache = {}
827837
build_root = self.base_build_root()
828838
if build_root.exists():
839+
_log_cache_clear_callsite(
840+
reason=f"{self.display_name}.clear_cache",
841+
target_path=build_root,
842+
)
829843
shutil.rmtree(build_root, ignore_errors=True)
830844

831845
def last_error_message(self) -> str:
@@ -880,6 +894,10 @@ def load(self) -> bool:
880894

881895
if force_rebuild and base_build_root.exists():
882896
setup_logger().info(f"{self.display_name}: clearing cached JIT extension at `{base_build_root}`.")
897+
_log_cache_clear_callsite(
898+
reason=f"{self.display_name}.force_rebuild",
899+
target_path=base_build_root,
900+
)
883901
shutil.rmtree(base_build_root, ignore_errors=True)
884902

885903
build_root.mkdir(parents=True, exist_ok=True)
@@ -1023,6 +1041,10 @@ def safe_load_cpp_ext(
10231041
build_directory = build_directory or _get_build_directory(name, verbose=verbose)
10241042
if os.path.exists(build_directory):
10251043
try:
1044+
_log_cache_clear_callsite(
1045+
reason="safe_load_cpp_ext.first_init_cleanup",
1046+
target_path=build_directory,
1047+
)
10261048
shutil.rmtree(build_directory)
10271049
if verbose:
10281050
log.debug(f"[safe_cpp_extension_load] Removed old build directory: {build_directory}")

0 commit comments

Comments
 (0)