88from torch .utils .data import DataLoader
99
1010from .sampler import Sampler
11+ from .file_samplers import FileSampler
1112
1213
1314def MultiFileSampler (* args , ** kwargs ):
@@ -88,11 +89,12 @@ def __init__(self,
8889 features ,
8990 save_datasets = save_datasets ,
9091 output_dir = output_dir )
91-
9292 self ._samplers = {
93- "train" : train_sampler if isinstance (train_sampler , Sampler ) \
93+ "train" : train_sampler if (isinstance (train_sampler , FileSampler ) or
94+ isinstance (train_sampler , Sampler )) \
9495 else None ,
95- "validate" : validate_sampler if isinstance (validate_sampler , Sampler ) \
96+ "validate" : validate_sampler if (isinstance (validate_sampler , FileSampler ) or
97+ isinstance (validate_sampler , Sampler )) \
9698 else None
9799 }
98100
@@ -115,7 +117,8 @@ def __init__(self,
115117 if test_sampler is not None :
116118 self .modes .append ("test" )
117119 self ._samplers ["test" ] = \
118- test_sampler if isinstance (test_sampler , Sampler ) else None
120+ test_sampler if (isinstance (test_sampler , FileSampler ) or
121+ isinstance (test_sampler , Sampler )) else None
119122 self ._dataloaders ["test" ] = \
120123 test_sampler if isinstance (test_sampler , DataLoader ) else None
121124 self ._iterators ["test" ] = iter (self ._dataloaders ["test" ]) \
0 commit comments