Skip to content

_take_long_axis body over-counts local tensor variables (9 vs. ~2 user-visible) #543

@khatchad

Description

@khatchad

After #538's graceful-degradation fix, _take_long_axis (tf2_test_take_along_axis.py) produces 9 distinct tensor-typed local values in the analysis, where only ~2 are user-visible (row_indices, gather_indices). The other 7 are intermediates introduced by the tf.shape → tf.range/tile/concat → tf.reshape runtime-shape idiom.

This surfaced when the testTakeAlongAxis regression guard flipped from @Test(expected = IllegalStateException.class) to plain @Test. The post-fix Javadoc TODO on that test points at this issue.

Per-value-number counting picks up every SSA temp that the analysis labels as a tensor, including runtime-shape arithmetic. The 9 is harmless for correctness (the test still pins the parameter types precisely: arr(2, 3) float32, indices(2, 2) int32) but means the local-tensor count is a loose regression guard rather than a tight one.

Possible directions

  1. Exclude runtime-shape intermediates from functionTensorVariables (e.g., results of tf.shape, tf.size, tf.rank and downstream tf.range/tf.tile chains).
  2. Recognize the tf.shape → tf.range/tile/concat → tf.reshape idiom as a unit and collapse it.
  3. Switch the count to per-source-line rather than per-value-number — closer to what a user reads.

Pointer

Test that documents the loose bound (and will be tightened when this is fixed): TestTensorflow2Model.testTakeAlongAxis, currently expecting 9 locals.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions