From 0f94339c5b40fdcceeca2786f320421f2b1044ba Mon Sep 17 00:00:00 2001 From: Justin Chu <11205048+justinchuby@users.noreply.github.com> Date: Fri, 19 Jun 2026 17:32:56 +0000 Subject: [PATCH 1/2] Fix aten_stft ONNX spec violations (frame_step/frame_length type + rank-3 signal) The ONNX STFT op requires: - a rank-3 signal of shape [batch, signal_length, 1], and - frame_step and frame_length to share the same (scalar) type. aten_stft previously passed a rank-1 frame_step (Reshape of hop_length) while frame_length (n_fft) was a rank-0 scalar, and only reshaped the signal up to rank 2. This produced STFT nodes that violate the spec. Fix by passing hop_length directly as the scalar frame_step and adding a trailing [1] dimension to the signal so it is rank 3 for both rank-1 and rank-2 torch.stft inputs. Adds an e2e regression test asserting the emitted STFT node is spec-compliant for rank-1 and rank-2 inputs. Fixes #2942 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../function_libs/torch_lib/ops/core.py | 8 +++- .../function_libs/torch_lib/e2e_ops_tests.py | 47 +++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 877a83a403..1c2f3b649c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -9331,7 +9331,6 @@ def aten_stft( # core dump # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) hop_length = n_fft // 4 - frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1])) # Pre-process input if needed is_signal_rank1 = len(self.shape) == 1 @@ -9339,6 +9338,11 @@ def aten_stft( # Add a batch dimension self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0]))) + # ONNX's STFT requires a rank-3 signal of shape [batch_size, signal_length, 1] + # (the trailing dimension is the real component). torch.stft accepts rank-1 or + # rank-2 signals, so add the trailing dimension here for both cases. + self = op.Unsqueeze(self, op.Constant(value_ints=[-1])) + # Get window and make sure it's the same size as `win_length` or `n_fft` if window is not None and window.shape[0] is not None: # first dimension @@ -9367,7 +9371,7 @@ def aten_stft( else: onesided = 0 window = op.CastLike(window, self) - result = op.STFT(self, frame_step_const, window, n_fft, onesided=onesided) + result = op.STFT(self, hop_length, window, n_fft, onesided=onesided) result = op.Transpose(result, perm=[0, 2, 1, 3]) # Remove batch dimension, if needed if is_signal_rank1: diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 019e6f7fe5..289da547e8 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -554,6 +554,53 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + @parameterized.parameterized.expand( + [ + ("rank1", (100,)), + ("rank2", (4, 100)), + ] + ) + def test_aten_stft_emits_spec_compliant_node(self, _: str, shape: tuple[int, ...]): + # Regression test for https://github.com/microsoft/onnxscript/issues/2942 + # The ONNX STFT op requires a rank-3 signal ([batch, signal_length, 1]) and + # `frame_step`/`frame_length` to share the same (scalar) type. torch.stft + # accepts rank-1 or rank-2 signals, so aten_stft must reshape accordingly. + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.stft(x, n_fft=16, return_complex=False) + + x = torch.randn(*shape, dtype=torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + model = onnx_program.model_proto + + def _rank(name: str) -> int: + for vi in ( + list(model.graph.value_info) + + list(model.graph.input) + + list(model.graph.output) + ): + if vi.name == name: + return len(vi.type.tensor_type.shape.dim) + raise AssertionError(f"value_info for {name} not found") + + stft_nodes = [n for n in model.graph.node if n.op_type == "STFT"] + self.assertEqual(len(stft_nodes), 1) + node = stft_nodes[0] + signal, frame_step = node.input[0], node.input[1] + frame_length = node.input[3] + # signal must be rank 3: [batch, signal_length, 1] + self.assertEqual(_rank(signal), 3) + # frame_step and frame_length must share the same (scalar) rank + self.assertEqual(_rank(frame_step), 0) + self.assertEqual(_rank(frame_length), 0) + def test_unbind_dim0(self): """Test unbind along dimension 0""" From 0b2c5a0885f3b2b7214e2064efe1123320bb864c Mon Sep 17 00:00:00 2001 From: Justin Chu <11205048+justinchuby@users.noreply.github.com> Date: Mon, 22 Jun 2026 17:47:04 +0000 Subject: [PATCH 2/2] Address review: simplify stft signal reshape to if-else Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1c2f3b649c..df5d14e904 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -9332,16 +9332,16 @@ def aten_stft( # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) hop_length = n_fft // 4 - # Pre-process input if needed - is_signal_rank1 = len(self.shape) == 1 - if is_signal_rank1: - # Add a batch dimension - self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0]))) - # ONNX's STFT requires a rank-3 signal of shape [batch_size, signal_length, 1] # (the trailing dimension is the real component). torch.stft accepts rank-1 or - # rank-2 signals, so add the trailing dimension here for both cases. - self = op.Unsqueeze(self, op.Constant(value_ints=[-1])) + # rank-2 signals. + is_signal_rank1 = len(self.shape) == 1 + if is_signal_rank1: + # [signal_length] -> [1, signal_length, 1]: add batch dim and trailing real-component dim + self = op.Unsqueeze(self, op.Constant(value_ints=[0, -1])) + else: + # [batch_size, signal_length] -> [batch_size, signal_length, 1] + self = op.Unsqueeze(self, op.Constant(value_ints=[-1])) # Get window and make sure it's the same size as `win_length` or `n_fft` if window is not None and window.shape[0] is not None: