Skip to content

Commit 0ab4f98

Browse files
authored
Merge pull request #15 from cms-ml/feature/forward_more_aot_flags
Allow passing arbitrary flags to the aot compiler.
2 parents 63ae87e + 66ed012 commit 0ab4f98

1 file changed

Lines changed: 27 additions & 10 deletions

File tree

cmsml/scripts/compile_tf_graph.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import os
1010

11-
from cmsml.util import colored, interruptable_popen
11+
from cmsml.util import colored, interruptable_popen, make_list
1212
from cmsml.tensorflow.tools import import_tf, load_model
1313

1414

@@ -20,8 +20,9 @@ def compile_tf_graph(
2020
output_serving_key: str | None = None,
2121
compile_prefix: str | None = None,
2222
compile_class: str | None = None,
23-
xla_flags: list[str] | None = None,
24-
tf_xla_flags: list[str] | None = None,
23+
xla_flags: list[str] | str | None = None,
24+
tf_xla_flags: list[str] | str | None = None,
25+
additional_flags: list[str] | str | None = None,
2526
) -> None:
2627
"""
2728
For AOT compilation a static memory layout at runtime is required. This function prepares the given input SavedModel
@@ -38,7 +39,7 @@ def compile_tf_graph(
3839
An optional AOT compilation is initiated if *compile_class* and *compile_prefix* are given. In this case
3940
*compile_prefix* is the file prefix, while *compile_class* is the name of the AOT class within the generated files.
4041
41-
*xla_flags* and *tf_xla_flags* are forwarded to :py:func:`aot_compile`.
42+
*xla_flags*, *tf_xla_flags* and *additional_flags* are forwarded to :py:func:`aot_compile`.
4243
"""
4344
tf = import_tf()[0]
4445

@@ -102,6 +103,7 @@ def compile_tf_graph(
102103
serving_key=output_serving_key,
103104
xla_flags=xla_flags,
104105
tf_xla_flags=tf_xla_flags,
106+
additional_flags=additional_flags,
105107
)
106108

107109

@@ -112,8 +114,9 @@ def aot_compile(
112114
class_name: str,
113115
batch_sizes: tuple[int] = (1,),
114116
serving_key: str = r"serving_default_bs{}",
115-
xla_flags: list[str] | None = None,
116-
tf_xla_flags: list[str] | None = None,
117+
xla_flags: list[str] | str | None = None,
118+
tf_xla_flags: list[str] | str | None = None,
119+
additional_flags: list[str] | str | None = None,
117120
) -> None:
118121
"""
119122
Loads the graph from the SavedModel located at *model_path*, extracts the static graph specified by *serving_key*
@@ -123,7 +126,8 @@ def aot_compile(
123126
header access the AOT-compiled network.
124127
125128
When *xla_flags* and *tf_xla_flags* are given, they are forwarded as comma-separated values to the *XLA_FLAGS*
126-
and *TF_XLA_FLAGS* environment variables, respectively.
129+
and *TF_XLA_FLAGS* environment variables, respectively. *additional_flags* are forwarded as is to the underlying
130+
aot compiler invocation.
127131
"""
128132
# prepare model path
129133
model_path = os.path.abspath(os.path.expandvars(os.path.expanduser(str(model_path))))
@@ -145,16 +149,21 @@ def aot_compile(
145149
# ammend the env when xla flags were passed
146150
env = os.environ.copy()
147151
if xla_flags:
152+
xla_flags = make_list(xla_flags)
148153
xla_flags_orig = env.get("XLA_FLAGS", "")
149154
if xla_flags_orig:
150155
xla_flags = [xla_flags_orig.rstrip(",")] + xla_flags
151156
env["XLA_FLAGS"] = ",".join(map(str, xla_flags))
152157
if tf_xla_flags:
158+
tf_xla_flags = make_list(tf_xla_flags)
153159
tf_xla_flags_orig = env.get("TF_XLA_FLAGS", "")
154160
if tf_xla_flags_orig:
155161
tf_xla_flags = [tf_xla_flags_orig.rstrip(",")] + tf_xla_flags
156162
env["TF_XLA_FLAGS"] = ",".join(map(str, tf_xla_flags))
157163

164+
# prepare additional flags
165+
additional_flags_str = " ".join(make_list(additional_flags)) if additional_flags else ""
166+
158167
# compile for each batch size
159168
for bs in sorted(set(map(int, batch_sizes))):
160169
cmd = (
@@ -164,7 +173,8 @@ def aot_compile(
164173
f" --output_prefix {prefix.format(bs)}"
165174
f" --cpp_class {class_name.format(bs)}"
166175
" --tag_set serve"
167-
)
176+
f" {additional_flags_str}"
177+
).strip()
168178

169179
print(f"compiling for batch size {colored(bs, 'magenta')}")
170180
code = interruptable_popen(cmd, executable="/bin/bash", shell=True, cwd=output_path, env=env)[0]
@@ -222,17 +232,23 @@ def main() -> None:
222232
)
223233
parser.add_argument(
224234
"--output-serving-key",
225-
help=r"serving key pattern for concrete models in --output-path, with {} being replaced by "
235+
help=r"serving key pattern for concrete models in --output-path, with {} being replaced by "
226236
r"the batch size; default: <input_serving_key>__bs{}",
227237
)
228238
parser.add_argument(
229239
"--compile",
230240
"-c",
231241
nargs=2,
232-
help=r"file name prefix and class name of the AOT compiled objects; in both values, {} is "
242+
help=r"file name prefix and class name of the AOT compiled objects; in both values, {} is "
233243
"replaced by the batch size; no AOT compilation is triggered when empty; files will be "
234244
"saved at <output_path>/aot/<prefix>{.h,.o,_metadata.o,_makefile.inc}",
235245
)
246+
parser.add_argument(
247+
"--additional-flags",
248+
"-f",
249+
help="additional, space-separated flags to be passed to the underlying aot compiler invocation; "
250+
"for more info, see 'saved_model_cli --helpfull'",
251+
)
236252

237253
args = parser.parse_args()
238254

@@ -244,6 +260,7 @@ def main() -> None:
244260
output_serving_key=args.output_serving_key,
245261
compile_prefix=args.compile and args.compile[0],
246262
compile_class=args.compile and args.compile[1],
263+
additional_flags=args.additional_flags,
247264
)
248265

249266

0 commit comments

Comments
 (0)