1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17+ import copy
1718from unittest import mock
18-
1919from absl .testing import absltest
2020from absl .testing import parameterized
2121from jax import numpy as jnp
2626from torax ._src .core_profiles import updaters
2727from torax ._src .fvm import cell_variable
2828from torax ._src .geometry import circular_geometry
29+ from torax ._src .physics import formulas
2930from torax ._src .test_utils import default_configs
3031from torax ._src .torax_pydantic import model_config
3132
@@ -50,12 +51,33 @@ def setUp(self):
5051 right_face_constraint = 1.0 ,
5152 right_face_grad_constraint = None ,
5253 )
54+ pressure_thermal_e = cell_variable .CellVariable (
55+ value = jnp .ones_like (self .geo .rho_norm ),
56+ face_centers = self .geo .rho_face_norm ,
57+ right_face_constraint = 1.0 ,
58+ right_face_grad_constraint = None ,
59+ )
60+ pressure_thermal_i = cell_variable .CellVariable (
61+ value = jnp .ones_like (self .geo .rho_norm ),
62+ face_centers = self .geo .rho_face_norm ,
63+ right_face_constraint = 1.0 ,
64+ right_face_grad_constraint = None ,
65+ )
66+ pressure_thermal_total = cell_variable .CellVariable (
67+ value = jnp .ones_like (self .geo .rho_norm ),
68+ face_centers = self .geo .rho_face_norm ,
69+ right_face_constraint = 1.0 ,
70+ right_face_grad_constraint = None ,
71+ )
5372
5473 self .core_profiles_t = mock .create_autospec (
5574 state .CoreProfiles ,
5675 instance = True ,
5776 T_e = T_e ,
5877 n_e = n_e ,
78+ pressure_thermal_e = pressure_thermal_e ,
79+ pressure_thermal_i = pressure_thermal_i ,
80+ pressure_thermal_total = pressure_thermal_total ,
5981 )
6082
6183 @parameterized .named_parameters (
@@ -216,6 +238,128 @@ def test_psi_not_updated_if_evolve_current_true(self):
216238 # Since it wasn't updated in provide_..., it should remain 10.0 here.
217239 np .testing .assert_allclose (core_profiles_t1 .psi .value , 10.0 )
218240
241+ def test_update_energy_state (self ):
242+ """Tests that energy state is updated correctly."""
243+ config = default_configs .get_default_config_dict ()
244+ torax_config = model_config .ToraxConfig .from_dict (config )
245+ provider = build_runtime_params .RuntimeParamsProvider .from_config (
246+ torax_config
247+ )
248+ runtime_params = provider (t = 0.0 )
249+
250+ energy_state_t = initialization ._initialise_internal_energy (
251+ runtime_params , self .core_profiles_t , self .geo
252+ )
253+
254+ # Pretend we take a step of half the window size and increase the electron
255+ # and ion pressure.
256+ mock_dt = runtime_params .numerics .dW_dt_window / 2.0
257+ core_profiles_t_plus_dt = copy .deepcopy (self .core_profiles_t )
258+ core_profiles_t_plus_dt .pressure_thermal_e = cell_variable .CellVariable (
259+ value = jnp .full_like (self .geo .rho_norm , 2.0 ),
260+ face_centers = self .geo .rho_face_norm ,
261+ right_face_constraint = 1.0 ,
262+ right_face_grad_constraint = None ,
263+ )
264+ core_profiles_t_plus_dt .pressure_thermal_i = cell_variable .CellVariable (
265+ value = jnp .full_like (self .geo .rho_norm , 3.0 ),
266+ face_centers = self .geo .rho_face_norm ,
267+ right_face_constraint = 1.0 ,
268+ right_face_grad_constraint = None ,
269+ )
270+ W_thermal_e_t_plus_dt , W_thermal_i_t_plus_dt , _ = (
271+ formulas .calculate_stored_thermal_energy (
272+ core_profiles_t_plus_dt .pressure_thermal_e ,
273+ core_profiles_t_plus_dt .pressure_thermal_i ,
274+ core_profiles_t_plus_dt .pressure_thermal_total ,
275+ self .geo ,
276+ )
277+ )
278+
279+ # Get the new energy state from the update function.
280+ energy_state_t_plus_dt = updaters ._update_energy_state (
281+ runtime_params ,
282+ self .geo ,
283+ core_profiles_t_plus_dt ,
284+ energy_state_t ,
285+ mock_dt ,
286+ )
287+
288+ # Check that the time history is updated.
289+ expected_t_history = jnp .concatenate ([
290+ energy_state_t .t_history [1 :],
291+ jnp .atleast_1d (energy_state_t .t_history [- 1 ]) + mock_dt ,
292+ ])
293+ np .testing .assert_allclose (
294+ energy_state_t_plus_dt .t_history , expected_t_history
295+ )
296+
297+ # Check that the W_thermal history is updated.
298+ expected_W_thermal_i_history = jnp .concatenate ([
299+ energy_state_t .W_thermal_i_history [1 :],
300+ jnp .atleast_1d (W_thermal_i_t_plus_dt ),
301+ ])
302+ np .testing .assert_allclose (
303+ energy_state_t_plus_dt .W_thermal_i_history , expected_W_thermal_i_history
304+ )
305+ expected_W_thermal_e_history = jnp .concatenate ([
306+ energy_state_t .W_thermal_e_history [1 :],
307+ jnp .atleast_1d (W_thermal_e_t_plus_dt ),
308+ ])
309+ np .testing .assert_allclose (
310+ energy_state_t_plus_dt .W_thermal_e_history , expected_W_thermal_e_history
311+ )
312+
313+ # Check that the dW_dt is calculated correctly.
314+ np .testing .assert_allclose (
315+ energy_state_t_plus_dt .dW_thermal_i_dt ,
316+ (W_thermal_i_t_plus_dt - energy_state_t .W_thermal_i ) / mock_dt ,
317+ )
318+ np .testing .assert_allclose (
319+ energy_state_t_plus_dt .dW_thermal_e_dt ,
320+ (W_thermal_e_t_plus_dt - energy_state_t .W_thermal_e ) / mock_dt ,
321+ )
322+
323+ # As we took a step of half the window size, the smoothed dW_dt should be
324+ # the same as the un-smoothed dW_dt.
325+ np .testing .assert_allclose (
326+ energy_state_t_plus_dt .dW_thermal_i_dt_smoothed ,
327+ energy_state_t_plus_dt .dW_thermal_i_dt ,
328+ )
329+ np .testing .assert_allclose (
330+ energy_state_t_plus_dt .dW_thermal_e_dt_smoothed ,
331+ energy_state_t_plus_dt .dW_thermal_e_dt ,
332+ )
333+
334+ # Take another step and check the dW_dt.
335+ energy_state_t_plus_2dt = updaters ._update_energy_state (
336+ runtime_params ,
337+ self .geo ,
338+ core_profiles_t_plus_dt ,
339+ energy_state_t_plus_dt ,
340+ mock_dt ,
341+ )
342+ # Raw dW_dt values should be zero as we haven't changed the pressures.
343+ np .testing .assert_allclose (
344+ energy_state_t_plus_2dt .dW_thermal_i_dt ,
345+ 0.0 ,
346+ )
347+ np .testing .assert_allclose (
348+ energy_state_t_plus_2dt .dW_thermal_e_dt ,
349+ 0.0 ,
350+ )
351+ # Smoothed dW_dt values should be computed vs the 0th state
352+ np .testing .assert_allclose (
353+ energy_state_t_plus_2dt .dW_thermal_i_dt_smoothed ,
354+ (energy_state_t_plus_2dt .W_thermal_i - energy_state_t .W_thermal_i )
355+ / (runtime_params .numerics .dW_dt_window ),
356+ )
357+ np .testing .assert_allclose (
358+ energy_state_t_plus_2dt .dW_thermal_e_dt_smoothed ,
359+ (energy_state_t_plus_2dt .W_thermal_e - energy_state_t .W_thermal_e )
360+ / (runtime_params .numerics .dW_dt_window ),
361+ )
362+
219363
220364if __name__ == '__main__' :
221365 absltest .main ()
0 commit comments