Add devices parameter for GPU device selection#99
Conversation
Allow callers to specify which GPU devices to use (e.g. devices=[2, 3]) instead of always using devices 0..N. Useful for distributed eval where different processes need different GPUs. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces the ability to specify explicit GPU device indices for local inference by adding a devices parameter to the TabPFNTimeSeriesPipeline. The GPUParallelWorker has been updated to utilize these indices, overriding the default behavior of using all available GPUs. Feedback suggests adding validation for the devices list to ensure it is not empty and contains valid indices, preventing potential runtime errors like division by zero or out-of-range device access.
| if devices is not None: | ||
| self.devices = list(devices) | ||
| else: | ||
| num_gpus = num_gpus or torch.cuda.device_count() | ||
| self.devices = list(range(num_gpus)) |
There was a problem hiding this comment.
The devices parameter should be validated to ensure it is not empty and contains valid GPU indices. Providing an empty list will lead to a ValueError during data splitting (division by zero) or an IndexError when accessing the first device. Additionally, validating that indices are within the range of available GPUs prevents late-stage RuntimeError when calling torch.cuda.set_device.
if devices is not None:
if not devices:
raise ValueError("The 'devices' list cannot be empty.")
self.devices = list(devices)
num_available = torch.cuda.device_count()
if any(d < 0 or d >= num_available for d in self.devices):
raise ValueError(
f"Invalid device index in {devices}. "
f"Available device indices: 0 to {num_available - 1}."
)
else:
num_gpus = num_gpus or torch.cuda.device_count()
self.devices = list(range(num_gpus))Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary
devices: list[int] | Noneparameter toGPUParallelWorker,TimeSeriesPredictor.from_tabpfn_family, andTabPFNTSPipelinedevices=[2, 3]) instead of always using devices 0..Ndevicesis only passed withTabPFNMode.LOCALnum_gpusare unaffectedMotivation
Distributed eval workloads need to pin different pipeline instances to specific GPUs. Previously the only control was
num_gpus, which always selected devices starting from 0.Usage
Test plan
self.devices[0]instead of hardcoded 0self.devicesdeviceswith CLIENT mode raisesValueErrornum_gpus=2still works as before🤖 Generated with Claude Code