Skip to content

Commit 02ae108

Browse files
committed
Apply PR huggingface#45055 fix for Trainer checkpoint configs
Direct merge conflicted after Trainer refactors; applied the minimal config-saving change from 57cb2b9.
1 parent c3d2b83 commit 02ae108

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

src/transformers/trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3834,15 +3834,16 @@ def _save(self, output_dir: str | None = None, state_dict: dict | None = None) -
38343834
if state_dict is None:
38353835
state_dict = self.model.state_dict()
38363836

3837-
if isinstance(self.accelerator.unwrap_model(self.model, keep_torch_compile=False), supported_classes):
3838-
self.accelerator.unwrap_model(self.model, keep_torch_compile=False).save_pretrained(
3839-
output_dir, state_dict=state_dict
3840-
)
3837+
unwrapped_model = self.accelerator.unwrap_model(self.model, keep_torch_compile=False)
3838+
if isinstance(unwrapped_model, supported_classes):
3839+
unwrapped_model.save_pretrained(output_dir, state_dict=state_dict)
38413840
else:
38423841
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
38433842
safetensors.torch.save_file(
38443843
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
38453844
)
3845+
if hasattr(unwrapped_model, "config") and unwrapped_model.config is not None:
3846+
unwrapped_model.config.save_pretrained(output_dir)
38463847
else:
38473848
self.model.save_pretrained(output_dir, state_dict=state_dict)
38483849

0 commit comments

Comments
 (0)