Skip to content

Commit 2b5d591

Browse files
joeljenningsTorax team
authored andcommitted
Add block tri diagonal class to Torax to represent the discrete system
PiperOrigin-RevId: 903931872
1 parent b240282 commit 2b5d591

6 files changed

Lines changed: 601 additions & 105 deletions

File tree

torax/_src/fvm/discrete_system.py

Lines changed: 45 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
newton_raphson_solve_block can capture nonlinear dynamics even when
2424
each step is expressed using a matrix multiply.
2525
"""
26+
2627
from typing import TypeAlias
2728

2829
import jax
2930
from jax import numpy as jnp
31+
from torax._src import tridiagonal
3032
from torax._src.fvm import block_1d_coeffs
3133
from torax._src.fvm import cell_variable
3234
from torax._src.fvm import convection_terms
@@ -41,11 +43,11 @@ def calc_c(
4143
coeffs: Block1DCoeffs,
4244
convection_dirichlet_mode: str = 'ghost',
4345
convection_neumann_mode: str = 'ghost',
44-
) -> tuple[jax.Array, jax.Array]:
45-
"""Calculate C and c such that F = C x + c.
46+
) -> tuple[tridiagonal.BlockTriDiagonal, jax.Array]:
47+
"""Calculate banded blocks and vector c such that F = C x + c.
4648
47-
See docstrings for `Block1DCoeff` and `implicit_solve_block` for
48-
more detail.
49+
Returns the block-tridiagonal representation of C. The matrix structure comes
50+
from the 1D FVM stencil: each cell couples to itself and its two neighbors.
4951
5052
Args:
5153
x: Tuple containing CellVariables for each channel. This function uses only
@@ -57,8 +59,10 @@ def calc_c(
5759
`neumann_mode` argument.
5860
5961
Returns:
60-
c_mat: matrix C, such that F = C x + c
61-
c: the vector c
62+
A tuple of (c_matrix, c_forcing) where:
63+
c_matrix: BlockTriDiagonal with sub/main/super-diagonal blocks.
64+
c_forcing: An array with the terms arising from explicit sources and
65+
boundary conditions.
6266
"""
6367

6468
d_face = coeffs.d_face
@@ -75,72 +79,63 @@ def calc_c(
7579
f'but got {x_i.value.shape}.'
7680
)
7781

78-
zero_block = jnp.zeros((num_cells, num_cells))
79-
zero_row_of_blocks = [zero_block] * num_channels
80-
zero_vec = jnp.zeros((num_cells))
81-
zero_block_vec = [zero_vec] * num_channels
82-
83-
# Make a matrix C and vector c that will accumulate contributions from
84-
# diffusion, convection, and source terms.
85-
# C and c are both block structured, with one block per channel.
86-
c_mat = [zero_row_of_blocks.copy() for _ in range(num_channels)]
87-
c = zero_block_vec.copy()
88-
8982
# Add diffusion terms
90-
if d_face is not None:
91-
for i in range(num_channels):
92-
(
93-
diffusion_mat,
94-
diffusion_vec,
95-
) = diffusion_terms.make_diffusion_terms(
96-
d_face[i],
97-
x[i],
98-
)
99-
100-
c_mat[i][i] += diffusion_mat.to_dense()
101-
c[i] += diffusion_vec
83+
if d_face is None:
84+
c_matrix = tridiagonal.BlockTriDiagonal.zeros(num_cells, num_channels)
85+
c_forcing = jnp.zeros((num_cells, num_channels))
86+
else:
87+
d_terms = [
88+
diffusion_terms.make_diffusion_terms(d_face_i, x_i)
89+
for d_face_i, x_i in zip(d_face, x)
90+
]
91+
# stack the forcing terms along the channel axis (axis=1)
92+
c_forcing = jnp.stack([c_forcing for _, c_forcing in d_terms], axis=1)
93+
c_matrix = tridiagonal.BlockTriDiagonal.from_tridiagonals(
94+
[d_mat for d_mat, _ in d_terms]
95+
)
10296

10397
# Add convection terms
10498
if v_face is not None:
99+
conv_terms = []
105100
for i in range(num_channels):
106101
# Resolve diffusion to zeros if it is not specified
107102
d_face_i = d_face[i] if d_face is not None else None
108103
d_face_i = jnp.zeros_like(v_face[i]) if d_face_i is None else d_face_i
109-
110-
(
111-
conv_mat,
112-
conv_vec,
113-
) = convection_terms.make_convection_terms(
104+
conv_mat, conv_forcing = convection_terms.make_convection_terms(
114105
v_face[i],
115106
d_face_i,
116107
x[i],
117108
dirichlet_mode=convection_dirichlet_mode,
118109
neumann_mode=convection_neumann_mode,
119110
)
120-
121-
c_mat[i][i] += conv_mat.to_dense()
122-
c[i] += conv_vec
111+
conv_terms.append((conv_mat, conv_forcing))
112+
# stack the forcing terms along the channel axis (axis=1)
113+
conv_forcing = jnp.stack(
114+
[conv_forcing for _, conv_forcing in conv_terms], axis=1
115+
)
116+
c_matrix += tridiagonal.BlockTriDiagonal.from_tridiagonals(
117+
[conv_mat for conv_mat, _ in conv_terms]
118+
)
119+
c_forcing += conv_forcing
123120

124121
# Add implicit source terms
125122
if source_mat_cell is not None:
123+
diag = c_matrix.diagonal
126124
for i in range(num_channels):
127125
for j in range(num_channels):
128126
source = source_mat_cell[i][j]
129127
if source is not None:
130-
c_mat[i][j] += jnp.diag(source)
128+
diag = diag.at[:, i, j].add(source)
129+
c_matrix = tridiagonal.BlockTriDiagonal(
130+
lower=c_matrix.lower,
131+
diagonal=diag,
132+
upper=c_matrix.upper,
133+
)
131134

132135
# Add explicit source terms
133-
def add(left: jax.Array, right: jax.Array | None):
134-
"""Addition with adding None treated as no-op."""
135-
if right is not None:
136-
return left + right
137-
return left
138-
139136
if source_cell is not None:
140-
c = [add(c_i, source_i) for c_i, source_i in zip(c, source_cell)]
141-
142-
# Form block structure
143-
c_mat = jnp.block(c_mat)
144-
c = jnp.block(c)
137+
for i in range(num_channels):
138+
if source_cell[i] is not None:
139+
c_forcing = c_forcing.at[:, i].add(source_cell[i])
145140

146-
return c_mat, c
141+
return c_matrix, c_forcing

torax/_src/fvm/fvm_conversions.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ def cell_variable_tuple_to_vec(
3333
Returns:
3434
A flat array of evolving state variables.
3535
"""
36-
x_vec = jnp.concatenate([x.value for x in x_tuple])
37-
return x_vec
36+
return jnp.concatenate([x.value for x in x_tuple])
3837

3938

4039
def vec_to_cell_variable_tuple(
@@ -77,3 +76,20 @@ def vec_to_cell_variable_tuple(
7776
]
7877

7978
return tuple(x_out)
79+
80+
81+
def cell_variable_tuple_to_array(
82+
x_tuple: tuple[cell_variable.CellVariable, ...],
83+
axis: int,
84+
) -> jax.Array:
85+
"""Converts a tuple of CellVariables to a multi-dimensional array.
86+
87+
88+
Args:
89+
x_tuple: A tuple of CellVariables.
90+
axis: The axis along which to stack the CellVariables.
91+
92+
Returns:
93+
A multi-dimensional array of CellVariables.
94+
"""
95+
return jnp.stack([var.value for var in x_tuple], axis=axis)

torax/_src/fvm/implicit_solve_block.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import dataclasses
2020

2121
import jax
22-
from jax import numpy as jnp
2322
from torax._src.fvm import block_1d_coeffs
2423
from torax._src.fvm import cell_variable
2524
from torax._src.fvm import fvm_conversions
@@ -79,10 +78,9 @@ def implicit_solve_block(
7978
# or from Picard iterations with predictor-corrector.
8079
# See residual_and_loss.theta_method_matrix_equation for a complete
8180
# description of how the equation is set up.
81+
x_old_array = fvm_conversions.cell_variable_tuple_to_array(x_old, axis=1)
8282

83-
x_old_vec = fvm_conversions.cell_variable_tuple_to_vec(x_old)
84-
85-
lhs_mat, lhs_vec, rhs_mat, rhs_vec = (
83+
lhs_matrix, lhs_vec, rhs_matrix, rhs_vec = (
8684
residual_and_loss.theta_method_matrix_equation(
8785
dt=dt,
8886
x_old=x_old,
@@ -95,16 +93,12 @@ def implicit_solve_block(
9593
)
9694
)
9795

98-
rhs = jnp.dot(rhs_mat, x_old_vec) + rhs_vec - lhs_vec
99-
x_new = jnp.linalg.solve(lhs_mat, rhs)
96+
rhs_result = rhs_matrix.matvec(x_old_array) + rhs_vec - lhs_vec
97+
x_new = lhs_matrix.solve(rhs_result)
10098

10199
# Create updated CellVariable instances based on state_plus_dt which has
102100
# updated boundary conditions and prescribed profiles.
103-
x_new = jnp.split(x_new, len(x_old))
104-
out = [
105-
dataclasses.replace(var, value=value)
106-
for var, value in zip(x_new_guess, x_new)
107-
]
108-
out = tuple(out)
109-
110-
return out
101+
return tuple(
102+
dataclasses.replace(var, value=x_new[:, i])
103+
for i, var in enumerate(x_new_guess)
104+
)

0 commit comments

Comments
 (0)