Skip to content

Commit d969a4d

Browse files
fix: avoid empty checkpoints for custom jax exports
1 parent 14cbe96 commit d969a4d

2 files changed

Lines changed: 212 additions & 87 deletions

File tree

pysr/sr.py

Lines changed: 139 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,78 +1288,77 @@ def from_file(
12881288
pkl_filename = Path(run_directory) / "checkpoint.pkl"
12891289
if pkl_filename.exists():
12901290
pysr_logger.info(f"Attempting to load model from {pkl_filename}...")
1291-
assert binary_operators is None
1292-
assert unary_operators is None
1293-
assert operators is None
1294-
assert n_features_in is None
1295-
with open(pkl_filename, "rb") as f:
1296-
model = cast("PySRRegressor", pkl.load(f))
1291+
model = cls._load_checkpoint(pkl_filename)
1292+
if model is not None:
1293+
assert binary_operators is None
1294+
assert unary_operators is None
1295+
assert operators is None
1296+
assert n_features_in is None
12971297

1298-
# Update any parameters if necessary, such as
1299-
# extra_sympy_mappings:
1300-
model.set_params(**pysr_kwargs)
1298+
# Update any parameters if necessary, such as
1299+
# extra_sympy_mappings:
1300+
model.set_params(**pysr_kwargs)
13011301

1302-
if "equations_" not in model.__dict__ or model.equations_ is None:
1303-
model.refresh()
1302+
if (
1303+
"equations_" not in model.__dict__
1304+
or model.equations_ is None
1305+
or cls._equations_missing_export_formats(model.equations_)
1306+
):
1307+
model.refresh()
13041308

1305-
if model.expression_spec is not None:
1306-
warnings.warn(
1307-
"Loading model from checkpoint file with a non-default expression spec "
1308-
"is not fully supported as it relies on dynamic objects. This may result in unexpected behavior.",
1309-
)
1309+
if model.expression_spec is not None:
1310+
warnings.warn(
1311+
"Loading model from checkpoint file with a non-default expression spec "
1312+
"is not fully supported as it relies on dynamic objects. This may result in unexpected behavior.",
1313+
)
13101314

1311-
return model
1312-
else:
1313-
pysr_logger.info(
1314-
f"Checkpoint file {pkl_filename} does not exist. "
1315-
"Attempting to recreate model from scratch..."
1315+
return model
1316+
pysr_logger.info(
1317+
f"Checkpoint file {pkl_filename} does not exist or could not be loaded. "
1318+
"Attempting to recreate model from CSV backups..."
1319+
)
1320+
csv_filename = Path(run_directory) / "hall_of_fame.csv"
1321+
csv_filename_bak = Path(run_directory) / "hall_of_fame.csv.bak"
1322+
if not csv_filename.exists() and not csv_filename_bak.exists():
1323+
raise FileNotFoundError(
1324+
f"Hall of fame file `{csv_filename}` or `{csv_filename_bak}` does not exist. "
1325+
"Please pass a `run_directory` containing a valid checkpoint file."
13161326
)
1317-
csv_filename = Path(run_directory) / "hall_of_fame.csv"
1318-
csv_filename_bak = Path(run_directory) / "hall_of_fame.csv.bak"
1319-
if not csv_filename.exists() and not csv_filename_bak.exists():
1320-
raise FileNotFoundError(
1321-
f"Hall of fame file `{csv_filename}` or `{csv_filename_bak}` does not exist. "
1322-
"Please pass a `run_directory` containing a valid checkpoint file."
1323-
)
1324-
if (
1325-
operators is None
1326-
and binary_operators is None
1327-
and unary_operators is None
1328-
):
1329-
raise ValueError(
1330-
"When recreating a model from CSV backups you must provide either "
1331-
"`operators` or legacy `binary_operators`/`unary_operators`."
1332-
)
1333-
assert n_features_in is not None
1334-
model = cls(
1335-
binary_operators=binary_operators,
1336-
unary_operators=unary_operators,
1337-
operators=operators,
1338-
**pysr_kwargs,
1327+
if operators is None and binary_operators is None and unary_operators is None:
1328+
raise ValueError(
1329+
"When recreating a model from CSV backups you must provide either "
1330+
"`operators` or legacy `binary_operators`/`unary_operators`."
13391331
)
1340-
model.nout_ = nout
1341-
model.n_features_in_ = n_features_in
1332+
assert n_features_in is not None
1333+
model = cls(
1334+
binary_operators=binary_operators,
1335+
unary_operators=unary_operators,
1336+
operators=operators,
1337+
**pysr_kwargs,
1338+
)
1339+
model.nout_ = nout
1340+
model.n_features_in_ = n_features_in
13421341

1343-
if feature_names_in is None:
1344-
model.feature_names_in_ = np.array(
1345-
[f"x{i}" for i in range(n_features_in)]
1346-
)
1347-
model.display_feature_names_in_ = np.array(
1348-
[f"x{_subscriptify(i)}" for i in range(n_features_in)]
1349-
)
1350-
else:
1351-
assert len(feature_names_in) == n_features_in
1352-
model.feature_names_in_ = feature_names_in
1353-
model.display_feature_names_in_ = feature_names_in
1342+
if feature_names_in is None:
1343+
model.feature_names_in_ = np.array(
1344+
[f"x{i}" for i in range(n_features_in)]
1345+
)
1346+
model.display_feature_names_in_ = np.array(
1347+
[f"x{_subscriptify(i)}" for i in range(n_features_in)]
1348+
)
1349+
else:
1350+
assert len(feature_names_in) == n_features_in
1351+
model.feature_names_in_ = feature_names_in
1352+
model.display_feature_names_in_ = feature_names_in
13541353

1355-
if selection_mask is None:
1356-
model.selection_mask_ = np.ones(n_features_in, dtype=np.bool_)
1357-
else:
1358-
model.selection_mask_ = selection_mask
1354+
if selection_mask is None:
1355+
model.selection_mask_ = np.ones(n_features_in, dtype=np.bool_)
1356+
else:
1357+
model.selection_mask_ = selection_mask
13591358

1360-
model.refresh(run_directory=run_directory)
1359+
model.refresh(run_directory=run_directory)
13611360

1362-
return model
1361+
return model
13631362

13641363
def __repr__(self) -> str:
13651364
"""
@@ -1424,8 +1423,12 @@ def __getstate__(self) -> dict[str, Any]:
14241423
show_pickle_warning = not (
14251424
"show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
14261425
)
1427-
state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
1428-
for state_key in state_keys_containing_lambdas:
1426+
state_keys_to_clear = (
1427+
"extra_sympy_mappings",
1428+
"extra_jax_mappings",
1429+
"extra_torch_mappings",
1430+
)
1431+
for state_key in state_keys_to_clear:
14291432
warn_msg = (
14301433
f"`{state_key}` cannot be pickled and will be removed from the "
14311434
"serialized instance. When loading the model, please redefine "
@@ -1436,8 +1439,7 @@ def __getstate__(self) -> dict[str, Any]:
14361439
warnings.warn(warn_msg)
14371440
else:
14381441
pysr_logger.debug(warn_msg)
1439-
state_keys_to_clear = state_keys_containing_lambdas
1440-
state_keys_to_clear.append("logger_")
1442+
state_keys_to_clear = (*state_keys_to_clear, "logger_")
14411443
pickled_state = {
14421444
key: (None if key in state_keys_to_clear else value)
14431445
for key, value in state.items()
@@ -1448,38 +1450,88 @@ def __getstate__(self) -> dict[str, Any]:
14481450
pickled_state["output_torch_format"] = False
14491451
pickled_state["output_jax_format"] = False
14501452
if self.nout_ == 1:
1451-
pickled_columns = ~pickled_state["equations_"].columns.isin(
1452-
["jax_format", "torch_format"]
1453-
)
1454-
pickled_state["equations_"] = (
1455-
pickled_state["equations_"].loc[:, pickled_columns].copy()
1453+
pickled_state["equations_"] = self._drop_equation_columns(
1454+
pickled_state["equations_"], ["jax_format", "torch_format"]
14561455
)
14571456
else:
1458-
pickled_columns = [
1459-
~dataframe.columns.isin(["jax_format", "torch_format"])
1460-
for dataframe in pickled_state["equations_"]
1461-
]
1462-
pickled_state["equations_"] = [
1463-
dataframe.loc[:, signle_pickled_columns]
1464-
for dataframe, signle_pickled_columns in zip(
1465-
pickled_state["equations_"], pickled_columns
1466-
)
1467-
]
1457+
pickled_state["equations_"] = self._drop_equation_columns(
1458+
pickled_state["equations_"], ["jax_format", "torch_format"]
1459+
)
1460+
try:
1461+
pkl.dumps(pickled_state["equations_"])
1462+
except Exception as e:
1463+
warn_msg = (
1464+
"`equations_` export formats cannot be pickled and will be "
1465+
"removed from the serialized instance. When loading the model, "
1466+
"please redefine custom mappings at runtime."
1467+
)
1468+
if show_pickle_warning:
1469+
warnings.warn(warn_msg)
1470+
else:
1471+
pysr_logger.debug(f"{warn_msg} Error: {e}")
1472+
pickled_state["equations_"] = self._drop_equation_columns(
1473+
pickled_state["equations_"], ["sympy_format", "lambda_format"]
1474+
)
14681475
return pickled_state
14691476

1477+
@staticmethod
1478+
def _drop_equation_columns(equations, columns: list[str]):
1479+
if isinstance(equations, list):
1480+
return [
1481+
dataframe.loc[:, ~dataframe.columns.isin(columns)].copy()
1482+
for dataframe in equations
1483+
]
1484+
return equations.loc[:, ~equations.columns.isin(columns)].copy()
1485+
1486+
@staticmethod
1487+
def _equations_missing_export_formats(equations) -> bool:
1488+
required_columns = {"sympy_format", "lambda_format"}
1489+
if isinstance(equations, list):
1490+
return any(
1491+
not required_columns.issubset(dataframe.columns)
1492+
for dataframe in equations
1493+
)
1494+
return not required_columns.issubset(equations.columns)
1495+
14701496
def _checkpoint(self):
14711497
"""Save the model's current state to a checkpoint file.
14721498
14731499
This should only be used internally by PySRRegressor.
14741500
"""
1475-
# Save model state:
14761501
self.show_pickle_warnings_ = False
1477-
with open(self.get_pkl_filename(), "wb") as f:
1478-
try:
1502+
pkl_filename = self.get_pkl_filename()
1503+
tmp_filename = None
1504+
try:
1505+
with tempfile.NamedTemporaryFile(
1506+
mode="wb", dir=pkl_filename.parent, delete=False
1507+
) as f:
1508+
tmp_filename = Path(f.name)
14791509
pkl.dump(self, f)
1480-
except Exception as e:
1481-
pysr_logger.debug(f"Error checkpointing model: {e}")
1482-
self.show_pickle_warnings_ = True
1510+
os.replace(tmp_filename, pkl_filename)
1511+
except Exception as e:
1512+
pysr_logger.debug(f"Error checkpointing model: {e}")
1513+
if tmp_filename is not None:
1514+
tmp_filename.unlink(missing_ok=True)
1515+
finally:
1516+
self.show_pickle_warnings_ = True
1517+
1518+
@staticmethod
1519+
def _load_checkpoint(pkl_filename: Path) -> "PySRRegressor" | None:
1520+
if pkl_filename.stat().st_size == 0:
1521+
pysr_logger.warning(
1522+
f"Checkpoint file {pkl_filename} is empty. "
1523+
"Attempting to recreate model from CSV backups..."
1524+
)
1525+
return None
1526+
try:
1527+
with open(pkl_filename, "rb") as f:
1528+
return cast("PySRRegressor", pkl.load(f))
1529+
except (EOFError, pkl.UnpicklingError) as e:
1530+
pysr_logger.warning(
1531+
f"Could not load checkpoint file {pkl_filename}: {e}. "
1532+
"Attempting to recreate model from CSV backups..."
1533+
)
1534+
return None
14831535

14841536
def get_pkl_filename(self) -> Path:
14851537
path = Path(self.output_directory_) / self.run_id_ / "checkpoint.pkl"

pysr/test/test_jax.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import tempfile
12
import unittest
23
from functools import partial
34
from pathlib import Path
@@ -149,6 +150,78 @@ def cos_approx(x):
149150
jax_output = jax_prediction(X.values)
150151
np.testing.assert_almost_equal(y.values, jax_output, decimal=3)
151152

153+
def test_checkpoint_custom_jax_mapping(self):
154+
sp_cos_approx = sympy.Function("cos_approx")
155+
jax_mapping = {
156+
sp_cos_approx: "(lambda x: 1 - x**2 / 2 + x**4 / 24 + x**6 / 720)"
157+
}
158+
159+
with tempfile.TemporaryDirectory() as tmpdir:
160+
run_dir = Path(tmpdir) / "custom_jax"
161+
run_dir.mkdir()
162+
pd.DataFrame(
163+
{
164+
"Complexity": [1],
165+
"Loss": [0.0],
166+
"Equation": ["cos_approx(x0)"],
167+
}
168+
).to_csv(run_dir / "hall_of_fame.csv")
169+
170+
model = PySRRegressor(
171+
progress=False,
172+
unary_operators=[
173+
"cos_approx(x) = 1 - x^2 / 2 + x^4 / 24 + x^6 / 720"
174+
],
175+
extra_sympy_mappings={"cos_approx": sp_cos_approx},
176+
extra_jax_mappings=jax_mapping,
177+
output_jax_format=True,
178+
)
179+
model.output_directory_ = tmpdir
180+
model.run_id_ = "custom_jax"
181+
model.nout_ = 1
182+
model.n_features_in_ = 1
183+
model.feature_names_in_ = np.array(["x0"])
184+
model.display_feature_names_in_ = np.array(["x0"])
185+
model.selection_mask_ = np.ones(1, dtype=np.bool_)
186+
model.refresh()
187+
188+
model._checkpoint()
189+
assert (run_dir / "checkpoint.pkl").stat().st_size > 0
190+
191+
model2 = PySRRegressor.from_file(
192+
run_directory=run_dir,
193+
extra_sympy_mappings={"cos_approx": sp_cos_approx},
194+
extra_jax_mappings=jax_mapping,
195+
)
196+
jax_format = model2.jax(index=0)
197+
X = self.jnp.array([[0.0], [1.0], [2.0]])
198+
expected = 1 - X[:, 0] ** 2 / 2 + X[:, 0] ** 4 / 24 + X[:, 0] ** 6 / 720
199+
np.testing.assert_allclose(
200+
np.array(jax_format["callable"](X, jax_format["parameters"])),
201+
np.array(expected),
202+
)
203+
204+
def test_from_file_empty_checkpoint_falls_back_to_csv(self):
205+
with tempfile.TemporaryDirectory() as tmpdir:
206+
run_dir = Path(tmpdir) / "empty_checkpoint"
207+
run_dir.mkdir()
208+
(run_dir / "checkpoint.pkl").touch()
209+
pd.DataFrame(
210+
{
211+
"Complexity": [1],
212+
"Loss": [0.0],
213+
"Equation": ["x0"],
214+
}
215+
).to_csv(run_dir / "hall_of_fame.csv")
216+
217+
model = PySRRegressor.from_file(
218+
run_directory=run_dir,
219+
binary_operators=["+"],
220+
unary_operators=[],
221+
n_features_in=1,
222+
)
223+
assert str(model.sympy(index=0)) == "x0"
224+
152225

153226
def runtests(just_tests=False):
154227
"""Run all tests in test_jax.py."""

0 commit comments

Comments
 (0)