@@ -154,6 +154,7 @@ class MoEConfig(TransformerConfig):
154154 moe_bias : bool = False
155155 moe_act_fn_cfg : MoEActFnConfig = MoEActFnConfig ()
156156 freeze_routers : bool = False
157+ vision_hidden_layers : int = 0
157158
158159 def build (self ) -> "MoE" :
159160 from xtuner .v1 .model .moe .moe import MoE
@@ -430,8 +431,8 @@ def _micro_batch_forward(
430431 with async_save_on_cpu (
431432 h2d_stream = self .offload_stream ,
432433 d2h_stream = self .offload_stream ,
433- block_idx = layer_idx - self .config .first_k_dense_replace ,
434- depth = len (self .layers ) - self .config .first_k_dense_replace ,
434+ block_idx = layer_idx - self .config .first_k_dense_replace + self . config . vision_hidden_layers ,
435+ depth = len (self .layers ) - self .config .first_k_dense_replace + self . config . vision_hidden_layers ,
435436 custom_check_fn = lambda x : x .data_ptr ()
436437 in [hidden_states .data_ptr () for hidden_states in hidden_states_list ],
437438 prefetch = True ,
@@ -577,8 +578,8 @@ def _forward(
577578 with async_save_on_cpu (
578579 h2d_stream = self .offload_stream ,
579580 d2h_stream = self .offload_stream ,
580- block_idx = int (idx ),
581- depth = len (self .layers ),
581+ block_idx = int (idx ) + self . config . vision_hidden_layers ,
582+ depth = len (self .layers ) + self . config . vision_hidden_layers ,
582583 custom_check_fn = lambda x : x .data_ptr () == hidden_states .data_ptr (),
583584 ):
584585 layer_results = decoder_layer (
0 commit comments