Skip to content

Commit 4c8f84a

Browse files
theo-brownTorax team
authored andcommitted
Add EPEDNNmit pedestal model.
Uses a surrogate trained on EPED simulations of SPARC(https://github.com/aaronkho/epednn_mit) to predict the pedestal pressure and width. The implementation of this model in TORAX is experimental and undergoing validation. PiperOrigin-RevId: 870795084
1 parent 1f94c62 commit 4c8f84a

3 files changed

Lines changed: 295 additions & 1 deletion

File tree

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright 2026 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""EPEDNN-mit pedestal model.
15+
16+
This model is only valid for the SPARC parameter space, as specified in
17+
https://github.com/aaronkho/epednn_mit/tree/main/src/epednn_mit/models/sparc.
18+
19+
Please cite [M. Muraca et al. 2025 Nucl. Fusion 65
20+
096010](https://doi.org/10.1088/1741-4326/adf656) in any works using this model.
21+
"""
22+
23+
import dataclasses
24+
import functools
25+
import pathlib
26+
from typing import Any, Final, TypeAlias
27+
from epednn_mit.models.sparc import jax_model as epednn_mit_jax_model
28+
import jax
29+
from jax import numpy as jnp
30+
from torax._src import array_typing
31+
from torax._src import math_utils
32+
from torax._src import state
33+
from torax._src.config import runtime_params as runtime_params_lib
34+
from torax._src.geometry import geometry
35+
from torax._src.pedestal_model import pedestal_model
36+
from torax._src.pedestal_model import runtime_params as pedestal_runtime_params_lib
37+
from torax._src.pedestal_model import set_pped_tpedratio_nped
38+
from torax._src.physics import formulas
39+
from typing_extensions import override
40+
41+
EPEDNNmitStats: TypeAlias = dict[str, jax.Array]
42+
EPEDNNmitParams: TypeAlias = dict[str, Any]
43+
44+
_INPUT_BOUNDS: Final[dict[str, tuple[float, float]]] = {
45+
"Ip": (1.6, 14.3),
46+
"Bt": (7.2, 12.2),
47+
"R": (1.85, 1.85),
48+
"a": (0.57, 0.57),
49+
"kappa": (1.53, 2.29),
50+
"delta": (0.39, 0.59),
51+
"neped": (2.84, 90.235),
52+
"betan": (0.8, 1.6),
53+
"zeff": (1.3, 2.5),
54+
}
55+
56+
57+
def _check_input_bounds(
58+
epednn_mit_inputs: jax.Array,
59+
) -> None:
60+
"""Checks that the EPEDNN-mit inputs are within the bounds."""
61+
for i, (key, (lower, upper)) in enumerate(_INPUT_BOUNDS.items()):
62+
if not (lower <= epednn_mit_inputs[i] <= upper):
63+
raise ValueError(
64+
f"EPEDNN-mit input {key} is out of bounds of the training"
65+
f" distribution. Value is {epednn_mit_inputs[i]}, but"
66+
f" bounds are [{lower}, {upper}]."
67+
)
68+
69+
70+
# pylint: disable=invalid-name
71+
@jax.tree_util.register_dataclass
72+
@dataclasses.dataclass(frozen=True)
73+
class RuntimeParams(pedestal_runtime_params_lib.RuntimeParams):
74+
"""Runtime params for the EPEDNNmitPedestalModel."""
75+
76+
n_e_ped: array_typing.FloatScalar
77+
T_i_T_e_ratio: array_typing.FloatScalar
78+
n_e_ped_is_fGW: array_typing.BoolScalar
79+
80+
81+
@dataclasses.dataclass(frozen=True, eq=False)
82+
class EPEDNNmitPedestalModel(
83+
set_pped_tpedratio_nped.SetPressureTemperatureRatioAndDensityPedestalModel
84+
):
85+
"""Pedestal model using EPEDNN-mit to predict pedestal pressure and width."""
86+
87+
def _prepare_epednn_mit_inputs(
88+
self,
89+
runtime_params: runtime_params_lib.RuntimeParams,
90+
geo: geometry.Geometry,
91+
core_profiles: state.CoreProfiles,
92+
) -> jax.Array:
93+
"""Prepares the inputs for EPEDNN-mit."""
94+
assert isinstance(runtime_params.pedestal, RuntimeParams)
95+
96+
_, _, beta_N = formulas.calculate_betas(core_profiles, geo)
97+
98+
# TODO(b/323504363): We really want the Z_eff at the pedestal top;
99+
# however, the location of the pedestal top is an *output* of the model.
100+
# Currently, we instead compute a density-weighted volume average of Z_eff
101+
# over the entire domain.
102+
Z_eff_average = math_utils.volume_integration(
103+
core_profiles.Z_eff * core_profiles.n_e.value, geo
104+
) / math_utils.volume_integration(core_profiles.n_e.value, geo)
105+
106+
inputs = jnp.array([
107+
core_profiles.Ip_profile_face[-1] * 1e-6, # [MA]
108+
geo.B_0, # [T]
109+
geo.R_major, # [m]
110+
geo.a_minor, # [m]
111+
geo.elongation_face[-1], # []
112+
geo.delta_face[-1], # []
113+
runtime_params.pedestal.n_e_ped * 1e-19, # [10^19 m^-3]
114+
beta_N, # [%]
115+
Z_eff_average, # [C]
116+
])
117+
_check_input_bounds(inputs)
118+
return inputs
119+
120+
@functools.cached_property
121+
def _get_model(
122+
self,
123+
) -> tuple[
124+
EPEDNNmitStats,
125+
EPEDNNmitParams,
126+
epednn_mit_jax_model.EPEDNNmitEnsemble,
127+
]:
128+
"""Returns the EPEDNN-mit model and parameters."""
129+
model_dir = pathlib.Path(epednn_mit_jax_model.__file__).parent
130+
model_weights = sorted(model_dir.glob("epednn_mit_sparc_*.pkl"))
131+
stats, params = epednn_mit_jax_model.load_ensemble_params_from_pickle(
132+
model_weights
133+
)
134+
model = epednn_mit_jax_model.EPEDNNmitEnsemble()
135+
return stats, params, model
136+
137+
@override
138+
def _call_implementation(
139+
self,
140+
runtime_params: runtime_params_lib.RuntimeParams,
141+
geo: geometry.Geometry,
142+
core_profiles: state.CoreProfiles,
143+
) -> pedestal_model.PedestalModelOutput:
144+
assert isinstance(runtime_params.pedestal, RuntimeParams)
145+
146+
# Get P_ped and rho_norm_ped_top from EPEDNN-mit.
147+
stats, params, model = self._get_model()
148+
epednn_mit_inputs = self._prepare_epednn_mit_inputs(
149+
runtime_params, geo, core_profiles
150+
)
151+
P_ped_kPa, pedestal_width_psi_norm = model.apply(
152+
params, epednn_mit_inputs, **stats
153+
)
154+
155+
# Convert pedestal width to rho_norm
156+
psi_norm = (core_profiles.psi.value - core_profiles.psi.value[0]) / (
157+
core_profiles.psi.value[-1] - core_profiles.psi.value[0]
158+
)
159+
psi_norm_ped_top = 1.0 - pedestal_width_psi_norm
160+
rho_norm_ped_top = jnp.interp(psi_norm_ped_top, psi_norm, geo.rho_norm)
161+
162+
# Convert P_ped from kPa to Pa.
163+
P_ped = P_ped_kPa * 1e3
164+
165+
# Use the set_pped_tpedratio_nped model to calculate the pedestal profiles.
166+
super_runtime_params = set_pped_tpedratio_nped.RuntimeParams(
167+
set_pedestal=runtime_params.pedestal.set_pedestal,
168+
P_ped=P_ped,
169+
n_e_ped=runtime_params.pedestal.n_e_ped,
170+
T_i_T_e_ratio=runtime_params.pedestal.T_i_T_e_ratio,
171+
rho_norm_ped_top=rho_norm_ped_top,
172+
n_e_ped_is_fGW=runtime_params.pedestal.n_e_ped_is_fGW,
173+
)
174+
modified_runtime_params = dataclasses.replace(
175+
runtime_params, pedestal=super_runtime_params
176+
)
177+
return super()._call_implementation(
178+
modified_runtime_params, geo, core_profiles
179+
)

torax/_src/pedestal_model/pydantic_model.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Annotated, Literal
1818

1919
import chex
20+
from torax._src.pedestal_model import epednn_mit_pedestal_model
2021
from torax._src.pedestal_model import no_pedestal
2122
from torax._src.pedestal_model import pedestal_model
2223
from torax._src.pedestal_model import runtime_params
@@ -100,6 +101,44 @@ def build_runtime_params(
100101
)
101102

102103

104+
class EPEDNNmit(BasePedestal):
105+
"""Uses EPEDNN-mit to predict pedestal pressure and width.
106+
107+
Attributes:
108+
n_e_ped: The electron density at the pedestal [m^-3] or fGW.
109+
n_e_ped_is_fGW: Whether the electron density at the pedestal is in units of
110+
fGW.
111+
T_i_T_e_ratio: Ratio of the ion and electron temperature at the pedestal
112+
[dimensionless].
113+
"""
114+
115+
model_name: Annotated[Literal['epednn_mit'], torax_pydantic.JAX_STATIC] = (
116+
'epednn_mit'
117+
)
118+
n_e_ped: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(
119+
0.7e20
120+
)
121+
n_e_ped_is_fGW: bool = False
122+
T_i_T_e_ratio: torax_pydantic.TimeVaryingScalar = (
123+
torax_pydantic.ValidatedDefault(1.0)
124+
)
125+
126+
def build_pedestal_model(
127+
self,
128+
) -> epednn_mit_pedestal_model.EPEDNNmitPedestalModel:
129+
return epednn_mit_pedestal_model.EPEDNNmitPedestalModel()
130+
131+
def build_runtime_params(
132+
self, t: chex.Numeric
133+
) -> epednn_mit_pedestal_model.RuntimeParams:
134+
return epednn_mit_pedestal_model.RuntimeParams(
135+
set_pedestal=self.set_pedestal.get_value(t),
136+
n_e_ped=self.n_e_ped.get_value(t),
137+
n_e_ped_is_fGW=self.n_e_ped_is_fGW,
138+
T_i_T_e_ratio=self.T_i_T_e_ratio.get_value(t),
139+
)
140+
141+
103142
class SetTpedNped(BasePedestal):
104143
"""A basic version of the pedestal model that uses direct specification.
105144
@@ -171,4 +210,4 @@ def build_runtime_params(
171210
)
172211

173212

174-
PedestalConfig = SetPpedTpedRatioNped | SetTpedNped | NoPedestal
213+
PedestalConfig = SetPpedTpedRatioNped | SetTpedNped | NoPedestal | EPEDNNmit
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2026 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import absltest
16+
from absl.testing import parameterized
17+
import jax
18+
import numpy as np
19+
from torax._src.config import build_runtime_params
20+
from torax._src.core_profiles import initialization
21+
from torax._src.test_utils import default_configs
22+
from torax._src.torax_pydantic import model_config
23+
24+
# pylint: disable=invalid-name
25+
26+
27+
class EPEDNNmitPedestalModelTest(parameterized.TestCase):
28+
29+
def test_build_and_call_pedestal_model(self):
30+
"""Tests the EPEDNN-mit pedestal model.
31+
32+
Note that the EPEDNN-mit is only valid for SPARC parameter space, but we're
33+
testing here with a generic config. Hence, we don't perform checks on
34+
the values of the model outputs.
35+
"""
36+
config = default_configs.get_default_config_dict()
37+
config['pedestal'] = {
38+
'model_name': 'epednn_mit',
39+
'set_pedestal': True,
40+
'n_e_ped': 0.7e20,
41+
'n_e_ped_is_fGW': False,
42+
'T_i_T_e_ratio': 1.0,
43+
}
44+
torax_config = model_config.ToraxConfig.from_dict(config)
45+
provider = (
46+
build_runtime_params.RuntimeParamsProvider.from_config(
47+
torax_config
48+
)
49+
)
50+
source_models = torax_config.sources.build_models()
51+
neoclassical_models = torax_config.neoclassical.build_models()
52+
pedestal_model = torax_config.pedestal.build_pedestal_model()
53+
jitted_pedestal_model = jax.jit(pedestal_model)
54+
55+
geo = torax_config.geometry.build_provider(0.0)
56+
runtime_params = provider(t=0.0)
57+
core_profiles = initialization.initial_core_profiles(
58+
runtime_params,
59+
geo,
60+
source_models,
61+
neoclassical_models,
62+
)
63+
pedestal_model_output = jitted_pedestal_model(
64+
runtime_params=runtime_params,
65+
geo=geo,
66+
core_profiles=core_profiles,
67+
)
68+
69+
np.testing.assert_allclose(pedestal_model_output.n_e_ped, 0.7e20)
70+
np.testing.assert_allclose(
71+
pedestal_model_output.T_i_ped / pedestal_model_output.T_e_ped, 1.0
72+
)
73+
74+
75+
if __name__ == '__main__':
76+
absltest.main()

0 commit comments

Comments
 (0)