Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions gptqmodel/utils/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,24 @@
import logging
import math
import os
import pcre
import platform
import shutil
import subprocess
import sys
import threading
import time
import traceback
from contextlib import contextmanager
from pathlib import Path
from torch.utils.cpp_extension import CUDA_HOME, _get_build_directory, _get_cuda_arch_flags, load
from typing import Callable, Optional, Sequence

import pcre
import torch
from torch.utils.cpp_extension import CUDA_HOME, _get_build_directory, _get_cuda_arch_flags, load

from .env import env_flag
from .jit_compile_baselines import get_jit_compile_baseline_seconds
from .logger import setup_logger


log = logging.getLogger(__name__)

# One process-local lock serializes every torch.ops JIT cache mutation and
Expand Down Expand Up @@ -64,6 +63,17 @@
_TORCH_OPS_BUILD_ROOT_ENV = "GPTQMODEL_TORCH_EXTENSIONS_DIR"


def _log_cache_clear_callsite(*, reason: str, target_path: str | Path) -> None:
stack_text = "".join(traceback.format_stack(limit=32))
log.warning(
"[jit-cache-clear] reason=%s pid=%s path=%s\ncallstack:\n%s",
reason,
os.getpid(),
target_path,
stack_text,
)


def _nvcc_path() -> Optional[str]:
return shutil.which("nvcc")

Expand Down Expand Up @@ -826,6 +836,10 @@ def clear_cache(self) -> None:
self._op_cache = {}
build_root = self.base_build_root()
if build_root.exists():
_log_cache_clear_callsite(
reason=f"{self.display_name}.clear_cache",
target_path=build_root,
)
shutil.rmtree(build_root, ignore_errors=True)

def last_error_message(self) -> str:
Expand Down Expand Up @@ -880,6 +894,10 @@ def load(self) -> bool:

if force_rebuild and base_build_root.exists():
setup_logger().info(f"{self.display_name}: clearing cached JIT extension at `{base_build_root}`.")
_log_cache_clear_callsite(
reason=f"{self.display_name}.force_rebuild",
target_path=base_build_root,
)
shutil.rmtree(base_build_root, ignore_errors=True)

build_root.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -1023,6 +1041,10 @@ def safe_load_cpp_ext(
build_directory = build_directory or _get_build_directory(name, verbose=verbose)
if os.path.exists(build_directory):
try:
_log_cache_clear_callsite(
reason="safe_load_cpp_ext.first_init_cleanup",
target_path=build_directory,
)
shutil.rmtree(build_directory)
if verbose:
log.debug(f"[safe_cpp_extension_load] Removed old build directory: {build_directory}")
Expand Down