@@ -397,18 +397,48 @@ def _get_pipeline_fn(config):
397397 return pretrain_preprocessing_pipeline
398398
399399
400- def _make_elastic_iterator (dataset , config , preprocessing_fn , shard_index = None , shard_count = None , mp_opts = None ):
400+ def _make_elastic_iterator (
401+ dataset ,
402+ config ,
403+ preprocessing_fn ,
404+ shard_index = None ,
405+ shard_count = None ,
406+ process_indices = None ,
407+ mp_opts = None ,
408+ ):
401409 """Applies preprocessing_fn then wraps the result with ElasticIterator.
402410
403- When shard_index/shard_count are None, defaults to jax.process_index()/jax.process_count().
411+ When shard_index/shard_count are None, defaults to
412+ jax.process_index()/jax.process_count() (or calculated from
413+ process_indices if provided).
414+
415+ Args:
416+ dataset: The input dataset.
417+ config: The hyperparameter configuration.
418+ preprocessing_fn: The function to apply before wrapping.
419+ shard_index: The shard index. Defaults to None.
420+ shard_count: The shard count. Defaults to None.
421+ process_indices: The active process indices. Defaults to None.
422+ mp_opts: Multiprocessing options. Defaults to None.
404423 """
405424 ds = preprocessing_fn (dataset = dataset )
425+ if shard_index is None :
426+ if process_indices is not None :
427+ shard_index = process_indices .index (jax .process_index ())
428+ else :
429+ shard_index = jax .process_index ()
430+ if shard_count is None :
431+ if process_indices is not None :
432+ shard_count = len (process_indices )
433+ else :
434+ shard_count = jax .process_count ()
435+
406436 return ElasticIterator (
407437 ds ,
408438 global_batch_size = config .global_batch_size_to_load ,
409439 shard_options = grain .ShardOptions (
410- shard_index = shard_index if shard_index is not None else jax . process_index () ,
411- shard_count = shard_count if shard_count is not None else jax . process_count () ,
440+ shard_index = shard_index ,
441+ shard_count = shard_count ,
412442 ),
413443 read_options = grain .ReadOptions (
414444 num_threads = config .grain_num_threads ,
@@ -463,7 +493,12 @@ def make_grain_train_iterator(
463493 # pass to MultiHostDataLoadIterator
464494 if config .colocated_python_data_input :
465495 if config .grain_use_elastic_iterator :
466- preprocessing_fn = functools .partial (_make_elastic_iterator , config = config , preprocessing_fn = preprocessing_fn )
496+ preprocessing_fn = functools .partial (
497+ _make_elastic_iterator ,
498+ config = config ,
499+ preprocessing_fn = preprocessing_fn ,
500+ process_indices = process_indices ,
501+ )
467502
468503 global_shape = (config .global_batch_size_to_load , config .max_target_length )
469504 return multihost_dataloading .RemoteIteratorWrapper (
0 commit comments