-
Notifications
You must be signed in to change notification settings - Fork 55
Expand file tree
/
Copy pathgraph.py
More file actions
138 lines (105 loc) · 3.32 KB
/
graph.py
File metadata and controls
138 lines (105 loc) · 3.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
DeepLabCut Toolbox (deeplabcut.org)
© A. & M. Mathis Labs
Licensed under GNU Lesser General Public License v3.0
"""
# NOTE DUPLICATED @C-Achard 2026-26-01: Duplication between this file
# and dlclive/graph.py
import tensorflow as tf
vers = (tf.__version__).split(".")
if int(vers[0]) == 2 or int(vers[0]) == 1 and int(vers[1]) > 12:
tf = tf.compat.v1
def read_graph(file):
"""
Loads the graph from a protobuf file
Parameters
-----------
file : string
path to the protobuf file
Returns
--------
graph_def :class:`tensorflow.tf.compat.v1.GraphDef`
The graph definition of the DeepLabCut model found at the object's path
"""
with tf.io.gfile.GFile(file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
def finalize_graph(graph_def):
"""
Finalize the graph and get inputs to model
Parameters
-----------
graph_def :class:`tensorflow.compat.v1.GraphDef`
The graph of the DeepLabCut model, read using the :func:`read_graph` method
Returns
--------
graph :class:`tensorflow.compat.v1.GraphDef`
The finalized graph of the DeepLabCut model
inputs :class:`tensorflow.Tensor`
Input tensor(s) for the model
"""
graph = tf.Graph()
with graph.as_default():
tf.import_graph_def(graph_def, name="DLC")
graph.finalize()
return graph
def get_output_nodes(graph):
"""
Get the output node names from a graph
Parameters
-----------
graph :class:`tensorflow.Graph`
The graph of the DeepLabCut model
Returns
--------
output : list
the output node names as a list of strings
"""
op_names = [str(op.name) for op in graph.get_operations()]
if "concat_1" in op_names[-1]:
output = [op_names[-1]]
else:
output = [op_names[-1], op_names[-2]]
return output
def get_output_tensors(graph):
"""
Get the names of the output tensors from a graph
Parameters
-----------
graph :class:`tensorflow.Graph`
The graph of the DeepLabCut model
Returns
--------
output : list
the output tensor names as a list of strings
"""
output_nodes = get_output_nodes(graph)
output_tensor = [out + ":0" for out in output_nodes]
return output_tensor
def get_input_tensor(graph):
input_tensor = str(graph.get_operations()[0].name) + ":0"
return input_tensor
def extract_graph(
graph, tf_config=None
) -> tuple[tf.Session, tf.Tensor, list[tf.Tensor]]:
"""
Initializes a tensorflow session with the specified graph and extracts the model's inputs and outputs
Parameters
-----------
graph :class:`tensorflow.Graph`
a tensorflow graph containing the desired model
tf_config :class:`tensorflow.ConfigProto`
Returns
--------
sess :class:`tensorflow.Session`
a tensorflow session with the specified graph definition
outputs :class:`tensorflow.Tensor`
the output tensor(s) for the model
"""
input_tensor = get_input_tensor(graph)
output_tensor = get_output_tensors(graph)
sess = tf.Session(graph=graph, config=tf_config)
inputs = graph.get_tensor_by_name(input_tensor)
outputs = [graph.get_tensor_by_name(out) for out in output_tensor]
return sess, inputs, outputs