Skip to content

Commit 1b717a4

Browse files
feat: Add Windows support with triton-windows and PyTorch fallback (#237)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent ef0c528 commit 1b717a4

14 files changed

Lines changed: 1335 additions & 86 deletions

File tree

README.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,60 @@ cd AngelSlim && python setup.py install
223223

224224
For more detailed installation instructions, please refer to the [Installation Documentation](https://angelslim.readthedocs.io/zh-cn/latest/getting_started/installation.html).
225225

226+
#### Windows Installation (with FP8 Triton Support)
227+
228+
AngelSlim supports Windows with FP8 Triton kernels. Follow these steps to build from source:
229+
230+
```batch
231+
:: Clone the repository
232+
git clone https://github.com/Tencent/AngelSlim.git
233+
cd AngelSlim
234+
235+
:: Create and activate virtual environment (Python 3.10 recommended)
236+
uv venv --python 3.10
237+
.venv\Scripts\activate
238+
239+
:: Install base dependencies
240+
uv pip install packaging wheel setuptools ninja numpy==1.26.4 pip build psutil
241+
242+
:: Install PyTorch with CUDA 12.8 support
243+
uv pip install torch==2.10.0 --index-url https://download.pytorch.org/whl/cu128
244+
245+
:: Install Triton for Windows
246+
uv pip install -U triton-windows
247+
248+
:: Configure Visual Studio build environment
249+
set INCLUDE=
250+
set LIB=
251+
set LIBPATH=
252+
call "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvarsall.bat" x64
253+
254+
:: Configure CUDA environment
255+
set CUDA_HOME=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8
256+
set PATH=%CUDA_HOME%\bin;%PATH%
257+
set DISTUTILS_USE_SDK=1
258+
259+
:: Set target CUDA architectures (adjust based on your GPU)
260+
set TORCH_CUDA_ARCH_LIST=8.0;8.6;8.9;9.0
261+
262+
:: Build the wheel
263+
set DG_USE_LOCAL_VERSION=0
264+
python setup.py bdist_wheel
265+
266+
:: Verify FP8 Triton kernels are working
267+
python -c "import torch; from angelslim.compressor.diffusion.kernels.python.quantizers import fp8_per_block_quant_triton; from angelslim.compressor.diffusion.kernels.python.gemm import fp8_gemm_triton_block; a,b=torch.randn(128,256,device='cuda'),torch.randn(512,256,device='cuda'); aq,a_s=fp8_per_block_quant_triton(a); bq,b_s=fp8_per_block_quant_triton(b); c=fp8_gemm_triton_block(aq,a_s,bq,b_s); print(f'FP8 GEMM OK: {c.shape}, {c.dtype}')"
268+
```
269+
270+
**Requirements:**
271+
- Windows 10/11 with NVIDIA GPU (Ampere or newer recommended)
272+
- Visual Studio 2022 with C++ build tools
273+
- CUDA Toolkit 12.8
274+
- Python 3.10
275+
276+
**Environment Variables:**
277+
- `ANGELSLIM_BACKEND`: Force backend selection (`triton` or `pytorch`)
278+
- `ANGELSLIM_TORCH_COMPILE`: Enable/disable torch.compile (`0` or `1`)
279+
226280
### 2. Quick Start
227281

228282
#### 2.1 Speculative Decoding

angelslim/compressor/_platform.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Platform detection and backend selection for AngelSlim.
17+
18+
This module provides utilities for detecting the runtime environment
19+
and selecting appropriate backends (Triton vs PyTorch) based on
20+
platform capabilities.
21+
22+
Environment Variables:
23+
ANGELSLIM_BACKEND: Force backend selection ("triton" or "pytorch")
24+
ANGELSLIM_TORCH_COMPILE: Enable/disable torch.compile ("0" or "1")
25+
"""
26+
27+
import os
28+
import sys
29+
from enum import Enum
30+
from functools import lru_cache
31+
from typing import Optional
32+
33+
import torch
34+
35+
36+
class Platform(Enum):
37+
"""Supported platforms."""
38+
39+
LINUX = "linux"
40+
WINDOWS = "windows"
41+
MACOS = "macos"
42+
UNKNOWN = "unknown"
43+
44+
45+
class Backend(Enum):
46+
"""Available computation backends."""
47+
48+
TRITON = "triton"
49+
PYTORCH = "pytorch"
50+
51+
52+
@lru_cache(maxsize=1)
53+
def get_platform() -> Platform:
54+
"""Detect the current platform."""
55+
if sys.platform.startswith("linux"):
56+
return Platform.LINUX
57+
elif sys.platform == "win32":
58+
return Platform.WINDOWS
59+
elif sys.platform == "darwin":
60+
return Platform.MACOS
61+
return Platform.UNKNOWN
62+
63+
64+
@lru_cache(maxsize=1)
65+
def is_triton_available() -> bool:
66+
"""
67+
Check if Triton is available and functional.
68+
69+
Returns:
70+
bool: True if Triton can be used, False otherwise.
71+
"""
72+
# Check environment variable override
73+
env_backend = os.environ.get("ANGELSLIM_BACKEND", "").lower()
74+
if env_backend == "pytorch":
75+
return False
76+
if env_backend == "triton":
77+
# User explicitly requested Triton, try to use it
78+
try:
79+
import triton
80+
81+
if not torch.cuda.is_available():
82+
raise RuntimeError("ANGELSLIM_BACKEND=triton but CUDA is not available")
83+
return True
84+
except ImportError:
85+
raise RuntimeError("ANGELSLIM_BACKEND=triton but triton is not installed")
86+
87+
# Auto-detection: check CUDA availability first
88+
if not torch.cuda.is_available():
89+
return False
90+
91+
# Try to import triton
92+
try:
93+
import triton
94+
95+
# Test if JIT compilation works
96+
return _test_triton_jit()
97+
except ImportError:
98+
return False
99+
except Exception:
100+
return False
101+
102+
103+
def _test_triton_jit() -> bool:
104+
"""
105+
Test if Triton JIT compilation actually works.
106+
107+
This is needed because triton-windows may import but fail at JIT time.
108+
"""
109+
try:
110+
import triton
111+
import triton.language as tl
112+
113+
@triton.jit
114+
def _test_kernel(x_ptr, BLOCK: tl.constexpr):
115+
pid = tl.program_id(0)
116+
offs = pid * BLOCK + tl.arange(0, BLOCK)
117+
x = tl.load(x_ptr + offs)
118+
tl.store(x_ptr + offs, x + 1.0)
119+
120+
# Try to compile and run the kernel
121+
x = torch.zeros(128, device="cuda", dtype=torch.float32)
122+
_test_kernel[(1,)](x, BLOCK=128)
123+
torch.cuda.synchronize()
124+
125+
# Verify the kernel ran correctly
126+
return torch.allclose(x, torch.ones(128, device="cuda", dtype=torch.float32))
127+
except Exception:
128+
return False
129+
130+
131+
@lru_cache(maxsize=1)
132+
def get_default_backend() -> Backend:
133+
"""
134+
Get the default computation backend for the current environment.
135+
136+
Priority:
137+
1. ANGELSLIM_BACKEND environment variable
138+
2. Triton if available and functional
139+
3. PyTorch fallback
140+
141+
Returns:
142+
Backend: The selected backend.
143+
"""
144+
if is_triton_available():
145+
return Backend.TRITON
146+
return Backend.PYTORCH
147+
148+
149+
@lru_cache(maxsize=1)
150+
def is_torch_compile_supported() -> bool:
151+
"""
152+
Check if torch.compile is supported and should be enabled.
153+
154+
Returns:
155+
bool: True if torch.compile should be used.
156+
"""
157+
# Check environment variable override
158+
env_compile = os.environ.get("ANGELSLIM_TORCH_COMPILE", "").lower()
159+
if env_compile == "0" or env_compile == "false":
160+
return False
161+
if env_compile == "1" or env_compile == "true":
162+
return True
163+
164+
# Windows: torch.compile has issues with dynamo
165+
if get_platform() == Platform.WINDOWS:
166+
return False
167+
168+
# Check PyTorch version (torch.compile requires 2.0+)
169+
try:
170+
version_parts = torch.__version__.split(".")[:2]
171+
major = int(version_parts[0])
172+
if major < 2:
173+
return False
174+
except Exception:
175+
return False
176+
177+
return True
178+
179+
180+
def use_triton() -> bool:
181+
"""Check if Triton backend should be used."""
182+
return get_default_backend() == Backend.TRITON
183+
184+
185+
def use_pytorch() -> bool:
186+
"""Check if PyTorch fallback should be used."""
187+
return get_default_backend() == Backend.PYTORCH
188+
189+
190+
def get_backend_info() -> dict:
191+
"""
192+
Get detailed information about the current backend configuration.
193+
194+
Returns:
195+
dict: Backend information including platform, backend, and capabilities.
196+
"""
197+
return {
198+
"platform": get_platform().value,
199+
"backend": get_default_backend().value,
200+
"triton_available": is_triton_available(),
201+
"torch_compile_supported": is_torch_compile_supported(),
202+
"cuda_available": torch.cuda.is_available(),
203+
"cuda_device": torch.cuda.get_device_name() if torch.cuda.is_available() else None,
204+
"torch_version": torch.__version__,
205+
"env_backend": os.environ.get("ANGELSLIM_BACKEND", "auto"),
206+
"env_torch_compile": os.environ.get("ANGELSLIM_TORCH_COMPILE", "auto"),
207+
}

angelslim/compressor/diffusion/cache/taylorcache_helper.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
11
import math
2-
from typing import Any, List, Optional, Set, Tuple
2+
from typing import Any, Callable, List, Optional, Set, Tuple
33

44
import torch
55
import torch.nn as nn
66

77
from .cache_helper import CacheHelper
88

9+
# Conditional torch.compile decorator
10+
# Disabled on Windows and when ANGELSLIM_TORCH_COMPILE=0
11+
try:
12+
from angelslim.compressor._platform import is_torch_compile_supported
13+
14+
_USE_TORCH_COMPILE = is_torch_compile_supported()
15+
except ImportError:
16+
_USE_TORCH_COMPILE = False
17+
18+
19+
def _conditional_compile(func: Callable) -> Callable:
20+
"""Apply torch.compile only if supported on this platform."""
21+
if _USE_TORCH_COMPILE:
22+
return torch.compile(func)
23+
return func
24+
925

1026
class TaylorCacheHelper(CacheHelper):
1127
"""
@@ -137,7 +153,7 @@ def clear_states(self) -> None:
137153
self.taylor_cache.clear_derivatives()
138154

139155

140-
@torch.compile
156+
@_conditional_compile
141157
def decomposition_FFT(
142158
x: torch.Tensor, cutoff_ratio: float = 0.1
143159
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -188,7 +204,7 @@ def decomposition_FFT(
188204
return low, high
189205

190206

191-
@torch.compile
207+
@_conditional_compile
192208
def reconstruction(low_freq: torch.Tensor, high_freq: torch.Tensor) -> torch.Tensor:
193209
return low_freq + high_freq
194210

angelslim/compressor/diffusion/kernels/python/gemm/__init__.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .fp8_gemm import fp8_gemm_triton_block
15+
"""
16+
FP8 GEMM kernels with automatic backend selection.
1617
17-
__all__ = ["fp8_gemm_triton_block"]
18+
This module automatically selects between Triton (for Linux/CUDA) and
19+
PyTorch (for Windows/CPU) implementations based on the runtime environment.
20+
"""
21+
22+
from angelslim.compressor._platform import use_triton
23+
24+
# Conditional imports based on platform/backend availability
25+
if use_triton():
26+
from .fp8_gemm import fp8_gemm_triton_block
27+
else:
28+
# PyTorch fallback implementation
29+
from .fp8_gemm_torch import fp8_gemm_torch_block as fp8_gemm_triton_block
30+
31+
# Also export PyTorch version directly for explicit use
32+
from .fp8_gemm_torch import fp8_gemm_torch_block
33+
34+
__all__ = ["fp8_gemm_triton_block", "fp8_gemm_torch_block"]

0 commit comments

Comments
 (0)