Skip to content

Commit 1ae5717

Browse files
committed
Return a 1-D vector from torch.diagonal converter
`torch.diagonal(input, offset, dim1, dim2)` returns the requested diagonal as a 1-D tensor (for 2-D input), but the converter used `mb.band_part`, which only zeros the off-diagonal entries and returns a same-shape matrix. As a result, `torch.diagonal(x)` for a 5x5 matrix produced a 5x5 result instead of a length-5 vector. Extract the diagonal by flattening the input and gathering the elements at strides of `m + 1`, mirroring NumPy's row-major diagonal indexing. Support `offset` and the `(dim1, dim2) == (1, 0)` transpose case in addition to the default. Higher-rank input still raises `NotImplementedError`, matching the pre-existing scope. Verified end-to-end against PyTorch reference for shapes {(5,5), (3,4), (4,3)}, offsets {-2,-1,0,1,2}, and dim swaps. Fixes #2565.
1 parent 5256644 commit 1ae5717

2 files changed

Lines changed: 84 additions & 3 deletions

File tree

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8358,11 +8358,53 @@ def _parse_keyword_args(context, node, offset, dim1, dim2) -> Tuple[Var]:
83588358

83598359
x, offset, dim1, dim2 = _parse_positional_args(context, node)
83608360
offset, dim1, dim2 = _parse_keyword_args(context, node, offset, dim1, dim2)
8361+
if isinstance(offset, Var):
8362+
offset = offset.val
8363+
if isinstance(dim1, Var):
8364+
dim1 = dim1.val
8365+
if isinstance(dim2, Var):
8366+
dim2 = dim2.val
8367+
8368+
# torch.diagonal(input, offset=0, dim1=0, dim2=1) returns the requested
8369+
# diagonal as a 1-D tensor (for 2-D input) by extracting elements at
8370+
# input[i + max(-offset, 0), i + max(offset, 0)]. The previous
8371+
# implementation used band_part, which only zeros out the off-diagonal
8372+
# entries and returns the same-shape matrix.
8373+
if x.rank != 2:
8374+
raise NotImplementedError(
8375+
f"diagonal currently supports 2-D input only, got rank {x.rank}"
8376+
)
8377+
if any_symbolic(x.shape):
8378+
raise NotImplementedError("diagonal requires a statically-shaped input")
8379+
8380+
if dim1 < 0:
8381+
dim1 += x.rank
8382+
if dim2 < 0:
8383+
dim2 += x.rank
8384+
if (dim1, dim2) == (1, 0):
8385+
# diagonal along (dim1=1, dim2=0) equals diagonal along (0, 1) with
8386+
# sign-flipped offset.
8387+
offset = -offset
8388+
elif (dim1, dim2) != (0, 1):
8389+
raise NotImplementedError(
8390+
f"diagonal supports dim1/dim2 in (0, 1) or (1, 0), got ({dim1}, {dim2})"
8391+
)
83618392

8362-
if offset == 0 and dim1 == 0 and dim2 == 1:
8363-
diagonal = mb.band_part(x=x, lower=0, upper=0, name=node.name)
8393+
n, m = x.shape
8394+
if offset >= 0:
8395+
diag_len = builtins.min(n, m - offset)
8396+
start = offset
83648397
else:
8365-
raise NotImplementedError("Only offset == 0 and dim1 == 0 and dim2 == 1 handled")
8398+
diag_len = builtins.min(n + offset, m)
8399+
start = -offset * m
8400+
if diag_len <= 0:
8401+
raise ValueError(
8402+
f"diagonal offset {offset} produces an empty diagonal for shape ({n}, {m})"
8403+
)
8404+
8405+
indices = [start + i * (m + 1) for i in range(diag_len)]
8406+
flat = mb.reshape(x=x, shape=[-1])
8407+
diagonal = mb.gather(x=flat, indices=indices, axis=0, name=node.name)
83668408

83678409
context.add(diagonal)
83688410

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6676,6 +6676,45 @@ def test_div(self, compute_unit, backend, frontend, rounding_mode, x2_type):
66766676
)
66776677

66786678

6679+
class TestDiagonal(TorchBaseTest):
6680+
@pytest.mark.parametrize(
6681+
"compute_unit, backend, frontend, shape, offset, dim1, dim2",
6682+
itertools.product(
6683+
compute_units,
6684+
backends,
6685+
frontends,
6686+
[(5, 5), (3, 4), (4, 3)],
6687+
[-2, -1, 0, 1, 2],
6688+
[0, 1],
6689+
[0, 1],
6690+
),
6691+
)
6692+
def test_diagonal(self, compute_unit, backend, frontend, shape, offset, dim1, dim2):
6693+
# Regression test for #2565: previously diagonal returned a same-shape
6694+
# matrix with off-diagonal zeroed out instead of a 1-D vector.
6695+
if dim1 == dim2:
6696+
pytest.skip("dim1 must differ from dim2")
6697+
n, m = shape
6698+
eff_offset = -offset if (dim1, dim2) == (1, 0) else offset
6699+
diag_len = (
6700+
min(n, m - eff_offset) if eff_offset >= 0 else min(n + eff_offset, m)
6701+
)
6702+
if diag_len <= 0:
6703+
pytest.skip("offset produces empty diagonal")
6704+
6705+
class TestModel(nn.Module):
6706+
def forward(self, x):
6707+
return torch.diagonal(x, offset=offset, dim1=dim1, dim2=dim2)
6708+
6709+
self.run_compare_torch(
6710+
shape,
6711+
TestModel(),
6712+
frontend=frontend,
6713+
backend=backend,
6714+
compute_unit=compute_unit,
6715+
)
6716+
6717+
66796718
class TestElementWiseUnary(TorchBaseTest):
66806719
@pytest.mark.parametrize(
66816720
"compute_unit, backend, frontend, shape, op_string",

0 commit comments

Comments
 (0)