Skip to content

Commit aa86e16

Browse files
mridul-sahucopybara-github
authored andcommitted
Create copies of Pathways arrays before starting the serialization in orbax handler.
This prevents the buffers from being deallocated if the user discards the array before the asynchronous save operation completes. PiperOrigin-RevId: 766470507
1 parent 873ec9a commit aa86e16

2 files changed

Lines changed: 6 additions & 172 deletions

File tree

pathwaysutils/elastic/README.md

Lines changed: 0 additions & 171 deletions
This file was deleted.

pathwaysutils/persistence/orbax_handler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,15 @@ async def serialize(
8787
if any([arg.dtype is not None for arg in args]):
8888
raise ValueError("Casting during save not supported for Pathways.")
8989

90+
# Create a copy of the arrays to ensure their buffers are not deallocated
91+
# before the asynchronous write operation completes.
92+
copied_values = [v.copy() for v in values]
93+
jax.block_until_ready(copied_values)
94+
9095
locations, names = extract_parent_dir_and_name(infos)
9196
return [
9297
future.CommitFutureAwaitingContractedSignals(
93-
self._background_serialize(values, locations, names),
98+
self._background_serialize(copied_values, locations, names),
9499
name="cloud_pathways_array_handler",
95100
)
96101
]

0 commit comments

Comments
 (0)