3939_WEIGHT_ALIGNMENT = 128
4040_MXFP4_DEFAULT_EXPERT_BLOCK_SIZE = 32
4141_MXFP4_SUPPORTED_EXPERT_BLOCK_SIZE = 32
42+ _MXFP4_VALUES_PER_BYTE = 2
4243_MXFP4_LAYOUT_ARG_NAMES = (
4344 "gate_up_blocks" ,
4445 "gate_up_scales" ,
@@ -375,12 +376,24 @@ def _register_existing_mxfp4_expert_layout_hooks(
375376 return num_hooks
376377
377378
379+ def _mxfp4_block_count (name : str , dim : int , expert_block_size : int ) -> int :
380+ if dim <= 0 :
381+ raise ValueError (f"MXFP4 expert { name } should be positive, got { dim } ." )
382+ if dim % expert_block_size != 0 :
383+ raise ValueError (
384+ f"MXFP4 expert { name } should be divisible by expert_block_size="
385+ f"{ expert_block_size } , got { dim } ."
386+ )
387+ return dim // expert_block_size
388+
389+
378390def _register_mxfp4_expert_params (
379391 gm : GraphModule ,
380392 gate_up_w_name : str ,
381393 gate_up_b_name : str ,
382394 down_w_name : str ,
383395 down_b_name : str ,
396+ expert_block_size : int = _MXFP4_DEFAULT_EXPERT_BLOCK_SIZE ,
384397) -> Tuple [str , str , str , str ]:
385398 """Create (if missing) the four MXFP4 params under the experts module and return their full names.
386399
@@ -404,9 +417,9 @@ def _register_mxfp4_expert_params(
404417 # Fallback: use down bias last dim
405418 H = int (dn_b .shape [1 ])
406419
407- # Compute block dims (assume divisible; zero-init anyway)
408- H_blk = max ( 1 , H // 32 )
409- I_blk = max ( 1 , In // 32 )
420+ packed_block_width = expert_block_size // _MXFP4_VALUES_PER_BYTE
421+ H_blk = _mxfp4_block_count ( "hidden_size" , H , expert_block_size )
422+ I_blk = _mxfp4_block_count ( "intermediate_size" , In , expert_block_size )
410423
411424 experts_mod , experts_path , _ = get_submodule_of_param (gm , gate_up_w_name )
412425
@@ -421,9 +434,13 @@ def _register_mxfp4_expert_params(
421434 # (meta in the normal meta-device build) so we don't materialize giant CPU
422435 # buffers before load.
423436 param_device = gu_w .device
424- gu_blocks = torch .empty ((E , 2 * In , H_blk , 16 ), dtype = torch .uint8 , device = param_device )
437+ gu_blocks = torch .empty (
438+ (E , 2 * In , H_blk , packed_block_width ), dtype = torch .uint8 , device = param_device
439+ )
425440 gu_scales = torch .empty ((E , 2 * In , H_blk ), dtype = torch .uint8 , device = param_device )
426- dn_blocks = torch .empty ((E , H , I_blk , 16 ), dtype = torch .uint8 , device = param_device )
441+ dn_blocks = torch .empty (
442+ (E , H , I_blk , packed_block_width ), dtype = torch .uint8 , device = param_device
443+ )
427444 dn_scales = torch .empty ((E , H , I_blk ), dtype = torch .uint8 , device = param_device )
428445
429446 experts_mod .register_parameter (gu_blocks_name , nn .Parameter (gu_blocks , requires_grad = False ))
@@ -722,8 +739,7 @@ def _apply(
722739 skipped = True , num_matches = 0 , is_clean = True , has_valid_shapes = True
723740 )
724741 checkpoint_layout = _get_packed_mxfp4_expert_layout (qcfg )
725- if checkpoint_layout is not None or qcfg .get ("expert_block_size" ) is not None :
726- _resolve_mxfp4_expert_block_size (qcfg , checkpoint_layout )
742+ expert_block_size = _resolve_mxfp4_expert_block_size (qcfg , checkpoint_layout )
727743 num_existing_hooks = 0
728744 if checkpoint_layout is not None :
729745 num_existing_hooks = _register_existing_mxfp4_expert_layout_hooks (
@@ -744,7 +760,9 @@ def _apply(
744760 ad_logger .info (f"quantize_mxfp4_moe: dispatching to backend={ backend !r} " )
745761
746762 if backend == "triton" :
747- gm , info = self ._apply_triton (gm , cm , factory , shared_config )
763+ gm , info = self ._apply_triton (
764+ gm , cm , factory , shared_config , expert_block_size = expert_block_size
765+ )
748766 elif backend == "trtllm" :
749767 gm , info = self ._apply_trtllm (gm , cm , factory , shared_config )
750768 else :
@@ -767,6 +785,8 @@ def _apply_triton(
767785 cm ,
768786 factory ,
769787 shared_config ,
788+ * ,
789+ expert_block_size : int = _MXFP4_DEFAULT_EXPERT_BLOCK_SIZE ,
770790 ) -> Tuple [GraphModule , TransformInfo ]:
771791 """Triton backend: graph rewrite to ``triton_mxfp4_moe``.
772792
@@ -820,7 +840,14 @@ def _apply_triton(
820840
821841 # Register MXFP4 params on experts
822842 gu_blocks_name , gu_scales_name , dn_blocks_name , dn_scales_name = (
823- _register_mxfp4_expert_params (gm , gu_w_name , gu_b_name , dn_w_name , dn_b_name )
843+ _register_mxfp4_expert_params (
844+ gm ,
845+ gu_w_name ,
846+ gu_b_name ,
847+ dn_w_name ,
848+ dn_b_name ,
849+ expert_block_size = expert_block_size ,
850+ )
824851 )
825852
826853 # Alpha/limit (from dense call)
0 commit comments