@@ -42,8 +42,9 @@ class TorchTruncatedSVD:
4242 """Truncated SVD for PyTorch tensors.
4343
4444 Similar to sklearn's TruncatedSVD but without any implicit state.
45- The state is returned explicitly. Uses full SVD and truncates to
46- n_components (efficient for typical TabPFN dimensions).
45+ The state is returned explicitly. Uses randomized SVD
46+ (``torch.svd_lowrank``) for large matrices and exact
47+ ``torch.linalg.svd`` for small ones.
4748
4849 Note: Unlike sklearn's TruncatedSVD, this does not center the data.
4950 If centering is needed, apply it before calling fit.
@@ -60,6 +61,11 @@ def __init__(self, n_components: int) -> None:
6061 def fit (self , x : torch .Tensor ) -> dict [str , torch .Tensor ]:
6162 """Compute the truncated SVD on the training data.
6263
64+ Uses randomized SVD (``torch.svd_lowrank``) when ``n_components`` is
65+ much smaller than the matrix dimensions. This reduces memory from
66+ O(N * min(N, F)) to O(N * n_components) — a large saving when
67+ n_components << min(N, F) (e.g. 128 vs 500).
68+
6369 Args:
6470 x: Input tensor with shape [n_samples, n_features].
6571
@@ -69,6 +75,17 @@ def fit(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
6975 [n_components, n_features]
7076 - "singular_values": The singular values [n_components]
7177 """
78+ orig_device = x .device
79+
80+ if x .device .type == "mps" :
81+ warnings .warn (
82+ "SVD operators ('aten::linalg_svd', 'aten::linalg_qr') are not "
83+ "currently supported on the MPS backend and will fall back to "
84+ "run on the CPU. This may have performance implications." ,
85+ stacklevel = 2 ,
86+ )
87+ x = x .cpu ()
88+
7289 n_samples , n_features = x .shape
7390
7491 # Handle NaN values by replacing with 0 for SVD computation
@@ -79,25 +96,43 @@ def fit(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
7996 n_components = min (self .n_components , n_samples , n_features )
8097 n_components = max (1 , n_components )
8198
82- # torch.linalg.svd requires float32 or float64; cast up if needed
99+ # torch SVD ops require float32 or float64; cast up if needed
83100 compute_dtype = x_filled .dtype
84101 if compute_dtype not in (torch .float32 , torch .float64 ):
85102 compute_dtype = torch .float32
86103 x_filled = x_filled .to (compute_dtype )
87104
88- # Compute full SVD: X = U @ diag(S) @ V^T
89- # torch.linalg.svd returns V^T directly (not V)
90- with warnings .catch_warnings ():
91- warnings .filterwarnings (
92- "ignore" ,
93- message = ".*linalg_svd.*not currently supported on the MPS backend.*" ,
94- )
95- u , s , vh = torch .linalg .svd (x_filled , full_matrices = False )
105+ # Use randomized SVD only when it is both (a) accurate — the matrix
106+ # is large enough that the top components are well-separated — and
107+ # (b) faster — the projected rank q is well below the matrix rank.
108+ # Benchmark shows svd_lowrank becomes favorable once
109+ # min(N, F) >= 2*q; below that, the random-projection overhead
110+ # exceeds the savings from avoiding the full decomposition.
111+ oversampling = 10 # matches sklearn's TruncatedSVD n_oversamples default
112+ q = n_components + oversampling
113+ use_lowrank = (
114+ n_samples * n_features > 1_000_000 and min (n_samples , n_features ) >= 2 * q
115+ )
96116
97- # Truncate to n_components
98- u = u [:, :n_components ]
99- s = s [:n_components ]
100- vh = vh [:n_components , :]
117+ if use_lowrank :
118+ # torch.svd_lowrank returns (U, S, V) with A ≈ U diag(S) V^T
119+ u , s , v = torch .svd_lowrank (x_filled , q = q , niter = 2 )
120+ # Truncate oversampling dimensions
121+ u = u [:, :n_components ]
122+ s = s [:n_components ]
123+ vh = v [:, :n_components ].T # V [n_features, q] → V^T [n_comp, n_features]
124+ else :
125+ # Fall back to full SVD for small matrices or when n_components
126+ # is close to min(n_samples, n_features).
127+ with warnings .catch_warnings (): # warning thrown above already
128+ warnings .filterwarnings (
129+ "ignore" ,
130+ message = ".*linalg_svd.*not currently supported on the MPS backend.*" , # noqa: E501
131+ )
132+ u , s , vh = torch .linalg .svd (x_filled , full_matrices = False )
133+ u = u [:, :n_components ]
134+ s = s [:n_components ]
135+ vh = vh [:n_components , :]
101136
102137 # Apply sign flip for deterministic output.
103138 # We use the same convention as sklearn (u_based_decision=False:
@@ -108,8 +143,8 @@ def fit(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
108143 u , vh = _svd_flip_stable (u , vh )
109144
110145 return {
111- "components" : vh ,
112- "singular_values" : s ,
146+ "components" : vh . to ( orig_device ) ,
147+ "singular_values" : s . to ( orig_device ) ,
113148 }
114149
115150 def transform (
@@ -119,6 +154,9 @@ def transform(
119154 ) -> torch .Tensor :
120155 """Project the data onto the SVD components.
121156
157+ Automatically processes in row-chunks when the data is large to keep
158+ peak intermediate memory bounded.
159+
122160 Args:
123161 x: Input tensor to transform [n_samples, n_features].
124162 fitted_cache: Cache returned by fit.
@@ -130,27 +168,36 @@ def transform(
130168 raise ValueError ("Invalid fitted cache. Must contain 'components'." )
131169
132170 components = fitted_cache ["components" ]
133-
134- # Components may be in float32 (from fit) while input is float16.
135- # Compute in the components dtype, then cast back.
136171 orig_dtype = x .dtype
137172 compute_dtype = components .dtype
173+ chunk_size = self ._get_transform_chunk_size (x , components )
138174
139- # Handle NaN values: preserve them in output
140- x_compute = x .to (compute_dtype ) if x .dtype != compute_dtype else x
141- nan_mask = torch .isnan (x_compute )
142- x_filled = torch .where (nan_mask , torch .zeros_like (x_compute ), x_compute )
175+ def _per_row (row : torch .Tensor ) -> torch .Tensor :
176+ x_c = row .to (compute_dtype )
177+ nan_mask = torch .isnan (x_c )
178+ x_filled = torch .where (nan_mask , torch .zeros_like (x_c ), x_c )
179+ out = x_filled @ components .T
180+ return torch .where (nan_mask .any (), float ("nan" ), out )
143181
144- # Project: X @ V (V = components.T)
145- result = x_filled @ components .T
146-
147- # If any input feature was NaN, the corresponding output should be NaN
148- # Since projection is a linear combination, any NaN input affects all outputs
149- any_nan_per_row = nan_mask .any (dim = - 1 , keepdim = True )
150- nan_fill = torch .tensor (float ("nan" ), device = x .device , dtype = compute_dtype )
151- result = torch .where (any_nan_per_row .expand_as (result ), nan_fill , result )
182+ result = torch .vmap (_per_row , chunk_size = chunk_size )(x )
152183 return result .to (orig_dtype ) if orig_dtype != compute_dtype else result
153184
185+ def _get_transform_chunk_size (
186+ self , x : torch .Tensor , components : torch .Tensor
187+ ) -> int :
188+ """Compute a row-chunk size that keeps intermediate memory bounded.
189+
190+ Transform creates ~3x N*F intermediates (dtype cast, nan_mask,
191+ x_filled) plus the N*C result. Target ~2 GB of temporaries.
192+ """
193+ n_features = x .shape [- 1 ]
194+ n_components = components .shape [0 ]
195+ element_size = max (x .element_size (), components .element_size ())
196+ # x_compute + nan_mask(~1B) + x_filled + result
197+ bytes_per_row = (3 * n_features + n_components ) * element_size
198+ target_bytes = 2 * 1024 ** 3 # 2 GB
199+ return max (1_000 , target_bytes // max (bytes_per_row , 1 ))
200+
154201 def __call__ (
155202 self ,
156203 x : torch .Tensor ,
@@ -241,6 +288,9 @@ def transform(
241288 ) -> torch .Tensor :
242289 """Apply the fitted scaling to the data (no mean centering).
243290
291+ Automatically processes in row-chunks when the data is large to keep
292+ peak intermediate memory bounded.
293+
244294 Args:
245295 x: Input tensor to transform.
246296 fitted_cache: Cache returned by fit.
@@ -252,40 +302,43 @@ def transform(
252302 raise ValueError ("Invalid fitted cache. Must contain 'std'." )
253303
254304 std = fitted_cache ["std" ]
255-
256- # Align dtype: std is in float32+ from fit, input may be float16
257305 orig_dtype = x .dtype
258306 compute_dtype = std .dtype
259- x_compute = x . to ( compute_dtype ) if x . dtype != compute_dtype else x
260-
261- # Replace inf with nan before scaling
262- x_safe = torch . where (
263- torch . isinf ( x_compute ),
264- torch . tensor ( float ( "nan" ), device = x . device , dtype = compute_dtype ),
265- x_compute ,
307+ chunk_size = self . _get_transform_chunk_size (
308+ x , compute_element_size = max ( x . element_size (), std . element_size ())
309+ )
310+ col_means = (
311+ fitted_cache [ "mean" ]. to ( device = x . device , dtype = compute_dtype )
312+ if "mean" in fitted_cache
313+ else None
266314 )
267315
268- # Impute NaN with column means (matching CPU make_scaler_safe which
269- # wraps the scaler with SimpleImputer(strategy="mean") pre/post).
270- if "mean" in fitted_cache :
271- nan_mask = torch .isnan (x_safe )
272- if nan_mask .any ():
273- col_means = fitted_cache ["mean" ].to (
274- device = x_safe .device , dtype = x_safe .dtype
275- )
276- x_safe = torch .where (nan_mask , col_means .unsqueeze (0 ), x_safe )
277-
278- x_scaled = x_safe / (std + torch .finfo (std .dtype ).eps )
316+ def _per_row (row : torch .Tensor ) -> torch .Tensor :
317+ x_c = row .to (compute_dtype )
318+ x_safe = torch .where (torch .isinf (x_c ), float ("nan" ), x_c )
319+ # Impute NaN with column means (matching CPU make_scaler_safe).
320+ if col_means is not None :
321+ nan_mask = torch .isnan (x_safe )
322+ x_safe = torch .where (nan_mask , col_means , x_safe )
323+ x_scaled = x_safe / (std + torch .finfo (compute_dtype ).eps )
324+ x_scaled = torch .clip (x_scaled , min = - 100 , max = 100 )
325+ return torch .where (torch .isfinite (x_scaled ), x_scaled , 0 )
326+
327+ result = torch .vmap (_per_row , chunk_size = chunk_size )(x )
328+ return result .to (orig_dtype ) if orig_dtype != compute_dtype else result
279329
280- # Clip very large values
281- x_scaled = torch .clip (x_scaled , min = - 100 , max = 100 )
330+ def _get_transform_chunk_size (
331+ self , x : torch .Tensor , compute_element_size : int
332+ ) -> int :
333+ """Compute a row-chunk size that keeps intermediate memory bounded.
282334
283- # Replace any inf that might have been created with nan, then impute
284- # remaining non-finite values (matching CPU post-imputation safety net)
285- result = torch .where (
286- torch .isfinite (x_scaled ), x_scaled , torch .zeros_like (x_scaled )
287- )
288- return result .to (orig_dtype ) if orig_dtype != compute_dtype else result
335+ Transform creates ~5x N*F intermediates (dtype cast, isinf check,
336+ nan_mask, imputed, scaled, clipped, finite-check). Target ~2 GB.
337+ """
338+ n_features = x .shape [- 1 ] if x .ndim > 1 else 1
339+ bytes_per_row = n_features * compute_element_size * 5
340+ target_bytes = 2 * 1024 ** 3 # 2 GB
341+ return max (1_000 , target_bytes // max (bytes_per_row , 1 ))
289342
290343 def __call__ (
291344 self ,
0 commit comments