1414
1515"""Calculates Block1DCoeffs for a time step."""
1616
17+ import dataclasses
1718import functools
1819import jax
1920import jax .numpy as jnp
2829from torax ._src .fvm import cell_variable
2930from torax ._src .geometry import geometry
3031from torax ._src .internal_boundary_conditions import internal_boundary_conditions as internal_boundary_conditions_lib
32+ from torax ._src .pedestal_model import pedestal_transition_state as pedestal_transition_state_lib
3133from torax ._src .pedestal_model import runtime_params as pedestal_runtime_params_lib
3234from torax ._src .sources import source_profile_builders
3335from torax ._src .sources import source_profiles as source_profiles_lib
@@ -43,9 +45,13 @@ def __init__(
4345 self ,
4446 physics_models : physics_models_lib .PhysicsModels ,
4547 evolving_names : tuple [str , ...],
48+ pedestal_transition_state : (
49+ pedestal_transition_state_lib .PedestalTransitionState | None
50+ ) = None ,
4651 ):
4752 self .physics_models = physics_models
4853 self .evolving_names = evolving_names
54+ self .pedestal_transition_state = pedestal_transition_state
4955
5056 def __hash__ (self ) -> int :
5157 return hash ((
@@ -84,8 +90,8 @@ def __call__(
8490 state x.
8591 geo: The geometry of the system at this time step.
8692 core_profiles: The core profiles of the system at this time step.
87- prev_core_profiles: The core profiles of the system at the previous
88- time step.
93+ prev_core_profiles: The core profiles of the system at the previous time
94+ step.
8995 dt: The time step size.
9096 x: The state with cell-grid values of the evolving variables.
9197 explicit_source_profiles: Precomputed explicit source profiles. These
@@ -133,6 +139,7 @@ def __call__(
133139 evolving_names = self .evolving_names ,
134140 use_pereverzev = use_pereverzev ,
135141 explicit_call = explicit_call ,
142+ pedestal_transition_state = self .pedestal_transition_state ,
136143 )
137144
138145
@@ -145,6 +152,9 @@ def calc_coeffs(
145152 evolving_names : tuple [str , ...],
146153 use_pereverzev : bool = False ,
147154 explicit_call : bool = False ,
155+ pedestal_transition_state : (
156+ pedestal_transition_state_lib .PedestalTransitionState | None
157+ ) = None ,
148158) -> block_1d_coeffs .Block1DCoeffs :
149159 """Calculates Block1DCoeffs for the time step described by `core_profiles`.
150160
@@ -170,6 +180,9 @@ def calc_coeffs(
170180 explicit component of the PDE. Then calculates a reduced Block1DCoeffs if
171181 theta_implicit=1. This saves computation for the default fully implicit
172182 implementation.
183+ pedestal_transition_state: State for tracking pedestal L-H and H-L
184+ transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with
185+ use_formation_model_with_adaptive_source=True. None otherwise.
173186
174187 Returns:
175188 coeffs: Block1DCoeffs containing the coefficients at this time step.
@@ -192,6 +205,7 @@ def calc_coeffs(
192205 physics_models = physics_models ,
193206 evolving_names = evolving_names ,
194207 use_pereverzev = use_pereverzev ,
208+ pedestal_transition_state = pedestal_transition_state ,
195209 )
196210
197211
@@ -210,6 +224,9 @@ def _calc_coeffs_full(
210224 physics_models : physics_models_lib .PhysicsModels ,
211225 evolving_names : tuple [str , ...],
212226 use_pereverzev : bool = False ,
227+ pedestal_transition_state : (
228+ pedestal_transition_state_lib .PedestalTransitionState | None
229+ ) = None ,
213230) -> block_1d_coeffs .Block1DCoeffs :
214231 """See `calc_coeffs` for details."""
215232
@@ -415,26 +432,75 @@ def _calc_coeffs_full(
415432 runtime_params .pedestal .mode
416433 == pedestal_runtime_params_lib .Mode .ADAPTIVE_SOURCE
417434 ):
435+ # Get the pedestal-top target values from the pedestal model.
436+ pedestal_top_values = (
437+ pedestal_model_output .to_internal_boundary_conditions (geo )
438+ )
439+
440+ # Apply ramp scaling if use_formation_model_with_adaptive_source is
441+ # enabled.
442+ if runtime_params .pedestal .use_formation_model_with_adaptive_source :
443+ assert pedestal_transition_state is not None , (
444+ 'pedestal_transition_state must not be None when'
445+ ' use_formation_model_with_adaptive_source is True.'
446+ )
447+ # Scale the pedestal-top values from the pedestal model by the ramp
448+ # fraction. Will be a no-op in H-mode following the transition_time_width.
449+ internal_boundary_conditions = _apply_transition_ramp_scaling (
450+ pedestal_top_values = pedestal_top_values ,
451+ pedestal_transition_state = pedestal_transition_state ,
452+ runtime_params = runtime_params ,
453+ )
454+ # If in L-mode and the H->L ramp has completed (fraction >= 1.0), skip
455+ # the adaptive source entirely to revert to standard L-mode modeling.
456+ # ramp_fraction will be 1.0 if simulation initialized in L-mode and has
457+ # remained in L-mode, since initial transition_start_time is -inf.
458+ ramp_fraction = _compute_ramp_fraction (
459+ pedestal_transition_state = pedestal_transition_state ,
460+ transition_time_width = runtime_params .pedestal .transition_time_width ,
461+ t = runtime_params .t ,
462+ )
463+ # Skip adaptive source if in L-mode and the H->L ramp has completed.
464+ skip_adaptive_source = ~ pedestal_transition_state .in_H_mode & (
465+ ramp_fraction >= 1.0
466+ )
467+ else :
468+ internal_boundary_conditions = pedestal_top_values
469+ skip_adaptive_source = jnp .bool_ (False )
470+
471+ def _apply_source ():
472+ return internal_boundary_conditions_lib .apply_adaptive_source (
473+ source_T_i = source_i ,
474+ source_T_e = source_e ,
475+ source_n_e = source_n_e ,
476+ source_mat_ii = source_mat_ii ,
477+ source_mat_ee = source_mat_ee ,
478+ source_mat_nn = source_mat_nn ,
479+ runtime_params = runtime_params ,
480+ internal_boundary_conditions = internal_boundary_conditions ,
481+ )
482+
483+ def _skip_source ():
484+ return (
485+ source_i ,
486+ source_e ,
487+ source_n_e ,
488+ source_mat_ii ,
489+ source_mat_ee ,
490+ source_mat_nn ,
491+ )
492+
418493 (
419494 source_i ,
420495 source_e ,
421496 source_n_e ,
422497 source_mat_ii ,
423498 source_mat_ee ,
424499 source_mat_nn ,
425- ) = internal_boundary_conditions_lib .apply_adaptive_source (
426- source_T_i = source_i ,
427- source_T_e = source_e ,
428- source_n_e = source_n_e ,
429- source_mat_ii = source_mat_ii ,
430- source_mat_ee = source_mat_ee ,
431- source_mat_nn = source_mat_nn ,
432- runtime_params = runtime_params ,
433- # Pedestal contributes an internal boundary condition to the source
434- # terms at the pedestal top.
435- internal_boundary_conditions = pedestal_model_output .to_internal_boundary_conditions (
436- geo
437- ),
500+ ) = jax .lax .cond (
501+ skip_adaptive_source ,
502+ _skip_source ,
503+ _apply_source ,
438504 )
439505
440506 # --- Build arguments to solver --- #
@@ -539,3 +605,93 @@ def _calc_coeffs_reduced(
539605 transient_in_cell = transient_in_cell ,
540606 )
541607 return coeffs
608+
609+
610+ def _compute_ramp_fraction (
611+ pedestal_transition_state : pedestal_transition_state_lib .PedestalTransitionState ,
612+ transition_time_width : array_typing .FloatScalar ,
613+ t : array_typing .FloatScalar ,
614+ ) -> array_typing .FloatScalar :
615+ """Computes the ramp fraction for a pedestal transition.
616+
617+ Returns a value in [0, 1] representing the progress of the current
618+ transition. 0 means the transition just started, 1 means it is complete.
619+
620+ Args:
621+ pedestal_transition_state: Current transition state.
622+ transition_time_width: Duration of the transition ramp.
623+ t: Current simulation time (i.e. t + dt when called from the solver).
624+
625+ Returns:
626+ Ramp fraction clipped to [0, 1].
627+ """
628+ elapsed = t - pedestal_transition_state .transition_start_time
629+ fraction = elapsed / transition_time_width
630+ return jnp .clip (fraction , 0.0 , 1.0 )
631+
632+
633+ def _apply_transition_ramp_scaling (
634+ pedestal_top_values : internal_boundary_conditions_lib .InternalBoundaryConditions ,
635+ pedestal_transition_state : pedestal_transition_state_lib .PedestalTransitionState ,
636+ runtime_params : runtime_params_lib .RuntimeParams ,
637+ ) -> internal_boundary_conditions_lib .InternalBoundaryConditions :
638+ """Applies ramp scaling to internal boundary conditions during transitions.
639+
640+ During an L-H transition, linearly ramps from L-mode values to the H-mode
641+ targets. During an H-L transition, ramps from the H-mode targets back to
642+ the L-mode values.
643+
644+ The L-mode values are stored in the pedestal_transition_state (captured
645+ at the start of an L->H transition). The H-mode targets are the full
646+ pedestal model output.
647+
648+ Args:
649+ pedestal_top_values: Pedestal-top target internal boundary conditions from
650+ the pedestal model.
651+ pedestal_transition_state: Current transition state containing L-mode
652+ baseline values.
653+ runtime_params: Runtime parameters (provides time t and pedestal config).
654+
655+ Returns:
656+ Scaled internal boundary conditions.
657+ """
658+ ramp_fraction = _compute_ramp_fraction (
659+ pedestal_transition_state = pedestal_transition_state ,
660+ transition_time_width = runtime_params .pedestal .transition_time_width ,
661+ t = runtime_params .t ,
662+ )
663+
664+ # Extract the nonzero pedestal-top values from the IBC. The IBC arrays are
665+ # cell-grid sized with a single nonzero element at the pedestal top. We use
666+ # jnp.max to extract the nonzero value.
667+ h_mode_T_i_ped = jnp .max (pedestal_top_values .T_i )
668+ h_mode_T_e_ped = jnp .max (pedestal_top_values .T_e )
669+ h_mode_n_e_ped = jnp .max (pedestal_top_values .n_e )
670+
671+ l_mode_T_i_ped = pedestal_transition_state .T_i_ped_L_mode
672+ l_mode_T_e_ped = pedestal_transition_state .T_e_ped_L_mode
673+ l_mode_n_e_ped = pedestal_transition_state .n_e_ped_L_mode
674+
675+ # In H-mode: ramp from L-mode to H-mode (L + fraction * (H - L))
676+ # In L-mode (H->L ramp): ramp from H-mode to L-mode (H + fraction * (L - H))
677+ def _lerp (l_val , h_val , frac , in_h_mode ):
678+ return jnp .where (
679+ in_h_mode ,
680+ l_val + frac * (h_val - l_val ), # L->H ramp
681+ h_val + frac * (l_val - h_val ), # H->L ramp
682+ )
683+
684+ in_h_mode = pedestal_transition_state .in_H_mode
685+ scaled_T_i = _lerp (l_mode_T_i_ped , h_mode_T_i_ped , ramp_fraction , in_h_mode )
686+ scaled_T_e = _lerp (l_mode_T_e_ped , h_mode_T_e_ped , ramp_fraction , in_h_mode )
687+ scaled_n_e = _lerp (l_mode_n_e_ped , h_mode_n_e_ped , ramp_fraction , in_h_mode )
688+
689+ # Reconstruct IBC with scaled values at the same pedestal-top location.
690+ # The nonzero mask from the original pedestal_top_values gives us the
691+ # location.
692+ return dataclasses .replace (
693+ pedestal_top_values ,
694+ T_i = jnp .where (pedestal_top_values .T_i != 0.0 , scaled_T_i , 0.0 ),
695+ T_e = jnp .where (pedestal_top_values .T_e != 0.0 , scaled_T_e , 0.0 ),
696+ n_e = jnp .where (pedestal_top_values .n_e != 0.0 , scaled_n_e , 0.0 ),
697+ )
0 commit comments