Skip to content

Commit cfa7439

Browse files
committed
Patching
1 parent fcbc174 commit cfa7439

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

squeez/training/train.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,24 @@ def _prepare_text_tokenizer(model_name: str, tokenizer):
8383
return tokenizer
8484

8585

86+
def _ensure_transformers_pretrained_config_alias(transformers_module=None) -> None:
87+
"""Work around Unsloth expecting PreTrainedConfig on newer Transformers."""
88+
if transformers_module is None:
89+
import transformers as transformers_module
90+
91+
if not hasattr(transformers_module, "PreTrainedConfig") and hasattr(
92+
transformers_module, "PretrainedConfig"
93+
):
94+
transformers_module.PreTrainedConfig = transformers_module.PretrainedConfig
95+
logger.info(
96+
"Aliased transformers.PreTrainedConfig to PretrainedConfig for Unsloth compatibility"
97+
)
98+
99+
86100
def train(args: argparse.Namespace):
87101
"""Run LoRA fine-tuning with Unsloth + SFTTrainer."""
102+
_ensure_transformers_pretrained_config_alias()
103+
88104
from unsloth import FastLanguageModel # noqa: I001
89105
from unsloth.chat_templates import train_on_responses_only
90106

0 commit comments

Comments
 (0)