diff --git a/torax/_src/core_profiles/profile_conditions.py b/torax/_src/core_profiles/profile_conditions.py index efd9764c4..e8fc5c063 100644 --- a/torax/_src/core_profiles/profile_conditions.py +++ b/torax/_src/core_profiles/profile_conditions.py @@ -22,6 +22,7 @@ import numpy as np import pydantic from torax._src import array_typing +from torax._src.internal_boundary_conditions import internal_boundary_conditions as internal_boundary_conditions_lib from torax._src.torax_pydantic import torax_pydantic from typing_extensions import Self @@ -66,6 +67,9 @@ class RuntimeParams: n_e_nbar_is_fGW: bool n_e_right_bc: array_typing.FloatScalar n_e_right_bc_is_fGW: bool + internal_boundary_conditions: ( + internal_boundary_conditions_lib.InternalBoundaryConditions + ) current_profile_nu: float initial_j_is_total_current: bool = dataclasses.field( metadata={'static': True} @@ -186,6 +190,11 @@ class ProfileConditions(torax_pydantic.BaseModelFrozen): n_e_nbar_is_fGW: bool = False n_e_right_bc: torax_pydantic.TimeVaryingScalar | None = None n_e_right_bc_is_fGW: bool = False + internal_boundary_conditions: ( + internal_boundary_conditions_lib.InternalBoundaryConditionsConfig + ) = torax_pydantic.ValidatedDefault( + internal_boundary_conditions_lib.InternalBoundaryConditionsConfig() + ) current_profile_nu: float = 1.0 initial_j_is_total_current: Annotated[bool, torax_pydantic.JAX_STATIC] = False # TODO(b/434175938): Remove this before the V2 API release in place of @@ -428,9 +437,18 @@ def build_runtime_params(self, t: chex.Numeric) -> RuntimeParams: else: runtime_params['n_e_right_bc_is_absolute'] = True + runtime_params['internal_boundary_conditions'] = ( + self.internal_boundary_conditions.build_runtime_params(t) + ) + def _get_value(x): if isinstance( - x, (torax_pydantic.TimeVaryingScalar, torax_pydantic.TimeVaryingArray) + x, + ( + torax_pydantic.TimeVaryingScalar, + torax_pydantic.TimeVaryingArray, + torax_pydantic.TimeVaryingPoints, + ), ): return x.get_value(t) else: diff --git a/torax/_src/core_profiles/tests/profile_conditions_test.py b/torax/_src/core_profiles/tests/profile_conditions_test.py index 665b14056..39eb7610d 100644 --- a/torax/_src/core_profiles/tests/profile_conditions_test.py +++ b/torax/_src/core_profiles/tests/profile_conditions_test.py @@ -521,6 +521,32 @@ def test_multiple_validation_errors(self): with self.assertRaisesRegex(ValueError, '3 errors were found'): profile_conditions.ProfileConditions(**config_overrides) + def test_internal_boundary_conditions_set_correctly(self): + """Tests that internal_boundary_conditions is populated from config.""" + config = default_configs.get_default_config_dict() + config['profile_conditions'] = { + 'internal_boundary_conditions': { + 'T_i': { + 0.0: {0: 1.0, 1: 2.0}, + 1.0: {0: 3.0, 1: 4.0}, + }, + }, + } + torax_config = model_config.ToraxConfig.from_dict(config) + runtime_params_provider = ( + build_runtime_params.RuntimeParamsProvider.from_config(torax_config) + ) + + runtime_params = runtime_params_provider(t=0.0) + self.assertIsNotNone( + runtime_params.profile_conditions.internal_boundary_conditions + ) + # Basic check to ensure the config was actually used. + np.testing.assert_array_equal( + runtime_params.profile_conditions.internal_boundary_conditions.T_i, + np.array([1.0, 0.0, 0.0, 2.0]), + ) + if __name__ == '__main__': absltest.main() diff --git a/torax/_src/fvm/calc_coeffs.py b/torax/_src/fvm/calc_coeffs.py index 552193cf2..2d0a4cdfa 100644 --- a/torax/_src/fvm/calc_coeffs.py +++ b/torax/_src/fvm/calc_coeffs.py @@ -13,6 +13,7 @@ # limitations under the License. """Calculates Block1DCoeffs for a time step.""" + import functools import jax @@ -26,7 +27,7 @@ from torax._src.fvm import block_1d_coeffs from torax._src.fvm import cell_variable from torax._src.geometry import geometry -from torax._src.internal_boundary_conditions import internal_boundary_conditions as internal_boundary_conditions_lib +from torax._src.internal_boundary_conditions import adaptive_source from torax._src.pedestal_model import pedestal_model as pedestal_model_lib from torax._src.sources import source_profile_builders from torax._src.sources import source_profiles as source_profiles_lib @@ -298,6 +299,9 @@ def _calc_coeffs_full( pedestal_model_output = physics_models.pedestal_model( runtime_params, geo, core_profiles ) + internal_boundary_conditions_from_config = ( + runtime_params.profile_conditions.internal_boundary_conditions + ) conductivity = ( physics_models.neoclassical_models.conductivity.calculate_conductivity( @@ -521,7 +525,15 @@ def _calc_coeffs_full( * core_profiles.psi.grad() ) - # Add internal boundary condition source terms + # Add internal boundary condition source terms, combining user-specified + # boundary conditions with pedestal model output. + # Note that the pedestal model will overwrite any user-specified boundary + # conditions, since the pedestal model is applied last. + combined_internal_boundary_conditions = ( + internal_boundary_conditions_from_config.update( + pedestal_model_output.to_internal_boundary_conditions(geo) + ) + ) ( source_i, source_e, @@ -529,7 +541,7 @@ def _calc_coeffs_full( source_mat_ii, source_mat_ee, source_mat_nn, - ) = internal_boundary_conditions_lib.apply_adaptive_source( + ) = adaptive_source.apply_adaptive_source( source_T_i=source_i, source_T_e=source_e, source_n_e=source_n_e, @@ -537,11 +549,7 @@ def _calc_coeffs_full( source_mat_ee=source_mat_ee, source_mat_nn=source_mat_nn, runtime_params=runtime_params, - # Pedestal contributes an internal boundary condition to the source - # terms at the pedestal top. - internal_boundary_conditions=pedestal_model_output.to_internal_boundary_conditions( - geo - ), + internal_boundary_conditions=combined_internal_boundary_conditions, ) # Build arguments to solver based on which variables are evolving diff --git a/torax/_src/internal_boundary_conditions/adaptive_source.py b/torax/_src/internal_boundary_conditions/adaptive_source.py new file mode 100644 index 000000000..7fed9daba --- /dev/null +++ b/torax/_src/internal_boundary_conditions/adaptive_source.py @@ -0,0 +1,83 @@ +# Copyright 2026 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Adaptive source for internal boundary conditions.""" + +import jax.numpy as jnp +from torax._src import array_typing +from torax._src.config import runtime_params as runtime_params_lib +from torax._src.internal_boundary_conditions import internal_boundary_conditions as internal_boundary_conditions_lib + + +def apply_adaptive_source( + *, + source_T_i: array_typing.FloatVectorCell, + source_T_e: array_typing.FloatVectorCell, + source_n_e: array_typing.FloatVectorCell, + source_mat_ii: array_typing.FloatVectorCell, + source_mat_ee: array_typing.FloatVectorCell, + source_mat_nn: array_typing.FloatVectorCell, + runtime_params: runtime_params_lib.RuntimeParams, + internal_boundary_conditions: internal_boundary_conditions_lib.InternalBoundaryConditions, +) -> tuple[ + array_typing.FloatVectorCell, + array_typing.FloatVectorCell, + array_typing.FloatVectorCell, + array_typing.FloatVectorCell, + array_typing.FloatVectorCell, + array_typing.FloatVectorCell, +]: + """Applies an adaptive source to the source profiles to set internal boundary conditions.""" + + # Ion temperature + source_T_i += ( + runtime_params.numerics.adaptive_T_source_prefactor + * internal_boundary_conditions.T_i + ) + source_mat_ii -= jnp.where( + internal_boundary_conditions.T_i != 0.0, + runtime_params.numerics.adaptive_T_source_prefactor, + 0.0, + ) + + # Electron temperature + source_T_e += ( + runtime_params.numerics.adaptive_T_source_prefactor + * internal_boundary_conditions.T_e + ) + source_mat_ee -= jnp.where( + internal_boundary_conditions.T_e != 0.0, + runtime_params.numerics.adaptive_T_source_prefactor, + 0.0, + ) + + # Density + source_n_e += ( + runtime_params.numerics.adaptive_n_source_prefactor + * internal_boundary_conditions.n_e + ) + source_mat_nn -= jnp.where( + internal_boundary_conditions.n_e != 0.0, + runtime_params.numerics.adaptive_n_source_prefactor, + 0.0, + ) + + return ( + source_T_i, + source_T_e, + source_n_e, + source_mat_ii, + source_mat_ee, + source_mat_nn, + ) diff --git a/torax/_src/internal_boundary_conditions/internal_boundary_conditions.py b/torax/_src/internal_boundary_conditions/internal_boundary_conditions.py index 359a4f71b..112e7368c 100644 --- a/torax/_src/internal_boundary_conditions/internal_boundary_conditions.py +++ b/torax/_src/internal_boundary_conditions/internal_boundary_conditions.py @@ -16,13 +16,13 @@ import dataclasses +import chex import jax import jax.numpy as jnp from torax._src import array_typing from torax._src import jax_utils -from torax._src.config import runtime_params as runtime_params_lib from torax._src.geometry import geometry - +from torax._src.torax_pydantic import torax_pydantic # pylint: disable=invalid-name @@ -72,64 +72,18 @@ def empty(cls, geo: geometry.Geometry) -> 'InternalBoundaryConditions': ) -def apply_adaptive_source( - *, - source_T_i: array_typing.FloatVectorCell, - source_T_e: array_typing.FloatVectorCell, - source_n_e: array_typing.FloatVectorCell, - source_mat_ii: array_typing.FloatVectorCell, - source_mat_ee: array_typing.FloatVectorCell, - source_mat_nn: array_typing.FloatVectorCell, - runtime_params: runtime_params_lib.RuntimeParams, - internal_boundary_conditions: InternalBoundaryConditions, -) -> tuple[ - array_typing.FloatVectorCell, - array_typing.FloatVectorCell, - array_typing.FloatVectorCell, - array_typing.FloatVectorCell, - array_typing.FloatVectorCell, - array_typing.FloatVectorCell, -]: - """Applies an adaptive source to the source profiles to set internal boundary conditions.""" - - # Ion temperature - source_T_i += ( - runtime_params.numerics.adaptive_T_source_prefactor - * internal_boundary_conditions.T_i - ) - source_mat_ii -= jnp.where( - internal_boundary_conditions.T_i != 0.0, - runtime_params.numerics.adaptive_T_source_prefactor, - 0.0, - ) - - # Electron temperature - source_T_e += ( - runtime_params.numerics.adaptive_T_source_prefactor - * internal_boundary_conditions.T_e - ) - source_mat_ee -= jnp.where( - internal_boundary_conditions.T_e != 0.0, - runtime_params.numerics.adaptive_T_source_prefactor, - 0.0, - ) +class InternalBoundaryConditionsConfig(torax_pydantic.BaseModelFrozen): + """Pydantic model for internal boundary conditions.""" - # Density - source_n_e += ( - runtime_params.numerics.adaptive_n_source_prefactor - * internal_boundary_conditions.n_e - ) - source_mat_nn -= jnp.where( - internal_boundary_conditions.n_e != 0.0, - runtime_params.numerics.adaptive_n_source_prefactor, - 0.0, - ) + # Set to zero by default, which is ignored by the adaptive source. + T_i: torax_pydantic.TimeVaryingPoints = torax_pydantic.ValidatedDefault(0.0) + T_e: torax_pydantic.TimeVaryingPoints = torax_pydantic.ValidatedDefault(0.0) + n_e: torax_pydantic.TimeVaryingPoints = torax_pydantic.ValidatedDefault(0.0) - return ( - source_T_i, - source_T_e, - source_n_e, - source_mat_ii, - source_mat_ee, - source_mat_nn, - ) + def build_runtime_params(self, t: chex.Numeric) -> InternalBoundaryConditions: + """Builds the runtime params for the internal boundary conditions.""" + kwargs = { + field.name: getattr(self, field.name).get_value(t) + for field in dataclasses.fields(InternalBoundaryConditions) + } + return InternalBoundaryConditions(**kwargs) diff --git a/torax/_src/internal_boundary_conditions/tests/internal_boundary_conditions_test.py b/torax/_src/internal_boundary_conditions/tests/internal_boundary_conditions_test.py index 6a5a86a64..d79339fb0 100644 --- a/torax/_src/internal_boundary_conditions/tests/internal_boundary_conditions_test.py +++ b/torax/_src/internal_boundary_conditions/tests/internal_boundary_conditions_test.py @@ -13,14 +13,50 @@ # limitations under the License. from absl.testing import absltest +from absl.testing import parameterized import jax.numpy as jnp import numpy as np +from torax._src.geometry import circular_geometry +from torax._src.internal_boundary_conditions import adaptive_source from torax._src.internal_boundary_conditions import internal_boundary_conditions +from torax._src.torax_pydantic import torax_pydantic # pylint: disable=invalid-name -class InternalBoundaryConditionsTest(absltest.TestCase): +class InternalBoundaryConditionsTest(parameterized.TestCase): + + @parameterized.named_parameters( + ('initial_time', 0.0, [1.0, 0.0, 0.0, 2.0]), + ('intermediate_time', 0.5, [2.0, 0.0, 0.0, 3.0]), + ('final_time', 1.0, [3.0, 0.0, 0.0, 4.0]), + ('after_final_time', 1.5, [3.0, 0.0, 0.0, 4.0]), + ) + def test_internal_boundary_conditions_config_build_runtime_params( + self, t, expected_T_i + ): + ibc_config = internal_boundary_conditions.InternalBoundaryConditionsConfig( + T_i={ + 0.0: {0: 1.0, 1: 2.0}, + 1.0: {0: 3.0, 1: 4.0}, + }, + ) + geo = circular_geometry.CircularConfig(n_rho=4).build_geometry() + torax_pydantic.set_grid(ibc_config, geo.torax_mesh) + + runtime_params = ibc_config.build_runtime_params(t=t) + np.testing.assert_array_equal( + runtime_params.T_i, + np.array(expected_T_i), + ) + np.testing.assert_array_equal( + runtime_params.T_e, + np.array([0.0, 0.0, 0.0, 0.0]), + ) + np.testing.assert_array_equal( + runtime_params.n_e, + np.array([0.0, 0.0, 0.0, 0.0]), + ) def test_update(self): ibc1 = internal_boundary_conditions.InternalBoundaryConditions( @@ -69,7 +105,7 @@ class MockRuntimeParams: source_mat_ii, source_mat_ee, source_mat_nn, - ) = internal_boundary_conditions.apply_adaptive_source( + ) = adaptive_source.apply_adaptive_source( source_T_i=source_T_i, source_T_e=source_T_e, source_n_e=source_n_e, diff --git a/torax/_src/interpolated_param.py b/torax/_src/interpolated_param.py index 9e7b176d9..c877d607e 100644 --- a/torax/_src/interpolated_param.py +++ b/torax/_src/interpolated_param.py @@ -58,6 +58,7 @@ class InterpolationMode(enum.Enum): input greater than x_n. Options: + NONE: No interpolation. Only appropriate for constant parameters. PIECEWISE_LINEAR: Does piecewise-linear interpolation between the values provided. See numpy.interp for a longer description of how it works. (This uses JAX, but the behavior is the same.) @@ -65,12 +66,13 @@ class InterpolationMode(enum.Enum): x_k+1), the output will be y_k. """ + NONE = 'none' PIECEWISE_LINEAR = 'piecewise_linear' STEP = 'step' InterpolationModeLiteral: TypeAlias = Literal[ - 'step', 'STEP', 'piecewise_linear', 'PIECEWISE_LINEAR' + 'step', 'STEP', 'piecewise_linear', 'PIECEWISE_LINEAR', 'none', 'NONE' ] @@ -515,18 +517,49 @@ def __init__( self._rho_interpolation_mode = rho_interpolation_mode self._time_interpolation_mode = time_interpolation_mode - sorted_indices = np.array(sorted(values.keys())) - rho_norm_interpolated_values = np.stack( - [ - InterpolatedVarSingleAxis( - values[t], rho_interpolation_mode - ).get_value(rho_norm) - for t in sorted_indices - ], - axis=0, - ) + sorted_times = np.array(sorted(values.keys())) + if self._rho_interpolation_mode == InterpolationMode.NONE: + # If no rho interpolation is needed, the given rho_norm locations will be + # quantized onto the rho_norm grid, and the values are delta functions at + # the quantized rho_norm locations. + + # Check that the same rho_norm values are used for all times + given_rho_norm_locations = values[sorted_times[0]][0] + for t in sorted_times: + if not np.array_equal(values[t][0], given_rho_norm_locations): + raise ValueError( + 'When rho_interpolation_mode is InterpolationMode.NONE, the' + 'rho_norm locations must be the same for all times.' + ) + + # Quantize the rho_norm locations onto the grid, selecting the rho_norm + # grid point that is closest to the given rho_norm location. + quantized_rho_norm_indices = np.argmin( + np.abs(np.atleast_1d(rho_norm)[:, None] - given_rho_norm_locations), + axis=0, + ) + + # Convert the values to delta functions on the quantized rho_norm values. + rho_norm_interpolated_values = np.zeros( + (len(sorted_times), len(np.atleast_1d(rho_norm))), + dtype=jax_utils.get_np_dtype(), + ) + for t_idx in range(len(sorted_times)): + rho_norm_interpolated_values[t_idx, quantized_rho_norm_indices] = ( + values[sorted_times[t_idx]][1] + ) + else: + rho_norm_interpolated_values = np.stack( + [ + InterpolatedVarSingleAxis( + values[t], rho_interpolation_mode + ).get_value(rho_norm) + for t in sorted_times + ], + axis=0, + ) self._time_interpolated_var = InterpolatedVarSingleAxis( - value=(sorted_indices, rho_norm_interpolated_values), + value=(sorted_times, rho_norm_interpolated_values), interpolation_mode=time_interpolation_mode, ) diff --git a/torax/_src/torax_pydantic/interpolated_param_2d.py b/torax/_src/torax_pydantic/interpolated_param_2d.py index 08a60ff9f..c5b6d4be0 100644 --- a/torax/_src/torax_pydantic/interpolated_param_2d.py +++ b/torax/_src/torax_pydantic/interpolated_param_2d.py @@ -49,6 +49,7 @@ class Grid1D(model_base.BaseModelFrozen): of all faces (including boundary faces). For a grid with N cells, there are N+1 faces. """ + face_centers: pydantic_types.NumpyArray1DSorted @pydantic.model_validator(mode='before') @@ -444,6 +445,23 @@ def get_cached_interpolated_param_face_right( ) +class TimeVaryingPoints(TimeVaryingArray): + """A TimeVaryingArray that is defined on a fixed set of rho points, without interpolation in rho.""" + + rho_interpolation_mode: Literal[interpolated_param.InterpolationMode.NONE] = ( + interpolated_param.InterpolationMode.NONE + ) + + @pydantic.model_validator(mode='before') + @classmethod + def _conform_data( + cls, data: interpolated_param.TimeRhoInterpolatedInput | dict[str, Any] + ) -> dict[str, Any]: + data = super()._conform_data(data) + data['rho_interpolation_mode'] = interpolated_param.InterpolationMode.NONE + return data + + def _is_positive(array: TimeVaryingArray) -> TimeVaryingArray: for _, value in array.value.values(): if not np.all(value > 0): diff --git a/torax/_src/torax_pydantic/tests/interpolated_param_2d_test.py b/torax/_src/torax_pydantic/tests/interpolated_param_2d_test.py index 9c54361ab..4f425fb32 100644 --- a/torax/_src/torax_pydantic/tests/interpolated_param_2d_test.py +++ b/torax/_src/torax_pydantic/tests/interpolated_param_2d_test.py @@ -615,6 +615,29 @@ def f( np.testing.assert_allclose(cell, [8.0, 10.0, 12.0, 14.0]) self.assertEqual(jax_utils.get_number_of_compiles(f), 1) + def test_time_varying_points(self): + time_rho_interpolated_input = ( + np.array([0.0, 1.0]), # time + np.array([0.0, 1.0]), # rho_norm + np.array([[1.0, 2.0], [3.0, 4.0]]), # values + ) + tvp = interpolated_param_2d.TimeVaryingPoints.model_validate( + time_rho_interpolated_input + ) + face_centers = interpolated_param_2d.get_face_centers(4) + grid = interpolated_param_2d.Grid1D(face_centers=face_centers) + interpolated_param_2d.set_grid(tvp, grid=grid) + + np.testing.assert_array_equal( + tvp.get_value(0.0), np.array([1.0, 0.0, 0.0, 2.0]) + ) + np.testing.assert_array_equal( + tvp.get_value(0.5), np.array([2.0, 0.0, 0.0, 3.0]) + ) + np.testing.assert_array_equal( + tvp.get_value(1.0), np.array([3.0, 0.0, 0.0, 4.0]) + ) + if __name__ == '__main__': absltest.main() diff --git a/torax/_src/torax_pydantic/torax_pydantic.py b/torax/_src/torax_pydantic/torax_pydantic.py index 45a0a82b1..4faa30062 100644 --- a/torax/_src/torax_pydantic/torax_pydantic.py +++ b/torax/_src/torax_pydantic/torax_pydantic.py @@ -51,6 +51,7 @@ BaseModelFrozen = model_base.BaseModelFrozen TimeVaryingScalar = interpolated_param_1d.TimeVaryingScalar +TimeVaryingPoints = interpolated_param_2d.TimeVaryingPoints TimeVaryingArray = interpolated_param_2d.TimeVaryingArray NonNegativeTimeVaryingArray = interpolated_param_2d.NonNegativeTimeVaryingArray PositiveTimeVaryingScalar = interpolated_param_1d.PositiveTimeVaryingScalar diff --git a/torax/tests/sim_test.py b/torax/tests/sim_test.py index 71386126e..182485384 100644 --- a/torax/tests/sim_test.py +++ b/torax/tests/sim_test.py @@ -238,6 +238,11 @@ class SimTest(sim_test_case.SimTestCase): 'test_iterhybrid_predictor_corrector_mavrin_n_e_ratios_z_eff', 'test_iterhybrid_predictor_corrector_mavrin_n_e_ratios_z_eff.py', ), + # Tests time dependent internal boundary conditions. + ( + 'test_iterhybrid_predictor_corrector_internal_boundary', + 'test_iterhybrid_predictor_corrector_internal_boundary.py', + ), # Predictor-corrector solver with constant pressure pedestal model. ( 'test_iterhybrid_predictor_corrector_set_pped_tpedratio_nped', diff --git a/torax/tests/test_data/test_iterhybrid_predictor_corrector_internal_boundary.nc b/torax/tests/test_data/test_iterhybrid_predictor_corrector_internal_boundary.nc new file mode 100644 index 000000000..23dea0a40 Binary files /dev/null and b/torax/tests/test_data/test_iterhybrid_predictor_corrector_internal_boundary.nc differ diff --git a/torax/tests/test_data/test_iterhybrid_predictor_corrector_internal_boundary.py b/torax/tests/test_data/test_iterhybrid_predictor_corrector_internal_boundary.py new file mode 100644 index 000000000..a7ccab6b8 --- /dev/null +++ b/torax/tests/test_data/test_iterhybrid_predictor_corrector_internal_boundary.py @@ -0,0 +1,23 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests internal boundary conditions.""" + +import copy +from torax.tests.test_data import test_timedependence + +CONFIG = copy.deepcopy(test_timedependence.CONFIG) +CONFIG['profile_conditions']['internal_boundary_conditions'] = { + 'T_e': {0.0: {0.4: 7.0}, 2.0: {0.4: 10.0}, 5.0: {0.4: 15.0}}, +}