Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9331,13 +9331,17 @@ def aten_stft(
# core dump
# hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
hop_length = n_fft // 4
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))

# Pre-process input if needed
# ONNX's STFT requires a rank-3 signal of shape [batch_size, signal_length, 1]
# (the trailing dimension is the real component). torch.stft accepts rank-1 or
# rank-2 signals.
is_signal_rank1 = len(self.shape) == 1
if is_signal_rank1:
# Add a batch dimension
self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0])))
# [signal_length] -> [1, signal_length, 1]: add batch dim and trailing real-component dim
self = op.Unsqueeze(self, op.Constant(value_ints=[0, -1]))
else:
# [batch_size, signal_length] -> [batch_size, signal_length, 1]
self = op.Unsqueeze(self, op.Constant(value_ints=[-1]))

# Get window and make sure it's the same size as `win_length` or `n_fft`
if window is not None and window.shape[0] is not None:
Expand Down Expand Up @@ -9367,7 +9371,7 @@ def aten_stft(
else:
onesided = 0
window = op.CastLike(window, self)
result = op.STFT(self, frame_step_const, window, n_fft, onesided=onesided)
result = op.STFT(self, hop_length, window, n_fft, onesided=onesided)
result = op.Transpose(result, perm=[0, 2, 1, 3])
# Remove batch dimension, if needed
if is_signal_rank1:
Expand Down
47 changes: 47 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,53 @@ def forward(self, x):
)
_testing.assert_onnx_program(onnx_program)

@parameterized.parameterized.expand(
[
("rank1", (100,)),
("rank2", (4, 100)),
]
)
def test_aten_stft_emits_spec_compliant_node(self, _: str, shape: tuple[int, ...]):
# Regression test for https://github.com/microsoft/onnxscript/issues/2942
# The ONNX STFT op requires a rank-3 signal ([batch, signal_length, 1]) and
# `frame_step`/`frame_length` to share the same (scalar) type. torch.stft
# accepts rank-1 or rank-2 signals, so aten_stft must reshape accordingly.
class Model(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.stft(x, n_fft=16, return_complex=False)

x = torch.randn(*shape, dtype=torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x,),
dynamo=True,
verbose=False,
)
_testing.assert_onnx_program(onnx_program)

model = onnx_program.model_proto

def _rank(name: str) -> int:
for vi in (
list(model.graph.value_info)
+ list(model.graph.input)
+ list(model.graph.output)
):
if vi.name == name:
return len(vi.type.tensor_type.shape.dim)
raise AssertionError(f"value_info for {name} not found")

stft_nodes = [n for n in model.graph.node if n.op_type == "STFT"]
self.assertEqual(len(stft_nodes), 1)
node = stft_nodes[0]
signal, frame_step = node.input[0], node.input[1]
frame_length = node.input[3]
# signal must be rank 3: [batch, signal_length, 1]
self.assertEqual(_rank(signal), 3)
# frame_step and frame_length must share the same (scalar) rank
self.assertEqual(_rank(frame_step), 0)
self.assertEqual(_rank(frame_length), 0)

def test_unbind_dim0(self):
"""Test unbind along dimension 0"""

Expand Down
Loading