Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,11 +860,30 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
split_dimension_size = shape[axis]
if not isinstance(split_dimension_size, int):
return None
num_outputs = math.ceil(split_dimension_size / split_value.item())
split_size = int(split_value.item())
num_outputs = math.ceil(split_dimension_size / split_size)
split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
split_values = op.Split(
input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs
)
if split_dimension_size % split_size != 0:
# Uneven split: the last chunk is smaller. We must pass explicit split
# sizes to Split, because Split with only num_outputs would do an
# equal (or near-equal) split ignoring the original chunk size.
remainder = split_dimension_size - (num_outputs - 1) * split_size
explicit_split_sizes = [split_size] * (num_outputs - 1) + [remainder]
explicit_split = op.Constant(
value_ints=explicit_split_sizes,
_outputs=[f"{output.name}_split_sizes"],
)
split_values = op.Split(
input,
explicit_split,
axis=axis,
num_outputs=num_outputs,
_outputs=split_outputs,
)
else:
split_values = op.Split(
input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs
)
else:
return None

Expand Down
39 changes: 39 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,45 @@ def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_
self.assertEqual(len(optimized.graph[-2].outputs), 4)
self.assertEqual(optimized.graph[-2].op_type, "Split")

def test_static_split_to_sequence_with_unequal_scalar_split_and_sequence_at_is_folded_as_split(
self,
):
"""Test that an unequal scalar split is preserved correctly (not turned into equal split).

Regression test for: SplitToSequence with scalar split that doesn't evenly divide
the axis dimension should produce a Split with explicit split sizes, not an equal split.
E.g., splitting dim=8400 with split=5000 should produce [5000, 3400], not [4200, 4200].
"""
model = """
<
ir_version: 8,
opset_import: ["" : 18]
>
func (float[1,8400,80] x) => (float[1,N,80] return_val) {
int64_5000 = Constant <value: tensor = int64 int64_5000 {5000}> ()
splits = SplitToSequence <axis: int = 1> (x, int64_5000)
int64_0 = Constant <value: tensor = int64 int64_0 {0}> ()
split_0 = SequenceAt (splits, int64_0)
int64_1 = Constant <value: tensor = int64 int64_1 {1}> ()
split_1 = SequenceAt (splits, int64_1)
return_val = Concat <axis: int = 1> (split_0, split_1)
}"""

optimized = self._fold(model)
split_nodes = [n for n in optimized.graph if n.op_type == "Split"]
self.assertEqual(len(split_nodes), 1)
split_node = split_nodes[0]
self.assertEqual(len(split_node.outputs), 2)
# The Split node must have an explicit split input (not just num_outputs),
# so that the split is [5000, 3400] and not [4200, 4200].
split_sizes_input = split_node.inputs[1]
self.assertIsNotNone(split_sizes_input, "Split node must have explicit split sizes")
# Verify the actual split sizes are [5000, 3400], not [4200, 4200]
self.assertIsNotNone(split_sizes_input.const_value)
np.testing.assert_array_equal(split_sizes_input.const_value.numpy(), [5000, 3400])
# Check no SequenceAt remains
self.assertTrue(all(n.op_type != "SequenceAt" for n in optimized.graph))

def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split(
self,
):
Expand Down
Loading