Skip to content

Commit cd418c7

Browse files
authored
Merge pull request #16 from KMCzajkowski/main
rename csv path var name to match train function
2 parents 1145efa + acea20c commit cd418c7

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

timm/data/loader.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def create_loader(
229229
worker_seeding: str = 'all',
230230
tf_preprocessing: bool = False,
231231
balance_classes: bool = False,
232-
dataset_csv_path: Optional[str] = None
232+
samples_csv_path: Optional[str] = None
233233
):
234234
"""
235235
@@ -274,7 +274,7 @@ def create_loader(
274274
worker_seeding: Control worker random seeding at init.
275275
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports.
276276
balance_classes: Sample classes with uniform probability
277-
dataset_csv_path: Path to dataset csv, used for class balancing
277+
samples_csv_path: Path to dataset csv, used for class balancing
278278
279279
Returns:
280280
DataLoader
@@ -333,9 +333,9 @@ def create_loader(
333333
else:
334334
assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
335335
if balance_classes:
336-
assert dataset_csv_path, "Provide csv with labels to use balance_classes."
337-
dataset_csv = pd.read_csv(dataset_csv_path)
338-
all_labels = dataset_csv["label"].values
336+
assert samples_csv_path, "Provide csv with labels to use balance_classes."
337+
samples_csv = pd.read_csv(samples_csv_path)
338+
all_labels = samples_csv["label"].values
339339
unique, counts = np.unique(all_labels, return_counts=True)
340340
unique_counts = {v: c for v, c in zip(unique, counts)}
341341
label_weights = np.array([1 / unique_counts[num] for num in all_labels])

0 commit comments

Comments
 (0)