55import copy
66import logging
77import os
8- import pickle as pkl
98import re
109import sys
1110import tempfile
2625from sklearn .utils .validation import _check_feature_names_in # type: ignore
2726from sklearn .utils .validation import check_is_fitted
2827
28+ from ._checkpoint import (
29+ equations_missing_export_formats ,
30+ get_regressor_pickle_state ,
31+ load_checkpoint ,
32+ save_checkpoint ,
33+ )
2934from .denoising import denoise , multi_denoise
3035from .deprecated import DEPRECATED_KWARGS
3136from .export_latex import (
@@ -1288,7 +1293,7 @@ def from_file(
12881293 pkl_filename = Path (run_directory ) / "checkpoint.pkl"
12891294 if pkl_filename .exists ():
12901295 pysr_logger .info (f"Attempting to load model from { pkl_filename } ..." )
1291- model = cls . _load_checkpoint (pkl_filename )
1296+ model = load_checkpoint (pkl_filename )
12921297 if model is not None :
12931298 assert binary_operators is None
12941299 assert unary_operators is None
@@ -1302,7 +1307,7 @@ def from_file(
13021307 if (
13031308 "equations_" not in model .__dict__
13041309 or model .equations_ is None
1305- or cls . _equations_missing_export_formats (model .equations_ )
1310+ or equations_missing_export_formats (model .equations_ )
13061311 ):
13071312 model .refresh ()
13081313
@@ -1417,120 +1422,19 @@ def __getstate__(self) -> dict[str, Any]:
14171422 `pickle.dumps()`. However, some attributes do not support pickling
14181423 and need to be hidden, such as the JAX and Torch representations.
14191424 """
1420- state = self .__dict__
1421- show_pickle_warning = not (
1422- "show_pickle_warnings_" in state and not state ["show_pickle_warnings_" ]
1423- )
1424- state_keys_to_clear = (
1425- "extra_sympy_mappings" ,
1426- "extra_jax_mappings" ,
1427- "extra_torch_mappings" ,
1428- )
1429- for state_key in state_keys_to_clear :
1430- warn_msg = (
1431- f"`{ state_key } ` cannot be pickled and will be removed from the "
1432- "serialized instance. When loading the model, please redefine "
1433- f"`{ state_key } ` at runtime."
1434- )
1435- if state [state_key ] is not None :
1436- if show_pickle_warning :
1437- warnings .warn (warn_msg )
1438- else :
1439- pysr_logger .debug (warn_msg )
1440- state_keys_to_clear = (* state_keys_to_clear , "logger_" )
1441- pickled_state = {
1442- key : (None if key in state_keys_to_clear else value )
1443- for key , value in state .items ()
1444- }
1445- if ("equations_" in pickled_state ) and (
1446- pickled_state ["equations_" ] is not None
1447- ):
1448- pickled_state ["output_torch_format" ] = False
1449- pickled_state ["output_jax_format" ] = False
1450- if self .nout_ == 1 :
1451- pickled_state ["equations_" ] = self ._drop_equation_columns (
1452- pickled_state ["equations_" ], ["jax_format" , "torch_format" ]
1453- )
1454- else :
1455- pickled_state ["equations_" ] = self ._drop_equation_columns (
1456- pickled_state ["equations_" ], ["jax_format" , "torch_format" ]
1457- )
1458- try :
1459- pkl .dumps (pickled_state ["equations_" ])
1460- except Exception as e :
1461- warn_msg = (
1462- "`equations_` export formats cannot be pickled and will be "
1463- "removed from the serialized instance. When loading the model, "
1464- "please redefine custom mappings at runtime."
1465- )
1466- if show_pickle_warning :
1467- warnings .warn (warn_msg )
1468- else :
1469- pysr_logger .debug (f"{ warn_msg } Error: { e } " )
1470- pickled_state ["equations_" ] = self ._drop_equation_columns (
1471- pickled_state ["equations_" ], ["sympy_format" , "lambda_format" ]
1472- )
1473- return pickled_state
1474-
1475- @staticmethod
1476- def _drop_equation_columns (equations , columns : list [str ]):
1477- if isinstance (equations , list ):
1478- return [
1479- dataframe .loc [:, ~ dataframe .columns .isin (columns )].copy ()
1480- for dataframe in equations
1481- ]
1482- return equations .loc [:, ~ equations .columns .isin (columns )].copy ()
1483-
1484- @staticmethod
1485- def _equations_missing_export_formats (equations ) -> bool :
1486- required_columns = {"sympy_format" , "lambda_format" }
1487- if isinstance (equations , list ):
1488- return any (
1489- not required_columns .issubset (dataframe .columns )
1490- for dataframe in equations
1491- )
1492- return not required_columns .issubset (equations .columns )
1425+ return get_regressor_pickle_state (self .__dict__ )
14931426
14941427 def _checkpoint (self ):
14951428 """Save the model's current state to a checkpoint file.
14961429
14971430 This should only be used internally by PySRRegressor.
14981431 """
14991432 self .show_pickle_warnings_ = False
1500- pkl_filename = self .get_pkl_filename ()
1501- tmp_filename = None
15021433 try :
1503- with tempfile .NamedTemporaryFile (
1504- mode = "wb" , dir = pkl_filename .parent , delete = False
1505- ) as f :
1506- tmp_filename = Path (f .name )
1507- pkl .dump (self , f )
1508- os .replace (tmp_filename , pkl_filename )
1509- except Exception as e :
1510- pysr_logger .debug (f"Error checkpointing model: { e } " )
1511- if tmp_filename is not None :
1512- tmp_filename .unlink (missing_ok = True )
1434+ save_checkpoint (self , self .get_pkl_filename ())
15131435 finally :
15141436 self .show_pickle_warnings_ = True
15151437
1516- @staticmethod
1517- def _load_checkpoint (pkl_filename : Path ) -> "PySRRegressor" | None :
1518- if pkl_filename .stat ().st_size == 0 :
1519- pysr_logger .warning (
1520- f"Checkpoint file { pkl_filename } is empty. "
1521- "Attempting to recreate model from CSV backups..."
1522- )
1523- return None
1524- try :
1525- with open (pkl_filename , "rb" ) as f :
1526- return cast ("PySRRegressor" , pkl .load (f ))
1527- except (EOFError , pkl .UnpicklingError ) as e :
1528- pysr_logger .warning (
1529- f"Could not load checkpoint file { pkl_filename } : { e } . "
1530- "Attempting to recreate model from CSV backups..."
1531- )
1532- return None
1533-
15341438 def get_pkl_filename (self ) -> Path :
15351439 path = Path (self .output_directory_ ) / self .run_id_ / "checkpoint.pkl"
15361440 path .parent .mkdir (parents = True , exist_ok = True )
0 commit comments