[torch_lib] Fix torchvision_roi_align signature mismatch for PyTorch 2.10+#2848
[torch_lib] Fix torchvision_roi_align signature mismatch for PyTorch 2.10+#28486zaille wants to merge 2 commits intomicrosoft:mainfrom
Conversation
|
@microsoft-github-policy-service agree |
|
Hi! I've signed the CLA. This PR fixes the torchvision_roi_align signature mismatch reported in PyTorch issue #138133. Tests passed locally. |
|
This seems to be a duplicate of #2830. I would like to know since which version the schema was changed, so we can define the op properly in a version compatible way. |
|
Hi @justinchuby, Regarding the version: I encountered this issue using PyTorch Nightly (2.10.0.dev). It seems the dispatcher now flattens the output_size tuple into pooled_height and pooled_width for the Dynamo-based exporter. If #2830 already covers the fix, feel free to close this. |
|
The most important info, if you can provide, is help us determine when this change happened by testing with older pytorch versions. Thanks! |
|
Merged #2830 |
| ): | ||
| """roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor""" | ||
| pooled_height, pooled_width = output_size | ||
| """roi_align(input, boxes, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned)""" |
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
| ): | ||
| """roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor""" | ||
| pooled_height, pooled_width = output_size | ||
| """roi_align(input, boxes, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned)""" |
Check warning
Code scanning / lintrunner
RUFF/W291 Warning
| @@ -0,0 +1,50 @@ | |||
| # Copyright (c) Microsoft Corporation. | |||
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning test
| @@ -0,0 +1,50 @@ | |||
| # Copyright (c) Microsoft Corporation. | |||
Check warning
Code scanning / lintrunner
RUFF/format Warning test
| @@ -0,0 +1,50 @@ | |||
| # Copyright (c) Microsoft Corporation. | |||
Check warning
Code scanning / lintrunner
RUFF-FORMAT/format Warning test
| return torchvision.ops.roi_align( | ||
| x, | ||
| boxes, | ||
| output_size=(7, 7), |
Check warning
Code scanning / lintrunner
RUFF/W291 Warning test
| x, | ||
| boxes, | ||
| output_size=(7, 7), | ||
| spatial_scale=0.5, |
Check warning
Code scanning / lintrunner
RUFF/W291 Warning test
| boxes, | ||
| output_size=(7, 7), | ||
| spatial_scale=0.5, | ||
| sampling_ratio=2, |
Check warning
Code scanning / lintrunner
RUFF/W291 Warning test
| try: | ||
| export(model, (x, boxes), self.model_path) | ||
| export_success = True | ||
| except Exception as e: |
Check warning
Code scanning / lintrunner
PYLINT/W0718 Warning test
| import torch | ||
| import torchvision | ||
| from torch.onnx import export | ||
| import os |
Check notice
Code scanning / lintrunner
PYLINT/C0411 Note test
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2848 +/- ##
==========================================
+ Coverage 71.86% 71.90% +0.03%
==========================================
Files 239 239
Lines 29137 29136 -1
Branches 2875 2875
==========================================
+ Hits 20940 20949 +9
+ Misses 7219 7199 -20
- Partials 978 988 +10 ☔ View full report in Codecov by Sentry. |
This PR updates the torchvision_roi_align signature in onnxscript to support 7 positional arguments.
In recent PyTorch versions (2.10 and newer), the Dynamo-based ONNX exporter has updated the way it flattens operators. Specifically, for roi_align, the output_size (previously a Sequence[int]) is now decomposed into two separate integer arguments: pooled_height and pooled_width.
Previously, the onnxscript implementation expected 6 arguments, leading to a TypeError: torchvision_roi_align() takes from 3 to 6 positional arguments but 7 were given during the ONNX translation phase.
Changes :
Testing :
Added a new test case in tests/function_libs/torch_lib/ops/vision_test.py that mocks a torchvision.ops.roi_align call and verifies successful ONNX export using the latest PyTorch Nightly.
Thanks for reading my PR 👍