Skip to content

Commit cd0c958

Browse files
author
AndyDeng
committed
clean code
1 parent 72df11e commit cd0c958

1 file changed

Lines changed: 72 additions & 66 deletions

File tree

train_pytorch.py

Lines changed: 72 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,55 @@
3232
random.seed(0)
3333
dtype = torch.cuda.FloatTensor
3434

35+
36+
37+
# Load Hyperparameters
38+
parser = argparse.ArgumentParser()
39+
parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
40+
parser.add_argument('--model', default='pointnet_cls',
41+
help='Model name: pointnet_cls or pointnet_cls_basic [default: pointnet_cls]')
42+
parser.add_argument('--log_dir', default='log', help='Log dir [default: log]')
43+
parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]')
44+
parser.add_argument('--max_epoch', type=int, default=2, help='Epoch to run [default: 250]')
45+
parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 32]')
46+
parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]')
47+
parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]')
48+
parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]')
49+
parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]')
50+
parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.8]')
51+
FLAGS = parser.parse_args()
52+
53+
NUM_POINT = FLAGS.num_point
54+
LEARNING_RATE = FLAGS.learning_rate
55+
GPU_INDEX = FLAGS.gpu
56+
MOMENTUM = FLAGS.momentum
57+
58+
MAX_NUM_POINT = 2048
59+
60+
BN_INIT_DECAY = 0.5
61+
BN_DECAY_DECAY_RATE = 0.5
62+
#BN_DECAY_DECAY_STEP = float(DECAY_STEP)
63+
BN_DECAY_CLIP = 0.99
64+
65+
decay_steps = FLAGS.decay_step
66+
decay_rate = FLAGS.decay_rate
67+
LEARNING_RATE_MIN = 0.00001
68+
69+
NUM_CLASS = 40
70+
#sample_num = 160
71+
BATCH_SIZE = FLAGS.batch_size #32
72+
NUM_EPOCHS = FLAGS.max_epoch
73+
jitter = 0.01
74+
jitter_val = 0.01
75+
76+
rotation_range = [0, math.pi / 18, 0, 'g']
77+
rotation_rage_val = [0, 0, 0, 'u']
78+
order = 'rxyz'
79+
80+
scaling_range = [0.05, 0.05, 0.05, 'g']
81+
scaling_range_val = [0, 0, 0, 'u']
82+
83+
3584
class modelnet40_dataset(Dataset):
3685

3786
def __init__(self, data, labels):
@@ -44,30 +93,32 @@ def __len__(self):
4493
def __getitem__(self, i):
4594
return self.data[i], self.labels[i]
4695

96+
4797
# C_in, C_out, D, N_neighbors, dilution, N_rep, r_indices_func, C_lifted = None, mlp_width = 2
4898
# (a, b, c, d, e) == (C_in, C_out, N_neighbors, dilution, N_rep)
4999
# Abbreviated PointCNN constructor.
50-
AbbPointCNN = lambda a,b,c,d,e: RandPointCNN(a, b, 3, c, d, e, knn_indices_func_gpu)
100+
AbbPointCNN = lambda a, b, c, d, e: RandPointCNN(a, b, 3, c, d, e, knn_indices_func_gpu)
101+
51102

52103
class Classifier(nn.Module):
53104

54105
def __init__(self):
55106
super(Classifier, self).__init__()
56-
57-
self.pcnn1 = AbbPointCNN( 3, 32, 8, 1, -1)
107+
108+
self.pcnn1 = AbbPointCNN(3, 32, 8, 1, -1)
58109
self.pcnn2 = nn.Sequential(
59-
AbbPointCNN( 32, 64, 8, 2, -1),
60-
AbbPointCNN( 64, 96, 8, 4, -1),
61-
AbbPointCNN( 96, 128, 12, 4, 120),
110+
AbbPointCNN(32, 64, 8, 2, -1),
111+
AbbPointCNN(64, 96, 8, 4, -1),
112+
AbbPointCNN(96, 128, 12, 4, 120),
62113
AbbPointCNN(128, 160, 12, 6, 120)
63114
)
64-
115+
65116
self.fcn = nn.Sequential(
66117
Dense(160, 128),
67-
Dense(128, 64, drop_rate = 0.5),
68-
Dense( 64, 40, with_bn = False, activation = None)
118+
Dense(128, 64, drop_rate=0.5),
119+
Dense(64, NUM_CLASS, with_bn=False, activation=None)
69120
)
70-
121+
71122
def forward(self, x):
72123
x = self.pcnn1(x)
73124
if False:
@@ -82,59 +133,14 @@ def forward(self, x):
82133
x = self.pcnn2(x)[1] # grab features
83134

84135
logits = self.fcn(x)
85-
logits_mean = torch.mean(logits, dim = 1)
136+
logits_mean = torch.mean(logits, dim=1)
86137
return logits_mean
87138

88-
# Load Hyperparameters
89-
parser = argparse.ArgumentParser()
90-
parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
91-
parser.add_argument('--model', default='pointnet_cls',
92-
help='Model name: pointnet_cls or pointnet_cls_basic [default: pointnet_cls]')
93-
parser.add_argument('--log_dir', default='log', help='Log dir [default: log]')
94-
parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]')
95-
parser.add_argument('--max_epoch', type=int, default=2, help='Epoch to run [default: 250]')
96-
parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 32]')
97-
parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]')
98-
parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]')
99-
parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]')
100-
parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]')
101-
parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.8]')
102-
FLAGS = parser.parse_args()
103-
104-
NUM_POINT = FLAGS.num_point
105-
lr = FLAGS.learning_rate
106-
GPU_INDEX = FLAGS.gpu
107-
MOMENTUM = FLAGS.momentum
108-
109-
MAX_NUM_POINT = 2048
110-
111-
BN_INIT_DECAY = 0.5
112-
BN_DECAY_DECAY_RATE = 0.5
113-
#BN_DECAY_DECAY_STEP = float(DECAY_STEP)
114-
BN_DECAY_CLIP = 0.99
115-
116-
117-
num_class = 40
118-
#sample_num = 160
119-
BATCH_SIZE = FLAGS.batch_size #32
120-
num_epochs = FLAGS.max_epoch
121-
jitter = 0.01
122-
jitter_val = 0.01
123-
124-
rotation_range = [0, math.pi / 18, 0, 'g']
125-
rotation_rage_val = [0, 0, 0, 'u']
126-
order = 'rxyz'
127-
128-
scaling_range = [0.05, 0.05, 0.05, 'g']
129-
scaling_range_val = [0, 0, 0, 'u']
130139

131140
print("------Building model-------")
132141
model = Classifier().cuda()
133142
print("------Successfully Built model-------")
134143

135-
decay_steps = FLAGS.decay_step
136-
decay_rate = FLAGS.decay_rate
137-
lr_min = 0.00001
138144

139145
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
140146
loss_fn = nn.CrossEntropyLoss()
@@ -144,19 +150,19 @@ def forward(self, x):
144150
#model_save_dir = os.path.join(CURRENT_DIR, "models", "mnist2")
145151
#os.makedirs(model_save_dir, exist_ok = True)
146152

147-
TRAIN_FILES = provider.getDataFiles( \
148-
os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt'))
149-
TEST_FILES = provider.getDataFiles(\
150-
os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt'))
153+
TRAIN_FILES = provider.getDataFiles(os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt'))
154+
TEST_FILES = provider.getDataFiles(os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt'))
151155

152156
losses = []
153157
accuracies = []
154158

159+
'''
155160
if False:
156161
latest_model = sorted(os.listdir(model_save_dir))[-1]
157162
model.load_state_dict(torch.load(os.path.join(model_save_dir, latest_model)))
158-
159-
for epoch in range(1, num_epochs+1):
163+
'''
164+
165+
for epoch in range(1, NUM_EPOCHS+1):
160166
train_file_idxs = np.arange(0, len(TRAIN_FILES))
161167
np.random.shuffle(train_file_idxs)
162168

@@ -176,10 +182,10 @@ def forward(self, x):
176182
loss_sum = 0
177183

178184
if epoch > 1:
179-
lr *= decay_rate ** (global_step // decay_steps)
180-
if lr > lr_min:
181-
print("NEW LEARNING RATE:", lr)
182-
optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum = 0.9)
185+
LEARNING_RATE *= decay_rate ** (global_step // decay_steps)
186+
if LEARNING_RATE > LEARNING_RATE_MIN:
187+
print("NEW LEARNING RATE:", LEARNING_RATE)
188+
optimizer = torch.optim.SGD(model.parameters(), lr = LEARNING_RATE, momentum = 0.9)
183189

184190
for batch_idx in range(num_batches):
185191
start_idx = batch_idx * BATCH_SIZE

0 commit comments

Comments
 (0)