@@ -229,7 +229,7 @@ def create_loader(
229229 worker_seeding : str = 'all' ,
230230 tf_preprocessing : bool = False ,
231231 balance_classes : bool = False ,
232- dataset_csv_path : Optional [str ] = None
232+ samples_csv_path : Optional [str ] = None
233233):
234234 """
235235
@@ -274,7 +274,7 @@ def create_loader(
274274 worker_seeding: Control worker random seeding at init.
275275 tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports.
276276 balance_classes: Sample classes with uniform probability
277- dataset_csv_path : Path to dataset csv, used for class balancing
277+ samples_csv_path : Path to dataset csv, used for class balancing
278278
279279 Returns:
280280 DataLoader
@@ -333,9 +333,9 @@ def create_loader(
333333 else :
334334 assert num_aug_repeats == 0 , "RepeatAugment not currently supported in non-distributed or IterableDataset use"
335335 if balance_classes :
336- assert dataset_csv_path , "Provide csv with labels to use balance_classes."
337- dataset_csv = pd .read_csv (dataset_csv_path )
338- all_labels = dataset_csv ["label" ].values
336+ assert samples_csv_path , "Provide csv with labels to use balance_classes."
337+ samples_csv = pd .read_csv (samples_csv_path )
338+ all_labels = samples_csv ["label" ].values
339339 unique , counts = np .unique (all_labels , return_counts = True )
340340 unique_counts = {v : c for v , c in zip (unique , counts )}
341341 label_weights = np .array ([1 / unique_counts [num ] for num in all_labels ])
0 commit comments