Skip to content

Commit 6b80802

Browse files
jcitrinTorax team
authored andcommitted
Persist rho_norm_ped_top in PedestalTransitionState.
This change makes rho_norm_ped_top part of the PedestalTransitionState, allowing models that dynamically compute the pedestal top location (like EPEDNN) to propagate this value between timesteps. The rho_norm_ped_top from the pedestal model output at t+dt is now stored in the transition state for use in the next timestep's pre_step processing, ensuring accurate L-mode baseline capture. PiperOrigin-RevId: 890589780
1 parent 5d30d36 commit 6b80802

17 files changed

Lines changed: 641 additions & 37 deletions

torax/_src/fvm/calc_coeffs.py

Lines changed: 171 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Calculates Block1DCoeffs for a time step."""
1616

17+
import dataclasses
1718
import functools
1819
import jax
1920
import jax.numpy as jnp
@@ -28,6 +29,7 @@
2829
from torax._src.fvm import cell_variable
2930
from torax._src.geometry import geometry
3031
from torax._src.internal_boundary_conditions import internal_boundary_conditions as internal_boundary_conditions_lib
32+
from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib
3133
from torax._src.pedestal_model import runtime_params as pedestal_runtime_params_lib
3234
from torax._src.sources import source_profile_builders
3335
from torax._src.sources import source_profiles as source_profiles_lib
@@ -43,9 +45,13 @@ def __init__(
4345
self,
4446
physics_models: physics_models_lib.PhysicsModels,
4547
evolving_names: tuple[str, ...],
48+
pedestal_transition_state: (
49+
pedestal_transition_state_lib.PedestalTransitionState | None
50+
) = None,
4651
):
4752
self.physics_models = physics_models
4853
self.evolving_names = evolving_names
54+
self.pedestal_transition_state = pedestal_transition_state
4955

5056
def __hash__(self) -> int:
5157
return hash((
@@ -84,8 +90,8 @@ def __call__(
8490
state x.
8591
geo: The geometry of the system at this time step.
8692
core_profiles: The core profiles of the system at this time step.
87-
prev_core_profiles: The core profiles of the system at the previous
88-
time step.
93+
prev_core_profiles: The core profiles of the system at the previous time
94+
step.
8995
dt: The time step size.
9096
x: The state with cell-grid values of the evolving variables.
9197
explicit_source_profiles: Precomputed explicit source profiles. These
@@ -133,6 +139,7 @@ def __call__(
133139
evolving_names=self.evolving_names,
134140
use_pereverzev=use_pereverzev,
135141
explicit_call=explicit_call,
142+
pedestal_transition_state=self.pedestal_transition_state,
136143
)
137144

138145

@@ -145,6 +152,9 @@ def calc_coeffs(
145152
evolving_names: tuple[str, ...],
146153
use_pereverzev: bool = False,
147154
explicit_call: bool = False,
155+
pedestal_transition_state: (
156+
pedestal_transition_state_lib.PedestalTransitionState | None
157+
) = None,
148158
) -> block_1d_coeffs.Block1DCoeffs:
149159
"""Calculates Block1DCoeffs for the time step described by `core_profiles`.
150160
@@ -170,6 +180,9 @@ def calc_coeffs(
170180
explicit component of the PDE. Then calculates a reduced Block1DCoeffs if
171181
theta_implicit=1. This saves computation for the default fully implicit
172182
implementation.
183+
pedestal_transition_state: State for tracking pedestal L-H and H-L
184+
transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with
185+
use_formation_model_with_adaptive_source=True. None otherwise.
173186
174187
Returns:
175188
coeffs: Block1DCoeffs containing the coefficients at this time step.
@@ -192,6 +205,7 @@ def calc_coeffs(
192205
physics_models=physics_models,
193206
evolving_names=evolving_names,
194207
use_pereverzev=use_pereverzev,
208+
pedestal_transition_state=pedestal_transition_state,
195209
)
196210

197211

@@ -210,6 +224,9 @@ def _calc_coeffs_full(
210224
physics_models: physics_models_lib.PhysicsModels,
211225
evolving_names: tuple[str, ...],
212226
use_pereverzev: bool = False,
227+
pedestal_transition_state: (
228+
pedestal_transition_state_lib.PedestalTransitionState | None
229+
) = None,
213230
) -> block_1d_coeffs.Block1DCoeffs:
214231
"""See `calc_coeffs` for details."""
215232

@@ -415,26 +432,75 @@ def _calc_coeffs_full(
415432
runtime_params.pedestal.mode
416433
== pedestal_runtime_params_lib.Mode.ADAPTIVE_SOURCE
417434
):
435+
# Get the pedestal-top target values from the pedestal model.
436+
pedestal_top_values = (
437+
pedestal_model_output.to_internal_boundary_conditions(geo)
438+
)
439+
440+
# Apply ramp scaling if use_formation_model_with_adaptive_source is
441+
# enabled.
442+
if runtime_params.pedestal.use_formation_model_with_adaptive_source:
443+
assert pedestal_transition_state is not None, (
444+
'pedestal_transition_state must not be None when'
445+
' use_formation_model_with_adaptive_source is True.'
446+
)
447+
# Scale the pedestal-top values from the pedestal model by the ramp
448+
# fraction. Will be a no-op in H-mode following the transition_time_width.
449+
internal_boundary_conditions = _apply_transition_ramp_scaling(
450+
pedestal_top_values=pedestal_top_values,
451+
pedestal_transition_state=pedestal_transition_state,
452+
runtime_params=runtime_params,
453+
)
454+
# If in L-mode and the H->L ramp has completed (fraction >= 1.0), skip
455+
# the adaptive source entirely to revert to standard L-mode modeling.
456+
# ramp_fraction will be 1.0 if simulation initialized in L-mode and has
457+
# remained in L-mode, since initial transition_start_time is -inf.
458+
ramp_fraction = _compute_ramp_fraction(
459+
pedestal_transition_state=pedestal_transition_state,
460+
transition_time_width=runtime_params.pedestal.transition_time_width,
461+
t=runtime_params.t,
462+
)
463+
# Skip adaptive source if in L-mode and the H->L ramp has completed.
464+
skip_adaptive_source = ~pedestal_transition_state.in_H_mode & (
465+
ramp_fraction >= 1.0
466+
)
467+
else:
468+
internal_boundary_conditions = pedestal_top_values
469+
skip_adaptive_source = jnp.bool_(False)
470+
471+
def _apply_source():
472+
return internal_boundary_conditions_lib.apply_adaptive_source(
473+
source_T_i=source_i,
474+
source_T_e=source_e,
475+
source_n_e=source_n_e,
476+
source_mat_ii=source_mat_ii,
477+
source_mat_ee=source_mat_ee,
478+
source_mat_nn=source_mat_nn,
479+
runtime_params=runtime_params,
480+
internal_boundary_conditions=internal_boundary_conditions,
481+
)
482+
483+
def _skip_source():
484+
return (
485+
source_i,
486+
source_e,
487+
source_n_e,
488+
source_mat_ii,
489+
source_mat_ee,
490+
source_mat_nn,
491+
)
492+
418493
(
419494
source_i,
420495
source_e,
421496
source_n_e,
422497
source_mat_ii,
423498
source_mat_ee,
424499
source_mat_nn,
425-
) = internal_boundary_conditions_lib.apply_adaptive_source(
426-
source_T_i=source_i,
427-
source_T_e=source_e,
428-
source_n_e=source_n_e,
429-
source_mat_ii=source_mat_ii,
430-
source_mat_ee=source_mat_ee,
431-
source_mat_nn=source_mat_nn,
432-
runtime_params=runtime_params,
433-
# Pedestal contributes an internal boundary condition to the source
434-
# terms at the pedestal top.
435-
internal_boundary_conditions=pedestal_model_output.to_internal_boundary_conditions(
436-
geo
437-
),
500+
) = jax.lax.cond(
501+
skip_adaptive_source,
502+
_skip_source,
503+
_apply_source,
438504
)
439505

440506
# --- Build arguments to solver --- #
@@ -539,3 +605,93 @@ def _calc_coeffs_reduced(
539605
transient_in_cell=transient_in_cell,
540606
)
541607
return coeffs
608+
609+
610+
def _compute_ramp_fraction(
611+
pedestal_transition_state: pedestal_transition_state_lib.PedestalTransitionState,
612+
transition_time_width: array_typing.FloatScalar,
613+
t: array_typing.FloatScalar,
614+
) -> array_typing.FloatScalar:
615+
"""Computes the ramp fraction for a pedestal transition.
616+
617+
Returns a value in [0, 1] representing the progress of the current
618+
transition. 0 means the transition just started, 1 means it is complete.
619+
620+
Args:
621+
pedestal_transition_state: Current transition state.
622+
transition_time_width: Duration of the transition ramp.
623+
t: Current simulation time (i.e. t + dt when called from the solver).
624+
625+
Returns:
626+
Ramp fraction clipped to [0, 1].
627+
"""
628+
elapsed = t - pedestal_transition_state.transition_start_time
629+
fraction = elapsed / transition_time_width
630+
return jnp.clip(fraction, 0.0, 1.0)
631+
632+
633+
def _apply_transition_ramp_scaling(
634+
pedestal_top_values: internal_boundary_conditions_lib.InternalBoundaryConditions,
635+
pedestal_transition_state: pedestal_transition_state_lib.PedestalTransitionState,
636+
runtime_params: runtime_params_lib.RuntimeParams,
637+
) -> internal_boundary_conditions_lib.InternalBoundaryConditions:
638+
"""Applies ramp scaling to internal boundary conditions during transitions.
639+
640+
During an L-H transition, linearly ramps from L-mode values to the H-mode
641+
targets. During an H-L transition, ramps from the H-mode targets back to
642+
the L-mode values.
643+
644+
The L-mode values are stored in the pedestal_transition_state (captured
645+
at the start of an L->H transition). The H-mode targets are the full
646+
pedestal model output.
647+
648+
Args:
649+
pedestal_top_values: Pedestal-top target internal boundary conditions from
650+
the pedestal model.
651+
pedestal_transition_state: Current transition state containing L-mode
652+
baseline values.
653+
runtime_params: Runtime parameters (provides time t and pedestal config).
654+
655+
Returns:
656+
Scaled internal boundary conditions.
657+
"""
658+
ramp_fraction = _compute_ramp_fraction(
659+
pedestal_transition_state=pedestal_transition_state,
660+
transition_time_width=runtime_params.pedestal.transition_time_width,
661+
t=runtime_params.t,
662+
)
663+
664+
# Extract the nonzero pedestal-top values from the IBC. The IBC arrays are
665+
# cell-grid sized with a single nonzero element at the pedestal top. We use
666+
# jnp.max to extract the nonzero value.
667+
h_mode_T_i_ped = jnp.max(pedestal_top_values.T_i)
668+
h_mode_T_e_ped = jnp.max(pedestal_top_values.T_e)
669+
h_mode_n_e_ped = jnp.max(pedestal_top_values.n_e)
670+
671+
l_mode_T_i_ped = pedestal_transition_state.T_i_ped_L_mode
672+
l_mode_T_e_ped = pedestal_transition_state.T_e_ped_L_mode
673+
l_mode_n_e_ped = pedestal_transition_state.n_e_ped_L_mode
674+
675+
# In H-mode: ramp from L-mode to H-mode (L + fraction * (H - L))
676+
# In L-mode (H->L ramp): ramp from H-mode to L-mode (H + fraction * (L - H))
677+
def _lerp(l_val, h_val, frac, in_h_mode):
678+
return jnp.where(
679+
in_h_mode,
680+
l_val + frac * (h_val - l_val), # L->H ramp
681+
h_val + frac * (l_val - h_val), # H->L ramp
682+
)
683+
684+
in_h_mode = pedestal_transition_state.in_H_mode
685+
scaled_T_i = _lerp(l_mode_T_i_ped, h_mode_T_i_ped, ramp_fraction, in_h_mode)
686+
scaled_T_e = _lerp(l_mode_T_e_ped, h_mode_T_e_ped, ramp_fraction, in_h_mode)
687+
scaled_n_e = _lerp(l_mode_n_e_ped, h_mode_n_e_ped, ramp_fraction, in_h_mode)
688+
689+
# Reconstruct IBC with scaled values at the same pedestal-top location.
690+
# The nonzero mask from the original pedestal_top_values gives us the
691+
# location.
692+
return dataclasses.replace(
693+
pedestal_top_values,
694+
T_i=jnp.where(pedestal_top_values.T_i != 0.0, scaled_T_i, 0.0),
695+
T_e=jnp.where(pedestal_top_values.T_e != 0.0, scaled_T_e, 0.0),
696+
n_e=jnp.where(pedestal_top_values.n_e != 0.0, scaled_n_e, 0.0),
697+
)

torax/_src/fvm/tests/calc_coeffs_test.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616

1717
from absl.testing import absltest
1818
from absl.testing import parameterized
19+
import jax.numpy as jnp
1920
from torax._src.config import build_runtime_params
2021
from torax._src.core_profiles import initialization
2122
from torax._src.fvm import calc_coeffs
23+
from torax._src.internal_boundary_conditions import internal_boundary_conditions
24+
from torax._src.pedestal_model import pedestal_transition_state
2225
from torax._src.sources import source_profile_builders
2326
from torax._src.test_utils import default_sources
2427
from torax._src.torax_pydantic import model_config
@@ -159,5 +162,92 @@ def create_coeffs_callback(
159162
)
160163

161164

165+
class TransitionCalculationsTest(parameterized.TestCase):
166+
167+
def test_pedestal_transition_state_initial_state(self):
168+
state = pedestal_transition_state.PedestalTransitionState.initial_state()
169+
self.assertTrue(jnp.isneginf(state.transition_start_time))
170+
self.assertEqual(state.T_i_ped_L_mode, 0.0)
171+
self.assertFalse(state.in_H_mode)
172+
173+
def test_compute_ramp_fraction_very_small_width(self):
174+
state = pedestal_transition_state.PedestalTransitionState(
175+
transition_start_time=jnp.array(1.0),
176+
T_i_ped_L_mode=jnp.array(0.0),
177+
T_e_ped_L_mode=jnp.array(0.0),
178+
n_e_ped_L_mode=jnp.array(0.0),
179+
in_H_mode=jnp.array(True),
180+
rho_norm_ped_top=jnp.array(0.9),
181+
)
182+
# Very small transition_time_width clips to 1.0 when elapsed > 0.
183+
self.assertEqual(
184+
calc_coeffs._compute_ramp_fraction(state, 1e-10, 1.5), 1.0
185+
)
186+
187+
def test_compute_ramp_fraction_ramp(self):
188+
state = pedestal_transition_state.PedestalTransitionState(
189+
transition_start_time=jnp.array(1.0),
190+
T_i_ped_L_mode=jnp.array(0.0),
191+
T_e_ped_L_mode=jnp.array(0.0),
192+
n_e_ped_L_mode=jnp.array(0.0),
193+
in_H_mode=jnp.array(True),
194+
rho_norm_ped_top=jnp.array(0.9),
195+
)
196+
# transition_time_width = 1.0. Start at 1.0.
197+
# Clip at both ends
198+
self.assertEqual(
199+
calc_coeffs._compute_ramp_fraction(state, 1.0, 0.5), 0.0
200+
) # t < start
201+
self.assertEqual(
202+
calc_coeffs._compute_ramp_fraction(state, 1.0, 1.0), 0.0
203+
) # t = start
204+
self.assertEqual(
205+
calc_coeffs._compute_ramp_fraction(state, 1.0, 1.5), 0.5
206+
) # t = start + 0.5
207+
self.assertEqual(
208+
calc_coeffs._compute_ramp_fraction(state, 1.0, 2.0), 1.0
209+
) # t = start + 1.0
210+
self.assertEqual(
211+
calc_coeffs._compute_ramp_fraction(state, 1.0, 2.5), 1.0
212+
) # t = start + 1.5
213+
214+
def test_apply_transition_ramp_scaling_l_to_h(self):
215+
l_mode_baseline = 1.0
216+
h_mode_target = 3.0
217+
218+
state = pedestal_transition_state.PedestalTransitionState(
219+
transition_start_time=jnp.array(1.0),
220+
T_i_ped_L_mode=jnp.array(l_mode_baseline),
221+
T_e_ped_L_mode=jnp.array(l_mode_baseline),
222+
n_e_ped_L_mode=jnp.array(l_mode_baseline),
223+
in_H_mode=jnp.array(True), # L -> H
224+
rho_norm_ped_top=jnp.array(0.9),
225+
)
226+
227+
pedestal_top_values = (
228+
internal_boundary_conditions.InternalBoundaryConditions(
229+
T_i=jnp.array([0.0, h_mode_target, 0.0]),
230+
T_e=jnp.array([0.0, h_mode_target, 0.0]),
231+
n_e=jnp.array([0.0, h_mode_target, 0.0]),
232+
)
233+
)
234+
235+
class MockPedestalRuntimeParams:
236+
transition_time_width = 1.0
237+
238+
class MockRuntimeParams:
239+
pedestal = MockPedestalRuntimeParams()
240+
t = 1.5 # halfway
241+
242+
scaled_ibc = calc_coeffs._apply_transition_ramp_scaling( # pytype: disable=wrong-arg-types
243+
pedestal_top_values=pedestal_top_values,
244+
pedestal_transition_state=state,
245+
runtime_params=MockRuntimeParams(),
246+
)
247+
248+
# Expected: 1.0 + 0.5 * (3.0 - 1.0) = 2.0
249+
self.assertTrue(jnp.allclose(scaled_ibc.T_i, jnp.array([0.0, 2.0, 0.0])))
250+
251+
162252
if __name__ == '__main__':
163253
absltest.main()

0 commit comments

Comments
 (0)