11import logging
2+ import pickle
23import time
34from dataclasses import asdict , dataclass , field
5+ from pathlib import Path
46from typing import Any , Dict , List , Optional , Tuple , Type , Union
57
68try :
@@ -75,6 +77,7 @@ class TrainingConfig:
7577 trainer_params : Optional [dict ] = None
7678 optimizer_params : Optional [dict ] = None
7779 scheduler_params : Optional [dict ] = None
80+ save_path : Optional [str ] = "my_ttc"
7881
7982 def to_dict (self ) -> Dict [str , Any ]:
8083 data = asdict (self )
@@ -362,6 +365,7 @@ def train(
362365 logger .info (f"Training completed in { end - start :.2f} seconds." )
363366
364367 best_model_path = trainer .checkpoint_callback .best_model_path
368+ self .checkpoint_path = best_model_path
365369
366370 self .lightning_module = TextClassificationModule .load_from_checkpoint (
367371 best_model_path ,
@@ -372,6 +376,9 @@ def train(
372376
373377 self .pytorch_model = self .lightning_module .model .to (self .device )
374378
379+ self .save_path = training_config .save_path
380+ self .save (self .save_path )
381+
375382 self .lightning_module .eval ()
376383
377384 def _check_XY (self , X : np .ndarray , Y : np .ndarray ) -> Tuple [np .ndarray , np .ndarray ]:
@@ -576,6 +583,122 @@ def predict(
576583 "confidence" : confidence ,
577584 }
578585
586+ def save (self , path : Union [str , Path ]) -> None :
587+ """Save the complete torchTextClassifiers instance to disk.
588+
589+ This saves:
590+ - Model configuration
591+ - Tokenizer state
592+ - PyTorch Lightning checkpoint (if trained)
593+ - All other instance attributes
594+
595+ Args:
596+ path: Directory path where the model will be saved
597+
598+ Example:
599+ >>> ttc = torchTextClassifiers(tokenizer, model_config)
600+ >>> ttc.train(X_train, y_train, training_config)
601+ >>> ttc.save("my_model")
602+ """
603+ path = Path (path )
604+ path .mkdir (parents = True , exist_ok = True )
605+
606+ # Save the checkpoint if model has been trained
607+ checkpoint_path = None
608+ if hasattr (self , "lightning_module" ):
609+ checkpoint_path = path / "model_checkpoint.ckpt"
610+ # Save the current state as a checkpoint
611+ trainer = pl .Trainer ()
612+ trainer .strategy .connect (self .lightning_module )
613+ trainer .save_checkpoint (checkpoint_path )
614+
615+ # Prepare metadata to save
616+ metadata = {
617+ "model_config" : self .model_config .to_dict (),
618+ "ragged_multilabel" : self .ragged_multilabel ,
619+ "vocab_size" : self .vocab_size ,
620+ "embedding_dim" : self .embedding_dim ,
621+ "categorical_vocabulary_sizes" : self .categorical_vocabulary_sizes ,
622+ "num_classes" : self .num_classes ,
623+ "checkpoint_path" : str (checkpoint_path ) if checkpoint_path else None ,
624+ "device" : str (self .device ) if hasattr (self , "device" ) else None ,
625+ }
626+
627+ # Save metadata
628+ with open (path / "metadata.pkl" , "wb" ) as f :
629+ pickle .dump (metadata , f )
630+
631+ # Save tokenizer
632+ tokenizer_path = path / "tokenizer.pkl"
633+ with open (tokenizer_path , "wb" ) as f :
634+ pickle .dump (self .tokenizer , f )
635+
636+ logger .info (f"Model saved successfully to { path } " )
637+
638+ @classmethod
639+ def load (cls , path : Union [str , Path ], device : str = "auto" ) -> "torchTextClassifiers" :
640+ """Load a torchTextClassifiers instance from disk.
641+
642+ Args:
643+ path: Directory path where the model was saved
644+ device: Device to load the model on ('auto', 'cpu', 'cuda', etc.)
645+
646+ Returns:
647+ Loaded torchTextClassifiers instance
648+
649+ Example:
650+ >>> loaded_ttc = torchTextClassifiers.load("my_model")
651+ >>> predictions = loaded_ttc.predict(X_test)
652+ """
653+ path = Path (path )
654+
655+ if not path .exists ():
656+ raise FileNotFoundError (f"Model directory not found: { path } " )
657+
658+ # Load metadata
659+ with open (path / "metadata.pkl" , "rb" ) as f :
660+ metadata = pickle .load (f )
661+
662+ # Load tokenizer
663+ with open (path / "tokenizer.pkl" , "rb" ) as f :
664+ tokenizer = pickle .load (f )
665+
666+ # Reconstruct model_config
667+ model_config = ModelConfig .from_dict (metadata ["model_config" ])
668+
669+ # Create instance
670+ instance = cls (
671+ tokenizer = tokenizer ,
672+ model_config = model_config ,
673+ ragged_multilabel = metadata ["ragged_multilabel" ],
674+ )
675+
676+ # Set device
677+ if device == "auto" :
678+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
679+ else :
680+ device = torch .device (device )
681+ instance .device = device
682+
683+ # Load checkpoint if it exists
684+ if metadata ["checkpoint_path" ]:
685+ checkpoint_path = path / "model_checkpoint.ckpt"
686+ if checkpoint_path .exists ():
687+ # Load the checkpoint with weights_only=False since it's our own trusted checkpoint
688+ instance .lightning_module = TextClassificationModule .load_from_checkpoint (
689+ str (checkpoint_path ),
690+ model = instance .pytorch_model ,
691+ weights_only = False ,
692+ )
693+ instance .pytorch_model = instance .lightning_module .model .to (device )
694+ instance .checkpoint_path = str (checkpoint_path )
695+ logger .info (f"Model checkpoint loaded from { checkpoint_path } " )
696+ else :
697+ logger .warning (f"Checkpoint file not found at { checkpoint_path } " )
698+
699+ logger .info (f"Model loaded successfully from { path } " )
700+ return instance
701+
579702 def __repr__ (self ):
580703 model_type = (
581704 self .lightning_module .__repr__ ()
0 commit comments