Skip to content

Commit 604d74d

Browse files
feat!(save&load): enable save and load for ttc object
- implies pickling tokenizer, metada and using ckpt for lightning - adapt test
1 parent 6c2ea00 commit 604d74d

4 files changed

Lines changed: 128 additions & 1 deletion

File tree

tests/test_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
159159
y_val=Y,
160160
training_config=training_config,
161161
)
162+
ttc.load(ttc.save_path) # test load
162163

163164
# Predict with explanations
164165
top_k = 5

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def __init__(self, text_embedder_config: TextEmbedderConfig):
2323
self.config = text_embedder_config
2424

2525
self.attention_config = text_embedder_config.attention_config
26+
if isinstance(self.attention_config, dict):
27+
self.attention_config = AttentionConfig(**self.attention_config)
28+
2629
if self.attention_config is not None:
2730
self.attention_config.n_embd = text_embedder_config.embedding_dim
2831

torchTextClassifiers/model/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
scheduler_interval: Scheduler interval.
3737
"""
3838
super().__init__()
39-
self.save_hyperparameters(ignore=["model", "loss"])
39+
self.save_hyperparameters(ignore=["model"])
4040

4141
self.model = model
4242
self.loss = loss

torchTextClassifiers/torchTextClassifiers.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import logging
2+
import pickle
23
import time
34
from dataclasses import asdict, dataclass, field
5+
from pathlib import Path
46
from typing import Any, Dict, List, Optional, Tuple, Type, Union
57

68
try:
@@ -75,6 +77,7 @@ class TrainingConfig:
7577
trainer_params: Optional[dict] = None
7678
optimizer_params: Optional[dict] = None
7779
scheduler_params: Optional[dict] = None
80+
save_path: Optional[str] = "my_ttc"
7881

7982
def to_dict(self) -> Dict[str, Any]:
8083
data = asdict(self)
@@ -362,6 +365,7 @@ def train(
362365
logger.info(f"Training completed in {end - start:.2f} seconds.")
363366

364367
best_model_path = trainer.checkpoint_callback.best_model_path
368+
self.checkpoint_path = best_model_path
365369

366370
self.lightning_module = TextClassificationModule.load_from_checkpoint(
367371
best_model_path,
@@ -372,6 +376,9 @@ def train(
372376

373377
self.pytorch_model = self.lightning_module.model.to(self.device)
374378

379+
self.save_path = training_config.save_path
380+
self.save(self.save_path)
381+
375382
self.lightning_module.eval()
376383

377384
def _check_XY(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
@@ -576,6 +583,122 @@ def predict(
576583
"confidence": confidence,
577584
}
578585

586+
def save(self, path: Union[str, Path]) -> None:
587+
"""Save the complete torchTextClassifiers instance to disk.
588+
589+
This saves:
590+
- Model configuration
591+
- Tokenizer state
592+
- PyTorch Lightning checkpoint (if trained)
593+
- All other instance attributes
594+
595+
Args:
596+
path: Directory path where the model will be saved
597+
598+
Example:
599+
>>> ttc = torchTextClassifiers(tokenizer, model_config)
600+
>>> ttc.train(X_train, y_train, training_config)
601+
>>> ttc.save("my_model")
602+
"""
603+
path = Path(path)
604+
path.mkdir(parents=True, exist_ok=True)
605+
606+
# Save the checkpoint if model has been trained
607+
checkpoint_path = None
608+
if hasattr(self, "lightning_module"):
609+
checkpoint_path = path / "model_checkpoint.ckpt"
610+
# Save the current state as a checkpoint
611+
trainer = pl.Trainer()
612+
trainer.strategy.connect(self.lightning_module)
613+
trainer.save_checkpoint(checkpoint_path)
614+
615+
# Prepare metadata to save
616+
metadata = {
617+
"model_config": self.model_config.to_dict(),
618+
"ragged_multilabel": self.ragged_multilabel,
619+
"vocab_size": self.vocab_size,
620+
"embedding_dim": self.embedding_dim,
621+
"categorical_vocabulary_sizes": self.categorical_vocabulary_sizes,
622+
"num_classes": self.num_classes,
623+
"checkpoint_path": str(checkpoint_path) if checkpoint_path else None,
624+
"device": str(self.device) if hasattr(self, "device") else None,
625+
}
626+
627+
# Save metadata
628+
with open(path / "metadata.pkl", "wb") as f:
629+
pickle.dump(metadata, f)
630+
631+
# Save tokenizer
632+
tokenizer_path = path / "tokenizer.pkl"
633+
with open(tokenizer_path, "wb") as f:
634+
pickle.dump(self.tokenizer, f)
635+
636+
logger.info(f"Model saved successfully to {path}")
637+
638+
@classmethod
639+
def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassifiers":
640+
"""Load a torchTextClassifiers instance from disk.
641+
642+
Args:
643+
path: Directory path where the model was saved
644+
device: Device to load the model on ('auto', 'cpu', 'cuda', etc.)
645+
646+
Returns:
647+
Loaded torchTextClassifiers instance
648+
649+
Example:
650+
>>> loaded_ttc = torchTextClassifiers.load("my_model")
651+
>>> predictions = loaded_ttc.predict(X_test)
652+
"""
653+
path = Path(path)
654+
655+
if not path.exists():
656+
raise FileNotFoundError(f"Model directory not found: {path}")
657+
658+
# Load metadata
659+
with open(path / "metadata.pkl", "rb") as f:
660+
metadata = pickle.load(f)
661+
662+
# Load tokenizer
663+
with open(path / "tokenizer.pkl", "rb") as f:
664+
tokenizer = pickle.load(f)
665+
666+
# Reconstruct model_config
667+
model_config = ModelConfig.from_dict(metadata["model_config"])
668+
669+
# Create instance
670+
instance = cls(
671+
tokenizer=tokenizer,
672+
model_config=model_config,
673+
ragged_multilabel=metadata["ragged_multilabel"],
674+
)
675+
676+
# Set device
677+
if device == "auto":
678+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
679+
else:
680+
device = torch.device(device)
681+
instance.device = device
682+
683+
# Load checkpoint if it exists
684+
if metadata["checkpoint_path"]:
685+
checkpoint_path = path / "model_checkpoint.ckpt"
686+
if checkpoint_path.exists():
687+
# Load the checkpoint with weights_only=False since it's our own trusted checkpoint
688+
instance.lightning_module = TextClassificationModule.load_from_checkpoint(
689+
str(checkpoint_path),
690+
model=instance.pytorch_model,
691+
weights_only=False,
692+
)
693+
instance.pytorch_model = instance.lightning_module.model.to(device)
694+
instance.checkpoint_path = str(checkpoint_path)
695+
logger.info(f"Model checkpoint loaded from {checkpoint_path}")
696+
else:
697+
logger.warning(f"Checkpoint file not found at {checkpoint_path}")
698+
699+
logger.info(f"Model loaded successfully from {path}")
700+
return instance
701+
579702
def __repr__(self):
580703
model_type = (
581704
self.lightning_module.__repr__()

0 commit comments

Comments
 (0)