Fail fast on dynamic_shapes and guard arange rewrite for symbolic args#1024
Draft
pctablet505 wants to merge 4 commits into
Draft
Fail fast on dynamic_shapes and guard arange rewrite for symbolic args#1024pctablet505 wants to merge 4 commits into
pctablet505 wants to merge 4 commits into
Conversation
…ecated pytree APIs.
When a torch.nn.Module accepts a dict as a positional forward argument,
litert_torch.convert flattens the dict values but previously renamed them
to generic args_0, args_1, etc. This change uses flat_dict_names for
positional args too (with an args_ prefix), so dict keys are preserved.
For example, a dict {"x": t1, "y": t2} passed positionally now produces:
args_0_x, args_0_y
instead of:
args_0, args_1
Also replaces deprecated treespec.children_specs with treespec.children()
and isinstance(treespec, LeafSpec) patterns.
Fixes google-ai-edge#1022
When dynamic shapes are used, torch.export may produce arange ops where start/end are symbolic expressions (SymInt/Node) rather than concrete ints. The in_i32 helper now gracefully skips the rewrite instead of crashing with TypeError on comparison. Related to google-ai-edge#1022.
The JAX-based lowering path in litert-torch does not yet support dynamic dimensions. Rather than crashing with obscure JAX assertion errors deep in the stack, we now fail fast with a clear message that points users to the existing tracking issue. Related to google-ai-edge#870.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
The LiteRT-Torch converter does not support dynamic shapes. When users pass
dynamic_shapestolitert_torch.convert(), the JAX-based lowering path crashes with obscure errors deep in the stack:AssertionError: float32[1,-9223372036854775808]— JAXjit(...).lower()receivesShapeDtypeStructwithIR_DYNAMICdimensions.TypeError: Shapes must be 1D sequences of concrete values, got (JitTracer(int32[]),)— shape-dependent ops likeaten.reshapeproduce JAX tracers for dynamic dims.Changes in this PR
Guard
in_i32against non-integer args (export.py)arangeops may have symbolicstart/endvalues. The helper now skips the rewrite gracefully instead of crashing withTypeError.Fail fast on
dynamic_shapes(interface.py)convert()now raises a clearNotImplementedErrorpointing to the tracking issue can only convert model using static shapes #870.Fix deprecated
children_specsin_get_output_names(litert_converter.py)spec.children_specswithspec.children().Why dynamic shapes are hard
Each aten op is lowered through a small JAX function →
jax.jit(...).lower() → StableHLO. JAXjitrequires concrete shapes at trace time. Dynamic dims becomeIR_DYNAMIC(-9223372036854775808) or JAX tracers, both of which breaklower(). Full support would require replacingjax.jit(...).lower()withjax.export.export()+symbolic_shapeacross the entire op lowering surface.Related to #870