Skip to content

Commit 99543a7

Browse files
committed
Lint.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent a502309 commit 99543a7

7 files changed

Lines changed: 130 additions & 308 deletions

File tree

recipes/vit/.ruff.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
extend = "../.ruff.toml"
22
[lint]
3-
per-file-ignores = { "tokenizer_auto" = ["ALL"] }
4-
ignore = ["RUF","D","N","E","PLW","PERF","C","F"]
3+
ignore = ["D", "N", "C901", "PLW2901"]

recipes/vit/config/vit_te_base_patch16_224.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ model:
99
training:
1010
checkpoint:
1111
path: "./checkpoints/vit_te"
12-
resume_from_metric: "-" # + = Highest Metric (Score), - = Lowest Metric (Loss)
12+
resume_from_metric: "-" # + = Highest Metric (Score), - = Lowest Metric (Loss)

recipes/vit/imagenet_dataset.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def natural_key(string_):
5151

5252

5353
def load_class_map(map_or_filename: str):
54-
"""
55-
Parse a TSV or PKL file that contains a list of class IDs. Then create a class-to-index mapping
54+
"""Parse a TSV or PKL file that contains a list of class IDs. Then create a class-to-index mapping
5655
where the enumerated index will represent the class index when computing the cross-entropy loss.
5756
5857
Args:
@@ -82,8 +81,7 @@ def load_class_map(map_or_filename: str):
8281

8382

8483
def load_image_labels(map_or_filename: str):
85-
"""
86-
Parse a TSV or PKL file that maps image filenames to class IDs.
84+
"""Parse a TSV or PKL file that maps image filenames to class IDs.
8785
8886
Args:
8987
map_or_filename (str): Path to a TSV or PKL file that maps image filenames to class IDs.
@@ -115,8 +113,7 @@ def find_images_and_targets(
115113
sort: bool = True,
116114
class_filter: Optional[List[Any]] = None,
117115
):
118-
"""
119-
Walk folder recursively to discover images and map them to classes by folder names.
116+
"""Walk folder recursively to discover images and map them to classes by folder names.
120117
121118
Args:
122119
folder: root of folder to recursively search
@@ -147,12 +144,12 @@ def find_images_and_targets(
147144
if class_to_idx is None:
148145
# building class index
149146
unique_labels = set(labels)
150-
sorted_labels = list(sorted(unique_labels, key=natural_key))
147+
sorted_labels = sorted(unique_labels, key=natural_key)
151148
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
152149
images_and_targets = [
153-
(f, class_to_idx[l])
154-
for f, l in zip(filenames, labels)
155-
if l in class_to_idx and (class_filter is None or l in class_filter)
150+
(files, class_to_idx[labels])
151+
for files, labels in zip(filenames, labels)
152+
if labels in class_to_idx and (class_filter is None or labels in class_filter)
156153
]
157154
if sort:
158155
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
@@ -214,15 +211,11 @@ def filename(self, index, basename=False, absolute=False):
214211
return self._filename(index, basename=basename, absolute=absolute)
215212

216213
def filenames(self, basename=False, absolute=False):
217-
return [
218-
self._filename(index, basename=basename, absolute=absolute)
219-
for index in range(len(self))
220-
]
214+
return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
221215

222216

223217
class ImageNetDataset(Dataset):
224-
"""
225-
ImageDataset class for loading image datasets from a root directory.
218+
"""ImageDataset class for loading image datasets from a root directory.
226219
227220
Expects the following directory structure:
228221
@@ -281,9 +274,7 @@ def __getitem__(self, index):
281274
try:
282275
img = img.read() if self.load_bytes else Image.open(img)
283276
except Exception as e:
284-
_logger.warning(
285-
f"Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}"
286-
)
277+
_logger.warning(f"Skipped sample (index {index}, file {self.reader.filename(index)}). {e!s}")
287278
self._consecutive_errors += 1
288279
if self._consecutive_errors < 50:
289280
return self.__getitem__((index + 1) % len(self.reader))

recipes/vit/imagenet_utils.py

Lines changed: 40 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ def __init__(self) -> None:
7171
super().__init__()
7272

7373
def __call__(self, pic) -> torch.Tensor:
74-
"""
75-
Args:
74+
"""Args:
7675
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
7776
7877
Returns:
@@ -93,8 +92,7 @@ def __init__(self) -> None:
9392
super().__init__()
9493

9594
def __call__(self, pic):
96-
"""
97-
Note: A deep copy of the underlying array is performed.
95+
"""Note: A deep copy of the underlying array is performed.
9896
9997
Args:
10098
pic (PIL Image): Image to be converted to tensor.
@@ -185,17 +183,15 @@ def __init__(
185183
random_aspect_prob=0.0,
186184
random_aspect_range=(0.9, 1.11),
187185
):
188-
"""
189-
190-
Args:
191-
size:
192-
longest:
193-
interpolation:
194-
random_scale_prob:
195-
random_scale_range:
196-
random_scale_area:
197-
random_aspect_prob:
198-
random_aspect_range:
186+
"""Args:
187+
size:
188+
longest:
189+
interpolation:
190+
random_scale_prob:
191+
random_scale_range:
192+
random_scale_area:
193+
random_aspect_prob:
194+
random_aspect_range:
199195
"""
200196
if isinstance(size, (list, tuple)):
201197
self.size = tuple(size)
@@ -228,9 +224,7 @@ def get_params(
228224
target_h, target_w = target_size
229225
ratio_h = img_h / target_h
230226
ratio_w = img_w / target_w
231-
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (
232-
1.0 - longest
233-
)
227+
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1.0 - longest)
234228

235229
if random_scale_prob > 0 and random.random() < random_scale_prob:
236230
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
@@ -260,8 +254,7 @@ def get_params(
260254
return size
261255

262256
def __call__(self, img):
263-
"""
264-
Args:
257+
"""Args:
265258
img (PIL Image): Image to be cropped and resized.
266259
267260
Returns:
@@ -286,9 +279,7 @@ def __call__(self, img):
286279

287280
def __repr__(self):
288281
if isinstance(self.interpolation, (tuple, list)):
289-
interpolate_str = " ".join(
290-
[interp_mode_to_str(x) for x in self.interpolation]
291-
)
282+
interpolate_str = " ".join([interp_mode_to_str(x) for x in self.interpolation])
292283
else:
293284
interpolate_str = interp_mode_to_str(self.interpolation)
294285
format_string = self.__class__.__name__ + "(size={0}".format(self.size)
@@ -297,7 +288,9 @@ def __repr__(self):
297288
format_string += f", random_scale_prob={self.random_scale_prob:.3f}"
298289
format_string += f", random_scale_range=({self.random_scale_range[0]:.3f}, {self.random_aspect_range[1]:.3f})"
299290
format_string += f", random_aspect_prob={self.random_aspect_prob:.3f}"
300-
format_string += f", random_aspect_range=({self.random_aspect_range[0]:.3f}, {self.random_aspect_range[1]:.3f}))"
291+
format_string += (
292+
f", random_aspect_range=({self.random_aspect_range[0]:.3f}, {self.random_aspect_range[1]:.3f}))"
293+
)
301294
return format_string
302295

303296

@@ -371,8 +364,8 @@ def get_params(img, scale, ratio):
371364
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
372365
aspect_ratio = math.exp(random.uniform(*log_ratio))
373366

374-
target_w = int(round(math.sqrt(target_area * aspect_ratio)))
375-
target_h = int(round(math.sqrt(target_area / aspect_ratio)))
367+
target_w = round(math.sqrt(target_area * aspect_ratio))
368+
target_h = round(math.sqrt(target_area / aspect_ratio))
376369
if target_w <= img_w and target_h <= img_h:
377370
i = random.randint(0, img_h - target_h)
378371
j = random.randint(0, img_w - target_w)
@@ -382,10 +375,10 @@ def get_params(img, scale, ratio):
382375
in_ratio = img_w / img_h
383376
if in_ratio < min(ratio):
384377
target_w = img_w
385-
target_h = int(round(target_w / min(ratio)))
378+
target_h = round(target_w / min(ratio))
386379
elif in_ratio > max(ratio):
387380
target_h = img_h
388-
target_w = int(round(target_h * max(ratio)))
381+
target_w = round(target_h * max(ratio))
389382
else: # whole image
390383
target_w = img_w
391384
target_h = img_h
@@ -394,8 +387,7 @@ def get_params(img, scale, ratio):
394387
return i, j, target_h, target_w
395388

396389
def __call__(self, img):
397-
"""
398-
Args:
390+
"""Args:
399391
img (PIL Image): Image to be cropped and resized.
400392
401393
Returns:
@@ -410,9 +402,7 @@ def __call__(self, img):
410402

411403
def __repr__(self):
412404
if isinstance(self.interpolation, (tuple, list)):
413-
interpolate_str = " ".join(
414-
[interp_mode_to_str(x) for x in self.interpolation]
415-
)
405+
interpolate_str = " ".join([interp_mode_to_str(x) for x in self.interpolation])
416406
else:
417407
interpolate_str = interp_mode_to_str(self.interpolation)
418408
format_string = self.__class__.__name__ + "(size={0}".format(self.size)
@@ -459,8 +449,8 @@ def center_crop_or_pad(
459449
if crop_width == image_width and crop_height == image_height:
460450
return img
461451

462-
crop_top = int(round((image_height - crop_height) / 2.0))
463-
crop_left = int(round((image_width - crop_width) / 2.0))
452+
crop_top = round((image_height - crop_height) / 2.0)
453+
crop_left = round((image_width - crop_width) / 2.0)
464454
return F.crop(img, crop_top, crop_left, crop_height, crop_width)
465455

466456

@@ -488,16 +478,13 @@ def __init__(
488478
self.padding_mode = padding_mode
489479

490480
def forward(self, img):
491-
"""
492-
Args:
481+
"""Args:
493482
img (PIL Image or Tensor): Image to be cropped.
494483
495484
Returns:
496485
PIL Image or Tensor: Cropped image.
497486
"""
498-
return center_crop_or_pad(
499-
img, self.size, fill=self.fill, padding_mode=self.padding_mode
500-
)
487+
return center_crop_or_pad(img, self.size, fill=self.fill, padding_mode=self.padding_mode)
501488

502489
def __repr__(self) -> str:
503490
return f"{self.__class__.__name__}(size={self.size})"
@@ -554,8 +541,7 @@ def get_params(img, size):
554541
return top, left
555542

556543
def forward(self, img):
557-
"""
558-
Args:
544+
"""Args:
559545
img (PIL Image or Tensor): Image to be cropped.
560546
561547
Returns:
@@ -595,6 +581,7 @@ class RandomErasing:
595581
596582
This variant of RandomErasing is intended to be applied to either a batch
597583
or single image tensor after it has been normalized by dataset mean and std.
584+
598585
Args:
599586
probability: Probability that the Random Erasing operation will be performed.
600587
min_area: Minimum percentage of erased area wrt input image area.
@@ -644,19 +631,13 @@ def _erase(self, img, chan, img_h, img_w, dtype):
644631
if random.random() > self.probability:
645632
return
646633
area = img_h * img_w
647-
count = (
648-
self.min_count
649-
if self.min_count == self.max_count
650-
else random.randint(self.min_count, self.max_count)
651-
)
634+
count = self.min_count if self.min_count == self.max_count else random.randint(self.min_count, self.max_count)
652635
for _ in range(count):
653636
for attempt in range(10):
654-
target_area = (
655-
random.uniform(self.min_area, self.max_area) * area / count
656-
)
637+
target_area = random.uniform(self.min_area, self.max_area) * area / count
657638
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
658-
h = int(round(math.sqrt(target_area * aspect_ratio)))
659-
w = int(round(math.sqrt(target_area / aspect_ratio)))
639+
h = round(math.sqrt(target_area * aspect_ratio))
640+
w = round(math.sqrt(target_area / aspect_ratio))
660641
if w < img_w and h < img_h:
661642
top = random.randint(0, img_h - h)
662643
left = random.randint(0, img_w - w)
@@ -709,11 +690,7 @@ def patchify_image(
709690
# Reshape image to patches
710691
patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0)
711692
# [nh, nw, ph, pw, c] -> [nh * nw, ph * pw * c] or [nh * nw, ph, pw, c]
712-
patches = (
713-
patches.reshape(-1, ph * pw * c)
714-
if flatten_patches
715-
else patches.reshape(-1, ph, pw, c)
716-
)
693+
patches = patches.reshape(-1, ph * pw * c) if flatten_patches else patches.reshape(-1, ph, pw, c)
717694

718695
if include_info:
719696
# Create coordinate indices
@@ -730,18 +707,13 @@ def patchify_image(
730707
class Patchify(torch.nn.Module):
731708
"""Transform an image into patches with corresponding coordinates and type indicators."""
732709

733-
def __init__(
734-
self, patch_size: Union[int, Tuple[int, int]], flatten_patches: bool = True
735-
):
710+
def __init__(self, patch_size: Union[int, Tuple[int, int]], flatten_patches: bool = True):
736711
super().__init__()
737-
self.patch_size = (
738-
patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
739-
)
712+
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
740713
self.flatten_patches = flatten_patches
741714

742715
def forward(self, img):
743-
"""
744-
Args:
716+
"""Args:
745717
img: A PIL Image or tensor of shape [C, H, W]
746718
747719
Returns:
@@ -755,9 +727,7 @@ def forward(self, img):
755727
# Convert PIL Image to tensor [C, H, W]
756728
img = transforms.functional.to_tensor(img)
757729

758-
patches, coord, valid = patchify_image(
759-
img, self.patch_size, flatten_patches=self.flatten_patches
760-
)
730+
patches, coord, valid = patchify_image(img, self.patch_size, flatten_patches=self.flatten_patches)
761731

762732
return {
763733
"patches": patches,
@@ -1005,9 +975,7 @@ def transforms_imagenet_eval(
1005975
# squash mode scales each edge to 1/pct of target, then crops
1006976
# aspect ratio is not preserved, no img lost if crop_pct == 1.0
1007977
tfl += [
1008-
transforms.Resize(
1009-
scale_size, interpolation=str_to_interp_mode(interpolation)
1010-
),
978+
transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
1011979
transforms.CenterCrop(img_size),
1012980
]
1013981
elif crop_mode == "border":
@@ -1023,11 +991,7 @@ def transforms_imagenet_eval(
1023991
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
1024992
if scale_size[0] == scale_size[1]:
1025993
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
1026-
tfl += [
1027-
transforms.Resize(
1028-
scale_size[0], interpolation=str_to_interp_mode(interpolation)
1029-
)
1030-
]
994+
tfl += [transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation))]
1031995
else:
1032996
# resize the shortest edge to matching target dim for non-square target
1033997
tfl += [ResizeKeepRatio(scale_size)]

0 commit comments

Comments
 (0)