Skip to content

Commit 3118455

Browse files
theo-brownTorax team
authored andcommitted
Cache grid properties but not value properties of CellVariables.
The grid will be fixed for a whole simulation, so the cell properties can be cached. The left and right face values are not necessarily fixed for the whole simulation, so should not be cached. Also fixed the fact that right_face_value was a property and left_face_value was a method. PiperOrigin-RevId: 889163086
1 parent 77607a6 commit 3118455

2 files changed

Lines changed: 9 additions & 11 deletions

File tree

torax/_src/fvm/cell_variable.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,12 @@ def cell_centers(self) -> jt.Float[chex.Array, 'cell']:
139139
"""Locations of the cell centers."""
140140
return (self.face_centers[..., 1:] + self.face_centers[..., :-1]) / 2.0
141141

142-
@property
142+
@functools.cached_property
143143
def cell_widths(self) -> jt.Float[chex.Array, 'cell']:
144144
"""Size of each cell."""
145145
return jnp.diff(self.face_centers)
146146

147-
@property
147+
@functools.cached_property
148148
def cell_spacings(self) -> jt.Float[chex.Array, 'cell-1']:
149149
"""Spacing between each cell."""
150150
return jnp.diff(self.cell_centers)
@@ -272,6 +272,7 @@ def constrained_grad(
272272
right = jnp.expand_dims(right_grad, axis=0)
273273
return jnp.concatenate([left, inner_grad, right])
274274

275+
@property
275276
def left_face_value(self) -> jt.Float[chex.Array, '']:
276277
"""Calculates the value of the leftmost face."""
277278
if self.left_face_constraint is not None:
@@ -284,7 +285,7 @@ def left_face_value(self) -> jt.Float[chex.Array, '']:
284285
value = self.value[..., 0:1]
285286
return value
286287

287-
@functools.cached_property
288+
@property
288289
def right_face_value(self) -> jt.Float[chex.Array, '']:
289290
"""Calculates the value of the rightmost face."""
290291
if self.right_face_constraint is not None:
@@ -311,7 +312,7 @@ def face_value(self) -> jt.Float[chex.Array, 'face']:
311312
)
312313

313314
return jnp.concatenate(
314-
[self.left_face_value(), inner, self.right_face_value], axis=-1
315+
[self.left_face_value, inner, self.right_face_value], axis=-1
315316
)
316317

317318
def grad(self) -> jt.Float[chex.Array, 'cell']:
@@ -338,9 +339,6 @@ def __str__(self) -> str:
338339

339340
def cell_plus_boundaries(self) -> jt.Float[chex.Array, 'cell+2']:
340341
"""Returns the value of this variable plus left and right boundaries."""
341-
right_value = self.right_face_value
342-
left_value = self.left_face_value()
343342
return jnp.concatenate(
344-
[left_value, self.value, right_value],
345-
axis=-1,
343+
[self.left_face_value, self.value, self.right_face_value], axis=-1
346344
)

torax/_src/imas_tools/output/core_profiles.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,11 @@ def _fill_profiles_1d_grid(
246246
[[0.0], geometry_slice.rho, [geometry_slice.rho_b]]
247247
)
248248
ids.profiles_1d[i].grid.psi = cp_state.psi.cell_plus_boundaries()
249-
ids.profiles_1d[i].grid.psi_magnetic_axis = cp_state.psi.left_face_value()[0]
249+
ids.profiles_1d[i].grid.psi_magnetic_axis = cp_state.psi.left_face_value[0]
250250
ids.profiles_1d[i].grid.psi_boundary = cp_state.psi.right_face_value[0]
251251
ids.profiles_1d[i].grid.rho_pol_norm = np.sqrt(
252-
(cp_state.psi.cell_plus_boundaries() - cp_state.psi.left_face_value()[0])
253-
/ (cp_state.psi.right_face_value[0] - cp_state.psi.left_face_value()[0])
252+
(cp_state.psi.cell_plus_boundaries() - cp_state.psi.left_face_value[0])
253+
/ (cp_state.psi.right_face_value[0] - cp_state.psi.left_face_value[0])
254254
)
255255
ids.profiles_1d[i].grid.volume = output.extend_cell_grid_to_boundaries(
256256
[geometry_slice.volume], np.array([geometry_slice.volume_face])

0 commit comments

Comments
 (0)