Skip to content

Add fusion rule to remove Expand before broadcast-capable binary operators#2862

Merged
gramalingam merged 12 commits intomainfrom
copilot/create-fusion-rule-remove-expand-node
Apr 10, 2026
Merged

Add fusion rule to remove Expand before broadcast-capable binary operators#2862
gramalingam merged 12 commits intomainfrom
copilot/create-fusion-rule-remove-expand-node

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Mar 20, 2026

Adds a rewrite rule that eliminates redundant Expand nodes preceding binary operators that natively support NumPy-style broadcasting.

Pattern

BinaryOp(Expand(x, shape), y)  →  BinaryOp(x, y)
BinaryOp(x, Expand(y, shape))  →  BinaryOp(x, y)

Safety check

The rule applies a dimension-by-dimension analysis to determine if the Expand is redundant. For each dimension i, the expand is safe to remove if any of the following hold:

  • expand_shape[i] == 1 - expand cannot shrink a dimension, so it is a no-op.
  • x.shape[i] == expand_shape[i] - the expand is a no-op at this dimension.
  • y.shape[i] == expand_shape[i] - y already covers the expansion via its own broadcasting.

Otherwise the check fails conservatively. Three producer-agnostic strategies are used to resolve the expand target shape:

  1. Constant expand shape: When the shape argument is a compile-time constant, the check is applied directly. Individual dimensions of x or y may still be symbolic. For example, Add(Expand(x=[N], shape=[3,4]), y=[3,4]) is optimized to Add(x, y) because y statically provides all expansion dimensions.

  2. Expand output shape annotation: When shape is dynamic but the Expand node's output value already carries a shape annotation (e.g. after ONNX shape inference has been applied), those dimension values are used directly for the check. For example, after onnx.shape_inference.infer_shapes, Expand(x=[N,1], Concat(Shape(x,0:1), Shape(x,1:2))) gets output shape [N,1] and the rule fires.

  3. Binary op output shape: When neither of the above is available, the rule verifies that broadcast(x.shape, y.shape) symbolically equals the binary op's output shape. If they agree, the binary op's own broadcasting already accounts for all the expansion and the Expand is redundant.

Supported ops

Add, Sub, Mul, Div, Pow, And, Or, Xor, BitwiseAnd, BitwiseOr, BitwiseXor, Greater, Less, Equal, GreaterOrEqual, LessOrEqual, Mod, PRelu, BitShift

Changes

  • _remove_expand_before_binary_op.py — new module with _ExpandFirstInput / _ExpandSecondInput rule classes, _compute_broadcast_shape / _check_dims_sufficient helpers, and the exported expand_before_binary_op_rules RewriteRuleSet; rule classes access context.root to obtain the Expand output and binary op output values
  • _remove_expand_before_binary_op_test.py — tests covering removal when safe (including dynamic shapes via shape annotations and binary op output shape matching), and non-removal when the expansion cannot be statically verified
  • rules/common/__init__.py — exports expand_before_binary_op_rules
Original prompt

This section details on the original issue you should resolve

<issue_title>create a fusion rule to remove an expand node before a binary operator if this op can handle it through broadcasting</issue_title>
<issue_description></issue_description>

Comments on the Issue (you are @copilot in this section)


✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

… operators

Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
Copilot AI changed the title [WIP] Add fusion rule to remove expand node before binary operator Add fusion rule to remove Expand before broadcast-capable binary operators Mar 20, 2026
Copilot AI requested a review from xadupre March 20, 2026 16:29
Comment thread onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py Outdated
Comment thread onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py Outdated
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
xadupre added 3 commits March 20, 2026 18:23
Signed-off-by: Xavier Dupré <xadupre@microsoft.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 20, 2026

❌ 1 Tests Failed:

Tests completed Failed Passed Skipped
15625 1 15624 937
View the full list of 1 ❄️ flaky test(s)
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__logsumexp_cpu_float16

Flake rate in main: 14.29% (Passed 912 times, Failed 152 times)

Stack Traces | 0.537s run time
.../function_libs/torch_lib/ops_test.py:243: in run_test_output_match
    torch.testing.assert_close(
E   AssertionError: Tensor-likes are not close!
E   
E   Mismatched elements: 1 / 5 (20.0%)
E   Greatest absolute difference: 2.288818359375e-05 at index (1,) (up to 1e-05 allowed)
E   Greatest relative difference: 0.0022869110107421875 at index (1,) (up to 0.001 allowed)

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
Comment thread onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py Fixed
Comment thread onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py Outdated
…adcast comparison

Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/4d4f4fb8-b66e-456e-a1d6-b1eb5ca1b532
Copilot AI requested a review from xadupre March 23, 2026 12:17
@xadupre xadupre marked this pull request as ready for review March 25, 2026 10:06
Comment thread onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py
Comment thread onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py
Copy link
Copy Markdown
Collaborator

@gramalingam gramalingam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some improvements are possible. Will do it myself later.

@gramalingam gramalingam enabled auto-merge (squash) April 10, 2026 21:07
@gramalingam gramalingam merged commit 6c092e2 into main Apr 10, 2026
50 of 57 checks passed
@gramalingam gramalingam deleted the copilot/create-fusion-rule-remove-expand-node branch April 10, 2026 21:27
justinchuby pushed a commit that referenced this pull request Apr 17, 2026
…ators (#2862)

Adds a rewrite rule that eliminates redundant `Expand` nodes preceding
binary operators that natively support NumPy-style broadcasting.

## Pattern

```
BinaryOp(Expand(x, shape), y)  →  BinaryOp(x, y)
BinaryOp(x, Expand(y, shape))  →  BinaryOp(x, y)
```

## Safety check

The rule applies a dimension-by-dimension analysis to determine if the
`Expand` is redundant. For each dimension `i`, the expand is safe to
remove if any of the following hold:

- `expand_shape[i] == 1` - expand cannot shrink a dimension, so it is a
no-op.
- `x.shape[i] == expand_shape[i]` - the expand is a no-op at this
dimension.
- `y.shape[i] == expand_shape[i]` - `y` already covers the expansion via
its own broadcasting.

Otherwise the check fails conservatively. Three producer-agnostic
strategies are used to resolve the expand target shape:

1. **Constant expand shape**: When the `shape` argument is a
compile-time constant, the check is applied directly. Individual
dimensions of `x` or `y` may still be symbolic. For example,
`Add(Expand(x=[N], shape=[3,4]), y=[3,4])` is optimized to `Add(x, y)`
because `y` statically provides all expansion dimensions.

2. **Expand output shape annotation**: When `shape` is dynamic but the
Expand node's output value already carries a shape annotation (e.g.
after ONNX shape inference has been applied), those dimension values are
used directly for the check. For example, after
`onnx.shape_inference.infer_shapes`, `Expand(x=[N,1],
Concat(Shape(x,0:1), Shape(x,1:2)))` gets output shape `[N,1]` and the
rule fires.

3. **Binary op output shape**: When neither of the above is available,
the rule verifies that `broadcast(x.shape, y.shape)` symbolically equals
the binary op's output shape. If they agree, the binary op's own
broadcasting already accounts for all the expansion and the Expand is
redundant.

## Supported ops

`Add`, `Sub`, `Mul`, `Div`, `Pow`, `And`, `Or`, `Xor`, `BitwiseAnd`,
`BitwiseOr`, `BitwiseXor`, `Greater`, `Less`, `Equal`, `GreaterOrEqual`,
`LessOrEqual`, `Mod`, `PRelu`, `BitShift`

## Changes

- **`_remove_expand_before_binary_op.py`** — new module with
`_ExpandFirstInput` / `_ExpandSecondInput` rule classes,
`_compute_broadcast_shape` / `_check_dims_sufficient` helpers, and the
exported `expand_before_binary_op_rules` `RewriteRuleSet`; rule classes
access `context.root` to obtain the Expand output and binary op output
values
- **`_remove_expand_before_binary_op_test.py`** — tests covering removal
when safe (including dynamic shapes via shape annotations and binary op
output shape matching), and non-removal when the expansion cannot be
statically verified
- **`rules/common/__init__.py`** — exports
`expand_before_binary_op_rules`

<!-- START COPILOT ORIGINAL PROMPT -->



<details>

<summary>Original prompt</summary>

> 
> ----
> 
> *This section details on the original issue you should resolve*
> 
> <issue_title>create a fusion rule to remove an expand node before a
binary operator if this op can handle it through
broadcasting</issue_title>
> <issue_description></issue_description>
> 
> ## Comments on the Issue (you are @copilot in this section)
> 
> <comments>
> </comments>
> 


</details>



<!-- START COPILOT CODING AGENT SUFFIX -->

- Fixes #2861

<!-- START COPILOT CODING AGENT TIPS -->
---

✨ Let Copilot coding agent [set things up for
you](https://github.com/microsoft/onnxscript/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot)
— coding agent works faster and does higher quality work when set up for
your repo.

---------

Signed-off-by: Xavier Dupré <xadupre@microsoft.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
Co-authored-by: Xavier Dupré <xadupre@microsoft.com>
Co-authored-by: Xavier Dupré <xadupre@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

create a fusion rule to remove an expand node before a binary operator if this op can handle it through broadcasting

4 participants