Commit f8bda5d
[PyTorch] Make
* make modules.GroupedLinear graph-safe
Signed-off-by: Xin Yao <xiny@nvidia.com>
* fix tests
Signed-off-by: Xin Yao <xiny@nvidia.com>
* Review suggestions
Handle tensor splits in both legacy and graph-safe impls. Create weight grad tensors as subviews of a larger buffer.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>modules.GroupedLinear graph-safe (#3038)1 parent 9e5a847 commit f8bda5d
5 files changed
Lines changed: 2326 additions & 1322 deletions
File tree
- benchmarks/linear
- qa/L0_pytorch_unittest
- tests/pytorch
- transformer_engine/pytorch/module
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
| 6 | + | |
6 | 7 | | |
7 | 8 | | |
8 | 9 | | |
| |||
185 | 186 | | |
186 | 187 | | |
187 | 188 | | |
| 189 | + | |
| 190 | + | |
188 | 191 | | |
189 | 192 | | |
190 | 193 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
| 32 | + | |
32 | 33 | | |
33 | 34 | | |
34 | 35 | | |
| |||
0 commit comments