diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index f75051467..be20e5211 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -8358,11 +8358,53 @@ def _parse_keyword_args(context, node, offset, dim1, dim2) -> Tuple[Var]: x, offset, dim1, dim2 = _parse_positional_args(context, node) offset, dim1, dim2 = _parse_keyword_args(context, node, offset, dim1, dim2) + if isinstance(offset, Var): + offset = offset.val + if isinstance(dim1, Var): + dim1 = dim1.val + if isinstance(dim2, Var): + dim2 = dim2.val + + # torch.diagonal(input, offset=0, dim1=0, dim2=1) returns the requested + # diagonal as a 1-D tensor (for 2-D input) by extracting elements at + # input[i + max(-offset, 0), i + max(offset, 0)]. The previous + # implementation used band_part, which only zeros out the off-diagonal + # entries and returns the same-shape matrix. + if x.rank != 2: + raise NotImplementedError( + f"diagonal currently supports 2-D input only, got rank {x.rank}" + ) + if any_symbolic(x.shape): + raise NotImplementedError("diagonal requires a statically-shaped input") + + if dim1 < 0: + dim1 += x.rank + if dim2 < 0: + dim2 += x.rank + if (dim1, dim2) == (1, 0): + # diagonal along (dim1=1, dim2=0) equals diagonal along (0, 1) with + # sign-flipped offset. + offset = -offset + elif (dim1, dim2) != (0, 1): + raise NotImplementedError( + f"diagonal supports dim1/dim2 in (0, 1) or (1, 0), got ({dim1}, {dim2})" + ) - if offset == 0 and dim1 == 0 and dim2 == 1: - diagonal = mb.band_part(x=x, lower=0, upper=0, name=node.name) + n, m = x.shape + if offset >= 0: + diag_len = builtins.min(n, m - offset) + start = offset else: - raise NotImplementedError("Only offset == 0 and dim1 == 0 and dim2 == 1 handled") + diag_len = builtins.min(n + offset, m) + start = -offset * m + if diag_len <= 0: + raise ValueError( + f"diagonal offset {offset} produces an empty diagonal for shape ({n}, {m})" + ) + + indices = [start + i * (m + 1) for i in range(diag_len)] + flat = mb.reshape(x=x, shape=[-1]) + diagonal = mb.gather(x=flat, indices=indices, axis=0, name=node.name) context.add(diagonal) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index dc17406d6..d82752929 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -6676,6 +6676,45 @@ def test_div(self, compute_unit, backend, frontend, rounding_mode, x2_type): ) +class TestDiagonal(TorchBaseTest): + @pytest.mark.parametrize( + "compute_unit, backend, frontend, shape, offset, dim1, dim2", + itertools.product( + compute_units, + backends, + frontends, + [(5, 5), (3, 4), (4, 3)], + [-2, -1, 0, 1, 2], + [0, 1], + [0, 1], + ), + ) + def test_diagonal(self, compute_unit, backend, frontend, shape, offset, dim1, dim2): + # Regression test for #2565: previously diagonal returned a same-shape + # matrix with off-diagonal zeroed out instead of a 1-D vector. + if dim1 == dim2: + pytest.skip("dim1 must differ from dim2") + n, m = shape + eff_offset = -offset if (dim1, dim2) == (1, 0) else offset + diag_len = ( + min(n, m - eff_offset) if eff_offset >= 0 else min(n + eff_offset, m) + ) + if diag_len <= 0: + pytest.skip("offset produces empty diagonal") + + class TestModel(nn.Module): + def forward(self, x): + return torch.diagonal(x, offset=offset, dim1=dim1, dim2=dim2) + + self.run_compare_torch( + shape, + TestModel(), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + ) + + class TestElementWiseUnary(TorchBaseTest): @pytest.mark.parametrize( "compute_unit, backend, frontend, shape, op_string",