Skip to content

Commit 102265a

Browse files
committed
Update DataLoader to ignore graphs with < 2 DOM hits
1 parent 91c6b62 commit 102265a

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

src/gnn_reco/models/training/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,17 @@ def make_dataloader(
3434
selection=selection,
3535
)
3636

37+
def collate_fn(graphs):
38+
# Remove graphs with less than two DOM hits. Should not occur in "production."
39+
graphs = [g for g in graphs if g.n_pulses > 1]
40+
return Batch.from_data_list(graphs)
41+
3742
dataloader = DataLoader(
3843
dataset,
3944
batch_size=batch_size,
4045
shuffle=shuffle,
4146
num_workers=num_workers,
42-
collate_fn=Batch.from_data_list,
47+
collate_fn=collate_fn,
4348
persistent_workers=persistent_workers,
4449
prefetch_factor=2,
4550
)

0 commit comments

Comments
 (0)