3030QWEN3_MOE_FOPE_PATH = os .environ ["QWEN3_MOE_FOPE_PATH" ]
3131
3232
33- # Skip fope tests for transformers >= 5.2.0 due to SlidingWindowCache incompatibility
34- # in the model's remote code
35- def skip_if_fope_incompatible (model_type ):
36- """Skip fope model tests if transformers version is incompatible."""
37- if model_type == "qwen3_moe_fope" and Version (transformers_version ) >= Version ("5.2.0" ):
38- return True
39- return False
40-
41-
4233class TestQwen3MoE (DeterministicDDPTestCase ):
4334 def prepare (self ):
4435 self .temp_dir = tempfile .TemporaryDirectory ()
@@ -58,8 +49,6 @@ def prepare(self):
5849 )
5950 def test_qwen3_moe_run (self , device , dispatcher , ep_size , compile , tol , loss_mode , model_type ):
6051 assert model_type in ["qwen3_moe" , "qwen3_moe_fope" ]
61- if skip_if_fope_incompatible (model_type ):
62- raise unittest .SkipTest (f"Skipping fope test for transformers { transformers_version } due to SlidingWindowCache incompatibility" )
6352 os .environ ["TRITON_CACHE_DIR" ] = str (Path (self .temp_dir .name ) / "triton_cache" )
6453 self .create_pg (device )
6554
@@ -99,7 +88,7 @@ def test_qwen3_moe_run(self, device, dispatcher, ep_size, compile, tol, loss_mod
9988 cfg .compile_cfg = False
10089 cfg .dispatcher = dispatcher
10190 cfg .ep_size = ep_size
102- qwen_model = cfg .build ().to ( torch .bfloat16 )
91+ qwen_model = cfg .build ()._to_device_dtype ( dtype = torch .bfloat16 , skip_buffers_dtype = True )
10392 qwen_model .from_hf (hf_model_path )
10493
10594 losses = []
@@ -139,8 +128,6 @@ def test_qwen3_moe_run(self, device, dispatcher, ep_size, compile, tol, loss_mod
139128 )
140129 def test_fsdp_accuracy (self , device , dispatcher , ep_size , model_type ):
141130 assert model_type in ["qwen3_moe" , "qwen3_moe_fope" ]
142- if skip_if_fope_incompatible (model_type ):
143- raise unittest .SkipTest (f"Skipping fope test for transformers { transformers_version } due to SlidingWindowCache incompatibility" )
144131 self .create_pg (device )
145132
146133 hf_model_path = QWEN3_MOE_PATH if model_type == "qwen3_moe" else QWEN3_MOE_FOPE_PATH
@@ -179,7 +166,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size, model_type):
179166 cfg .compile_cfg = False
180167 cfg .ep_size = ep_size
181168 cfg .dispatcher = dispatcher
182- qwen_model = cfg .build ().to ( torch .bfloat16 )
169+ qwen_model = cfg .build ()._to_device_dtype ( dtype = torch .bfloat16 , skip_buffers_dtype = True )
183170
184171 fsdp_config = FSDPConfig (
185172 ep_size = ep_size ,
@@ -212,7 +199,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size, model_type):
212199 loss = output ["loss" ]
213200 losses .append (loss )
214201
215- self ._check_loss_curve (losses = torch .tensor (losses ), losses_ref = torch .tensor (expected_losses ), sim_tol = 1e -2 , rtol = 1e -2 )
202+ self ._check_loss_curve (losses = torch .tensor (losses ), losses_ref = torch .tensor (expected_losses ), sim_tol = 3e -2 , rtol = 3e -2 )
216203
217204 @parametrize .parametrize (
218205 "use_sliding_window, max_window_layers, sliding_window" ,
@@ -235,7 +222,7 @@ def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_wi
235222 use_sliding_window = use_sliding_window ,
236223 max_window_layers = max_window_layers ,
237224 attention = attention )
238- qwen_model = cfg .build ().to ( torch .bfloat16 )
225+ qwen_model = cfg .build ()._to_device_dtype ( dtype = torch .bfloat16 , skip_buffers_dtype = True )
239226 loss_cfg = CELossConfig ()
240227
241228 if use_sliding_window is False or max_window_layers >= num_hidden_layers :
@@ -264,7 +251,7 @@ def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_wi
264251 use_sliding_window = use_sliding_window ,
265252 max_window_layers = max_window_layers ,
266253 attention = attention )
267- qwen_model = cfg .build ().to ( torch .bfloat16 )
254+ qwen_model = cfg .build ()._to_device_dtype ( dtype = torch .bfloat16 , skip_buffers_dtype = True )
268255
269256 fsdp_config = FSDPConfig ()
270257 tokenizer = AutoTokenizer .from_pretrained (QWEN3_MOE_PATH , trust_remote_code = True )
@@ -303,7 +290,7 @@ def test_save_hf(self, device, dispatcher, ep_size):
303290 cfg = Qwen3MoE30BA3Config ()
304291 cfg .dispatcher = dispatcher
305292 cfg .ep_size = ep_size
306- qwen_model = cfg .build ().to ( torch .bfloat16 )
293+ qwen_model = cfg .build ()._to_device_dtype ( dtype = torch .bfloat16 , skip_buffers_dtype = True )
307294
308295 fsdp_config = FSDPConfig (
309296 ep_size = ep_size ,
@@ -382,8 +369,6 @@ def test_fope_auto_config_with_remote_code(self):
382369 ],
383370 )
384371 def test_save_hf_fope (self , device , dispatcher , ep_size ):
385- if Version (transformers_version ) >= Version ("5.2.0" ):
386- raise unittest .SkipTest (f"Skipping fope test for transformers { transformers_version } due to SlidingWindowCache incompatibility" )
387372 self .create_pg (device )
388373 with tempfile .TemporaryDirectory () as tmpdir :
389374 load_from = Path (QWEN3_MOE_FOPE_PATH )
@@ -518,4 +503,3 @@ def check_dict_equal(dict1: dict, dict2: dict) -> bool:
518503 print (f"[ERROR] key { key } value is not equal" )
519504 return False
520505 return True
521-
0 commit comments