Commit 1ae5717
committed
Return a 1-D vector from torch.diagonal converter
`torch.diagonal(input, offset, dim1, dim2)` returns the requested
diagonal as a 1-D tensor (for 2-D input), but the converter used
`mb.band_part`, which only zeros the off-diagonal entries and returns
a same-shape matrix. As a result, `torch.diagonal(x)` for a 5x5 matrix
produced a 5x5 result instead of a length-5 vector.
Extract the diagonal by flattening the input and gathering the
elements at strides of `m + 1`, mirroring NumPy's row-major diagonal
indexing. Support `offset` and the `(dim1, dim2) == (1, 0)` transpose
case in addition to the default. Higher-rank input still raises
`NotImplementedError`, matching the pre-existing scope.
Verified end-to-end against PyTorch reference for shapes
{(5,5), (3,4), (4,3)}, offsets {-2,-1,0,1,2}, and dim swaps.
Fixes #2565.1 parent 5256644 commit 1ae5717
2 files changed
Lines changed: 84 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8358 | 8358 | | |
8359 | 8359 | | |
8360 | 8360 | | |
| 8361 | + | |
| 8362 | + | |
| 8363 | + | |
| 8364 | + | |
| 8365 | + | |
| 8366 | + | |
| 8367 | + | |
| 8368 | + | |
| 8369 | + | |
| 8370 | + | |
| 8371 | + | |
| 8372 | + | |
| 8373 | + | |
| 8374 | + | |
| 8375 | + | |
| 8376 | + | |
| 8377 | + | |
| 8378 | + | |
| 8379 | + | |
| 8380 | + | |
| 8381 | + | |
| 8382 | + | |
| 8383 | + | |
| 8384 | + | |
| 8385 | + | |
| 8386 | + | |
| 8387 | + | |
| 8388 | + | |
| 8389 | + | |
| 8390 | + | |
| 8391 | + | |
8361 | 8392 | | |
8362 | | - | |
8363 | | - | |
| 8393 | + | |
| 8394 | + | |
| 8395 | + | |
| 8396 | + | |
8364 | 8397 | | |
8365 | | - | |
| 8398 | + | |
| 8399 | + | |
| 8400 | + | |
| 8401 | + | |
| 8402 | + | |
| 8403 | + | |
| 8404 | + | |
| 8405 | + | |
| 8406 | + | |
| 8407 | + | |
8366 | 8408 | | |
8367 | 8409 | | |
8368 | 8410 | | |
| |||
Lines changed: 39 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6676 | 6676 | | |
6677 | 6677 | | |
6678 | 6678 | | |
| 6679 | + | |
| 6680 | + | |
| 6681 | + | |
| 6682 | + | |
| 6683 | + | |
| 6684 | + | |
| 6685 | + | |
| 6686 | + | |
| 6687 | + | |
| 6688 | + | |
| 6689 | + | |
| 6690 | + | |
| 6691 | + | |
| 6692 | + | |
| 6693 | + | |
| 6694 | + | |
| 6695 | + | |
| 6696 | + | |
| 6697 | + | |
| 6698 | + | |
| 6699 | + | |
| 6700 | + | |
| 6701 | + | |
| 6702 | + | |
| 6703 | + | |
| 6704 | + | |
| 6705 | + | |
| 6706 | + | |
| 6707 | + | |
| 6708 | + | |
| 6709 | + | |
| 6710 | + | |
| 6711 | + | |
| 6712 | + | |
| 6713 | + | |
| 6714 | + | |
| 6715 | + | |
| 6716 | + | |
| 6717 | + | |
6679 | 6718 | | |
6680 | 6719 | | |
6681 | 6720 | | |
| |||
0 commit comments