Skip to content

Fail fast on dynamic_shapes and guard arange rewrite for symbolic args#1024

Draft
pctablet505 wants to merge 4 commits into
google-ai-edge:mainfrom
pctablet505:fix-dynamic-shapes-limitation
Draft

Fail fast on dynamic_shapes and guard arange rewrite for symbolic args#1024
pctablet505 wants to merge 4 commits into
google-ai-edge:mainfrom
pctablet505:fix-dynamic-shapes-limitation

Conversation

@pctablet505
Copy link
Copy Markdown

Problem
The LiteRT-Torch converter does not support dynamic shapes. When users pass dynamic_shapes to litert_torch.convert(), the JAX-based lowering path crashes with obscure errors deep in the stack:

  • AssertionError: float32[1,-9223372036854775808] — JAX jit(...).lower() receives ShapeDtypeStruct with IR_DYNAMIC dimensions.
  • TypeError: Shapes must be 1D sequences of concrete values, got (JitTracer(int32[]),) — shape-dependent ops like aten.reshape produce JAX tracers for dynamic dims.

Changes in this PR

  1. Guard in_i32 against non-integer args (export.py)

    • When dynamic shapes are present, arange ops may have symbolic start/end values. The helper now skips the rewrite gracefully instead of crashing with TypeError.
  2. Fail fast on dynamic_shapes (interface.py)

  3. Fix deprecated children_specs in _get_output_names (litert_converter.py)

    • Replaces spec.children_specs with spec.children().

Why dynamic shapes are hard
Each aten op is lowered through a small JAX function → jax.jit(...).lower() → StableHLO. JAX jit requires concrete shapes at trace time. Dynamic dims become IR_DYNAMIC (-9223372036854775808) or JAX tracers, both of which break lower(). Full support would require replacing jax.jit(...).lower() with jax.export.export() + symbolic_shape across the entire op lowering surface.

Related to #870

…ecated pytree APIs.

When a torch.nn.Module accepts a dict as a positional forward argument,
litert_torch.convert flattens the dict values but previously renamed them
to generic args_0, args_1, etc. This change uses flat_dict_names for
positional args too (with an args_ prefix), so dict keys are preserved.

For example, a dict {"x": t1, "y": t2} passed positionally now produces:
  args_0_x, args_0_y
instead of:
  args_0, args_1

Also replaces deprecated treespec.children_specs with treespec.children()
and isinstance(treespec, LeafSpec) patterns.

Fixes google-ai-edge#1022
When dynamic shapes are used, torch.export may produce arange ops
where start/end are symbolic expressions (SymInt/Node) rather than
concrete ints. The in_i32 helper now gracefully skips the rewrite
instead of crashing with TypeError on comparison.

Related to google-ai-edge#1022.
The JAX-based lowering path in litert-torch does not yet support
dynamic dimensions. Rather than crashing with obscure JAX assertion
errors deep in the stack, we now fail fast with a clear message that
points users to the existing tracking issue.

Related to google-ai-edge#870.
@pctablet505 pctablet505 marked this pull request as draft May 13, 2026 17:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant