|
4 | 4 |
|
5 | 5 | import torch |
6 | 6 |
|
7 | | -from linear_operator.operators import MatmulLinearOperator |
| 7 | +from linear_operator.operators import DenseLinearOperator, DiagLinearOperator, MatmulLinearOperator |
8 | 8 | from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase, RectangularLinearOperatorTestCase |
9 | 9 |
|
10 | 10 |
|
@@ -56,5 +56,60 @@ def evaluate_linear_op(self, linear_op): |
56 | 56 | return linear_op.left_linear_op.tensor.matmul(linear_op.right_linear_op.tensor) |
57 | 57 |
|
58 | 58 |
|
| 59 | +class TestMatmulLinearOperatorDiagOptimization(unittest.TestCase): |
| 60 | + """Tests for efficient diagonal matrix multiplication in to_dense().""" |
| 61 | + |
| 62 | + def test_diag_left_matmul_to_dense(self): |
| 63 | + """Test D @ A uses element-wise multiplication.""" |
| 64 | + diag = torch.tensor([1.0, 2.0, 3.0, 4.0]) |
| 65 | + A = torch.randn(4, 5) |
| 66 | + |
| 67 | + D = DiagLinearOperator(diag) |
| 68 | + result = MatmulLinearOperator(D, DenseLinearOperator(A)) |
| 69 | + |
| 70 | + expected = torch.diag(diag) @ A |
| 71 | + self.assertTrue(torch.allclose(result.to_dense(), expected)) |
| 72 | + |
| 73 | + def test_diag_right_matmul_to_dense(self): |
| 74 | + """Test A @ D uses element-wise multiplication.""" |
| 75 | + diag = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) |
| 76 | + A = torch.randn(4, 5) |
| 77 | + |
| 78 | + D = DiagLinearOperator(diag) |
| 79 | + result = MatmulLinearOperator(DenseLinearOperator(A), D) |
| 80 | + |
| 81 | + expected = A @ torch.diag(diag) |
| 82 | + self.assertTrue(torch.allclose(result.to_dense(), expected)) |
| 83 | + |
| 84 | + def test_diag_sandwich_to_dense(self): |
| 85 | + """Test D1 @ A @ D2 uses element-wise multiplication (the main bug fix).""" |
| 86 | + diag1 = torch.tensor([1.0, 2.0, 3.0, 4.0]) |
| 87 | + diag2 = torch.tensor([0.5, 1.5, 2.5, 3.5]) |
| 88 | + A = torch.randn(4, 4) |
| 89 | + |
| 90 | + D1 = DiagLinearOperator(diag1) |
| 91 | + D2 = DiagLinearOperator(diag2) |
| 92 | + |
| 93 | + result = D1 @ DenseLinearOperator(A) @ D2 |
| 94 | + expected = torch.diag(diag1) @ A @ torch.diag(diag2) |
| 95 | + self.assertTrue(torch.allclose(result.to_dense(), expected)) |
| 96 | + |
| 97 | + def test_diag_sandwich_batch(self): |
| 98 | + """Test D1 @ A @ D2 with batch dimensions.""" |
| 99 | + batch_size = 3 |
| 100 | + n = 4 |
| 101 | + |
| 102 | + diag1 = torch.randn(batch_size, n).abs() |
| 103 | + diag2 = torch.randn(batch_size, n).abs() |
| 104 | + A = torch.randn(batch_size, n, n) |
| 105 | + |
| 106 | + D1 = DiagLinearOperator(diag1) |
| 107 | + D2 = DiagLinearOperator(diag2) |
| 108 | + |
| 109 | + result = D1 @ DenseLinearOperator(A) @ D2 |
| 110 | + expected = torch.diag_embed(diag1) @ A @ torch.diag_embed(diag2) |
| 111 | + self.assertTrue(torch.allclose(result.to_dense(), expected)) |
| 112 | + |
| 113 | + |
59 | 114 | if __name__ == "__main__": |
60 | 115 | unittest.main() |
0 commit comments