Skip to content

Commit e550ce4

Browse files
theo-brownTorax team
authored andcommitted
Add internal boundary conditions to ProfileConditions
Summary of changes: - Necessitated separating adaptive_source out from internal_boundary_conditions to avoid a circular import - Added an iterhybrid test case with an internal boundary condition - Combines internal boundary conditions from pedestal with ones set by user PiperOrigin-RevId: 868117952
1 parent 51ab828 commit e550ce4

13 files changed

Lines changed: 319 additions & 86 deletions

torax/_src/core_profiles/profile_conditions.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
import pydantic
2424
from torax._src import array_typing
25+
from torax._src.internal_boundary_conditions import internal_boundary_conditions as internal_boundary_conditions_lib
2526
from torax._src.torax_pydantic import torax_pydantic
2627
from typing_extensions import Self
2728

@@ -66,6 +67,9 @@ class RuntimeParams:
6667
n_e_nbar_is_fGW: bool
6768
n_e_right_bc: array_typing.FloatScalar
6869
n_e_right_bc_is_fGW: bool
70+
internal_boundary_conditions: (
71+
internal_boundary_conditions_lib.InternalBoundaryConditions
72+
)
6973
current_profile_nu: float
7074
initial_j_is_total_current: bool = dataclasses.field(
7175
metadata={'static': True}
@@ -186,6 +190,11 @@ class ProfileConditions(torax_pydantic.BaseModelFrozen):
186190
n_e_nbar_is_fGW: bool = False
187191
n_e_right_bc: torax_pydantic.TimeVaryingScalar | None = None
188192
n_e_right_bc_is_fGW: bool = False
193+
internal_boundary_conditions: (
194+
internal_boundary_conditions_lib.InternalBoundaryConditionsConfig
195+
) = torax_pydantic.ValidatedDefault(
196+
internal_boundary_conditions_lib.InternalBoundaryConditionsConfig()
197+
)
189198
current_profile_nu: float = 1.0
190199
initial_j_is_total_current: Annotated[bool, torax_pydantic.JAX_STATIC] = False
191200
# TODO(b/434175938): Remove this before the V2 API release in place of
@@ -428,9 +437,18 @@ def build_runtime_params(self, t: chex.Numeric) -> RuntimeParams:
428437
else:
429438
runtime_params['n_e_right_bc_is_absolute'] = True
430439

440+
runtime_params['internal_boundary_conditions'] = (
441+
self.internal_boundary_conditions.build_runtime_params(t)
442+
)
443+
431444
def _get_value(x):
432445
if isinstance(
433-
x, (torax_pydantic.TimeVaryingScalar, torax_pydantic.TimeVaryingArray)
446+
x,
447+
(
448+
torax_pydantic.TimeVaryingScalar,
449+
torax_pydantic.TimeVaryingArray,
450+
torax_pydantic.TimeVaryingPoints,
451+
),
434452
):
435453
return x.get_value(t)
436454
else:

torax/_src/core_profiles/tests/profile_conditions_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,32 @@ def test_multiple_validation_errors(self):
521521
with self.assertRaisesRegex(ValueError, '3 errors were found'):
522522
profile_conditions.ProfileConditions(**config_overrides)
523523

524+
def test_internal_boundary_conditions_set_correctly(self):
525+
"""Tests that internal_boundary_conditions is populated from config."""
526+
config = default_configs.get_default_config_dict()
527+
config['profile_conditions'] = {
528+
'internal_boundary_conditions': {
529+
'T_i': {
530+
0.0: {0: 1.0, 1: 2.0},
531+
1.0: {0: 3.0, 1: 4.0},
532+
},
533+
},
534+
}
535+
torax_config = model_config.ToraxConfig.from_dict(config)
536+
runtime_params_provider = (
537+
build_runtime_params.RuntimeParamsProvider.from_config(torax_config)
538+
)
539+
540+
runtime_params = runtime_params_provider(t=0.0)
541+
self.assertIsNotNone(
542+
runtime_params.profile_conditions.internal_boundary_conditions
543+
)
544+
# Basic check to ensure the config was actually used.
545+
np.testing.assert_array_equal(
546+
runtime_params.profile_conditions.internal_boundary_conditions.T_i,
547+
np.array([1.0, 0.0, 0.0, 2.0]),
548+
)
549+
524550

525551
if __name__ == '__main__':
526552
absltest.main()

torax/_src/fvm/calc_coeffs.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Calculates Block1DCoeffs for a time step."""
16+
1617
import functools
1718

1819
import jax
@@ -26,7 +27,7 @@
2627
from torax._src.fvm import block_1d_coeffs
2728
from torax._src.fvm import cell_variable
2829
from torax._src.geometry import geometry
29-
from torax._src.internal_boundary_conditions import internal_boundary_conditions as internal_boundary_conditions_lib
30+
from torax._src.internal_boundary_conditions import adaptive_source
3031
from torax._src.pedestal_model import pedestal_model as pedestal_model_lib
3132
from torax._src.sources import source_profile_builders
3233
from torax._src.sources import source_profiles as source_profiles_lib
@@ -298,6 +299,9 @@ def _calc_coeffs_full(
298299
pedestal_model_output = physics_models.pedestal_model(
299300
runtime_params, geo, core_profiles
300301
)
302+
internal_boundary_conditions_from_config = (
303+
runtime_params.profile_conditions.internal_boundary_conditions
304+
)
301305

302306
conductivity = (
303307
physics_models.neoclassical_models.conductivity.calculate_conductivity(
@@ -521,27 +525,31 @@ def _calc_coeffs_full(
521525
* core_profiles.psi.grad()
522526
)
523527

524-
# Add internal boundary condition source terms
528+
# Add internal boundary condition source terms, combining user-specified
529+
# boundary conditions with pedestal model output.
530+
# Note that the pedestal model will overwrite any user-specified boundary
531+
# conditions, since the pedestal model is applied last.
532+
combined_internal_boundary_conditions = (
533+
internal_boundary_conditions_from_config.update(
534+
pedestal_model_output.to_internal_boundary_conditions(geo)
535+
)
536+
)
525537
(
526538
source_i,
527539
source_e,
528540
source_n_e,
529541
source_mat_ii,
530542
source_mat_ee,
531543
source_mat_nn,
532-
) = internal_boundary_conditions_lib.apply_adaptive_source(
544+
) = adaptive_source.apply_adaptive_source(
533545
source_T_i=source_i,
534546
source_T_e=source_e,
535547
source_n_e=source_n_e,
536548
source_mat_ii=source_mat_ii,
537549
source_mat_ee=source_mat_ee,
538550
source_mat_nn=source_mat_nn,
539551
runtime_params=runtime_params,
540-
# Pedestal contributes an internal boundary condition to the source
541-
# terms at the pedestal top.
542-
internal_boundary_conditions=pedestal_model_output.to_internal_boundary_conditions(
543-
geo
544-
),
552+
internal_boundary_conditions=combined_internal_boundary_conditions,
545553
)
546554

547555
# Build arguments to solver based on which variables are evolving
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2026 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Adaptive source for internal boundary conditions."""
16+
17+
import jax.numpy as jnp
18+
from torax._src import array_typing
19+
from torax._src.config import runtime_params as runtime_params_lib
20+
from torax._src.internal_boundary_conditions import internal_boundary_conditions as internal_boundary_conditions_lib
21+
22+
23+
def apply_adaptive_source(
24+
*,
25+
source_T_i: array_typing.FloatVectorCell,
26+
source_T_e: array_typing.FloatVectorCell,
27+
source_n_e: array_typing.FloatVectorCell,
28+
source_mat_ii: array_typing.FloatVectorCell,
29+
source_mat_ee: array_typing.FloatVectorCell,
30+
source_mat_nn: array_typing.FloatVectorCell,
31+
runtime_params: runtime_params_lib.RuntimeParams,
32+
internal_boundary_conditions: internal_boundary_conditions_lib.InternalBoundaryConditions,
33+
) -> tuple[
34+
array_typing.FloatVectorCell,
35+
array_typing.FloatVectorCell,
36+
array_typing.FloatVectorCell,
37+
array_typing.FloatVectorCell,
38+
array_typing.FloatVectorCell,
39+
array_typing.FloatVectorCell,
40+
]:
41+
"""Applies an adaptive source to the source profiles to set internal boundary conditions."""
42+
43+
# Ion temperature
44+
source_T_i += (
45+
runtime_params.numerics.adaptive_T_source_prefactor
46+
* internal_boundary_conditions.T_i
47+
)
48+
source_mat_ii -= jnp.where(
49+
internal_boundary_conditions.T_i != 0.0,
50+
runtime_params.numerics.adaptive_T_source_prefactor,
51+
0.0,
52+
)
53+
54+
# Electron temperature
55+
source_T_e += (
56+
runtime_params.numerics.adaptive_T_source_prefactor
57+
* internal_boundary_conditions.T_e
58+
)
59+
source_mat_ee -= jnp.where(
60+
internal_boundary_conditions.T_e != 0.0,
61+
runtime_params.numerics.adaptive_T_source_prefactor,
62+
0.0,
63+
)
64+
65+
# Density
66+
source_n_e += (
67+
runtime_params.numerics.adaptive_n_source_prefactor
68+
* internal_boundary_conditions.n_e
69+
)
70+
source_mat_nn -= jnp.where(
71+
internal_boundary_conditions.n_e != 0.0,
72+
runtime_params.numerics.adaptive_n_source_prefactor,
73+
0.0,
74+
)
75+
76+
return (
77+
source_T_i,
78+
source_T_e,
79+
source_n_e,
80+
source_mat_ii,
81+
source_mat_ee,
82+
source_mat_nn,
83+
)

torax/_src/internal_boundary_conditions/internal_boundary_conditions.py

Lines changed: 24 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""Internal boundary conditions."""
1616

1717
import dataclasses
18+
from typing import Annotated
1819

20+
import chex
1921
import jax
2022
import jax.numpy as jnp
2123
from torax._src import array_typing
2224
from torax._src import jax_utils
23-
from torax._src.config import runtime_params as runtime_params_lib
2425
from torax._src.geometry import geometry
25-
26+
from torax._src.torax_pydantic import torax_pydantic
2627
# pylint: disable=invalid-name
2728

2829

@@ -72,64 +73,24 @@ def empty(cls, geo: geometry.Geometry) -> 'InternalBoundaryConditions':
7273
)
7374

7475

75-
def apply_adaptive_source(
76-
*,
77-
source_T_i: array_typing.FloatVectorCell,
78-
source_T_e: array_typing.FloatVectorCell,
79-
source_n_e: array_typing.FloatVectorCell,
80-
source_mat_ii: array_typing.FloatVectorCell,
81-
source_mat_ee: array_typing.FloatVectorCell,
82-
source_mat_nn: array_typing.FloatVectorCell,
83-
runtime_params: runtime_params_lib.RuntimeParams,
84-
internal_boundary_conditions: InternalBoundaryConditions,
85-
) -> tuple[
86-
array_typing.FloatVectorCell,
87-
array_typing.FloatVectorCell,
88-
array_typing.FloatVectorCell,
89-
array_typing.FloatVectorCell,
90-
array_typing.FloatVectorCell,
91-
array_typing.FloatVectorCell,
92-
]:
93-
"""Applies an adaptive source to the source profiles to set internal boundary conditions."""
94-
95-
# Ion temperature
96-
source_T_i += (
97-
runtime_params.numerics.adaptive_T_source_prefactor
98-
* internal_boundary_conditions.T_i
99-
)
100-
source_mat_ii -= jnp.where(
101-
internal_boundary_conditions.T_i != 0.0,
102-
runtime_params.numerics.adaptive_T_source_prefactor,
103-
0.0,
104-
)
105-
106-
# Electron temperature
107-
source_T_e += (
108-
runtime_params.numerics.adaptive_T_source_prefactor
109-
* internal_boundary_conditions.T_e
110-
)
111-
source_mat_ee -= jnp.where(
112-
internal_boundary_conditions.T_e != 0.0,
113-
runtime_params.numerics.adaptive_T_source_prefactor,
114-
0.0,
115-
)
116-
117-
# Density
118-
source_n_e += (
119-
runtime_params.numerics.adaptive_n_source_prefactor
120-
* internal_boundary_conditions.n_e
121-
)
122-
source_mat_nn -= jnp.where(
123-
internal_boundary_conditions.n_e != 0.0,
124-
runtime_params.numerics.adaptive_n_source_prefactor,
125-
0.0,
126-
)
127-
128-
return (
129-
source_T_i,
130-
source_T_e,
131-
source_n_e,
132-
source_mat_ii,
133-
source_mat_ee,
134-
source_mat_nn,
135-
)
76+
class InternalBoundaryConditionsConfig(torax_pydantic.BaseModelFrozen):
77+
"""Pydantic model for internal boundary conditions."""
78+
79+
# Set to zero by default, which is ignored by the adaptive source.
80+
T_i: Annotated[
81+
torax_pydantic.TimeVaryingPoints, torax_pydantic.JAX_STATIC
82+
] = torax_pydantic.ValidatedDefault(0.0)
83+
T_e: Annotated[
84+
torax_pydantic.TimeVaryingPoints, torax_pydantic.JAX_STATIC
85+
] = torax_pydantic.ValidatedDefault(0.0)
86+
n_e: Annotated[
87+
torax_pydantic.TimeVaryingPoints, torax_pydantic.JAX_STATIC
88+
] = torax_pydantic.ValidatedDefault(0.0)
89+
90+
def build_runtime_params(self, t: chex.Numeric) -> InternalBoundaryConditions:
91+
"""Builds the runtime params for the internal boundary conditions."""
92+
kwargs = {
93+
field.name: getattr(self, field.name).get_value(t)
94+
for field in dataclasses.fields(InternalBoundaryConditions)
95+
}
96+
return InternalBoundaryConditions(**kwargs)

torax/_src/internal_boundary_conditions/tests/internal_boundary_conditions_test.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,50 @@
1313
# limitations under the License.
1414

1515
from absl.testing import absltest
16+
from absl.testing import parameterized
1617
import jax.numpy as jnp
1718
import numpy as np
19+
from torax._src.geometry import circular_geometry
20+
from torax._src.internal_boundary_conditions import adaptive_source
1821
from torax._src.internal_boundary_conditions import internal_boundary_conditions
22+
from torax._src.torax_pydantic import torax_pydantic
1923

2024
# pylint: disable=invalid-name
2125

2226

23-
class InternalBoundaryConditionsTest(absltest.TestCase):
27+
class InternalBoundaryConditionsTest(parameterized.TestCase):
28+
29+
@parameterized.named_parameters(
30+
('initial_time', 0.0, [1.0, 0.0, 0.0, 2.0]),
31+
('intermediate_time', 0.5, [2.0, 0.0, 0.0, 3.0]),
32+
('final_time', 1.0, [3.0, 0.0, 0.0, 4.0]),
33+
('after_final_time', 1.5, [3.0, 0.0, 0.0, 4.0]),
34+
)
35+
def test_internal_boundary_conditions_config_build_runtime_params(
36+
self, t, expected_T_i
37+
):
38+
ibc_config = internal_boundary_conditions.InternalBoundaryConditionsConfig(
39+
T_i={
40+
0.0: {0: 1.0, 1: 2.0},
41+
1.0: {0: 3.0, 1: 4.0},
42+
},
43+
)
44+
geo = circular_geometry.CircularConfig(n_rho=4).build_geometry()
45+
torax_pydantic.set_grid(ibc_config, geo.torax_mesh)
46+
47+
runtime_params = ibc_config.build_runtime_params(t=t)
48+
np.testing.assert_array_equal(
49+
runtime_params.T_i,
50+
np.array(expected_T_i),
51+
)
52+
np.testing.assert_array_equal(
53+
runtime_params.T_e,
54+
np.array([0.0, 0.0, 0.0, 0.0]),
55+
)
56+
np.testing.assert_array_equal(
57+
runtime_params.n_e,
58+
np.array([0.0, 0.0, 0.0, 0.0]),
59+
)
2460

2561
def test_update(self):
2662
ibc1 = internal_boundary_conditions.InternalBoundaryConditions(
@@ -69,7 +105,7 @@ class MockRuntimeParams:
69105
source_mat_ii,
70106
source_mat_ee,
71107
source_mat_nn,
72-
) = internal_boundary_conditions.apply_adaptive_source(
108+
) = adaptive_source.apply_adaptive_source(
73109
source_T_i=source_T_i,
74110
source_T_e=source_T_e,
75111
source_n_e=source_n_e,

0 commit comments

Comments
 (0)