Skip to content

Commit 4e8d774

Browse files
committed
Merge branch 'main' into save-peft-fast-moe-limited
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
2 parents 0f7796e + ebe35a3 commit 4e8d774

2 files changed

Lines changed: 8 additions & 15 deletions

File tree

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -919,12 +919,12 @@ For information on supported dataset formats and how to tune a vision-language m
919919

920920
? May be supported, but not tested
921921

922-
Model Name & Size | Model Architecture | LoRA Tuning | Full Finetuning |
923-
-------------------- | ---------------- | --------------- | --------------- |
924-
Llama 3.2-11B Vision | MllamaForConditionalGeneration | ✅* |* |
925-
Llava 1.5-7B | LlavaForConditionalGeneration | ✅* | 🚫 |
926-
Granite 3.1-2B Vision | LlavaNextForConditionalGeneration | ✅* | 🚫 |
927-
Llava Mistral 1.6-7B | LlavaNextForConditionalGeneration | ✅* | 🚫 |
922+
Model Name & Size | Model Architecture | Full Finetuning |
923+
-------------------- | ---------------- | --------------- |
924+
Llama 3.2-11B Vision | MllamaForConditionalGeneration | ✅* |
925+
Llava 1.5-7B | LlavaForConditionalGeneration | ✅* |
926+
Granite 3.1-2B Vision | LlavaNextForConditionalGeneration | ✅* |
927+
Llava Mistral 1.6-7B | LlavaNextForConditionalGeneration | ✅* |
928928

929929
(*) - Supported with `fms-hf-tuning` v2.8.0 or later.
930930

tuning/sft_trainer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,16 +264,9 @@ def train(
264264
attn_implementation="flash_attention_2"
265265
if model_args.use_flash_attn
266266
else None,
267+
# avoid warning that use_cache is incompatible with gradient checkpointing
268+
use_cache=(not train_args.gradient_checkpointing),
267269
)
268-
try:
269-
if "use_cache" in model.language_model.config:
270-
# avoid warning that use_cache is incompatible with gradient checkpointing
271-
model.language_model.config.use_cache = (
272-
not train_args.gradient_checkpointing
273-
)
274-
except AttributeError as e:
275-
# When the model doesn't have the use_cache attribute
276-
logger.warning("Couldn't update use_cache for vision model: %s", e)
277270

278271
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path)
279272
tokenizer = processor.tokenizer

0 commit comments

Comments
 (0)