-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathtorchscript_utils.py
More file actions
1688 lines (1521 loc) · 69.7 KB
/
torchscript_utils.py
File metadata and controls
1688 lines (1521 loc) · 69.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright The FMS Model Optimizer Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains utils related to torchscript
"""
# pylint: disable=c-extension-no-member
# Standard
from copy import deepcopy
from typing import List, Tuple
import logging
import sys
# Third Party
from transformers.tokenization_utils_base import BatchEncoding
import torch
# Local
from fms_mo.modules import QBmm
from fms_mo.quant.quantizers import transformers_prepare_input
from fms_mo.utils.import_utils import available_packages
from fms_mo.utils.utils import move_to, patch_torch_bmm, prepare_data_4_fwd
logger = logging.getLogger(__name__)
def parse_operation(op_str: str):
"""
Takes a string searches for the last '(' and ')' and separates it into the operator and operands
Args:
op_str (str):
e.g. "^LearnedClippedLinearQuantizeSTE_rev1(4, True, False, None, None)(%input.5, %80)"
Returns:
tuple: A tuple containing operator (str), and operands List[str]:
e.g. "^LearnedClippedLinearQuantizeSTE_rev1(4, True, False, None, None)"
e.g. "[%input.5, %80]"
"""
last_open_parenthesis_index = op_str.rfind("(")
last_close_parenthesis_index = op_str.rfind(")")
operator = op_str[:last_open_parenthesis_index]
operands = op_str[
last_open_parenthesis_index + 1 : last_close_parenthesis_index
].split(",")
# pylint: disable=line-too-long
operands = [operand.strip() for operand in operands] if operands != [""] else None # type: ignore[assignment]
return operator, operands
class Node:
r"""
A class representing a node in a PyTorch model.
Starting example:
%input.28 : Float(1:238144, 64:3721, 61:61, 61:1, requires_grad=1, device=cpu)
= aten::_convolution(%input.27, %184, %60, %185, %186, %187, %57, %188, %59, %57, %57,
%56, %56), scope: __module.layer1/__module.layer1.2/__module.layer1.2.conv2
# directory_path \ file.py:419:0
Step 1: Remove everything following # such that example becomes:
%input.28 : Float(1:238144, 64:3721, 61:61, 61:1, requires_grad=1, device=cpu)
= aten::_convolution(%input.27, %184, %60, %185, %186, %187, %57, %188, %59, %57, %57,
%56, %56), scope: __module.layer1/__module.layer1.2/__module.layer1.2.conv2
Step 2: Simplify "scope" string such the example becomes:
%input.28 : Float(1:238144, 64:3721, 61:61, 61:1, requires_grad=1, device=cpu)
= aten::_convolution(%input.27, %184, %60, %185, %186, %187, %57, %188, %59, %57, %57,
%56, %56), scope: __module.layer1.2.conv2
Step 3: Further parse into a dict
Example:
'%77 : Float(16:1632000, 3:544000, 80:6800, 80:85, 85:1, requires_grad=1,
device=cuda:0), %78 : Float(16:408000, 3:136000, 40:3400, 40:85, 85:1, requires_grad=1,
device=cuda:0), %79 : Float(16:102000, 3:34000, 20:1700, 20:85, 85:1, requires_grad=1,
device=cuda:0) = prim::TupleUnpack(%1248)'
Attributes:
name (str): The name of the node.
obj (str): The object associated with the node.
Op (str): The operation performed by the node.
operator (str): The operator type of the node.
operands (list): A list of operands associated with the node.
parents (list): A list of parent nodes.
children (list): A list of child nodes.
scope (str): The scope of the node.
modname (str): The module name of the node.
lineno (int): The line number of the node.
unpackIdx (int): The index of the unpack operation.
ch_in (list): The input channels of the node.
ch_out (list): The output channels of the node.
TSparents (list): The native PyTorch script parents of the node.
TSoutputs (list): The native PyTorch script outputs of the node.
Methods:
__init__(node_input, dictionary_of_nodes): Initializes the Node object.
__repr__(): Returns a string representation of the Node object.
"""
def __init__(self, node_input, dictionary_of_nodes: dict):
"""
Initialize a Node object.
Args:
node_input (str or torch._C.Node): The input to the Node object.
If it's a string, it represents the node definition as a string.
If it's a torch._C.Node, it represents a native TorchScript node.
dictionary_of_nodes (dict): A dictionary that keeps track of all the nodes in the graph.
"""
if isinstance(node_input, torch._C.Node):
node_input_repr = node_input.__repr__().replace("\n", "")
native_torchscript_node = node_input
native_torchscript_parents = [
"%" + n.__repr__().split(" defined in")[0]
for n in native_torchscript_node.inputs()
]
native_torchscript_outputs = [
"%" + n.__repr__().split(" defined in")[0]
for n in native_torchscript_node.outputs()
]
elif isinstance(node_input, str):
node_input_repr = node_input
native_torchscript_node = None
native_torchscript_parents = None
native_torchscript_outputs = None
else:
logger.warning(
"Input to class Node is neither a string nor a torchscript node"
)
return
if "# " in node_input_repr:
line_number = node_input_repr.split("#")[1].split(":")[-2]
node_input_repr = node_input_repr.split("#")[0]
else:
line_number = None
if "scope:" in node_input_repr:
temp_str = node_input_repr.split("scope:")
scope_repr = temp_str[1].split("/")[-1]
node_input_repr = temp_str[0]
else:
scope_repr = None
module_name = (
scope_repr.replace("__module.", "") if scope_repr is not None else ""
)
module_name = module_name.replace(
"model.", ""
) # Remove model. for shorter names
start_index = 0
if " = " in node_input_repr:
if node_input_repr.count(" = ") == 1:
node_def, op_str = node_input_repr.split(" = ")
else:
# e.g., %2206 : Float(1, 3, 1, 1, 2, strides=[6, 2, 2, 2, 1], requires_grad=0,
# device=cpu) = prim::Constant[value=(1,1,1,.,.) =
# 1.2500 1.6250 (1,2,1,.,.) = 2.0000 3.7500 (1,3,1,.,.) =
# 4.1250 2.8750 [ CPUFloatType{1,3,1,1,2} ]](), scope: __module.module_list.106
idx1steq = node_input_repr.find(" = ")
node_def, op_str = (
node_input_repr[:idx1steq],
node_input_repr[idx1steq + 3 :],
)
operator, operands = parse_operation(op_str)
if "aten::_conv" in op_str:
if native_torchscript_node:
self.ch_in = (
list(native_torchscript_node.inputs())[0].type().sizes()
)
# NOTE: Needed for finding shortcut convolutions later
self.ch_out = (
list(native_torchscript_node.outputs())[0].type().sizes()
)
else:
node_def = node_input_repr
op_str, operator, operands = None, None, None
node_def_in_one_line = node_def.count(" : ")
# when unpacking OPs, will create N instances of node, i.e. not pointing to the same "self"
node_instances = [self] + [
deepcopy(self) for _ in range(node_def_in_one_line - 1)
]
for node_index, node_instance in enumerate(node_instances):
if node_index == node_def_in_one_line - 1:
end_index = len(node_def)
else:
current_colon_index = node_def.find(" : ", start_index)
next_colon_index = node_def.find(" : ", current_colon_index + 1)
end_index = node_def.rfind("%", start_index, next_colon_index) - 2
working_str = node_input_repr[start_index:end_index]
start_index = end_index + 2
# pylint: disable=line-too-long
node_instance.name, node_instance.obj = working_str.split(" : ") # type: ignore[attr-defined]
node_instance.name = node_instance.name.strip() # type: ignore[attr-defined]
if native_torchscript_outputs:
# pylint: disable=line-too-long
if node_instance.name not in native_torchscript_outputs: # type: ignore[attr-defined]
# pylint: disable=line-too-long
logger.error(
f"Node def {node_instance.name} not in nativeTSoutputs " # type: ignore[attr-defined]
f"{native_torchscript_outputs}"
)
node_instance.Op = op_str # type: ignore[attr-defined]
if node_def_in_one_line > 1:
node_instance.unpackIdx = node_index # type: ignore[attr-defined]
if line_number:
node_instance.lineno = line_number # type: ignore[attr-defined]
node_instance.operator = operator # type: ignore[attr-defined]
# This is the name of parents, not the pointer to the parent nodes
node_instance.parents = operands # type: ignore[attr-defined]
node_instance.parents_ptr = [] # type: ignore[attr-defined]
node_instance.scope = scope_repr # type: ignore[attr-defined]
node_instance.modname = module_name # type: ignore[attr-defined]
node_instance.children = [] # type: ignore[attr-defined]
node_instance.children_ptr = [] # type: ignore[attr-defined]
node_instance.TSparents = native_torchscript_parents # type: ignore[attr-defined]
node_instance.TSoutputs = native_torchscript_outputs # type: ignore[attr-defined]
# graph.dictionary_of_nodes will keep a record of all the nodes
dictionary_of_nodes[node_instance.name] = node_instance # type: ignore[attr-defined]
def __repr__(self):
return f"{self.name} "
class Graph:
"""
Class for Graph
Attributes:
dictionary_of_nodes (dict): A dictionary that maps node names to Node objects.
inputs (list): A list of Node objects representing the input nodes of the graph.
disable_plots (bool): A flag indicating whether to disable plotting functionality.
model_node (Node): A Node object representing the model node of the graph.
return_node (list): A list of Node objects representing the return nodes of the graph.
"""
def __init__(self, graph):
"""
Initializes the Graph object by parsing the given torch._C.Graph object.
Args:
graph (torch._C.Graph): The torch._C.Graph object to be parsed.
"""
self.dictionary_of_nodes = {}
self.inputs = []
self.disable_plots = not available_packages["pygraphviz"]
list_str = None
if isinstance(graph, torch._C.Graph):
graph_repr = graph.__repr__()
list_str = graph_repr.split("\n")
# Parse header, i.e. first few lines, def of graph inputs, first arg must be model itself
# example: graph(%self.1 : __torch__.torchvision.models.resnet.ResNet,
# %input.1 : Float(1:178608, 3:59536, 244:244, 244:1, requires_grad=0, device=cpu)):
curr_line = 0
header = ""
is_header = True
left_parenthesis, right_parenthesis = 0, 0
while is_header and curr_line < 10:
line_str_i = list_str[curr_line]
header += line_str_i
left_parenthesis = header.count("(")
right_parenthesis = header.count(")")
# Unbalanced parenthesis means still in header
is_header = left_parenthesis != right_parenthesis
if line_str_i.endswith(","):
line_str_i = line_str_i[:-1]
elif line_str_i.endswith("):"):
line_str_i = line_str_i[:-2]
if curr_line == 0:
self.model_node = Node(
line_str_i.replace("graph(", ""), self.dictionary_of_nodes
)
else:
self.inputs.append(Node(line_str_i, self.dictionary_of_nodes))
curr_line += 1
# Parse body, i.e. def of nodes. we now can utilize TS native nodes
for graph_node in graph.nodes():
n_ptr = Node(graph_node, self.dictionary_of_nodes)
if "prim::If" in graph_node.kind():
n_ptr.scope = "Either"
for bi in graph_node.blocks():
for nj in bi.nodes():
nj_ptr = Node(nj, self.dictionary_of_nodes)
n_ptr.scope += f" {nj_ptr.name} or"
n_ptr.scope = n_ptr.scope[:-3]
# Parse last line, return, usually looks like "return (%xxx)"
temp_str = graph.return_node().__repr__()
idx_s = temp_str.index("(")
idx_e = temp_str.rindex(")")
self.return_node = [
self.dictionary_of_nodes[ri.strip()]
for ri in temp_str[idx_s + 1 : idx_e].split(",")
]
# Second pass add children info and update pointer lists
for key_node, value_node in self.dictionary_of_nodes.items():
if value_node.parents is not None:
value_node.parents_ptr = [
self.dictionary_of_nodes[pi] for pi in value_node.parents
]
for pi in (
value_node.parents_ptr
): # Here we can use pointer to parents directly
pi.children.append(key_node)
pi.children_ptr.append(value_node)
# Verify consistency between our data structure vs native torchscript's
if set(value_node.parents) != set(value_node.TSparents):
logger.info(
f"{key_node} inconsistent parents {value_node.parents} "
f"{value_node.TSparents}"
)
logger.info("torchscript inlined_graph parsed successfully!")
def DFS(
self,
target_op,
verbose: bool = False,
insert_annotation: bool = False,
find_first_only: bool = False,
reverse: bool = False,
node_st=None,
stop_op=None,
ignore_op=None,
hook=None,
):
"""
Performs depth first search on the graph.
Args:
target_op (str or list of str): The target operation(s) to search for.
verbose (bool): Whether to print search results for debugging purposes.
Defaults to False.
insert_annotation (bool): Whether to insert annotations in the returned node list
for specific cases. Defaults to False.
find_first_only (bool): Whether to stop after finding the first node that matches
the target operation(s). Defaults to False.
reverse (bool): Whether to search in reverse direction, from the end to the beginning
of the graph. Defaults to False.
node_st (Node or str): The starting node for the search. If not specified, it defaults
to the input nodes of the graph.
stop_op (str or list of str): The operations to stop the search at.
If not specified, it defaults to None.
ignore_op (str or list of str): The operations to ignore during the search.
If not specified, it defaults to None.
hook (callable, optional): A function to apply to the node if it matches the
target operation(s). Defaults to None.
Returns:
list: A list of Node objects that match the target operation(s).
"""
# Make a table to record visit history, to avoid redundant search
self.visited = {key: False for key in self.dictionary_of_nodes}
if isinstance(target_op, str):
# 1. If target_op is not a list of strings, make it a list for easier looping.
# 2. a list of criteria means more than one search criteria is allowed
target_op = [target_op]
if isinstance(stop_op, str):
stop_op = [stop_op]
if isinstance(node_st, Node):
node_st = [
node_st
] # Make it a single element list, so we can use for loop easier.
if ignore_op is None:
ignore_op = ["aten::size"] # Default prescreen for now
elif isinstance(ignore_op, str):
ignore_op = [ignore_op] + ["aten::size"]
else: # a list
ignore_op += ["aten::size"]
# Search from a given starting node, the beginning (input nodes) or the end (return node)
starting_nodes = (
node_st
if node_st is not None
else self.inputs
if not reverse
else self.return_node
)
self.node_found = []
self.node_traced = 0
for node_i in starting_nodes:
self.br_begin = None
self._dfs(
node_i,
target_op,
verbose,
insert_annotation,
find_first_only,
reverse,
node_st,
stop_op,
ignore_op,
hook,
)
if verbose:
logger.info(
f"Nodes traced={self.node_traced}/{len(self.dictionary_of_nodes.keys())}, "
f"found {len(self.node_found)} that satisfy the criteria {target_op}"
)
return self.node_found
def _dfs(
self,
curr_node,
target_op,
verbose,
insert_annotation,
find_first_only,
reverse,
node_st,
stop_op,
ignore_op,
hook,
):
# Make sure curr_node is a pointer, not a string
if isinstance(curr_node, str):
curr_node = self.dictionary_of_nodes[curr_node]
if self.visited[curr_node.name]: # Avoid redundant search
return curr_node # This is the end of a branch
self.visited[curr_node.name] = True # Mark this node visited first
if curr_node.Op: # e.g, input nodes' .Op=[]
# Checking sequence ignore_op -> target_op -> stop_op
# 1) bBfore checking if curr_node satisfies target_op, filter by ignore_op first
if any(Op_i in curr_node.Op for Op_i in ignore_op):
return
# 2) Then check if satisfies target_op
if any(Op_i in curr_node.Op for Op_i in target_op):
self.node_found.append(curr_node)
if verbose:
logger.info(
f"{curr_node.name} {curr_node.operator} {curr_node.scope}"
)
if find_first_only:
return
# Find_first_only means "stop searching after 1st node with target_op is found"
# (if branches before found, 1st node for each branch)
if hook:
hook(
curr_node
) # Can apply a hook function if search criteria satisfied
# 3) Then decide if it satisfies stop_op
if stop_op:
if any(Op_i in curr_node.Op for Op_i in stop_op):
return
# we can also stop the search when run into any of the provided stop_op, but
# make sure the 1st node where the search begins does not satisfy that criteria
# for example, we could search from a Conv and hope to stop at the next Conv,
# call by DFS(...,node_st=Conv.parent_ptr[i], stop_op='conv')
# remember we have added our known "filters" to stop_op already in the outer loop
self.node_traced += 1
next_nodes = curr_node.children_ptr if not reverse else curr_node.parents_ptr
end_nodes = self.return_node if not reverse else self.inputs
if next_nodes:
num_next_nodes: int = len(next_nodes)
if num_next_nodes > 1 and insert_annotation:
for next_node_elem in next_nodes[1:]:
self.node_found.append(
f"{curr_node.name}->{next_node_elem.name}, A branch begins here"
)
for next_node_index, next_node_elem in enumerate(next_nodes):
if next_node_index > 0:
self.br_begin = next_node_elem
next_node_elem.isOnBranch = True
if verbose:
logger.info(
f"Start searching branch #{next_node_index}/{num_next_nodes} of "
f"node {curr_node.name}"
)
br_end = self._dfs(
next_node_elem,
target_op,
verbose,
insert_annotation,
find_first_only,
reverse,
node_st,
stop_op,
ignore_op,
hook,
)
if br_end is not None:
if verbose:
logger.info(f"current branch merged into {br_end.name}")
if insert_annotation:
self.node_found.append(
f"{self.br_begin.name}->{br_end.name}, A branch merges into "
)
next_node_elem.isBranchMerge = True
elif curr_node in end_nodes: # return node also has no child
if verbose:
logger.info(
"Run into a return-node or an input-node (if reverse search) "
)
if insert_annotation:
self.node_found.append(f"{curr_node.name}, End of main branch")
def brute_force_search(self, target_op):
"""
Searches the nodes_dictionary directly for nodes with a matching operator.
Args:
target_op (str): The operator to search for.
Returns:
list: A list of nodes that match the target operator.
"""
self.node_found = []
for curr_node in self.dictionary_of_nodes.values():
if curr_node.Op:
if target_op in curr_node.Op:
self.node_found.append(curr_node)
logger.info(
f"{curr_node.name} {curr_node.operator} {curr_node.scope}"
)
return self.node_found
def plot_full(self, output_name="test.svg"):
"""
Plots the full graph
Args:
output_name (str, optional): The name of the file to save the plot to.
Defaults to "test.svg".
"""
if self.disable_plots:
return
# Third Party
import pygraphviz as pgv
G = pgv.AGraph(strict=False, directed=True)
for node_name, node_ptr in self.dictionary_of_nodes.items():
G.add_node(node_name.replace("%", ""), label=node_name, shape="record")
for childen_node in node_ptr.children:
G.add_edge(node_name.replace("%", ""), childen_node.replace("%", ""))
G.layout(prog="dot")
G.draw(output_name)
def plot_short(
self,
kw,
output_name="test.svg",
plot_in_notebook: bool = False,
verbose: bool = False,
fields=1,
showQW=True,
):
"""
Plot a computation graph for a given keyword.
Args:
kw (str): The keyword to search for in the computation graph.
output_name (str, optional): The name of the output file. Defaults to "test.svg".
plot_in_notebook (bool, optional): Whether to plot the graph in the Jupyter notebook.
Defaults to False.
verbose (bool, optional): Whether to logger.info verbose output. Defaults to False.
fields (int, optional): The number of fields to include in the node label.
Defaults to 1.
showQW (bool, optional): Whether to show the weight quantizers in the graph.
Defaults to True.
"""
if self.disable_plots:
return
# Third Party
import pygraphviz as pgv
G = pgv.AGraph(strict=False, directed=True)
nodes_of_intr = self.DFS(kw)
if verbose:
logger.info(nodes_of_intr)
for node_ptr in nodes_of_intr:
nname = node_ptr.name.replace("%", "")
modname = (
node_ptr.scope.replace("__module.", "")
if node_ptr.scope is not None
else ""
)
modname = modname.replace(
"model.", ""
) # Further simplify the name for better visualization on SVG
fieldsStr = ["%" + nname, node_ptr.operator, modname]
labelStr = "|".join(
fieldsStr[: min(fields, len(fieldsStr))]
) # Can plot with fewer fields for each node as needed
fcolor = (
"#b2d3e4"
if "conv" in node_ptr.operator
else (
"#FFF2CC"
if "^" in node_ptr.operator
else (
"#DDDDDD"
if (
"addmm" in node_ptr.operator
or "bmm" in node_ptr.operator
or "linear" in node_ptr.operator
)
else "#C5E0B4"
if "If" in node_ptr.operator
else "white"
)
)
)
if "conv" in node_ptr.operator:
ch_info = (
f"{node_ptr.ch_in}|{node_ptr.ch_out}"
if hasattr(node_ptr, "ch_in") and hasattr(node_ptr, "ch_out")
else ""
)
labelStr = (
f"{fieldsStr[0]}: \n| {{ {fieldsStr[1]}|{fieldsStr[2]} }}| "
f"{{ input:|output: }} |{{ {ch_info} }}"
)
# Also include weight quantizers
nodeW = node_ptr.parents_ptr[1]
if "^" in nodeW.Op and showQW: # A quantizer is being used.
nWname = nodeW.name.replace("%", "")
modWname = (
nodeW.scope.replace("__module.", "")
if nodeW.scope is not None
else ""
)
modWname = modWname.replace("model.", "")
fieldsWStr = ["%" + nWname, nodeW.operator, modWname]
labelWStr = "|".join(fieldsWStr[: min(fields, len(fieldsWStr))])
G.add_node(
nWname,
fillcolor="#FFF2CC",
style="filled",
label=labelWStr,
shape="record",
)
G.add_edge(nWname, nname)
G.add_node(
nname, fillcolor=fcolor, style="filled", label=labelStr, shape="record"
)
for j in node_ptr.children:
G.add_edge(nname, j.replace("%", ""))
if verbose:
G.write(output_name + ".dot")
logger.info(
"Output graph to .dot. It could be very slow for DenseNet or complicated nets."
)
G.layout(prog="dot")
G.draw(output_name)
if plot_in_notebook and available_packages["matplotlib"]:
# Third Party
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
plt.figure(figsize=(16, 50))
plt.imshow(mpimg.imread(output_name))
plt.axis("off")
def is_child(self, child_node, parent_node):
"""
Checks if a node is a child node of a parent node
Args:
child_node (Node): The child node to check
parent_node (Node): The parent node to check against
Returns:
bool: True if the child node is a child of the parent node, False otherwise
"""
potential_match = self.DFS(child_node.Op, node_st=parent_node)
if not isinstance(potential_match, list):
potential_match = [potential_match]
# Matching node.Op may be insufficient, need to check .obj as well
is_child_node = any(node.obj == child_node.obj for node in potential_match)
return is_child_node
def id_rpn_from_last(self, last_candidates):
"""
Identify the RPN candidates from the last candidates.
Args:
last_candidates (list): A list of candidate nodes.
Returns:
rpn_candidates (list): A list of identified RPN candidates.
"""
grouped, rpn_candidates = [], []
for n in last_candidates:
if n.children_ptr: # Sometimes last nodes have no children anymore
child0 = n.children_ptr[0]
if len(n.children) == 1 and (
"TupleConstruct" in child0.Op or "ListConstruct" in child0.Op
):
if child0 not in grouped:
grouped.append(
child0
) # NOTE: We add the tuple construct, not the Op itself, here
if len(grouped) > 1:
num_member_per_group = [len(n.parents) for n in grouped]
if max(num_member_per_group) % min(num_member_per_group) == 0:
idx_max = num_member_per_group.index(max(num_member_per_group))
rpn_candidates = grouped[idx_max].parents_ptr
return rpn_candidates
def find_fpn_convs(self):
"""
Identify the FPN to be quantized
Returns:
List[Node]: A list of Node objects representing the FPN convolutions
"""
fpn_begin = []
fpn_end = []
fpn_candidates = []
temp_node = self.DFS("aten::upsample")
children = (
[n1 for n in temp_node for n1 in n.children_ptr if "aten::add" in n1.Op]
if temp_node
else []
)
grandchildren = (
[n1 for n in children for n1 in n.children_ptr if "aten::upsample" in n1.Op]
if children
else []
)
ggchildren = (
[n1 for n in grandchildren for n1 in n.children_ptr if "aten::add" in n1.Op]
if grandchildren
else []
)
if ggchildren:
for n in ggchildren:
fpn_begin += self.DFS(
"TupleConstruct", find_first_only=True, reverse=True, node_st=n
)
fpn_end += self.DFS(
"TupleConstruct", find_first_only=True, reverse=False, node_st=n
)
# Remove duplicates
temp_node, fpn_begin = fpn_begin, []
for n in temp_node:
if n not in fpn_begin:
fpn_begin.append(n)
temp_node, fpn_end = fpn_end, []
for n in temp_node:
if n not in fpn_end:
fpn_end.append(n)
if len(fpn_begin) > 1 or len(fpn_end) > 1 or len(fpn_begin) != len(fpn_end):
logger.warning(f"FPN detection is inconsistent. {fpn_begin} {fpn_end}")
else:
fpn_candidates = self.DFS(
"aten::_conv", node_st=fpn_begin[0], stop_op=fpn_end[0].Op
)
return fpn_candidates
def __repr__(self):
return (
f" model node={self.model_node}\n inputs={self.inputs}\n "
f"{self.dictionary_of_nodes}\nreturn={self.return_node}"
)
def find_shortcut_conv_v2(graph, verbose=False):
"""
Revised algorithm for finding convolutional modules on shortcut path
TODO: make it a member functon of class Graph
Args:
graph (Graph): The input graph.
verbose (bool, optional): Whether to print detailed information. Defaults to False.
Returns:
list: List of potential convolutional modules on shortcut path
"""
# 1. Find the Add_ nodes
assert isinstance(graph, Graph), "input needs to be an object of our custom graph"
NodesAdd = graph.DFS(["aten::add(", "aten::add_("])
irrNodes = []
# Depending on user preference, could be out+=shortcut, out=out+shortcut or out=shortcut+out
if verbose:
logger.info(NodesAdd)
qconv_candidate = []
for node_i in NodesAdd:
if all("Float" in n.obj for n in node_i.parents_ptr[:2]):
# Make sure this Add Op is adding 2 float tensors (No interest in Long or Int Adds here)
# 2. Find where branch begins, i.e. common node along 1st and 2nd parent node of Add,
# record how many levels to the current Add
node_i.parents_ptr[:2]
levels_from_add = [1, 1]
branch_nodes = [None, None]
for j, p_j in enumerate(node_i.parents_ptr[:2]):
while len(p_j.children_ptr) < 2 and levels_from_add[j] < 20:
levels_from_add[j] += 1
p_j = p_j.parents_ptr[0]
branch_nodes[j] = p_j
# 2.1 Make sure the branch node is the same (there might be cases where a secondary
# branch, like a concat, exists in one of the branches)
if branch_nodes[0] != branch_nodes[1]:
# possible cause, run into search limit (20 levels) or branch-in-branch situation,
# theoretically we should keep searching to confirm
# but if we can confirm one of the two is a common parent, should suffice
if graph.is_child(
child_node=branch_nodes[0], parent_node=branch_nodes[1]
):
real_branch_begin = branch_nodes[1]
elif graph.is_child(
child_node=branch_nodes[1], parent_node=branch_nodes[0]
):
real_branch_begin = branch_nodes[0]
else:
real_branch_begin = None
logger.warning(f"{node_i}'s branch analysis could be incorrect!")
else:
# Most likely case, the 2 searches return the same node
real_branch_begin = branch_nodes[0]
shorter_path = 0 if levels_from_add[0] < levels_from_add[1] else 1
if real_branch_begin:
temp_nodes = graph.DFS(
"conv",
find_first_only=True,
reverse=True,
node_st=node_i.parents_ptr[shorter_path],
stop_op=real_branch_begin.Op,
)
# shortcut Convs must have channel_out > channel_in, on the other hand,
# FPN must have channel_out <= channel_in
qconv_candidate += [n for n in temp_nodes if n.ch_out > n.ch_in]
else:
irrNodes.append(node_i)
return qconv_candidate
def find_all_conv_sorted(graph, verbose=False):
"""
Find all the convolutional modules in the graph, and insert those on shortcut path at the
branch point (as a list, in case there are more than 1)
Args:
graph (Graph): The graph to be searched.
verbose (bool, optional): Whether to print out information about the search process.
Default is False.
Returns:
list: A list of sorted convolutional layers.
"""
NodesConv = graph.DFS(["conv", "addmm"], verbose=False, insert_annotation=True)
if verbose:
for node_name in NodesConv:
if isinstance(node_name, str):
logger.info(f"[Annotation] {node_name}")
else:
logger.info(node_name.name)
end_of_main_branch_index = 0
branch_begin_index = []
branch_merge_index = []
dict_branch_first_node = {}
for node_index, node_name in enumerate(NodesConv):
if isinstance(node_name, str):
if "begins" in node_name:
# e.g. '%input.237->%input.240, A branch begins here'
branch_begin_index.append(node_index)
branch_first_node_name = node_name[
node_name.find("->") + 2 : node_name.find(",")
]
if branch_first_node_name in dict_branch_first_node:
# this node already in dictBr1stNode, e.g. DenseNet
curr_val = dict_branch_first_node[branch_first_node_name]
if isinstance(curr_val, int):
dict_branch_first_node[branch_first_node_name] = [
curr_val,
node_index,
]
else:
dict_branch_first_node[branch_first_node_name].append(
node_index
)
else:
dict_branch_first_node[branch_first_node_name] = node_index
elif "merges" in node_name:
branch_merge_index.append(node_index)
elif "End of main" in node_name:
end_of_main_branch_index = node_index
else:
logger.error("Undefined annotations found")
if len(branch_begin_index) != len(branch_merge_index):
logger.error(
f"Total number of branch_begins {len(branch_begin_index)} and branch_merges "
f"{len(branch_merge_index)} mismatch"
)
else:
revBrMerge = branch_merge_index[::-1] + [end_of_main_branch_index]
for idx_st, idx_end in zip(revBrMerge[1:], revBrMerge[:-1]):
if idx_end - idx_st > 1:
strAnno = NodesConv[idx_end]
br1stnode = strAnno[
: strAnno.find("->")
] # e.g. '%input.248->%1730, A branch merges into'
# look up dict['%input.248'] to get which line has '%xxxx->%input.248, ....',
# insert this section of nodes into that line
index_to_insert = dict_branch_first_node[br1stnode]
if isinstance(index_to_insert, list):
index_to_insert = dict_branch_first_node[br1stnode].pop()
NodesConv[index_to_insert] = NodesConv[idx_st + 1 : idx_end]
new_nodes_conv = [
node
for node in NodesConv[:end_of_main_branch_index]
if not isinstance(node, str)
]
if verbose:
logger.info("\nAfter insertion of Convs on branches into the list\n")
for node_name in new_nodes_conv:
if isinstance(node_name, list):
logger.info(node_name)
else:
logger.info(node_name)
return new_nodes_conv
def check_activation_dir(curr_node):
"""
Check activation direction of a node in a PyTorch module. The function identifies the
`isActOutUnidir` and `isActOutBounded` attributes of the current node.
Args:
curr_node (Node): The current node in the PyTorch module.
"""
# Corresponds to ReLU, ReLU6, Sigmoid
rectifier_ops: List[str] = [
"aten::relu",
"aten::hardtanh",
"aten::sigmoid",
"aten::softmax",
]
#
bidirectional_ops: List[str] = [
"aten::_conv",
"aten::addmm",
"aten::linear",
"aten::index_put_", # Fills target tensor with a given src tensor based on an idx tensor
"aten::matmul",
]
add_ops: List[str] = ["aten::add(", "aten::add_("]
concat_ops: List[str] = ["prim::ListConstruct"]
# Bounded Ops, note that silu is only bounded on the neg side.
bounded_ops: List[str] = [
"aten::hardtanh",
"aten::sigmoid",
"aten::softmax",
]
# Deterministic cases
if any(rectifier_op in curr_node.Op for rectifier_op in rectifier_ops):
curr_node.isActOutUnidir = True
curr_node.isActOutBounded = bool(
any(Op_i in curr_node.Op for Op_i in bounded_ops)
)
elif any(bidrectional_op in curr_node.Op for bidrectional_op in bidirectional_ops):
curr_node.isActOutUnidir = False
curr_node.isActOutBounded = bool(
any(Op_i in curr_node.Op for Op_i in bounded_ops)