Skip to content

Fix DataParallel issues for GP-VAE, FILM, and FITS on multi-GPU setups#819

Draft
Claude wants to merge 2 commits into
devfrom
claude/fix-gp-vae-fail-multi-cuda
Draft

Fix DataParallel issues for GP-VAE, FILM, and FITS on multi-GPU setups#819
Claude wants to merge 2 commits into
devfrom
claude/fix-gp-vae-fail-multi-cuda

Conversation

@Claude
Copy link
Copy Markdown

@Claude Claude AI commented Mar 16, 2026

Multi-GPU training with torch.nn.DataParallel fails for GP-VAE, FILM, and FITS models with runtime errors related to distribution initialization, complex tensor operations, and einsum dimension mismatches.

Changes

GP-VAE (pypots/nn/modules/gpvae/backbone.py)

  • Pre-compute kernel matrices in __init__ and register as buffer via register_buffer("prior_covariance", ...)
  • Create prior distribution fresh per forward pass in new _get_prior() method instead of caching
  • Eliminates "lazy wrapper should be called at most once" error from shared distribution state across GPU replicas

FITS (pypots/nn/modules/fits/backbone.py)

  • Remove .to(torch.cfloat) from Linear layer initialization
  • Split complex FFT outputs into real/imaginary components, apply Linear transformations separately, then recombine with torch.complex()
  • Resolves "t() expects tensor with <= 2 dimensions" error from complex dtype parameters

FILM (pypots/nn/modules/film/layers.py)

  • Decompose complex einsum operations: (a+bi)(c+di) = (ac-bd) + (ad+bc)i
  • Apply einsum to real/imaginary parts separately, then recombine
  • Fixes "einsum() subscript mismatch" error in DataParallel context

Technical rationale

DataParallel replicates models to each GPU but struggles with:

  1. Shared mutable state (distribution objects)
  2. Non-standard parameter dtypes (complex Linear layers)
  3. Complex tensor operations in certain contexts

Solution uses buffers for device-aware tensors and explicit real/imaginary arithmetic, consistent with patterns from PR #633 (Koopa, USGAN, CRLI fixes).

Usage

# Single GPU (existing behavior unchanged)
model = GPVAE(n_steps, n_features, latent_size, device="cuda:0")

# Multiple GPUs (now works)
model = GPVAE(n_steps, n_features, latent_size, device=["cuda:0", "cuda:1"])
Original prompt

This section details on the original issue you should resolve

<issue_title>FITS/FILM/GP-VAE fail when running on multiple CUDA devices</issue_title>
<issue_description>### 1. System Info

v0.11

2. Information

  • The official example scripts
  • My own created scripts

3. Reproduction

  • pypots.clustering.crli
  • pypots.imputation.usgan
  • pypots.imputation.koopa
  • pypots.imputation.film
  • pypots.imputation.gpvae
  • pypots.imputation.fits
  • pypots.forecasting.fits

4. Expected behavior

For pypots.forecasting.fits and pypots.imputation.fits we have

E       RuntimeError: Caught RuntimeError in replica 0 on device 1.
E       Original Traceback (most recent call last):
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
E           output = module(*input, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
E           return forward_call(*args, **kwargs)
E         File "/home/wdudu/PyPOTS_dev/pypots/forecasting/fits/core.py", line 68, in forward
E           enc_out = self.backbone(enc_out)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
E           return forward_call(*args, **kwargs)
E         File "/home/wdudu/PyPOTS_dev/pypots/nn/modules/fits/backbone.py", line 63, in forward
E           low_specxy_ = self.freq_upsampler(low_specx.permute(0, 2, 1))
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
E           return forward_call(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
E           return F.linear(input, self.weight, self.bias)
E       RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D

For pypots.imputation.film we have

E       RuntimeError: Caught RuntimeError in replica 0 on device 1.
E       Original Traceback (most recent call last):
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
E           output = module(*input, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
E           return forward_call(*args, **kwargs)
E         File "/home/wdudu/PyPOTS_dev/pypots/imputation/film/core.py", line 65, in forward
E           backbone_output = self.backbone(X_embedding)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
E           return forward_call(*args, **kwargs)
E         File "/home/wdudu/PyPOTS_dev/pypots/nn/modules/film/backbone.py", line 65, in forward
E           out1 = self.spec_conv_1[i](x_in_c)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
E           return forward_call(*args, **kwargs)
E         File "/home/wdudu/PyPOTS_dev/pypots/nn/modules/film/layers.py", line 128, in forward
E           out_ft[:, :, :, : self.modes2] = torch.einsum("bjix,iox->bjox", a, self.weights1)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/functional.py", line 380, in einsum
E           return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
E       RuntimeError: einsum(): the number of subscripts in the equation (3) does not match the number of dimensions (4) for operand 1 and no ellipsis was given

For pypots.imputation.gpvae we have

E       RuntimeError: Caught RuntimeError in replica 1 on device 2.
E       Original Traceback (most recent call last):
E         File...

</details>



<!-- START COPILOT CODING AGENT SUFFIX -->

- Fixes WenjieDu/PyPOTS#632

- GP-VAE: Pre-compute kernel matrices as buffers and create prior distribution on-the-fly to avoid DataParallel replication issues
- FITS: Handle complex numbers by splitting into real/imaginary parts instead of using .to(torch.cfloat) on Linear layers
- FILM: Handle complex einsum operations by splitting into real/imaginary parts to avoid DataParallel issues with complex tensors

Co-authored-by: WenjieDu <17807970+WenjieDu@users.noreply.github.com>
@Claude Claude AI changed the title [WIP] [FITS-123] Fix GP-VAE fail when running on multiple CUDA devices Fix DataParallel issues for GP-VAE, FILM, and FITS on multi-GPU setups Mar 16, 2026
@Claude Claude AI requested a review from WenjieDu March 16, 2026 16:51
@github-actions
Copy link
Copy Markdown

This pull request had no activity for 14 days. It will be closed in 1 week unless there is some new activity.

@github-actions github-actions Bot added the stale label Mar 31, 2026
@github-actions github-actions Bot closed this Apr 7, 2026
@WenjieDu WenjieDu reopened this Apr 7, 2026
@WenjieDu WenjieDu added keep Keep this issue away from being stale. and removed stale labels Apr 7, 2026
@WenjieDu WenjieDu changed the base branch from main to dev April 26, 2026 07:36
@WenjieDu WenjieDu closed this Apr 26, 2026
@WenjieDu WenjieDu reopened this Apr 26, 2026
@sonarqubecloud
Copy link
Copy Markdown

@coveralls
Copy link
Copy Markdown
Collaborator

coveralls commented Apr 26, 2026

Coverage Report for CI Build 24951327727

Coverage decreased (-0.03%) to 79.961%

Details

  • Coverage decreased (-0.03%) from the base build.
  • Patch coverage: 13 uncovered changes across 3 files (28 of 41 lines covered, 68.29%).
  • No coverage regressions found.

Uncovered Changes

File Changed Covered %
pypots/nn/modules/film/layers.py 15 8 53.33%
pypots/nn/modules/fits/backbone.py 9 5 55.56%
pypots/nn/modules/gpvae/backbone.py 17 15 88.24%

Coverage Regressions

No coverage regressions found.


Coverage Stats

Coverage Status
Relevant Lines: 19028
Covered Lines: 15215
Line Coverage: 79.96%
Coverage Strength: 1.6 hits per line

💛 - Coveralls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

keep Keep this issue away from being stale.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants