2727from torax ._src .config import runtime_params as runtime_params_lib
2828from torax ._src .geometry import geometry
2929from torax ._src .internal_boundary_conditions import internal_boundary_conditions as internal_boundary_conditions_lib
30+ from torax ._src .pedestal_model import runtime_params as pedestal_runtime_params_lib
31+ from torax ._src .transport_model import turbulent_transport as turbulent_transport_lib
3032
3133# pylint: disable=invalid-name
3234# Using physics notation naming convention
3335
3436
3537@jax .tree_util .register_dataclass
3638@dataclasses .dataclass (frozen = True )
37- class PedestalModelOutput :
38- """Output of the PedestalModel."""
39+ class AdaptiveSourcePedestalModelOutput :
40+ """Output of a PedestalModel in ADAPTIVE_SOURCE mode ."""
3941
40- # The location of the pedestal.
42+ # The location of the pedestal top .
4143 rho_norm_ped_top : array_typing .FloatScalar
42- # The index of the pedestal in rho_norm.
44+ # The index of the pedestal top in rho_norm.
4345 rho_norm_ped_top_idx : array_typing .IntScalar
44- # The ion temperature at the pedestal.
46+ # The ion temperature at the pedestal top .
4547 T_i_ped : array_typing .FloatScalar
46- # The electron temperature at the pedestal.
48+ # The electron temperature at the pedestal top .
4749 T_e_ped : array_typing .FloatScalar
48- # The electron density at the pedestal in units 10^-3 .
50+ # The electron density at the pedestal top .
4951 n_e_ped : array_typing .FloatScalar
5052
5153 def to_internal_boundary_conditions (
5254 self ,
5355 geo : geometry .Geometry ,
5456 ) -> internal_boundary_conditions_lib .InternalBoundaryConditions :
5557 """Convert the pedestal model output to internal boundary conditions."""
58+ # In this case, the mask is only the pedestal top, not the whole pedestal
59+ # region. This is because we are adding a source/sink term only at the
60+ # pedestal top.
5661 pedestal_mask = (
5762 jnp .zeros_like (geo .rho , dtype = bool )
5863 .at [self .rho_norm_ped_top_idx ]
@@ -65,30 +70,122 @@ def to_internal_boundary_conditions(
6570 )
6671
6772
73+ @jax .tree_util .register_dataclass
74+ @dataclasses .dataclass (frozen = True )
75+ class AdaptiveTransportPedestalModelOutput :
76+ """Output of a PedestalModel in ADAPTIVE_TRANSPORT mode."""
77+
78+ # The location of the pedestal top.
79+ rho_norm_ped_top : array_typing .FloatScalar
80+ # The index of the pedestal top in rho_norm.
81+ rho_norm_ped_top_idx : array_typing .IntScalar
82+ # The multipliers for the turbulent transport coefficients.
83+ chi_e_multiplier : array_typing .FloatScalar
84+ chi_i_multiplier : array_typing .FloatScalar
85+ D_e_multiplier : array_typing .FloatScalar
86+ v_e_multiplier : array_typing .FloatScalar
87+
88+ def combine_with_turbulent_transport (
89+ self ,
90+ turbulent_transport : turbulent_transport_lib .TurbulentTransport ,
91+ geo : geometry .Geometry ,
92+ ) -> turbulent_transport_lib .TurbulentTransport :
93+ """Combine the pedestal model output with the turbulent transport coefficients."""
94+
95+ # In this case, the mask is the whole pedestal region, not just the top.
96+ # This is because we are modifying the transport coefficients in the whole
97+ # pedestal region.
98+ pedestal_mask_face = (
99+ jnp .zeros_like (geo .rho_face , dtype = bool )
100+ .at [self .rho_norm_ped_top_idx :]
101+ .set (True )
102+ )
103+
104+ modified_chi_face_ion = jnp .where (
105+ pedestal_mask_face ,
106+ turbulent_transport .chi_face_ion * self .chi_i_multiplier ,
107+ turbulent_transport .chi_face_ion ,
108+ )
109+ modified_chi_face_el = jnp .where (
110+ pedestal_mask_face ,
111+ turbulent_transport .chi_face_el * self .chi_e_multiplier ,
112+ turbulent_transport .chi_face_el ,
113+ )
114+ modified_d_face_el = jnp .where (
115+ pedestal_mask_face ,
116+ turbulent_transport .d_face_el * self .D_e_multiplier ,
117+ turbulent_transport .d_face_el ,
118+ )
119+ modified_v_face_el = jnp .where (
120+ pedestal_mask_face ,
121+ turbulent_transport .v_face_el * self .v_e_multiplier ,
122+ turbulent_transport .v_face_el ,
123+ )
124+ return turbulent_transport_lib .TurbulentTransport (
125+ chi_face_ion = modified_chi_face_ion ,
126+ chi_face_el = modified_chi_face_el ,
127+ d_face_el = modified_d_face_el ,
128+ v_face_el = modified_v_face_el ,
129+ )
130+
131+
132+ PedestalModelOutput = (
133+ AdaptiveSourcePedestalModelOutput | AdaptiveTransportPedestalModelOutput
134+ )
135+
136+
68137@dataclasses .dataclass (frozen = True , eq = False )
69138class PedestalModel (static_dataclass .StaticDataclass , abc .ABC ):
70- """Calculates temperature and density of the pedestal."""
139+ """Calculates properties of the pedestal."""
71140
72141 def __call__ (
73142 self ,
74143 runtime_params : runtime_params_lib .RuntimeParams ,
75144 geo : geometry .Geometry ,
76145 core_profiles : state .CoreProfiles ,
77146 ) -> PedestalModelOutput :
147+ if (
148+ runtime_params .pedestal .mode
149+ == pedestal_runtime_params_lib .Mode .ADAPTIVE_SOURCE
150+ ):
151+ # Set the pedestal location to infinite to indicate that the pedestal is
152+ # not present.
153+ # Set the index to outside of bounds of the mesh to indicate that the
154+ # pedestal is not present.
155+ dummy_output = AdaptiveSourcePedestalModelOutput (
156+ rho_norm_ped_top = jnp .inf ,
157+ T_i_ped = 0.0 ,
158+ T_e_ped = 0.0 ,
159+ n_e_ped = 0.0 ,
160+ rho_norm_ped_top_idx = geo .torax_mesh .nx ,
161+ )
162+ elif (
163+ runtime_params .pedestal .mode
164+ == pedestal_runtime_params_lib .Mode .ADAPTIVE_TRANSPORT
165+ ):
166+ # Set the pedestal location to infinite to indicate that the pedestal is
167+ # not present.
168+ # Set the index to outside of bounds of the mesh to indicate that the
169+ # pedestal is not present.
170+ # Set the multipliers to 1.0 to indicate that the transport coefficients
171+ # are not modified.
172+ dummy_output = AdaptiveTransportPedestalModelOutput (
173+ rho_norm_ped_top = jnp .inf ,
174+ rho_norm_ped_top_idx = geo .torax_mesh .nx ,
175+ chi_e_multiplier = 1.0 ,
176+ chi_i_multiplier = 1.0 ,
177+ D_e_multiplier = 1.0 ,
178+ v_e_multiplier = 1.0 ,
179+ )
180+ else :
181+ raise ValueError (
182+ f'Unsupported pedestal model mode: { runtime_params .pedestal .mode } '
183+ )
184+
78185 return jax .lax .cond (
79186 runtime_params .pedestal .set_pedestal ,
80187 lambda : self ._call_implementation (runtime_params , geo , core_profiles ),
81- # Set the pedestal location to infinite to indicate that the pedestal is
82- # not present.
83- # Set the index to outside of bounds of the mesh to indicate that the
84- # pedestal is not present.
85- lambda : PedestalModelOutput (
86- rho_norm_ped_top = jnp .inf ,
87- T_i_ped = 0.0 ,
88- T_e_ped = 0.0 ,
89- n_e_ped = 0.0 ,
90- rho_norm_ped_top_idx = geo .torax_mesh .nx ,
91- ),
188+ lambda : dummy_output ,
92189 )
93190
94191 @abc .abstractmethod
@@ -98,4 +195,4 @@ def _call_implementation(
98195 geo : geometry .Geometry ,
99196 core_profiles : state .CoreProfiles ,
100197 ) -> PedestalModelOutput :
101- """Calculate the pedestal values ."""
198+ """Calculate the pedestal properties ."""
0 commit comments