diff --git a/linear_operator/operators/masked_linear_operator.py b/linear_operator/operators/masked_linear_operator.py index 782a4313..ef93fc51 100644 --- a/linear_operator/operators/masked_linear_operator.py +++ b/linear_operator/operators/masked_linear_operator.py @@ -138,7 +138,11 @@ def to( new_kwargs = {} for arg in self._args: if hasattr(arg, "to"): - if hasattr(arg, "dtype") and arg.dtype.is_floating_point == dtype.is_floating_point: + if ( + dtype is not None + and hasattr(arg, "dtype") + and arg.dtype.is_floating_point == dtype.is_floating_point + ): new_args.append(arg.to(dtype=dtype, device=device)) else: new_args.append(arg.to(device=device)) diff --git a/test/operators/test_masked_linear_operator.py b/test/operators/test_masked_linear_operator.py index 4c6c33aa..a71c0c52 100644 --- a/test/operators/test_masked_linear_operator.py +++ b/test/operators/test_masked_linear_operator.py @@ -38,6 +38,14 @@ def test_to_double(self): self.assertEqual(linear_op.col_mask.dtype, torch.bool) self.assertEqual(linear_op.row_mask.dtype, torch.bool) + def test_to_device(self): + # `.to(device=...)` with no dtype should not raise and should keep masks boolean. + linear_op = self.create_linear_op() + linear_op = linear_op.to(device=torch.device("cpu")) + self.assertEqual(linear_op.device.type, "cpu") + self.assertEqual(linear_op.col_mask.dtype, torch.bool) + self.assertEqual(linear_op.row_mask.dtype, torch.bool) + class TestMaskedLinearOperatorBatch(LinearOperatorTestCase, unittest.TestCase): seed = 2023