diff --git a/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py b/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py index d1fff4fe2..c9a4d0cfe 100644 --- a/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py +++ b/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py @@ -25,6 +25,24 @@ FLASH_ATTN_AVAILABLE = False +def _get_sparse_attn_topk(default: float = 0.4) -> float: + raw_topk = os.getenv("SPARSE_ATTN_TOPK", str(default)) + try: + topk_value = float(raw_topk) + except (TypeError, ValueError): + logger.warning("Invalid SPARSE_ATTN_TOPK={!r}; falling back to {:.3f}", raw_topk, default) + return default + + if not 0.0 < topk_value <= 1.0: + logger.warning( + "SPARSE_ATTN_TOPK={!r} is outside (0.0, 1.0]; falling back to {:.3f}", + raw_topk, + default, + ) + return default + return topk_value + + @PLATFORM_ATTN_WEIGHT_REGISTER("flash_attn_hygon_dcu") class FlashAttnHygonDcu(AttnWeightTemplate): """ @@ -114,7 +132,7 @@ def half(x): softmax_scale = 1.0 / math.sqrt(q.shape[-1]) # Use Flash Attention 2.6.1 (ROCm version) with varlen interface if SAPRDE_LINEAR_ATTN and int(os.getenv("USE_SLA", 0)) and q.shape[1] == k.shape[1]: - topk_value = float(os.getenv("SPARSE_ATTN_TOPK", "0.5")) + topk_value = _get_sparse_attn_topk() q = q_flat.unsqueeze(0) k = k_flat.unsqueeze(0) @@ -124,7 +142,7 @@ def half(x): q, k, v, - topk=0.4, + topk=topk_value, ) else: output = flash_attn_varlen_func(