Skip to content

Commit f2216e2

Browse files
Merge pull request #3437 from AI-Hypercomputer:ajkv/sft-grain-implementation
PiperOrigin-RevId: 892619290
2 parents baec42f + c7bd12f commit f2216e2

4 files changed

Lines changed: 349 additions & 122 deletions

File tree

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility functions for data processing pipelines."""
16+
17+
import functools
18+
19+
import jax
20+
from grain.experimental import BestFitPackIterDataset, pick_performance_config
21+
import grain.python as grain
22+
23+
from maxtext.input_pipeline import input_pipeline_utils
24+
from maxtext.input_pipeline import tokenizer
25+
26+
27+
def parse_and_keep_features(dataset, config, data_columns, tokenize):
28+
"""Parse arrayrecord features or keep specified columns for other formats."""
29+
if config.grain_file_type in ("arrayrecord", "tfrecord"):
30+
dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize))
31+
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
32+
else:
33+
dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns))
34+
return dataset
35+
36+
37+
def get_tokenizer_and_pad_id(config):
38+
"""Builds tokenizer and extracts pad_id safely."""
39+
tokenizer_model = tokenizer.build_tokenizer(
40+
config.tokenizer_path,
41+
config.tokenizer_type,
42+
config.add_bos,
43+
config.add_eos,
44+
config.hf_access_token,
45+
)
46+
if tokenizer_model.pad_id is not None:
47+
pad_id = tokenizer_model.pad_id
48+
elif tokenizer_model.unk_id is not None:
49+
pad_id = tokenizer_model.unk_id
50+
else:
51+
pad_id = 0
52+
return tokenizer_model, pad_id
53+
54+
55+
def validate_and_configure_sft_columns(data_columns, tokenizer_model, chat_template=None):
56+
"""Validates SFT data columns and configures the tokenizer chat template."""
57+
if chat_template and hasattr(tokenizer_model, "chat_template"):
58+
tokenizer_model.chat_template = chat_template
59+
60+
supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]]
61+
assert any(
62+
set(data_columns) == set(supported) for supported in supported_columns
63+
), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_columns}"
64+
65+
66+
def get_local_batch_size(config):
67+
"""Computes local batch size based on process count and expansion factor."""
68+
batch_size = config.global_batch_size_to_load // jax.process_count()
69+
if config.expansion_factor_real_data > 1:
70+
# global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1.
71+
# But when using Grain, we want to keep the batch_size consistent with that in the checkpoint.
72+
# We revert the batch_size expansion here, but load multiple batches per step in multihost_dataloading.py.
73+
batch_size = int(batch_size // config.expansion_factor_real_data)
74+
return batch_size
75+
76+
77+
def format_and_batch(dataset, config, batch_size, pad_id, data_columns, tokenizer_model):
78+
"""Packs or pads the dataset according to config and batches it."""
79+
if config.packing:
80+
length_struct = {col: config.max_target_length for col in data_columns}
81+
max_segments = config.max_segments_per_seq
82+
if max_segments is not None and max_segments <= 0:
83+
max_segments = None
84+
if config.grain_packing_type == "first_fit":
85+
dataset = grain.experimental.FirstFitPackIterDataset(
86+
dataset,
87+
length_struct=length_struct,
88+
num_packing_bins=batch_size,
89+
max_sequences_per_bin=max_segments,
90+
)
91+
elif config.grain_packing_type == "best_fit":
92+
dataset = BestFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=batch_size)
93+
elif config.grain_packing_type == "concat_then_split":
94+
if config.add_bos and hasattr(tokenizer_model, "bos_id"):
95+
dataset = grain.experimental.ConcatThenSplitIterDataset(
96+
dataset,
97+
length_struct=length_struct,
98+
bos_handling=grain.experimental.BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS,
99+
bos_token_id=tokenizer_model.bos_id,
100+
)
101+
else:
102+
dataset = grain.experimental.ConcatThenSplitIterDataset(dataset, length_struct=length_struct)
103+
else:
104+
raise ValueError(f"Unknown packing type: {config.packing}")
105+
106+
rekey_dict = {
107+
"targets_segmentation": "targets_segment_ids",
108+
"inputs_segmentation": "inputs_segment_ids",
109+
"targets_position": "targets_positions",
110+
"inputs_position": "inputs_positions",
111+
}
112+
dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict))
113+
else:
114+
dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
115+
116+
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
117+
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
118+
return dataset
119+
120+
121+
def shift_dataset(dataset, pad_id):
122+
"""Shift tokens to create inputs and targets for standard next-token prediction."""
123+
return dataset.map(
124+
input_pipeline_utils.ShiftData(
125+
ignored_ids=[pad_id],
126+
axis=1,
127+
)
128+
)
129+
130+
131+
def apply_multiprocessing_and_prefetch(dataset, config, grain_worker_count, grain_per_worker_buffer_size):
132+
"""Applies multiprocessing and prefetching configurations to the dataset."""
133+
multiprocessing_options = (
134+
pick_performance_config(
135+
ds=dataset,
136+
ram_budget_mb=config.grain_ram_budget_mb,
137+
max_workers=None,
138+
max_buffer_size=None,
139+
).multiprocessing_options
140+
if grain_worker_count == -1
141+
else grain.MultiprocessingOptions(
142+
num_workers=grain_worker_count,
143+
per_worker_buffer_size=grain_per_worker_buffer_size,
144+
)
145+
)
146+
return dataset.mp_prefetch(multiprocessing_options)

0 commit comments

Comments
 (0)