Skip to content

Commit 5873c65

Browse files
hamelphiTorax team
authored andcommitted
Add enable_fast_ions toggle to numerics config
PiperOrigin-RevId: 879030506
1 parent 3a9eeb6 commit 5873c65

4 files changed

Lines changed: 50 additions & 15 deletions

File tree

torax/_src/config/numerics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class RuntimeParams:
5353
evolve_density: bool = dataclasses.field(metadata={'static': True})
5454
exact_t_final: bool = dataclasses.field(metadata={'static': True})
5555
adaptive_dt: bool = dataclasses.field(metadata={'static': True})
56+
enable_fast_ions: bool = dataclasses.field(metadata={'static': True})
5657

5758
@functools.cached_property
5859
def evolving_names(self) -> tuple[str, ...]:
@@ -128,6 +129,7 @@ class Numerics(torax_pydantic.BaseModelFrozen):
128129
evolve_electron_heat: Annotated[bool, torax_pydantic.JAX_STATIC] = True
129130
evolve_current: Annotated[bool, torax_pydantic.JAX_STATIC] = False
130131
evolve_density: Annotated[bool, torax_pydantic.JAX_STATIC] = False
132+
enable_fast_ions: Annotated[bool, torax_pydantic.JAX_STATIC] = False
131133
resistivity_multiplier: torax_pydantic.TimeVaryingScalar = (
132134
torax_pydantic.ValidatedDefault(1.0)
133135
)
@@ -188,4 +190,5 @@ def build_runtime_params(self, t: chex.Numeric) -> RuntimeParams:
188190
evolve_density=self.evolve_density,
189191
exact_t_final=self.exact_t_final,
190192
adaptive_dt=self.adaptive_dt,
193+
enable_fast_ions=self.enable_fast_ions,
191194
)

torax/_src/core_profiles/initialization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,9 @@ def initial_core_profiles(
9999
)
100100

101101
fast_ions_list = []
102-
for s in source_models.standard_sources.values():
103-
fast_ions_list.extend(s.zero_fast_ions(geo))
102+
if runtime_params.numerics.enable_fast_ions:
103+
for s in source_models.standard_sources.values():
104+
fast_ions_list.extend(s.zero_fast_ions(geo))
104105

105106
core_profiles = state.CoreProfiles(
106107
T_i=T_i,

torax/_src/sources/source.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -222,34 +222,55 @@ def get_value(
222222
conductivity,
223223
)
224224
case sources_runtime_params_lib.Mode.PRESCRIBED:
225-
if len(self.affected_core_profiles) != len(
226-
source_params.prescribed_values
225+
expected_len = len(self.affected_core_profiles)
226+
prescribed_len = len(source_params.prescribed_values)
227+
if (
228+
AffectedCoreProfile.FAST_IONS in self.affected_core_profiles
229+
and not runtime_params.numerics.enable_fast_ions
230+
and prescribed_len == expected_len - 1
227231
):
232+
fast_ions_idx = self.affected_core_profiles.index(
233+
AffectedCoreProfile.FAST_IONS
234+
)
235+
res_list = list(source_params.prescribed_values)
236+
res_list.insert(fast_ions_idx, ())
237+
res = tuple(res_list)
238+
elif prescribed_len != expected_len:
228239
raise ValueError(
229240
'When using PRESCRIBED mode, the number of prescribed values must'
230241
' match the number of affected core profiles. Was: '
231242
f'{len(source_params.prescribed_values)} '
232243
f' Expected: {len(self.affected_core_profiles)}.'
233244
)
234-
res = source_params.prescribed_values
245+
else:
246+
res = source_params.prescribed_values
235247
case sources_runtime_params_lib.Mode.ZERO:
236248
zeros = jnp.zeros(geo.rho_norm.shape)
237-
res = tuple(
238-
self.zero_fast_ions(geo)
239-
if acp == AffectedCoreProfile.FAST_IONS
240-
else zeros
241-
for acp in self.affected_core_profiles
242-
)
249+
res_list = []
250+
for affected_core_profile in self.affected_core_profiles:
251+
if affected_core_profile == AffectedCoreProfile.FAST_IONS:
252+
if runtime_params.numerics.enable_fast_ions:
253+
res_list.append(self.zero_fast_ions(geo))
254+
else:
255+
res_list.append(())
256+
else:
257+
res_list.append(zeros)
258+
res = tuple(res_list)
243259
case _:
244260
raise ValueError(f'Unknown mode: {mode}')
245261

246262
if AffectedCoreProfile.FAST_IONS in self.affected_core_profiles:
247263
fast_ions_idx = self.affected_core_profiles.index(
248264
AffectedCoreProfile.FAST_IONS
249265
)
250-
self._validate_fast_ions(
251-
res[fast_ions_idx],
252-
geo,
253-
)
266+
if runtime_params.numerics.enable_fast_ions:
267+
self._validate_fast_ions(
268+
res[fast_ions_idx],
269+
geo,
270+
)
271+
elif res[fast_ions_idx]:
272+
res_list = list(res)
273+
res_list[fast_ions_idx] = ()
274+
res = tuple(res_list)
254275

255276
return res

torax/_src/sources/tests/source_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from absl.testing import absltest
1919
from absl.testing import parameterized
2020
import numpy as np
21+
from torax._src.config import numerics
2122
from torax._src.config import runtime_params as runtime_params_lib
2223
from torax._src.geometry import geometry
2324
from torax._src.sources import electron_cyclotron_source
@@ -97,6 +98,9 @@ def test_correct_mode_called(
9798
dynamic_slice = mock.create_autospec(
9899
runtime_params_lib.RuntimeParams,
99100
sources=dynamic_source_params,
101+
numerics=mock.create_autospec(
102+
numerics.RuntimeParams, enable_fast_ions=True
103+
),
100104
)
101105
# Make a geo with rho_norm as we need it for the zero profile shape.
102106
geo = mock.create_autospec(
@@ -130,6 +134,9 @@ def test_prescribed_values_for_multiple_affected_profiles(self):
130134
dynamic_slice = mock.create_autospec(
131135
runtime_params_lib.RuntimeParams,
132136
sources=dynamic_source_params,
137+
numerics=mock.create_autospec(
138+
numerics.RuntimeParams, enable_fast_ions=True
139+
),
133140
)
134141
profile = source.get_value(
135142
runtime_params=dynamic_slice,
@@ -166,6 +173,9 @@ def test_source_with_mismatched_prescribed_values_raises_error(self):
166173
dynamic_slice = mock.create_autospec(
167174
runtime_params_lib.RuntimeParams,
168175
sources=dynamic_source_params,
176+
numerics=mock.create_autospec(
177+
numerics.RuntimeParams, enable_fast_ions=True
178+
),
169179
)
170180
with self.assertRaisesRegex(
171181
ValueError,

0 commit comments

Comments
 (0)