Skip to content

Commit 2a095f4

Browse files
committed
[None][fix] read attention layout by schema position
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
1 parent 4f8e1ef commit 2a095f4

5 files changed

Lines changed: 5 additions & 38 deletions

File tree

tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -565,14 +565,7 @@ def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCalla
565565
@classmethod
566566
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
567567
# Sanity check: layout == "bsnd"
568-
# Prefer kwargs; fall back to the final positional arg if it's a string.
569-
layout = source_attn_node.kwargs.get("layout", None)
570-
if (
571-
layout is None
572-
and len(source_attn_node.args) > 0
573-
and isinstance(source_attn_node.args[-1], str)
574-
):
575-
layout = source_attn_node.args[-1]
568+
layout = extract_op_args(source_attn_node, "layout")[0]
576569
if layout != "bsnd":
577570
raise RuntimeError(
578571
f"Expected torch_attention layout='bsnd' but got {layout!r} "

tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -499,14 +499,7 @@ def get_dynamic_inputs(cls, source_attn_node: Node) -> List[Optional[Node]]:
499499
@classmethod
500500
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
501501
# Sanity check: layout == "bsnd"
502-
# Prefer kwargs; fall back to the final positional arg if it's a string.
503-
layout = source_attn_node.kwargs.get("layout", None)
504-
if (
505-
layout is None
506-
and len(source_attn_node.args) > 0
507-
and isinstance(source_attn_node.args[-1], str)
508-
):
509-
layout = source_attn_node.args[-1]
502+
layout = extract_op_args(source_attn_node, "layout")[0]
510503
if layout != "bsnd":
511504
raise RuntimeError(
512505
f"Expected torch_attention layout='bsnd' but got {layout!r} "

tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -388,14 +388,7 @@ def get_cache_initializers(
388388
@classmethod
389389
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
390390
# Sanity check: layout == "bsnd"
391-
# Prefer kwargs; fall back to the final positional arg if it's a string.
392-
layout = source_attn_node.kwargs.get("layout", None)
393-
if (
394-
layout is None
395-
and len(source_attn_node.args) > 0
396-
and isinstance(source_attn_node.args[-1], str)
397-
):
398-
layout = source_attn_node.args[-1]
391+
layout = extract_op_args(source_attn_node, "layout")[0]
399392
if layout != "bsnd":
400393
raise RuntimeError(
401394
f"Expected torch_attention layout='bsnd' but got {layout!r} "

tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,13 +1359,7 @@ def get_dynamic_inputs(cls, source_attn_node: Node) -> List[Optional[Node]]:
13591359

13601360
@classmethod
13611361
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
1362-
layout = source_attn_node.kwargs.get("layout", None)
1363-
if (
1364-
layout is None
1365-
and len(source_attn_node.args) > 0
1366-
and isinstance(source_attn_node.args[-1], str)
1367-
):
1368-
layout = source_attn_node.args[-1]
1362+
layout = extract_op_args(source_attn_node, "layout")[0]
13691363
if layout != "bsnd":
13701364
raise RuntimeError(
13711365
f"Expected torch_attention layout='bsnd' but got {layout!r} "

tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -639,13 +639,7 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
639639
from tensor shapes or SequenceInfo metadata at runtime.
640640
"""
641641
# Sanity check: layout == "bsnd"
642-
layout = source_attn_node.kwargs.get("layout", None)
643-
if (
644-
layout is None
645-
and len(source_attn_node.args) > 0
646-
and isinstance(source_attn_node.args[-1], str)
647-
):
648-
layout = source_attn_node.args[-1]
642+
layout = extract_op_args(source_attn_node, "layout")[0]
649643
if layout != "bsnd":
650644
raise RuntimeError(
651645
f"Expected torch_attention layout='bsnd' but got {layout!r} "

0 commit comments

Comments
 (0)