Add fusion rule to remove Expand before broadcast-capable binary operators#2862
Merged
gramalingam merged 12 commits intomainfrom Apr 10, 2026
Merged
Conversation
… operators Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
Copilot
AI
changed the title
[WIP] Add fusion rule to remove expand node before binary operator
Add fusion rule to remove Expand before broadcast-capable binary operators
Mar 20, 2026
xadupre
reviewed
Mar 20, 2026
xadupre
reviewed
Mar 20, 2026
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
…://github.com/microsoft/onnxscript into copilot/create-fusion-rule-remove-expand-node
Signed-off-by: Xavier Dupré <xadupre@microsoft.com>
xadupre
reviewed
Mar 20, 2026
❌ 1 Tests Failed:
View the full list of 1 ❄️ flaky test(s)
To view more test analytics, go to the Test Analytics Dashboard |
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
xadupre
reviewed
Mar 23, 2026
…adcast comparison Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/4d4f4fb8-b66e-456e-a1d6-b1eb5ca1b532
gramalingam
reviewed
Mar 26, 2026
gramalingam
reviewed
Mar 26, 2026
gramalingam
approved these changes
Apr 10, 2026
Collaborator
gramalingam
left a comment
There was a problem hiding this comment.
I think some improvements are possible. Will do it myself later.
justinchuby
pushed a commit
that referenced
this pull request
Apr 17, 2026
…ators (#2862) Adds a rewrite rule that eliminates redundant `Expand` nodes preceding binary operators that natively support NumPy-style broadcasting. ## Pattern ``` BinaryOp(Expand(x, shape), y) → BinaryOp(x, y) BinaryOp(x, Expand(y, shape)) → BinaryOp(x, y) ``` ## Safety check The rule applies a dimension-by-dimension analysis to determine if the `Expand` is redundant. For each dimension `i`, the expand is safe to remove if any of the following hold: - `expand_shape[i] == 1` - expand cannot shrink a dimension, so it is a no-op. - `x.shape[i] == expand_shape[i]` - the expand is a no-op at this dimension. - `y.shape[i] == expand_shape[i]` - `y` already covers the expansion via its own broadcasting. Otherwise the check fails conservatively. Three producer-agnostic strategies are used to resolve the expand target shape: 1. **Constant expand shape**: When the `shape` argument is a compile-time constant, the check is applied directly. Individual dimensions of `x` or `y` may still be symbolic. For example, `Add(Expand(x=[N], shape=[3,4]), y=[3,4])` is optimized to `Add(x, y)` because `y` statically provides all expansion dimensions. 2. **Expand output shape annotation**: When `shape` is dynamic but the Expand node's output value already carries a shape annotation (e.g. after ONNX shape inference has been applied), those dimension values are used directly for the check. For example, after `onnx.shape_inference.infer_shapes`, `Expand(x=[N,1], Concat(Shape(x,0:1), Shape(x,1:2)))` gets output shape `[N,1]` and the rule fires. 3. **Binary op output shape**: When neither of the above is available, the rule verifies that `broadcast(x.shape, y.shape)` symbolically equals the binary op's output shape. If they agree, the binary op's own broadcasting already accounts for all the expansion and the Expand is redundant. ## Supported ops `Add`, `Sub`, `Mul`, `Div`, `Pow`, `And`, `Or`, `Xor`, `BitwiseAnd`, `BitwiseOr`, `BitwiseXor`, `Greater`, `Less`, `Equal`, `GreaterOrEqual`, `LessOrEqual`, `Mod`, `PRelu`, `BitShift` ## Changes - **`_remove_expand_before_binary_op.py`** — new module with `_ExpandFirstInput` / `_ExpandSecondInput` rule classes, `_compute_broadcast_shape` / `_check_dims_sufficient` helpers, and the exported `expand_before_binary_op_rules` `RewriteRuleSet`; rule classes access `context.root` to obtain the Expand output and binary op output values - **`_remove_expand_before_binary_op_test.py`** — tests covering removal when safe (including dynamic shapes via shape annotations and binary op output shape matching), and non-removal when the expansion cannot be statically verified - **`rules/common/__init__.py`** — exports `expand_before_binary_op_rules` <!-- START COPILOT ORIGINAL PROMPT --> <details> <summary>Original prompt</summary> > > ---- > > *This section details on the original issue you should resolve* > > <issue_title>create a fusion rule to remove an expand node before a binary operator if this op can handle it through broadcasting</issue_title> > <issue_description></issue_description> > > ## Comments on the Issue (you are @copilot in this section) > > <comments> > </comments> > </details> <!-- START COPILOT CODING AGENT SUFFIX --> - Fixes #2861 <!-- START COPILOT CODING AGENT TIPS --> --- ✨ Let Copilot coding agent [set things up for you](https://github.com/microsoft/onnxscript/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Signed-off-by: Xavier Dupré <xadupre@microsoft.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> Co-authored-by: Xavier Dupré <xadupre@microsoft.com> Co-authored-by: Xavier Dupré <xadupre@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds a rewrite rule that eliminates redundant
Expandnodes preceding binary operators that natively support NumPy-style broadcasting.Pattern
Safety check
The rule applies a dimension-by-dimension analysis to determine if the
Expandis redundant. For each dimensioni, the expand is safe to remove if any of the following hold:expand_shape[i] == 1- expand cannot shrink a dimension, so it is a no-op.x.shape[i] == expand_shape[i]- the expand is a no-op at this dimension.y.shape[i] == expand_shape[i]-yalready covers the expansion via its own broadcasting.Otherwise the check fails conservatively. Three producer-agnostic strategies are used to resolve the expand target shape:
Constant expand shape: When the
shapeargument is a compile-time constant, the check is applied directly. Individual dimensions ofxorymay still be symbolic. For example,Add(Expand(x=[N], shape=[3,4]), y=[3,4])is optimized toAdd(x, y)becauseystatically provides all expansion dimensions.Expand output shape annotation: When
shapeis dynamic but the Expand node's output value already carries a shape annotation (e.g. after ONNX shape inference has been applied), those dimension values are used directly for the check. For example, afteronnx.shape_inference.infer_shapes,Expand(x=[N,1], Concat(Shape(x,0:1), Shape(x,1:2)))gets output shape[N,1]and the rule fires.Binary op output shape: When neither of the above is available, the rule verifies that
broadcast(x.shape, y.shape)symbolically equals the binary op's output shape. If they agree, the binary op's own broadcasting already accounts for all the expansion and the Expand is redundant.Supported ops
Add,Sub,Mul,Div,Pow,And,Or,Xor,BitwiseAnd,BitwiseOr,BitwiseXor,Greater,Less,Equal,GreaterOrEqual,LessOrEqual,Mod,PRelu,BitShiftChanges
_remove_expand_before_binary_op.py— new module with_ExpandFirstInput/_ExpandSecondInputrule classes,_compute_broadcast_shape/_check_dims_sufficienthelpers, and the exportedexpand_before_binary_op_rulesRewriteRuleSet; rule classes accesscontext.rootto obtain the Expand output and binary op output values_remove_expand_before_binary_op_test.py— tests covering removal when safe (including dynamic shapes via shape annotations and binary op output shape matching), and non-removal when the expansion cannot be statically verifiedrules/common/__init__.py— exportsexpand_before_binary_op_rulesOriginal prompt
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.