@@ -397,18 +397,37 @@ def _get_pipeline_fn(config):
397397 return pretrain_preprocessing_pipeline
398398
399399
400- def _make_elastic_iterator (dataset , config , preprocessing_fn , shard_index = None , shard_count = None , mp_opts = None ):
400+ def _make_elastic_iterator (
401+ dataset ,
402+ config ,
403+ preprocessing_fn ,
404+ shard_index = None ,
405+ shard_count = None ,
406+ process_indices = None ,
407+ mp_opts = None ,
408+ ):
401409 """Applies preprocessing_fn then wraps the result with ElasticIterator.
402410
403411 When shard_index/shard_count are None, defaults to jax.process_index()/jax.process_count().
404412 """
405413 ds = preprocessing_fn (dataset = dataset )
414+ if shard_index is None :
415+ if process_indices is not None :
416+ shard_index = process_indices .index (jax .process_index ())
417+ else :
418+ shard_index = jax .process_index ()
419+ if shard_count is None :
420+ if process_indices is not None :
421+ shard_count = len (process_indices )
422+ else :
423+ shard_count = jax .process_count ()
424+
406425 return ElasticIterator (
407426 ds ,
408427 global_batch_size = config .global_batch_size_to_load ,
409428 shard_options = grain .ShardOptions (
410- shard_index = shard_index if shard_index is not None else jax . process_index () ,
411- shard_count = shard_count if shard_count is not None else jax . process_count () ,
429+ shard_index = shard_index ,
430+ shard_count = shard_count ,
412431 ),
413432 read_options = grain .ReadOptions (
414433 num_threads = config .grain_num_threads ,
@@ -463,7 +482,12 @@ def make_grain_train_iterator(
463482 # pass to MultiHostDataLoadIterator
464483 if config .colocated_python_data_input :
465484 if config .grain_use_elastic_iterator :
466- preprocessing_fn = functools .partial (_make_elastic_iterator , config = config , preprocessing_fn = preprocessing_fn )
485+ preprocessing_fn = functools .partial (
486+ _make_elastic_iterator ,
487+ config = config ,
488+ preprocessing_fn = preprocessing_fn ,
489+ process_indices = process_indices ,
490+ )
467491
468492 global_shape = (config .global_batch_size_to_load , config .max_target_length )
469493 return multihost_dataloading .RemoteIteratorWrapper (
0 commit comments