Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions docs/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -502,15 +502,18 @@ equations being solved, constant numerical variables.
Prefactor for adaptive source term for setting density internal boundary
conditions.

``dW_dt_smoothing_time_scale`` (float [default = 0.3])
Time scale [s] for the exponential moving average smoothing of dW/dt terms
used in P_SOL and confinement time calculations. If 0.0, no smoothing is
applied and raw dW/dt is used.

``min_rho_norm`` (float [default = 0.015])
Minimum rho_norm value below which current profile values are extrapolated to
the axis in psi calculations, to avoid numerical artifacts near rho=0.

``dW_dt_window`` (float [default = 0.01])
Time window [s] over which to compute the windowed derivative of the stored
thermal energy.

``dW_dt_buffer_length`` (int [default = 50])
Number of elements to keep in the history buffer for computing the windowed
derivative.

.. TODO (b/434175938): consolidate naming to _min or _minimum.

``T_minimum_eV`` (float [default = 5.0])
Expand Down
22 changes: 10 additions & 12 deletions docs/output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -577,23 +577,21 @@ properties and characteristics, as well as scalar edge geometry quantities.
present if provided by FBT geometry.

``dW_thermal_dt`` (time)
Time derivative of the total thermal stored energy [:math:`W`], raw unsmoothed
value.
Time derivative of the total thermal stored energy [:math:`W`], computed over
the last time step.

``dW_thermal_dt_smoothed`` (time)
Smoothed time derivative of total stored thermal energy [:math:`W`].
Exponential moving average of ``dW_thermal_dt`` with a time window coefficient
set in the `numerics` config.
Time derivative of the total thermal stored energy [:math:`W`], computed over
a fixed time window.

``dW_thermal_i_dt_smoothed`` (time)
Time derivative of the ion thermal stored energy [:math:`W`], computed over
a fixed time window.

``dW_thermal_e_dt_smoothed`` (time)
Smoothed time derivative of electron stored thermal energy [:math:`W`].
Exponential moving average of ``dW_thermal_dt`` with a time window coefficient
set in the `numerics` config.
Time derivative of the electron thermal stored energy [:math:`W`], computed over
a fixed time window.

``dW_thermal_i_dt_smoothed`` (time)
Smoothed time derivative of ion stored thermal energy [:math:`W`].
Exponential moving average of ``dW_thermal_dt`` with a time window coefficient
set in the `numerics` config.

``drho`` (time)
Radial grid spacing in the unnormalized rho coordinate [:math:`m`].
Expand Down
13 changes: 9 additions & 4 deletions torax/_src/config/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class RuntimeParams:
resistivity_multiplier: array_typing.FloatScalar
adaptive_T_source_prefactor: float
adaptive_n_source_prefactor: float
dW_dt_smoothing_time_scale: float
dW_dt_window: array_typing.FloatScalar
dW_dt_buffer_length: int = dataclasses.field(metadata={'static': True})
min_rho_norm: float
evolve_ion_heat: bool = dataclasses.field(metadata={'static': True})
evolve_electron_heat: bool = dataclasses.field(metadata={'static': True})
Expand Down Expand Up @@ -121,7 +122,7 @@ class Numerics(torax_pydantic.BaseModelFrozen):
min_dt: torax_pydantic.Second = 1e-8
chi_timestep_prefactor: pydantic.PositiveFloat = 50.0
fixed_dt: torax_pydantic.NonNegativeTimeVaryingScalarStep = (
torax_pydantic.ValidatedDefault(1e-1)
torax_pydantic.ValidatedDefault(0.1)
)
adaptive_dt: Annotated[bool, torax_pydantic.JAX_STATIC] = True
dt_reduction_factor: pydantic.PositiveFloat = 3.0
Expand All @@ -135,7 +136,10 @@ class Numerics(torax_pydantic.BaseModelFrozen):
)
adaptive_T_source_prefactor: pydantic.PositiveFloat = 2.0e10
adaptive_n_source_prefactor: pydantic.PositiveFloat = 2.0e8
dW_dt_smoothing_time_scale: pydantic.NonNegativeFloat = 0.3
dW_dt_window: torax_pydantic.NonNegativeTimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.01)
)
dW_dt_buffer_length: Annotated[int, torax_pydantic.JAX_STATIC] = 10
min_rho_norm: torax_pydantic.UnitInterval = 0.015

T_minimum_eV: pydantic.PositiveFloat = 5.0
Expand Down Expand Up @@ -182,7 +186,8 @@ def build_runtime_params(self, t: chex.Numeric) -> RuntimeParams:
resistivity_multiplier=self.resistivity_multiplier.get_value(t),
adaptive_T_source_prefactor=self.adaptive_T_source_prefactor,
adaptive_n_source_prefactor=self.adaptive_n_source_prefactor,
dW_dt_smoothing_time_scale=self.dW_dt_smoothing_time_scale,
dW_dt_window=self.dW_dt_window.get_value(t),
dW_dt_buffer_length=self.dW_dt_buffer_length,
min_rho_norm=self.min_rho_norm,
evolve_ion_heat=self.evolve_ion_heat,
evolve_electron_heat=self.evolve_electron_heat,
Expand Down
9 changes: 8 additions & 1 deletion torax/_src/core_profiles/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def initial_core_profiles(
# core profiles refactor.
core_profiles = dataclasses.replace(
core_profiles,
internal_plasma_energy=_initialise_internal_energy(core_profiles, geo),
internal_plasma_energy=_initialise_internal_energy(
runtime_params, core_profiles, geo
),
)

return _init_psi_and_psi_derived(
Expand All @@ -161,6 +163,7 @@ def initial_core_profiles(


def _initialise_internal_energy(
runtime_params: runtime_params_lib.RuntimeParams,
core_profiles: state.CoreProfiles,
geo: geometry.Geometry,
) -> state.PlasmaInternalEnergy:
Expand All @@ -173,6 +176,7 @@ def _initialise_internal_energy(
geo,
)
)
N = runtime_params.numerics.dW_dt_buffer_length
return state.PlasmaInternalEnergy(
W_thermal_i=W_thermal_i,
W_thermal_e=W_thermal_e,
Expand All @@ -181,6 +185,9 @@ def _initialise_internal_energy(
dW_thermal_e_dt=jnp.array(0.0, dtype=jax_utils.get_dtype()),
dW_thermal_i_dt_smoothed=jnp.array(0.0, dtype=jax_utils.get_dtype()),
dW_thermal_e_dt_smoothed=jnp.array(0.0, dtype=jax_utils.get_dtype()),
W_thermal_i_history=jnp.full((N,), W_thermal_i),
W_thermal_e_history=jnp.full((N,), W_thermal_e),
t_history=jnp.full((N,), runtime_params.numerics.t_initial),
)


Expand Down
146 changes: 145 additions & 1 deletion torax/_src/core_profiles/tests/updaters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
from jax import numpy as jnp
Expand All @@ -26,6 +26,7 @@
from torax._src.core_profiles import updaters
from torax._src.fvm import cell_variable
from torax._src.geometry import circular_geometry
from torax._src.physics import formulas
from torax._src.test_utils import default_configs
from torax._src.torax_pydantic import model_config

Expand All @@ -50,12 +51,33 @@ def setUp(self):
right_face_constraint=1.0,
right_face_grad_constraint=None,
)
pressure_thermal_e = cell_variable.CellVariable(
value=jnp.ones_like(self.geo.rho_norm),
face_centers=self.geo.rho_face_norm,
right_face_constraint=1.0,
right_face_grad_constraint=None,
)
pressure_thermal_i = cell_variable.CellVariable(
value=jnp.ones_like(self.geo.rho_norm),
face_centers=self.geo.rho_face_norm,
right_face_constraint=1.0,
right_face_grad_constraint=None,
)
pressure_thermal_total = cell_variable.CellVariable(
value=jnp.ones_like(self.geo.rho_norm),
face_centers=self.geo.rho_face_norm,
right_face_constraint=1.0,
right_face_grad_constraint=None,
)

self.core_profiles_t = mock.create_autospec(
state.CoreProfiles,
instance=True,
T_e=T_e,
n_e=n_e,
pressure_thermal_e=pressure_thermal_e,
pressure_thermal_i=pressure_thermal_i,
pressure_thermal_total=pressure_thermal_total,
)

@parameterized.named_parameters(
Expand Down Expand Up @@ -216,6 +238,128 @@ def test_psi_not_updated_if_evolve_current_true(self):
# Since it wasn't updated in provide_..., it should remain 10.0 here.
np.testing.assert_allclose(core_profiles_t1.psi.value, 10.0)

def test_update_energy_state(self):
"""Tests that energy state is updated correctly."""
config = default_configs.get_default_config_dict()
torax_config = model_config.ToraxConfig.from_dict(config)
provider = build_runtime_params.RuntimeParamsProvider.from_config(
torax_config
)
runtime_params = provider(t=0.0)

energy_state_t = initialization._initialise_internal_energy(
runtime_params, self.core_profiles_t, self.geo
)

# Pretend we take a step of half the window size and increase the electron
# and ion pressure.
mock_dt = runtime_params.numerics.dW_dt_window / 2.0
core_profiles_t_plus_dt = copy.deepcopy(self.core_profiles_t)
core_profiles_t_plus_dt.pressure_thermal_e = cell_variable.CellVariable(
value=jnp.full_like(self.geo.rho_norm, 2.0),
face_centers=self.geo.rho_face_norm,
right_face_constraint=1.0,
right_face_grad_constraint=None,
)
core_profiles_t_plus_dt.pressure_thermal_i = cell_variable.CellVariable(
value=jnp.full_like(self.geo.rho_norm, 3.0),
face_centers=self.geo.rho_face_norm,
right_face_constraint=1.0,
right_face_grad_constraint=None,
)
W_thermal_e_t_plus_dt, W_thermal_i_t_plus_dt, _ = (
formulas.calculate_stored_thermal_energy(
core_profiles_t_plus_dt.pressure_thermal_e,
core_profiles_t_plus_dt.pressure_thermal_i,
core_profiles_t_plus_dt.pressure_thermal_total,
self.geo,
)
)

# Get the new energy state from the update function.
energy_state_t_plus_dt = updaters._update_energy_state(
runtime_params,
self.geo,
core_profiles_t_plus_dt,
energy_state_t,
mock_dt,
)

# Check that the time history is updated.
expected_t_history = jnp.concatenate([
energy_state_t.t_history[1:],
jnp.atleast_1d(energy_state_t.t_history[-1]) + mock_dt,
])
np.testing.assert_allclose(
energy_state_t_plus_dt.t_history, expected_t_history
)

# Check that the W_thermal history is updated.
expected_W_thermal_i_history = jnp.concatenate([
energy_state_t.W_thermal_i_history[1:],
jnp.atleast_1d(W_thermal_i_t_plus_dt),
])
np.testing.assert_allclose(
energy_state_t_plus_dt.W_thermal_i_history, expected_W_thermal_i_history
)
expected_W_thermal_e_history = jnp.concatenate([
energy_state_t.W_thermal_e_history[1:],
jnp.atleast_1d(W_thermal_e_t_plus_dt),
])
np.testing.assert_allclose(
energy_state_t_plus_dt.W_thermal_e_history, expected_W_thermal_e_history
)

# Check that the dW_dt is calculated correctly.
np.testing.assert_allclose(
energy_state_t_plus_dt.dW_thermal_i_dt,
(W_thermal_i_t_plus_dt - energy_state_t.W_thermal_i) / mock_dt,
)
np.testing.assert_allclose(
energy_state_t_plus_dt.dW_thermal_e_dt,
(W_thermal_e_t_plus_dt - energy_state_t.W_thermal_e) / mock_dt,
)

# As we took a step of half the window size, the smoothed dW_dt should be
# the same as the un-smoothed dW_dt.
np.testing.assert_allclose(
energy_state_t_plus_dt.dW_thermal_i_dt_smoothed,
energy_state_t_plus_dt.dW_thermal_i_dt,
)
np.testing.assert_allclose(
energy_state_t_plus_dt.dW_thermal_e_dt_smoothed,
energy_state_t_plus_dt.dW_thermal_e_dt,
)

# Take another step and check the dW_dt.
energy_state_t_plus_2dt = updaters._update_energy_state(
runtime_params,
self.geo,
core_profiles_t_plus_dt,
energy_state_t_plus_dt,
mock_dt,
)
# Raw dW_dt values should be zero as we haven't changed the pressures.
np.testing.assert_allclose(
energy_state_t_plus_2dt.dW_thermal_i_dt,
0.0,
)
np.testing.assert_allclose(
energy_state_t_plus_2dt.dW_thermal_e_dt,
0.0,
)
# Smoothed dW_dt values should be computed vs the 0th state
np.testing.assert_allclose(
energy_state_t_plus_2dt.dW_thermal_i_dt_smoothed,
(energy_state_t_plus_2dt.W_thermal_i - energy_state_t.W_thermal_i)
/ (runtime_params.numerics.dW_dt_window),
)
np.testing.assert_allclose(
energy_state_t_plus_2dt.dW_thermal_e_dt_smoothed,
(energy_state_t_plus_2dt.W_thermal_e - energy_state_t.W_thermal_e)
/ (runtime_params.numerics.dW_dt_window),
)


if __name__ == '__main__':
absltest.main()
60 changes: 34 additions & 26 deletions torax/_src/core_profiles/updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import jax
import jax.numpy as jnp
from torax._src import array_typing
from torax._src import jax_utils
from torax._src import state
from torax._src.config import runtime_params as runtime_params_lib
from torax._src.core_profiles import convertors
Expand Down Expand Up @@ -372,37 +371,46 @@ def _update_energy_state(
geo,
)
)
dW_i_dt_raw = (W_thermal_i - prev_energy_state.W_thermal_i) / dt
dW_e_dt_raw = (W_thermal_e - prev_energy_state.W_thermal_e) / dt

exponential_smoothing_alpha = jax.lax.cond(
runtime_params.numerics.dW_dt_smoothing_time_scale > 0.0,
lambda: jnp.array(1.0, dtype=jax_utils.get_dtype())
- jnp.exp(-dt / runtime_params.numerics.dW_dt_smoothing_time_scale),
lambda: jnp.array(1.0, dtype=jax_utils.get_dtype()),
)
dW_i_dt_smoothed = _exponential_smoothing(
dW_i_dt_raw,
prev_energy_state.dW_thermal_i_dt_smoothed,
exponential_smoothing_alpha,
)
dW_e_dt_smoothed = _exponential_smoothing(
dW_e_dt_raw,
prev_energy_state.dW_thermal_e_dt_smoothed,
exponential_smoothing_alpha,

# Raw values: dW/dt over the last timestep.
dW_i_dt = (W_thermal_i - prev_energy_state.W_thermal_i_history[-1]) / dt
dW_e_dt = (W_thermal_e - prev_energy_state.W_thermal_e_history[-1]) / dt

# Smoothed values: dW/dt over the last window.
current_t = prev_energy_state.t_history[-1] + dt
t_target = current_t - runtime_params.numerics.dW_dt_window

# If t_target is before the beginning of the history, use the first element.
# Otherwise use closest element to t_target.
idx = jnp.maximum(
0, jnp.searchsorted(prev_energy_state.t_history, t_target) - 1
)
dW_i_dt_smoothed = (
W_thermal_i - prev_energy_state.W_thermal_i_history[idx]
) / (current_t - prev_energy_state.t_history[idx])
dW_e_dt_smoothed = (
W_thermal_e - prev_energy_state.W_thermal_e_history[idx]
) / (current_t - prev_energy_state.t_history[idx])

# Update history arrays.
W_i_hist_new = jnp.roll(prev_energy_state.W_thermal_i_history, -1)
W_i_hist_new = W_i_hist_new.at[-1].set(W_thermal_i)

W_e_hist_new = jnp.roll(prev_energy_state.W_thermal_e_history, -1)
W_e_hist_new = W_e_hist_new.at[-1].set(W_thermal_e)

t_hist_new = jnp.roll(prev_energy_state.t_history, -1)
t_hist_new = t_hist_new.at[-1].set(current_t)

return state.PlasmaInternalEnergy(
W_thermal_i=W_thermal_i,
W_thermal_e=W_thermal_e,
W_thermal_total=W_thermal_total,
dW_thermal_i_dt=dW_i_dt_raw,
dW_thermal_e_dt=dW_e_dt_raw,
dW_thermal_i_dt=dW_i_dt,
dW_thermal_e_dt=dW_e_dt,
dW_thermal_i_dt_smoothed=dW_i_dt_smoothed,
dW_thermal_e_dt_smoothed=dW_e_dt_smoothed,
W_thermal_i_history=W_i_hist_new,
W_thermal_e_history=W_e_hist_new,
t_history=t_hist_new,
)


def _exponential_smoothing(new_raw, old_smoothed, alpha):
"""Exponential moving average (EMA)."""
return (1.0 - alpha) * old_smoothed + alpha * new_raw
Loading
Loading