Skip to content

Commit fc2a107

Browse files
theo-brownTorax team
authored andcommitted
Replace exponential smoothing of dW/dt with computation with a time lag.
Introduces a history buffer for stored thermal energy and time, allowing for the calculation of dW/dt over a fixed time window (`dW_dt_window`). The buffer length is configurable via `dW_dt_buffer_length`. PiperOrigin-RevId: 893464692
1 parent b751c80 commit fc2a107

7 files changed

Lines changed: 220 additions & 49 deletions

File tree

docs/configuration.rst

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -502,15 +502,18 @@ equations being solved, constant numerical variables.
502502
Prefactor for adaptive source term for setting density internal boundary
503503
conditions.
504504

505-
``dW_dt_smoothing_time_scale`` (float [default = 0.3])
506-
Time scale [s] for the exponential moving average smoothing of dW/dt terms
507-
used in P_SOL and confinement time calculations. If 0.0, no smoothing is
508-
applied and raw dW/dt is used.
509-
510505
``min_rho_norm`` (float [default = 0.015])
511506
Minimum rho_norm value below which current profile values are extrapolated to
512507
the axis in psi calculations, to avoid numerical artifacts near rho=0.
513508

509+
``dW_dt_window`` (float [default = 0.01])
510+
Time window [s] over which to compute the windowed derivative of the stored
511+
thermal energy.
512+
513+
``dW_dt_buffer_length`` (int [default = 50])
514+
Number of elements to keep in the history buffer for computing the windowed
515+
derivative.
516+
514517
.. TODO (b/434175938): consolidate naming to _min or _minimum.
515518
516519
``T_minimum_eV`` (float [default = 5.0])

docs/output.rst

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -577,23 +577,21 @@ properties and characteristics, as well as scalar edge geometry quantities.
577577
present if provided by FBT geometry.
578578

579579
``dW_thermal_dt`` (time)
580-
Time derivative of the total thermal stored energy [:math:`W`], raw unsmoothed
581-
value.
580+
Time derivative of the total thermal stored energy [:math:`W`], computed over
581+
the last time step.
582582

583583
``dW_thermal_dt_smoothed`` (time)
584-
Smoothed time derivative of total stored thermal energy [:math:`W`].
585-
Exponential moving average of ``dW_thermal_dt`` with a time window coefficient
586-
set in the `numerics` config.
584+
Time derivative of the total thermal stored energy [:math:`W`], computed over
585+
a fixed time window.
586+
587+
``dW_thermal_i_dt_smoothed`` (time)
588+
Time derivative of the ion thermal stored energy [:math:`W`], computed over
589+
a fixed time window.
587590

588591
``dW_thermal_e_dt_smoothed`` (time)
589-
Smoothed time derivative of electron stored thermal energy [:math:`W`].
590-
Exponential moving average of ``dW_thermal_dt`` with a time window coefficient
591-
set in the `numerics` config.
592+
Time derivative of the electron thermal stored energy [:math:`W`], computed over
593+
a fixed time window.
592594

593-
``dW_thermal_i_dt_smoothed`` (time)
594-
Smoothed time derivative of ion stored thermal energy [:math:`W`].
595-
Exponential moving average of ``dW_thermal_dt`` with a time window coefficient
596-
set in the `numerics` config.
597595

598596
``drho`` (time)
599597
Radial grid spacing in the unnormalized rho coordinate [:math:`m`].

torax/_src/config/numerics.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ class RuntimeParams:
4545
resistivity_multiplier: array_typing.FloatScalar
4646
adaptive_T_source_prefactor: float
4747
adaptive_n_source_prefactor: float
48-
dW_dt_smoothing_time_scale: float
48+
dW_dt_window: array_typing.FloatScalar
49+
dW_dt_buffer_length: int = dataclasses.field(metadata={'static': True})
4950
min_rho_norm: float
5051
evolve_ion_heat: bool = dataclasses.field(metadata={'static': True})
5152
evolve_electron_heat: bool = dataclasses.field(metadata={'static': True})
@@ -121,7 +122,7 @@ class Numerics(torax_pydantic.BaseModelFrozen):
121122
min_dt: torax_pydantic.Second = 1e-8
122123
chi_timestep_prefactor: pydantic.PositiveFloat = 50.0
123124
fixed_dt: torax_pydantic.NonNegativeTimeVaryingScalarStep = (
124-
torax_pydantic.ValidatedDefault(1e-1)
125+
torax_pydantic.ValidatedDefault(0.1)
125126
)
126127
adaptive_dt: Annotated[bool, torax_pydantic.JAX_STATIC] = True
127128
dt_reduction_factor: pydantic.PositiveFloat = 3.0
@@ -135,7 +136,10 @@ class Numerics(torax_pydantic.BaseModelFrozen):
135136
)
136137
adaptive_T_source_prefactor: pydantic.PositiveFloat = 2.0e10
137138
adaptive_n_source_prefactor: pydantic.PositiveFloat = 2.0e8
138-
dW_dt_smoothing_time_scale: pydantic.NonNegativeFloat = 0.3
139+
dW_dt_window: torax_pydantic.NonNegativeTimeVaryingScalar = (
140+
torax_pydantic.ValidatedDefault(0.01)
141+
)
142+
dW_dt_buffer_length: Annotated[int, torax_pydantic.JAX_STATIC] = 10
139143
min_rho_norm: torax_pydantic.UnitInterval = 0.015
140144

141145
T_minimum_eV: pydantic.PositiveFloat = 5.0
@@ -182,7 +186,8 @@ def build_runtime_params(self, t: chex.Numeric) -> RuntimeParams:
182186
resistivity_multiplier=self.resistivity_multiplier.get_value(t),
183187
adaptive_T_source_prefactor=self.adaptive_T_source_prefactor,
184188
adaptive_n_source_prefactor=self.adaptive_n_source_prefactor,
185-
dW_dt_smoothing_time_scale=self.dW_dt_smoothing_time_scale,
189+
dW_dt_window=self.dW_dt_window.get_value(t),
190+
dW_dt_buffer_length=self.dW_dt_buffer_length,
186191
min_rho_norm=self.min_rho_norm,
187192
evolve_ion_heat=self.evolve_ion_heat,
188193
evolve_electron_heat=self.evolve_electron_heat,

torax/_src/core_profiles/initialization.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ def initial_core_profiles(
148148
# core profiles refactor.
149149
core_profiles = dataclasses.replace(
150150
core_profiles,
151-
internal_plasma_energy=_initialise_internal_energy(core_profiles, geo),
151+
internal_plasma_energy=_initialise_internal_energy(
152+
runtime_params, core_profiles, geo
153+
),
152154
)
153155

154156
return _init_psi_and_psi_derived(
@@ -161,6 +163,7 @@ def initial_core_profiles(
161163

162164

163165
def _initialise_internal_energy(
166+
runtime_params: runtime_params_lib.RuntimeParams,
164167
core_profiles: state.CoreProfiles,
165168
geo: geometry.Geometry,
166169
) -> state.PlasmaInternalEnergy:
@@ -173,6 +176,7 @@ def _initialise_internal_energy(
173176
geo,
174177
)
175178
)
179+
N = runtime_params.numerics.dW_dt_buffer_length
176180
return state.PlasmaInternalEnergy(
177181
W_thermal_i=W_thermal_i,
178182
W_thermal_e=W_thermal_e,
@@ -181,6 +185,9 @@ def _initialise_internal_energy(
181185
dW_thermal_e_dt=jnp.array(0.0, dtype=jax_utils.get_dtype()),
182186
dW_thermal_i_dt_smoothed=jnp.array(0.0, dtype=jax_utils.get_dtype()),
183187
dW_thermal_e_dt_smoothed=jnp.array(0.0, dtype=jax_utils.get_dtype()),
188+
W_thermal_i_history=jnp.full((N,), W_thermal_i),
189+
W_thermal_e_history=jnp.full((N,), W_thermal_e),
190+
t_history=jnp.full((N,), runtime_params.numerics.t_initial),
184191
)
185192

186193

torax/_src/core_profiles/tests/updaters_test.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import copy
1718
from unittest import mock
18-
1919
from absl.testing import absltest
2020
from absl.testing import parameterized
2121
from jax import numpy as jnp
@@ -26,6 +26,7 @@
2626
from torax._src.core_profiles import updaters
2727
from torax._src.fvm import cell_variable
2828
from torax._src.geometry import circular_geometry
29+
from torax._src.physics import formulas
2930
from torax._src.test_utils import default_configs
3031
from torax._src.torax_pydantic import model_config
3132

@@ -50,12 +51,33 @@ def setUp(self):
5051
right_face_constraint=1.0,
5152
right_face_grad_constraint=None,
5253
)
54+
pressure_thermal_e = cell_variable.CellVariable(
55+
value=jnp.ones_like(self.geo.rho_norm),
56+
face_centers=self.geo.rho_face_norm,
57+
right_face_constraint=1.0,
58+
right_face_grad_constraint=None,
59+
)
60+
pressure_thermal_i = cell_variable.CellVariable(
61+
value=jnp.ones_like(self.geo.rho_norm),
62+
face_centers=self.geo.rho_face_norm,
63+
right_face_constraint=1.0,
64+
right_face_grad_constraint=None,
65+
)
66+
pressure_thermal_total = cell_variable.CellVariable(
67+
value=jnp.ones_like(self.geo.rho_norm),
68+
face_centers=self.geo.rho_face_norm,
69+
right_face_constraint=1.0,
70+
right_face_grad_constraint=None,
71+
)
5372

5473
self.core_profiles_t = mock.create_autospec(
5574
state.CoreProfiles,
5675
instance=True,
5776
T_e=T_e,
5877
n_e=n_e,
78+
pressure_thermal_e=pressure_thermal_e,
79+
pressure_thermal_i=pressure_thermal_i,
80+
pressure_thermal_total=pressure_thermal_total,
5981
)
6082

6183
@parameterized.named_parameters(
@@ -216,6 +238,128 @@ def test_psi_not_updated_if_evolve_current_true(self):
216238
# Since it wasn't updated in provide_..., it should remain 10.0 here.
217239
np.testing.assert_allclose(core_profiles_t1.psi.value, 10.0)
218240

241+
def test_update_energy_state(self):
242+
"""Tests that energy state is updated correctly."""
243+
config = default_configs.get_default_config_dict()
244+
torax_config = model_config.ToraxConfig.from_dict(config)
245+
provider = build_runtime_params.RuntimeParamsProvider.from_config(
246+
torax_config
247+
)
248+
runtime_params = provider(t=0.0)
249+
250+
energy_state_t = initialization._initialise_internal_energy(
251+
runtime_params, self.core_profiles_t, self.geo
252+
)
253+
254+
# Pretend we take a step of half the window size and increase the electron
255+
# and ion pressure.
256+
mock_dt = runtime_params.numerics.dW_dt_window / 2.0
257+
core_profiles_t_plus_dt = copy.deepcopy(self.core_profiles_t)
258+
core_profiles_t_plus_dt.pressure_thermal_e = cell_variable.CellVariable(
259+
value=jnp.full_like(self.geo.rho_norm, 2.0),
260+
face_centers=self.geo.rho_face_norm,
261+
right_face_constraint=1.0,
262+
right_face_grad_constraint=None,
263+
)
264+
core_profiles_t_plus_dt.pressure_thermal_i = cell_variable.CellVariable(
265+
value=jnp.full_like(self.geo.rho_norm, 3.0),
266+
face_centers=self.geo.rho_face_norm,
267+
right_face_constraint=1.0,
268+
right_face_grad_constraint=None,
269+
)
270+
W_thermal_e_t_plus_dt, W_thermal_i_t_plus_dt, _ = (
271+
formulas.calculate_stored_thermal_energy(
272+
core_profiles_t_plus_dt.pressure_thermal_e,
273+
core_profiles_t_plus_dt.pressure_thermal_i,
274+
core_profiles_t_plus_dt.pressure_thermal_total,
275+
self.geo,
276+
)
277+
)
278+
279+
# Get the new energy state from the update function.
280+
energy_state_t_plus_dt = updaters._update_energy_state(
281+
runtime_params,
282+
self.geo,
283+
core_profiles_t_plus_dt,
284+
energy_state_t,
285+
mock_dt,
286+
)
287+
288+
# Check that the time history is updated.
289+
expected_t_history = jnp.concatenate([
290+
energy_state_t.t_history[1:],
291+
jnp.atleast_1d(energy_state_t.t_history[-1]) + mock_dt,
292+
])
293+
np.testing.assert_allclose(
294+
energy_state_t_plus_dt.t_history, expected_t_history
295+
)
296+
297+
# Check that the W_thermal history is updated.
298+
expected_W_thermal_i_history = jnp.concatenate([
299+
energy_state_t.W_thermal_i_history[1:],
300+
jnp.atleast_1d(W_thermal_i_t_plus_dt),
301+
])
302+
np.testing.assert_allclose(
303+
energy_state_t_plus_dt.W_thermal_i_history, expected_W_thermal_i_history
304+
)
305+
expected_W_thermal_e_history = jnp.concatenate([
306+
energy_state_t.W_thermal_e_history[1:],
307+
jnp.atleast_1d(W_thermal_e_t_plus_dt),
308+
])
309+
np.testing.assert_allclose(
310+
energy_state_t_plus_dt.W_thermal_e_history, expected_W_thermal_e_history
311+
)
312+
313+
# Check that the dW_dt is calculated correctly.
314+
np.testing.assert_allclose(
315+
energy_state_t_plus_dt.dW_thermal_i_dt,
316+
(W_thermal_i_t_plus_dt - energy_state_t.W_thermal_i) / mock_dt,
317+
)
318+
np.testing.assert_allclose(
319+
energy_state_t_plus_dt.dW_thermal_e_dt,
320+
(W_thermal_e_t_plus_dt - energy_state_t.W_thermal_e) / mock_dt,
321+
)
322+
323+
# As we took a step of half the window size, the smoothed dW_dt should be
324+
# the same as the un-smoothed dW_dt.
325+
np.testing.assert_allclose(
326+
energy_state_t_plus_dt.dW_thermal_i_dt_smoothed,
327+
energy_state_t_plus_dt.dW_thermal_i_dt,
328+
)
329+
np.testing.assert_allclose(
330+
energy_state_t_plus_dt.dW_thermal_e_dt_smoothed,
331+
energy_state_t_plus_dt.dW_thermal_e_dt,
332+
)
333+
334+
# Take another step and check the dW_dt.
335+
energy_state_t_plus_2dt = updaters._update_energy_state(
336+
runtime_params,
337+
self.geo,
338+
core_profiles_t_plus_dt,
339+
energy_state_t_plus_dt,
340+
mock_dt,
341+
)
342+
# Raw dW_dt values should be zero as we haven't changed the pressures.
343+
np.testing.assert_allclose(
344+
energy_state_t_plus_2dt.dW_thermal_i_dt,
345+
0.0,
346+
)
347+
np.testing.assert_allclose(
348+
energy_state_t_plus_2dt.dW_thermal_e_dt,
349+
0.0,
350+
)
351+
# Smoothed dW_dt values should be computed vs the 0th state
352+
np.testing.assert_allclose(
353+
energy_state_t_plus_2dt.dW_thermal_i_dt_smoothed,
354+
(energy_state_t_plus_2dt.W_thermal_i - energy_state_t.W_thermal_i)
355+
/ (runtime_params.numerics.dW_dt_window),
356+
)
357+
np.testing.assert_allclose(
358+
energy_state_t_plus_2dt.dW_thermal_e_dt_smoothed,
359+
(energy_state_t_plus_2dt.W_thermal_e - energy_state_t.W_thermal_e)
360+
/ (runtime_params.numerics.dW_dt_window),
361+
)
362+
219363

220364
if __name__ == '__main__':
221365
absltest.main()

torax/_src/core_profiles/updaters.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import jax
3434
import jax.numpy as jnp
3535
from torax._src import array_typing
36-
from torax._src import jax_utils
3736
from torax._src import state
3837
from torax._src.config import runtime_params as runtime_params_lib
3938
from torax._src.core_profiles import convertors
@@ -372,37 +371,46 @@ def _update_energy_state(
372371
geo,
373372
)
374373
)
375-
dW_i_dt_raw = (W_thermal_i - prev_energy_state.W_thermal_i) / dt
376-
dW_e_dt_raw = (W_thermal_e - prev_energy_state.W_thermal_e) / dt
377-
378-
exponential_smoothing_alpha = jax.lax.cond(
379-
runtime_params.numerics.dW_dt_smoothing_time_scale > 0.0,
380-
lambda: jnp.array(1.0, dtype=jax_utils.get_dtype())
381-
- jnp.exp(-dt / runtime_params.numerics.dW_dt_smoothing_time_scale),
382-
lambda: jnp.array(1.0, dtype=jax_utils.get_dtype()),
383-
)
384-
dW_i_dt_smoothed = _exponential_smoothing(
385-
dW_i_dt_raw,
386-
prev_energy_state.dW_thermal_i_dt_smoothed,
387-
exponential_smoothing_alpha,
388-
)
389-
dW_e_dt_smoothed = _exponential_smoothing(
390-
dW_e_dt_raw,
391-
prev_energy_state.dW_thermal_e_dt_smoothed,
392-
exponential_smoothing_alpha,
374+
375+
# Raw values: dW/dt over the last timestep.
376+
dW_i_dt = (W_thermal_i - prev_energy_state.W_thermal_i_history[-1]) / dt
377+
dW_e_dt = (W_thermal_e - prev_energy_state.W_thermal_e_history[-1]) / dt
378+
379+
# Smoothed values: dW/dt over the last window.
380+
current_t = prev_energy_state.t_history[-1] + dt
381+
t_target = current_t - runtime_params.numerics.dW_dt_window
382+
383+
# If t_target is before the beginning of the history, use the first element.
384+
# Otherwise use closest element to t_target.
385+
idx = jnp.maximum(
386+
0, jnp.searchsorted(prev_energy_state.t_history, t_target) - 1
393387
)
388+
dW_i_dt_smoothed = (
389+
W_thermal_i - prev_energy_state.W_thermal_i_history[idx]
390+
) / (current_t - prev_energy_state.t_history[idx])
391+
dW_e_dt_smoothed = (
392+
W_thermal_e - prev_energy_state.W_thermal_e_history[idx]
393+
) / (current_t - prev_energy_state.t_history[idx])
394+
395+
# Update history arrays.
396+
W_i_hist_new = jnp.roll(prev_energy_state.W_thermal_i_history, -1)
397+
W_i_hist_new = W_i_hist_new.at[-1].set(W_thermal_i)
398+
399+
W_e_hist_new = jnp.roll(prev_energy_state.W_thermal_e_history, -1)
400+
W_e_hist_new = W_e_hist_new.at[-1].set(W_thermal_e)
401+
402+
t_hist_new = jnp.roll(prev_energy_state.t_history, -1)
403+
t_hist_new = t_hist_new.at[-1].set(current_t)
394404

395405
return state.PlasmaInternalEnergy(
396406
W_thermal_i=W_thermal_i,
397407
W_thermal_e=W_thermal_e,
398408
W_thermal_total=W_thermal_total,
399-
dW_thermal_i_dt=dW_i_dt_raw,
400-
dW_thermal_e_dt=dW_e_dt_raw,
409+
dW_thermal_i_dt=dW_i_dt,
410+
dW_thermal_e_dt=dW_e_dt,
401411
dW_thermal_i_dt_smoothed=dW_i_dt_smoothed,
402412
dW_thermal_e_dt_smoothed=dW_e_dt_smoothed,
413+
W_thermal_i_history=W_i_hist_new,
414+
W_thermal_e_history=W_e_hist_new,
415+
t_history=t_hist_new,
403416
)
404-
405-
406-
def _exponential_smoothing(new_raw, old_smoothed, alpha):
407-
"""Exponential moving average (EMA)."""
408-
return (1.0 - alpha) * old_smoothed + alpha * new_raw

0 commit comments

Comments
 (0)