Skip to content

Commit 3616eb3

Browse files
committed
migrate to colocated python class
1 parent 2b9cebb commit 3616eb3

4 files changed

Lines changed: 109 additions & 47 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import jax
2525
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
2626
from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator
27-
from maxtext.input_pipeline.multihost_dataloading import RemoteIterator
27+
from maxtext.input_pipeline.multihost_dataloading import RemoteIteratorWrapper
2828
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
2929
from maxtext.utils import exceptions
3030
from maxtext.utils import max_logging
@@ -132,11 +132,11 @@ class GrainCheckpointRestore(ocp.args.CheckpointArgs):
132132

133133

134134
def _is_remote_iterator(data_iterator):
135-
"""Check if data_iterator is a RemoteIterator or contains RemoteIterator instances."""
136-
if isinstance(data_iterator, RemoteIterator):
135+
"""Check if data_iterator is a RemoteIteratorWrapper or contains RemoteIteratorWrapper instances."""
136+
if isinstance(data_iterator, RemoteIteratorWrapper):
137137
return True
138138
if isinstance(data_iterator, list):
139-
return any(isinstance(item, RemoteIterator) for item in data_iterator)
139+
return any(isinstance(item, RemoteIteratorWrapper) for item in data_iterator)
140140
return False
141141

142142

@@ -619,7 +619,7 @@ def map_to_pspec(data):
619619
None,
620620
)
621621
# Case 2: Matches if dataset type is "grain" and the data iterator is not a
622-
# PlaceHolderDataIterator or RemoteIterator and a specific checkpoint file exists for the iterator
622+
# PlaceHolderDataIterator or RemoteIteratorWrapper and a specific checkpoint file exists for the iterator
623623
case (
624624
checkpoint_manager,
625625
dataset_type,

src/maxtext/input_pipeline/grain_data_processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def make_grain_train_iterator(
454454
)
455455
if config.colocated_python_data_input:
456456
global_shape = (config.global_batch_size_to_load, config.max_target_length)
457-
return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape)
457+
return multihost_dataloading.RemoteIteratorWrapper(get_ds_fn, preprocessing_fn, global_mesh, global_shape)
458458
else:
459459
# config.expansion_factor_real_data is between 0 and 1
460460
num_dataloader_to_restore = int(1 / config.expansion_factor_real_data)
@@ -567,4 +567,4 @@ def make_grain_eval_iterator(
567567
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
568568
)
569569
global_shape = (config.global_batch_size_to_load, config.max_target_length)
570-
return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape)
570+
return multihost_dataloading.RemoteIteratorWrapper(get_ds_fn, preprocessing_fn, global_mesh, global_shape)

src/maxtext/input_pipeline/multihost_dataloading.py

Lines changed: 100 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,53 @@ def _colocated_cpu_mesh(mesh: Mesh) -> Mesh:
205205
return colocated_python.colocated_cpu_devices(mesh)
206206

207207

208+
@colocated_python.colocated_python_class
208209
class RemoteIterator:
209-
"iterator class for using colocated python, iterator is initiated remotely and stored in the state of colocated python"
210+
"iterator class for using colocated python class"
211+
212+
def __init__(self, get_ds_fn, preprocessing_fn, global_shape):
213+
# self.cpu_devices = _colocated_cpu_devices(jax.local_devices())
214+
# self.tpu_devices = jax.local_devices()
215+
# self.cpu_mesh = _colocated_cpu_mesh(global_mesh)
216+
# self.tpu_sharding = jax.sharding.NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names))
217+
# self.cpu_sharding = jax.sharding.NamedSharding(self.cpu_mesh, PartitionSpec(self.cpu_mesh.axis_names))
218+
self.global_shape = global_shape
219+
ds = get_ds_fn(dataloading_host_index=jax.process_index(), dataloading_host_count=jax.process_count())
220+
dataloader = preprocessing_fn(dataset=ds)
221+
self.iterator = None
222+
if isinstance(dataloader, tf.data.Dataset):
223+
self.iterator = dataloader.as_numpy_iterator()
224+
elif isinstance(dataloader, Iterable):
225+
self.iterator = iter(dataloader)
226+
else:
227+
raise ValueError("Type error: dataloader should be Iterable.")
210228

229+
def get_next(self, dummy_array):
230+
local_data = next(self.iterator)
231+
def form_global_array_colocated_python(path, array, devices, global_shape, sharding):
232+
try:
233+
device_arrays = np.split(array, len(devices), axis=0)
234+
except ValueError as array_split_error:
235+
raise ValueError(
236+
f"Unable to put to devices shape {array.shape} with "
237+
f"local device count {len(devices)} "
238+
f"at {jtu.keystr(path)}"
239+
) from array_split_error
240+
device_arrays = jax.device_put(device_arrays, devices)
241+
return jax.make_array_from_single_device_arrays(shape=global_shape, sharding=sharding, arrays=device_arrays)
242+
243+
return jtu.tree_map_with_path(
244+
partial(
245+
form_global_array_colocated_python,
246+
devices=list(dummy_array.sharding.addressable_devices),
247+
global_shape=self.global_shape,
248+
sharding=dummy_array.sharding,
249+
),
250+
local_data,
251+
)
252+
253+
254+
class RemoteIteratorWrapper:
211255
def __init__(self, get_ds_fn, preprocessing_fn, global_mesh, global_shape):
212256
self.cpu_devices = _colocated_cpu_devices(jax.local_devices())
213257
self.tpu_devices = jax.local_devices()
@@ -216,42 +260,60 @@ def __init__(self, get_ds_fn, preprocessing_fn, global_mesh, global_shape):
216260
self.cpu_sharding = jax.sharding.NamedSharding(self.cpu_mesh, PartitionSpec(self.cpu_mesh.axis_names))
217261
self.dummy_array = jnp.zeros((len(self.cpu_devices)))
218262
self.dummy_array = jax.device_put(self.dummy_array, self.cpu_sharding)
219-
220-
@colocated_python.colocated_python
221-
def init(dummy_array):
222-
colocated_python.global_shape = global_shape
223-
ds = get_ds_fn(dataloading_host_index=jax.process_index(), dataloading_host_count=jax.process_count())
224-
dataloader = preprocessing_fn(dataset=ds)
225-
if isinstance(dataloader, tf.data.Dataset):
226-
colocated_python.iterator = dataloader.as_numpy_iterator()
227-
elif isinstance(dataloader, Iterable):
228-
colocated_python.iterator = iter(dataloader)
229-
else:
230-
raise ValueError("Type error: dataloader should be either tf.data.Dataset or grain.DataLoader.")
231-
return dummy_array
232-
233-
max_logging.log("Initiating RemoteIterator")
234-
try:
235-
out = jax.device_get(init(self.dummy_array))
236-
if out is not None:
237-
max_logging.log(f"RemoteIterator initiated. Test output: {out}")
238-
except Exception as e:
239-
max_logging.log(f"RemoteIterator init FAILED with error: {type(e).__name__}: {e}")
240-
raise
241-
242-
def __iter__(self):
243-
return self
263+
self.remote_iterator = RemoteIterator(get_ds_fn, preprocessing_fn, global_shape)
244264

245265
def __next__(self):
246-
out = _get_next(self.dummy_array)
247-
248-
def put_to_tpu_devices(path, array, sharding):
249-
try:
250-
return jax.device_put(array, sharding)
251-
except Exception as e: # pylint: disable=broad-exception-caught
252-
max_logging.log(f"Error putting data to TPU device path{path}, exception={e}")
253-
raise
254-
255-
input_gdas = jtu.tree_map_with_path(partial(put_to_tpu_devices, sharding=self.tpu_sharding), out)
256-
257-
return input_gdas
266+
out = self.remote_iterator.get_next(self.dummy_array)
267+
# use tree_map is out is a dict
268+
return jax.device_put(out, self.tpu_sharding)
269+
270+
# class RemoteIterator:
271+
# "iterator class for using colocated python, iterator is initiated remotely and stored in the state of colocated python"
272+
273+
# def __init__(self, get_ds_fn, preprocessing_fn, global_mesh, global_shape):
274+
# self.cpu_devices = _colocated_cpu_devices(jax.local_devices())
275+
# self.tpu_devices = jax.local_devices()
276+
# self.cpu_mesh = _colocated_cpu_mesh(global_mesh)
277+
# self.tpu_sharding = jax.sharding.NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names))
278+
# self.cpu_sharding = jax.sharding.NamedSharding(self.cpu_mesh, PartitionSpec(self.cpu_mesh.axis_names))
279+
# self.dummy_array = jnp.zeros((len(self.cpu_devices)))
280+
# self.dummy_array = jax.device_put(self.dummy_array, self.cpu_sharding)
281+
282+
# @colocated_python.colocated_python
283+
# def init(dummy_array):
284+
# colocated_python.global_shape = global_shape
285+
# ds = get_ds_fn(dataloading_host_index=jax.process_index(), dataloading_host_count=jax.process_count())
286+
# dataloader = preprocessing_fn(dataset=ds)
287+
# if isinstance(dataloader, tf.data.Dataset):
288+
# colocated_python.iterator = dataloader.as_numpy_iterator()
289+
# elif isinstance(dataloader, Iterable):
290+
# colocated_python.iterator = iter(dataloader)
291+
# else:
292+
# raise ValueError("Type error: dataloader should be either tf.data.Dataset or grain.DataLoader.")
293+
# return dummy_array
294+
295+
# max_logging.log("Initiating RemoteIterator")
296+
# try:
297+
# out = jax.device_get(init(self.dummy_array))
298+
# if out is not None:
299+
# max_logging.log(f"RemoteIterator initiated. Test output: {out}")
300+
# except Exception as e:
301+
# max_logging.log(f"RemoteIterator init FAILED with error: {type(e).__name__}: {e}")
302+
# raise
303+
304+
# def __iter__(self):
305+
# return self
306+
307+
# def __next__(self):
308+
# out = _get_next(self.dummy_array)
309+
310+
# def put_to_tpu_devices(path, array, sharding):
311+
# try:
312+
# return jax.device_put(array, sharding)
313+
# except Exception as e: # pylint: disable=broad-exception-caught
314+
# max_logging.log(f"Error putting data to TPU device path{path}, exception={e}")
315+
# raise
316+
317+
# input_gdas = jtu.tree_map_with_path(partial(put_to_tpu_devices, sharding=self.tpu_sharding), out)
318+
319+
# return input_gdas

src/maxtext/input_pipeline/tfds_data_processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def make_tfds_train_iterator(
232232
hf_access_token=config.hf_access_token,
233233
)
234234
global_shape = (config.global_batch_size_to_load, config.max_target_length)
235-
return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape)
235+
return multihost_dataloading.RemoteIteratorWrapper(get_ds_fn, preprocessing_fn, global_mesh, global_shape)
236236

237237

238238
def make_tfds_eval_iterator(
@@ -296,4 +296,4 @@ def make_tfds_eval_iterator(
296296
use_dpo=config.use_dpo,
297297
hf_access_token=config.hf_access_token,
298298
)
299-
return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, config, global_mesh)
299+
return multihost_dataloading.RemoteIteratorWrapper(get_ds_fn, preprocessing_fn, config, global_mesh)

0 commit comments

Comments
 (0)