Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8358,11 +8358,53 @@ def _parse_keyword_args(context, node, offset, dim1, dim2) -> Tuple[Var]:

x, offset, dim1, dim2 = _parse_positional_args(context, node)
offset, dim1, dim2 = _parse_keyword_args(context, node, offset, dim1, dim2)
if isinstance(offset, Var):
offset = offset.val
if isinstance(dim1, Var):
dim1 = dim1.val
if isinstance(dim2, Var):
dim2 = dim2.val

# torch.diagonal(input, offset=0, dim1=0, dim2=1) returns the requested
# diagonal as a 1-D tensor (for 2-D input) by extracting elements at
# input[i + max(-offset, 0), i + max(offset, 0)]. The previous
# implementation used band_part, which only zeros out the off-diagonal
# entries and returns the same-shape matrix.
if x.rank != 2:
raise NotImplementedError(
f"diagonal currently supports 2-D input only, got rank {x.rank}"
)
if any_symbolic(x.shape):
raise NotImplementedError("diagonal requires a statically-shaped input")

if dim1 < 0:
dim1 += x.rank
if dim2 < 0:
dim2 += x.rank
if (dim1, dim2) == (1, 0):
# diagonal along (dim1=1, dim2=0) equals diagonal along (0, 1) with
# sign-flipped offset.
offset = -offset
elif (dim1, dim2) != (0, 1):
raise NotImplementedError(
f"diagonal supports dim1/dim2 in (0, 1) or (1, 0), got ({dim1}, {dim2})"
)

if offset == 0 and dim1 == 0 and dim2 == 1:
diagonal = mb.band_part(x=x, lower=0, upper=0, name=node.name)
n, m = x.shape
if offset >= 0:
diag_len = builtins.min(n, m - offset)
start = offset
else:
raise NotImplementedError("Only offset == 0 and dim1 == 0 and dim2 == 1 handled")
diag_len = builtins.min(n + offset, m)
start = -offset * m
if diag_len <= 0:
raise ValueError(
f"diagonal offset {offset} produces an empty diagonal for shape ({n}, {m})"
)

indices = [start + i * (m + 1) for i in range(diag_len)]
flat = mb.reshape(x=x, shape=[-1])
diagonal = mb.gather(x=flat, indices=indices, axis=0, name=node.name)

context.add(diagonal)

Expand Down
39 changes: 39 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6676,6 +6676,45 @@ def test_div(self, compute_unit, backend, frontend, rounding_mode, x2_type):
)


class TestDiagonal(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, shape, offset, dim1, dim2",
itertools.product(
compute_units,
backends,
frontends,
[(5, 5), (3, 4), (4, 3)],
[-2, -1, 0, 1, 2],
[0, 1],
[0, 1],
),
)
def test_diagonal(self, compute_unit, backend, frontend, shape, offset, dim1, dim2):
# Regression test for #2565: previously diagonal returned a same-shape
# matrix with off-diagonal zeroed out instead of a 1-D vector.
if dim1 == dim2:
pytest.skip("dim1 must differ from dim2")
n, m = shape
eff_offset = -offset if (dim1, dim2) == (1, 0) else offset
diag_len = (
min(n, m - eff_offset) if eff_offset >= 0 else min(n + eff_offset, m)
)
if diag_len <= 0:
pytest.skip("offset produces empty diagonal")

class TestModel(nn.Module):
def forward(self, x):
return torch.diagonal(x, offset=offset, dim1=dim1, dim2=dim2)

self.run_compare_torch(
shape,
TestModel(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
)


class TestElementWiseUnary(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, shape, op_string",
Expand Down