From 911afda338a9dfb921392e90013e2c27a6895a1c Mon Sep 17 00:00:00 2001 From: devteamaegis Date: Wed, 17 Jun 2026 15:21:31 -0400 Subject: [PATCH] Fix MaskedLinearOperator.to(device=...) crash when no dtype given _to_helper returns dtype=None when only a device is passed. The MaskedLinearOperator.to override then accessed dtype.is_floating_point unconditionally, raising AttributeError. Guard for dtype is None so a device-only move keeps masks boolean, matching the base to() behavior. --- linear_operator/operators/masked_linear_operator.py | 6 +++++- test/operators/test_masked_linear_operator.py | 8 ++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/linear_operator/operators/masked_linear_operator.py b/linear_operator/operators/masked_linear_operator.py index 782a4313..ef93fc51 100644 --- a/linear_operator/operators/masked_linear_operator.py +++ b/linear_operator/operators/masked_linear_operator.py @@ -138,7 +138,11 @@ def to( new_kwargs = {} for arg in self._args: if hasattr(arg, "to"): - if hasattr(arg, "dtype") and arg.dtype.is_floating_point == dtype.is_floating_point: + if ( + dtype is not None + and hasattr(arg, "dtype") + and arg.dtype.is_floating_point == dtype.is_floating_point + ): new_args.append(arg.to(dtype=dtype, device=device)) else: new_args.append(arg.to(device=device)) diff --git a/test/operators/test_masked_linear_operator.py b/test/operators/test_masked_linear_operator.py index 4c6c33aa..a71c0c52 100644 --- a/test/operators/test_masked_linear_operator.py +++ b/test/operators/test_masked_linear_operator.py @@ -38,6 +38,14 @@ def test_to_double(self): self.assertEqual(linear_op.col_mask.dtype, torch.bool) self.assertEqual(linear_op.row_mask.dtype, torch.bool) + def test_to_device(self): + # `.to(device=...)` with no dtype should not raise and should keep masks boolean. + linear_op = self.create_linear_op() + linear_op = linear_op.to(device=torch.device("cpu")) + self.assertEqual(linear_op.device.type, "cpu") + self.assertEqual(linear_op.col_mask.dtype, torch.bool) + self.assertEqual(linear_op.row_mask.dtype, torch.bool) + class TestMaskedLinearOperatorBatch(LinearOperatorTestCase, unittest.TestCase): seed = 2023