diff --git a/torax/_src/fvm/cell_variable.py b/torax/_src/fvm/cell_variable.py index 501bb2694..fad348960 100644 --- a/torax/_src/fvm/cell_variable.py +++ b/torax/_src/fvm/cell_variable.py @@ -272,6 +272,7 @@ def constrained_grad( right = jnp.expand_dims(right_grad, axis=0) return jnp.concatenate([left, inner_grad, right]) + @functools.cached_property def left_face_value(self) -> jt.Float[chex.Array, '']: """Calculates the value of the leftmost face.""" if self.left_face_constraint is not None: @@ -311,7 +312,7 @@ def face_value(self) -> jt.Float[chex.Array, 'face']: ) return jnp.concatenate( - [self.left_face_value(), inner, self.right_face_value], axis=-1 + [self.left_face_value, inner, self.right_face_value], axis=-1 ) def grad(self) -> jt.Float[chex.Array, 'cell']: @@ -338,9 +339,6 @@ def __str__(self) -> str: def cell_plus_boundaries(self) -> jt.Float[chex.Array, 'cell+2']: """Returns the value of this variable plus left and right boundaries.""" - right_value = self.right_face_value - left_value = self.left_face_value() return jnp.concatenate( - [left_value, self.value, right_value], - axis=-1, + [self.left_face_value, self.value, self.right_face_value], axis=-1 ) diff --git a/torax/_src/imas_tools/output/core_profiles.py b/torax/_src/imas_tools/output/core_profiles.py index 9d1e9f230..792799225 100644 --- a/torax/_src/imas_tools/output/core_profiles.py +++ b/torax/_src/imas_tools/output/core_profiles.py @@ -246,11 +246,11 @@ def _fill_profiles_1d_grid( [[0.0], geometry_slice.rho, [geometry_slice.rho_b]] ) ids.profiles_1d[i].grid.psi = cp_state.psi.cell_plus_boundaries() - ids.profiles_1d[i].grid.psi_magnetic_axis = cp_state.psi.left_face_value()[0] + ids.profiles_1d[i].grid.psi_magnetic_axis = cp_state.psi.left_face_value[0] ids.profiles_1d[i].grid.psi_boundary = cp_state.psi.right_face_value[0] ids.profiles_1d[i].grid.rho_pol_norm = np.sqrt( - (cp_state.psi.cell_plus_boundaries() - cp_state.psi.left_face_value()[0]) - / (cp_state.psi.right_face_value[0] - cp_state.psi.left_face_value()[0]) + (cp_state.psi.cell_plus_boundaries() - cp_state.psi.left_face_value[0]) + / (cp_state.psi.right_face_value[0] - cp_state.psi.left_face_value[0]) ) ids.profiles_1d[i].grid.volume = output.extend_cell_grid_to_boundaries( [geometry_slice.volume], np.array([geometry_slice.volume_face])