1+ #! /usr/bin/python
2+ # -*- coding: utf-8 -*-
3+
4+ import tensorlayerx as tlx
5+ from tensorlayerx .nn import Module
6+ from tensorlayerx .nn import Linear , Conv2d , BatchNorm2d , MaxPool2d , Flatten
7+
8+ class CNN (Module ):
9+
10+ def __init__ (self ):
11+ super (CNN , self ).__init__ ()
12+ # weights init
13+ W_init = tlx .nn .initializers .truncated_normal (stddev = 5e-2 )
14+ W_init2 = tlx .nn .initializers .truncated_normal (stddev = 0.04 )
15+ b_init2 = tlx .nn .initializers .constant (value = 0.1 )
16+
17+ self .conv1 = Conv2d (64 , (5 , 5 ), (1 , 1 ), padding = 'SAME' , W_init = W_init , b_init = None , name = 'conv1' , in_channels = 3 , act = tlx .nn .ReLU )
18+ self .bn = BatchNorm2d (num_features = 64 , act = tlx .nn .ReLU )
19+ self .maxpool1 = MaxPool2d ((3 , 3 ), (2 , 2 ), padding = 'SAME' , name = 'pool1' )
20+
21+ self .conv2 = Conv2d (
22+ 64 , (5 , 5 ), (1 , 1 ), padding = 'SAME' , act = tlx .nn .ReLU , W_init = W_init , b_init = None , name = 'conv2' , in_channels = 64
23+ )
24+ self .maxpool2 = MaxPool2d ((3 , 3 ), (2 , 2 ), padding = 'SAME' , name = 'pool2' )
25+
26+ self .flatten = Flatten (name = 'flatten' )
27+ self .linear1 = Linear (384 , act = tlx .nn .ReLU , W_init = W_init2 , b_init = b_init2 , name = 'linear1relu' , in_features = 2304 )
28+ self .linear2 = Linear (192 , act = tlx .nn .ReLU , W_init = W_init2 , b_init = b_init2 , name = 'linear2relu' , in_features = 384 )
29+ self .linear3 = Linear (10 , act = None , W_init = W_init2 , name = 'output1' , in_features = 192 )
30+ self .linear4 = Linear (20 , act = None , W_init = W_init2 , name = 'output2' , in_features = 192 )
31+ self .concat = tlx .nn .Concat (name = 'concat' )
32+
33+ def forward (self , x ):
34+ z = self .conv1 (x )
35+ z = self .bn (z )
36+ z = self .maxpool1 (z )
37+ z = self .conv2 (z )
38+ z = self .maxpool2 (z )
39+ z = self .flatten (z )
40+ z = self .linear1 (z )
41+ z = self .linear2 (z )
42+ z1 = self .linear3 (z )
43+ z2 = self .linear4 (z )
44+ z = self .concat ([z1 , z2 ])
45+ return z
46+
47+ model = CNN ()
48+ inputs = tlx .nn .Input (shape = (3 , 24 , 24 , 3 ))
49+ outputs = model (inputs )
50+
51+ node_by_depth , all_layers = model .build_graph (inputs )
52+
53+ for depth , nodes in enumerate (node_by_depth ):
54+ if depth == 0 :
55+ if isinstance (inputs , list ):
56+ assert len (inputs ) == len (nodes )
57+ for idx , node in enumerate (nodes ):
58+ print (node .node_name , node .layer )
59+ else :
60+ print (nodes [0 ].node_name , nodes [0 ].layer )
61+ else :
62+ for node in nodes :
63+ print (node .node_name , node .layer )
0 commit comments