@@ -212,20 +212,36 @@ def test_fa4(self):
212212class TestInitFlashAttnVersion (unittest .TestCase ):
213213 """Tests for the init_flash_attn_version FA4 import branch (sm>=100)."""
214214
215+ _MODULE_NAMES = (
216+ "paddlefleet" ,
217+ "paddlefleet.ops" ,
218+ "paddlefleet.ops.flash_mask" ,
219+ "paddlefleet.ops.flash_mask.cute" ,
220+ "paddlefleet.ops.flash_mask.cute.interface" ,
221+ "paddlefleet_ops" ,
222+ "paddlefleet_ops.flash_mask" ,
223+ "paddlefleet_ops.flash_mask.cute" ,
224+ "paddlefleet_ops.flash_mask.cute.interface" ,
225+ )
226+
215227 def setUp (self ):
216228 # Save state to restore after each test.
217229 self ._saved_version = flash_attn_backend .FLASH_ATTN_VERSION
218230 self ._saved_v4 = flash_attn_backend .flashmask_attention_v4
219- self ._saved_modules = {
220- name : sys .modules .get (name )
221- for name in (
222- "paddlefleet" ,
223- "paddlefleet_ops" ,
224- "paddlefleet_ops.flash_mask" ,
225- "paddlefleet_ops.flash_mask.cute" ,
226- "paddlefleet_ops.flash_mask.cute.interface" ,
227- )
228- }
231+ self ._saved_modules = {name : sys .modules .get (name ) for name in self ._MODULE_NAMES }
232+ # Make sure each test starts with a clean module state.
233+ for name in self ._MODULE_NAMES :
234+ sys .modules .pop (name , None )
235+
236+ def _block_old_api (self ):
237+ """Force `paddlefleet.ops` import to fail regardless of what is installed."""
238+ # Setting sys.modules[name] = None makes importlib.import_module raise ImportError.
239+ sys .modules ["paddlefleet" ] = None
240+ sys .modules ["paddlefleet.ops" ] = None
241+
242+ def _block_new_api (self ):
243+ """Force `paddlefleet_ops` import to fail regardless of what is installed."""
244+ sys .modules ["paddlefleet_ops" ] = None
229245
230246 def tearDown (self ):
231247 flash_attn_backend .FLASH_ATTN_VERSION = self ._saved_version
@@ -236,10 +252,30 @@ def tearDown(self):
236252 else :
237253 sys .modules [name ] = mod
238254
239- def _install_fake_paddlefleet (self , is_available : bool ):
240- """Inject fake paddlefleet modules so the inner imports succeed ."""
255+ def _install_fake_paddlefleet_old_api (self , is_available : bool ):
256+ """Inject fake ` paddlefleet.ops` (old API) modules ."""
241257 pkg = types .ModuleType ("paddlefleet" )
242258 pkg .__path__ = []
259+ ops = types .ModuleType ("paddlefleet.ops" )
260+ ops .__path__ = []
261+ ops .is_flash_mask_available = lambda : is_available
262+ pkg .ops = ops
263+ flash_mask = types .ModuleType ("paddlefleet.ops.flash_mask" )
264+ flash_mask .__path__ = []
265+ cute = types .ModuleType ("paddlefleet.ops.flash_mask.cute" )
266+ cute .__path__ = []
267+ interface = types .ModuleType ("paddlefleet.ops.flash_mask.cute.interface" )
268+ interface .flashmask_attention = mock .MagicMock (name = "fa4_old" )
269+
270+ sys .modules ["paddlefleet" ] = pkg
271+ sys .modules ["paddlefleet.ops" ] = ops
272+ sys .modules ["paddlefleet.ops.flash_mask" ] = flash_mask
273+ sys .modules ["paddlefleet.ops.flash_mask.cute" ] = cute
274+ sys .modules ["paddlefleet.ops.flash_mask.cute.interface" ] = interface
275+ return interface .flashmask_attention
276+
277+ def _install_fake_paddlefleet_new_api (self , is_available : bool ):
278+ """Inject fake `paddlefleet_ops` (new API) modules."""
243279 ops = types .ModuleType ("paddlefleet_ops" )
244280 ops .__path__ = []
245281 ops .is_flash_mask_available = lambda : is_available
@@ -248,18 +284,19 @@ def _install_fake_paddlefleet(self, is_available: bool):
248284 cute = types .ModuleType ("paddlefleet_ops.flash_mask.cute" )
249285 cute .__path__ = []
250286 interface = types .ModuleType ("paddlefleet_ops.flash_mask.cute.interface" )
251- interface .flashmask_attention = mock .MagicMock (name = "fa4 " )
287+ interface .flashmask_attention = mock .MagicMock (name = "fa4_new " )
252288
253- sys .modules ["paddlefleet" ] = pkg
254289 sys .modules ["paddlefleet_ops" ] = ops
255290 sys .modules ["paddlefleet_ops.flash_mask" ] = flash_mask
256291 sys .modules ["paddlefleet_ops.flash_mask.cute" ] = cute
257292 sys .modules ["paddlefleet_ops.flash_mask.cute.interface" ] = interface
258293 return interface .flashmask_attention
259294
260- def test_fa4_import_success (self ):
261- """Covers lines 88, 89, 91, 92 (is_flash_mask_available True branch)."""
262- fake_fa4 = self ._install_fake_paddlefleet (is_available = True )
295+ def test_fa4_old_api_import_success (self ):
296+ """Old API (`paddlefleet.ops`) is preferred when available."""
297+ fake_fa4 = self ._install_fake_paddlefleet_old_api (is_available = True )
298+ # Also install new API to verify the old API takes precedence.
299+ new_fa4 = self ._install_fake_paddlefleet_new_api (is_available = True )
263300 flash_attn_backend .FLASH_ATTN_VERSION = None
264301 flash_attn_backend .flashmask_attention_v4 = None
265302
@@ -272,10 +309,12 @@ def test_fa4_import_success(self):
272309
273310 self .assertEqual (flash_attn_backend .FLASH_ATTN_VERSION , 4 )
274311 self .assertIs (flash_attn_backend .flashmask_attention_v4 , fake_fa4 )
312+ self .assertIsNot (flash_attn_backend .flashmask_attention_v4 , new_fa4 )
275313
276- def test_fa4_flash_mask_unavailable (self ):
277- """Covers lines 88, 89, 91, 96, 98, 99 (raise + except path)."""
278- self ._install_fake_paddlefleet (is_available = False )
314+ def test_fa4_old_api_flash_mask_unavailable (self ):
315+ """Old API present but `is_flash_mask_available` is False."""
316+ self ._install_fake_paddlefleet_old_api (is_available = False )
317+ self ._block_new_api ()
279318 flash_attn_backend .FLASH_ATTN_VERSION = None
280319 flash_attn_backend .flashmask_attention_v4 = None
281320
@@ -284,28 +323,53 @@ def test_fa4_flash_mask_unavailable(self):
284323 mock .patch .object (flash_attn_backend , "get_sm_version" , return_value = 100 ),
285324 mock .patch .object (paddle , "enable_compat" , create = True , return_value = None ),
286325 ):
287- # The inner except swallows ModuleNotFoundError, but `fa4` is then
288- # unbound, so the outer block raises NameError (not ImportError),
289- # which propagates. Verify the inner except actually executed by
290- # checking that FA4 was not selected.
291326 try :
292327 flash_attn_backend .init_flash_attn_version ()
293328 except NameError :
294329 pass
295330
296331 self .assertNotEqual (flash_attn_backend .FLASH_ATTN_VERSION , 4 )
297332
298- def test_fa4_paddlefleet_import_error (self ):
299- """Covers lines 88, 89, 98, 99 (ImportError caught by inner except)."""
300- # Ensure paddlefleet import fails.
301- for name in (
302- "paddlefleet" ,
303- "paddlefleet_ops" ,
304- "paddlefleet_ops.flash_mask" ,
305- "paddlefleet_ops.flash_mask.cute" ,
306- "paddlefleet_ops.flash_mask.cute.interface" ,
333+ def test_fa4_new_api_import_success (self ):
334+ """Falls back to new API (`paddlefleet_ops`) when old API is missing."""
335+ fake_fa4 = self ._install_fake_paddlefleet_new_api (is_available = True )
336+ self ._block_old_api ()
337+ flash_attn_backend .FLASH_ATTN_VERSION = None
338+ flash_attn_backend .flashmask_attention_v4 = None
339+
340+ with (
341+ mock .patch .object (flash_attn_backend .current_platform , "is_cuda" , return_value = True ),
342+ mock .patch .object (flash_attn_backend , "get_sm_version" , return_value = 100 ),
343+ mock .patch .object (paddle , "enable_compat" , create = True , return_value = None ),
307344 ):
308- sys .modules .pop (name , None )
345+ flash_attn_backend .init_flash_attn_version ()
346+
347+ self .assertEqual (flash_attn_backend .FLASH_ATTN_VERSION , 4 )
348+ self .assertIs (flash_attn_backend .flashmask_attention_v4 , fake_fa4 )
349+
350+ def test_fa4_new_api_flash_mask_unavailable (self ):
351+ """New API present but `is_flash_mask_available` is False."""
352+ self ._install_fake_paddlefleet_new_api (is_available = False )
353+ self ._block_old_api ()
354+ flash_attn_backend .FLASH_ATTN_VERSION = None
355+ flash_attn_backend .flashmask_attention_v4 = None
356+
357+ with (
358+ mock .patch .object (flash_attn_backend .current_platform , "is_cuda" , return_value = True ),
359+ mock .patch .object (flash_attn_backend , "get_sm_version" , return_value = 100 ),
360+ mock .patch .object (paddle , "enable_compat" , create = True , return_value = None ),
361+ ):
362+ try :
363+ flash_attn_backend .init_flash_attn_version ()
364+ except NameError :
365+ pass
366+
367+ self .assertNotEqual (flash_attn_backend .FLASH_ATTN_VERSION , 4 )
368+
369+ def test_fa4_paddlefleet_import_error (self ):
370+ """Neither old nor new API is importable."""
371+ self ._block_old_api ()
372+ self ._block_new_api ()
309373 flash_attn_backend .FLASH_ATTN_VERSION = None
310374 flash_attn_backend .flashmask_attention_v4 = None
311375
0 commit comments