Skip to content

CUDA tensor inputs #918

@tvatter

Description

@tvatter

Describe the workflow you want to enable

I'd like to call TabPFNRegressor.fit(X, y) and
TabPFNRegressor.predict(X, ...) with CUDA torch.Tensor inputs when
the regressor is configured with device="cuda", without having to
.cpu() the inputs at every call site.

Concrete use case: I'm building a PyTorch-native pipeline on top of
TabPFN, i.e. a conditional density / copula estimator that uses
predict(output_type="full") heavily and keeps logits, the criterion
output, downstream Jacobians, and trapezoidal-integration tensors all
GPU-resident. The only point in the pipeline that forces a cpu copy is
the call into TabPFN itself:

import torch
from tabpfn import TabPFNRegressor
from tabpfn.constants import ModelVersion

m = TabPFNRegressor.create_default_for_version(
    ModelVersion.V2_5, device="cuda"
)

X_cuda = torch.randn(500, 3, device="cuda")
y_cuda = torch.randn(500, device="cuda")

m.fit(X_cuda, y_cuda)
# tabpfn.errors.TabPFNValidationError: can't convert cuda:0 device type
# tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

m.predict(X_cuda, output_type="full")  # same error

CPU tensors and numpy arrays both work today; CUDA tensors are the only common PyTorch input type that's rejected. Since the model is already configured for device="cuda" and the forward pass runs there, the public API requiring CPU input is surprising. Every PyTorch user I've shown this to expects CUDA tensors to "just work".

Describe your proposed solution

In tabpfn/validation.py, coerce CUDA tensors to CPU inside the validators before they reach sklearn's validate_data. Minimal patch:

import torch

def _to_validatable(arr):
    if isinstance(arr, torch.Tensor) and arr.is_cuda:
        return arr.detach().cpu()
    return arr

# inside ensure_compatible_fit_inputs_sklearn (and the predict-side
# validator), at the very top:
X = _to_validatable(X)
y = _to_validatable(y) if y is not None else None
# ... existing validate_data call unchanged from here

This is fully backward compatible. CPU tensors and numpy arrays go through the existing path unchanged. The to(device) step that already happens before the forward pass is unchanged. The only difference is that callers no longer have to write .cpu() themselves.

If the goal is zero unnecessary cuda ↔ cpu transfers on a CUDA input, that's strictly better but a larger refactor of the preprocessing pipeline (SafePowerTransformer, the ensemble preprocessing, etc.) which today operates on numpy. Even the small patch above would resolve the API friction.

Today, with a CUDA input and device="cuda", the data path is:

CUDA tensor (caller)
  → cpu (sklearn validate_data, via np.asarray)
  → cpu (TabPFN preprocessing: SafePowerTransformer, ensemble preproc, all numpy)
  → cuda (.to(device) for forward pass)
  → cuda (logits / criterion output)

Two transfers (cuda → cpu and back) happen for every fit and every
predict. For PyTorch-native pipelines, both are pure overhead — the
caller already has the data on the device the model wants to compute on.

Concretely, two changes inside TabPFN:

  1. Input validation accepts CUDA tensors without copying.
    tabpfn.validation.ensure_compatible_fit_inputs_sklearn (and the
    predict-side validator) should validate shape / dtype / finiteness
    without going through np.asarray(). For a torch.Tensor, use
    torch.isfinite / tensor.shape / tensor.dtype directly. This
    removes the cuda → cpu copy entirely.

  2. Preprocessing runs on the input's device. The preprocessing
    pipeline (SafePowerTransformer, the ensemble preprocessing, any
    feature scaling) currently relies on numpy / scikit-learn. For
    torch-tensor inputs it should dispatch to torch equivalents (or use
    the Array API if you
    prefer a single code path). Power transforms, scaling, and quantile
    estimation all have straightforward torch implementations. This
    keeps the data on device from the caller's hand all the way to
    the forward pass.

After both changes, the data path becomes:

CUDA tensor (caller)
  → CUDA (validation, no copy)
  → CUDA (preprocessing, torch ops)
  → CUDA (forward pass)
  → CUDA (logits / criterion output)

i.e., which is what users with device="cuda" reasonably expect.

If a full preprocessing port is too large for a single PR, the changes
could land incrementally: (a) validation that accepts CUDA without
copying first, then (b) preprocessing modules ported one at a time,
falling back to CPU when a torch-native implementation isn't available
yet. The validation-only change alone wouldn't actually save transfers
(preprocessing would still bounce through CPU), so it's only worthwhile
as the first step of a larger plan, not as a standalone fix.

Describe alternatives you've considered, if relevant

  • Caller-side .cpu() wrappers. Works, but adds boilerplate at every call site and makes it impossible to drop TabPFN into a standard PyTorch pipeline without a custom adapter.
  • Monkey-patching tabpfn.validation. Fragile; breaks silently on upgrades and feels wrong for a public API.
  • Subclassing TabPFNRegressor and overriding the validators. Same fragility, slightly more contained. Still hits numpy preprocessing inside the parent class.
  • Calling TabPFN's internal forward methods directly, bypassing predict. Reproduces too much of TabPFN's preprocessing pipeline in user code, real risk of subtle correctness bugs across versions.
  • Waiting for sklearn Array API support to mature. Long-term this would let the whole pipeline stay on GPU, but it's not yet a realistic dependency for TabPFN.

The proposed in-validator coercion is the smallest possible fix that removes the user-facing friction without any downstream changes.

Additional context

No response

Impact

Medium (Significant enhancement)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions