@@ -120,7 +120,7 @@ def get_train_seq_ctx(
120120):
121121 seq_ctx = SequenceContext .from_input_ids ((input_ids ,), device = "cpu" )
122122 if multimodal_train_info and len (multimodal_train_info ) > 0 :
123- position_ids = multimodal_train_info .pop ("position_ids" ) # (1,n) or (3,1,n)
123+ position_ids = multimodal_train_info .get ("position_ids" ) # (1,n) or (3,1,n)
124124 if position_ids is not None and len (position_ids .shape ) == 3 :
125125 # qwen3vl 需要特殊处理,其余的不需要额外处理
126126 max_value = position_ids .max (dim = - 1 ).values # (3,1)
@@ -130,9 +130,8 @@ def get_train_seq_ctx(
130130 position_ids = torch .cat ([position_ids , response_position_ids ], dim = - 1 )
131131 seq_ctx .position_ids = position_ids # type: ignore[assignment]
132132 assert position_ids .size (- 1 ) == input_ids .size (- 1 )
133- seq_ctx .pixel_values = multimodal_train_info .pop ("pixel_values" )
134- seq_ctx .image_grid_thw = multimodal_train_info .pop ("image_grid_thw" )
135- del multimodal_train_info
133+ seq_ctx .pixel_values = multimodal_train_info .get ("pixel_values" )
134+ seq_ctx .image_grid_thw = multimodal_train_info .get ("image_grid_thw" )
136135 return seq_ctx
137136
138137
@@ -803,6 +802,8 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf
803802 seq_ctx .rollout_routed_experts = routed_experts # n,layer,expert
804803
805804 data_batches .append (data_dict )
805+ if multimodal_train_info is not None :
806+ del multimodal_train_info
806807 random .shuffle (data_batches )
807808
808809 rewards_t = torch .tensor (rewards_list ).float () if rewards_list else torch .tensor ([0.0 ]).float ()
0 commit comments