@@ -272,6 +272,7 @@ def constrained_grad(
272272 right = jnp .expand_dims (right_grad , axis = 0 )
273273 return jnp .concatenate ([left , inner_grad , right ])
274274
275+ @property
275276 def left_face_value (self ) -> jt .Float [chex .Array , '' ]:
276277 """Calculates the value of the leftmost face."""
277278 if self .left_face_constraint is not None :
@@ -284,7 +285,7 @@ def left_face_value(self) -> jt.Float[chex.Array, '']:
284285 value = self .value [..., 0 :1 ]
285286 return value
286287
287- @functools . cached_property
288+ @property
288289 def right_face_value (self ) -> jt .Float [chex .Array , '' ]:
289290 """Calculates the value of the rightmost face."""
290291 if self .right_face_constraint is not None :
@@ -311,7 +312,7 @@ def face_value(self) -> jt.Float[chex.Array, 'face']:
311312 )
312313
313314 return jnp .concatenate (
314- [self .left_face_value () , inner , self .right_face_value ], axis = - 1
315+ [self .left_face_value , inner , self .right_face_value ], axis = - 1
315316 )
316317
317318 def grad (self ) -> jt .Float [chex .Array , 'cell' ]:
@@ -338,9 +339,6 @@ def __str__(self) -> str:
338339
339340 def cell_plus_boundaries (self ) -> jt .Float [chex .Array , 'cell+2' ]:
340341 """Returns the value of this variable plus left and right boundaries."""
341- right_value = self .right_face_value
342- left_value = self .left_face_value ()
343342 return jnp .concatenate (
344- [left_value , self .value , right_value ],
345- axis = - 1 ,
343+ [self .left_face_value , self .value , self .right_face_value ], axis = - 1
346344 )
0 commit comments