@@ -63,6 +63,18 @@ def __init__(
6363 raise ValueError ("OCDBT not supported for Pathways." )
6464 super ().__init__ ()
6565
66+ async def _background_serialize (
67+ self ,
68+ values : Sequence [jax .Array ],
69+ locations : Sequence [str ],
70+ names : Sequence [str ],
71+ ) -> None :
72+ """Uses Pathways Persistence API to serialize a jax array."""
73+ f = functools .partial (helper .write_one_array , timeout = self ._read_timeout )
74+ futures_results = list (map (f , locations , names , values ))
75+ for future_result in futures_results :
76+ future_result .result ()
77+
6678 async def serialize (
6779 self ,
6880 values : Sequence [jax .Array ],
@@ -76,8 +88,12 @@ async def serialize(
7688 raise ValueError ("Casting during save not supported for Pathways." )
7789
7890 locations , names = extract_parent_dir_and_name (infos )
79- f = functools .partial (helper .write_one_array , timeout = self ._read_timeout )
80- return list (map (f , locations , names , values ))
91+ return [
92+ future .CommitFutureAwaitingContractedSignals (
93+ self ._background_serialize (values , locations , names ),
94+ name = "cloud_pathways_array_handler" ,
95+ )
96+ ]
8197
8298 async def deserialize (
8399 self ,
0 commit comments