5858 "forward_args" : AttentionForwardArgs ,
5959}
6060
61+ _THOP_KWARG_SOURCE_ALIASES : dict [str , tuple [str , tuple [str , ...]]] = {
62+ "context_lengths" : ("metadata" , ("prompt_lens_cuda_runtime" ,)),
63+ "head_size" : ("self" , ("head_dim" ,)),
64+ "host_context_lengths" : ("metadata" , ("prompt_lens_cpu_runtime" ,)),
65+ "host_past_key_value_lengths" : ("metadata" , ("kv_lens_runtime" ,)),
66+ "host_request_types" : ("metadata" , ("host_request_types_runtime" ,)),
67+ "sequence_length" : ("metadata" , ("kv_lens_cuda_runtime" ,)),
68+ "spec_decoding_target_max_draft_tokens" : (
69+ "metadata" ,
70+ ("max_total_draft_tokens" ,),
71+ ),
72+ "workspace_" : ("metadata" , ("effective_workspace" ,)),
73+ }
74+
6175# The C++ attention() declaration is the single source of truth for kwarg
6276# names, ordering, and types.
6377_HEADER = pathlib .Path (__file__ ).resolve ().parents [4 ] / ("cpp/tensorrt_llm/thop/attentionOp.h" )
@@ -214,13 +228,39 @@ def _attribute_path(node: ast.AST) -> tuple[str, tuple[str, ...]] | None:
214228 return current .id , tuple (reversed (attrs ))
215229
216230
231+ def _getattr_path (node : ast .AST ) -> tuple [str , tuple [str , ...]] | None :
232+ """If ``node`` is ``getattr(Name.attr..., "leaf", <default>)``, return
233+ ``(root_name_id, (attr1, ..., leaf))``. Otherwise return ``None``.
234+ """
235+ if (
236+ not isinstance (node , ast .Call )
237+ or not isinstance (node .func , ast .Name )
238+ or node .func .id != "getattr"
239+ or len (node .args ) not in (2 , 3 )
240+ or not isinstance (node .args [1 ], ast .Constant )
241+ or not isinstance (node .args [1 ].value , str )
242+ ):
243+ return None
244+ root_path : tuple [str , tuple [str , ...]]
245+ if isinstance (node .args [0 ], ast .Name ):
246+ root_path = (node .args [0 ].id , ())
247+ else :
248+ path = _attribute_path (node .args [0 ])
249+ if path is None :
250+ return None
251+ root_path = path
252+ root , attrs = root_path
253+ return root , (* attrs , node .args [1 ].value )
254+
255+
217256def _classify_kwargs () -> tuple [
218257 dict [str , tuple [str , tuple [str , ...]]], dict [str , object ], set [str ]
219258]:
220259 """Split the call site's kwargs into three buckets:
221260
222- - ``attr_kwargs``: ``kwarg=source.attr[...]`` (or
223- ``kwarg=int(source.attr)``) → ``{kwarg: (root, path)}``.
261+ - ``attr_kwargs``: ``kwarg=source.attr[...]``,
262+ ``kwarg=int(source.attr)``, or ``kwarg=getattr(source, "attr", ...)``
263+ → ``{kwarg: (root, path)}``.
224264 - ``literal_kwargs``: ``kwarg=<constant>`` → ``{kwarg: value}``.
225265 - ``other_kwargs``: kwargs whose value is anything else (e.g. a bare
226266 Name like ``q``).
@@ -247,6 +287,10 @@ def _classify_kwargs() -> tuple[
247287 if isinstance (v , ast .Constant ):
248288 literal_kwargs [kw .arg ] = v .value
249289 continue
290+ path = _getattr_path (v )
291+ if path is not None :
292+ attr_kwargs [kw .arg ] = path
293+ continue
250294 path = _attribute_path (v )
251295 if path is not None :
252296 attr_kwargs [kw .arg ] = path
@@ -408,6 +452,20 @@ def test_each_source_attr_kwarg_resolves_uniquely():
408452 )
409453
410454
455+ def test_attr_kwarg_names_match_source_leaf_attrs_except_allowlisted_aliases ():
456+ """Most ``thop.attention`` kwargs should bind to a source attribute with
457+ the same name. Existing aliases must stay explicit so new semantic
458+ mismatches cannot slip in under a broad type-compatible mapping.
459+ """
460+ attr_kwargs , _ , _ = _classify_kwargs ()
461+ aliases = {kwarg : source for kwarg , source in attr_kwargs .items () if kwarg != source [1 ][- 1 ]}
462+ assert aliases == _THOP_KWARG_SOURCE_ALIASES , (
463+ "Unexpected thop kwarg/source attribute aliases.\n "
464+ f"new or changed aliases: { aliases .items () - _THOP_KWARG_SOURCE_ALIASES .items ()} \n "
465+ f"stale allowlist entries: { _THOP_KWARG_SOURCE_ALIASES .items () - aliases .items ()} "
466+ )
467+
468+
411469def test_literal_kwargs_match_allowlist ():
412470 """Every literal-constant kwarg at the call site must appear in
413471 ``_THOP_LITERALS`` with the matching value, and every entry in
0 commit comments