diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 877a83a403..df5d14e904 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -9331,13 +9331,17 @@ def aten_stft( # core dump # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) hop_length = n_fft // 4 - frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1])) - # Pre-process input if needed + # ONNX's STFT requires a rank-3 signal of shape [batch_size, signal_length, 1] + # (the trailing dimension is the real component). torch.stft accepts rank-1 or + # rank-2 signals. is_signal_rank1 = len(self.shape) == 1 if is_signal_rank1: - # Add a batch dimension - self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0]))) + # [signal_length] -> [1, signal_length, 1]: add batch dim and trailing real-component dim + self = op.Unsqueeze(self, op.Constant(value_ints=[0, -1])) + else: + # [batch_size, signal_length] -> [batch_size, signal_length, 1] + self = op.Unsqueeze(self, op.Constant(value_ints=[-1])) # Get window and make sure it's the same size as `win_length` or `n_fft` if window is not None and window.shape[0] is not None: @@ -9367,7 +9371,7 @@ def aten_stft( else: onesided = 0 window = op.CastLike(window, self) - result = op.STFT(self, frame_step_const, window, n_fft, onesided=onesided) + result = op.STFT(self, hop_length, window, n_fft, onesided=onesided) result = op.Transpose(result, perm=[0, 2, 1, 3]) # Remove batch dimension, if needed if is_signal_rank1: diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 019e6f7fe5..289da547e8 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -554,6 +554,53 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + @parameterized.parameterized.expand( + [ + ("rank1", (100,)), + ("rank2", (4, 100)), + ] + ) + def test_aten_stft_emits_spec_compliant_node(self, _: str, shape: tuple[int, ...]): + # Regression test for https://github.com/microsoft/onnxscript/issues/2942 + # The ONNX STFT op requires a rank-3 signal ([batch, signal_length, 1]) and + # `frame_step`/`frame_length` to share the same (scalar) type. torch.stft + # accepts rank-1 or rank-2 signals, so aten_stft must reshape accordingly. + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.stft(x, n_fft=16, return_complex=False) + + x = torch.randn(*shape, dtype=torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + model = onnx_program.model_proto + + def _rank(name: str) -> int: + for vi in ( + list(model.graph.value_info) + + list(model.graph.input) + + list(model.graph.output) + ): + if vi.name == name: + return len(vi.type.tensor_type.shape.dim) + raise AssertionError(f"value_info for {name} not found") + + stft_nodes = [n for n in model.graph.node if n.op_type == "STFT"] + self.assertEqual(len(stft_nodes), 1) + node = stft_nodes[0] + signal, frame_step = node.input[0], node.input[1] + frame_length = node.input[3] + # signal must be rank 3: [batch, signal_length, 1] + self.assertEqual(_rank(signal), 3) + # frame_step and frame_length must share the same (scalar) rank + self.assertEqual(_rank(frame_step), 0) + self.assertEqual(_rank(frame_length), 0) + def test_unbind_dim0(self): """Test unbind along dimension 0"""