Skip to content

Add fusion rules (Whisper optimizations)#2221

Merged
shubhambhokare1 merged 27 commits intomainfrom
sbhokare/whisper-test
May 7, 2025
Merged

Add fusion rules (Whisper optimizations)#2221
shubhambhokare1 merged 27 commits intomainfrom
sbhokare/whisper-test

Conversation

@shubhambhokare1
Copy link
Copy Markdown
Contributor

@shubhambhokare1 shubhambhokare1 commented Apr 23, 2025

Add fusion rules to support the optimization of Whisper models.

Fusions added:

TODO:

  • Fix SDPA singular prescale case, due to lost shape information
  • - Enable check conditions when Allow sdpa fusion to accept custom scale factor #2210 is merged
  • - Improve/Rewrite whisper model test case to be similar to that of smollm (for eg)
  • - Fix failing test cases to account for new patterns
  • - Add isolated test cases for new fusions like BiasGelu, SkipLayerNorm etc

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 23, 2025

❌ 12 Tests Failed:

Tests completed Failed Passed Skipped
13075 12 13063 3001
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0220_test_bitwise_or_i32_2d
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_bitwise_or_i32_2d'

The above exception was the direct cause of the following exception:
.nox\test_torch_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_bitwise_or_i32_2d' (e=No module named 'tests.onnx_backend_test_code.test_bitwise_or_i32_2d') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_bitwise_or_i32_2d.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_bitwise_or_i32_2d.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import INT32
E   from onnxscript.onnx_opset import opset18
E   
E   @script()
E   def bck_test_bitwise_or_i32_2d(x: INT32[3,4], y: INT32[3,4]) -> (INT32[3,4]):
E       bitwiseor = opset18.BitwiseOr(x, y)
E       return bitwiseor
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0985_test_reduce_mean_negative_axes_keepdims_random
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_reduce_mean_negative_axes_keepdims_random'

The above exception was the direct cause of the following exception:
.nox\test_torch_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_reduce_mean_negative_axes_keepdims_random' (e=No module named 'tests.onnx_backend_test_code.test_reduce_mean_negative_axes_keepdims_random') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_reduce_mean_negative_axes_keepdims_random.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_reduce_mean_negative_axes_keepdims_random.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset18
E   
E   @script()
E   def bck_test_reduce_mean_negative_axes_keepdims_random(data: FLOAT[3,2,2], axes: INT64[1]) -> (FLOAT[3,1,2]):
E       reduced = opset18.ReduceMean(data, axes, keepdims=1)
E       return reduced
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1415_test_tril_zero
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_tril_zero'

The above exception was the direct cause of the following exception:
.nox\test_torch_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_tril_zero' (e=No module named 'tests.onnx_backend_test_code.test_tril_zero') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_tril_zero.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_tril_zero.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import INT64
E   from onnxscript.onnx_opset import opset14
E   
E   @script()
E   def bck_test_tril_zero(x: INT64[3,0,5], k: INT64) -> (INT64[3,0,5]):
E       y = opset14.Trilu(x, k, upper=0)
E       return y

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Comment thread onnxscript/rewriter/ort_fusions/fuse_mha_bias.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/fuse_mha_bias.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/fuse_mha_bias.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/mha.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/mha.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/_core.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/sdpa.py Outdated
Comment thread onnxscript/rewriter/ort_fusions/skip_normalization.py Outdated
Comment thread onnxscript/rewriter/ort_fusions/fuse_mha_bias.py Outdated
Comment thread onnxscript/rewriter/ort_fusions/fuse_mha_bias.py
Comment thread onnxscript/rewriter/ort_fusions/_whisper_tiny.py Outdated
@shubhambhokare1 shubhambhokare1 marked this pull request as ready for review April 29, 2025 02:26
Comment thread onnxscript/rewriter/ort_fusions/models/_whisper_decoder.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/models/_whisper_decoder.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/models/_whisper_encoder.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/models/_whisper_encoder.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/models/_whisper_encoder.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/models/_whisper_encoder.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/_core.py Outdated
Comment thread onnxscript/rewriter/ort_fusions/attention.py Outdated
Comment thread onnxscript/rewriter/ort_fusions/attention.py Outdated
Comment thread onnxscript/rewriter/ort_fusions/mha.py Outdated
shubhambhokare1 added a commit that referenced this pull request May 2, 2025
- Rewrite SkipLayerNorm fusions and SkipRMSNorm fusions to match format
of other ort-fusion patterns.
- Added check functions for ensuring shapes are as expected.
- Moving these fusions out of PR #2221 

Fusion support patterns with:
- `Add(input, skip) -> Norm`
- `Add(input, skip) -> Add (result, bias) -> Norm`
- `Add(input, bias) -> Add (result, skip) -> Norm`


NOTE:
These fusions should support:
- Planned whisper-related optimizations
- Benchmark failures stemming from wrong bias shapes for SkipLayerNorm
fusions
@shubhambhokare1 shubhambhokare1 force-pushed the sbhokare/whisper-test branch from 1d3bc5b to 1d7ad44 Compare May 2, 2025 18:48
Comment thread onnxscript/rewriter/ort_fusions/fuse_mha_bias.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/attention.py Outdated
@shubhambhokare1 shubhambhokare1 force-pushed the sbhokare/whisper-test branch from 9ca3c1a to ba12b06 Compare May 7, 2025 15:50
@shubhambhokare1 shubhambhokare1 enabled auto-merge (squash) May 7, 2025 16:47
@github-project-automation github-project-automation Bot moved this from Todo to Done in ONNX Script Review Board May 7, 2025
@shubhambhokare1 shubhambhokare1 merged commit 605e06e into main May 7, 2025
19 of 29 checks passed
@shubhambhokare1 shubhambhokare1 deleted the sbhokare/whisper-test branch May 7, 2025 16:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

3 participants