@@ -71,8 +71,7 @@ def __init__(self) -> None:
7171 super ().__init__ ()
7272
7373 def __call__ (self , pic ) -> torch .Tensor :
74- """
75- Args:
74+ """Args:
7675 pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
7776
7877 Returns:
@@ -93,8 +92,7 @@ def __init__(self) -> None:
9392 super ().__init__ ()
9493
9594 def __call__ (self , pic ):
96- """
97- Note: A deep copy of the underlying array is performed.
95+ """Note: A deep copy of the underlying array is performed.
9896
9997 Args:
10098 pic (PIL Image): Image to be converted to tensor.
@@ -185,17 +183,15 @@ def __init__(
185183 random_aspect_prob = 0.0 ,
186184 random_aspect_range = (0.9 , 1.11 ),
187185 ):
188- """
189-
190- Args:
191- size:
192- longest:
193- interpolation:
194- random_scale_prob:
195- random_scale_range:
196- random_scale_area:
197- random_aspect_prob:
198- random_aspect_range:
186+ """Args:
187+ size:
188+ longest:
189+ interpolation:
190+ random_scale_prob:
191+ random_scale_range:
192+ random_scale_area:
193+ random_aspect_prob:
194+ random_aspect_range:
199195 """
200196 if isinstance (size , (list , tuple )):
201197 self .size = tuple (size )
@@ -228,9 +224,7 @@ def get_params(
228224 target_h , target_w = target_size
229225 ratio_h = img_h / target_h
230226 ratio_w = img_w / target_w
231- ratio = max (ratio_h , ratio_w ) * longest + min (ratio_h , ratio_w ) * (
232- 1.0 - longest
233- )
227+ ratio = max (ratio_h , ratio_w ) * longest + min (ratio_h , ratio_w ) * (1.0 - longest )
234228
235229 if random_scale_prob > 0 and random .random () < random_scale_prob :
236230 ratio_factor = random .uniform (random_scale_range [0 ], random_scale_range [1 ])
@@ -260,8 +254,7 @@ def get_params(
260254 return size
261255
262256 def __call__ (self , img ):
263- """
264- Args:
257+ """Args:
265258 img (PIL Image): Image to be cropped and resized.
266259
267260 Returns:
@@ -286,9 +279,7 @@ def __call__(self, img):
286279
287280 def __repr__ (self ):
288281 if isinstance (self .interpolation , (tuple , list )):
289- interpolate_str = " " .join (
290- [interp_mode_to_str (x ) for x in self .interpolation ]
291- )
282+ interpolate_str = " " .join ([interp_mode_to_str (x ) for x in self .interpolation ])
292283 else :
293284 interpolate_str = interp_mode_to_str (self .interpolation )
294285 format_string = self .__class__ .__name__ + "(size={0}" .format (self .size )
@@ -297,7 +288,9 @@ def __repr__(self):
297288 format_string += f", random_scale_prob={ self .random_scale_prob :.3f} "
298289 format_string += f", random_scale_range=({ self .random_scale_range [0 ]:.3f} , { self .random_aspect_range [1 ]:.3f} )"
299290 format_string += f", random_aspect_prob={ self .random_aspect_prob :.3f} "
300- format_string += f", random_aspect_range=({ self .random_aspect_range [0 ]:.3f} , { self .random_aspect_range [1 ]:.3f} ))"
291+ format_string += (
292+ f", random_aspect_range=({ self .random_aspect_range [0 ]:.3f} , { self .random_aspect_range [1 ]:.3f} ))"
293+ )
301294 return format_string
302295
303296
@@ -371,8 +364,8 @@ def get_params(img, scale, ratio):
371364 log_ratio = (math .log (ratio [0 ]), math .log (ratio [1 ]))
372365 aspect_ratio = math .exp (random .uniform (* log_ratio ))
373366
374- target_w = int ( round (math .sqrt (target_area * aspect_ratio ) ))
375- target_h = int ( round (math .sqrt (target_area / aspect_ratio ) ))
367+ target_w = round (math .sqrt (target_area * aspect_ratio ))
368+ target_h = round (math .sqrt (target_area / aspect_ratio ))
376369 if target_w <= img_w and target_h <= img_h :
377370 i = random .randint (0 , img_h - target_h )
378371 j = random .randint (0 , img_w - target_w )
@@ -382,10 +375,10 @@ def get_params(img, scale, ratio):
382375 in_ratio = img_w / img_h
383376 if in_ratio < min (ratio ):
384377 target_w = img_w
385- target_h = int ( round (target_w / min (ratio ) ))
378+ target_h = round (target_w / min (ratio ))
386379 elif in_ratio > max (ratio ):
387380 target_h = img_h
388- target_w = int ( round (target_h * max (ratio ) ))
381+ target_w = round (target_h * max (ratio ))
389382 else : # whole image
390383 target_w = img_w
391384 target_h = img_h
@@ -394,8 +387,7 @@ def get_params(img, scale, ratio):
394387 return i , j , target_h , target_w
395388
396389 def __call__ (self , img ):
397- """
398- Args:
390+ """Args:
399391 img (PIL Image): Image to be cropped and resized.
400392
401393 Returns:
@@ -410,9 +402,7 @@ def __call__(self, img):
410402
411403 def __repr__ (self ):
412404 if isinstance (self .interpolation , (tuple , list )):
413- interpolate_str = " " .join (
414- [interp_mode_to_str (x ) for x in self .interpolation ]
415- )
405+ interpolate_str = " " .join ([interp_mode_to_str (x ) for x in self .interpolation ])
416406 else :
417407 interpolate_str = interp_mode_to_str (self .interpolation )
418408 format_string = self .__class__ .__name__ + "(size={0}" .format (self .size )
@@ -459,8 +449,8 @@ def center_crop_or_pad(
459449 if crop_width == image_width and crop_height == image_height :
460450 return img
461451
462- crop_top = int ( round ((image_height - crop_height ) / 2.0 ) )
463- crop_left = int ( round ((image_width - crop_width ) / 2.0 ) )
452+ crop_top = round ((image_height - crop_height ) / 2.0 )
453+ crop_left = round ((image_width - crop_width ) / 2.0 )
464454 return F .crop (img , crop_top , crop_left , crop_height , crop_width )
465455
466456
@@ -488,16 +478,13 @@ def __init__(
488478 self .padding_mode = padding_mode
489479
490480 def forward (self , img ):
491- """
492- Args:
481+ """Args:
493482 img (PIL Image or Tensor): Image to be cropped.
494483
495484 Returns:
496485 PIL Image or Tensor: Cropped image.
497486 """
498- return center_crop_or_pad (
499- img , self .size , fill = self .fill , padding_mode = self .padding_mode
500- )
487+ return center_crop_or_pad (img , self .size , fill = self .fill , padding_mode = self .padding_mode )
501488
502489 def __repr__ (self ) -> str :
503490 return f"{ self .__class__ .__name__ } (size={ self .size } )"
@@ -554,8 +541,7 @@ def get_params(img, size):
554541 return top , left
555542
556543 def forward (self , img ):
557- """
558- Args:
544+ """Args:
559545 img (PIL Image or Tensor): Image to be cropped.
560546
561547 Returns:
@@ -595,6 +581,7 @@ class RandomErasing:
595581
596582 This variant of RandomErasing is intended to be applied to either a batch
597583 or single image tensor after it has been normalized by dataset mean and std.
584+
598585 Args:
599586 probability: Probability that the Random Erasing operation will be performed.
600587 min_area: Minimum percentage of erased area wrt input image area.
@@ -644,19 +631,13 @@ def _erase(self, img, chan, img_h, img_w, dtype):
644631 if random .random () > self .probability :
645632 return
646633 area = img_h * img_w
647- count = (
648- self .min_count
649- if self .min_count == self .max_count
650- else random .randint (self .min_count , self .max_count )
651- )
634+ count = self .min_count if self .min_count == self .max_count else random .randint (self .min_count , self .max_count )
652635 for _ in range (count ):
653636 for attempt in range (10 ):
654- target_area = (
655- random .uniform (self .min_area , self .max_area ) * area / count
656- )
637+ target_area = random .uniform (self .min_area , self .max_area ) * area / count
657638 aspect_ratio = math .exp (random .uniform (* self .log_aspect_ratio ))
658- h = int ( round (math .sqrt (target_area * aspect_ratio ) ))
659- w = int ( round (math .sqrt (target_area / aspect_ratio ) ))
639+ h = round (math .sqrt (target_area * aspect_ratio ))
640+ w = round (math .sqrt (target_area / aspect_ratio ))
660641 if w < img_w and h < img_h :
661642 top = random .randint (0 , img_h - h )
662643 left = random .randint (0 , img_w - w )
@@ -709,11 +690,7 @@ def patchify_image(
709690 # Reshape image to patches
710691 patches = img .view (c , nh , ph , nw , pw ).permute (1 , 3 , 2 , 4 , 0 )
711692 # [nh, nw, ph, pw, c] -> [nh * nw, ph * pw * c] or [nh * nw, ph, pw, c]
712- patches = (
713- patches .reshape (- 1 , ph * pw * c )
714- if flatten_patches
715- else patches .reshape (- 1 , ph , pw , c )
716- )
693+ patches = patches .reshape (- 1 , ph * pw * c ) if flatten_patches else patches .reshape (- 1 , ph , pw , c )
717694
718695 if include_info :
719696 # Create coordinate indices
@@ -730,18 +707,13 @@ def patchify_image(
730707class Patchify (torch .nn .Module ):
731708 """Transform an image into patches with corresponding coordinates and type indicators."""
732709
733- def __init__ (
734- self , patch_size : Union [int , Tuple [int , int ]], flatten_patches : bool = True
735- ):
710+ def __init__ (self , patch_size : Union [int , Tuple [int , int ]], flatten_patches : bool = True ):
736711 super ().__init__ ()
737- self .patch_size = (
738- patch_size if isinstance (patch_size , tuple ) else (patch_size , patch_size )
739- )
712+ self .patch_size = patch_size if isinstance (patch_size , tuple ) else (patch_size , patch_size )
740713 self .flatten_patches = flatten_patches
741714
742715 def forward (self , img ):
743- """
744- Args:
716+ """Args:
745717 img: A PIL Image or tensor of shape [C, H, W]
746718
747719 Returns:
@@ -755,9 +727,7 @@ def forward(self, img):
755727 # Convert PIL Image to tensor [C, H, W]
756728 img = transforms .functional .to_tensor (img )
757729
758- patches , coord , valid = patchify_image (
759- img , self .patch_size , flatten_patches = self .flatten_patches
760- )
730+ patches , coord , valid = patchify_image (img , self .patch_size , flatten_patches = self .flatten_patches )
761731
762732 return {
763733 "patches" : patches ,
@@ -1005,9 +975,7 @@ def transforms_imagenet_eval(
1005975 # squash mode scales each edge to 1/pct of target, then crops
1006976 # aspect ratio is not preserved, no img lost if crop_pct == 1.0
1007977 tfl += [
1008- transforms .Resize (
1009- scale_size , interpolation = str_to_interp_mode (interpolation )
1010- ),
978+ transforms .Resize (scale_size , interpolation = str_to_interp_mode (interpolation )),
1011979 transforms .CenterCrop (img_size ),
1012980 ]
1013981 elif crop_mode == "border" :
@@ -1023,11 +991,7 @@ def transforms_imagenet_eval(
1023991 # aspect ratio is preserved, crops center within image, no borders are added, image is lost
1024992 if scale_size [0 ] == scale_size [1 ]:
1025993 # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
1026- tfl += [
1027- transforms .Resize (
1028- scale_size [0 ], interpolation = str_to_interp_mode (interpolation )
1029- )
1030- ]
994+ tfl += [transforms .Resize (scale_size [0 ], interpolation = str_to_interp_mode (interpolation ))]
1031995 else :
1032996 # resize the shortest edge to matching target dim for non-square target
1033997 tfl += [ResizeKeepRatio (scale_size )]
0 commit comments