|
16 | 16 | from oneshot_nas_blocks import NasBatchNorm |
17 | 17 |
|
18 | 18 | FLOP_MAX = 1 |
19 | | -PARAM_MAX = 1 |
| 19 | +PARAM_MAX = -1 |
20 | 20 | SCORE_ACC_RATIO = 1 # flop_param_score_weight/acc_weight for fitness |
21 | 21 | BLOCK_CHOICE = None # [0, 0, 3, 1, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 2, 0, 2, 1, 3, 2] |
22 | 22 | CHANNEL_CHOICE = None # [6, 5, 3, 5, 2, 6, 3, 4, 2, 5, 7, 5, 4, 6, 7, 4, 4, 5, 4, 3] |
23 | 23 |
|
| 24 | + |
24 | 25 | def parse_args(): |
25 | 26 | parser = argparse.ArgumentParser(description='Train a model for image classification.') |
26 | 27 |
|
@@ -236,22 +237,22 @@ def update_log(elem, logger=None): |
236 | 237 | """ |
237 | 238 | if logger: |
238 | 239 | logger.info('-' * 40) |
239 | | - logger.info("Overall score: {}".format(elem[0])) |
240 | | - logger.info("Val accuracy: {}".format(elem[1])) |
241 | | - logger.info("Model normalized score: {}.".format(elem[2])) |
242 | | - logger.info('Flops: {} MFLOPS'.format(elem[3])) |
243 | | - logger.info('# parameters: {} M'.format(elem[4])) |
244 | | - logger.info("Block choices: {}".format(elem[5])) |
245 | | - logger.info("Channel choices: {}".format(elem[6])) |
| 240 | + logger.info("Acc/computation balanced score: {}".format(elem[0])) |
| 241 | + logger.info("Val accuracy: {}".format(elem[1])) |
| 242 | + logger.info("Model normalized score: {}.".format(elem[2])) |
| 243 | + logger.info('Flops: {} MFLOPS'.format(elem[3])) |
| 244 | + logger.info('# parameters: {} M'.format(elem[4])) |
| 245 | + logger.info("Block choices: {}".format(elem[5])) |
| 246 | + logger.info("Channel choices: {}".format(elem[6])) |
246 | 247 | else: |
247 | 248 | print('-' * 40) |
248 | | - print("Overall score: {}".format(elem[0])) |
249 | | - print("Val accuracy: {}".format(elem[1])) |
250 | | - print("Model normalized score: {}.".format(elem[2])) |
251 | | - print('Flops: {} MFLOPS'.format(elem[3])) |
252 | | - print('# parameters: {} M'.format(elem[4])) |
253 | | - print("Block choices: {}".format(elem[5])) |
254 | | - print("Channel choices: {}".format(elem[6])) |
| 249 | + print("Acc/computation balanced score: {}".format(elem[0])) |
| 250 | + print("Val accuracy: {}".format(elem[1])) |
| 251 | + print("Model normalized score: {}.".format(elem[2])) |
| 252 | + print('Flops: {} MFLOPS'.format(elem[3])) |
| 253 | + print('# parameters: {} M'.format(elem[4])) |
| 254 | + print("Block choices: {}".format(elem[5])) |
| 255 | + print("Channel choices: {}".format(elem[6])) |
255 | 256 |
|
256 | 257 |
|
257 | 258 | class TopKHeap(object): |
@@ -325,21 +326,30 @@ def create_population(self): |
325 | 326 | get_flop_param_score(block_choices, channel_choices, comparison_model='SinglePathOneShot') |
326 | 327 |
|
327 | 328 | combined_score = 0.5 * flop_score + 0.5 * model_size_score |
328 | | - if flop_score > FLOP_MAX: |
| 329 | + if FLOP_MAX != -1 and flop_score > FLOP_MAX: |
329 | 330 | print("[SKIPPED] Current model normalized score: {}.".format(combined_score)) |
330 | 331 | print("[SKIPPED] Block choices: {}".format(block_choices.asnumpy())) |
331 | 332 | print("[SKIPPED] Channel choices: {}".format(channel_choices)) |
332 | 333 | print('[SKIPPED] Flops: {} MFLOPS'.format(flops)) |
333 | 334 | print('[SKIPPED] # parameters: {} M'.format(model_size)) |
334 | 335 | continue |
335 | | - if model_size_score > PARAM_MAX: |
| 336 | + if PARAM_MAX != -1 and model_size_score > PARAM_MAX: |
336 | 337 | print("[SKIPPED] Current model normalized score: {}.".format(combined_score)) |
337 | 338 | print("[SKIPPED] Block choices: {}".format(block_choices.asnumpy())) |
338 | 339 | print("[SKIPPED] Channel choices: {}".format(channel_choices)) |
339 | 340 | print('[SKIPPED] Flops: {} MFLOPS'.format(flops)) |
340 | 341 | print('[SKIPPED] # parameters: {} M'.format(model_size)) |
341 | 342 | continue |
342 | | - print("Population size + 1, total {}, with normalized score: {}".format(len(population) + 1, combined_score)) |
| 343 | + if combined_score > 1: |
| 344 | + print("[SKIPPED] Current model normalized score: {}.".format(combined_score)) |
| 345 | + print("[SKIPPED] Block choices: {}".format(block_choices.asnumpy())) |
| 346 | + print("[SKIPPED] Channel choices: {}".format(channel_choices)) |
| 347 | + print('[SKIPPED] Flops: {} MFLOPS'.format(flops)) |
| 348 | + print('[SKIPPED] # parameters: {} M'.format(model_size)) |
| 349 | + continue |
| 350 | + |
| 351 | + print("Population size + 1, total {}, with normalized score: {}, flop score: {}, param score: {}" |
| 352 | + .format(len(population) + 1, combined_score, flop_score, model_size_score)) |
343 | 353 | # Add the network to our population. |
344 | 354 | instance['flops'] = flops |
345 | 355 | instance['model_size'] = model_size |
@@ -390,9 +400,7 @@ def breed(self, mother, father): |
390 | 400 | """ |
391 | 401 | children = [] |
392 | 402 | for _ in range(2): |
393 | | - |
394 | 403 | child = {} |
395 | | - |
396 | 404 | # Crossover: loop through the parameters and pick params for the kid. |
397 | 405 | # for param_name in self.param_dict.keys(): |
398 | 406 | # child[param_name] = [0] * len(father[param_name]) |
@@ -444,8 +452,7 @@ def evolve(self, population, topk_items, logger=None): |
444 | 452 | copy.deepcopy(person['block']), copy.deepcopy(person['channel'])) |
445 | 453 | topk_items.push(net_obj) |
446 | 454 | update_log(net_obj, logger) |
447 | | - population = population.sort(key=lambda x: -SCORE_ACC_RATIO * x['score'] + x['acc'], reverse=True) |
448 | | - |
| 455 | + population.sort(key=lambda x: -SCORE_ACC_RATIO * x['score'] + x['acc'], reverse=True) |
449 | 456 | # The parents are every network we want to keep. |
450 | 457 | parents = population[:self.retain_length] |
451 | 458 |
|
@@ -579,7 +586,7 @@ def genetic_search(net, dtype='float32', logger=None, ctx=[mx.cpu()], comparison |
579 | 586 | topk_nets = TopKHeap(topk) # a list of tuple (acc, score, flops, model_size, block_choices, channel_choices) |
580 | 587 |
|
581 | 588 | # set channel and block value list |
582 | | - param_dict = {'channel': [2, 3, 4, 5, 6, 7, 8, 9], |
| 589 | + param_dict = {'channel': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], |
583 | 590 | 'block': [0, 1, 2, 3]} |
584 | 591 |
|
585 | 592 | # evolution |
|
0 commit comments