Skip to content

Commit 2949a3a

Browse files
authored
Merge pull request #610 from dushyantbehl/main
fix: typos
2 parents f41eb2c + b6aa877 commit 2949a3a

3 files changed

Lines changed: 32 additions & 36 deletions

File tree

build/nvcr.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ RUN python -m pip install --upgrade pip
4646
RUN pip install --upgrade --force-reinstall torch torchaudio torchvision --index-url https://download.pytorch.org/whl/cu128
4747

4848
# Install main package + flash attention
49-
RUN COPY . ${SOURCE_DIR}
49+
COPY . ${SOURCE_DIR}
5050
RUN cd ${SOURCE_DIR}
5151
RUN pip install --no-cache-dir ${SOURCE_DIR} && \
5252
pip install --no-cache-dir ${SOURCE_DIR}[flash-attn]

tuning/data/setup_dataprocessor.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def process_dataconfig_file(
159159

160160

161161
# Data Format 1: Pretokenized Data
162-
def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized):
162+
def _get_pretokenized_dataset_handlers(
163+
data_args: DataArguments, is_eval_present, is_eval_tokenized
164+
):
163165

164166
# if the provided train dataset is pretokenized
165167
# however user provides formatting flags, error out
@@ -168,6 +170,7 @@ def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized):
168170
or data_args.data_formatter_template
169171
or data_args.dataset_text_field
170172
or data_args.instruction_template
173+
or data_args.dataset_conversation_field
171174
):
172175
raise ValueError(
173176
"fields response_template, data_formatter_template,"
@@ -177,7 +180,7 @@ def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized):
177180

178181
# if the train dataset is pretokenized
179182
# ensure validation dataset is pretokenized otherwise error out
180-
if is_eval_tokenized:
183+
if is_eval_present and not is_eval_tokenized:
181184
raise ValueError(
182185
"validation data should be pretokenized to be used \
183186
along with pretokenized train data"
@@ -189,7 +192,9 @@ def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized):
189192

190193
### Data format 2
191194
# pylint: disable=unused-argument
192-
def _get_dataset_formatting_handlers(data_args, packing, is_padding_free=False):
195+
def _get_dataset_formatting_handlers(
196+
data_args: DataArguments, packing, is_padding_free=False
197+
):
193198

194199
if data_args.response_template is None:
195200
if packing is False:
@@ -253,7 +258,7 @@ def _get_chat_dataset_handlers(data_args, tokenizer_kwargs):
253258
fn_kwargs["formatted_text_column_name"] = data_args.dataset_text_field
254259
fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs
255260
if data_args.dataset_conversation_field is not None:
256-
fn_kwargs["conversation_column"] = data_args.dataset_conversation_field
261+
fn_kwargs["conversation_column_name"] = data_args.dataset_conversation_field
257262

258263
kwargs = {"fn_kwargs": fn_kwargs, "batched": False, "remove_columns": "all"}
259264

@@ -284,14 +289,14 @@ def _get_default_dataset_handlers(data_args, tokenizer_kwargs):
284289

285290

286291
### Vsion Data Format
287-
def _get_vision_dataset_handlers(data_args, processor_kwargs):
292+
def _get_vision_dataset_handlers(data_args: DataArguments, processor_kwargs):
288293

289294
handlers = []
290295

291296
# First data handler configuration
292297
handler_fn_kwargs1 = {
293-
"dataset_text_field": data_args.dataset_text_field,
294-
"conversation_column": data_args.dataset_text_field,
298+
"formatted_text_column_name": data_args.dataset_text_field,
299+
"conversation_column_name": data_args.dataset_conversation_field,
295300
}
296301
handler_kwargs1 = {
297302
"fn_kwargs": handler_fn_kwargs1,
@@ -403,7 +408,7 @@ def _process_raw_data_args(
403408
if is_traindata_tokenized:
404409
# Data Format 1: Pretokenized Data
405410
handlers, dataset_text_field = _get_pretokenized_dataset_handlers(
406-
data_args, (is_eval_dataset_present and not is_evaldata_tokenized)
411+
data_args, is_eval_dataset_present, is_evaldata_tokenized
407412
)
408413
elif processor and data_args.dataset_text_field and data_args.dataset_image_field:
409414

tuning/sft_trainer.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,16 @@ def train(
255255

256256
model_load_time = time.time()
257257
try:
258+
model_kwargs = dict( # pylint: disable=use-dict-literal
259+
cache_dir=train_args.cache_dir,
260+
torch_dtype=get_torch_dtype(model_args.torch_dtype),
261+
attn_implementation=model_args.flash_attn_implementation
262+
if model_args.use_flash_attn
263+
else None,
264+
)
265+
if quantization_config is not None:
266+
model_kwargs["quantization_config"] = quantization_config.to_hf_config()
267+
258268
logger.info("Loading the model {} now".format(model_args.model_name_or_path))
259269
try:
260270
logger.info(
@@ -263,18 +273,8 @@ def train(
263273
)
264274
)
265275
# try to load model as a vision model
266-
model_loader = AutoModelForVision2Seq.from_pretrained
267-
268-
model = model_loader(
269-
model_args.model_name_or_path,
270-
cache_dir=train_args.cache_dir,
271-
torch_dtype=get_torch_dtype(model_args.torch_dtype),
272-
quantization_config=quantization_config.to_hf_config()
273-
if quantization_config
274-
else None,
275-
attn_implementation=model_args.flash_attn_implementation
276-
if model_args.use_flash_attn
277-
else None,
276+
model = AutoModelForVision2Seq.from_pretrained(
277+
model_args.model_name_or_path, **model_kwargs
278278
)
279279
try:
280280
if "use_cache" in model.language_model.config:
@@ -290,10 +290,10 @@ def train(
290290
logger.info("Loaded vision model as {} ".format(model))
291291
logger.info("Loaded vision model processor {} ".format(processor))
292292
logger.info("Loaded model tokenizer {} ".format(tokenizer))
293-
except ValueError:
293+
except Exception as e: # pylint: disable=broad-except
294294
logger.info(
295-
"Couldn't load model {} as a vision model".format(
296-
model_args.model_name_or_path
295+
"Couldn't load model {} as a vision model due to {} ".format(
296+
model_args.model_name_or_path, e
297297
)
298298
)
299299
model = None
@@ -314,16 +314,7 @@ def train(
314314
model_loader = AutoModelForCausalLM.from_pretrained
315315

316316
model = model_loader(
317-
model_args.model_name_or_path,
318-
cache_dir=train_args.cache_dir,
319-
torch_dtype=get_torch_dtype(model_args.torch_dtype),
320-
quantization_config=quantization_config.to_hf_config()
321-
if quantization_config
322-
else None,
323-
attn_implementation=model_args.flash_attn_implementation
324-
if model_args.use_flash_attn
325-
else None,
326-
use_cache=False,
317+
model_args.model_name_or_path, use_cache=False, **model_kwargs
327318
)
328319

329320
# TODO: Move these to a config as well
@@ -757,12 +748,12 @@ def main():
757748
"Tune Config": tune_config,
758749
"Quantization Config": quantization_config,
759750
"QLoRA Config": quantized_lora_config,
760-
"Tracker Config": tracker_configs,
761751
"AADP (fms-acceleration) Config": attention_and_distributed_packing_config,
762752
"Fused Ops Kernels Config": fusedops_kernels_config,
763753
"Fast MoE Config": fast_moe_config,
764-
"Trainer Controller Config": trainer_controller_args,
754+
"Tracker Config": tracker_configs,
765755
"Extra Metadata": exp_metadata,
756+
"Trainer Controller Config": trainer_controller_args,
766757
}
767758
)
768759
logger.info(args_dump)

0 commit comments

Comments
 (0)