Skip to content

Commit 56727dc

Browse files
joeljenningsTorax team
authored andcommitted
Add the Thomas Algorithm Block Tridiagonal solver
PiperOrigin-RevId: 884493419
1 parent 9cbeabe commit 56727dc

6 files changed

Lines changed: 266 additions & 5 deletions

File tree

torax/_src/fvm/implicit_solve_block.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import dataclasses
2020

2121
import jax
22+
from torax._src import tridiagonal
2223
from torax._src.fvm import block_1d_coeffs
2324
from torax._src.fvm import cell_variable
2425
from torax._src.fvm import fvm_conversions
@@ -29,6 +30,7 @@
2930
static_argnames=[
3031
'convection_dirichlet_mode',
3132
'convection_neumann_mode',
33+
'implicit_solver_type',
3234
'theta_implicit',
3335
],
3436
)
@@ -41,6 +43,7 @@ def implicit_solve_block(
4143
theta_implicit: float = 1.0,
4244
convection_dirichlet_mode: str = 'ghost',
4345
convection_neumann_mode: str = 'ghost',
46+
implicit_solver_type: tridiagonal.SolverType = tridiagonal.SolverType.THOMAS,
4447
) -> tuple[cell_variable.CellVariable, ...]:
4548
# pyformat: disable # pyformat removes line breaks needed for readability
4649
"""Runs one time step of an implicit solver on the equation defined by `coeffs`.
@@ -65,6 +68,9 @@ def implicit_solve_block(
6568
`dirichlet_mode` argument.
6669
convection_neumann_mode: See docstring of the `convection_terms` function,
6770
`neumann_mode` argument.
71+
implicit_solver_type: The tridiagonal solver algorithm to use for the
72+
implicit linear system solve. Either SolverType.THOMAS (default) for the
73+
Thomas algorithm, or SolverType.DENSE for a dense matrix solve.
6874
6975
Returns:
7076
x_new: Tuple, with x_new[i] giving channel i of x at the next time step
@@ -94,7 +100,7 @@ def implicit_solve_block(
94100
)
95101

96102
rhs_result = rhs_matrix.matvec(x_old_array) + rhs_vec - lhs_vec
97-
x_new = lhs_matrix.solve(rhs_result)
103+
x_new = lhs_matrix.solve(rhs_result, solver_type=implicit_solver_type)
98104

99105
# Create updated CellVariable instances based on state_plus_dt which has
100106
# updated boundary conditions and prescribed profiles.

torax/_src/solver/predictor_corrector_method.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def loop_body(x_new_guess):
105105
theta_implicit=solver_params.theta_implicit,
106106
convection_dirichlet_mode=(solver_params.convection_dirichlet_mode),
107107
convection_neumann_mode=(solver_params.convection_neumann_mode),
108+
implicit_solver_type=solver_params.implicit_solver_type,
108109
)
109110

110111
if solver_params.use_predictor_corrector:

torax/_src/solver/pydantic_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pydantic
2121
from torax._src import models as models_lib
22+
from torax._src import tridiagonal
2223
from torax._src.fvm import enums
2324
from torax._src.solver import linear_theta_method
2425
from torax._src.solver import nonlinear_theta_method
@@ -45,6 +46,9 @@ class BaseSolver(torax_pydantic.BaseModelFrozen, abc.ABC):
4546
`neumann_mode` argument.
4647
use_pereverzev: Use pereverzev terms for linear solver. Is only applied in
4748
the nonlinear solver for the optional initial guess from the linear solver
49+
implicit_solver_type: The tridiagonal solver algorithm to use for the
50+
implicit linear system solve. Either SolverType.THOMAS (default) for the
51+
Thomas algorithm, or SolverType.DENSE for a dense matrix solve.
4852
chi_pereverzev: (deliberately) large heat conductivity for Pereverzev rule.
4953
D_pereverzev: (deliberately) large particle diffusion for Pereverzev rule.
5054
"""
@@ -63,6 +67,9 @@ class BaseSolver(torax_pydantic.BaseModelFrozen, abc.ABC):
6367
Literal['ghost', 'semi-implicit'], torax_pydantic.JAX_STATIC
6468
] = 'ghost'
6569
use_pereverzev: Annotated[bool, torax_pydantic.JAX_STATIC] = False
70+
implicit_solver_type: Annotated[
71+
tridiagonal.SolverType, torax_pydantic.JAX_STATIC
72+
] = tridiagonal.SolverType.THOMAS
6673
chi_pereverzev: pydantic.PositiveFloat = 30.0
6774
D_pereverzev: pydantic.NonNegativeFloat = 15.0
6875

@@ -107,6 +114,7 @@ def build_runtime_params(self) -> runtime_params.RuntimeParams:
107114
convection_neumann_mode=self.convection_neumann_mode,
108115
use_pereverzev=self.use_pereverzev,
109116
use_predictor_corrector=self.use_predictor_corrector,
117+
implicit_solver_type=self.implicit_solver_type,
110118
chi_pereverzev=self.chi_pereverzev,
111119
D_pereverzev=self.D_pereverzev,
112120
n_corrector_steps=self.n_corrector_steps,
@@ -160,6 +168,7 @@ def build_runtime_params(
160168
convection_neumann_mode=self.convection_neumann_mode,
161169
use_pereverzev=self.use_pereverzev,
162170
use_predictor_corrector=self.use_predictor_corrector,
171+
implicit_solver_type=self.implicit_solver_type,
163172
chi_pereverzev=self.chi_pereverzev,
164173
D_pereverzev=self.D_pereverzev,
165174
maxiter=self.n_max_iterations,
@@ -210,6 +219,7 @@ def build_runtime_params(
210219
convection_neumann_mode=self.convection_neumann_mode,
211220
use_pereverzev=self.use_pereverzev,
212221
use_predictor_corrector=self.use_predictor_corrector,
222+
implicit_solver_type=self.implicit_solver_type,
213223
chi_pereverzev=self.chi_pereverzev,
214224
D_pereverzev=self.D_pereverzev,
215225
n_max_iterations=self.n_max_iterations,

torax/_src/solver/runtime_params.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import dataclasses
1616

1717
import jax
18+
from torax._src import tridiagonal
1819

1920

2021
@jax.tree_util.register_dataclass
@@ -27,5 +28,8 @@ class RuntimeParams:
2728
convection_dirichlet_mode: str = dataclasses.field(metadata={'static': True})
2829
convection_neumann_mode: str = dataclasses.field(metadata={'static': True})
2930
use_pereverzev: bool = dataclasses.field(metadata={'static': True})
31+
implicit_solver_type: tridiagonal.SolverType = dataclasses.field(
32+
metadata={'static': True}
33+
)
3034
chi_pereverzev: float
3135
D_pereverzev: float # pylint: disable=invalid-name

torax/_src/tests/tridiagonal_test.py

Lines changed: 172 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from absl.testing import absltest
16+
import jax
1617
import jax.numpy as jnp
1718
import numpy as np
1819
from torax._src import tridiagonal
@@ -282,7 +283,7 @@ def test_solve(self):
282283
x_true = jnp.array(rng.randn(4, 3), dtype=jnp.float64)
283284
rhs = bt.matvec(x_true)
284285

285-
x_solved = bt.solve(rhs)
286+
x_solved = bt.solve(rhs, solver_type=tridiagonal.SolverType.THOMAS)
286287

287288
np.testing.assert_allclose(x_solved, x_true, atol=1e-10)
288289

@@ -292,7 +293,7 @@ def test_solve_recovers_rhs(self):
292293
rng = np.random.RandomState(55)
293294
rhs = jnp.array(rng.randn(3, 2), dtype=jnp.float64)
294295

295-
x = bt.solve(rhs)
296+
x = bt.solve(rhs, solver_type=tridiagonal.SolverType.THOMAS)
296297
reconstructed_rhs = bt.matvec(x)
297298

298299
np.testing.assert_allclose(reconstructed_rhs, rhs, atol=1e-10)
@@ -345,5 +346,174 @@ def test_from_tridiagonals_to_dense_matches_per_channel(self):
345346
np.testing.assert_allclose(dense, expected_full)
346347

347348

349+
class ThomasSolveTest(absltest.TestCase):
350+
"""Tests specifically targeting the Thomas algorithm for block-tridiagonal."""
351+
352+
def _make_nonsingular_block_tridiag(
353+
self, num_blocks: int, block_size: int, seed: int = 0
354+
) -> tridiagonal.BlockTriDiagonal:
355+
"""Helper to create a diagonally-dominant BlockTriDiagonal."""
356+
rng = np.random.RandomState(seed)
357+
lower = jnp.array(
358+
rng.randn(num_blocks - 1, block_size, block_size), dtype=jnp.float64
359+
)
360+
upper = jnp.array(
361+
rng.randn(num_blocks - 1, block_size, block_size), dtype=jnp.float64
362+
)
363+
diag_blocks = jnp.array(
364+
rng.randn(num_blocks, block_size, block_size), dtype=jnp.float64
365+
)
366+
diag_blocks = diag_blocks + 10.0 * jnp.eye(block_size, dtype=jnp.float64)
367+
return tridiagonal.BlockTriDiagonal(
368+
lower=lower, diagonal=diag_blocks, upper=upper
369+
)
370+
371+
def test_thomas_matches_dense_solve(self):
372+
"""Thomas and dense solvers should produce the same result."""
373+
bt = self._make_nonsingular_block_tridiag(num_blocks=5, block_size=3)
374+
rng = np.random.RandomState(42)
375+
rhs = jnp.array(rng.randn(5, 3), dtype=jnp.float64)
376+
377+
x_thomas = tridiagonal.thomas_solve(bt, rhs)
378+
x_dense = tridiagonal.dense_solve(bt, rhs)
379+
380+
np.testing.assert_allclose(x_thomas, x_dense, atol=1e-10)
381+
382+
def test_thomas_small_known_system(self):
383+
"""Thomas algorithm on a small 2-block system with known answer."""
384+
# 2x2 block system: [[D0, U0], [L0, D1]] @ x = rhs
385+
# D0 = [[10, 0], [0, 10]], U0 = [[1, 0], [0, 1]]
386+
# L0 = [[1, 0], [0, 1]], D1 = [[10, 0], [0, 10]]
387+
# This is close to identity so the answer is close to rhs/10.
388+
diag = jnp.array(
389+
[[[10.0, 0.0], [0.0, 10.0]], [[10.0, 0.0], [0.0, 10.0]]],
390+
dtype=jnp.float64,
391+
)
392+
upper = jnp.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=jnp.float64)
393+
lower = jnp.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=jnp.float64)
394+
bt = tridiagonal.BlockTriDiagonal(lower=lower, diagonal=diag, upper=upper)
395+
396+
x_true = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float64)
397+
rhs = bt.matvec(x_true)
398+
399+
x_solved = tridiagonal.thomas_solve(bt, rhs)
400+
401+
np.testing.assert_allclose(x_solved, x_true, atol=1e-12)
402+
403+
def test_thomas_identity_blocks(self):
404+
"""Solving with block-identity should return the RHS itself."""
405+
num_blocks = 4
406+
block_size = 2
407+
bt = tridiagonal.BlockTriDiagonal(
408+
lower=jnp.zeros(
409+
(num_blocks - 1, block_size, block_size), dtype=jnp.float64
410+
),
411+
diagonal=jnp.tile(
412+
jnp.eye(block_size, dtype=jnp.float64), (num_blocks, 1, 1)
413+
),
414+
upper=jnp.zeros(
415+
(num_blocks - 1, block_size, block_size), dtype=jnp.float64
416+
),
417+
)
418+
rng = np.random.RandomState(11)
419+
rhs = jnp.array(rng.randn(num_blocks, block_size), dtype=jnp.float64)
420+
421+
x = tridiagonal.thomas_solve(bt, rhs)
422+
423+
np.testing.assert_allclose(x, rhs, atol=1e-14)
424+
425+
def test_thomas_scalar_blocks(self):
426+
"""Thomas algorithm with block_size=1 should match scalar tridiagonal."""
427+
bt = self._make_nonsingular_block_tridiag(
428+
num_blocks=6, block_size=1, seed=7
429+
)
430+
rng = np.random.RandomState(13)
431+
rhs = jnp.array(rng.randn(6, 1), dtype=jnp.float64)
432+
433+
x = tridiagonal.thomas_solve(bt, rhs)
434+
435+
np.testing.assert_allclose(bt.matvec(x), rhs, atol=1e-12)
436+
437+
def test_thomas_two_blocks(self):
438+
"""Minimal multi-block case: 2 blocks."""
439+
bt = self._make_nonsingular_block_tridiag(
440+
num_blocks=2, block_size=2, seed=99
441+
)
442+
rng = np.random.RandomState(17)
443+
x_true = jnp.array(rng.randn(2, 2), dtype=jnp.float64)
444+
rhs = bt.matvec(x_true)
445+
446+
x_solved = tridiagonal.thomas_solve(bt, rhs)
447+
448+
np.testing.assert_allclose(x_solved, x_true, atol=1e-10)
449+
450+
def test_thomas_large_system(self):
451+
"""Thomas should handle larger systems accurately."""
452+
bt = self._make_nonsingular_block_tridiag(
453+
num_blocks=50, block_size=4, seed=22
454+
)
455+
rng = np.random.RandomState(33)
456+
x_true = jnp.array(rng.randn(50, 4), dtype=jnp.float64)
457+
rhs = bt.matvec(x_true)
458+
459+
x_solved = tridiagonal.thomas_solve(bt, rhs)
460+
461+
np.testing.assert_allclose(x_solved, x_true, atol=1e-8)
462+
463+
def test_solver_type_dispatch_thomas(self):
464+
"""solve() with SolverType.THOMAS should use thomas_solve."""
465+
bt = self._make_nonsingular_block_tridiag(num_blocks=3, block_size=2)
466+
rng = np.random.RandomState(44)
467+
rhs = jnp.array(rng.randn(3, 2), dtype=jnp.float64)
468+
469+
x_via_type = bt.solve(rhs, solver_type=tridiagonal.SolverType.THOMAS)
470+
x_direct = tridiagonal.thomas_solve(bt, rhs)
471+
472+
np.testing.assert_allclose(x_via_type, x_direct, atol=1e-14)
473+
474+
def test_solver_type_dispatch_dense(self):
475+
"""solve() with SolverType.DENSE should use dense_solve."""
476+
bt = self._make_nonsingular_block_tridiag(num_blocks=3, block_size=2)
477+
rng = np.random.RandomState(44)
478+
rhs = jnp.array(rng.randn(3, 2), dtype=jnp.float64)
479+
480+
x_via_type = bt.solve(rhs, solver_type=tridiagonal.SolverType.DENSE)
481+
x_direct = tridiagonal.dense_solve(bt, rhs)
482+
483+
np.testing.assert_allclose(x_via_type, x_direct, atol=1e-14)
484+
485+
def test_thomas_jit_compatible(self):
486+
"""thomas_solve should work under jax.jit."""
487+
bt = self._make_nonsingular_block_tridiag(num_blocks=4, block_size=2)
488+
rng = np.random.RandomState(66)
489+
x_true = jnp.array(rng.randn(4, 2), dtype=jnp.float64)
490+
rhs = bt.matvec(x_true)
491+
492+
jitted_solve = jax.jit(tridiagonal.thomas_solve)
493+
x_solved = jitted_solve(bt, rhs)
494+
495+
np.testing.assert_allclose(x_solved, x_true, atol=1e-10)
496+
497+
def test_thomas_from_tridiagonals(self):
498+
"""Thomas solve on a block system built from per-channel scalar tridiagonals."""
499+
ch0 = tridiagonal.TriDiagonal(
500+
diagonal=jnp.array([10.0, 12.0, 14.0], dtype=jnp.float64),
501+
above=jnp.array([1.0, 3.0], dtype=jnp.float64),
502+
below=jnp.array([5.0, 7.0], dtype=jnp.float64),
503+
)
504+
ch1 = tridiagonal.TriDiagonal(
505+
diagonal=jnp.array([11.0, 13.0, 15.0], dtype=jnp.float64),
506+
above=jnp.array([2.0, 4.0], dtype=jnp.float64),
507+
below=jnp.array([6.0, 8.0], dtype=jnp.float64),
508+
)
509+
bt = tridiagonal.BlockTriDiagonal.from_tridiagonals([ch0, ch1])
510+
rng = np.random.RandomState(77)
511+
rhs = jnp.array(rng.randn(3, 2), dtype=jnp.float64)
512+
513+
x = tridiagonal.thomas_solve(bt, rhs)
514+
515+
np.testing.assert_allclose(bt.matvec(x), rhs, atol=1e-12)
516+
517+
348518
if __name__ == '__main__':
349519
absltest.main()

0 commit comments

Comments
 (0)