Skip to content

Make GPU transforms more memory efficient#887

Merged
bejaeger merged 7 commits intomainfrom
ben/lower-mem-gpu-transforms
Apr 27, 2026
Merged

Make GPU transforms more memory efficient#887
bejaeger merged 7 commits intomainfrom
ben/lower-mem-gpu-transforms

Conversation

@bejaeger
Copy link
Copy Markdown
Collaborator

@bejaeger bejaeger commented Apr 22, 2026

image

@bejaeger
Copy link
Copy Markdown
Collaborator Author

This change is part of the following stack:

Change managed by git-spice.

@bejaeger bejaeger requested a review from a team as a code owner April 22, 2026 12:59
@bejaeger bejaeger requested review from klemens-floege and removed request for a team and klemens-floege April 22, 2026 12:59
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@bejaeger bejaeger requested a review from oscarkey April 22, 2026 13:00
@bejaeger bejaeger added the no changelog needed PR does not require a changelog entry label Apr 22, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces memory-efficient chunked processing for TorchQuantileTransformer and TorchTruncatedSVD to bound peak memory usage during fitting and transformation. It also adds support for randomized SVD in TorchTruncatedSVD and implements a CPU fallback for the MPS backend to handle unsupported SVD operations. Review feedback suggests optimizing memory usage by replacing tensor allocations with Python scalars in torch.where calls and removing redundant tensor expansions.

Comment thread src/tabpfn/preprocessing/torch/torch_quantile_transformer.py Outdated
Comment thread src/tabpfn/preprocessing/torch/torch_svd.py Outdated
Comment thread src/tabpfn/preprocessing/torch/torch_svd.py Outdated
Comment thread src/tabpfn/preprocessing/torch/torch_svd.py Outdated
@oscarkey
Copy link
Copy Markdown
Contributor

if you're maxed out then totally fine to keep it like this, but did you consider using torch.vmap(chunk_size=blah) rather than implementing the chunking manually? I just checked and it seems to be supported on torch==2.5

@bejaeger
Copy link
Copy Markdown
Collaborator Author

Yep, using it in torch svd now to reduce line count! :)
For torch quantile I stick to the manual way because "PyTorch emits a 'performance drop' warning and falls back to a slower path".

Copy link
Copy Markdown
Contributor

@oscarkey oscarkey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great savings!!

Comment thread src/tabpfn/preprocessing/torch/torch_quantile_transformer.py Outdated
Comment thread src/tabpfn/preprocessing/torch/torch_quantile_transformer.py Outdated
Comment thread src/tabpfn/preprocessing/torch/torch_svd.py Outdated
Comment thread tests/test_torch_preprocessing/test_torch_quantile_transformer.py
@bejaeger bejaeger added this pull request to the merge queue Apr 27, 2026
Merged via the queue into main with commit e758def Apr 27, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

no changelog needed PR does not require a changelog entry

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants