Skip to content

Commit 477ad5d

Browse files
authored
add isinstance check of Sampler to FileSampler in MultiSampler class (#177)
* switch check of Sampler to FileSampler * add back the Sampler check
1 parent 4313219 commit 477ad5d

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

selene_sdk/samplers/multi_sampler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch.utils.data import DataLoader
99

1010
from .sampler import Sampler
11+
from .file_samplers import FileSampler
1112

1213

1314
def MultiFileSampler(*args, **kwargs):
@@ -88,11 +89,12 @@ def __init__(self,
8889
features,
8990
save_datasets=save_datasets,
9091
output_dir=output_dir)
91-
9292
self._samplers = {
93-
"train": train_sampler if isinstance(train_sampler, Sampler) \
93+
"train": train_sampler if (isinstance(train_sampler, FileSampler) or
94+
isinstance(train_sampler, Sampler)) \
9495
else None,
95-
"validate": validate_sampler if isinstance(validate_sampler, Sampler) \
96+
"validate": validate_sampler if (isinstance(validate_sampler, FileSampler) or
97+
isinstance(validate_sampler, Sampler)) \
9698
else None
9799
}
98100

@@ -115,7 +117,8 @@ def __init__(self,
115117
if test_sampler is not None:
116118
self.modes.append("test")
117119
self._samplers["test"] = \
118-
test_sampler if isinstance(test_sampler, Sampler) else None
120+
test_sampler if (isinstance(test_sampler, FileSampler) or
121+
isinstance(test_sampler, Sampler)) else None
119122
self._dataloaders["test"] = \
120123
test_sampler if isinstance(test_sampler, DataLoader) else None
121124
self._iterators["test"] = iter(self._dataloaders["test"]) \

0 commit comments

Comments
 (0)