|
| 1 | +# Copyright 2026 DeepMind Technologies Limited |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""A pedestal that forms dynamically based on the LH threshold and critical ballooning parameter.""" |
| 15 | + |
| 16 | +import dataclasses |
| 17 | +import jax |
| 18 | +from jax import numpy as jnp |
| 19 | +from torax._src import array_typing |
| 20 | +from torax._src import constants |
| 21 | +from torax._src import state |
| 22 | +from torax._src.config import runtime_params as runtime_params_lib |
| 23 | +from torax._src.geometry import geometry |
| 24 | +from torax._src.pedestal_model import pedestal_model |
| 25 | +from torax._src.pedestal_model import runtime_params as pedestal_runtime_params_lib |
| 26 | +from torax._src.physics import formulas |
| 27 | +from torax._src.physics import scaling_laws |
| 28 | + |
| 29 | +# pylint: disable=invalid-name |
| 30 | + |
| 31 | + |
| 32 | +@jax.tree_util.register_dataclass |
| 33 | +@dataclasses.dataclass(frozen=True) |
| 34 | +class RuntimeParams(pedestal_runtime_params_lib.RuntimeParams): |
| 35 | + """Runtime params for the DynamicPedestalModel.""" |
| 36 | + |
| 37 | + suppression_factor: array_typing.FloatScalar |
| 38 | + suppression_rate: array_typing.FloatScalar |
| 39 | + augmentation_factor: array_typing.FloatScalar |
| 40 | + augmentation_rate: array_typing.FloatScalar |
| 41 | + alpha_crit: array_typing.FloatScalar |
| 42 | + rho_norm_ped_top: array_typing.FloatScalar |
| 43 | + |
| 44 | + |
| 45 | +@dataclasses.dataclass(frozen=True, eq=False) |
| 46 | +class DynamicPedestal(pedestal_model.PedestalModel): |
| 47 | + """A pedestal that forms dynamically based on the LH threshold and critical ballooning parameter.""" |
| 48 | + |
| 49 | + def _call_implementation( |
| 50 | + self, |
| 51 | + runtime_params: runtime_params_lib.RuntimeParams, |
| 52 | + geo: geometry.Geometry, |
| 53 | + core_profiles: state.CoreProfiles, |
| 54 | + ) -> pedestal_model.PedestalModelOutput: |
| 55 | + if ( |
| 56 | + runtime_params.pedestal.mode |
| 57 | + != pedestal_runtime_params_lib.Mode.ADAPTIVE_TRANSPORT |
| 58 | + ): |
| 59 | + raise ValueError('DynamicPedestal only supports ADAPTIVE_TRANSPORT mode.') |
| 60 | + |
| 61 | + pedestal_runtime_params = runtime_params.pedestal |
| 62 | + assert isinstance(pedestal_runtime_params, RuntimeParams) |
| 63 | + |
| 64 | + # Get the pedestal top location. |
| 65 | + rho_norm_ped_top_idx = jnp.abs( |
| 66 | + geo.rho_norm - pedestal_runtime_params.rho_norm_ped_top |
| 67 | + ).argmin() |
| 68 | + rho_norm_ped_top = jax.lax.dynamic_index_in_dim( |
| 69 | + geo.rho_norm, |
| 70 | + rho_norm_ped_top_idx, |
| 71 | + keepdims=False, |
| 72 | + ) |
| 73 | + pedestal_active_mask_face = jnp.where( |
| 74 | + geo.rho_face_norm >= rho_norm_ped_top, 1.0, 0.0 |
| 75 | + ) |
| 76 | + |
| 77 | + # Are we above P_LH? If so, decrease chi |
| 78 | + _, _, P_LH, _ = scaling_laws.calculate_plh_scaling_factor( |
| 79 | + geo, core_profiles |
| 80 | + ) |
| 81 | + # TODO(b/323504363): use the correct source profiles to calculate P_SOL_total. |
| 82 | + # dP_e_drho_norm = merged_source_profiles.total_sources('T_e', geo) |
| 83 | + # dP_i_drho_norm = merged_source_profiles.total_sources('T_i', geo) |
| 84 | + # Integrate over rho_norm to get total power out of the separatrix [W]. |
| 85 | + # P_SOL_total = math_utils.volume_integration( |
| 86 | + # dP_e_drho_norm + dP_i_drho_norm, geo |
| 87 | + # ) |
| 88 | + P_SOL_total = P_LH # TODO(b/323504363): replace with calculated value |
| 89 | + # We use a sigmoid function to smooth the transition. |
| 90 | + # If P < P_LH, h_mode_weight -> 0, transport_decrease_multiplier -> 1.0. |
| 91 | + # If P > P_LH, h_mode_weight -> 1, transport_decrease_multiplier -> |
| 92 | + # suppression_factor. |
| 93 | + h_mode_weight = jax.nn.sigmoid( |
| 94 | + (P_SOL_total - P_LH) / (pedestal_runtime_params.suppression_rate * P_LH) |
| 95 | + ) |
| 96 | + transport_decrease_multiplier = ( |
| 97 | + 1.0 - h_mode_weight |
| 98 | + ) * 1.0 + h_mode_weight * pedestal_runtime_params.suppression_factor |
| 99 | + |
| 100 | + # Are we above the critical ballooning parameter anywhere in the |
| 101 | + # pedestal region? If so, increase chi. |
| 102 | + dp_dr_face = formulas.calc_pprime(core_profiles) |
| 103 | + alpha_face = jnp.abs( |
| 104 | + 2 |
| 105 | + * constants.CONSTANTS.mu_0 |
| 106 | + * geo.R_major_profile_face |
| 107 | + * core_profiles.q_face**2 |
| 108 | + / geo.B_0**2 |
| 109 | + * dp_dr_face |
| 110 | + ) |
| 111 | + max_alpha = jnp.max(pedestal_active_mask_face * alpha_face) |
| 112 | + # We use a softplus function to smooth the transition. |
| 113 | + # If max_alpha < alpha_crit, continuous_elm_weight -> 0, |
| 114 | + # transport_increase_multiplier -> 1.0 |
| 115 | + # If max_alpha > alpha_crit, continuous_elm_weight -> inf, |
| 116 | + # transport_increase_multiplier -> inf. |
| 117 | + continuous_elm_weight = jax.nn.softplus( |
| 118 | + (max_alpha - pedestal_runtime_params.alpha_crit) |
| 119 | + / ( |
| 120 | + pedestal_runtime_params.augmentation_rate |
| 121 | + * pedestal_runtime_params.alpha_crit |
| 122 | + ) |
| 123 | + ) |
| 124 | + transport_increase_multiplier = 1.0 + ( |
| 125 | + continuous_elm_weight * pedestal_runtime_params.augmentation_factor |
| 126 | + ) |
| 127 | + |
| 128 | + # Combine the multipliers. |
| 129 | + transport_multiplier = jnp.exp( |
| 130 | + jnp.log(transport_decrease_multiplier) |
| 131 | + + jnp.log(transport_increase_multiplier) |
| 132 | + ) |
| 133 | + |
| 134 | + # For simplicity, we currently scale all coefficients by the same factor. |
| 135 | + return pedestal_model.AdaptiveTransportPedestalModelOutput( |
| 136 | + rho_norm_ped_top=rho_norm_ped_top, |
| 137 | + rho_norm_ped_top_idx=rho_norm_ped_top_idx, |
| 138 | + chi_e_multiplier=transport_multiplier, |
| 139 | + chi_i_multiplier=transport_multiplier, |
| 140 | + D_e_multiplier=transport_multiplier, |
| 141 | + v_e_multiplier=transport_multiplier, |
| 142 | + ) |
0 commit comments