2323newton_raphson_solve_block can capture nonlinear dynamics even when
2424each step is expressed using a matrix multiply.
2525"""
26+
2627from typing import TypeAlias
2728
2829import jax
2930from jax import numpy as jnp
31+ from torax ._src import tridiagonal
3032from torax ._src .fvm import block_1d_coeffs
3133from torax ._src .fvm import cell_variable
3234from torax ._src .fvm import convection_terms
@@ -41,11 +43,11 @@ def calc_c(
4143 coeffs : Block1DCoeffs ,
4244 convection_dirichlet_mode : str = 'ghost' ,
4345 convection_neumann_mode : str = 'ghost' ,
44- ) -> tuple [jax . Array , jax .Array ]:
45- """Calculate C and c such that F = C x + c.
46+ ) -> tuple [tridiagonal . BlockTriDiagonal , jax .Array ]:
47+ """Calculate banded blocks and vector c such that F = C x + c.
4648
47- See docstrings for `Block1DCoeff` and `implicit_solve_block` for
48- more detail .
49+ Returns the block-tridiagonal representation of C. The matrix structure comes
50+ from the 1D FVM stencil: each cell couples to itself and its two neighbors .
4951
5052 Args:
5153 x: Tuple containing CellVariables for each channel. This function uses only
@@ -57,8 +59,10 @@ def calc_c(
5759 `neumann_mode` argument.
5860
5961 Returns:
60- c_mat: matrix C, such that F = C x + c
61- c: the vector c
62+ A tuple of (c_matrix, c_forcing) where:
63+ c_matrix: BlockTriDiagonal with sub/main/super-diagonal blocks.
64+ c_forcing: An array with the terms arising from explicit sources and
65+ boundary conditions.
6266 """
6367
6468 d_face = coeffs .d_face
@@ -75,72 +79,63 @@ def calc_c(
7579 f'but got { x_i .value .shape } .'
7680 )
7781
78- zero_block = jnp .zeros ((num_cells , num_cells ))
79- zero_row_of_blocks = [zero_block ] * num_channels
80- zero_vec = jnp .zeros ((num_cells ))
81- zero_block_vec = [zero_vec ] * num_channels
82-
83- # Make a matrix C and vector c that will accumulate contributions from
84- # diffusion, convection, and source terms.
85- # C and c are both block structured, with one block per channel.
86- c_mat = [zero_row_of_blocks .copy () for _ in range (num_channels )]
87- c = zero_block_vec .copy ()
88-
8982 # Add diffusion terms
90- if d_face is not None :
91- for i in range (num_channels ):
92- (
93- diffusion_mat ,
94- diffusion_vec ,
95- ) = diffusion_terms .make_diffusion_terms (
96- d_face [i ],
97- x [i ],
98- )
99-
100- c_mat [i ][i ] += diffusion_mat .to_dense ()
101- c [i ] += diffusion_vec
83+ if d_face is None :
84+ c_matrix = tridiagonal .BlockTriDiagonal .zeros (num_cells , num_channels )
85+ c_forcing = jnp .zeros ((num_cells , num_channels ))
86+ else :
87+ d_terms = [
88+ diffusion_terms .make_diffusion_terms (d_face_i , x_i )
89+ for d_face_i , x_i in zip (d_face , x )
90+ ]
91+ # stack the forcing terms along the channel axis (axis=1)
92+ c_forcing = jnp .stack ([c_forcing for _ , c_forcing in d_terms ], axis = 1 )
93+ c_matrix = tridiagonal .BlockTriDiagonal .from_tridiagonals (
94+ [d_mat for d_mat , _ in d_terms ]
95+ )
10296
10397 # Add convection terms
10498 if v_face is not None :
99+ conv_terms = []
105100 for i in range (num_channels ):
106101 # Resolve diffusion to zeros if it is not specified
107102 d_face_i = d_face [i ] if d_face is not None else None
108103 d_face_i = jnp .zeros_like (v_face [i ]) if d_face_i is None else d_face_i
109-
110- (
111- conv_mat ,
112- conv_vec ,
113- ) = convection_terms .make_convection_terms (
104+ conv_mat , conv_forcing = convection_terms .make_convection_terms (
114105 v_face [i ],
115106 d_face_i ,
116107 x [i ],
117108 dirichlet_mode = convection_dirichlet_mode ,
118109 neumann_mode = convection_neumann_mode ,
119110 )
120-
121- c_mat [i ][i ] += conv_mat .to_dense ()
122- c [i ] += conv_vec
111+ conv_terms .append ((conv_mat , conv_forcing ))
112+ # stack the forcing terms along the channel axis (axis=1)
113+ conv_forcing = jnp .stack (
114+ [conv_forcing for _ , conv_forcing in conv_terms ], axis = 1
115+ )
116+ c_matrix += tridiagonal .BlockTriDiagonal .from_tridiagonals (
117+ [conv_mat for conv_mat , _ in conv_terms ]
118+ )
119+ c_forcing += conv_forcing
123120
124121 # Add implicit source terms
125122 if source_mat_cell is not None :
123+ diag = c_matrix .diagonal
126124 for i in range (num_channels ):
127125 for j in range (num_channels ):
128126 source = source_mat_cell [i ][j ]
129127 if source is not None :
130- c_mat [i ][j ] += jnp .diag (source )
128+ diag = diag .at [:, i , j ].add (source )
129+ c_matrix = tridiagonal .BlockTriDiagonal (
130+ lower = c_matrix .lower ,
131+ diagonal = diag ,
132+ upper = c_matrix .upper ,
133+ )
131134
132135 # Add explicit source terms
133- def add (left : jax .Array , right : jax .Array | None ):
134- """Addition with adding None treated as no-op."""
135- if right is not None :
136- return left + right
137- return left
138-
139136 if source_cell is not None :
140- c = [add (c_i , source_i ) for c_i , source_i in zip (c , source_cell )]
141-
142- # Form block structure
143- c_mat = jnp .block (c_mat )
144- c = jnp .block (c )
137+ for i in range (num_channels ):
138+ if source_cell [i ] is not None :
139+ c_forcing = c_forcing .at [:, i ].add (source_cell [i ])
145140
146- return c_mat , c
141+ return c_matrix , c_forcing
0 commit comments