Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.

Commit 33d6373

Browse files
danieldksywangyi
andauthored
Improve cutlass-sycl support (#214)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 0348287 commit 33d6373

10 files changed

Lines changed: 220 additions & 30 deletions

File tree

build2cmake/src/config/v2.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ pub enum Dependencies {
217217
Cutlass3_8,
218218
#[serde(rename = "cutlass_3_9")]
219219
Cutlass3_9,
220-
#[serde(rename = "cutlass_sycl_3_9")]
221-
CutlassSycl3_9,
220+
#[serde(rename = "cutlass_sycl")]
221+
CutlassSycl,
222222
Torch,
223223
}
224224

build2cmake/src/templates/xpu/dep-cutlass-sycl.cmake

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,13 @@ if (NOT CutlassSycl_FOUND)
3434

3535
# Set Intel backend env
3636
message(STATUS "Setting Intel GPU optimization env vars for Cutlass-SYCL")
37-
string(REPLACE "-fsycl-targets=spir64_gen,spir64" "-fsycl-targets=intel_gpu_pvc" sycl_link_flags "${sycl_link_flags}")
38-
string(REPLACE "-device pvc,xe-lpg,ats-m150" "" sycl_link_flags "${sycl_link_flags}")
39-
string(APPEND sycl_link_flags "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier;")
40-
string(REPLACE "-fsycl-targets=spir64_gen,spir64" "-fsycl-targets=intel_gpu_pvc" sycl_flags "${sycl_flags}")
41-
4237
set(CUTLASS_ENABLE_SYCL ON CACHE BOOL "Enable SYCL for CUTLASS")
4338
add_compile_definitions(CUTLASS_ENABLE_SYCL=1)
4439
set(DPCPP_SYCL_TARGET "intel_gpu_pvc" CACHE STRING "SYCL target for Intel GPU")
4540
add_compile_definitions(DPCPP_SYCL_TARGET=intel_gpu_pvc)
4641
set(SYCL_INTEL_TARGET ON CACHE BOOL "Enable SYCL for INTEL")
4742
add_compile_definitions(SYCL_INTEL_TARGET=1)
48-
43+
4944
set(ENV{SYCL_PROGRAM_COMPILE_OPTIONS} "-ze-opt-large-register-file")
5045
set(ENV{IGC_VISAOptions} "-perfmodel")
5146
set(ENV{IGC_VectorAliasBBThreshold} "10000")
@@ -56,5 +51,11 @@ if (NOT CutlassSycl_FOUND)
5651
include_directories(${CUTLASS_INCLUDE_DIR})
5752
include_directories(${CUTLASS_TOOLS_UTIL_INCLUDE_DIR})
5853
else()
59-
message(STATUS "Using system cutlass with version: ${CutlassSycl_VERSION}")
54+
include_directories(${CUTLASS_INCLUDE_DIR})
55+
include_directories(${CUTLASS_TOOLS_UTIL_INCLUDE_DIR})
6056
endif(NOT CutlassSycl_FOUND)
57+
string(REPLACE "-fsycl-targets=spir64_gen,spir64" "-fsycl-targets=intel_gpu_pvc" sycl_link_flags "${sycl_link_flags}")
58+
string(REPLACE "-device pvc,xe-lpg,ats-m150" "" sycl_link_flags "${sycl_link_flags}")
59+
string(APPEND sycl_link_flags "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier;")
60+
string(REPLACE "-fsycl-targets=spir64_gen,spir64" "-fsycl-targets=intel_gpu_pvc" sycl_flags "${sycl_flags}")
61+

build2cmake/src/torch/xpu.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::collections::HashSet;
2+
use std::env;
23
use std::io::Write;
34
use std::path::PathBuf;
45

@@ -196,16 +197,18 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
196197

197198
for dep in deps {
198199
match dep {
199-
Dependencies::CutlassSycl3_9 => {
200-
env.get_template("xpu/dep-cutlass-sycl.cmake")
201-
.wrap_err("Cannot get CUTLASS-SYCL dependency template")?
202-
.render_to_write(
203-
context! {
204-
version => "3.9-0.3",
205-
},
206-
&mut *write,
207-
)
208-
.wrap_err("Cannot render CUTLASS-SYCL dependency template")?;
200+
Dependencies::CutlassSycl => {
201+
let dpcpp_version = env::var("DPCPP_VERSION").unwrap_or("2025.1".to_string());
202+
let version = match dpcpp_version.as_str() {
203+
"2025.0" => "3.9-0.2",
204+
"2025.1" => "3.9-0.3",
205+
_ => bail!(
206+
"No cutlass_sycl version mapped for DPCPP_VERSION {}",
207+
dpcpp_version
208+
),
209+
};
210+
env.get_template("xpu/dep-cutlass-sycl.cmake")?
211+
.render_to_write(context! { version => version }, &mut *write)?;
209212
}
210213
Dependencies::Torch => (),
211214
_ => {

examples/cutlass-gemm/build.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,12 @@ depends = [
1515
"cutlass_3_6",
1616
]
1717
src = ["gemm.cu"]
18+
19+
[kernel.gemm_xpu]
20+
backend = "xpu"
21+
depends = [
22+
"torch",
23+
"cutlass_sycl",
24+
]
25+
src = ["gemm_sycl.cpp"]
26+
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
/*! \file
32+
\brief CUTLASS Intel BMG Gemm Example.
33+
34+
This example constructs and executes a simple CUTLASS GEMM kernel on Intel BMG hardware, and
35+
verifies its correctness with a reference implementation
36+
(cutlass::reference::device::GemmComplex). The example also provides a performance measurement
37+
for the GEMM in TFLOPS.
38+
39+
This example makes use of BMGs subgroup cooperative 2d-block copy operations and DPAS instructions.
40+
41+
The shapes of the A and B matrix are defined at runtime by `options.m`, `.n` and `.k`, and the
42+
batch size is defined by `options.l`. The tile shape, which defines how much work is executed by
43+
a single work-group, is defined at compile time by:
44+
```
45+
using TileShape = Shape<_256, _256, _32>;
46+
```
47+
That is, each work-group processes a tile of M=256, N=256, and iterates over `options.k` in
48+
blocks of K=32.
49+
50+
Performance of GEMM on BMG is heavily dependent on prefetching the A and B matrices. That is,
51+
executing Intel specific prefetch instructions for future iterations to ensure that the required
52+
blocks of A and B are resident in cache before they are needed.
53+
54+
To build & run this example (from your build dir):
55+
56+
$ ninja 00_bmg_gemm
57+
$ ./examples/sycl/00_bmg_gemm/00_bmg_gemm
58+
59+
Call with `--help` for information about available options
60+
*/
61+
62+
#include "cutlass/epilogue/collective/default_epilogue.hpp"
63+
#include "cutlass/epilogue/collective/xe_epilogue.hpp"
64+
#include "cutlass/epilogue/fusion/xe_callbacks.hpp"
65+
#include "cutlass/gemm/device/gemm_universal.h"
66+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
67+
#include "cutlass/gemm/collective/collective_mma.hpp"
68+
#include "cutlass/util/GPU_Clock.hpp"
69+
70+
#include <cute/tensor.hpp>
71+
#include <random>
72+
73+
#include "cutlass/util/command_line.h"
74+
#include "cutlass/util/device_memory.h"
75+
#include "cutlass/util/packed_stride.hpp"
76+
#include "cutlass/util/reference/device/gemm_complex.h"
77+
#include "cutlass/util/reference/device/tensor_compare.h"
78+
#include <torch/all.h>
79+
using namespace cute;
80+
81+
#define CUTLASS_CHECK(status) \
82+
{ \
83+
cutlass::Status error = status; \
84+
if (error != cutlass::Status::kSuccess) { \
85+
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
86+
<< std::endl; \
87+
exit(EXIT_FAILURE); \
88+
} \
89+
}
90+
91+
void cutlass_gemm(torch::Tensor &out, torch::Tensor const &A, torch::Tensor const &B) {
92+
using ElementAccumulator = float;
93+
using ElementComputeEpilogue = float;
94+
using ElementInputA = bfloat16_t;
95+
using ElementInputB = bfloat16_t;
96+
using ElementOutput = float;
97+
98+
using LayoutA = cutlass::layout::RowMajor;
99+
using LayoutB = cutlass::layout::RowMajor;
100+
using LayoutC = cutlass::layout::RowMajor;
101+
using LayoutD = cutlass::layout::RowMajor;
102+
103+
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
104+
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
105+
using TileShape = Shape<_256, _256, _32>;
106+
using TiledMma = typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>, Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
107+
constexpr int PipelineStages = 2;
108+
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
109+
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
110+
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
111+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape, decltype(tile_shape(TiledMma()))>;
112+
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
113+
EpilogueDispatchPolicy,
114+
TileShape,
115+
ElementAccumulator,
116+
cutlass::gemm::TagToStrideC_t<LayoutC>,
117+
ElementOutput,
118+
cutlass::gemm::TagToStrideC_t<LayoutD>,
119+
FusionCallBacks,
120+
XE_2D_U32x8x16_LD_N,
121+
void, void,
122+
XE_2D_U32x8x16_ST_N,
123+
void, void>;
124+
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
125+
GEMMDispatchPolicy,
126+
TileShape,
127+
ElementInputA,
128+
cutlass::gemm::TagToStrideA_t<LayoutA>,
129+
ElementInputB,
130+
cutlass::gemm::TagToStrideB_t<LayoutB>,
131+
TiledMma,
132+
GmemTiledCopyA, void, void, cute::identity,
133+
GmemTiledCopyB, void, void, cute::identity>;
134+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
135+
Shape<int, int, int, int>,
136+
CollectiveMainloop,
137+
CollectiveEpilogue>;
138+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
139+
140+
cutlass::KernelHardwareInfo hw_info;
141+
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
142+
143+
// get shape
144+
int M = A.size(0);
145+
int K = A.size(1);
146+
int N = B.size(1);
147+
int L = 1; // batch size
148+
149+
auto stride_A = cutlass::make_cute_packed_stride(GemmKernel::StrideA{}, cute::make_shape(M, K, L));
150+
auto stride_B = cutlass::make_cute_packed_stride(GemmKernel::StrideB{}, cute::make_shape(N, K, L));
151+
auto stride_C = cutlass::make_cute_packed_stride(GemmKernel::StrideC{}, cute::make_shape(M, N, L));
152+
auto stride_D = cutlass::make_cute_packed_stride(GemmKernel::StrideD{}, cute::make_shape(M, N, L));
153+
154+
GemmKernel::Arguments arguments{
155+
cutlass::gemm::GemmUniversalMode::kGemm,
156+
GemmKernel::ProblemShape{M, N, K, L},
157+
{reinterpret_cast<ElementInputA*>(A.data_ptr()), stride_A, reinterpret_cast<ElementInputB*>(B.data_ptr()), stride_B},
158+
{{1.0f, 0.0f}, reinterpret_cast<ElementOutput*>(out.data_ptr()), stride_C, reinterpret_cast<ElementOutput*>(out.data_ptr()), stride_D},
159+
hw_info
160+
};
161+
162+
Gemm gemm_op;
163+
size_t workspace_size = Gemm::get_workspace_size(arguments);
164+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
165+
166+
TORCH_CHECK(gemm_op.can_implement(arguments) == cutlass::Status::kSuccess, "Invalid GEMM problem size or configuration");
167+
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
168+
CUTLASS_CHECK(gemm_op.run());
169+
syclcompat::wait();
170+
}

examples/cutlass-gemm/tests/test_gemm.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33
from cutlass_gemm import cutlass_gemm
44

55
def test_gemm():
6-
A = torch.randn((10, 20), device="cuda", dtype=torch.float32)
7-
B = torch.randn((20, 30), device="cuda", dtype=torch.float32)
8-
out = torch.randn((10, 30), device="cuda", dtype=torch.float32)
6+
if hasattr(torch, "xpu") and torch.xpu.is_available():
7+
A = torch.randn((64, 32), device=torch.device("xpu"), dtype=torch.bfloat16)
8+
B = torch.randn((32, 64), device=torch.device("xpu"), dtype=torch.bfloat16)
9+
out = torch.randn((64, 64), device=torch.device("xpu"), dtype=torch.float32)
10+
else:
11+
A = torch.randn((10, 20), device=torch.device("cuda"), dtype=torch.float32)
12+
B = torch.randn((20, 30), device=torch.device("cuda"), dtype=torch.float32)
13+
out = torch.randn((10, 30), device=torch.device("cuda"), dtype=torch.float32)
914

1015
cutlass_gemm(out, A, B)
1116

12-
torch.testing.assert_allclose(out, torch.mm(A, B))
17+
torch.testing.assert_allclose(out, torch.mm(A.float(), B.float()))

examples/cutlass-gemm/torch-ext/torch_binding.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
77
ops.def("cutlass_gemm(Tensor! out, Tensor A, Tensor B) -> ()");
8+
#if defined(CUDA_KERNEL)
89
ops.impl("cutlass_gemm", torch::kCUDA, &cutlass_gemm);
10+
#elif defined(XPU_KERNEL)
11+
ops.impl("cutlass_gemm", torch::kXPU, &cutlass_gemm);
12+
#endif
913
}
1014

1115
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

flake.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lib/deps.nix

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,12 @@ let
2525
"cutlass_3_9" = [
2626
pkgs.cutlass_3_9
2727
];
28-
"cutlass_sycl_3_9" = [
29-
pkgs.cutlass_sycl_3_9
30-
];
3128
"torch" = [
3229
torch
3330
torch.cxxdev
3431
];
32+
"cutlass_sycl" = [ torch.xpuPackages.cutlass-sycl ];
3533
};
36-
3734
in
3835
let
3936
depToPkg =

lib/torch-extension/default.nix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ stdenv.mkDerivation (prevAttrs: {
153153
// lib.optionalAttrs xpuSupport {
154154
MKLROOT = oneapi-torch-dev;
155155
SYCL_ROOT = oneapi-torch-dev;
156+
DPCPP_VERSION = (lib.versions.majorMinor xpuPackages.intel-oneapi-dpcpp-cpp.version);
156157
};
157158

158159
# If we use the default setup, CMAKE_CUDA_HOST_COMPILER gets set to nixpkgs g++.

0 commit comments

Comments
 (0)