1313# limitations under the License.
1414
1515from absl .testing import absltest
16+ import jax
1617import jax .numpy as jnp
1718import numpy as np
1819from torax ._src import tridiagonal
@@ -282,7 +283,7 @@ def test_solve(self):
282283 x_true = jnp .array (rng .randn (4 , 3 ), dtype = jnp .float64 )
283284 rhs = bt .matvec (x_true )
284285
285- x_solved = bt .solve (rhs )
286+ x_solved = bt .solve (rhs , solver_type = tridiagonal . SolverType . THOMAS )
286287
287288 np .testing .assert_allclose (x_solved , x_true , atol = 1e-10 )
288289
@@ -292,7 +293,7 @@ def test_solve_recovers_rhs(self):
292293 rng = np .random .RandomState (55 )
293294 rhs = jnp .array (rng .randn (3 , 2 ), dtype = jnp .float64 )
294295
295- x = bt .solve (rhs )
296+ x = bt .solve (rhs , solver_type = tridiagonal . SolverType . THOMAS )
296297 reconstructed_rhs = bt .matvec (x )
297298
298299 np .testing .assert_allclose (reconstructed_rhs , rhs , atol = 1e-10 )
@@ -345,5 +346,174 @@ def test_from_tridiagonals_to_dense_matches_per_channel(self):
345346 np .testing .assert_allclose (dense , expected_full )
346347
347348
349+ class ThomasSolveTest (absltest .TestCase ):
350+ """Tests specifically targeting the Thomas algorithm for block-tridiagonal."""
351+
352+ def _make_nonsingular_block_tridiag (
353+ self , num_blocks : int , block_size : int , seed : int = 0
354+ ) -> tridiagonal .BlockTriDiagonal :
355+ """Helper to create a diagonally-dominant BlockTriDiagonal."""
356+ rng = np .random .RandomState (seed )
357+ lower = jnp .array (
358+ rng .randn (num_blocks - 1 , block_size , block_size ), dtype = jnp .float64
359+ )
360+ upper = jnp .array (
361+ rng .randn (num_blocks - 1 , block_size , block_size ), dtype = jnp .float64
362+ )
363+ diag_blocks = jnp .array (
364+ rng .randn (num_blocks , block_size , block_size ), dtype = jnp .float64
365+ )
366+ diag_blocks = diag_blocks + 10.0 * jnp .eye (block_size , dtype = jnp .float64 )
367+ return tridiagonal .BlockTriDiagonal (
368+ lower = lower , diagonal = diag_blocks , upper = upper
369+ )
370+
371+ def test_thomas_matches_dense_solve (self ):
372+ """Thomas and dense solvers should produce the same result."""
373+ bt = self ._make_nonsingular_block_tridiag (num_blocks = 5 , block_size = 3 )
374+ rng = np .random .RandomState (42 )
375+ rhs = jnp .array (rng .randn (5 , 3 ), dtype = jnp .float64 )
376+
377+ x_thomas = tridiagonal .thomas_solve (bt , rhs )
378+ x_dense = tridiagonal .dense_solve (bt , rhs )
379+
380+ np .testing .assert_allclose (x_thomas , x_dense , atol = 1e-10 )
381+
382+ def test_thomas_small_known_system (self ):
383+ """Thomas algorithm on a small 2-block system with known answer."""
384+ # 2x2 block system: [[D0, U0], [L0, D1]] @ x = rhs
385+ # D0 = [[10, 0], [0, 10]], U0 = [[1, 0], [0, 1]]
386+ # L0 = [[1, 0], [0, 1]], D1 = [[10, 0], [0, 10]]
387+ # This is close to identity so the answer is close to rhs/10.
388+ diag = jnp .array (
389+ [[[10.0 , 0.0 ], [0.0 , 10.0 ]], [[10.0 , 0.0 ], [0.0 , 10.0 ]]],
390+ dtype = jnp .float64 ,
391+ )
392+ upper = jnp .array ([[[1.0 , 0.0 ], [0.0 , 1.0 ]]], dtype = jnp .float64 )
393+ lower = jnp .array ([[[1.0 , 0.0 ], [0.0 , 1.0 ]]], dtype = jnp .float64 )
394+ bt = tridiagonal .BlockTriDiagonal (lower = lower , diagonal = diag , upper = upper )
395+
396+ x_true = jnp .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]], dtype = jnp .float64 )
397+ rhs = bt .matvec (x_true )
398+
399+ x_solved = tridiagonal .thomas_solve (bt , rhs )
400+
401+ np .testing .assert_allclose (x_solved , x_true , atol = 1e-12 )
402+
403+ def test_thomas_identity_blocks (self ):
404+ """Solving with block-identity should return the RHS itself."""
405+ num_blocks = 4
406+ block_size = 2
407+ bt = tridiagonal .BlockTriDiagonal (
408+ lower = jnp .zeros (
409+ (num_blocks - 1 , block_size , block_size ), dtype = jnp .float64
410+ ),
411+ diagonal = jnp .tile (
412+ jnp .eye (block_size , dtype = jnp .float64 ), (num_blocks , 1 , 1 )
413+ ),
414+ upper = jnp .zeros (
415+ (num_blocks - 1 , block_size , block_size ), dtype = jnp .float64
416+ ),
417+ )
418+ rng = np .random .RandomState (11 )
419+ rhs = jnp .array (rng .randn (num_blocks , block_size ), dtype = jnp .float64 )
420+
421+ x = tridiagonal .thomas_solve (bt , rhs )
422+
423+ np .testing .assert_allclose (x , rhs , atol = 1e-14 )
424+
425+ def test_thomas_scalar_blocks (self ):
426+ """Thomas algorithm with block_size=1 should match scalar tridiagonal."""
427+ bt = self ._make_nonsingular_block_tridiag (
428+ num_blocks = 6 , block_size = 1 , seed = 7
429+ )
430+ rng = np .random .RandomState (13 )
431+ rhs = jnp .array (rng .randn (6 , 1 ), dtype = jnp .float64 )
432+
433+ x = tridiagonal .thomas_solve (bt , rhs )
434+
435+ np .testing .assert_allclose (bt .matvec (x ), rhs , atol = 1e-12 )
436+
437+ def test_thomas_two_blocks (self ):
438+ """Minimal multi-block case: 2 blocks."""
439+ bt = self ._make_nonsingular_block_tridiag (
440+ num_blocks = 2 , block_size = 2 , seed = 99
441+ )
442+ rng = np .random .RandomState (17 )
443+ x_true = jnp .array (rng .randn (2 , 2 ), dtype = jnp .float64 )
444+ rhs = bt .matvec (x_true )
445+
446+ x_solved = tridiagonal .thomas_solve (bt , rhs )
447+
448+ np .testing .assert_allclose (x_solved , x_true , atol = 1e-10 )
449+
450+ def test_thomas_large_system (self ):
451+ """Thomas should handle larger systems accurately."""
452+ bt = self ._make_nonsingular_block_tridiag (
453+ num_blocks = 50 , block_size = 4 , seed = 22
454+ )
455+ rng = np .random .RandomState (33 )
456+ x_true = jnp .array (rng .randn (50 , 4 ), dtype = jnp .float64 )
457+ rhs = bt .matvec (x_true )
458+
459+ x_solved = tridiagonal .thomas_solve (bt , rhs )
460+
461+ np .testing .assert_allclose (x_solved , x_true , atol = 1e-8 )
462+
463+ def test_solver_type_dispatch_thomas (self ):
464+ """solve() with SolverType.THOMAS should use thomas_solve."""
465+ bt = self ._make_nonsingular_block_tridiag (num_blocks = 3 , block_size = 2 )
466+ rng = np .random .RandomState (44 )
467+ rhs = jnp .array (rng .randn (3 , 2 ), dtype = jnp .float64 )
468+
469+ x_via_type = bt .solve (rhs , solver_type = tridiagonal .SolverType .THOMAS )
470+ x_direct = tridiagonal .thomas_solve (bt , rhs )
471+
472+ np .testing .assert_allclose (x_via_type , x_direct , atol = 1e-14 )
473+
474+ def test_solver_type_dispatch_dense (self ):
475+ """solve() with SolverType.DENSE should use dense_solve."""
476+ bt = self ._make_nonsingular_block_tridiag (num_blocks = 3 , block_size = 2 )
477+ rng = np .random .RandomState (44 )
478+ rhs = jnp .array (rng .randn (3 , 2 ), dtype = jnp .float64 )
479+
480+ x_via_type = bt .solve (rhs , solver_type = tridiagonal .SolverType .DENSE )
481+ x_direct = tridiagonal .dense_solve (bt , rhs )
482+
483+ np .testing .assert_allclose (x_via_type , x_direct , atol = 1e-14 )
484+
485+ def test_thomas_jit_compatible (self ):
486+ """thomas_solve should work under jax.jit."""
487+ bt = self ._make_nonsingular_block_tridiag (num_blocks = 4 , block_size = 2 )
488+ rng = np .random .RandomState (66 )
489+ x_true = jnp .array (rng .randn (4 , 2 ), dtype = jnp .float64 )
490+ rhs = bt .matvec (x_true )
491+
492+ jitted_solve = jax .jit (tridiagonal .thomas_solve )
493+ x_solved = jitted_solve (bt , rhs )
494+
495+ np .testing .assert_allclose (x_solved , x_true , atol = 1e-10 )
496+
497+ def test_thomas_from_tridiagonals (self ):
498+ """Thomas solve on a block system built from per-channel scalar tridiagonals."""
499+ ch0 = tridiagonal .TriDiagonal (
500+ diagonal = jnp .array ([10.0 , 12.0 , 14.0 ], dtype = jnp .float64 ),
501+ above = jnp .array ([1.0 , 3.0 ], dtype = jnp .float64 ),
502+ below = jnp .array ([5.0 , 7.0 ], dtype = jnp .float64 ),
503+ )
504+ ch1 = tridiagonal .TriDiagonal (
505+ diagonal = jnp .array ([11.0 , 13.0 , 15.0 ], dtype = jnp .float64 ),
506+ above = jnp .array ([2.0 , 4.0 ], dtype = jnp .float64 ),
507+ below = jnp .array ([6.0 , 8.0 ], dtype = jnp .float64 ),
508+ )
509+ bt = tridiagonal .BlockTriDiagonal .from_tridiagonals ([ch0 , ch1 ])
510+ rng = np .random .RandomState (77 )
511+ rhs = jnp .array (rng .randn (3 , 2 ), dtype = jnp .float64 )
512+
513+ x = tridiagonal .thomas_solve (bt , rhs )
514+
515+ np .testing .assert_allclose (bt .matvec (x ), rhs , atol = 1e-12 )
516+
517+
348518if __name__ == '__main__' :
349519 absltest .main ()
0 commit comments