Skip to content

Commit e758def

Browse files
authored
Make GPU transforms more memory efficient (#887)
1 parent 8163f50 commit e758def

4 files changed

Lines changed: 336 additions & 66 deletions

File tree

src/tabpfn/preprocessing/torch/torch_quantile_transformer.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,78 @@ def fit(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
6161
references = torch.linspace(
6262
0, 1, n_quantiles_effective, device=x.device, dtype=compute_dtype
6363
)
64-
quantiles = torch.nanquantile(x.to(compute_dtype), references, dim=0)
64+
65+
x_compute = x.to(compute_dtype)
66+
quantiles = self._nanquantile_chunked(x_compute, references)
67+
6568
# Ensure monotonicity (handle floating point issues)
6669
# Use cumulative maximum along the quantile dimension
6770
quantiles = torch.cummax(quantiles, dim=0).values
6871

6972
return {"quantiles": quantiles, "references": references}
7073

74+
def _nanquantile_chunked(
75+
self,
76+
x: torch.Tensor,
77+
references: torch.Tensor,
78+
) -> torch.Tensor:
79+
"""Compute nanquantile in column-chunks to bound peak memory.
80+
81+
Each column is independent so we process subsets of columns to keep
82+
peak memory at O(N * chunk_cols) instead of O(N * F). The chunk size
83+
is determined by ``_get_fit_chunk_cols``.
84+
"""
85+
original_shape = x.shape # [N, ...]
86+
n_samples = x.shape[0]
87+
88+
# Flatten trailing dimensions to [N, F]
89+
x_flat = x.reshape(n_samples, -1) if x.ndim > 2 else x
90+
91+
n_features = x_flat.shape[1] if x_flat.ndim > 1 else 1
92+
chunk_cols = self._get_fit_chunk_cols(x_flat)
93+
process_all_at_once = n_features <= chunk_cols
94+
95+
if process_all_at_once:
96+
quantiles = torch.nanquantile(x, references, dim=0)
97+
else:
98+
chunks = []
99+
for col_start in range(0, n_features, chunk_cols):
100+
q_chunk = torch.nanquantile(
101+
x_flat[:, col_start : col_start + chunk_cols], references, dim=0
102+
)
103+
chunks.append(q_chunk)
104+
quantiles = torch.cat(chunks, dim=-1)
105+
106+
if x.ndim > 2:
107+
quantiles = quantiles.reshape(len(references), *original_shape[1:])
108+
109+
return quantiles
110+
111+
def _get_fit_chunk_cols(self, x_flat: torch.Tensor) -> int:
112+
"""Compute a column-chunk size that keeps nanquantile peak memory bounded.
113+
114+
``torch.nanquantile`` internally sorts the data along dim=0, requiring
115+
roughly 10x the input size as temporary memory (sort buffer + indexing).
116+
We target ~2 GB of intermediates so that 100k x 500 fits in one shot
117+
while 1M-row datasets stay bounded at ~2 GB of workspace per chunk.
118+
"""
119+
n_samples = x_flat.shape[0]
120+
element_size = x_flat.element_size()
121+
overhead_factor = 10
122+
target_bytes = 2 * 1024**3 # 2 GB
123+
bytes_per_col = n_samples * element_size * overhead_factor
124+
return max(1, target_bytes // max(bytes_per_col, 1))
125+
71126
def transform(
72127
self,
73128
x: torch.Tensor,
74129
fitted_cache: dict[str, torch.Tensor],
75130
) -> torch.Tensor:
76131
"""Transform the data to uniform distribution using fitted quantiles.
77132
133+
Automatically processes in row-chunks when the data is large to keep
134+
peak intermediate memory bounded (~2 GB of temporaries).
135+
78136
Args:
79137
x: Input tensor to transform.
80138
fitted_cache: Cache returned by fit.
@@ -94,14 +152,48 @@ def transform(
94152
# Compute in the cache dtype, then cast back.
95153
orig_dtype = x.dtype
96154
compute_dtype = quantiles.dtype
97-
x_compute = x.to(compute_dtype) if x.dtype != compute_dtype else x
155+
x_compute = x.to(compute_dtype)
156+
157+
chunk_size = self._get_transform_chunk_size(x_compute)
158+
n_samples = x_compute.shape[0]
159+
160+
if n_samples <= chunk_size:
161+
result = self._transform_chunk(x_compute, quantiles, references)
162+
else:
163+
chunks = []
164+
for start in range(0, n_samples, chunk_size):
165+
chunk = x_compute[start : start + chunk_size]
166+
chunks.append(self._transform_chunk(chunk, quantiles, references))
167+
result = torch.cat(chunks, dim=0)
98168

99-
nan_mask = torch.isnan(x_compute)
100-
result = self._interpolate(x_compute, quantiles, references)
101-
nan_fill = torch.tensor(float("nan"), device=x.device, dtype=compute_dtype)
102-
result = torch.where(nan_mask, nan_fill, result)
103169
return result.to(orig_dtype) if orig_dtype != compute_dtype else result
104170

171+
def _transform_chunk(
172+
self,
173+
x: torch.Tensor,
174+
quantiles: torch.Tensor,
175+
references: torch.Tensor,
176+
) -> torch.Tensor:
177+
"""Transform a single chunk, preserving NaN positions."""
178+
nan_mask = torch.isnan(x)
179+
result = self._interpolate(x, quantiles, references)
180+
return torch.where(nan_mask, float("nan"), result)
181+
182+
def _get_transform_chunk_size(self, x: torch.Tensor) -> int:
183+
"""Compute a row-chunk size that keeps intermediate memory bounded.
184+
185+
``_interpolate`` creates ~15x intermediate memory per input element
186+
(forward + backward searchsorted, gather, slope tensors). We target
187+
~2 GB of intermediates so that even very large datasets stay within
188+
reasonable GPU memory.
189+
"""
190+
n_features = x.shape[-1] if x.ndim > 1 else 1
191+
element_size = x.element_size()
192+
overhead_factor = 15
193+
target_bytes = 2 * 1024**3 # 2 GB
194+
bytes_per_row = n_features * element_size * overhead_factor
195+
return max(1_000, target_bytes // max(bytes_per_row, 1))
196+
105197
def _interpolate(
106198
self,
107199
x: torch.Tensor,

src/tabpfn/preprocessing/torch/torch_svd.py

Lines changed: 113 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ class TorchTruncatedSVD:
4242
"""Truncated SVD for PyTorch tensors.
4343
4444
Similar to sklearn's TruncatedSVD but without any implicit state.
45-
The state is returned explicitly. Uses full SVD and truncates to
46-
n_components (efficient for typical TabPFN dimensions).
45+
The state is returned explicitly. Uses randomized SVD
46+
(``torch.svd_lowrank``) for large matrices and exact
47+
``torch.linalg.svd`` for small ones.
4748
4849
Note: Unlike sklearn's TruncatedSVD, this does not center the data.
4950
If centering is needed, apply it before calling fit.
@@ -60,6 +61,11 @@ def __init__(self, n_components: int) -> None:
6061
def fit(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
6162
"""Compute the truncated SVD on the training data.
6263
64+
Uses randomized SVD (``torch.svd_lowrank``) when ``n_components`` is
65+
much smaller than the matrix dimensions. This reduces memory from
66+
O(N * min(N, F)) to O(N * n_components) — a large saving when
67+
n_components << min(N, F) (e.g. 128 vs 500).
68+
6369
Args:
6470
x: Input tensor with shape [n_samples, n_features].
6571
@@ -69,6 +75,17 @@ def fit(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
6975
[n_components, n_features]
7076
- "singular_values": The singular values [n_components]
7177
"""
78+
orig_device = x.device
79+
80+
if x.device.type == "mps":
81+
warnings.warn(
82+
"SVD operators ('aten::linalg_svd', 'aten::linalg_qr') are not "
83+
"currently supported on the MPS backend and will fall back to "
84+
"run on the CPU. This may have performance implications.",
85+
stacklevel=2,
86+
)
87+
x = x.cpu()
88+
7289
n_samples, n_features = x.shape
7390

7491
# Handle NaN values by replacing with 0 for SVD computation
@@ -79,25 +96,43 @@ def fit(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
7996
n_components = min(self.n_components, n_samples, n_features)
8097
n_components = max(1, n_components)
8198

82-
# torch.linalg.svd requires float32 or float64; cast up if needed
99+
# torch SVD ops require float32 or float64; cast up if needed
83100
compute_dtype = x_filled.dtype
84101
if compute_dtype not in (torch.float32, torch.float64):
85102
compute_dtype = torch.float32
86103
x_filled = x_filled.to(compute_dtype)
87104

88-
# Compute full SVD: X = U @ diag(S) @ V^T
89-
# torch.linalg.svd returns V^T directly (not V)
90-
with warnings.catch_warnings():
91-
warnings.filterwarnings(
92-
"ignore",
93-
message=".*linalg_svd.*not currently supported on the MPS backend.*",
94-
)
95-
u, s, vh = torch.linalg.svd(x_filled, full_matrices=False)
105+
# Use randomized SVD only when it is both (a) accurate — the matrix
106+
# is large enough that the top components are well-separated — and
107+
# (b) faster — the projected rank q is well below the matrix rank.
108+
# Benchmark shows svd_lowrank becomes favorable once
109+
# min(N, F) >= 2*q; below that, the random-projection overhead
110+
# exceeds the savings from avoiding the full decomposition.
111+
oversampling = 10 # matches sklearn's TruncatedSVD n_oversamples default
112+
q = n_components + oversampling
113+
use_lowrank = (
114+
n_samples * n_features > 1_000_000 and min(n_samples, n_features) >= 2 * q
115+
)
96116

97-
# Truncate to n_components
98-
u = u[:, :n_components]
99-
s = s[:n_components]
100-
vh = vh[:n_components, :]
117+
if use_lowrank:
118+
# torch.svd_lowrank returns (U, S, V) with A ≈ U diag(S) V^T
119+
u, s, v = torch.svd_lowrank(x_filled, q=q, niter=2)
120+
# Truncate oversampling dimensions
121+
u = u[:, :n_components]
122+
s = s[:n_components]
123+
vh = v[:, :n_components].T # V [n_features, q] → V^T [n_comp, n_features]
124+
else:
125+
# Fall back to full SVD for small matrices or when n_components
126+
# is close to min(n_samples, n_features).
127+
with warnings.catch_warnings(): # warning thrown above already
128+
warnings.filterwarnings(
129+
"ignore",
130+
message=".*linalg_svd.*not currently supported on the MPS backend.*", # noqa: E501
131+
)
132+
u, s, vh = torch.linalg.svd(x_filled, full_matrices=False)
133+
u = u[:, :n_components]
134+
s = s[:n_components]
135+
vh = vh[:n_components, :]
101136

102137
# Apply sign flip for deterministic output.
103138
# We use the same convention as sklearn (u_based_decision=False:
@@ -108,8 +143,8 @@ def fit(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
108143
u, vh = _svd_flip_stable(u, vh)
109144

110145
return {
111-
"components": vh,
112-
"singular_values": s,
146+
"components": vh.to(orig_device),
147+
"singular_values": s.to(orig_device),
113148
}
114149

115150
def transform(
@@ -119,6 +154,9 @@ def transform(
119154
) -> torch.Tensor:
120155
"""Project the data onto the SVD components.
121156
157+
Automatically processes in row-chunks when the data is large to keep
158+
peak intermediate memory bounded.
159+
122160
Args:
123161
x: Input tensor to transform [n_samples, n_features].
124162
fitted_cache: Cache returned by fit.
@@ -130,27 +168,36 @@ def transform(
130168
raise ValueError("Invalid fitted cache. Must contain 'components'.")
131169

132170
components = fitted_cache["components"]
133-
134-
# Components may be in float32 (from fit) while input is float16.
135-
# Compute in the components dtype, then cast back.
136171
orig_dtype = x.dtype
137172
compute_dtype = components.dtype
173+
chunk_size = self._get_transform_chunk_size(x, components)
138174

139-
# Handle NaN values: preserve them in output
140-
x_compute = x.to(compute_dtype) if x.dtype != compute_dtype else x
141-
nan_mask = torch.isnan(x_compute)
142-
x_filled = torch.where(nan_mask, torch.zeros_like(x_compute), x_compute)
175+
def _per_row(row: torch.Tensor) -> torch.Tensor:
176+
x_c = row.to(compute_dtype)
177+
nan_mask = torch.isnan(x_c)
178+
x_filled = torch.where(nan_mask, torch.zeros_like(x_c), x_c)
179+
out = x_filled @ components.T
180+
return torch.where(nan_mask.any(), float("nan"), out)
143181

144-
# Project: X @ V (V = components.T)
145-
result = x_filled @ components.T
146-
147-
# If any input feature was NaN, the corresponding output should be NaN
148-
# Since projection is a linear combination, any NaN input affects all outputs
149-
any_nan_per_row = nan_mask.any(dim=-1, keepdim=True)
150-
nan_fill = torch.tensor(float("nan"), device=x.device, dtype=compute_dtype)
151-
result = torch.where(any_nan_per_row.expand_as(result), nan_fill, result)
182+
result = torch.vmap(_per_row, chunk_size=chunk_size)(x)
152183
return result.to(orig_dtype) if orig_dtype != compute_dtype else result
153184

185+
def _get_transform_chunk_size(
186+
self, x: torch.Tensor, components: torch.Tensor
187+
) -> int:
188+
"""Compute a row-chunk size that keeps intermediate memory bounded.
189+
190+
Transform creates ~3x N*F intermediates (dtype cast, nan_mask,
191+
x_filled) plus the N*C result. Target ~2 GB of temporaries.
192+
"""
193+
n_features = x.shape[-1]
194+
n_components = components.shape[0]
195+
element_size = max(x.element_size(), components.element_size())
196+
# x_compute + nan_mask(~1B) + x_filled + result
197+
bytes_per_row = (3 * n_features + n_components) * element_size
198+
target_bytes = 2 * 1024**3 # 2 GB
199+
return max(1_000, target_bytes // max(bytes_per_row, 1))
200+
154201
def __call__(
155202
self,
156203
x: torch.Tensor,
@@ -241,6 +288,9 @@ def transform(
241288
) -> torch.Tensor:
242289
"""Apply the fitted scaling to the data (no mean centering).
243290
291+
Automatically processes in row-chunks when the data is large to keep
292+
peak intermediate memory bounded.
293+
244294
Args:
245295
x: Input tensor to transform.
246296
fitted_cache: Cache returned by fit.
@@ -252,40 +302,43 @@ def transform(
252302
raise ValueError("Invalid fitted cache. Must contain 'std'.")
253303

254304
std = fitted_cache["std"]
255-
256-
# Align dtype: std is in float32+ from fit, input may be float16
257305
orig_dtype = x.dtype
258306
compute_dtype = std.dtype
259-
x_compute = x.to(compute_dtype) if x.dtype != compute_dtype else x
260-
261-
# Replace inf with nan before scaling
262-
x_safe = torch.where(
263-
torch.isinf(x_compute),
264-
torch.tensor(float("nan"), device=x.device, dtype=compute_dtype),
265-
x_compute,
307+
chunk_size = self._get_transform_chunk_size(
308+
x, compute_element_size=max(x.element_size(), std.element_size())
309+
)
310+
col_means = (
311+
fitted_cache["mean"].to(device=x.device, dtype=compute_dtype)
312+
if "mean" in fitted_cache
313+
else None
266314
)
267315

268-
# Impute NaN with column means (matching CPU make_scaler_safe which
269-
# wraps the scaler with SimpleImputer(strategy="mean") pre/post).
270-
if "mean" in fitted_cache:
271-
nan_mask = torch.isnan(x_safe)
272-
if nan_mask.any():
273-
col_means = fitted_cache["mean"].to(
274-
device=x_safe.device, dtype=x_safe.dtype
275-
)
276-
x_safe = torch.where(nan_mask, col_means.unsqueeze(0), x_safe)
277-
278-
x_scaled = x_safe / (std + torch.finfo(std.dtype).eps)
316+
def _per_row(row: torch.Tensor) -> torch.Tensor:
317+
x_c = row.to(compute_dtype)
318+
x_safe = torch.where(torch.isinf(x_c), float("nan"), x_c)
319+
# Impute NaN with column means (matching CPU make_scaler_safe).
320+
if col_means is not None:
321+
nan_mask = torch.isnan(x_safe)
322+
x_safe = torch.where(nan_mask, col_means, x_safe)
323+
x_scaled = x_safe / (std + torch.finfo(compute_dtype).eps)
324+
x_scaled = torch.clip(x_scaled, min=-100, max=100)
325+
return torch.where(torch.isfinite(x_scaled), x_scaled, 0)
326+
327+
result = torch.vmap(_per_row, chunk_size=chunk_size)(x)
328+
return result.to(orig_dtype) if orig_dtype != compute_dtype else result
279329

280-
# Clip very large values
281-
x_scaled = torch.clip(x_scaled, min=-100, max=100)
330+
def _get_transform_chunk_size(
331+
self, x: torch.Tensor, compute_element_size: int
332+
) -> int:
333+
"""Compute a row-chunk size that keeps intermediate memory bounded.
282334
283-
# Replace any inf that might have been created with nan, then impute
284-
# remaining non-finite values (matching CPU post-imputation safety net)
285-
result = torch.where(
286-
torch.isfinite(x_scaled), x_scaled, torch.zeros_like(x_scaled)
287-
)
288-
return result.to(orig_dtype) if orig_dtype != compute_dtype else result
335+
Transform creates ~5x N*F intermediates (dtype cast, isinf check,
336+
nan_mask, imputed, scaled, clipped, finite-check). Target ~2 GB.
337+
"""
338+
n_features = x.shape[-1] if x.ndim > 1 else 1
339+
bytes_per_row = n_features * compute_element_size * 5
340+
target_bytes = 2 * 1024**3 # 2 GB
341+
return max(1_000, target_bytes // max(bytes_per_row, 1))
289342

290343
def __call__(
291344
self,

0 commit comments

Comments
 (0)