@@ -1288,78 +1288,77 @@ def from_file(
12881288 pkl_filename = Path (run_directory ) / "checkpoint.pkl"
12891289 if pkl_filename .exists ():
12901290 pysr_logger .info (f"Attempting to load model from { pkl_filename } ..." )
1291- assert binary_operators is None
1292- assert unary_operators is None
1293- assert operators is None
1294- assert n_features_in is None
1295- with open ( pkl_filename , "rb" ) as f :
1296- model = cast ( "PySRRegressor" , pkl . load ( f ))
1291+ model = cls . _load_checkpoint ( pkl_filename )
1292+ if model is not None :
1293+ assert binary_operators is None
1294+ assert unary_operators is None
1295+ assert operators is None
1296+ assert n_features_in is None
12971297
1298- # Update any parameters if necessary, such as
1299- # extra_sympy_mappings:
1300- model .set_params (** pysr_kwargs )
1298+ # Update any parameters if necessary, such as
1299+ # extra_sympy_mappings:
1300+ model .set_params (** pysr_kwargs )
13011301
1302- if "equations_" not in model .__dict__ or model .equations_ is None :
1303- model .refresh ()
1302+ if (
1303+ "equations_" not in model .__dict__
1304+ or model .equations_ is None
1305+ or cls ._equations_missing_export_formats (model .equations_ )
1306+ ):
1307+ model .refresh ()
13041308
1305- if model .expression_spec is not None :
1306- warnings .warn (
1307- "Loading model from checkpoint file with a non-default expression spec "
1308- "is not fully supported as it relies on dynamic objects. This may result in unexpected behavior." ,
1309- )
1309+ if model .expression_spec is not None :
1310+ warnings .warn (
1311+ "Loading model from checkpoint file with a non-default expression spec "
1312+ "is not fully supported as it relies on dynamic objects. This may result in unexpected behavior." ,
1313+ )
13101314
1311- return model
1312- else :
1313- pysr_logger .info (
1314- f"Checkpoint file { pkl_filename } does not exist. "
1315- "Attempting to recreate model from scratch..."
1315+ return model
1316+ pysr_logger .info (
1317+ f"Checkpoint file { pkl_filename } does not exist or could not be loaded. "
1318+ "Attempting to recreate model from CSV backups..."
1319+ )
1320+ csv_filename = Path (run_directory ) / "hall_of_fame.csv"
1321+ csv_filename_bak = Path (run_directory ) / "hall_of_fame.csv.bak"
1322+ if not csv_filename .exists () and not csv_filename_bak .exists ():
1323+ raise FileNotFoundError (
1324+ f"Hall of fame file `{ csv_filename } ` or `{ csv_filename_bak } ` does not exist. "
1325+ "Please pass a `run_directory` containing a valid checkpoint file."
13161326 )
1317- csv_filename = Path (run_directory ) / "hall_of_fame.csv"
1318- csv_filename_bak = Path (run_directory ) / "hall_of_fame.csv.bak"
1319- if not csv_filename .exists () and not csv_filename_bak .exists ():
1320- raise FileNotFoundError (
1321- f"Hall of fame file `{ csv_filename } ` or `{ csv_filename_bak } ` does not exist. "
1322- "Please pass a `run_directory` containing a valid checkpoint file."
1323- )
1324- if (
1325- operators is None
1326- and binary_operators is None
1327- and unary_operators is None
1328- ):
1329- raise ValueError (
1330- "When recreating a model from CSV backups you must provide either "
1331- "`operators` or legacy `binary_operators`/`unary_operators`."
1332- )
1333- assert n_features_in is not None
1334- model = cls (
1335- binary_operators = binary_operators ,
1336- unary_operators = unary_operators ,
1337- operators = operators ,
1338- ** pysr_kwargs ,
1327+ if operators is None and binary_operators is None and unary_operators is None :
1328+ raise ValueError (
1329+ "When recreating a model from CSV backups you must provide either "
1330+ "`operators` or legacy `binary_operators`/`unary_operators`."
13391331 )
1340- model .nout_ = nout
1341- model .n_features_in_ = n_features_in
1332+ assert n_features_in is not None
1333+ model = cls (
1334+ binary_operators = binary_operators ,
1335+ unary_operators = unary_operators ,
1336+ operators = operators ,
1337+ ** pysr_kwargs ,
1338+ )
1339+ model .nout_ = nout
1340+ model .n_features_in_ = n_features_in
13421341
1343- if feature_names_in is None :
1344- model .feature_names_in_ = np .array (
1345- [f"x{ i } " for i in range (n_features_in )]
1346- )
1347- model .display_feature_names_in_ = np .array (
1348- [f"x{ _subscriptify (i )} " for i in range (n_features_in )]
1349- )
1350- else :
1351- assert len (feature_names_in ) == n_features_in
1352- model .feature_names_in_ = feature_names_in
1353- model .display_feature_names_in_ = feature_names_in
1342+ if feature_names_in is None :
1343+ model .feature_names_in_ = np .array (
1344+ [f"x{ i } " for i in range (n_features_in )]
1345+ )
1346+ model .display_feature_names_in_ = np .array (
1347+ [f"x{ _subscriptify (i )} " for i in range (n_features_in )]
1348+ )
1349+ else :
1350+ assert len (feature_names_in ) == n_features_in
1351+ model .feature_names_in_ = feature_names_in
1352+ model .display_feature_names_in_ = feature_names_in
13541353
1355- if selection_mask is None :
1356- model .selection_mask_ = np .ones (n_features_in , dtype = np .bool_ )
1357- else :
1358- model .selection_mask_ = selection_mask
1354+ if selection_mask is None :
1355+ model .selection_mask_ = np .ones (n_features_in , dtype = np .bool_ )
1356+ else :
1357+ model .selection_mask_ = selection_mask
13591358
1360- model .refresh (run_directory = run_directory )
1359+ model .refresh (run_directory = run_directory )
13611360
1362- return model
1361+ return model
13631362
13641363 def __repr__ (self ) -> str :
13651364 """
@@ -1424,8 +1423,12 @@ def __getstate__(self) -> dict[str, Any]:
14241423 show_pickle_warning = not (
14251424 "show_pickle_warnings_" in state and not state ["show_pickle_warnings_" ]
14261425 )
1427- state_keys_containing_lambdas = ["extra_sympy_mappings" , "extra_torch_mappings" ]
1428- for state_key in state_keys_containing_lambdas :
1426+ state_keys_to_clear = (
1427+ "extra_sympy_mappings" ,
1428+ "extra_jax_mappings" ,
1429+ "extra_torch_mappings" ,
1430+ )
1431+ for state_key in state_keys_to_clear :
14291432 warn_msg = (
14301433 f"`{ state_key } ` cannot be pickled and will be removed from the "
14311434 "serialized instance. When loading the model, please redefine "
@@ -1436,8 +1439,7 @@ def __getstate__(self) -> dict[str, Any]:
14361439 warnings .warn (warn_msg )
14371440 else :
14381441 pysr_logger .debug (warn_msg )
1439- state_keys_to_clear = state_keys_containing_lambdas
1440- state_keys_to_clear .append ("logger_" )
1442+ state_keys_to_clear = (* state_keys_to_clear , "logger_" )
14411443 pickled_state = {
14421444 key : (None if key in state_keys_to_clear else value )
14431445 for key , value in state .items ()
@@ -1448,38 +1450,88 @@ def __getstate__(self) -> dict[str, Any]:
14481450 pickled_state ["output_torch_format" ] = False
14491451 pickled_state ["output_jax_format" ] = False
14501452 if self .nout_ == 1 :
1451- pickled_columns = ~ pickled_state ["equations_" ].columns .isin (
1452- ["jax_format" , "torch_format" ]
1453- )
1454- pickled_state ["equations_" ] = (
1455- pickled_state ["equations_" ].loc [:, pickled_columns ].copy ()
1453+ pickled_state ["equations_" ] = self ._drop_equation_columns (
1454+ pickled_state ["equations_" ], ["jax_format" , "torch_format" ]
14561455 )
14571456 else :
1458- pickled_columns = [
1459- ~ dataframe .columns .isin (["jax_format" , "torch_format" ])
1460- for dataframe in pickled_state ["equations_" ]
1461- ]
1462- pickled_state ["equations_" ] = [
1463- dataframe .loc [:, signle_pickled_columns ]
1464- for dataframe , signle_pickled_columns in zip (
1465- pickled_state ["equations_" ], pickled_columns
1466- )
1467- ]
1457+ pickled_state ["equations_" ] = self ._drop_equation_columns (
1458+ pickled_state ["equations_" ], ["jax_format" , "torch_format" ]
1459+ )
1460+ try :
1461+ pkl .dumps (pickled_state ["equations_" ])
1462+ except Exception as e :
1463+ warn_msg = (
1464+ "`equations_` export formats cannot be pickled and will be "
1465+ "removed from the serialized instance. When loading the model, "
1466+ "please redefine custom mappings at runtime."
1467+ )
1468+ if show_pickle_warning :
1469+ warnings .warn (warn_msg )
1470+ else :
1471+ pysr_logger .debug (f"{ warn_msg } Error: { e } " )
1472+ pickled_state ["equations_" ] = self ._drop_equation_columns (
1473+ pickled_state ["equations_" ], ["sympy_format" , "lambda_format" ]
1474+ )
14681475 return pickled_state
14691476
1477+ @staticmethod
1478+ def _drop_equation_columns (equations , columns : list [str ]):
1479+ if isinstance (equations , list ):
1480+ return [
1481+ dataframe .loc [:, ~ dataframe .columns .isin (columns )].copy ()
1482+ for dataframe in equations
1483+ ]
1484+ return equations .loc [:, ~ equations .columns .isin (columns )].copy ()
1485+
1486+ @staticmethod
1487+ def _equations_missing_export_formats (equations ) -> bool :
1488+ required_columns = {"sympy_format" , "lambda_format" }
1489+ if isinstance (equations , list ):
1490+ return any (
1491+ not required_columns .issubset (dataframe .columns )
1492+ for dataframe in equations
1493+ )
1494+ return not required_columns .issubset (equations .columns )
1495+
14701496 def _checkpoint (self ):
14711497 """Save the model's current state to a checkpoint file.
14721498
14731499 This should only be used internally by PySRRegressor.
14741500 """
1475- # Save model state:
14761501 self .show_pickle_warnings_ = False
1477- with open (self .get_pkl_filename (), "wb" ) as f :
1478- try :
1502+ pkl_filename = self .get_pkl_filename ()
1503+ tmp_filename = None
1504+ try :
1505+ with tempfile .NamedTemporaryFile (
1506+ mode = "wb" , dir = pkl_filename .parent , delete = False
1507+ ) as f :
1508+ tmp_filename = Path (f .name )
14791509 pkl .dump (self , f )
1480- except Exception as e :
1481- pysr_logger .debug (f"Error checkpointing model: { e } " )
1482- self .show_pickle_warnings_ = True
1510+ os .replace (tmp_filename , pkl_filename )
1511+ except Exception as e :
1512+ pysr_logger .debug (f"Error checkpointing model: { e } " )
1513+ if tmp_filename is not None :
1514+ tmp_filename .unlink (missing_ok = True )
1515+ finally :
1516+ self .show_pickle_warnings_ = True
1517+
1518+ @staticmethod
1519+ def _load_checkpoint (pkl_filename : Path ) -> "PySRRegressor" | None :
1520+ if pkl_filename .stat ().st_size == 0 :
1521+ pysr_logger .warning (
1522+ f"Checkpoint file { pkl_filename } is empty. "
1523+ "Attempting to recreate model from CSV backups..."
1524+ )
1525+ return None
1526+ try :
1527+ with open (pkl_filename , "rb" ) as f :
1528+ return cast ("PySRRegressor" , pkl .load (f ))
1529+ except (EOFError , pkl .UnpicklingError ) as e :
1530+ pysr_logger .warning (
1531+ f"Could not load checkpoint file { pkl_filename } : { e } . "
1532+ "Attempting to recreate model from CSV backups..."
1533+ )
1534+ return None
14831535
14841536 def get_pkl_filename (self ) -> Path :
14851537 path = Path (self .output_directory_ ) / self .run_id_ / "checkpoint.pkl"
0 commit comments