Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4757,9 +4757,9 @@ def _impl_v10(cls, bb, inputs, attr, params):
_, param_value = params[1][var_name]
max_output_boxes_per_class = int(param_value.numpy().item())
else:
max_output_boxes_per_class = 100 # Default value
max_output_boxes_per_class = 0 # Default value
else:
max_output_boxes_per_class = 100 # Default value
max_output_boxes_per_class = 0 # Default value
Comment on lines +4760 to +4762
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value for max_output_boxes_per_class is updated to 0 here to align with the ONNX specification. However, AllClassNMS._impl_v1 (lines 4849 and 4851) still uses 100 as the default value. To maintain consistency across the ONNX frontend, AllClassNMS should also be updated to use 0 as the default value.


if iou_threshold is not None and isinstance(iou_threshold, relax.Constant):
iou_threshold = float(iou_threshold.data.numpy())
Expand Down Expand Up @@ -4846,9 +4846,9 @@ def _impl_v1(cls, bb, inputs, attr, params):
_, param_value = params[1][var_name]
max_output_boxes_per_class = int(param_value.numpy().item())
else:
max_output_boxes_per_class = 100 # Default value
max_output_boxes_per_class = 0 # Default value
else:
max_output_boxes_per_class = 100 # Default value
max_output_boxes_per_class = 0 # Default value

if iou_threshold is not None and isinstance(iou_threshold, relax.Constant):
iou_threshold = float(iou_threshold.data.numpy())
Expand Down
57 changes: 57 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4821,6 +4821,63 @@ def test_nms():
)


@pytest.mark.parametrize("with_explicit_max", [False, True])
def test_nms_max_output_boxes_per_class_zero(with_explicit_max: bool):
"""ONNX default for max_output_boxes_per_class is 0, yielding empty output."""
node_inputs = ["boxes", "scores"]
initializer = []
if with_explicit_max:
node_inputs.append("max_output_boxes_per_class")
initializer.append(
helper.make_tensor("max_output_boxes_per_class", TensorProto.INT64, [1], [0])
)

nms_node = helper.make_node(
"NonMaxSuppression",
node_inputs,
["selected_indices"],
center_point_box=0,
)

boxes_shape = [1, 4, 4]
scores_shape = [1, 1, 4]
graph = helper.make_graph(
[nms_node],
"nms_max_output_boxes_per_class_zero",
inputs=[
helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes_shape),
helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores_shape),
],
initializer=initializer,
outputs=[helper.make_tensor_value_info("selected_indices", TensorProto.INT64, [0, 3])],
)

model = helper.make_model(graph, producer_name="nms_max_output_boxes_per_class_zero")
model.ir_version = 8
model.opset_import[0].version = 11

inputs = {
"boxes": np.array(
[
[
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.1, 1.0, 1.1],
[2.0, 2.0, 3.0, 3.0],
[2.0, 2.1, 3.0, 3.1],
]
],
dtype=np.float32,
),
"scores": np.array([[[0.9, 0.8, 0.7, 0.6]]], dtype=np.float32),
}

check_correctness(model, inputs=inputs, opset=11)

tvm_out = run_in_tvm(model, inputs=inputs, opset=11)
tvm_selected = tvm_out[0].numpy() if isinstance(tvm_out, (list, tuple)) else tvm_out.numpy()
assert tvm_selected.shape == (0, 3)


def test_nms_algorithm_correctness():
"""Test NMS algorithm correctness with fixed data to verify suppression logic."""
nms_node = helper.make_node(
Expand Down
Loading