-
Notifications
You must be signed in to change notification settings - Fork 131
Expand file tree
/
Copy pathset_layer_param.py
More file actions
40 lines (36 loc) · 1.15 KB
/
Copy pathset_layer_param.py
File metadata and controls
40 lines (36 loc) · 1.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import *
from PIL import Image
import caffe
import sys
import lmdb
from caffe.proto import caffe_pb2
from pittnuts import *
import os
from caffe_apps import *
import caffeparser
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--net_template', type=str, required=True)
parser.add_argument('--layer_type', type=str, required=True)
parser.add_argument('--param_value', type=str, required=True)
args = parser.parse_args()
net_template = args.net_template
layer_type = args.layer_type
param_value = args.param_value
caffe.set_mode_cpu()
net_parser = caffeparser.CaffeProtoParser(net_template)
net_msg = net_parser.readProtoNetFile()
for cur_layer in net_msg.layer:
if 'Sparsify' == cur_layer.type:
cur_layer.sparsify_param.coef = float(param_value)
# save
dirname = os.path.dirname(net_template)
filepath = dirname + "/generated.prototxt"
file = open(filepath, "w")
if not file:
raise IOError("ERROR (" + filepath + ")!")
file.write(str(net_msg))
file.close()