Skip to content

Commit eb46880

Browse files
hamelphiTorax team
authored andcommitted
Enable fast ions by default and update sim test references
PiperOrigin-RevId: 868861998
1 parent 5873c65 commit eb46880

6 files changed

Lines changed: 745 additions & 50 deletions

File tree

torax/_src/config/numerics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class Numerics(torax_pydantic.BaseModelFrozen):
129129
evolve_electron_heat: Annotated[bool, torax_pydantic.JAX_STATIC] = True
130130
evolve_current: Annotated[bool, torax_pydantic.JAX_STATIC] = False
131131
evolve_density: Annotated[bool, torax_pydantic.JAX_STATIC] = False
132-
enable_fast_ions: Annotated[bool, torax_pydantic.JAX_STATIC] = False
132+
enable_fast_ions: Annotated[bool, torax_pydantic.JAX_STATIC] = True
133133
resistivity_multiplier: torax_pydantic.TimeVaryingScalar = (
134134
torax_pydantic.ValidatedDefault(1.0)
135135
)
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Copyright 2026 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Fast ion utility functions."""
16+
17+
import jax
18+
from jax import numpy as jnp
19+
from torax._src import constants
20+
from torax._src import math_utils
21+
from torax._src.physics import collisions
22+
23+
24+
# pylint: disable=invalid-name
25+
26+
27+
def _nu_epsilon(
28+
m_a_amu: float,
29+
Z_a: float,
30+
T_a_keV: jax.Array,
31+
m_b_amu: float,
32+
Z_b: float,
33+
n_b_m3: jax.Array,
34+
T_b_keV: jax.Array,
35+
ln_lambda: jax.Array,
36+
) -> jax.Array:
37+
"""NRL Formulary energy exchange rate nu_epsilon [Hz].
38+
39+
See NRL Plasma Formulary, page 34.
40+
41+
Args:
42+
m_a_amu: Mass of species a [amu].
43+
Z_a: Charge number of species a.
44+
T_a_keV: Temperature of species a [keV].
45+
m_b_amu: Mass of species b [amu].
46+
Z_b: Charge number of species b.
47+
n_b_m3: Density of species b [m^-3].
48+
T_b_keV: Temperature of species b [keV].
49+
ln_lambda: Coulomb logarithm.
50+
51+
Returns:
52+
Energy exchange rate [Hz].
53+
"""
54+
55+
n_b_cm3 = n_b_m3 / 1.0e6
56+
T_a_ev = T_a_keV * 1000.0
57+
T_b_ev = T_b_keV * 1000.0
58+
59+
# The formulary uses cgs units. We convert the constant to use amu for
60+
# masses to avoid tiny values which can cause numerical issues.
61+
coeff = 1.8e-19 / jnp.sqrt(constants.CONSTANTS.m_amu * 1e3)
62+
63+
num = (
64+
coeff
65+
* jnp.sqrt(m_a_amu * m_b_amu)
66+
* Z_a**2
67+
* Z_b**2
68+
* n_b_cm3
69+
* ln_lambda
70+
)
71+
denom = jnp.power(m_b_amu * T_a_ev + m_a_amu * T_b_ev, 1.5)
72+
return jnp.asarray(math_utils.safe_divide(num, denom))
73+
74+
75+
def _compute_T_tail(
76+
P_density_W: jax.Array,
77+
T_e: jax.Array,
78+
n_e: jax.Array,
79+
n_total: jax.Array,
80+
charge_number: float,
81+
mass_number: float,
82+
) -> jax.Array:
83+
"""Computes the effective tail temperature via the Stix xi parameter.
84+
85+
Uses the Spitzer slowing-down time on electrons (tau_s) and the Stix
86+
parameter xi to compute T_tail = T_e * (1 + xi) [Stix, Nuc. Fus. 1975].
87+
88+
The slowing-down time uses:
89+
tau_s [s] = 6.27e8 * A * (T_e[eV])^1.5 / (Z^2 * n_e[cm^-3] * ln_lambda)
90+
[Stix, Plasma Physics 14, 367 (1972), formula 16].
91+
92+
Args:
93+
P_density_W: Absolute power density [W/m^3].
94+
T_e: Electron temperature [keV].
95+
n_e: Electron density [m^-3].
96+
n_total: Total minority density [m^-3].
97+
charge_number: Charge number of the minority species.
98+
mass_number: Mass number of the minority species.
99+
100+
Returns:
101+
T_tail: Effective tail temperature [keV].
102+
"""
103+
log_lambda_ei = collisions.calculate_log_lambda_ei(T_e, n_e)
104+
105+
T_e_eV = T_e * 1000.0
106+
n_e_cm3 = n_e / 1.0e6
107+
108+
tau_s = math_utils.safe_divide(
109+
6.27e8 * mass_number * jnp.power(T_e_eV, 1.5),
110+
charge_number**2 * n_e_cm3 * log_lambda_ei,
111+
)
112+
T_e_J = T_e * constants.CONSTANTS.keV_to_J
113+
energy_density = 1.5 * n_total * T_e_J
114+
# Accroding to Stix 1972 (page 374), the energy_slowing_down_time is half
115+
# the Spitzer slowing-down time.
116+
energy_slowing_down_time = 0.5 * tau_s
117+
118+
xi = math_utils.safe_divide(
119+
P_density_W * energy_slowing_down_time, energy_density
120+
)
121+
122+
return T_e * (1.0 + xi)
123+
124+
125+
def bimaxwellian_split(
126+
power_deposition: jax.Array,
127+
T_e: jax.Array,
128+
n_e: jax.Array,
129+
T_i: jax.Array,
130+
n_i: jax.Array,
131+
minority_concentration: jax.Array | float,
132+
P_total_W: float,
133+
charge_number: float,
134+
mass_number: float,
135+
bulk_ion_mass: float,
136+
Z_i: float,
137+
n_impurity: jax.Array,
138+
Z_impurity: float,
139+
A_impurity: float,
140+
) -> tuple[jax.Array, jax.Array]:
141+
"""Returns (n_tail, T_tail) using the Power Balance Closure.
142+
143+
Splits a minority species density into a bulk thermal component and a
144+
high-energy tail component based on Stix theory power balance.
145+
146+
Unlike the simplified Stix model, this implementation includes energy transfer
147+
to bulk ions and impurities via the NRL Formulary nu_epsilon rate, making it
148+
more accurate when T_tail is close to the critical energy.
149+
150+
Args:
151+
power_deposition: Power deposition profile [MW/m^3 / MW_in]. Normalized per
152+
MW of input power.
153+
T_e: Electron temperature profile [keV].
154+
n_e: Electron density profile [m^-3].
155+
T_i: Ion temperature profile [keV].
156+
n_i: Main ion density profile [m^-3].
157+
minority_concentration: Minority species fractional concentration
158+
(n_minority/n_e).
159+
P_total_W: Total absolute power absorbed [W].
160+
charge_number: Charge number of the minority species (e.g. 2 for He3).
161+
mass_number: Mass number of the minority species (e.g. 3.016 for He3).
162+
bulk_ion_mass: Mass of the bulk main ion species [amu] (e.g. 2.014 for D).
163+
Z_i: Charge number of the bulk main ion species.
164+
n_impurity: Impurity density profile [m^-3].
165+
Z_impurity: Charge number of the impurity species.
166+
A_impurity: Mass number of the impurity species [amu].
167+
168+
Returns:
169+
Tuple containing:
170+
n_tail: Density of the fast tail component [m^-3].
171+
T_tail: Temperature of the fast tail component [keV].
172+
"""
173+
consts = constants.CONSTANTS
174+
175+
n_total = n_e * minority_concentration
176+
177+
P_density_W = power_deposition * (P_total_W)
178+
179+
me_amu = consts.m_e / consts.m_amu
180+
181+
T_tail = _compute_T_tail(
182+
P_density_W=P_density_W,
183+
T_e=T_e,
184+
n_e=n_e,
185+
n_total=n_total,
186+
charge_number=charge_number,
187+
mass_number=mass_number,
188+
)
189+
190+
log_lambda_ei = collisions.calculate_log_lambda_ei(T_e, n_e)
191+
192+
nu_tail_e = _nu_epsilon(
193+
mass_number,
194+
charge_number,
195+
T_tail,
196+
me_amu,
197+
1.0,
198+
n_e,
199+
T_e,
200+
log_lambda_ei,
201+
)
202+
203+
log_lambda_ii = collisions.calculate_log_lambda_ii(
204+
T_tail, n_i, jnp.asarray(Z_i)
205+
)
206+
nu_tail_i = _nu_epsilon(
207+
mass_number,
208+
charge_number,
209+
T_tail,
210+
bulk_ion_mass,
211+
Z_i,
212+
n_i,
213+
T_i,
214+
log_lambda_ii,
215+
)
216+
217+
log_lambda_impurity = collisions.calculate_log_lambda_ii(
218+
T_tail, jnp.maximum(n_impurity, 1.0), jnp.asarray(Z_impurity)
219+
)
220+
nu_tail_impurity = _nu_epsilon(
221+
mass_number,
222+
charge_number,
223+
T_tail,
224+
A_impurity,
225+
Z_impurity,
226+
n_impurity,
227+
T_i,
228+
log_lambda_impurity,
229+
)
230+
231+
energy_loss_rate_per_particle = (
232+
1.5
233+
* consts.keV_to_J
234+
* (
235+
nu_tail_e * (T_tail - T_e)
236+
+ nu_tail_i * (T_tail - T_i)
237+
+ nu_tail_impurity * (T_tail - T_i)
238+
)
239+
)
240+
241+
n_tail = math_utils.safe_divide(P_density_W, energy_loss_rate_per_particle)
242+
n_tail = jnp.clip(n_tail, 0.0, n_total * 0.99)
243+
244+
n_tail = jnp.where(P_density_W <= 1.0e-6, 0.0, n_tail)
245+
T_tail = jnp.where(P_density_W <= 1.0e-6, T_i, T_tail)
246+
247+
return n_tail, T_tail

0 commit comments

Comments
 (0)