Skip to content

Commit b8d83b0

Browse files
transforms (#272)
1 parent f70fa97 commit b8d83b0

3 files changed

Lines changed: 18 additions & 15 deletions

File tree

encoding/lib/cpu/operator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ py::array_t<float> apply_transform(int H, int W, int C, py::array_t<float> img,
9393
auto ctm_buf = ctm.request();
9494

9595
// printf("H: %d, W: %d, C: %d\n", H, W, C);
96-
py::array_t<float> result{img_buf.size};
96+
py::array_t<float> result{(unsigned long)img_buf.size};
9797
auto res_buf = result.request();
9898

9999
float *img_ptr = (float *)img_buf.ptr;

encoding/transforms/get_transform.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_transform(dataset, base_size=None, crop_size=224, rand_aug=False, etrans
3636
CenterCrop(crop_size),
3737
])
3838
train_transforms.extend([
39-
RandomHorizontalFlip(),
39+
RandomHorizontalFlip(),
4040
ColorJitter(0.4, 0.4, 0.4),
4141
ToTensor(),
4242
Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']),
@@ -65,16 +65,16 @@ def get_transform(dataset, base_size=None, crop_size=224, rand_aug=False, etrans
6565
normalize,
6666
])
6767
elif dataset == 'cifar10':
68-
transform_train = transforms.Compose([
69-
transforms.RandomCrop(32, padding=4),
70-
transforms.RandomHorizontalFlip(),
71-
transforms.ToTensor(),
72-
transforms.Normalize((0.4914, 0.4822, 0.4465),
73-
(0.2023, 0.1994, 0.2010)),
68+
transform_train = Compose([
69+
RandomCrop(32, padding=4),
70+
RandomHorizontalFlip(),
71+
ToTensor(),
72+
Normalize((0.4914, 0.4822, 0.4465),
73+
(0.2023, 0.1994, 0.2010)),
7474
])
75-
transform_val = transforms.Compose([
76-
transforms.ToTensor(),
77-
transforms.Normalize((0.4914, 0.4822, 0.4465),
75+
transform_val = Compose([
76+
ToTensor(),
77+
Normalize((0.4914, 0.4822, 0.4465),
7878
(0.2023, 0.1994, 0.2010)),
7979
])
8080
return transform_train, transform_val

encoding/utils/lr_scheduler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ class LR_Scheduler(object):
2929
iters_per_epoch: number of iterations per epoch
3030
"""
3131
def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0,
32-
lr_step=0, warmup_epochs=0):
32+
lr_step=0, warmup_epochs=0, quiet=False):
3333
self.mode = mode
34-
print('Using {} LR scheduler with warm-up epochs of {}!'.format(self.mode, warmup_epochs))
34+
self.quiet = quiet
35+
if not quiet:
36+
print('Using {} LR scheduler with warm-up epochs of {}!'.format(self.mode, warmup_epochs))
3537
if mode == 'step':
3638
assert lr_step
3739
self.base_lr = base_lr
@@ -57,8 +59,9 @@ def __call__(self, optimizer, i, epoch, best_pred):
5759
else:
5860
raise NotImplemented
5961
if epoch > self.epoch and (epoch == 0 or best_pred > 0.0):
60-
print('\n=>Epoch %i, learning rate = %.4f, \
61-
previous best = %.4f' % (epoch, lr, best_pred))
62+
if not self.quiet:
63+
print('\n=>Epoch %i, learning rate = %.4f, \
64+
previous best = %.4f' % (epoch, lr, best_pred))
6265
self.epoch = epoch
6366
assert lr >= 0
6467
self._adjust_learning_rate(optimizer, lr)

0 commit comments

Comments
 (0)