Skip to content

Fix aten_stft ONNX spec violations#2943

Merged
justinchuby merged 2 commits into
mainfrom
justinchu/fix-stft-spec-2942
Jun 23, 2026
Merged

Fix aten_stft ONNX spec violations#2943
justinchuby merged 2 commits into
mainfrom
justinchu/fix-stft-spec-2942

Conversation

@justinchuby

Copy link
Copy Markdown
Collaborator

Summary

aten_stft emitted an ONNX STFT node that violated the operator spec in two ways (reported in #2942 by @bas-aarts):

  1. frame_step / frame_length type mismatch. frame_step was built via op.Reshape(hop_length, [1]), producing a rank-1 tensor, while frame_length (n_fft) is a rank-0 scalar. The spec (T2 for both) requires them to share the same type/rank.
  2. Signal rank. The ONNX STFT signal must be rank 3 ([batch, signal_length, 1] for real input). torch.stft accepts rank-1 or rank-2 signals, and the code only unsqueezed a rank-1 signal up to rank 2, leaving the signal one dimension short of the spec.

Fix

  • Pass hop_length directly as frame_step (a rank-0 scalar, matching n_fft) and remove the frame_step_const = op.Reshape(...) line. hop_length is always a Python int in this trace-only function (input arg or n_fft // 4), so it is emitted as a rank-0 scalar exactly like n_fft.
  • Add self = op.Unsqueeze(self, [-1]) after the existing batch-dim handling so the signal becomes rank 3 for both rank-1 and rank-2 inputs.

Rank trace (verified by inspecting the emitted STFT node)

Before → after, STFT inputs [signal, frame_step, window, frame_length]:

input buggy rank fixed rank spec
signal 2 3 3
frame_step 1 0 scalar
window 1 1 1
frame_length 0 0 scalar

End-to-end shapes (fixed): rank-1 input [L] → batch unsqueeze [1, L] → trailing unsqueeze [1, L, 1] → STFT [1, frames, bins, 2] → Transpose [1, bins, frames, 2] → squeeze batch [bins, frames, 2]. Rank-2 input [B, L][B, L, 1][B, frames, bins, 2][B, bins, frames, 2]. Both match torch.stft's real output. The normalized and onesided paths are unaffected (only dtype/scaling, not rank).

Tests

  • The existing OpInfo (ops.aten.stft) and _testing.assert_onnx_program value-comparison harnesses pass both before and after — they don't strictly enforce the STFT rank, which is why this spec bug went unnoticed.
  • Added test_aten_stft_emits_spec_compliant_node (parameterized for rank-1 and rank-2 inputs) in tests/function_libs/torch_lib/e2e_ops_tests.py, which asserts the emitted STFT node's signal is rank 3 and frame_step/frame_length are both scalar. This test fails on the old code (2 != 3) and passes with the fix.

Ran:

  • pytest tests/function_libs/torch_lib/ops_test.py -k stft → 2 passed, 1 skipped, 2 xfailed
  • pytest tests/function_libs/torch_lib/e2e_ops_tests.py -k stft → 6 passed
  • lintrunner -a → no lint issues

Fixes #2942

…nk-3 signal)

The ONNX STFT op requires:
- a rank-3 signal of shape [batch, signal_length, 1], and
- frame_step and frame_length to share the same (scalar) type.

aten_stft previously passed a rank-1 frame_step (Reshape of hop_length)
while frame_length (n_fft) was a rank-0 scalar, and only reshaped the
signal up to rank 2. This produced STFT nodes that violate the spec.

Fix by passing hop_length directly as the scalar frame_step and adding a
trailing [1] dimension to the signal so it is rank 3 for both rank-1 and
rank-2 torch.stft inputs. Adds an e2e regression test asserting the
emitted STFT node is spec-compliant for rank-1 and rank-2 inputs.

Fixes #2942

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@codecov

codecov Bot commented Jun 19, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 66.66667% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 72.64%. Comparing base (029441f) to head (0b2c5a0).
⚠️ Report is 13 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2943      +/-   ##
==========================================
+ Coverage   72.61%   72.64%   +0.02%     
==========================================
  Files         259      259              
  Lines       31597    31766     +169     
  Branches     2973     3007      +34     
==========================================
+ Hits        22945    23075     +130     
- Misses       7643     7672      +29     
- Partials     1009     1019      +10     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

@justinchuby justinchuby changed the title Fix aten_stft ONNX spec violations (Fixes #2942) Fix aten_stft ONNX spec violations Jun 19, 2026
@justinchuby justinchuby requested a review from Copilot June 22, 2026 17:20
@justinchuby justinchuby enabled auto-merge (squash) June 22, 2026 17:21
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Jun 22, 2026

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes ONNX STFT operator spec violations in the aten::stft lowering by ensuring the exported node receives spec-compliant input ranks and scalar parameters, and adds a regression test to prevent reintroduction.

Changes:

  • Make the STFT signal input rank-3 by appending a trailing Unsqueeze(-1) for both rank-1 and rank-2 Torch inputs.
  • Pass hop_length directly to ONNX STFT as frame_step (removing the prior reshape that produced a rank-1 tensor).
  • Add an end-to-end regression test that inspects the emitted ONNX STFT node inputs for rank/spec compliance.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
onnxscript/function_libs/torch_lib/ops/core.py Adjusts aten_stft preprocessing and STFT invocation so signal is rank-3 and frame_step/frame_length are scalar inputs per the ONNX spec.
tests/function_libs/torch_lib/e2e_ops_tests.py Adds a regression test validating the exported ONNX graph contains a spec-compliant STFT node for both rank-1 and rank-2 inputs.

Comment thread onnxscript/function_libs/torch_lib/ops/core.py Outdated
@justinchuby justinchuby disabled auto-merge June 22, 2026 17:44
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@justinchuby justinchuby merged commit 386769d into main Jun 23, 2026
29 of 33 checks passed
@justinchuby justinchuby deleted the justinchu/fix-stft-spec-2942 branch June 23, 2026 18:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

ONNX export at aten_stft generates STFT layer that violates the ONNX spec

4 participants