@@ -205,9 +205,53 @@ def _colocated_cpu_mesh(mesh: Mesh) -> Mesh:
205205 return colocated_python .colocated_cpu_devices (mesh )
206206
207207
208+ @colocated_python .colocated_python_class
208209class RemoteIterator :
209- "iterator class for using colocated python, iterator is initiated remotely and stored in the state of colocated python"
210+ "iterator class for using colocated python class"
211+
212+ def __init__ (self , get_ds_fn , preprocessing_fn , global_shape ):
213+ # self.cpu_devices = _colocated_cpu_devices(jax.local_devices())
214+ # self.tpu_devices = jax.local_devices()
215+ # self.cpu_mesh = _colocated_cpu_mesh(global_mesh)
216+ # self.tpu_sharding = jax.sharding.NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names))
217+ # self.cpu_sharding = jax.sharding.NamedSharding(self.cpu_mesh, PartitionSpec(self.cpu_mesh.axis_names))
218+ self .global_shape = global_shape
219+ ds = get_ds_fn (dataloading_host_index = jax .process_index (), dataloading_host_count = jax .process_count ())
220+ dataloader = preprocessing_fn (dataset = ds )
221+ self .iterator = None
222+ if isinstance (dataloader , tf .data .Dataset ):
223+ self .iterator = dataloader .as_numpy_iterator ()
224+ elif isinstance (dataloader , Iterable ):
225+ self .iterator = iter (dataloader )
226+ else :
227+ raise ValueError ("Type error: dataloader should be Iterable." )
210228
229+ def get_next (self , dummy_array ):
230+ local_data = next (self .iterator )
231+ def form_global_array_colocated_python (path , array , devices , global_shape , sharding ):
232+ try :
233+ device_arrays = np .split (array , len (devices ), axis = 0 )
234+ except ValueError as array_split_error :
235+ raise ValueError (
236+ f"Unable to put to devices shape { array .shape } with "
237+ f"local device count { len (devices )} "
238+ f"at { jtu .keystr (path )} "
239+ ) from array_split_error
240+ device_arrays = jax .device_put (device_arrays , devices )
241+ return jax .make_array_from_single_device_arrays (shape = global_shape , sharding = sharding , arrays = device_arrays )
242+
243+ return jtu .tree_map_with_path (
244+ partial (
245+ form_global_array_colocated_python ,
246+ devices = list (dummy_array .sharding .addressable_devices ),
247+ global_shape = self .global_shape ,
248+ sharding = dummy_array .sharding ,
249+ ),
250+ local_data ,
251+ )
252+
253+
254+ class RemoteIteratorWrapper :
211255 def __init__ (self , get_ds_fn , preprocessing_fn , global_mesh , global_shape ):
212256 self .cpu_devices = _colocated_cpu_devices (jax .local_devices ())
213257 self .tpu_devices = jax .local_devices ()
@@ -216,42 +260,60 @@ def __init__(self, get_ds_fn, preprocessing_fn, global_mesh, global_shape):
216260 self .cpu_sharding = jax .sharding .NamedSharding (self .cpu_mesh , PartitionSpec (self .cpu_mesh .axis_names ))
217261 self .dummy_array = jnp .zeros ((len (self .cpu_devices )))
218262 self .dummy_array = jax .device_put (self .dummy_array , self .cpu_sharding )
219-
220- @colocated_python .colocated_python
221- def init (dummy_array ):
222- colocated_python .global_shape = global_shape
223- ds = get_ds_fn (dataloading_host_index = jax .process_index (), dataloading_host_count = jax .process_count ())
224- dataloader = preprocessing_fn (dataset = ds )
225- if isinstance (dataloader , tf .data .Dataset ):
226- colocated_python .iterator = dataloader .as_numpy_iterator ()
227- elif isinstance (dataloader , Iterable ):
228- colocated_python .iterator = iter (dataloader )
229- else :
230- raise ValueError ("Type error: dataloader should be either tf.data.Dataset or grain.DataLoader." )
231- return dummy_array
232-
233- max_logging .log ("Initiating RemoteIterator" )
234- try :
235- out = jax .device_get (init (self .dummy_array ))
236- if out is not None :
237- max_logging .log (f"RemoteIterator initiated. Test output: { out } " )
238- except Exception as e :
239- max_logging .log (f"RemoteIterator init FAILED with error: { type (e ).__name__ } : { e } " )
240- raise
241-
242- def __iter__ (self ):
243- return self
263+ self .remote_iterator = RemoteIterator (get_ds_fn , preprocessing_fn , global_shape )
244264
245265 def __next__ (self ):
246- out = _get_next (self .dummy_array )
247-
248- def put_to_tpu_devices (path , array , sharding ):
249- try :
250- return jax .device_put (array , sharding )
251- except Exception as e : # pylint: disable=broad-exception-caught
252- max_logging .log (f"Error putting data to TPU device path{ path } , exception={ e } " )
253- raise
254-
255- input_gdas = jtu .tree_map_with_path (partial (put_to_tpu_devices , sharding = self .tpu_sharding ), out )
256-
257- return input_gdas
266+ out = self .remote_iterator .get_next (self .dummy_array )
267+ # use tree_map is out is a dict
268+ return jax .device_put (out , self .tpu_sharding )
269+
270+ # class RemoteIterator:
271+ # "iterator class for using colocated python, iterator is initiated remotely and stored in the state of colocated python"
272+
273+ # def __init__(self, get_ds_fn, preprocessing_fn, global_mesh, global_shape):
274+ # self.cpu_devices = _colocated_cpu_devices(jax.local_devices())
275+ # self.tpu_devices = jax.local_devices()
276+ # self.cpu_mesh = _colocated_cpu_mesh(global_mesh)
277+ # self.tpu_sharding = jax.sharding.NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names))
278+ # self.cpu_sharding = jax.sharding.NamedSharding(self.cpu_mesh, PartitionSpec(self.cpu_mesh.axis_names))
279+ # self.dummy_array = jnp.zeros((len(self.cpu_devices)))
280+ # self.dummy_array = jax.device_put(self.dummy_array, self.cpu_sharding)
281+
282+ # @colocated_python.colocated_python
283+ # def init(dummy_array):
284+ # colocated_python.global_shape = global_shape
285+ # ds = get_ds_fn(dataloading_host_index=jax.process_index(), dataloading_host_count=jax.process_count())
286+ # dataloader = preprocessing_fn(dataset=ds)
287+ # if isinstance(dataloader, tf.data.Dataset):
288+ # colocated_python.iterator = dataloader.as_numpy_iterator()
289+ # elif isinstance(dataloader, Iterable):
290+ # colocated_python.iterator = iter(dataloader)
291+ # else:
292+ # raise ValueError("Type error: dataloader should be either tf.data.Dataset or grain.DataLoader.")
293+ # return dummy_array
294+
295+ # max_logging.log("Initiating RemoteIterator")
296+ # try:
297+ # out = jax.device_get(init(self.dummy_array))
298+ # if out is not None:
299+ # max_logging.log(f"RemoteIterator initiated. Test output: {out}")
300+ # except Exception as e:
301+ # max_logging.log(f"RemoteIterator init FAILED with error: {type(e).__name__}: {e}")
302+ # raise
303+
304+ # def __iter__(self):
305+ # return self
306+
307+ # def __next__(self):
308+ # out = _get_next(self.dummy_array)
309+
310+ # def put_to_tpu_devices(path, array, sharding):
311+ # try:
312+ # return jax.device_put(array, sharding)
313+ # except Exception as e: # pylint: disable=broad-exception-caught
314+ # max_logging.log(f"Error putting data to TPU device path{path}, exception={e}")
315+ # raise
316+
317+ # input_gdas = jtu.tree_map_with_path(partial(put_to_tpu_devices, sharding=self.tpu_sharding), out)
318+
319+ # return input_gdas
0 commit comments