@@ -495,6 +495,22 @@ struct ggml_webgpu_binary_pipeline_key_hash {
495495 }
496496};
497497
498+ /* Add_Id */
499+
500+ struct ggml_webgpu_add_id_pipeline_key {
501+ bool inplace;
502+
503+ bool operator ==(const ggml_webgpu_add_id_pipeline_key & other) const { return inplace == other.inplace ; }
504+ };
505+
506+ struct ggml_webgpu_add_id_pipeline_key_hash {
507+ size_t operator ()(const ggml_webgpu_add_id_pipeline_key & key) const {
508+ size_t seed = 0 ;
509+ ggml_webgpu_hash_combine (seed, key.inplace );
510+ return seed;
511+ }
512+ };
513+
498514/* * Unary **/
499515
500516struct ggml_webgpu_unary_pipeline_key {
@@ -1058,7 +1074,9 @@ class ggml_webgpu_shader_lib {
10581074 std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
10591075 pad_pipelines; // circular/non-circular
10601076 std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
1061- binary_pipelines; // type/op/inplace/overlap
1077+ binary_pipelines; // type/op/inplace/overlap/src_overlap
1078+ std::unordered_map<ggml_webgpu_add_id_pipeline_key, webgpu_pipeline, ggml_webgpu_add_id_pipeline_key_hash>
1079+ add_id_pipelines; // inplace
10621080 std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
10631081 concat_pipelines; // type
10641082 std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
@@ -1433,6 +1451,7 @@ class ggml_webgpu_shader_lib {
14331451 case GGML_TYPE_IQ3_S:
14341452 case GGML_TYPE_IQ1_S:
14351453 case GGML_TYPE_IQ4_NL:
1454+ case GGML_TYPE_MXFP4:
14361455 {
14371456 // Quantized types using u32 buffers for portability.
14381457 defines.push_back (" SRC_TYPE=u32" );
@@ -1451,6 +1470,7 @@ class ggml_webgpu_shader_lib {
14511470 defines.push_back (type_upper + " _SCALE_MIN" );
14521471 defines.push_back (type_upper + " _TABLES" );
14531472 defines.push_back (type_upper + " _GRID" );
1473+ defines.push_back (type_upper + " _LUT" );
14541474
14551475 variant += " _" ;
14561476 variant += type_str;
@@ -1460,7 +1480,7 @@ class ggml_webgpu_shader_lib {
14601480 if (key.src_type == GGML_TYPE_Q1_0) {
14611481 defines.push_back (" BLOCK_SIZE=128u" );
14621482 } else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
1463- key.src_type == GGML_TYPE_IQ4_NL) {
1483+ key.src_type == GGML_TYPE_IQ4_NL || key. src_type == GGML_TYPE_MXFP4 ) {
14641484 defines.push_back (" BLOCK_SIZE=32u" );
14651485 } else if (key.src_type >= GGML_TYPE_Q2_K) {
14661486 defines.push_back (" BLOCK_SIZE=256u" );
@@ -1774,6 +1794,9 @@ class ggml_webgpu_shader_lib {
17741794 defines.push_back (type_upper + " _GRID" );
17751795 defines.push_back (type_upper + " _TABLES" );
17761796 break ;
1797+ case GGML_TYPE_MXFP4:
1798+ defines.push_back (type_upper + " _LUT" );
1799+ break ;
17771800 default :
17781801 break ;
17791802 }
@@ -1908,6 +1931,9 @@ class ggml_webgpu_shader_lib {
19081931 defines.push_back (type_upper + " _GRID" );
19091932 defines.push_back (type_upper + " _TABLES" );
19101933 break ;
1934+ case GGML_TYPE_MXFP4:
1935+ defines.push_back (type_upper + " _LUT" );
1936+ break ;
19111937 default :
19121938 break ;
19131939 }
@@ -2042,6 +2068,7 @@ class ggml_webgpu_shader_lib {
20422068 case GGML_TYPE_IQ3_S:
20432069 case GGML_TYPE_IQ1_S:
20442070 case GGML_TYPE_IQ4_NL:
2071+ case GGML_TYPE_MXFP4:
20452072 {
20462073 // Quantized types using u32 buffers for portability.
20472074 defines.push_back (" SRC0_TYPE=u32" );
@@ -2169,6 +2196,9 @@ class ggml_webgpu_shader_lib {
21692196 defines.push_back (type_upper + " _GRID" );
21702197 defines.push_back (type_upper + " _TABLES" );
21712198 break ;
2199+ case GGML_TYPE_MXFP4:
2200+ defines.push_back (type_upper + " _LUT" );
2201+ break ;
21722202 default :
21732203 break ;
21742204 }
@@ -2286,6 +2316,9 @@ class ggml_webgpu_shader_lib {
22862316 defines.push_back (type_upper + " _GRID" );
22872317 defines.push_back (type_upper + " _TABLES" );
22882318 break ;
2319+ case GGML_TYPE_MXFP4:
2320+ defines.push_back (type_upper + " _LUT" );
2321+ break ;
22892322 default :
22902323 break ;
22912324 }
@@ -2503,6 +2536,37 @@ class ggml_webgpu_shader_lib {
25032536 return binary_pipelines[key];
25042537 }
25052538
2539+ webgpu_pipeline get_add_id_pipeline (const ggml_webgpu_shader_lib_context & context) {
2540+ ggml_webgpu_add_id_pipeline_key key = {};
2541+ key.inplace = ggml_webgpu_tensor_equal (context.src0 , context.dst );
2542+
2543+ auto it = add_id_pipelines.find (key);
2544+ if (it != add_id_pipelines.end ()) {
2545+ return it->second ;
2546+ }
2547+
2548+ std::vector<std::string> defines;
2549+ std::string variant = " add_id" ;
2550+ const char * shader_src = wgsl_add_id;
2551+
2552+ if (key.inplace ) {
2553+ defines.push_back (" INPLACE" );
2554+ variant += " _inplace" ;
2555+ }
2556+
2557+ defines.push_back (std::string (" WG_SIZE=" ) + std::to_string (context.max_wg_size ));
2558+
2559+ auto processed = preprocessor.preprocess (shader_src, defines);
2560+ auto pipeline_decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
2561+ pipeline_decisions->wg_size = context.max_wg_size ;
2562+ pipeline_decisions->inplace = key.inplace ;
2563+
2564+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline (device, processed, variant);
2565+ pipeline.context = pipeline_decisions;
2566+ add_id_pipelines[key] = pipeline;
2567+ return pipeline;
2568+ }
2569+
25062570 webgpu_pipeline get_concat_pipeline (const ggml_webgpu_shader_lib_context & context) {
25072571 ggml_webgpu_concat_pipeline_key key = {};
25082572 key.type = context.dst ->type ;
0 commit comments