Skip to content

Commit 1ebeb8b

Browse files
Merge remote-tracking branch 'origin/master' into fix/torch-export-preserve-dims
2 parents a342033 + f1963a4 commit 1ebeb8b

5 files changed

Lines changed: 52 additions & 78 deletions

File tree

.github/workflows/update_backend.yml

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ jobs:
2020
cache: pip
2121

2222
- name: "Install dependencies"
23-
run: |
24-
python -m pip install --upgrade pip
25-
pip install tomlkit
23+
run: python -m pip install --upgrade pip
2624

2725
- name: "Get SymbolicRegression.jl latest version"
2826
id: get-latest
@@ -36,13 +34,6 @@ jobs:
3634
run: |
3735
python .github/workflows/update_backend_version.py ${{ steps.get-latest.outputs.version }}
3836
39-
- name: "Restore changes if no diff to `pysr/juliapkg.json`"
40-
run: |
41-
if git diff --quiet pysr/juliapkg.json; then
42-
echo "No changes to pysr/juliapkg.json. Restoring changes."
43-
git restore pyproject.toml .release-please-manifest.json
44-
fi
45-
4637
- name: "Create PR if necessary"
4738
id: cpr
4839
uses: peter-evans/create-pull-request@v8
@@ -56,8 +47,6 @@ jobs:
5647
delete-branch: true
5748
commit-message: "chore: update backend to v${{ steps.get-latest.outputs.version }}"
5849
add-paths: |
59-
.release-please-manifest.json
60-
pyproject.toml
6150
pysr/juliapkg.json
6251
6352
- name: "Trigger CI workflows (backend update PR)"
Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,17 @@
11
import json
2-
import re
32
import sys
43
from pathlib import Path
54

6-
import tomlkit
7-
85
new_backend_version = sys.argv[1]
96

107
assert not new_backend_version.startswith("v"), "Version should not start with 'v'"
118

129
repo_root = Path(__file__).parent / ".." / ".."
13-
pyproject_toml = repo_root / "pyproject.toml"
1410
juliapkg_json = repo_root / "pysr" / "juliapkg.json"
15-
release_please_manifest = repo_root / ".release-please-manifest.json"
16-
17-
with open(pyproject_toml) as toml_file:
18-
pyproject_data = tomlkit.parse(toml_file.read())
19-
2011
with open(juliapkg_json) as f:
2112
juliapkg_data = json.load(f)
2213

23-
with open(release_please_manifest) as f:
24-
release_please_manifest_data = json.load(f)
25-
26-
current_version = pyproject_data["project"]["version"]
27-
parts = current_version.split(".")
28-
29-
if len(parts) < 3:
30-
raise ValueError(
31-
f"Invalid version format: {current_version}. Expected at least 3 components (major.minor.patch)"
32-
)
33-
34-
major, minor = parts[0], parts[1]
35-
36-
patch_match = re.match(r"^(\d+)(.*)$", parts[2])
37-
if not patch_match:
38-
raise ValueError(
39-
f"Could not parse patch version from '{parts[2]}' in version {current_version}. "
40-
f"Expected patch to start with a number (e.g., '0', '1a1', '2rc3')"
41-
)
42-
43-
patch_num_str, patch_suffix = patch_match.groups()
44-
patch_num = int(patch_num_str)
45-
46-
pre_release_match = re.fullmatch(r"(a|b|rc)(\d+)", patch_suffix)
47-
if pre_release_match:
48-
pre_tag, pre_num = pre_release_match.groups()
49-
new_patch = patch_num
50-
new_suffix = f"{pre_tag}{int(pre_num) + 1}"
51-
else:
52-
new_patch = patch_num + 1
53-
new_suffix = patch_suffix
54-
55-
# Add back any additional version components (e.g., "2.0.0.dev1" -> ".dev1")
56-
extra_parts = "." + ".".join(parts[3:]) if len(parts) > 3 else ""
57-
new_version = f"{major}.{minor}.{new_patch}{new_suffix}{extra_parts}"
58-
59-
pyproject_data["project"]["version"] = new_version
60-
release_please_manifest_data["."] = new_version
61-
62-
# Update backend - maintain current format (either "rev" or "version")
14+
# Update backend, maintain current format (either "rev" or "version")
6315
backend_pkg = juliapkg_data["packages"]["SymbolicRegression"]
6416
if "rev" in backend_pkg:
6517
backend_pkg["rev"] = f"v{new_backend_version}"
@@ -70,13 +22,6 @@
7022
"SymbolicRegression package must have either 'rev' or 'version' field"
7123
)
7224

73-
with open(pyproject_toml, "w") as toml_file:
74-
toml_file.write(tomlkit.dumps(pyproject_data))
75-
7625
with open(juliapkg_json, "w") as f:
7726
json.dump(juliapkg_data, f, indent=4)
7827
f.write("\n")
79-
80-
with open(release_please_manifest, "w") as f:
81-
json.dump(release_please_manifest_data, f, indent=2)
82-
f.write("\n")

.release-please-manifest.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
".": "2.0.0a2"
2+
".": "2.0.0a1"
33
}

pysr/sr.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,26 +235,34 @@ def _check_assertions(
235235
)
236236

237237

238-
def _validate_elementwise_loss(custom_loss, *, has_weights: bool) -> None:
239-
"""Validate that a Julia `elementwise_loss` is callable.
238+
def _validate_elementwise_loss(
239+
custom_loss, *, has_weights: bool, probe_value: Any = 1.0
240+
) -> None:
241+
"""Check whether a Julia `elementwise_loss` accepts the expected inputs.
240242
241-
We require exactly 2 args unless the user passed `weights=` to fit,
242-
in which case we require 3 args.
243+
The function probes the loss with two or three arguments, depending on
244+
whether weights are present, using the same dtype that fitting will use.
245+
If the probe fails, it raises a `ValueError` describing the expected
246+
signature.
243247
"""
244248

245249
# This can be either a LossFunctions.jl object (e.g. `L2DistLoss()`) or a Julia function.
246250
# Only validate arity when the evaluated object is actually a function.
247251
if not jl_is_function(custom_loss):
248252
return
249253

254+
probe_args = (
255+
(probe_value, probe_value, probe_value)
256+
if has_weights
257+
else (probe_value, probe_value)
258+
)
259+
ok = bool(jl.applicable(custom_loss, *probe_args))
250260
if has_weights:
251-
ok = bool(jl.applicable(custom_loss, 1.0, 1.0, 1.0))
252261
if not ok:
253262
raise ValueError(
254263
"`elementwise_loss` must accept (prediction, target, weight) when `weights` is passed to `fit`."
255264
)
256265
else:
257-
ok = bool(jl.applicable(custom_loss, 1.0, 1.0))
258266
if not ok:
259267
raise ValueError(
260268
"`elementwise_loss` must accept (prediction, target). If you intended a full objective, use "
@@ -2109,13 +2117,19 @@ def _run(
21092117
if isinstance(complexity_of_variables, list):
21102118
complexity_of_variables = jl_array(complexity_of_variables)
21112119

2120+
np_dtype = self._get_precision_mapped_dtype(np.array(X))
2121+
21122122
custom_loss = jl.seval(
21132123
str(self.elementwise_loss)
21142124
if self.elementwise_loss is not None
21152125
else "nothing"
21162126
)
21172127
if self.elementwise_loss is not None:
2118-
_validate_elementwise_loss(custom_loss, has_weights=weights is not None)
2128+
_validate_elementwise_loss(
2129+
custom_loss,
2130+
has_weights=weights is not None,
2131+
probe_value=np_dtype(1.0),
2132+
)
21192133

21202134
custom_full_objective = jl.seval(
21212135
str(self.loss_function) if self.loss_function is not None else "nothing"
@@ -2304,8 +2318,6 @@ def _run(
23042318
self.julia_options_stream_ = jl_serialize(options)
23052319

23062320
# Convert data to desired precision
2307-
test_X = np.array(X)
2308-
np_dtype = self._get_precision_mapped_dtype(test_X)
23092321

23102322
# This converts the data into a Julia array:
23112323
jl_X = jl_array(np.array(X, dtype=np_dtype).T)

pysr/test/test_main.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,34 @@ def test_elementwise_loss_with_weights_accepts_three_args(self):
313313
weights = np.array([1.0, 1.0])
314314
model.fit(X, y, weights=weights)
315315

316+
def test_elementwise_loss_float32_probe_accepts_strictly_typed_loss(self):
317+
custom_loss = jl.seval(
318+
"(prediction::Float32, target::Float32) -> (prediction - target)^2"
319+
)
320+
_validate_elementwise_loss(
321+
custom_loss,
322+
has_weights=False,
323+
probe_value=np.float32(1.0),
324+
)
325+
326+
def test_elementwise_loss_float32_fit_accepts_strictly_typed_loss(self):
327+
model = PySRRegressor(
328+
niterations=1,
329+
populations=1,
330+
procs=0,
331+
progress=False,
332+
verbosity=0,
333+
precision=32,
334+
temp_equation_file=True,
335+
binary_operators=["+"],
336+
elementwise_loss=(
337+
"(prediction::Float32, target::Float32) -> (prediction - target)^2"
338+
),
339+
)
340+
X = np.array([[0.0], [1.0]], dtype=np.float32)
341+
y = np.array([0.0, 1.0], dtype=np.float32)
342+
model.fit(X, y)
343+
316344
def test_validation_helpers_skip_nonfunction(self):
317345
_validate_elementwise_loss(jl.seval("1.0"), has_weights=False)
318346

0 commit comments

Comments
 (0)