Skip to content

Commit 1e78997

Browse files
committed
PR Change 1
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
1 parent 21de4a0 commit 1e78997

5 files changed

Lines changed: 29 additions & 35 deletions

File tree

tuning/config/configs.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,6 @@ class DataArguments:
145145
Add special tokens as new tokens and increase vocabulary and model embedding size."
146146
},
147147
)
148-
use_streaming_dataset: bool = field(
149-
default=False,
150-
metadata={
151-
"help": "Use of Streaming with Iterable dataset to be enabled, default is False"
152-
},
153-
)
154148

155149
def __post_init__(self):
156150
def unescape(s):

tuning/data/data_handlers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ def tokenize_and_apply_chat_template_with_masking(
627627
"prepare_multimodal_data_processor": DataHandler(
628628
op=prepare_multimodal_data_processor,
629629
handler_type=DataHandlerType.MAP,
630-
allows_batching=True,
630+
allows_batching=False,
631631
),
632632
"tokenize": DataHandler(
633633
op=tokenize,

tuning/data/data_processors.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
class DataPreProcessor:
4444

4545
tokenizer = None
46-
image_processor = None
46+
processor = None
4747
data_config: DataConfig = None
4848
processor_config: DataPreProcessorConfig = None
4949
registered_handlers: Dict[str, DataHandler] = None
@@ -52,10 +52,10 @@ def __init__(
5252
self,
5353
processor_config: DataPreProcessorConfig,
5454
tokenizer: AutoTokenizer,
55-
image_processor: AutoProcessor = None,
55+
processor: AutoProcessor = None,
5656
):
5757
self.tokenizer = tokenizer
58-
self.image_processor = image_processor
58+
self.processor = processor
5959
self.processor_config = processor_config
6060

6161
# Initialize other objects
@@ -376,7 +376,7 @@ def _process_dataset_configs(
376376
kwargs["fn_kwargs"] = {}
377377

378378
kwargs["fn_kwargs"]["tokenizer"] = self.tokenizer
379-
kwargs["fn_kwargs"]["processor"] = self.image_processor
379+
kwargs["fn_kwargs"]["processor"] = self.processor
380380
kwargs["fn_kwargs"]["column_names"] = column_names
381381

382382
kwargs["fn_kwargs"] = dict(kwargs["fn_kwargs"], **extra_kwargs)
@@ -457,13 +457,13 @@ def process_dataset_configs(
457457
def get_datapreprocessor(
458458
processor_config: DataPreProcessorConfig,
459459
tokenizer: AutoTokenizer,
460-
image_processor: AutoProcessor = None,
460+
processor: AutoProcessor = None,
461461
additional_data_handlers: Dict[str, DataHandler] = None,
462462
) -> DataPreProcessor:
463-
processor = DataPreProcessor(
463+
data_processor = DataPreProcessor(
464464
processor_config=processor_config,
465465
tokenizer=tokenizer,
466-
image_processor=image_processor,
466+
processor=processor,
467467
)
468-
processor.register_data_handlers(additional_data_handlers)
469-
return processor
468+
data_processor.register_data_handlers(additional_data_handlers)
469+
return data_processor

tuning/data/setup_dataprocessor.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]):
4949

5050
if isinstance(data, str):
5151
# Create a data processor with default processor config
52-
processor = get_datapreprocessor(
52+
data_processor = get_datapreprocessor(
5353
processor_config=DataPreProcessorConfig(), tokenizer=None
5454
)
55-
data = processor.load_dataset(
55+
data = data_processor.load_dataset(
5656
None,
5757
streaming=False,
5858
splitName="train[:1]",
@@ -73,23 +73,23 @@ def _process_dataconfig_file(
7373
is_multipack: bool = False,
7474
):
7575
data_config = load_and_validate_data_config(data_args.data_config_path)
76-
processor = get_datapreprocessor(
76+
data_processor = get_datapreprocessor(
7777
processor_config=data_config.dataprocessor,
7878
tokenizer=tokenizer,
79-
image_processor=processor,
79+
processor=processor,
8080
additional_data_handlers=additional_data_handlers,
8181
)
8282

83-
if processor.processor_config.chat_template is not None:
83+
if data_processor.processor_config.chat_template is not None:
8484
if tokenizer.chat_template:
8585
logger.warning(
8686
"replacing existing chat_template %s with data config's chat_template %s",
8787
tokenizer.chat_template,
88-
processor.processor_config.chat_template,
88+
data_processor.processor_config.chat_template,
8989
)
90-
tokenizer.chat_template = processor.processor_config.chat_template
90+
tokenizer.chat_template = data_processor.processor_config.chat_template
9191

92-
if processor.processor_config.streaming:
92+
if data_processor.processor_config.streaming:
9393
if train_args.max_steps < 1:
9494
logging.error(
9595
"ValueError: `--max_steps` must be set when streaming is set in data \
@@ -108,7 +108,7 @@ def _process_dataconfig_file(
108108
"Multipack is not compatible with streaming=true please set streaming=false "
109109
"or disable multipack sampler"
110110
)
111-
train_dataset = processor.process_dataset_configs(data_config.datasets)
111+
train_dataset = data_processor.process_dataset_configs(data_config.datasets)
112112

113113
return (train_dataset, None, data_args.dataset_text_field)
114114

@@ -239,17 +239,16 @@ def _get_vision_dataset_handlers(data_args, processor_kwargs):
239239
handlers = []
240240

241241
# First data handler configuration
242-
fn_kwargs1 = {
242+
handler_fn_kwargs1 = {
243243
"dataset_text_field": data_args.dataset_text_field,
244244
"conversation_column": data_args.dataset_text_field,
245245
}
246-
kwargs1 = {
247-
"fn_kwargs": fn_kwargs1,
248-
"batched": False,
246+
handler_kwargs1 = {
247+
"fn_kwargs": handler_fn_kwargs1,
249248
"remove_columns": None,
250249
}
251250
handlers.append(
252-
DataHandlerConfig("apply_tokenizer_chat_template", arguments=kwargs1)
251+
DataHandlerConfig("apply_tokenizer_chat_template", arguments=handler_kwargs1)
253252
)
254253

255254
# Second data handler configuration
@@ -262,8 +261,6 @@ def _get_vision_dataset_handlers(data_args, processor_kwargs):
262261
}
263262
kwargs2 = {
264263
"fn_kwargs": fn_kwargs2,
265-
"batched": False,
266-
"num_proc": None,
267264
}
268265
handlers.append(
269266
DataHandlerConfig("prepare_multimodal_data_processor", arguments=kwargs2)
@@ -297,11 +294,10 @@ def _process_raw_data_args(
297294

298295
# Create a data processor with default processor config
299296
default_processor_config = DataPreProcessorConfig()
300-
default_processor_config.streaming = data_args.use_streaming_dataset
301297
data_processor = get_datapreprocessor(
302298
processor_config=default_processor_config,
303299
tokenizer=tokenizer,
304-
image_processor=processor,
300+
processor=processor,
305301
additional_data_handlers=additional_data_handlers,
306302
)
307303
assert isinstance(
@@ -488,6 +484,7 @@ def process_dataargs(
488484
)
489485

490486
dataset_kwargs = {}
487+
# For vision model tuning prepare_dataset is skipped.
491488
if processor is not None:
492489
dataset_kwargs["skip_prepare_dataset"] = True
493490

tuning/sft_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ def train(
238238
attn_implementation="flash_attention_2"
239239
if model_args.use_flash_attn
240240
else None,
241+
# avoid warning that use_cache is incompatible with gradient checkpointing
242+
use_cache=(not train_args.gradient_checkpointing),
241243
)
242244

243245
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path)
@@ -256,6 +258,8 @@ def train(
256258
attn_implementation="flash_attention_2"
257259
if model_args.use_flash_attn
258260
else None,
261+
# avoid warning that use_cache is incompatible with gradient checkpointing
262+
use_cache=(not train_args.gradient_checkpointing),
259263
)
260264

261265
# TODO: Move these to a config as well
@@ -268,7 +272,6 @@ def train(
268272
cache_dir=train_args.cache_dir,
269273
use_fast=True,
270274
legacy=True,
271-
use_cache=(not train_args.gradient_checkpointing),
272275
)
273276
except Exception as e: # pylint: disable=broad-except
274277
logger.error(traceback.format_exc())

0 commit comments

Comments
 (0)