@@ -60,22 +60,43 @@ def __init__(
6060 self ,
6161 original_error : Exception | None = None ,
6262 * ,
63+ n_train_samples : int | None = None ,
6364 n_test_samples : int | None = None ,
65+ n_features : int | None = None ,
6466 model_type : str = "classifier" ,
6567 ):
6668 predict_method = "predict_proba" if model_type == "classifier" else "predict"
6769
6870 size_info = f" with { n_test_samples :,} test samples" if n_test_samples else ""
6971
72+ size_line = ""
73+ if n_train_samples is not None and n_test_samples is not None :
74+ size_line = (
75+ f"Your sizes: { n_train_samples :,} train / "
76+ f"{ n_test_samples :,} test samples"
77+ )
78+ if n_features is not None :
79+ size_line += f", { n_features } features"
80+ size_line += ".\n "
81+
7082 message = (
7183 f"{ self .device_name } out of memory{ size_info } .\n \n "
72- f"Solution: Split your test data into smaller batches :\n \n "
73- f" batch_size = 1000 # depends on hardware \n "
84+ f"This is issue is usually caused by one of the following two reasons :\n \n "
85+ f"1) Large test set — split into batches: \n \n "
7486 f" predictions = []\n "
75- f" for i in range(0, len(X_test), batch_size):\n "
76- f" batch = model.{ predict_method } (X_test[i:i + batch_size])\n "
77- f" predictions.append(batch)\n "
78- f" predictions = np.vstack(predictions)"
87+ f" for i in range(0, len(X_test), 100):\n "
88+ f" pred = model.{ predict_method } ("
89+ f"X_test[i:i + 100])\n "
90+ f" predictions.append(pred)\n "
91+ f" predictions = np.vstack(predictions)\n \n "
92+ f"2) Large training set — batching won't help.\n "
93+ f" You need subsampling or ensembling, see:\n "
94+ f" https://github.com/PriorLabs/tabpfn-extensions/"
95+ f"blob/main/examples/large_datasets/"
96+ f"large_datasets_example.py\n \n "
97+ f"{ size_line } "
98+ f"Not sure which? If model.{ predict_method } (X_test[:1]) "
99+ f"also fails, it's (2)."
79100 )
80101 if original_error is not None :
81102 message += f"\n \n Original error: { original_error } "
@@ -100,13 +121,17 @@ def handle_oom_errors(
100121 devices : tuple [torch .device , ...],
101122 X : XType ,
102123 model_type : str ,
124+ n_train_samples : int | None = None ,
125+ n_features : int | None = None ,
103126) -> Generator [None , None , None ]:
104127 """Context manager to catch OOM errors and raise helpful TabPFN exceptions.
105128
106129 Args:
107130 devices: The devices the model is running on.
108131 X: The input data (used to get n_samples for the error message).
109132 model_type: Either "classifier" or "regressor".
133+ n_train_samples: Number of training samples (for the error message).
134+ n_features: Number of features (for the error message).
110135
111136 Raises:
112137 TabPFNCUDAOutOfMemoryError: If a CUDA OOM error occurs.
@@ -115,16 +140,24 @@ def handle_oom_errors(
115140 try :
116141 yield
117142 except torch .OutOfMemoryError as e :
118- n_samples = X .shape [0 ] if hasattr (X , "shape" ) else len (X )
143+ n_test_samples = X .shape [0 ] if hasattr (X , "shape" ) else len (X )
119144 raise TabPFNCUDAOutOfMemoryError (
120- e , n_test_samples = n_samples , model_type = model_type
145+ e ,
146+ n_train_samples = n_train_samples ,
147+ n_test_samples = n_test_samples ,
148+ n_features = n_features ,
149+ model_type = model_type ,
121150 ) from None
122151 except RuntimeError as e :
123152 is_mps = any (d .type == "mps" for d in devices )
124153 is_oom = "out of memory" in str (e ).lower ()
125154 if is_mps and is_oom :
126- n_samples = X .shape [0 ] if hasattr (X , "shape" ) else len (X )
155+ n_test_samples = X .shape [0 ] if hasattr (X , "shape" ) else len (X )
127156 raise TabPFNMPSOutOfMemoryError (
128- e , n_test_samples = n_samples , model_type = model_type
157+ e ,
158+ n_train_samples = n_train_samples ,
159+ n_test_samples = n_test_samples ,
160+ n_features = n_features ,
161+ model_type = model_type ,
129162 ) from None
130163 raise
0 commit comments