Skip to content

Commit e370fc7

Browse files
authored
[Relax][ONNX] Normalize negative indices before the take call for Gather operator (#19525)
Hi Committers, This PR is trying to fix issues #19436. Any suggestions would be appreciated if you are available. ### Root Cause 1. ONNX `Gather` allows negative indices (counting from the end of the target axis). 2. In the Relax ONNX importer, `Gather` was lowered directly to `relax.op.take` without normalizing negative indices first. 3. This created semantic mismatch / incorrect behavior in downstream lowering paths that assume non-negative indices. 4. Test failures were also caused by pytest parametrization issues: - using ONNX `TensorProto` enum values directly as NumPy dtypes, - and tuple-style parametrization triggering fixture interpretation errors. ### Solutions 1. Added conditional negative-index normalization in `Gather._impl_v13`: - apply only for signed index dtypes, - use: `idx < 0 ? idx + axis_extent : idx`, - derive `axis_extent` from shape/runtime expression to support dynamic shapes. 2. Skipped normalization for unsigned index dtypes to avoid redundant graph ops/checks. --------- Co-authored-by: cchung100m <cchung100m@users.noreply.github.com>
1 parent 6b27d19 commit e370fc7

2 files changed

Lines changed: 81 additions & 0 deletions

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,25 @@ def _impl_v13(cls, bb, inputs, attr, params):
11061106
shape_val = data[np_index]
11071107
return relax.PrimValue(shape_val)
11081108

1109+
indices_dtype = indices.struct_info.dtype
1110+
if not indices_dtype.startswith("uint"):
1111+
data_shape = bb.normalize(relax.op.shape_of(data))
1112+
data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape))
1113+
axis_extent = bb.normalize(
1114+
relax.op.take(data_shape_tensor, relax.const(axis, "int64"), axis=0, mode="wrap")
1115+
)
1116+
1117+
if indices_dtype !="int64":
1118+
axis_extent = bb.normalize(relax.op.astype(axis_extent, indices_dtype))
1119+
1120+
indices = bb.normalize(
1121+
relax.op.where(
1122+
relax.op.less(indices, relax.const(0, indices_dtype)),
1123+
relax.op.add(indices, axis_extent),
1124+
indices,
1125+
)
1126+
)
1127+
11091128
return relax.op.take(data, indices, axis)
11101129

11111130

tests/python/relax/test_frontend_onnx.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,68 @@ def _verify_gather(data_shape, indices, out_shape, axis=0):
874874
_verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1)
875875

876876

877+
@pytest.mark.parametrize(
878+
"axis, indices, out_shape",
879+
[
880+
(0, [-1, 0], [2, 4]),
881+
(1, [-1, 0], [3, 2]),
882+
(
883+
1,
884+
[[-1, 0], [1, -2]],
885+
[3, 2, 2],
886+
),
887+
],
888+
)
889+
@pytest.mark.parametrize("indices_type", [TensorProto.INT64, TensorProto.INT32])
890+
def test_gather_negative_indices(axis, indices, out_shape, indices_type):
891+
gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=axis)
892+
indices_shape = np.asarray(indices).shape
893+
894+
graph = helper.make_graph(
895+
[gather_node],
896+
"gather_negative_indices_test",
897+
inputs=[
898+
helper.make_tensor_value_info("data", TensorProto.FLOAT, [3, 4]),
899+
helper.make_tensor_value_info("indices", indices_type, indices_shape),
900+
],
901+
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, out_shape)],
902+
)
903+
904+
model = helper.make_model(graph, producer_name="gather_negative_indices_test")
905+
indices_np_dtype = {
906+
TensorProto.INT64: np.int64,
907+
TensorProto.INT32: np.int32,
908+
}[indices_type]
909+
input_values = {
910+
"data": np.random.randn(3, 4).astype("float32"),
911+
"indices": np.array(indices).astype(indices_np_dtype),
912+
}
913+
check_correctness(model, inputs=input_values)
914+
915+
916+
@pytest.mark.parametrize("indices_type", [TensorProto.INT64, TensorProto.INT32])
917+
def test_gather_negative_indices_ir_normalization(indices_type):
918+
gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=1)
919+
graph = helper.make_graph(
920+
[gather_node],
921+
"gather_negative_indices_ir_test",
922+
inputs=[
923+
helper.make_tensor_value_info("data", TensorProto.FLOAT, [3, 4]),
924+
helper.make_tensor_value_info("indices", indices_type, [2]),
925+
],
926+
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 2])],
927+
)
928+
929+
model = helper.make_model(graph, producer_name="gather_negative_indices_ir_test")
930+
tvm_model = from_onnx(model, opset=13, keep_params_in_input=True)
931+
call_ops = collect_relax_call_ops(tvm_model["main"])
932+
933+
assert "relax.where" in call_ops
934+
assert "relax.less" in call_ops
935+
assert "relax.add" in call_ops
936+
assert "relax.take" in call_ops
937+
938+
877939
@pytest.mark.parametrize(
878940
"data_shape, indices_shape, axis",
879941
[

0 commit comments

Comments
 (0)