@@ -36,7 +36,7 @@ def get_transform(dataset, base_size=None, crop_size=224, rand_aug=False, etrans
3636 CenterCrop (crop_size ),
3737 ])
3838 train_transforms .extend ([
39- RandomHorizontalFlip (),
39+ RandomHorizontalFlip (),
4040 ColorJitter (0.4 , 0.4 , 0.4 ),
4141 ToTensor (),
4242 Lighting (0.1 , _imagenet_pca ['eigval' ], _imagenet_pca ['eigvec' ]),
@@ -65,16 +65,16 @@ def get_transform(dataset, base_size=None, crop_size=224, rand_aug=False, etrans
6565 normalize ,
6666 ])
6767 elif dataset == 'cifar10' :
68- transform_train = transforms . Compose ([
69- transforms . RandomCrop (32 , padding = 4 ),
70- transforms . RandomHorizontalFlip (),
71- transforms . ToTensor (),
72- transforms . Normalize ((0.4914 , 0.4822 , 0.4465 ),
73- (0.2023 , 0.1994 , 0.2010 )),
68+ transform_train = Compose ([
69+ RandomCrop (32 , padding = 4 ),
70+ RandomHorizontalFlip (),
71+ ToTensor (),
72+ Normalize ((0.4914 , 0.4822 , 0.4465 ),
73+ (0.2023 , 0.1994 , 0.2010 )),
7474 ])
75- transform_val = transforms . Compose ([
76- transforms . ToTensor (),
77- transforms . Normalize ((0.4914 , 0.4822 , 0.4465 ),
75+ transform_val = Compose ([
76+ ToTensor (),
77+ Normalize ((0.4914 , 0.4822 , 0.4465 ),
7878 (0.2023 , 0.1994 , 0.2010 )),
7979 ])
8080 return transform_train , transform_val
0 commit comments