Skip to content

Commit a344d77

Browse files
Peterc3-devPeter Clemente IIIclaude
authored
Wave 2 polish: fix shader build, remove hardcoded paths, add CI (#1)
- compile.sh: target Vulkan 1.1 (SPIR-V 1.3) so the 12 subgroup-using quantized kernels (matmul_q*k*, matmul_gpuq*) actually compile; the default `glslangValidator -V` emits SPIR-V 1.0 and failed on them. Added set -euo pipefail and a non-zero exit on any failure. - vulkan_engine.cpp: drop the hardcoded /home/raz/... shader path (leaked a developer path and broke on every other machine). Fall back to the TORCH_VULKAN_SHADER_DIR env var instead. - __init__.py: export the resolved bundled-shader dir into TORCH_VULKAN_SHADER_DIR so the lazily-constructed VulkanEngine resolves the same shaders. Aligned the module docstring with the README (.to("vulkan") is the supported path; torch.randn(device=) / .vulkan() are partial). - tests: removed hardcoded /home/raz/builds/pytorch-gfx1150 sys.path inserts; honour TORCH_VULKAN_PYTORCH_PATH instead. Fixed the misleading "relu isn't implemented" fallback test (relu IS wired) and split it into a real relu correctness test plus a genuine CPU-fallback test using an unimplemented op. - Removed dead imports (numpy, time, sys) and unused locals flagged by ruff. - Added pyproject.toml (ruff config) and .github/workflows/ci.yml that compiles all shaders and runs ruff + py_compile. The GPU build/test suite is not run in CI (no GPU on hosted runners). Co-authored-by: Peter Clemente III <peterc3@live.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 323e4d9 commit a344d77

12 files changed

Lines changed: 187 additions & 33 deletions

File tree

.github/workflows/ci.yml

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches: [master, main]
6+
pull_request:
7+
8+
jobs:
9+
shaders:
10+
name: Compile SPIR-V shaders
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
15+
- name: Install glslang
16+
run: |
17+
sudo apt-get update
18+
sudo apt-get install -y glslang-tools
19+
20+
# Compile every .comp to a scratch dir and fail if any shader does not
21+
# compile. This guards the regression that the default `glslangValidator -V`
22+
# (SPIR-V 1.0) could not build the subgroup-using quantized kernels, which
23+
# need --target-env vulkan1.1 (SPIR-V 1.3).
24+
- name: Compile all shaders
25+
working-directory: csrc/shaders
26+
run: |
27+
set -euo pipefail
28+
out="$(mktemp -d)"
29+
fail=0
30+
for comp in *.comp; do
31+
echo "Compiling $comp"
32+
if ! glslangValidator --target-env vulkan1.1 -V "$comp" -o "$out/${comp%.comp}.spv"; then
33+
echo "::error file=csrc/shaders/$comp::shader failed to compile"
34+
fail=1
35+
fi
36+
done
37+
[ "$fail" -eq 0 ] || { echo "One or more shaders failed to compile"; exit 1; }
38+
echo "Compiled $(ls "$out"/*.spv | wc -l) shaders."
39+
40+
python-lint:
41+
name: Python lint + syntax
42+
runs-on: ubuntu-latest
43+
steps:
44+
- uses: actions/checkout@v4
45+
46+
- uses: actions/setup-python@v5
47+
with:
48+
python-version: "3.11"
49+
50+
- name: Install ruff
51+
run: pip install ruff
52+
53+
# py_compile catches syntax errors without needing the compiled _C
54+
# extension (which requires a Vulkan GPU + Kompute + custom PyTorch and
55+
# cannot run on a stock CI runner).
56+
- name: Syntax check
57+
run: |
58+
python -m py_compile setup.py persistent_pipeline.py \
59+
torch_vulkan/__init__.py tests/*.py
60+
61+
- name: Ruff
62+
run: ruff check .
63+
64+
# NOTE: The C++ extension build (CMake + Torch + Kompute) and the runtime test
65+
# suite require a Vulkan-capable GPU and are not run here -- GitHub-hosted
66+
# runners have no GPU. Build/test verification is done locally on the target
67+
# hardware (AMD Radeon 890M / RADV). This workflow verifies what can be checked
68+
# without a GPU: shader compilation and Python static checks.

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,16 @@ attention + KV cache, kv_store.
7878

7979
```bash
8080
# Requires a Vulkan 1.2+ driver (RADV recommended) and glslangValidator for shader build.
81-
cd build && cmake .. && make -j$(nproc)
81+
82+
# 1. (Re)compile the GLSL shaders to SPIR-V (pre-built .spv files are committed,
83+
# so this is only needed if you edit a shader). Targets Vulkan 1.1 / SPIR-V 1.3
84+
# because the quantized kernels use subgroup ops.
85+
bash csrc/shaders/compile.sh
86+
87+
# 2. Build the C++ extension.
88+
cmake -S . -B build && cmake --build build -j
89+
90+
# 3. Install the Python package (build_ext also invokes CMake).
8291
pip install -e .
8392

8493
python -c "import torch, torch_vulkan; print(torch_vulkan.device_name())"

csrc/shaders/compile.sh

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,30 @@
11
#!/bin/bash
22
# Compile all GLSL compute shaders to SPIR-V
33
# Requires: glslangValidator (from Vulkan SDK or `pacman -S glslang`)
4+
#
5+
# We target Vulkan 1.1, which produces SPIR-V 1.3. Several of the quantized
6+
# matmul kernels (matmul_q*k*, matmul_gpuq*) use GL_KHR_shader_subgroup
7+
# reductions, and those subgroup ops require SPIR-V >= 1.3. The default
8+
# `glslangValidator -V` emits SPIR-V 1.0 and fails on them, so the explicit
9+
# --target-env is required to compile the full shader set.
10+
set -euo pipefail
411

512
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
613
cd "$SCRIPT_DIR"
714

15+
fail=0
816
for comp in *.comp; do
917
spv="${comp%.comp}.spv"
1018
echo "Compiling $comp -> $spv"
11-
glslangValidator -V "$comp" -o "$spv"
19+
if ! glslangValidator --target-env vulkan1.1 -V "$comp" -o "$spv"; then
20+
echo " ERROR: failed to compile $comp" >&2
21+
fail=1
22+
fi
1223
done
1324

25+
if [ "$fail" -ne 0 ]; then
26+
echo "One or more shaders failed to compile." >&2
27+
exit 1
28+
fi
29+
1430
echo "Done. $(ls *.spv | wc -l) shaders compiled."

csrc/torch_vulkan.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, VulkanGuardImpl);
6262

6363
// Python bindings
6464
void set_shader_dir(const std::string& path) {
65+
// Set the shader dir on the Kompute-based VulkanContext (used by all the
66+
// registered ops). The raw VulkanEngine (used only by the not-yet-wired
67+
// mm_raw path) is constructed lazily and reads its directory from the
68+
// TORCH_VULKAN_SHADER_DIR env var, which __init__.py points at this same
69+
// directory -- so we deliberately do NOT touch VulkanEngine::instance()
70+
// here, to avoid eagerly creating a second Vulkan device at import time.
6571
VulkanContext::instance().set_shader_dir(path);
6672
}
6773

csrc/vulkan_engine.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <fstream>
33
#include <stdexcept>
44
#include <cstring>
5+
#include <cstdlib>
56
#include <algorithm>
67

78
namespace torch_vulkan {
@@ -12,7 +13,16 @@ VulkanEngine& VulkanEngine::instance() {
1213
}
1314

1415
VulkanEngine::VulkanEngine() {
15-
shaderDir_ = "/home/raz/projects/torch-vulkan/csrc/shaders/";
16+
// The shader directory is normally set at import time via
17+
// setShaderDir() (driven by torch_vulkan._set_shader_dir in __init__.py).
18+
// As a fallback for direct/standalone use, honour TORCH_VULKAN_SHADER_DIR
19+
// so the path is never hardcoded to a developer's machine.
20+
if (const char* env = std::getenv("TORCH_VULKAN_SHADER_DIR")) {
21+
shaderDir_ = env;
22+
if (!shaderDir_.empty() && shaderDir_.back() != '/') {
23+
shaderDir_ += '/';
24+
}
25+
}
1626
initVulkan();
1727
}
1828

persistent_pipeline.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
"""
99

1010
import torch
11-
import time
12-
import numpy as np
1311

1412

1513
class PersistentLayerPipeline:

pyproject.toml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[tool.ruff]
2+
# torch-vulkan is a thin Python shim over a C++/Vulkan extension; keep linting
3+
# focused on real problems (pyflakes + import sorting) rather than style churn.
4+
line-length = 100
5+
target-version = "py310"
6+
7+
[tool.ruff.lint]
8+
select = ["E", "F", "I"]
9+
# E402 (module-level import not at top of file) is intentionally allowed: the
10+
# package must import `torch` first, then load the `_C` extension that registers
11+
# the PrivateUse1 backend, and finally alias the module into `torch.vulkan` --
12+
# all of which happen mid-module by design.
13+
ignore = ["E402"]

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88

99
import os
1010
import subprocess
11-
import sys
1211

13-
from setuptools import setup, Extension
12+
from setuptools import Extension, setup
1413
from setuptools.command.build_ext import build_ext
1514

1615

tests/bench_layer.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,19 @@
44
Focuses on the hot ops: mm, add, gelu that repeat across layers.
55
"""
66

7-
import sys
87
import os
8+
import sys
99
import time
1010

11-
sys.path.insert(0, "/home/raz/builds/pytorch-gfx1150")
11+
# Allow pointing at a custom-built PyTorch without hardcoding a developer's
12+
# path. Set TORCH_VULKAN_PYTORCH_PATH if needed.
13+
_custom_torch = os.environ.get("TORCH_VULKAN_PYTORCH_PATH")
14+
if _custom_torch:
15+
sys.path.insert(0, _custom_torch)
1216
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
1317

1418
import torch
19+
1520
import torch_vulkan
1621

1722

@@ -23,12 +28,12 @@ def bench_mm_repeated(M, K, N, iters=20):
2328
b = b_cpu.to("vulkan")
2429

2530
# Warmup (first call = cache miss)
26-
c = torch.mm(a, b)
31+
torch.mm(a, b)
2732

2833
# Benchmark (subsequent calls should hit cache)
2934
t0 = time.perf_counter()
3035
for _ in range(iters):
31-
c = torch.mm(a, b)
36+
torch.mm(a, b)
3237
elapsed = (time.perf_counter() - t0) / iters
3338
print(f" mm [{M}x{K}] @ [{K}x{N}]: {elapsed*1000:.2f} ms/call")
3439
return elapsed
@@ -41,11 +46,11 @@ def bench_add_repeated(N, iters=20):
4146
a = a_cpu.to("vulkan")
4247
b = b_cpu.to("vulkan")
4348

44-
c = torch.add(a, b) # warmup
49+
torch.add(a, b) # warmup
4550

4651
t0 = time.perf_counter()
4752
for _ in range(iters):
48-
c = torch.add(a, b)
53+
torch.add(a, b)
4954
elapsed = (time.perf_counter() - t0) / iters
5055
print(f" add [{N}]: {elapsed*1000:.2f} ms/call")
5156
return elapsed
@@ -56,11 +61,11 @@ def bench_gelu_repeated(N, iters=20):
5661
a_cpu = torch.randn(N)
5762
a = a_cpu.to("vulkan")
5863

59-
c = torch.nn.functional.gelu(a) # warmup
64+
torch.nn.functional.gelu(a) # warmup
6065

6166
t0 = time.perf_counter()
6267
for _ in range(iters):
63-
c = torch.nn.functional.gelu(a)
68+
torch.nn.functional.gelu(a)
6469
elapsed = (time.perf_counter() - t0) / iters
6570
print(f" gelu [{N}]: {elapsed*1000:.2f} ms/call")
6671
return elapsed

tests/test_algo_cache.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@
44
repeated dispatches with the same tensor buffers hit the cache.
55
"""
66

7-
import sys
87
import os
8+
import sys
99
import time
1010

11-
# Use the custom-built PyTorch
12-
sys.path.insert(0, "/home/raz/builds/pytorch-gfx1150")
11+
# Allow pointing at a custom-built PyTorch (e.g. a local APU build) without
12+
# hardcoding a developer's path. Set TORCH_VULKAN_PYTORCH_PATH if needed.
13+
_custom_torch = os.environ.get("TORCH_VULKAN_PYTORCH_PATH")
14+
if _custom_torch:
15+
sys.path.insert(0, _custom_torch)
1316
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
1417

1518
import torch
19+
1620
import torch_vulkan
1721

1822

0 commit comments

Comments
 (0)