Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rtp_llm/cpp/config/ConfigModules.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ struct FMHAConfig {
bool enable_xqa = true;
bool use_aiter_pa = true;
bool use_asm_pa = true;
bool use_triton_pa = true;
bool use_triton_pa = false;
int64_t absorb_opt_len = 1024;
std::string to_string() const;
};
Expand Down
2 changes: 1 addition & 1 deletion rtp_llm/models_py/modules/factory/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
PREFILL_MHA_IMPS.append(AiterPrefillImplPaged)
PREFILL_MHA_IMPS.append(AiterPrefillImplAsm)
PREFILL_MHA_IMPS.append(AiterPrefillImplNonAsm)
DECODE_MHA_IMPS.append(AiterDecodeImplTriton)
DECODE_MHA_IMPS.append(AiterDecodeImplAsm)
DECODE_MHA_IMPS.append(AiterDecodeImplNonAsm)
DECODE_MHA_IMPS.append(AiterDecodeImplTriton)
else:
# currently append early means impl has higher priority
if device_type == DeviceType.Cuda:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,14 @@ def _is_fmha_impl_disabled(
# Aiter ASM / Paged prefill
elif (
"AiterPrefillImplAsm" in impl_class_name
or "AiterDecodeImplAsm" in impl_class_name
or "AiterPrefillImplPaged" in impl_class_name
):
return not fmha_config.use_asm_pa
# Aiter ASM decode — disabled when triton PA is enabled (triton PA takes priority)
elif "AiterDecodeImplAsm" in impl_class_name:
if fmha_config.use_triton_pa:
return True
return not fmha_config.use_asm_pa
# Aiter Non-ASM implementations
elif (
"AiterPrefillImplNonAsm" in impl_class_name
Expand Down
8 changes: 8 additions & 0 deletions rtp_llm/server/server_args/fmha_group_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ def init_fmha_group_args(parser, fmha_config):
default=True,
help="Rocm是否使用AITER ASM Attention",
)
fmha_group.add_argument(
"--use_triton_pa",
env_name="USE_TRITON_PA",
bind_to=(fmha_config, "use_triton_pa"),
type=str2bool,
default=False,
help="Rocm decode阶段是否使用Triton PA",
)
fmha_group.add_argument(
"--absorb_opt_len",
env_name="RTP_LLM_ABSORB_OPT_LEN",
Expand Down
Loading