@@ -168,35 +168,52 @@ def get_read_request(
168168 dtype : np .dtype ,
169169 shape : Sequence [int ],
170170 sharding : jax .sharding .Sharding ,
171- devices : Sequence [jax .Device ],
171+ devices : Sequence [jax .Device ] | None ,
172172 timeout : datetime .timedelta ,
173173 return_dict : bool = False ,
174174) -> Union [str , dict [str , Any ]]:
175175 """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
176- if not isinstance (devices , np .ndarray ):
177- devices = np .array (devices )
178-
179176 timeout_seconds , timeout_fractional_seconds = divmod (
180177 timeout .total_seconds (), 1
181178 )
182179 timeout_nanoseconds = timeout_fractional_seconds * 1e9
183- d = {
184- "persistenceReadRequest" : {
185- "b64_location" : string_to_base64 (location_path ),
186- "shape" : get_shape_info (dtype , shape ),
187- "b64_name" : string_to_base64 (name ),
188- "b64_hlo_sharding_string" : get_hlo_sharding_string (
189- sharding , len (shape )
190- ),
191- "devices" : {
192- "device_ids" : [device .id for device in devices .flatten ()]
193- },
194- "timeout" : {
195- "seconds" : int (timeout_seconds ),
196- "nanos" : int (timeout_nanoseconds ),
197- },
198- }
199- }
180+
181+ if devices is None :
182+ d = {
183+ "persistenceReadRequest" : {
184+ "b64_location" : string_to_base64 (location_path ),
185+ "shape" : get_shape_info (dtype , shape ),
186+ "b64_name" : string_to_base64 (name ),
187+ "b64_hlo_sharding_string" : get_hlo_sharding_string (
188+ sharding , len (shape )
189+ ),
190+ "timeout" : {
191+ "seconds" : int (timeout_seconds ),
192+ "nanos" : int (timeout_nanoseconds ),
193+ },
194+ }
195+ }
196+ else :
197+ if not isinstance (devices , np .ndarray ):
198+ devices = np .array (devices )
199+
200+ d = {
201+ "persistenceReadRequest" : {
202+ "b64_location" : string_to_base64 (location_path ),
203+ "shape" : get_shape_info (dtype , shape ),
204+ "b64_name" : string_to_base64 (name ),
205+ "b64_hlo_sharding_string" : get_hlo_sharding_string (
206+ sharding , len (shape )
207+ ),
208+ "devices" : {
209+ "device_ids" : [device .id for device in devices .flatten ()]
210+ },
211+ "timeout" : {
212+ "seconds" : int (timeout_seconds ),
213+ "nanos" : int (timeout_nanoseconds ),
214+ },
215+ }
216+ }
200217
201218 if return_dict :
202219 return d
@@ -224,6 +241,38 @@ def get_bulk_read_request(
224241 )
225242
226243
244+ def get_bulk_read_request_per_device_list (
245+ location_path : str ,
246+ names : Sequence [str ],
247+ dtypes : Sequence [np .dtype ],
248+ shapes : Sequence [Sequence [int ]],
249+ shardings : Sequence [jax .sharding .Sharding ],
250+ devices : Sequence [jax .Device ],
251+ timeout : datetime .timedelta ,
252+ ) -> str :
253+ """Returns a string representation of a bulk read request, reads multiple arrays with one call."""
254+ read_requests = [
255+ get_read_request (
256+ location_path , name , dtype , shape , sharding , None , timeout , True
257+ )["persistenceReadRequest" ]
258+ for name , dtype , shape , sharding in zip (names , dtypes , shapes , shardings )
259+ ]
260+
261+ if not isinstance (devices , np .ndarray ):
262+ devices = np .array (devices )
263+
264+ return json .dumps ({
265+ "bulk_persistence_read_request" : {
266+ "read_requests_per_device_list" : {
267+ "device_list" : {
268+ "device_ids" : [device .id for device in devices .flatten ()]
269+ },
270+ "read_requests" : read_requests ,
271+ }
272+ }
273+ })
274+
275+
227276def write_one_array (
228277 location : str ,
229278 name : str ,
0 commit comments