Skip to content

Commit e8a6d87

Browse files
committed
refine code
1 parent c699120 commit e8a6d87

5 files changed

Lines changed: 17 additions & 15 deletions

File tree

CMakeLists.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
2828
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
2929
set(MPS_FILES csrc/mps_ops.mm)
3030
set(METAL_FILES csrc/mps_kernels.metal)
31-
#set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp csrc/xpu_cutlass.cpp csrc/xpu_cutlass-cute.cpp csrc/xpu_cutlass_fusion.cpp)
32-
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp csrc/xpu_cutlass_fusion.cpp)
31+
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp csrc/xpu_cutlass.cpp csrc/xpu_cutlass-cute.cpp csrc/xpu_cutlass_fusion.cpp)
32+
#set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp csrc/xpu_cutlass_fusion.cpp)
3333
# C++ sources are always included
3434
list(APPEND SRC_FILES ${CPP_FILES})
3535

@@ -321,11 +321,13 @@ if(BUILD_XPU)
321321
-Xs
322322
-options "-cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required"
323323
)
324-
set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=intel_gpu_pvc;-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier;")
324+
set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=intel_gpu_pvc;-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate;")
325325

326326
set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20)
327327
target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS})
328328
target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS})
329+
#find_package(IntelSYCL REQUIRED)
330+
#target_link_libraries(bitsandbytes PRIVATE Intel::SYCL OpenCL::OpenCL)
329331

330332
endif()
331333

csrc/xpu_cutlass-cute.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ using TiledMma =
6363
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
6464
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
6565

66-
using DispatchPolicy = MainloopIntelPVC<Stages>; //, KernelPVC /*Schedule*/>;
66+
using DispatchPolicy = MainloopIntelXeXMX16<Stages>; //, KernelPVC /*Schedule*/>;
6767
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<float /*data_type of GEMM output*/, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
68-
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::IntelPVCEpilogue, EpilogueOp, TileShape, decltype(tile_shape(TiledMma()))>;
68+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::IntelXeXMX16, EpilogueOp, TileShape, decltype(tile_shape(TiledMma()))>;
6969
using SharedStorage = FusionCallBacks::SharedStorage;
7070

7171
using ClusterShape = typename DispatchPolicy::ClusterShape;
@@ -79,7 +79,7 @@ using ClusterShape = typename DispatchPolicy::ClusterShape;
7979
using TileSchedulerParams = typename TileScheduler::Params;
8080

8181
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
82-
cutlass::epilogue::IntelPVCEpilogue,
82+
cutlass::epilogue::IntelXeXMX16,
8383
TileShape,
8484
ElementAccumulator,
8585
cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>, // Convert CUTLASS 2.x to CUTLASS 3.x representation
@@ -280,8 +280,8 @@ class kgemv_4bit_inference_cutlass_cute {
280280
constexpr auto workgroup_shape = WorkgroupTileShape{};
281281
constexpr auto subgroup_shape = SubgroupTileShape{};
282282

283-
Tensor mA_mkl = cute::get_pvc_tensor(make_shape(M,K,L)); //(m,k,l)
284-
Tensor mB_nkl = cute::get_pvc_tensor(make_shape(N,K,L)); //(n,k,l)
283+
Tensor mA_mkl = cute::get_xe_tensor(make_shape(M,K,L)); //(m,k,l)
284+
Tensor mB_nkl = cute::get_xe_tensor(make_shape(N,K,L)); //(n,k,l)
285285

286286
Tensor gA = local_tile(mA_mkl, select<0,2>(blk_shape), make_coord(m_coord,_,l_coord));
287287
Tensor gB = local_tile(mB_nkl, select<1,2>(blk_shape), make_coord(n_coord,_,l_coord));

csrc/xpu_cutlass.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
2727
// Create the Epilogue
2828
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<float /*data_type of GEMM output*/, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
2929

30-
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::IntelPVCEpilogue, EpilogueOp, TileShape,
30+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::IntelXeXMX16, EpilogueOp, TileShape,
3131
decltype(tile_shape(TiledMma()))>;
3232
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
33-
cutlass::epilogue::IntelPVCEpilogue,
33+
cutlass::epilogue::IntelXeXMX16,
3434
TileShape,
3535
ElementAccumulator,
3636
cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>, // Convert CUTLASS 2.x to CUTLASS 3.x representation
@@ -44,7 +44,7 @@ void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
4444

4545
// GEMM Mainloop - iteration over blocks in K dimension
4646
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
47-
cutlass::gemm::MainloopIntelPVC<2>, //use PipelineStages = 2
47+
cutlass::gemm::MainloopIntelXeXMX16<2>, //use PipelineStages = 2
4848
TileShape,
4949
bfloat16_t, // data_type of input: A
5050
cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>, // Convert CUTLASS 2.x to CUTLASS 3.x representation

run_case.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
#gdb -args python -m pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
31-
#pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
32-
python tests/test_xpu_db.py
31+
pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
32+
#python tests/test_xpu_db.py
3333
#gdb -args python tests/test_xpu_db.py
3434
#pytest tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=256-uint8-bf16-fc1-nf4-DQ_True-xpu]

tests/test_xpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
9393
#pdb.set_trace()
9494
diff = abs(C2-C3)
9595
print("diff = ", diff.sum())
96-
print(C3[0])
97-
print(C2[0])
96+
#print(C3[0])
97+
#print(C2[0])
9898
#print(C3)
9999
#print(C2)
100100
#A.requires_grad = True

0 commit comments

Comments
 (0)