Skip to content

Commit 5333624

Browse files
committed
compat old api
1 parent f013e10 commit 5333624

2 files changed

Lines changed: 118 additions & 41 deletions

File tree

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@
7474

7575
FLASH_ATTN_VERSION = None
7676

77+
from fastdeploy.model_executor.utils import try_import
78+
7779

7880
def init_flash_attn_version():
7981
"""
@@ -86,14 +88,25 @@ def init_flash_attn_version():
8688
try:
8789
paddle.enable_compat(scope={"cutlass"})
8890
try:
89-
from paddlefleet_ops import is_flash_mask_available
90-
91-
if is_flash_mask_available():
92-
from paddlefleet_ops.flash_mask.cute.interface import (
93-
flashmask_attention as fa4,
94-
)
91+
old_api = try_import(["paddlefleet.ops"])
92+
if old_api is not None:
93+
from paddlefleet.ops import is_flash_mask_available
94+
95+
if is_flash_mask_available():
96+
from paddlefleet.ops.flash_mask.cute.interface import (
97+
flashmask_attention as fa4,
98+
)
99+
else:
100+
raise ModuleNotFoundError("flash_mask not available.")
95101
else:
96-
raise ModuleNotFoundError("flash_mask not available.")
102+
from paddlefleet_ops import is_flash_mask_available
103+
104+
if is_flash_mask_available():
105+
from paddlefleet_ops.flash_mask.cute.interface import (
106+
flashmask_attention as fa4,
107+
)
108+
else:
109+
raise ModuleNotFoundError("flash_mask not available.")
97110

98111
except (ImportError, ModuleNotFoundError):
99112
logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.")

tests/layers/test_flash_attn_func.py

Lines changed: 98 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -212,20 +212,36 @@ def test_fa4(self):
212212
class TestInitFlashAttnVersion(unittest.TestCase):
213213
"""Tests for the init_flash_attn_version FA4 import branch (sm>=100)."""
214214

215+
_MODULE_NAMES = (
216+
"paddlefleet",
217+
"paddlefleet.ops",
218+
"paddlefleet.ops.flash_mask",
219+
"paddlefleet.ops.flash_mask.cute",
220+
"paddlefleet.ops.flash_mask.cute.interface",
221+
"paddlefleet_ops",
222+
"paddlefleet_ops.flash_mask",
223+
"paddlefleet_ops.flash_mask.cute",
224+
"paddlefleet_ops.flash_mask.cute.interface",
225+
)
226+
215227
def setUp(self):
216228
# Save state to restore after each test.
217229
self._saved_version = flash_attn_backend.FLASH_ATTN_VERSION
218230
self._saved_v4 = flash_attn_backend.flashmask_attention_v4
219-
self._saved_modules = {
220-
name: sys.modules.get(name)
221-
for name in (
222-
"paddlefleet",
223-
"paddlefleet_ops",
224-
"paddlefleet_ops.flash_mask",
225-
"paddlefleet_ops.flash_mask.cute",
226-
"paddlefleet_ops.flash_mask.cute.interface",
227-
)
228-
}
231+
self._saved_modules = {name: sys.modules.get(name) for name in self._MODULE_NAMES}
232+
# Make sure each test starts with a clean module state.
233+
for name in self._MODULE_NAMES:
234+
sys.modules.pop(name, None)
235+
236+
def _block_old_api(self):
237+
"""Force `paddlefleet.ops` import to fail regardless of what is installed."""
238+
# Setting sys.modules[name] = None makes importlib.import_module raise ImportError.
239+
sys.modules["paddlefleet"] = None
240+
sys.modules["paddlefleet.ops"] = None
241+
242+
def _block_new_api(self):
243+
"""Force `paddlefleet_ops` import to fail regardless of what is installed."""
244+
sys.modules["paddlefleet_ops"] = None
229245

230246
def tearDown(self):
231247
flash_attn_backend.FLASH_ATTN_VERSION = self._saved_version
@@ -236,10 +252,30 @@ def tearDown(self):
236252
else:
237253
sys.modules[name] = mod
238254

239-
def _install_fake_paddlefleet(self, is_available: bool):
240-
"""Inject fake paddlefleet modules so the inner imports succeed."""
255+
def _install_fake_paddlefleet_old_api(self, is_available: bool):
256+
"""Inject fake `paddlefleet.ops` (old API) modules."""
241257
pkg = types.ModuleType("paddlefleet")
242258
pkg.__path__ = []
259+
ops = types.ModuleType("paddlefleet.ops")
260+
ops.__path__ = []
261+
ops.is_flash_mask_available = lambda: is_available
262+
pkg.ops = ops
263+
flash_mask = types.ModuleType("paddlefleet.ops.flash_mask")
264+
flash_mask.__path__ = []
265+
cute = types.ModuleType("paddlefleet.ops.flash_mask.cute")
266+
cute.__path__ = []
267+
interface = types.ModuleType("paddlefleet.ops.flash_mask.cute.interface")
268+
interface.flashmask_attention = mock.MagicMock(name="fa4_old")
269+
270+
sys.modules["paddlefleet"] = pkg
271+
sys.modules["paddlefleet.ops"] = ops
272+
sys.modules["paddlefleet.ops.flash_mask"] = flash_mask
273+
sys.modules["paddlefleet.ops.flash_mask.cute"] = cute
274+
sys.modules["paddlefleet.ops.flash_mask.cute.interface"] = interface
275+
return interface.flashmask_attention
276+
277+
def _install_fake_paddlefleet_new_api(self, is_available: bool):
278+
"""Inject fake `paddlefleet_ops` (new API) modules."""
243279
ops = types.ModuleType("paddlefleet_ops")
244280
ops.__path__ = []
245281
ops.is_flash_mask_available = lambda: is_available
@@ -248,18 +284,19 @@ def _install_fake_paddlefleet(self, is_available: bool):
248284
cute = types.ModuleType("paddlefleet_ops.flash_mask.cute")
249285
cute.__path__ = []
250286
interface = types.ModuleType("paddlefleet_ops.flash_mask.cute.interface")
251-
interface.flashmask_attention = mock.MagicMock(name="fa4")
287+
interface.flashmask_attention = mock.MagicMock(name="fa4_new")
252288

253-
sys.modules["paddlefleet"] = pkg
254289
sys.modules["paddlefleet_ops"] = ops
255290
sys.modules["paddlefleet_ops.flash_mask"] = flash_mask
256291
sys.modules["paddlefleet_ops.flash_mask.cute"] = cute
257292
sys.modules["paddlefleet_ops.flash_mask.cute.interface"] = interface
258293
return interface.flashmask_attention
259294

260-
def test_fa4_import_success(self):
261-
"""Covers lines 88, 89, 91, 92 (is_flash_mask_available True branch)."""
262-
fake_fa4 = self._install_fake_paddlefleet(is_available=True)
295+
def test_fa4_old_api_import_success(self):
296+
"""Old API (`paddlefleet.ops`) is preferred when available."""
297+
fake_fa4 = self._install_fake_paddlefleet_old_api(is_available=True)
298+
# Also install new API to verify the old API takes precedence.
299+
new_fa4 = self._install_fake_paddlefleet_new_api(is_available=True)
263300
flash_attn_backend.FLASH_ATTN_VERSION = None
264301
flash_attn_backend.flashmask_attention_v4 = None
265302

@@ -272,10 +309,12 @@ def test_fa4_import_success(self):
272309

273310
self.assertEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4)
274311
self.assertIs(flash_attn_backend.flashmask_attention_v4, fake_fa4)
312+
self.assertIsNot(flash_attn_backend.flashmask_attention_v4, new_fa4)
275313

276-
def test_fa4_flash_mask_unavailable(self):
277-
"""Covers lines 88, 89, 91, 96, 98, 99 (raise + except path)."""
278-
self._install_fake_paddlefleet(is_available=False)
314+
def test_fa4_old_api_flash_mask_unavailable(self):
315+
"""Old API present but `is_flash_mask_available` is False."""
316+
self._install_fake_paddlefleet_old_api(is_available=False)
317+
self._block_new_api()
279318
flash_attn_backend.FLASH_ATTN_VERSION = None
280319
flash_attn_backend.flashmask_attention_v4 = None
281320

@@ -284,28 +323,53 @@ def test_fa4_flash_mask_unavailable(self):
284323
mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100),
285324
mock.patch.object(paddle, "enable_compat", create=True, return_value=None),
286325
):
287-
# The inner except swallows ModuleNotFoundError, but `fa4` is then
288-
# unbound, so the outer block raises NameError (not ImportError),
289-
# which propagates. Verify the inner except actually executed by
290-
# checking that FA4 was not selected.
291326
try:
292327
flash_attn_backend.init_flash_attn_version()
293328
except NameError:
294329
pass
295330

296331
self.assertNotEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4)
297332

298-
def test_fa4_paddlefleet_import_error(self):
299-
"""Covers lines 88, 89, 98, 99 (ImportError caught by inner except)."""
300-
# Ensure paddlefleet import fails.
301-
for name in (
302-
"paddlefleet",
303-
"paddlefleet_ops",
304-
"paddlefleet_ops.flash_mask",
305-
"paddlefleet_ops.flash_mask.cute",
306-
"paddlefleet_ops.flash_mask.cute.interface",
333+
def test_fa4_new_api_import_success(self):
334+
"""Falls back to new API (`paddlefleet_ops`) when old API is missing."""
335+
fake_fa4 = self._install_fake_paddlefleet_new_api(is_available=True)
336+
self._block_old_api()
337+
flash_attn_backend.FLASH_ATTN_VERSION = None
338+
flash_attn_backend.flashmask_attention_v4 = None
339+
340+
with (
341+
mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True),
342+
mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100),
343+
mock.patch.object(paddle, "enable_compat", create=True, return_value=None),
307344
):
308-
sys.modules.pop(name, None)
345+
flash_attn_backend.init_flash_attn_version()
346+
347+
self.assertEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4)
348+
self.assertIs(flash_attn_backend.flashmask_attention_v4, fake_fa4)
349+
350+
def test_fa4_new_api_flash_mask_unavailable(self):
351+
"""New API present but `is_flash_mask_available` is False."""
352+
self._install_fake_paddlefleet_new_api(is_available=False)
353+
self._block_old_api()
354+
flash_attn_backend.FLASH_ATTN_VERSION = None
355+
flash_attn_backend.flashmask_attention_v4 = None
356+
357+
with (
358+
mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True),
359+
mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100),
360+
mock.patch.object(paddle, "enable_compat", create=True, return_value=None),
361+
):
362+
try:
363+
flash_attn_backend.init_flash_attn_version()
364+
except NameError:
365+
pass
366+
367+
self.assertNotEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4)
368+
369+
def test_fa4_paddlefleet_import_error(self):
370+
"""Neither old nor new API is importable."""
371+
self._block_old_api()
372+
self._block_new_api()
309373
flash_attn_backend.FLASH_ATTN_VERSION = None
310374
flash_attn_backend.flashmask_attention_v4 = None
311375

0 commit comments

Comments
 (0)