Skip to content

Commit d6967a1

Browse files
authored
[TRTLLM-12807][test] Guard thop attention kwarg aliases (#15335)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
1 parent 2ef2ea5 commit d6967a1

1 file changed

Lines changed: 60 additions & 2 deletions

File tree

tests/unittest/_torch/attention_backend/test_attention_op_sync.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@
5858
"forward_args": AttentionForwardArgs,
5959
}
6060

61+
_THOP_KWARG_SOURCE_ALIASES: dict[str, tuple[str, tuple[str, ...]]] = {
62+
"context_lengths": ("metadata", ("prompt_lens_cuda_runtime",)),
63+
"head_size": ("self", ("head_dim",)),
64+
"host_context_lengths": ("metadata", ("prompt_lens_cpu_runtime",)),
65+
"host_past_key_value_lengths": ("metadata", ("kv_lens_runtime",)),
66+
"host_request_types": ("metadata", ("host_request_types_runtime",)),
67+
"sequence_length": ("metadata", ("kv_lens_cuda_runtime",)),
68+
"spec_decoding_target_max_draft_tokens": (
69+
"metadata",
70+
("max_total_draft_tokens",),
71+
),
72+
"workspace_": ("metadata", ("effective_workspace",)),
73+
}
74+
6175
# The C++ attention() declaration is the single source of truth for kwarg
6276
# names, ordering, and types.
6377
_HEADER = pathlib.Path(__file__).resolve().parents[4] / ("cpp/tensorrt_llm/thop/attentionOp.h")
@@ -214,13 +228,39 @@ def _attribute_path(node: ast.AST) -> tuple[str, tuple[str, ...]] | None:
214228
return current.id, tuple(reversed(attrs))
215229

216230

231+
def _getattr_path(node: ast.AST) -> tuple[str, tuple[str, ...]] | None:
232+
"""If ``node`` is ``getattr(Name.attr..., "leaf", <default>)``, return
233+
``(root_name_id, (attr1, ..., leaf))``. Otherwise return ``None``.
234+
"""
235+
if (
236+
not isinstance(node, ast.Call)
237+
or not isinstance(node.func, ast.Name)
238+
or node.func.id != "getattr"
239+
or len(node.args) not in (2, 3)
240+
or not isinstance(node.args[1], ast.Constant)
241+
or not isinstance(node.args[1].value, str)
242+
):
243+
return None
244+
root_path: tuple[str, tuple[str, ...]]
245+
if isinstance(node.args[0], ast.Name):
246+
root_path = (node.args[0].id, ())
247+
else:
248+
path = _attribute_path(node.args[0])
249+
if path is None:
250+
return None
251+
root_path = path
252+
root, attrs = root_path
253+
return root, (*attrs, node.args[1].value)
254+
255+
217256
def _classify_kwargs() -> tuple[
218257
dict[str, tuple[str, tuple[str, ...]]], dict[str, object], set[str]
219258
]:
220259
"""Split the call site's kwargs into three buckets:
221260
222-
- ``attr_kwargs``: ``kwarg=source.attr[...]`` (or
223-
``kwarg=int(source.attr)``) → ``{kwarg: (root, path)}``.
261+
- ``attr_kwargs``: ``kwarg=source.attr[...]``,
262+
``kwarg=int(source.attr)``, or ``kwarg=getattr(source, "attr", ...)``
263+
→ ``{kwarg: (root, path)}``.
224264
- ``literal_kwargs``: ``kwarg=<constant>`` → ``{kwarg: value}``.
225265
- ``other_kwargs``: kwargs whose value is anything else (e.g. a bare
226266
Name like ``q``).
@@ -247,6 +287,10 @@ def _classify_kwargs() -> tuple[
247287
if isinstance(v, ast.Constant):
248288
literal_kwargs[kw.arg] = v.value
249289
continue
290+
path = _getattr_path(v)
291+
if path is not None:
292+
attr_kwargs[kw.arg] = path
293+
continue
250294
path = _attribute_path(v)
251295
if path is not None:
252296
attr_kwargs[kw.arg] = path
@@ -408,6 +452,20 @@ def test_each_source_attr_kwarg_resolves_uniquely():
408452
)
409453

410454

455+
def test_attr_kwarg_names_match_source_leaf_attrs_except_allowlisted_aliases():
456+
"""Most ``thop.attention`` kwargs should bind to a source attribute with
457+
the same name. Existing aliases must stay explicit so new semantic
458+
mismatches cannot slip in under a broad type-compatible mapping.
459+
"""
460+
attr_kwargs, _, _ = _classify_kwargs()
461+
aliases = {kwarg: source for kwarg, source in attr_kwargs.items() if kwarg != source[1][-1]}
462+
assert aliases == _THOP_KWARG_SOURCE_ALIASES, (
463+
"Unexpected thop kwarg/source attribute aliases.\n"
464+
f"new or changed aliases: {aliases.items() - _THOP_KWARG_SOURCE_ALIASES.items()}\n"
465+
f"stale allowlist entries: {_THOP_KWARG_SOURCE_ALIASES.items() - aliases.items()}"
466+
)
467+
468+
411469
def test_literal_kwargs_match_allowlist():
412470
"""Every literal-constant kwarg at the call site must appear in
413471
``_THOP_LITERALS`` with the matching value, and every entry in

0 commit comments

Comments
 (0)