forked from foundation-model-stack/fms-hf-tuning
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathoffline_data_processing.py
More file actions
235 lines (204 loc) · 7.91 KB
/
offline_data_processing.py
File metadata and controls
235 lines (204 loc) · 7.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# Standard
import logging
import os
import sys
import traceback
# Third Party
from transformers import AutoTokenizer
# Local
from tuning.config import configs
from tuning.data.setup_dataprocessor import process_dataargs
from tuning.sft_trainer import get_parser
from tuning.utils.error_logging import USER_ERROR_EXIT_CODE, write_termination_log
from tuning.utils.logging import set_log_level
from tuning.utils.tokenizer_data_utils import get_special_tokens_dict
def save_dataset_shards(
dataset, output_dir: str, num_shards: int, dataset_name: str
) -> None:
"""
Saves the given dataset in the specified number of shards.
Args:
dataset: The dataset to shard and save.
output_dir (str): Directory to save the dataset shards.
num_shards (int): Number of shards to create.
dataset_name (str): Name of the dataset (used for logging).
"""
os.makedirs(output_dir, exist_ok=True)
for shard_idx in range(num_shards):
shard = dataset.shard(index=shard_idx, num_shards=num_shards)
shard_path = os.path.join(output_dir, f"ds_{shard_idx:05d}.parquet")
shard.to_parquet(shard_path)
logging.info("Dumped %d shards of %s at %s", num_shards, dataset_name, output_dir)
def get_processed_dataset(
model_args: configs.ModelArguments,
data_args: configs.DataArguments,
train_args: configs.TrainingArguments,
):
"""
Processes the dataset based on data config yaml.
Args:
model_args (configs.ModelArguments): Model configuration arguments.
data_args (configs.DataArguments): Data configuration arguments.
train_args (configs.TrainingArguments): Training configuration arguments.
Returns:
tuple: A tuple containing the formatted training dataset and validation dataset.
"""
# Set log level for this function
train_args, logger = set_log_level(train_args, "get_processed_dataset")
logger.info(
"Starting dataset processing with model_args: %s, data_args: %s, training_args: %s",
model_args,
data_args,
train_args,
)
# Load tokenizer for the model
tokenizer_path = model_args.tokenizer_name_or_path or model_args.model_name_or_path
logger.debug("Loading tokenizer from %s", tokenizer_path)
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
cache_dir=train_args.cache_dir,
use_fast=True,
legacy=True,
)
logger.debug("Tokenizer loaded successfully.")
# Add chat_template to the tokenizer if provided
if data_args.chat_template:
data_args.chat_template = data_args.chat_template.replace(r"\n", "\n")
logger.info("Adding chat_template to the tokenizer")
if tokenizer.chat_template:
logger.warning(
"replacing existing chat_template %s with the given chat_template %s",
tokenizer.chat_template,
data_args.chat_template,
)
tokenizer.chat_template = data_args.chat_template
# Prepare special tokens dictionary
special_tokens_dict = get_special_tokens_dict(
tokenizer_name_or_path=model_args.tokenizer_name_or_path, tokenizer=tokenizer
)
# adds user specified special tokens to vocab
if data_args.add_special_tokens:
logger.info(
"Adding user-defined special tokens: %s ", data_args.add_special_tokens
)
special_tokens_dict["additional_special_tokens"] = data_args.add_special_tokens
if special_tokens_dict:
logger.info("Adding special tokens: %s", special_tokens_dict)
tokenizer.add_special_tokens(
special_tokens_dict=special_tokens_dict,
replace_additional_special_tokens=False,
)
# Process data using the provided arguments and tokenizer
logger.info("Calling process_dataargs to format datasets.")
(
formatted_train_dataset,
formatted_validation_dataset,
_,
_,
_,
_,
) = process_dataargs(data_args, tokenizer, train_args)
logger.info("Dataset processing completed successfully.")
return formatted_train_dataset, formatted_validation_dataset
def main():
"""
Main function that parses arguments, processes datasets, and saves the output.
"""
logger = logging.getLogger()
logger.info("Starting Data Processing script execution.")
parser = get_parser()
parser.add_argument(
"--num_dataset_shards",
type=int,
default=1,
help="Number of shards to be used for saving the dataset.",
)
try:
parsed_output = parser.parse_args_into_dataclasses()
# Extract arguments based on type
arg_types = {
configs.ModelArguments: "model_args",
configs.DataArguments: "data_args",
configs.TrainingArguments: "training_args",
}
args = {key: None for key in arg_types.values()}
for item in parsed_output:
for arg_class, key in arg_types.items():
if isinstance(item, arg_class):
args[key] = item
# Extract additional namespace argument
num_dataset_shards = next(
(
item.num_dataset_shards
for item in parsed_output
if hasattr(item, "num_dataset_shards")
),
1,
)
if None in args.values():
raise ValueError(
"One of the arguments is None. Please check the arguments passed."
)
logger.debug(
"Input args parsed:\n model_args: %s\n data_args: %s\n training_args: %s\n Shards: %d",
args["model_args"],
args["data_args"],
args["training_args"],
num_dataset_shards,
)
args["training_args"], logger = set_log_level(args["training_args"], __name__)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Error parsing arguments: %s", traceback.format_exc())
write_termination_log(f"Exception raised during argument parsing: {e}")
sys.exit(USER_ERROR_EXIT_CODE)
try:
logger.info("Processing dataset.")
formatted_train_dataset, formatted_validation_dataset = get_processed_dataset(
model_args=args["model_args"],
data_args=args["data_args"],
train_args=args["training_args"],
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Error processing dataset: %s", traceback.format_exc())
write_termination_log(f"Exception raised during dataset processing: {e}")
sys.exit(USER_ERROR_EXIT_CODE)
# Save train dataset shards
train_dataset_dir = os.path.join(args["training_args"].output_dir, "train_dataset")
logging.info(
"Trying to dump %d shards of train dataset at %s",
num_dataset_shards,
train_dataset_dir,
)
if formatted_train_dataset is not None:
save_dataset_shards(
formatted_train_dataset,
train_dataset_dir,
num_dataset_shards,
"train_dataset",
)
else:
logging.warning("Train dataset is None. Not saving train dataset.")
# Save validation dataset shards
validation_dataset_dir = os.path.join(
args["training_args"].output_dir, "validation_dataset"
)
logging.info(
"Trying to dump %d shards of validation dataset at %s",
num_dataset_shards,
validation_dataset_dir,
)
if formatted_validation_dataset is not None:
save_dataset_shards(
formatted_validation_dataset,
validation_dataset_dir,
num_dataset_shards,
"validation_dataset",
)
else:
logging.warning("Validation dataset is None. Not saving validation dataset.")
logger.info(
"Data Processing script execution completed. Data saved in %s directory",
args["training_args"].output_dir,
)
if __name__ == "__main__":
main()