Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@
FLASH_ATTN_AVAILABLE = False


def _get_sparse_attn_topk(default: float = 0.4) -> float:
raw_topk = os.getenv("SPARSE_ATTN_TOPK", str(default))
try:
topk_value = float(raw_topk)
except (TypeError, ValueError):
logger.warning("Invalid SPARSE_ATTN_TOPK={!r}; falling back to {:.3f}", raw_topk, default)
return default

if not 0.0 < topk_value <= 1.0:
logger.warning(
"SPARSE_ATTN_TOPK={!r} is outside (0.0, 1.0]; falling back to {:.3f}",
raw_topk,
default,
)
return default
return topk_value


@PLATFORM_ATTN_WEIGHT_REGISTER("flash_attn_hygon_dcu")
class FlashAttnHygonDcu(AttnWeightTemplate):
"""
Expand Down Expand Up @@ -114,7 +132,7 @@ def half(x):
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
# Use Flash Attention 2.6.1 (ROCm version) with varlen interface
if SAPRDE_LINEAR_ATTN and int(os.getenv("USE_SLA", 0)) and q.shape[1] == k.shape[1]:
topk_value = float(os.getenv("SPARSE_ATTN_TOPK", "0.5"))
topk_value = _get_sparse_attn_topk()

q = q_flat.unsqueeze(0)
k = k_flat.unsqueeze(0)
Expand All @@ -124,7 +142,7 @@ def half(x):
q,
k,
v,
topk=0.4,
topk=topk_value,
)
else:
output = flash_attn_varlen_func(
Expand Down