11from abc import ABC , abstractmethod
22from dataclasses import dataclass , field
3- from typing import Optional , Dict , List , Type , Any , ClassVar , Union , Tuple , Callable , Literal
3+ from typing import Optional , Dict , List , Type , Any , Union , Tuple
44import torch
55import os
66import numpy as np
77from ..utils .checkpoint import LocalCheckpointManager , HuggingFaceCheckpointManager
88from ..utils .checker import MolecularInputChecker
99
10- @dataclass
1110class BaseModel (ABC ):
1211 """Base class for molecular models with shared functionality.
1312
1413 This abstract class provides common methods and utilities for molecular models,
1514 including model initialization, saving/loading, and parameter management.
16- """
1715
18- device : Optional [torch .device ] = field (default = None )
19- model_name : str = field (default = "BaseModel" )
20- model_class : Optional [Type [torch .nn .Module ]] = field (default = None , init = False ) # used for model initialization
21- model : Optional [torch .nn .Module ] = field (default = None , init = False ) # initialized model
22- is_fitted_ : bool = field (default = False , init = False )
23-
24- def __post_init__ (self ):
25- """Initialize common device settings after instance creation.
16+ Parameters
17+ ----------
18+ device : torch.device, optional
19+ Device to run the model on. If None, automatically selects CUDA if available,
20+ otherwise CPU.
21+ model_name : str, default="BaseModel"
22+ String identifier for the model name which can be specified by the user.
23+
24+ Attributes
25+ ----------
26+ model_class : type or None
27+ The class of the model used to initialize the model instance.
28+ model : object or None
29+ The fitted model instance if the model has been trained, None otherwise.
30+ is_fitted_ : bool
31+ Whether the model has been fitted/trained. False by default.
32+ """
33+ def __init__ (self , device : Optional [torch .device ] = None , model_name : str = "BaseModel" ):
34+ self .device = device
35+ self .model_name = model_name # string of the model name which could be specified by the user
2636
27- Sets the device to CUDA if available, otherwise CPU, when no device is specified.
28- Converts string device specifications to torch.device objects.
29- """
3037 if self .device is None :
3138 self .device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
3239 elif isinstance (self .device , str ):
3340 self .device = torch .device (self .device )
3441
42+ self .is_fitted_ = False # whether the model is fitted
43+ self .model = None # the fitted model if not None
44+ self .model_class = None # the class of the model used to initialize the model
45+
3546 @abstractmethod
3647 def _setup_optimizers (self ) -> Tuple [torch .optim .Optimizer , Optional [Any ]]:
3748 """Set up optimizers for model training.
@@ -78,7 +89,7 @@ def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]
7889 pass
7990
8091 @staticmethod
81- def _get_param_names (self ) -> List [str ]:
92+ def _get_param_names () -> List [str ]:
8293 """Get parameter names in the modeling class.
8394
8495 Returns
@@ -104,7 +115,7 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
104115 Dictionary of parameter names mapped to their values
105116 """
106117 out = {}
107- for key in self ._get_param_names ():
118+ for key in self .__class__ . _get_param_names ():
108119 value = getattr (self , key )
109120 if deep and hasattr (value , "get_params" ):
110121 deep_items = value .get_params ().items ()
@@ -392,5 +403,4 @@ def format_value(v):
392403 if len (repr_str ) > N_CHAR_MAX :
393404 repr_str = "\n " .join ([repr_str [:N_CHAR_MAX // 2 ], "..." , repr_str [- N_CHAR_MAX // 2 :]])
394405
395- return repr_str
396-
406+ return repr_str
0 commit comments