Skip to content

Commit bf1c913

Browse files
authored
Merge pull request #122 from Balandat/speedup_matmul_linop_todense_diag
Speed up MatmulLinearOperator.to_dense() for DiagLazyTensors
2 parents cd6ec0d + 2606fb1 commit bf1c913

2 files changed

Lines changed: 62 additions & 1 deletion

File tree

linear_operator/operators/matmul_linear_operator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,10 @@ def _transpose_nonbatch(
135135
def to_dense(
136136
self: LinearOperator, # shape: (*batch, M, N)
137137
) -> Tensor: # shape: (*batch, M, N)
138+
# Use element-wise multiplication for DiagLinearOperators
139+
if isinstance(self.left_linear_op, DiagLinearOperator):
140+
return self.left_linear_op._diag.unsqueeze(-1) * self.right_linear_op.to_dense()
141+
if isinstance(self.right_linear_op, DiagLinearOperator):
142+
return self.left_linear_op.to_dense() * self.right_linear_op._diag.unsqueeze(-2)
143+
138144
return torch.matmul(self.left_linear_op.to_dense(), self.right_linear_op.to_dense())

test/operators/test_matmul_linear_operator.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from linear_operator.operators import MatmulLinearOperator
7+
from linear_operator.operators import DenseLinearOperator, DiagLinearOperator, MatmulLinearOperator
88
from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase, RectangularLinearOperatorTestCase
99

1010

@@ -56,5 +56,60 @@ def evaluate_linear_op(self, linear_op):
5656
return linear_op.left_linear_op.tensor.matmul(linear_op.right_linear_op.tensor)
5757

5858

59+
class TestMatmulLinearOperatorDiagOptimization(unittest.TestCase):
60+
"""Tests for efficient diagonal matrix multiplication in to_dense()."""
61+
62+
def test_diag_left_matmul_to_dense(self):
63+
"""Test D @ A uses element-wise multiplication."""
64+
diag = torch.tensor([1.0, 2.0, 3.0, 4.0])
65+
A = torch.randn(4, 5)
66+
67+
D = DiagLinearOperator(diag)
68+
result = MatmulLinearOperator(D, DenseLinearOperator(A))
69+
70+
expected = torch.diag(diag) @ A
71+
self.assertTrue(torch.allclose(result.to_dense(), expected))
72+
73+
def test_diag_right_matmul_to_dense(self):
74+
"""Test A @ D uses element-wise multiplication."""
75+
diag = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
76+
A = torch.randn(4, 5)
77+
78+
D = DiagLinearOperator(diag)
79+
result = MatmulLinearOperator(DenseLinearOperator(A), D)
80+
81+
expected = A @ torch.diag(diag)
82+
self.assertTrue(torch.allclose(result.to_dense(), expected))
83+
84+
def test_diag_sandwich_to_dense(self):
85+
"""Test D1 @ A @ D2 uses element-wise multiplication (the main bug fix)."""
86+
diag1 = torch.tensor([1.0, 2.0, 3.0, 4.0])
87+
diag2 = torch.tensor([0.5, 1.5, 2.5, 3.5])
88+
A = torch.randn(4, 4)
89+
90+
D1 = DiagLinearOperator(diag1)
91+
D2 = DiagLinearOperator(diag2)
92+
93+
result = D1 @ DenseLinearOperator(A) @ D2
94+
expected = torch.diag(diag1) @ A @ torch.diag(diag2)
95+
self.assertTrue(torch.allclose(result.to_dense(), expected))
96+
97+
def test_diag_sandwich_batch(self):
98+
"""Test D1 @ A @ D2 with batch dimensions."""
99+
batch_size = 3
100+
n = 4
101+
102+
diag1 = torch.randn(batch_size, n).abs()
103+
diag2 = torch.randn(batch_size, n).abs()
104+
A = torch.randn(batch_size, n, n)
105+
106+
D1 = DiagLinearOperator(diag1)
107+
D2 = DiagLinearOperator(diag2)
108+
109+
result = D1 @ DenseLinearOperator(A) @ D2
110+
expected = torch.diag_embed(diag1) @ A @ torch.diag_embed(diag2)
111+
self.assertTrue(torch.allclose(result.to_dense(), expected))
112+
113+
59114
if __name__ == "__main__":
60115
unittest.main()

0 commit comments

Comments
 (0)