|
14 | 14 | import torch |
15 | 15 | from monai.data import CacheDataset, Dataset |
16 | 16 | from monai.transforms import Compose |
17 | | -from torch.utils.data import DataLoader |
| 17 | +from torch.utils.data import DataLoader, Sampler |
18 | 18 |
|
19 | 19 |
|
20 | 20 | class ConnectomicsDataModule(pl.LightningDataModule): |
@@ -125,21 +125,26 @@ def val_dataloader(self): |
125 | 125 | return self._create_dataloader(self.val_dataset, shuffle=False) |
126 | 126 |
|
127 | 127 | def test_dataloader(self): |
| 128 | + sampler = None |
| 129 | + if self.test_dataset is not None and _is_distributed_evaluation_active(): |
| 130 | + sampler = DistributedEvaluationSampler(self.test_dataset) |
128 | 131 | return self._create_dataloader( |
129 | 132 | self.test_dataset, |
130 | 133 | shuffle=False, |
131 | 134 | collate_fn=collate_dict_list, |
| 135 | + sampler=sampler, |
132 | 136 | ) |
133 | 137 |
|
134 | | - def _create_dataloader(self, dataset, shuffle, collate_fn=None): |
| 138 | + def _create_dataloader(self, dataset, shuffle, collate_fn=None, sampler=None): |
135 | 139 | if dataset is None: |
136 | 140 | return None |
137 | 141 | if collate_fn is None: |
138 | 142 | collate_fn = collate_dict |
139 | 143 | return DataLoader( |
140 | 144 | dataset=dataset, |
141 | 145 | batch_size=self.batch_size, |
142 | | - shuffle=shuffle, |
| 146 | + shuffle=shuffle if sampler is None else False, |
| 147 | + sampler=sampler, |
143 | 148 | num_workers=self.num_workers, |
144 | 149 | pin_memory=self.pin_memory, |
145 | 150 | persistent_workers=(self.persistent_workers and self.num_workers > 0), |
@@ -189,6 +194,45 @@ def __getitem__(self, index): |
189 | 194 | return self.dataset[index % len(self.dataset)] |
190 | 195 |
|
191 | 196 |
|
| 197 | +def _is_distributed_evaluation_active() -> bool: |
| 198 | + return torch.distributed.is_available() and torch.distributed.is_initialized() |
| 199 | + |
| 200 | + |
| 201 | +class DistributedEvaluationSampler(Sampler[int]): |
| 202 | + """Shard evaluation samples across DDP ranks without padding or duplication.""" |
| 203 | + |
| 204 | + def __init__( |
| 205 | + self, |
| 206 | + dataset, |
| 207 | + *, |
| 208 | + rank: Optional[int] = None, |
| 209 | + world_size: Optional[int] = None, |
| 210 | + ): |
| 211 | + if rank is None or world_size is None: |
| 212 | + if not _is_distributed_evaluation_active(): |
| 213 | + raise RuntimeError( |
| 214 | + "DistributedEvaluationSampler requires an initialized distributed process " |
| 215 | + "group or explicit rank/world_size." |
| 216 | + ) |
| 217 | + rank = torch.distributed.get_rank() |
| 218 | + world_size = torch.distributed.get_world_size() |
| 219 | + |
| 220 | + if world_size <= 0: |
| 221 | + raise ValueError(f"world_size must be positive, got {world_size}.") |
| 222 | + if rank < 0 or rank >= world_size: |
| 223 | + raise ValueError(f"rank must satisfy 0 <= rank < world_size, got {rank}/{world_size}.") |
| 224 | + |
| 225 | + self.rank = int(rank) |
| 226 | + self.world_size = int(world_size) |
| 227 | + self.indices = list(range(len(dataset)))[self.rank :: self.world_size] |
| 228 | + |
| 229 | + def __iter__(self): |
| 230 | + return iter(self.indices) |
| 231 | + |
| 232 | + def __len__(self): |
| 233 | + return len(self.indices) |
| 234 | + |
| 235 | + |
192 | 236 | def collate_dict( |
193 | 237 | batch: List[Dict[str, Any]], |
194 | 238 | ) -> Dict[str, Any]: |
@@ -226,6 +270,7 @@ def collate_dict_list( |
226 | 270 |
|
227 | 271 | __all__ = [ |
228 | 272 | "ConnectomicsDataModule", |
| 273 | + "DistributedEvaluationSampler", |
229 | 274 | "SimpleDataModule", |
230 | 275 | "collate_dict", |
231 | 276 | "collate_dict_list", |
|
0 commit comments