-
Notifications
You must be signed in to change notification settings - Fork 111
Fix include_self for scatter_reduce #2090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 26 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
c451d4d
Fix include_self for scatter_reduce
xadupre 87c4085
fix values
xadupre 0a6c181
lint
xadupre a11b84a
Merge branch 'main' of https://github.com/microsoft/onnxscript into s…
xadupre 9aaf142
Update onnxscript/function_libs/torch_lib/ops/core.py
xadupre b89b653
Merge branch 'scatter' of https://github.com/xadupre/onnxscript into …
xadupre 27d0ff7
more comments
xadupre 2ed4f9b
dtype
xadupre 8813d44
Merge branch 'main' into scatter
xadupre f00fe37
use inf
xadupre e662f81
fix one bug
xadupre 914283f
Update tests/function_libs/torch_lib/ops_test_data.py
justinchuby ba2a563
Add float16 dtype to xfail tests
justinchuby 6e9e756
Fix tuple syntax for torch.float16 dtypes
justinchuby d7ab3b8
fix dtype
xadupre dd5e5d3
Merge branch 'main' of https://github.com/microsoft/onnxscript into s…
xadupre 8af5353
simple try
xadupre 559888d
variant
xadupre 3976cd6
lint
xadupre 020ec1c
Update onnxscript/function_libs/torch_lib/ops/core.py
xadupre 02fdc55
fix missing type
xadupre f8071df
merge
xadupre b6b57f7
disable two tests
xadupre 8ce51a5
Merge branch 'main' of https://github.com/microsoft/onnxscript into s…
xadupre 5c063f9
fix remaining test
xadupre d2a40e5
Update tests/function_libs/torch_lib/e2e_ops_tests.py
xadupre f86e2f6
lint
xadupre 673a32d
fix merhe
xadupre a4d17ff
comment
xadupre 7fa5127
Merge branch 'main' into scatter
justinchuby 33eac3e
Refactor get_constant_value function
justinchuby fe0c0ed
Apply suggestions from code review
justinchuby 2743b8e
Merge branch 'main' into scatter
titaiwangms 7fc0c27
Merge branch 'main' into scatter
justinchuby 0f1e996
Add dtype assertion and fix dtype checks
justinchuby dfd613b
Replace ml_dtypes with torch for BFLOAT16
justinchuby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
|
|
||
| # Licensed under the MIT License. | ||
|
xadupre marked this conversation as resolved.
|
||
|
|
||
| # TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo | ||
|
|
||
| import itertools | ||
|
|
||
| import unittest | ||
|
|
||
| import onnxruntime | ||
| import torch | ||
|
|
||
| from tests.common import testutils | ||
|
|
||
|
|
||
| class TorchLibe2eTest(testutils.TestBase): | ||
|
xadupre marked this conversation as resolved.
|
||
| def test_investigate_one_particular_model(self): | ||
| """This test can be used to investigate a particular issue.""" | ||
| red, include, stype = "amin", False, "int32" | ||
| dtype = getattr(torch, stype) | ||
|
|
||
| class Model(torch.nn.Module): | ||
| def __init__(self, include, red): | ||
| super().__init__() | ||
| self.include = include | ||
| self.red = red | ||
|
|
||
| def forward(self, x, indices, updates): | ||
| x = x.clone() | ||
| return x.scatter_reduce( | ||
| 0, indices, updates, self.red, include_self=self.include | ||
| ) | ||
|
|
||
| model = Model(include, red) | ||
| xs = ( | ||
| torch.tensor([[-2, 0, 2], [2, -2, 0]], dtype=dtype), | ||
| torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64), | ||
| torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype), | ||
| ) | ||
| expected = model(*xs) | ||
| model_path = ( | ||
| f"test_aten_scatter_{red}_" | ||
| f"{'include' if include else 'exclude'}_{stype}.onnx" | ||
| ) | ||
| torch.onnx.export(model, xs, model_path, dynamo=True) | ||
| feeds = dict(zip(["x", "indices", "updates"], [x.numpy() for x in xs])) | ||
|
|
||
| sess_options = onnxruntime.SessionOptions() | ||
| sess = onnxruntime.InferenceSession( | ||
| model_path, sess_options=sess_options, providers=["CPUExecutionProvider"] | ||
| ) | ||
| got = sess.run(None, feeds)[0] | ||
| torch.testing.assert_close( | ||
| expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5 | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.