-
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathebsynth_torch_loader.py
More file actions
48 lines (41 loc) · 1.78 KB
/
ebsynth_torch_loader.py
File metadata and controls
48 lines (41 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# ezsynth/ebsynth_torch_loader.py
import os
from pathlib import Path
import torch.utils.cpp_extension
# This is the name the compiled module will have in Python
MODULE_NAME = "ebsynth_torch_jit"
# Find the directory where the C++/CUDA source files are located
# This assumes this loader file is in the same directory as the 'engines' folder
# or has a predictable relative path. Let's make it robust.
_ext_dir = Path(__file__).parent.parent / Path(__file__).parent / "ebsynth_extension"
# List all the source files for the extension
_source_files = [
_ext_dir / "ext.cpp",
_ext_dir / "dispatch.cu",
_ext_dir / "kernels.cu",
_ext_dir / "integral_image.cu",
]
# Convert Path objects to strings for the compiler
_source_files_str = [str(p) for p in _source_files]
# JIT compilation using torch.utils.cpp_extension.load()
# This will be executed only once, the first time this module is imported.
# PyTorch caches the compiled library in a build directory.
try:
if os.getenv("JIT_VERBOSE", "").lower() in ("1", "true", "yes"):
print(f"Attempting to JIT compile and load CUDA extension '{MODULE_NAME}'...")
ebsynth_torch = torch.utils.cpp_extension.load(
name=MODULE_NAME,
sources=_source_files_str,
# Use verbose=True to see the compiler commands and debug issues
verbose=True,
)
if os.getenv("JIT_VERBOSE", "").lower() in ("1", "true", "yes"):
print("CUDA extension loaded successfully via JIT compilation.")
except Exception as e:
print("=" * 50)
print(f"[ERROR] Failed to JIT compile the CUDA extension '{MODULE_NAME}'.")
print("Please ensure you have a compatible C++ compiler (MSVC on Windows)")
print("and the NVIDIA CUDA Toolkit installed.")
print(f"Error details: {e}")
print("=" * 50)
ebsynth_torch = None