Skip to content

Commit 52fbf34

Browse files
Fix add_low_rank to only compute roots when cached roots exist
This fixes a numerical instability bug where add_low_rank would speculatively compute root decompositions even when no cached roots existed. This caused SVD failures on ill-conditioned matrices (e.g., when using LinearKernel). Changes: - Change early exit condition from 'not generate_roots and not has_roots' to 'not (generate_roots and has_roots)' - exit early if EITHER generate_roots=False OR no cached roots exist - Pass generate_roots parameter through in RootLinearOperator.add_low_rank (was being silently ignored) - Add regression test to verify root_decomposition is not called when no roots are cached - Update test_cat_rows to first cache roots before expecting them to be cached after cat_rows (aligns test with new expected behavior)
1 parent 0f31521 commit 52fbf34

4 files changed

Lines changed: 62 additions & 7 deletions

File tree

linear_operator/operators/_linear_operator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,9 +1074,10 @@ def add_low_rank(
10741074
new_linear_op = to_linear_operator(new_linear_op.to_dense())
10751075

10761076
# if the old LinearOperator does not have either a root decomposition or a root inverse decomposition
1077-
# don't create one
1077+
# don't create one. Also skip if the caller explicitly doesn't want roots generated.
1078+
# The root update is only beneficial when self already has cached roots that can be efficiently updated.
10781079
has_roots = any(_is_in_cache_ignore_args(self, key) for key in ("root_decomposition", "root_inv_decomposition"))
1079-
if not generate_roots and not has_roots:
1080+
if not (generate_roots and has_roots):
10801081
return new_linear_op
10811082

10821083
# we are going to compute the following
@@ -1253,15 +1254,16 @@ def cat_rows(
12531254
new_linear_op = CatLinearOperator(upper_row, lower_row, dim=-1, output_device=A.device)
12541255

12551256
# if the old LinearOperator does not have either a root decomposition or a root inverse decomposition
1256-
# don't create one
1257+
# don't create one. Also skip if the caller explicitly doesn't want roots generated.
1258+
# The root update is only beneficial when self already has cached roots that can be efficiently updated.
12571259
has_roots = any(
12581260
_is_in_cache_ignore_args(self, key)
12591261
for key in (
12601262
"root_decomposition",
12611263
"root_inv_decomposition",
12621264
)
12631265
)
1264-
if not generate_roots and not has_roots:
1266+
if not (generate_roots and has_roots):
12651267
return new_linear_op
12661268

12671269
# Get components for new root Z = [E 0; F G]

linear_operator/operators/root_linear_operator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@ def add_low_rank(
9595
generate_roots: bool | None = True,
9696
**root_decomp_kwargs,
9797
) -> LinearOperator: # shape: (*batch, N, N)
98-
return super().add_low_rank(low_rank_mat, root_inv_decomp_method=root_inv_decomp_method)
98+
return super().add_low_rank(
99+
low_rank_mat,
100+
root_inv_decomp_method=root_inv_decomp_method,
101+
generate_roots=generate_roots,
102+
)
99103

100104
def root_decomposition(
101105
self: LinearOperator, method: str | None = None # shape: (*batch, N, N)

linear_operator/test/linear_operator_test_case.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -733,8 +733,14 @@ def test_cat_rows(self):
733733
root_rhs = linear_operator.root_decomposition(new_lt).matmul(rhs)
734734
self.assertAllClose(root_rhs, concat_rhs, **self.tolerances["root_decomposition"])
735735

736-
# check that root inv is cached
737-
root_inv = get_from_cache(new_lt, "root_inv_decomposition")
736+
# Test root_inv caching: roots are only updated when cached roots already exist.
737+
# First, ensure linear_op has cached roots before calling cat_rows.
738+
_ = linear_op.root_decomposition()
739+
_ = linear_op.root_inv_decomposition()
740+
new_lt_with_roots = linear_op.cat_rows(new_rows, new_point)
741+
742+
# check that root inv is cached (since linear_op had cached roots)
743+
root_inv = get_from_cache(new_lt_with_roots, "root_inv_decomposition")
738744
# check that the inverse root decomposition is close
739745
concat_solve = torch.linalg.solve(concatenated_lt, rhs.unsqueeze(-1)).squeeze(-1)
740746
root_inv_solve = root_inv.matmul(rhs)

test/operators/test_dense_linear_operator.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22

33
import unittest
4+
from unittest.mock import patch
45

56
import torch
67

@@ -30,6 +31,48 @@ def test_root_decomposition_exact(self):
3031
actual = linear_op.matmul(test_mat)
3132
self.assertLess(torch.norm(res - actual) / actual.norm(), 0.1)
3233

34+
def test_no_root_computation_when_no_cached_roots(self):
35+
"""
36+
Regression test for add_low_rank speculative root computation bug.
37+
Verify root_decomposition is NOT called when no roots are cached.
38+
39+
This catches a bug where add_low_rank would unnecessarily compute expensive
40+
root decompositions even when the base LinearOperator had no cached roots.
41+
This caused numerical instability (SVD failures) on ill-conditioned matrices.
42+
43+
The fix ensures root updates only happen when BOTH:
44+
1. generate_roots=True (default)
45+
2. The base operator already has cached roots
46+
"""
47+
torch.manual_seed(42)
48+
49+
# Create a simple PSD matrix without any cached root decomposition
50+
n = 5
51+
A = torch.randn(n, n)
52+
base_matrix = A @ A.T + 0.1 * torch.eye(n)
53+
base_op = DenseLinearOperator(base_matrix)
54+
55+
# Create a low-rank term (like LinearKernel produces)
56+
low_rank = torch.randn(n, 2)
57+
58+
# Patch root_decomposition to track if it's called
59+
# Before the fix, add_low_rank would call root_decomposition even when none are cached
60+
# After the fix, it should NOT call root_decomposition
61+
with patch.object(
62+
DenseLinearOperator, "root_decomposition", wraps=base_op.root_decomposition
63+
) as mock_root_decomp:
64+
result = base_op.add_low_rank(low_rank)
65+
66+
# Verify root_decomposition was NOT called (the fix's behavior)
67+
# Before the fix, this would fail because root_decomposition was called
68+
# add_low_rank should NOT compute root_decomposition when no roots are cached
69+
self.assertEqual(mock_root_decomp.call_count, 0)
70+
71+
# Verify the result is still correct (simple matrix addition)
72+
expected = base_matrix + low_rank @ low_rank.T
73+
# add_low_rank should return correct sum
74+
self.assertTrue(torch.allclose(result.to_dense(), expected, atol=1e-5))
75+
3376

3477
class TestDenseLinearOperatorBatch(LinearOperatorTestCase, unittest.TestCase):
3578
seed = 0

0 commit comments

Comments
 (0)