diff --git a/onnxoptimizer/passes/fuse_consecutive_squeezes.h b/onnxoptimizer/passes/fuse_consecutive_squeezes.h index 1a3018b0b..ea05a7b1a 100644 --- a/onnxoptimizer/passes/fuse_consecutive_squeezes.h +++ b/onnxoptimizer/passes/fuse_consecutive_squeezes.h @@ -44,6 +44,10 @@ struct FuseConsecutiveSqueezes final : public PredicateBasedPass { !GetValueFromAttrOrInput(n, kaxes, 1, axes_2)) { return false; } + if (std::any_of(axes_1.begin(), axes_1.end(), [](int64_t v) { return v < 0; }) || + std::any_of(axes_2.begin(), axes_2.end(), [](int64_t v) { return v < 0; })) { + return false; + } std::vector &ret = composed_axes; ret.clear(); diff --git a/onnxoptimizer/test/optimizer_test.py b/onnxoptimizer/test/optimizer_test.py index 7336d873c..940177137 100644 --- a/onnxoptimizer/test/optimizer_test.py +++ b/onnxoptimizer/test/optimizer_test.py @@ -2872,6 +2872,18 @@ def test_fuse_consecutive_squeezes_multi_uses(self): # type: () -> None if init.name == optimized_model.graph.node[2].input[1]: assert list(to_array(init)) == [0, 1, 4, 5, 6] + def test_fuse_consecutive_squeezes_negative_axes(self): # type: () -> None + graph = parser.parse_graph(""" + agraph (float[5, 7, 1, 1] X) => (float[5, 7] Z) + { + Axes = Constant () + Y = Squeeze (X, Axes) + Z = Squeeze (Y, Axes) + } + """) + optimized_model = self._optimized(graph, ["fuse_consecutive_squeezes"]) + assert len(optimized_model.graph.node) == 3 + @pytest.mark.xfail def test_fuse_consecutive_softmax_log_axis(self): # type: () -> None for axis in range(3):