Skip to content

Commit e30ef24

Browse files
intel_gpu: fix Qwen3 MoE GEMM3_SWIGLU on MTL-class (non-systolic) iGPU
causing it to be skipped on MTL-class iGPU (12.70.x, XeHPG, no DPAS). This left raw FP32 weight-decompression chains that overwhelmed propagate_constants with ~56 GB of constant-folding memory. Root cause of inference failure: moe_3gemm_swiglu_opt uses oneDNN internally (onednn_linear for gate/up/down matrix multiplications). OneDNN requires an in-order OCL queue. MTL uses out-of-order queue by default because use_onednn is false when supports_immad=false. Fix: three MoE transformation passes (FuseVectorizedMOE3GEMM, ConvertMOEToMOECompressed, FuseMOE3GemmCompressed) run on all architectures. FuseMOE3GemmCompressed creates MOE3GemmFusedCompressed which the OCL moe_3gemm_swiglu_opt kernel executes. - Detect MOE3GemmFusedCompressed in apply_model_specific_options and force use_onednn=true so finalize_impl sets queue_type=in_order, satisfying the oneDNN in-order queue requirement. - Fix moe_gather validate_impl to accept rank-2 input for models where the batch dimension is pre-flattened (Qwen3-style). - Re-apply iGPU transfer skip (usm_shared -> usm_device) in network.cpp and program.cpp for integrated GPUs where both allocation types share system DRAM (xe2+ or 12.7x-class MTL/ARL-S). Tested on DUT1486ARLHx (12.70.4 / XeHPG / 64 GB): model loads in 14 s, generates meaningful tokens, Unevictable stays below 120 MB. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 2f873a5 commit e30ef24

6 files changed

Lines changed: 45 additions & 40 deletions

File tree

src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_gather.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ struct MoeGatherRef : public ImplementationManager {
3636
const auto& out_layout = node.get_output_layout(0);
3737
const auto& input_pshapes = in0_layout.get_partial_shape();
3838

39-
if (input_pshapes.rank() != 3 || input_pshapes[2].is_dynamic()) {
39+
// Accept rank-2 [tokens, hidden] (Qwen3-style, batch already flattened)
40+
// and rank-3 [batch, tokens, hidden]. The kernel only needs the last
41+
// dimension (hidden_size) to be static.
42+
const auto input_rank = input_pshapes.rank().get_length();
43+
if ((input_rank != 2 && input_rank != 3) || input_pshapes[input_rank - 1].is_dynamic()) {
4044
return false;
4145
}
4246

src/plugins/intel_gpu/src/graph/network.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,11 +1085,15 @@ void network::transfer_memory_to_device(std::shared_ptr<primitive_inst> instance
10851085
return;
10861086

10871087
if (alloc_type == allocation_type::usm_host || alloc_type == allocation_type::usm_shared) {
1088-
// usm_device memory does not provide performance benefits on the integrated Xe2+ platforms
1089-
if (get_engine().get_device_info().arch >= gpu_arch::xe2 &&
1090-
get_engine().get_device_info().dev_type == device_type::integrated_gpu) {
1088+
const auto& dev_info = get_engine().get_device_info();
1089+
const bool skip_transfer_on_igpu = dev_info.dev_type == device_type::integrated_gpu &&
1090+
(dev_info.arch >= gpu_arch::xe2 ||
1091+
(dev_info.gfx_ver.major == 12 && dev_info.gfx_ver.minor < 73));
1092+
// On MTL-class and xe2+ integrated GPUs, usm_shared and usm_device
1093+
// live in the same DRAM. Copying constant storage only inflates
1094+
// pinned memory without a corresponding benefit.
1095+
if (skip_transfer_on_igpu)
10911096
return;
1092-
}
10931097

10941098
// Allocate and transfer memory
10951099
auto device_mem = inst_mem.get_engine()->allocate_memory(inst_mem.get_layout(), allocation_type::usm_device, false);

src/plugins/intel_gpu/src/graph/program.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -734,10 +734,15 @@ void program::transfer_memory_to_device() {
734734
continue;
735735

736736
allocation_type target_alloc_type = alloc_type;
737-
// usm_device memory does not provide performance benefits on the LNL platform
737+
const auto& dev_info = get_engine().get_device_info();
738+
const bool skip_transfer_on_igpu = dev_info.dev_type == device_type::integrated_gpu &&
739+
(dev_info.arch >= gpu_arch::xe2 ||
740+
(dev_info.gfx_ver.major == 12 && dev_info.gfx_ver.minor < 73));
741+
// On MTL-class and xe2+ integrated GPUs, usm_shared and usm_device
742+
// live in the same DRAM. Copying constant storage only inflates
743+
// pinned memory without a corresponding benefit.
738744
if ((alloc_type == allocation_type::usm_host || alloc_type == allocation_type::usm_shared) &&
739-
!(get_engine().get_device_info().arch >= gpu_arch::xe2 &&
740-
get_engine().get_device_info().dev_type == device_type::integrated_gpu)) {
745+
!skip_transfer_on_igpu) {
741746
// Convert to usm_device for performance optimization
742747
target_alloc_type = allocation_type::usm_device;
743748
}

src/plugins/intel_gpu/src/plugin/ops/moe.cpp

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
//
44
#include "openvino/op/moe.hpp"
55

6-
#include <intel_gpu/primitives/eltwise.hpp>
76
#include <intel_gpu/primitives/moe_gather.hpp>
87
#include <intel_gpu/primitives/moe_scatter_reduction.hpp>
98
#include <intel_gpu/primitives/swiglu.hpp>
@@ -98,31 +97,11 @@ static void CreateMOECompressedOp(ProgramBuilder& p, const std::shared_ptr<ov::o
9897
input_infos.push_back(cldnn::input_info(input));
9998
}
10099
if (config.expert_type == ov::op::internal::MOE::Expert_type::GEMM3_SWIGLU) {
101-
// Create GEMM3_SWIGLU specific primitives
102-
// 0: hidden_states - input tensor with hidden representations
103-
// 1: routing_weights - [num_experts, ...] normalized weights for selected experts
104-
// (input to final multiplication)
105-
// 2: router_topk_output_indices - [..., topk] indices of selected top-k experts
106-
// 3: w0_weight - expert weights for first projection,
107-
// shape [num_experts, inter_size, group_num, group_size]
108-
// 4: w0_scale - expert scale for first projection for compressed experts,
109-
// shape [num_experts, inter_size, group_num, 1]
110-
// 5: w0_zp - expert zp for first projection for compressed experts,
111-
// shape [num_experts, inter_size, group_num, 1]
112-
// 6: w1_weight - expert weights for second projection,
113-
// shape [num_experts, inter_size, group_num, group_size]
114-
// 7: w1_scale - expert scale for second projection for compressed experts,
115-
// shape [num_experts, inter_size, group_num, 1]
116-
// 8: w1_zp - expert zp for second projection for compressed experts,
117-
// shape [num_experts, inter_size, group_num, 1]
118-
// 9: w2_weight - expert weights for final projection,
119-
// shape [num_experts, hidden_size, group_num, group_size]
120-
// 10: w2_scale - expert scale for final projection for compressed experts,
121-
// shape [num_experts, hidden_size, group_num, 1]
122-
// 11: w2_zp - expert zp for final projection for compressed experts,
123-
// shape [num_experts, hidden_size, group_num, 1]
124-
125-
// Use moe_3gemm_fused_compressed to replace it.
100+
// GEMM3_SWIGLU (Qwen3-style MoE) should be handled by FuseMOE3GemmCompressed
101+
// which converts MOECompressed(GEMM3_SWIGLU) → MOE3GemmFusedCompressed executed
102+
// by the OCL moe_3gemm_swiglu_opt kernel on all architectures. If execution
103+
// reaches here the transformation pipeline is misconfigured.
104+
OPENVINO_THROW("[GPU] MOECompressed(GEMM3_SWIGLU) must be handled by FuseMOE3GemmCompressed before program build");
126105
} else {
127106
// Create GEMM2_BIAS_SWIGLU_CLAMP specific primitives
128107
// input0 : input {#tokens, hidden_size}

src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -487,13 +487,17 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
487487
return (info.arch != cldnn::gpu_arch::xe2) && (info.arch != cldnn::gpu_arch::xe3);
488488
});
489489

490+
// FuseVectorizedMOE3GEMM converts the original vectorized MoE graph
491+
// (separate MatMul + scatter/gather ops) into MOE(GEMM3_SWIGLU) with
492+
// packed INT4 weights. This structural conversion must run on ALL
493+
// architectures so that ConvertMOEToMOECompressed can match the INT4
494+
// constants downstream. Without it the raw FP32 decompression chains
495+
// reach propagate_constants and cause OOM on MTL-class iGPU.
496+
//
497+
// FuseMOE3GemmCompressed converts MOECompressed(GEMM3_SWIGLU) into
498+
// MOE3GemmFusedCompressed, executed by the OCL moe_3gemm_swiglu_opt
499+
// kernel on all architectures including non-systolic (MTL-class) iGPU.
490500
manager.register_pass<ov::pass::FuseVectorizedMOE3GEMM>();
491-
pass_config->set_callback<ov::pass::FuseVectorizedMOE3GEMM>([&](const_node_ptr& root) -> bool {
492-
// Currently moe gemm3 is only supported by systolic-array architectures
493-
auto& engine = m_context->get_engine();
494-
const auto& info = engine.get_device_info();
495-
return (!info.supports_immad);
496-
});
497501

498502
bool is_pa = false;
499503
for (const auto& op : func->get_ops()) {

src/plugins/intel_gpu/src/runtime/execution_config.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "intel_gpu/op/indirect_sdpa.hpp"
88
#include "intel_gpu/op/kv_cache.hpp"
9+
#include "intel_gpu/op/moe_3gemm_fused_compressed.hpp"
910
#include "intel_gpu/op/sdpa.hpp"
1011
#include "intel_gpu/plugin/remote_context.hpp"
1112
#include "intel_gpu/primitives/paged_attention.hpp"
@@ -215,6 +216,14 @@ void ExecutionConfig::apply_model_specific_options(const IRemoteContext* context
215216
m_use_onednn = true;
216217
}
217218

219+
// moe_3gemm_fused_compressed uses oneDNN internally for matrix multiplications,
220+
// so it requires an in-order queue. Force use_onednn=true here so that
221+
// finalize_impl will set queue_type=in_order regardless of whether the
222+
// hardware supports systolic (supports_immad).
223+
if (ov::is_type<ov::intel_gpu::op::MOE3GemmFusedCompressed>(op)) {
224+
m_use_onednn = true;
225+
}
226+
218227
if (auto multi_subgraph_op = ov::as_type_ptr<ov::op::util::MultiSubGraphOp>(op)) {
219228
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
220229
for (auto& sub_op : sub_graph->get_ops()) {

0 commit comments

Comments
 (0)