Skip to content

Commit 2e59bfd

Browse files
refactor: move checkpoint helpers out of regressor
Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent 31210d8 commit 2e59bfd

2 files changed

Lines changed: 135 additions & 106 deletions

File tree

pysr/_checkpoint.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""Checkpoint and pickle helpers for PySRRegressor."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
import os
7+
import pickle as pkl
8+
import tempfile
9+
import warnings
10+
from pathlib import Path
11+
from typing import TYPE_CHECKING, Any, cast
12+
13+
if TYPE_CHECKING:
14+
import pandas as pd
15+
16+
from .sr import PySRRegressor
17+
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
def get_regressor_pickle_state(state: dict[str, Any]) -> dict[str, Any]:
23+
"""Return a pickle-safe version of a PySRRegressor state dictionary."""
24+
show_pickle_warning = not (
25+
"show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
26+
)
27+
state_keys_to_clear = (
28+
"extra_sympy_mappings",
29+
"extra_jax_mappings",
30+
"extra_torch_mappings",
31+
)
32+
for state_key in state_keys_to_clear:
33+
warn_msg = (
34+
f"`{state_key}` cannot be pickled and will be removed from the "
35+
"serialized instance. When loading the model, please redefine "
36+
f"`{state_key}` at runtime."
37+
)
38+
if state[state_key] is not None:
39+
if show_pickle_warning:
40+
warnings.warn(warn_msg)
41+
else:
42+
logger.debug(warn_msg)
43+
state_keys_to_clear = (*state_keys_to_clear, "logger_")
44+
pickled_state = {
45+
key: (None if key in state_keys_to_clear else value)
46+
for key, value in state.items()
47+
}
48+
if ("equations_" in pickled_state) and (pickled_state["equations_"] is not None):
49+
pickled_state["output_torch_format"] = False
50+
pickled_state["output_jax_format"] = False
51+
pickled_state["equations_"] = drop_equation_columns(
52+
pickled_state["equations_"], ["jax_format", "torch_format"]
53+
)
54+
try:
55+
pkl.dumps(pickled_state["equations_"])
56+
except Exception as e:
57+
warn_msg = (
58+
"`equations_` export formats cannot be pickled and will be "
59+
"removed from the serialized instance. When loading the model, "
60+
"please redefine custom mappings at runtime."
61+
)
62+
if show_pickle_warning:
63+
warnings.warn(warn_msg)
64+
else:
65+
logger.debug(f"{warn_msg} Error: {e}")
66+
pickled_state["equations_"] = drop_equation_columns(
67+
pickled_state["equations_"], ["sympy_format", "lambda_format"]
68+
)
69+
return pickled_state
70+
71+
72+
def drop_equation_columns(
73+
equations: pd.DataFrame | list[pd.DataFrame],
74+
columns: list[str],
75+
) -> pd.DataFrame | list[pd.DataFrame]:
76+
if isinstance(equations, list):
77+
return [
78+
dataframe.loc[:, ~dataframe.columns.isin(columns)].copy()
79+
for dataframe in equations
80+
]
81+
return equations.loc[:, ~equations.columns.isin(columns)].copy()
82+
83+
84+
def equations_missing_export_formats(
85+
equations: pd.DataFrame | list[pd.DataFrame],
86+
) -> bool:
87+
required_columns = {"sympy_format", "lambda_format"}
88+
if isinstance(equations, list):
89+
return any(
90+
not required_columns.issubset(dataframe.columns) for dataframe in equations
91+
)
92+
return not required_columns.issubset(equations.columns)
93+
94+
95+
def save_checkpoint(model: PySRRegressor, pkl_filename: Path) -> None:
96+
tmp_filename = None
97+
try:
98+
with tempfile.NamedTemporaryFile(
99+
mode="wb", dir=pkl_filename.parent, delete=False
100+
) as f:
101+
tmp_filename = Path(f.name)
102+
pkl.dump(model, f)
103+
os.replace(tmp_filename, pkl_filename)
104+
except Exception as e:
105+
logger.debug(f"Error checkpointing model: {e}")
106+
if tmp_filename is not None:
107+
tmp_filename.unlink(missing_ok=True)
108+
109+
110+
def load_checkpoint(pkl_filename: Path) -> PySRRegressor | None:
111+
if pkl_filename.stat().st_size == 0:
112+
logger.warning(
113+
f"Checkpoint file {pkl_filename} is empty. "
114+
"Attempting to recreate model from CSV backups..."
115+
)
116+
return None
117+
try:
118+
with open(pkl_filename, "rb") as f:
119+
return cast("PySRRegressor", pkl.load(f))
120+
except (EOFError, pkl.UnpicklingError) as e:
121+
logger.warning(
122+
f"Could not load checkpoint file {pkl_filename}: {e}. "
123+
"Attempting to recreate model from CSV backups..."
124+
)
125+
return None

pysr/sr.py

Lines changed: 10 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import copy
66
import logging
77
import os
8-
import pickle as pkl
98
import re
109
import sys
1110
import tempfile
@@ -26,6 +25,12 @@
2625
from sklearn.utils.validation import _check_feature_names_in # type: ignore
2726
from sklearn.utils.validation import check_is_fitted
2827

28+
from ._checkpoint import (
29+
equations_missing_export_formats,
30+
get_regressor_pickle_state,
31+
load_checkpoint,
32+
save_checkpoint,
33+
)
2934
from .denoising import denoise, multi_denoise
3035
from .deprecated import DEPRECATED_KWARGS
3136
from .export_latex import (
@@ -1288,7 +1293,7 @@ def from_file(
12881293
pkl_filename = Path(run_directory) / "checkpoint.pkl"
12891294
if pkl_filename.exists():
12901295
pysr_logger.info(f"Attempting to load model from {pkl_filename}...")
1291-
model = cls._load_checkpoint(pkl_filename)
1296+
model = load_checkpoint(pkl_filename)
12921297
if model is not None:
12931298
assert binary_operators is None
12941299
assert unary_operators is None
@@ -1302,7 +1307,7 @@ def from_file(
13021307
if (
13031308
"equations_" not in model.__dict__
13041309
or model.equations_ is None
1305-
or cls._equations_missing_export_formats(model.equations_)
1310+
or equations_missing_export_formats(model.equations_)
13061311
):
13071312
model.refresh()
13081313

@@ -1417,120 +1422,19 @@ def __getstate__(self) -> dict[str, Any]:
14171422
`pickle.dumps()`. However, some attributes do not support pickling
14181423
and need to be hidden, such as the JAX and Torch representations.
14191424
"""
1420-
state = self.__dict__
1421-
show_pickle_warning = not (
1422-
"show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
1423-
)
1424-
state_keys_to_clear = (
1425-
"extra_sympy_mappings",
1426-
"extra_jax_mappings",
1427-
"extra_torch_mappings",
1428-
)
1429-
for state_key in state_keys_to_clear:
1430-
warn_msg = (
1431-
f"`{state_key}` cannot be pickled and will be removed from the "
1432-
"serialized instance. When loading the model, please redefine "
1433-
f"`{state_key}` at runtime."
1434-
)
1435-
if state[state_key] is not None:
1436-
if show_pickle_warning:
1437-
warnings.warn(warn_msg)
1438-
else:
1439-
pysr_logger.debug(warn_msg)
1440-
state_keys_to_clear = (*state_keys_to_clear, "logger_")
1441-
pickled_state = {
1442-
key: (None if key in state_keys_to_clear else value)
1443-
for key, value in state.items()
1444-
}
1445-
if ("equations_" in pickled_state) and (
1446-
pickled_state["equations_"] is not None
1447-
):
1448-
pickled_state["output_torch_format"] = False
1449-
pickled_state["output_jax_format"] = False
1450-
if self.nout_ == 1:
1451-
pickled_state["equations_"] = self._drop_equation_columns(
1452-
pickled_state["equations_"], ["jax_format", "torch_format"]
1453-
)
1454-
else:
1455-
pickled_state["equations_"] = self._drop_equation_columns(
1456-
pickled_state["equations_"], ["jax_format", "torch_format"]
1457-
)
1458-
try:
1459-
pkl.dumps(pickled_state["equations_"])
1460-
except Exception as e:
1461-
warn_msg = (
1462-
"`equations_` export formats cannot be pickled and will be "
1463-
"removed from the serialized instance. When loading the model, "
1464-
"please redefine custom mappings at runtime."
1465-
)
1466-
if show_pickle_warning:
1467-
warnings.warn(warn_msg)
1468-
else:
1469-
pysr_logger.debug(f"{warn_msg} Error: {e}")
1470-
pickled_state["equations_"] = self._drop_equation_columns(
1471-
pickled_state["equations_"], ["sympy_format", "lambda_format"]
1472-
)
1473-
return pickled_state
1474-
1475-
@staticmethod
1476-
def _drop_equation_columns(equations, columns: list[str]):
1477-
if isinstance(equations, list):
1478-
return [
1479-
dataframe.loc[:, ~dataframe.columns.isin(columns)].copy()
1480-
for dataframe in equations
1481-
]
1482-
return equations.loc[:, ~equations.columns.isin(columns)].copy()
1483-
1484-
@staticmethod
1485-
def _equations_missing_export_formats(equations) -> bool:
1486-
required_columns = {"sympy_format", "lambda_format"}
1487-
if isinstance(equations, list):
1488-
return any(
1489-
not required_columns.issubset(dataframe.columns)
1490-
for dataframe in equations
1491-
)
1492-
return not required_columns.issubset(equations.columns)
1425+
return get_regressor_pickle_state(self.__dict__)
14931426

14941427
def _checkpoint(self):
14951428
"""Save the model's current state to a checkpoint file.
14961429
14971430
This should only be used internally by PySRRegressor.
14981431
"""
14991432
self.show_pickle_warnings_ = False
1500-
pkl_filename = self.get_pkl_filename()
1501-
tmp_filename = None
15021433
try:
1503-
with tempfile.NamedTemporaryFile(
1504-
mode="wb", dir=pkl_filename.parent, delete=False
1505-
) as f:
1506-
tmp_filename = Path(f.name)
1507-
pkl.dump(self, f)
1508-
os.replace(tmp_filename, pkl_filename)
1509-
except Exception as e:
1510-
pysr_logger.debug(f"Error checkpointing model: {e}")
1511-
if tmp_filename is not None:
1512-
tmp_filename.unlink(missing_ok=True)
1434+
save_checkpoint(self, self.get_pkl_filename())
15131435
finally:
15141436
self.show_pickle_warnings_ = True
15151437

1516-
@staticmethod
1517-
def _load_checkpoint(pkl_filename: Path) -> "PySRRegressor" | None:
1518-
if pkl_filename.stat().st_size == 0:
1519-
pysr_logger.warning(
1520-
f"Checkpoint file {pkl_filename} is empty. "
1521-
"Attempting to recreate model from CSV backups..."
1522-
)
1523-
return None
1524-
try:
1525-
with open(pkl_filename, "rb") as f:
1526-
return cast("PySRRegressor", pkl.load(f))
1527-
except (EOFError, pkl.UnpicklingError) as e:
1528-
pysr_logger.warning(
1529-
f"Could not load checkpoint file {pkl_filename}: {e}. "
1530-
"Attempting to recreate model from CSV backups..."
1531-
)
1532-
return None
1533-
15341438
def get_pkl_filename(self) -> Path:
15351439
path = Path(self.output_directory_) / self.run_id_ / "checkpoint.pkl"
15361440
path.parent.mkdir(parents=True, exist_ok=True)

0 commit comments

Comments
 (0)