We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 91c6b62 commit 102265aCopy full SHA for 102265a
1 file changed
src/gnn_reco/models/training/utils.py
@@ -34,12 +34,17 @@ def make_dataloader(
34
selection=selection,
35
)
36
37
+ def collate_fn(graphs):
38
+ # Remove graphs with less than two DOM hits. Should not occur in "production."
39
+ graphs = [g for g in graphs if g.n_pulses > 1]
40
+ return Batch.from_data_list(graphs)
41
+
42
dataloader = DataLoader(
43
dataset,
44
batch_size=batch_size,
45
shuffle=shuffle,
46
num_workers=num_workers,
- collate_fn=Batch.from_data_list,
47
+ collate_fn=collate_fn,
48
persistent_workers=persistent_workers,
49
prefetch_factor=2,
50
0 commit comments