Skip to content

Commit 403d57e

Browse files
rigaBogdan WBogdan-Wiederspan
authored
Add AOT tools (#10)
* added tools to share with MR * added operations checker.py * remove unnecessary file * clean up of methods that should be in another class * add test saved_model * added tensorboard function * added scripts dummies * clean up repo * more cleaning * Added tests and markdown tables for TF2.6.4 and TF1.X * create tools and tests * delete refactor directory structure * Removed many function. This was possible by found a way using signatures to convert tf and keras the same way. * fixed path dst path bug * finished unit tests for convert_model.py * Add more describtion to the functions, also refactor a bit * changed graph to concrete function, to be more clear. removed debugger * made out of test a proper module * Finished writing tests * added test for aot_compatility * Start adjustments. * Start refactoring. * Add docstrings on all functions * Rename save and load grpah functions. * Linting. * moved functions to handle the loading of Saved_models, and extraction of graph_defs from 'aot.py' to 'tools.py' * added tests for 'load_model' and 'load_graph_def' * Added tests for functions regarding aot-utility and aot-compilation scripts * moved functions to handle the loading of Saved_models, and extraction of graph_defs from 'aot.py' to 'tools.py' * added tests for 'load_model' and 'load_graph_def' * Added tests for functions regarding aot-utility and aot-compilation scripts * fixed usage of contextmanager in test_load_graph_def * fix bug when extracting nodes from graphdef of a concretefunction. Also fixed aot unittests * Fixed unit tests for general tensorflow tools * fixed test for aot compiltation and fixed most linting issues * Skip some tests if skip_if_no_tf2xla_supported_ops not available. * Adjust tests. * Fix lazy loader tests. * Polish tests. * Move to pytest. * Update intallation in images. * Update deps. * Update docker files. * Update type hints. * Polish code. --------- Co-authored-by: Bogdan W <bogdan.wiederspan@desy.de> Co-authored-by: Bogdan-Wiederspan <79155113+Bogdan-Wiederspan@users.noreply.github.com> Co-authored-by: Bogdan Wiederspan <b.wiederspan@web.de>
1 parent 1d07dda commit 403d57e

18 files changed

Lines changed: 1270 additions & 73 deletions

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[flake8]
22

3-
max-line-length = 101
3+
max-line-length = 120
44

55
# codes of errors to ignore
66
ignore = E128, E306, E402, E722, E731, E741, W504, Q003

.github/workflows/lint_and_test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
runs-on: ubuntu-latest
1010
steps:
1111
- name: Checkout 🛎️
12-
uses: actions/checkout@v3
12+
uses: actions/checkout@v4
1313
with:
1414
persist-credentials: false
1515

@@ -40,12 +40,12 @@ jobs:
4040
name: test (image=${{ matrix.versions.tag }}, tf=${{ matrix.versions.tf }})
4141
steps:
4242
- name: Checkout 🛎️
43-
uses: actions/checkout@v3
43+
uses: actions/checkout@v4
4444
with:
4545
persist-credentials: false
4646

4747
- name: Pull docker image 🐳
4848
run: docker pull cmsml/cmsml:${{ matrix.versions.tag }}
4949

5050
- name: Test 🎰
51-
run: bash tests/docker.sh cmsml/cmsml:${{ matrix.versions.tag }} "[ '${{ matrix.versions.tf }}' = 'default' ] || pip install -U tensorflow=='${{ matrix.versions.tf }}'; python -m unittest tests"
51+
run: bash tests/docker.sh cmsml/cmsml:${{ matrix.versions.tag }} "[ '${{ matrix.versions.tf }}' = 'default' ] || pip install -U tensorflow=='${{ matrix.versions.tf }}'; pytest -n 2 tests"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ To use the cmsml package via docker, checkout our [DockerHub](https://hub.docker
6262
The tests can be triggered with
6363

6464
```shell
65-
python -m unittest tests
65+
pytest -n auto tests
6666
```
6767

6868
and in general, they should be run for Python 3.7 to 3.11.

cmsml/scripts/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
11
# coding: utf-8
2+
3+
__all__ = ["compile_tf_graph", "aot_compile"]
4+
5+
# provisioning imports
6+
from cmsml.scripts.compile_tf_graph import compile_tf_graph, aot_compile
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# coding: utf-8
2+
3+
"""
4+
Script that provides insight on which TensorFlow operations are XLA / AOT compatible and whether a specified graph would
5+
be supported.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import tabulate
11+
12+
from cmsml.util import colored
13+
from cmsml.tensorflow.aot import OpsData, load_graph_def, get_graph_ops
14+
15+
16+
def check_aot_compatibility(
17+
model_path: str,
18+
serving_key: str = "serving_default",
19+
devices: tuple[str] = ("cpu",),
20+
table_format: str = "grid",
21+
) -> None:
22+
"""
23+
Loads model stored in *model_path* and extracts the GraphDef saved under the specified *serving_key*. From this
24+
GraphDef, all ops for specific *devices* are read and compared to all ops with XLA implementation. The matching
25+
result is printed given the chosen *table_format* style.
26+
"""
27+
# open the graph
28+
graph_def = load_graph_def(model_path, serving_key=serving_key)
29+
30+
# extract operation names
31+
op_names = get_graph_ops(graph_def)
32+
33+
# remove trivial ops
34+
op_names = [op_name for op_name in op_names if op_name not in ["Placeholder", "NoOp"]]
35+
36+
# print the op table
37+
devices, ops = print_op_table(devices, filter_ops=op_names, table_format=table_format)
38+
39+
# print a final summary per device
40+
for device in devices:
41+
failed_ops = [
42+
op_name
43+
for op_name in op_names
44+
if not ops.get(op_name, {}).get(device)
45+
]
46+
47+
msg = f"\n{colored(device, 'magenta')}: "
48+
if failed_ops:
49+
msg += colored("not compatible", "red")
50+
msg += f", {len(failed_ops)} incompatible ops: {', '.join(failed_ops)}"
51+
else:
52+
msg += colored("all ops compatible", "green")
53+
print(msg)
54+
55+
56+
def print_op_table(
57+
devices: tuple[str],
58+
filter_ops: list[str] | None = None,
59+
table_format: str = "grid",
60+
) -> tuple[list[str], OpsData]:
61+
"""
62+
Reads all ops for specific *devices* and prints a table given *table_format* style. Specific ops can be filtered
63+
using *filter_ops*.
64+
"""
65+
# read ops
66+
ops = OpsData(devices)
67+
68+
# get parsed devices
69+
devices = [
70+
device
71+
for device in ops.device_ids
72+
if any(
73+
op_data.get(device)
74+
for op_name, op_data in ops.items()
75+
if not filter_ops or op_name in filter_ops
76+
)
77+
]
78+
devices = sorted(set(devices), key=devices.index)
79+
80+
# prepare the table
81+
headers = ["Operation"] + devices
82+
content = []
83+
str_flag = lambda b: "yes" if b else "NO"
84+
for op_name, op_data in ops.items():
85+
if filter_ops and op_name not in filter_ops:
86+
continue
87+
88+
content.append([
89+
op_name,
90+
*(str_flag(bool(op_data.get(device))) for device in devices),
91+
])
92+
93+
# print it
94+
print(tabulate.tabulate(content, headers=headers, tablefmt=table_format))
95+
96+
return devices, ops
97+
98+
99+
def main() -> None:
100+
import os
101+
import sys
102+
from argparse import ArgumentParser
103+
104+
parser = ArgumentParser(
105+
prog=f"cmsml_{os.path.splitext(os.path.basename(__file__))[0]}",
106+
description="performs XLA / AOT compatiblity checks on a TensorFlow graph",
107+
)
108+
109+
parser.add_argument(
110+
"model_path",
111+
nargs="?",
112+
help="the path of the model to open",
113+
)
114+
parser.add_argument(
115+
"--serving-key",
116+
"-k",
117+
default="serving_default",
118+
help="serving key of the graph in --model-path; default: serving_default",
119+
)
120+
parser.add_argument(
121+
"--table",
122+
"-t",
123+
action="store_true",
124+
help="just print a table showing which operations are XLA / AOT supported for --devices",
125+
)
126+
parser.add_argument(
127+
"--table-format",
128+
"-f",
129+
default="grid",
130+
help="the tabulate format for printed tables; default: grid",
131+
)
132+
parser.add_argument(
133+
"--devices",
134+
"-d",
135+
type=(lambda s: tuple(s.strip().split(","))),
136+
help="comma separated list of devices to check; choices: cpu,gpu,tpu, default: cpu",
137+
)
138+
139+
args = parser.parse_args()
140+
141+
if args.table:
142+
# print the op table
143+
print_op_table(
144+
devices=args.devices,
145+
table_format=args.table_format,
146+
)
147+
148+
elif args.model_path:
149+
# run the compatibility check
150+
check_aot_compatibility(
151+
model_path=args.model_path,
152+
serving_key=args.serving_key,
153+
devices=args.devices,
154+
table_format=args.table_format,
155+
)
156+
157+
else:
158+
print("either '--model-path PATH' or '--table' must be set", file=sys.stderr)
159+
sys.exit(1)
160+
161+
162+
if __name__ == "__main__":
163+
main()

0 commit comments

Comments
 (0)