Skip to content

Commit ca64d76

Browse files
More informative Memory Error Msg (#805)
Co-authored-by: Phil <phil@priorlabs.ai>
1 parent 3a57521 commit ca64d76

4 files changed

Lines changed: 66 additions & 12 deletions

File tree

changelog/805.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
More informative Out-Of-Memory error message.

src/tabpfn/classifier.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ class TabPFNClassifier(ClassifierMixin, BaseEstimator):
148148
n_features_in_: int
149149
"""The number of features in the input data used during `fit()`."""
150150

151+
n_train_samples_: int
152+
"""The number of training samples used during `fit()`."""
153+
151154
inferred_feature_schema_: FeatureSchema
152155
"""The inferred feature schema. This contains the feature modalities per column,
153156
using heuristics and user-provided indices for categorical features."""
@@ -650,6 +653,7 @@ def _initialize_dataset_preprocessing(
650653
self.ordinal_encoder_ = ordinal_encoder
651654
self.feature_names_in_ = feature_names
652655
self.n_features_in_ = n_features
656+
self.n_train_samples_ = len(X)
653657

654658
# Label encoding
655659
self.label_encoder_ = TabPFNLabelEncoder(original_target_name=original_y_name)
@@ -1062,7 +1066,13 @@ def _raw_predict(
10621066
ord_encoder=getattr(self, "ordinal_encoder_", None),
10631067
)
10641068

1065-
with handle_oom_errors(self.devices_, X, model_type="classifier"):
1069+
with handle_oom_errors(
1070+
self.devices_,
1071+
X,
1072+
model_type="classifier",
1073+
n_train_samples=getattr(self, "n_train_samples_", None),
1074+
n_features=getattr(self, "n_features_in_", None),
1075+
):
10661076
return self.forward(
10671077
X,
10681078
use_inference_mode=True,

src/tabpfn/errors.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,22 +60,43 @@ def __init__(
6060
self,
6161
original_error: Exception | None = None,
6262
*,
63+
n_train_samples: int | None = None,
6364
n_test_samples: int | None = None,
65+
n_features: int | None = None,
6466
model_type: str = "classifier",
6567
):
6668
predict_method = "predict_proba" if model_type == "classifier" else "predict"
6769

6870
size_info = f" with {n_test_samples:,} test samples" if n_test_samples else ""
6971

72+
size_line = ""
73+
if n_train_samples is not None and n_test_samples is not None:
74+
size_line = (
75+
f"Your sizes: {n_train_samples:,} train / "
76+
f"{n_test_samples:,} test samples"
77+
)
78+
if n_features is not None:
79+
size_line += f", {n_features} features"
80+
size_line += ".\n"
81+
7082
message = (
7183
f"{self.device_name} out of memory{size_info}.\n\n"
72-
f"Solution: Split your test data into smaller batches:\n\n"
73-
f" batch_size = 1000 # depends on hardware\n"
84+
f"This is issue is usually caused by one of the following two reasons:\n\n"
85+
f"1) Large test set — split into batches:\n\n"
7486
f" predictions = []\n"
75-
f" for i in range(0, len(X_test), batch_size):\n"
76-
f" batch = model.{predict_method}(X_test[i:i + batch_size])\n"
77-
f" predictions.append(batch)\n"
78-
f" predictions = np.vstack(predictions)"
87+
f" for i in range(0, len(X_test), 100):\n"
88+
f" pred = model.{predict_method}("
89+
f"X_test[i:i + 100])\n"
90+
f" predictions.append(pred)\n"
91+
f" predictions = np.vstack(predictions)\n\n"
92+
f"2) Large training set — batching won't help.\n"
93+
f" You need subsampling or ensembling, see:\n"
94+
f" https://github.com/PriorLabs/tabpfn-extensions/"
95+
f"blob/main/examples/large_datasets/"
96+
f"large_datasets_example.py\n\n"
97+
f"{size_line}"
98+
f"Not sure which? If model.{predict_method}(X_test[:1]) "
99+
f"also fails, it's (2)."
79100
)
80101
if original_error is not None:
81102
message += f"\n\nOriginal error: {original_error}"
@@ -100,13 +121,17 @@ def handle_oom_errors(
100121
devices: tuple[torch.device, ...],
101122
X: XType,
102123
model_type: str,
124+
n_train_samples: int | None = None,
125+
n_features: int | None = None,
103126
) -> Generator[None, None, None]:
104127
"""Context manager to catch OOM errors and raise helpful TabPFN exceptions.
105128
106129
Args:
107130
devices: The devices the model is running on.
108131
X: The input data (used to get n_samples for the error message).
109132
model_type: Either "classifier" or "regressor".
133+
n_train_samples: Number of training samples (for the error message).
134+
n_features: Number of features (for the error message).
110135
111136
Raises:
112137
TabPFNCUDAOutOfMemoryError: If a CUDA OOM error occurs.
@@ -115,16 +140,24 @@ def handle_oom_errors(
115140
try:
116141
yield
117142
except torch.OutOfMemoryError as e:
118-
n_samples = X.shape[0] if hasattr(X, "shape") else len(X)
143+
n_test_samples = X.shape[0] if hasattr(X, "shape") else len(X)
119144
raise TabPFNCUDAOutOfMemoryError(
120-
e, n_test_samples=n_samples, model_type=model_type
145+
e,
146+
n_train_samples=n_train_samples,
147+
n_test_samples=n_test_samples,
148+
n_features=n_features,
149+
model_type=model_type,
121150
) from None
122151
except RuntimeError as e:
123152
is_mps = any(d.type == "mps" for d in devices)
124153
is_oom = "out of memory" in str(e).lower()
125154
if is_mps and is_oom:
126-
n_samples = X.shape[0] if hasattr(X, "shape") else len(X)
155+
n_test_samples = X.shape[0] if hasattr(X, "shape") else len(X)
127156
raise TabPFNMPSOutOfMemoryError(
128-
e, n_test_samples=n_samples, model_type=model_type
157+
e,
158+
n_train_samples=n_train_samples,
159+
n_test_samples=n_test_samples,
160+
n_features=n_features,
161+
model_type=model_type,
129162
) from None
130163
raise

src/tabpfn/regressor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ class TabPFNRegressor(RegressorMixin, BaseEstimator):
174174
n_features_in_: int
175175
"""The number of features in the input data used during `fit()`."""
176176

177+
n_train_samples_: int
178+
"""The number of training samples used during `fit()`."""
179+
177180
inferred_feature_schema_: FeatureSchema
178181
"""The inferred feature schema. This contains the feature modalities per column,
179182
using heuristics and user-provided indices for categorical features."""
@@ -619,6 +622,7 @@ def _initialize_dataset_preprocessing(
619622
# Set class variables for sklearn compatibility
620623
self.feature_names_in_ = feature_names
621624
self.n_features_in_ = n_features
625+
self.n_train_samples_ = len(X)
622626

623627
feature_schema = detect_feature_modalities(
624628
X=X,
@@ -926,7 +930,13 @@ def predict(
926930
)
927931

928932
# Runs over iteration engine
929-
with handle_oom_errors(self.devices_, X, model_type="regressor"):
933+
with handle_oom_errors(
934+
self.devices_,
935+
X,
936+
model_type="regressor",
937+
n_train_samples=getattr(self, "n_train_samples_", None),
938+
n_features=getattr(self, "n_features_in_", None),
939+
):
930940
(
931941
_,
932942
# list of tensors [N_est, N_samples, N_borders] (after forward)

0 commit comments

Comments
 (0)