@@ -63,9 +63,9 @@ using TiledMma =
6363 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
6464 Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
6565
66- using DispatchPolicy = MainloopIntelPVC <Stages>; // , KernelPVC /*Schedule*/>;
66+ using DispatchPolicy = MainloopIntelXeXMX16 <Stages>; // , KernelPVC /*Schedule*/>;
6767using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<float /* data_type of GEMM output*/ , ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
68- using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::IntelPVCEpilogue , EpilogueOp, TileShape, decltype (tile_shape(TiledMma()))>;
68+ using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::IntelXeXMX16 , EpilogueOp, TileShape, decltype (tile_shape(TiledMma()))>;
6969using SharedStorage = FusionCallBacks::SharedStorage;
7070
7171using ClusterShape = typename DispatchPolicy::ClusterShape;
@@ -79,7 +79,7 @@ using ClusterShape = typename DispatchPolicy::ClusterShape;
7979 using TileSchedulerParams = typename TileScheduler::Params;
8080
8181 using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
82- cutlass::epilogue::IntelPVCEpilogue ,
82+ cutlass::epilogue::IntelXeXMX16 ,
8383 TileShape,
8484 ElementAccumulator,
8585 cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>, // Convert CUTLASS 2.x to CUTLASS 3.x representation
@@ -280,8 +280,8 @@ class kgemv_4bit_inference_cutlass_cute {
280280 constexpr auto workgroup_shape = WorkgroupTileShape{};
281281 constexpr auto subgroup_shape = SubgroupTileShape{};
282282
283- Tensor mA_mkl = cute::get_pvc_tensor (make_shape (M,K,L)); // (m,k,l)
284- Tensor mB_nkl = cute::get_pvc_tensor (make_shape (N,K,L)); // (n,k,l)
283+ Tensor mA_mkl = cute::get_xe_tensor (make_shape (M,K,L)); // (m,k,l)
284+ Tensor mB_nkl = cute::get_xe_tensor (make_shape (N,K,L)); // (n,k,l)
285285
286286 Tensor gA = local_tile (mA_mkl , select<0 ,2 >(blk_shape), make_coord (m_coord,_,l_coord));
287287 Tensor gB = local_tile (mB_nkl , select<1 ,2 >(blk_shape), make_coord (n_coord,_,l_coord));
0 commit comments