Skip to content

Commit e0e8f37

Browse files
adrian-priorclaude
andcommitted
Flip strict=False -> strict=True on zips with invariant-equal lengths
The ruff B905 auto-fix inserted `strict=False` everywhere to preserve pre-existing behavior. Audit of each site showed the conservative choice was wrong for most of them: the two iterables are paired by construction or invariant (feature_names to columns, estimators to weights, pred_np rows to y_indices, etc.), so truncation would hide a bug rather than be desired behavior. The only genuine intentional-truncation case is meta_models.py's `zip(range(n_estimators), relevant_config_product)`, where `n_estimators_per_model = max(n_estimators // len(product), 1)` deliberately allows the two lengths to differ. That one stays `strict=False`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent b3e4366 commit e0e8f37

5 files changed

Lines changed: 7 additions & 7 deletions

File tree

src/tabpfn_extensions/sklearn_ensembles/weighted_ensemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def fit(self, X, y):
3030
# Prune classifiers with weights below the threshold
3131
pruned_classifiers = []
3232
pruned_weights = []
33-
for clf, weight in zip(self.estimators, weights, strict=False):
33+
for clf, weight in zip(self.estimators, weights, strict=True):
3434
if weight >= self.weight_threshold:
3535
pruned_classifiers.append(clf)
3636
pruned_weights.append(weight)
@@ -40,7 +40,7 @@ def fit(self, X, y):
4040
pruned_classifiers = [
4141
clf
4242
for _, clf in sorted(
43-
zip(pruned_weights, pruned_classifiers, strict=False),
43+
zip(pruned_weights, pruned_classifiers, strict=True),
4444
key=lambda pair: pair[0],
4545
)
4646
]

src/tabpfn_extensions/unsupervised/experiments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def run(self, tabpfn, **kwargs):
141141
zip(
142142
self.feature_names,
143143
[self.X[:, i] for i in range(self.X.shape[1])],
144-
strict=False,
144+
strict=True,
145145
),
146146
),
147147
"real_or_synthetic": "Actual samples",
@@ -156,7 +156,7 @@ def run(self, tabpfn, **kwargs):
156156
self.synthetic_X[:, i]
157157
for i in range(self.synthetic_X.shape[1])
158158
],
159-
strict=False,
159+
strict=True,
160160
),
161161
),
162162
"real_or_synthetic": "Generated samples",

src/tabpfn_extensions/unsupervised/unsupervised.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def outliers_single_permutation_(
624624
if valid_indices.any():
625625
# Get probabilities for each sample based on its class in y_predict
626626
for idx, (prob_row, y_idx) in enumerate(
627-
zip(pred_np, y_indices, strict=False)
627+
zip(pred_np, y_indices, strict=True)
628628
):
629629
if (
630630
0 <= y_idx < pred_np.shape[1]

src/tabpfn_extensions/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def product_dict(d: dict[str, list[T]]) -> Iterator[dict[str, T]]:
444444
keys = d.keys()
445445
values = [d[key] for key in keys]
446446
for combination in itertools.product(*values):
447-
yield dict(zip(keys, combination, strict=False))
447+
yield dict(zip(keys, combination, strict=True))
448448

449449

450450
# Get the TabPFN models with our wrappers applied

tests/test_many_class_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def predict_proba(self, X):
286286
for labels in fit_y_records:
287287
assert rest_code not in labels
288288

289-
for weights, labels in zip(fit_weight_records, fit_y_records, strict=False):
289+
for weights, labels in zip(fit_weight_records, fit_y_records, strict=True):
290290
if weights is not None:
291291
assert weights.shape[0] == labels.shape[0]
292292

0 commit comments

Comments
 (0)