Skip to content

Commit 493cf05

Browse files
Bowen Fuclaude
andcommitted
feat(annotation/e2e): add LLM-domain, dynamic-shape, and cross-backend E2E tests
New tests in _BackendE2ETests mixin (×3 backends): test_llm_hidden_unary/binary : activation/gating at [batch=8, hidden=512] test_dynamic_batch : dynamic batch dim via torch_tensorrt.Input test_gated_ffn_llm : full FFN at hidden=256, inter=512 (2× LLM ratio) New TestCrossBackendE2E class (9 tests): Verifies that Triton, CuTile, and CuTeDSL PluginV3 ops coexist in one engine. All at LLM-domain [batch=8, hidden=512]. Bug fixes triggered by dynamic-shape tests: _qdp_utils.py: _collect_shape_var_bindings now binds non-constant ShapeExpr dims (is_constant=False) so meta_impl runs during get_output_shapes for dynamic-shape engines. _qdp_utils.py: _safe_dim default changed from 256 → 1 so CuTeDSL dummy tensors compiled for dynamic dims use a [1, static_dim] layout whose offset formula (offset=idx) is correct for any batch size at runtime. (Using 256 caused the kernel to access only 8/256 columns, giving 96.8% output mismatch.) _triton.py: use _as_symint32() for grid_x/y/z assignment (matching CuTile and CuTeDSL) so _ShapeDim/ShapeExpr values are properly wrapped before being stored in KernelLaunchParams. Previously, direct assignment caused a TRT segfault during dynamic-shape engine builds. E2E test count: 19 → 40 (50.6% of 79 total, up from 32%). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 9693733 commit 493cf05

3 files changed

Lines changed: 304 additions & 34 deletions

File tree

py/torch_tensorrt/annotation/_custom_plugin/_aot/_triton.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -378,13 +378,13 @@ def _to_int(x: Any) -> Any:
378378
# extra_args from shape_expr-derived scalars; only tile sizes in constexprs.
379379
launch = trtp.KernelLaunchParams()
380380
if isinstance(grid, tuple):
381-
launch.grid_x = grid[0] if len(grid) >= 1 else 1
382-
launch.grid_y = grid[1] if len(grid) >= 2 else 1
383-
launch.grid_z = grid[2] if len(grid) >= 3 else 1
381+
launch.grid_x = _as_symint32(grid[0]) if len(grid) >= 1 else trtp.SymInt32(1)
382+
launch.grid_y = _as_symint32(grid[1]) if len(grid) >= 2 else trtp.SymInt32(1)
383+
launch.grid_z = _as_symint32(grid[2]) if len(grid) >= 3 else trtp.SymInt32(1)
384384
else:
385-
launch.grid_x = grid
386-
launch.grid_y = 1
387-
launch.grid_z = 1
385+
launch.grid_x = _as_symint32(grid)
386+
launch.grid_y = trtp.SymInt32(1)
387+
launch.grid_z = trtp.SymInt32(1)
388388

389389
launch.block_x = compiled.metadata.num_warps * 32
390390
launch.block_y = 1

py/torch_tensorrt/annotation/_custom_plugin/_qdp_utils.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,18 @@ def _as_symint32(v: Any) -> Any:
302302
return trtp.SymInt32(1)
303303

304304

305-
def _safe_dim(d: Any, default: int = 256) -> int:
305+
def _safe_dim(d: Any, default: int = 1) -> int:
306306
"""Extract a concrete int from a shape element safely.
307307
308308
For TRT SymInt32 elements (dynamic shapes), calling int() directly does NOT raise
309309
but returns a garbage pointer-like value (~470 TB), causing OOM. Check
310310
_int_expr.is_constant() first; return *default* for dynamic dims.
311+
312+
The default is 1 (minimum valid tensor dimension) rather than a larger value
313+
so that dummy tensors constructed for kernel compilation use the most compact
314+
shape. CuTeDSL kernels bake the tensor layout into their type; using 1 for
315+
dynamic dims gives a [1, static_dim, ...] dummy whose row-major offset formula
316+
(offset = idx) is valid for any larger batch size at runtime.
311317
"""
312318
if isinstance(d, int):
313319
return d
@@ -461,13 +467,21 @@ def collect_allowed_formats_for_io(
461467

462468

463469
def _collect_shape_var_bindings(shape_expr: Any, bindings: Dict[int, int]) -> None:
464-
"""Recursively find free/fake shape vars and assign them the minimum valid value.
470+
"""Recursively find free/fake/dynamic shape vars and assign them the minimum valid value.
465471
466472
Walks *shape_expr* recursively (handling nested tensors with a `.shape_expr`
467473
attribute and plain list/tuple containers) and populates *bindings* with a
468474
mapping from ``id(var)`` → 1 for every element that is not a plain ``int``
469-
and has ``is_fake == True``. The value 1 is the minimum positive integer
470-
accepted as a tensor dimension.
475+
and is either:
476+
- marked as fake (``is_fake == True``), or
477+
- a non-constant symbolic expression (``is_constant == False``), or
478+
- not directly convertible to int via ``int()``.
479+
480+
The value 1 is the minimum positive integer accepted as a tensor dimension.
481+
Both ``is_fake`` fakes (from TRT's shape-inference placeholder pass) and
482+
true dynamic ``ShapeExpr`` dims (from dynamic-shape engines) are bound so
483+
that ``_shape_expr_to_ints`` can produce a concrete fallback shape for
484+
``meta_impl`` even in dynamic-shape contexts.
471485
472486
Mutates *bindings* in place.
473487
"""
@@ -476,10 +490,21 @@ def _collect_shape_var_bindings(shape_expr: Any, bindings: Dict[int, int]) -> No
476490
_collect_shape_var_bindings(d.shape_expr, bindings)
477491
elif isinstance(d, (list, tuple)):
478492
_collect_shape_var_bindings(d, bindings)
479-
elif not isinstance(d, int) and getattr(d, "is_fake", False):
480-
vid = id(d)
481-
if vid not in bindings:
482-
bindings[vid] = _MIN_VALID_DIM
493+
elif not isinstance(d, int):
494+
is_symbolic = (
495+
getattr(d, "is_fake", False)
496+
or (hasattr(d, "is_constant") and not d.is_constant)
497+
)
498+
if not is_symbolic:
499+
# Last resort: try converting to int; if it fails, treat as symbolic.
500+
try:
501+
int(d)
502+
except (TypeError, ValueError):
503+
is_symbolic = True
504+
if is_symbolic:
505+
vid = id(d)
506+
if vid not in bindings:
507+
bindings[vid] = _MIN_VALID_DIM
483508

484509

485510
def _shape_elem_to_int(d: Any, bindings: Dict[int, int]) -> int:

0 commit comments

Comments
 (0)