Skip to content

Commit 89df2ec

Browse files
authored
Merge pull request #14 from cms-ml/feature/passthrough_xla_flags
Forward XLA_FLAGS and TF_XLA_FLAGS in AOT compilation.
2 parents 8b24692 + b5a33b1 commit 89df2ec

1 file changed

Lines changed: 25 additions & 1 deletion

File tree

cmsml/scripts/compile_tf_graph.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def compile_tf_graph(
2020
output_serving_key: str | None = None,
2121
compile_prefix: str | None = None,
2222
compile_class: str | None = None,
23+
xla_flags: list[str] | None = None,
24+
tf_xla_flags: list[str] | None = None,
2325
) -> None:
2426
"""
2527
For AOT compilation a static memory layout at runtime is required. This function prepares the given input SavedModel
@@ -35,6 +37,8 @@ def compile_tf_graph(
3537
3638
An optional AOT compilation is initiated if *compile_class* and *compile_prefix* are given. In this case
3739
*compile_prefix* is the file prefix, while *compile_class* is the name of the AOT class within the generated files.
40+
41+
*xla_flags* and *tf_xla_flags* are forwarded to :py:func:`aot_compile`.
3842
"""
3943
tf = import_tf()[0]
4044

@@ -96,6 +100,8 @@ def compile_tf_graph(
96100
compile_class,
97101
batch_sizes=batch_sizes,
98102
serving_key=output_serving_key,
103+
xla_flags=xla_flags,
104+
tf_xla_flags=tf_xla_flags,
99105
)
100106

101107

@@ -106,13 +112,18 @@ def aot_compile(
106112
class_name: str,
107113
batch_sizes: tuple[int] = (1,),
108114
serving_key: str = r"serving_default_bs{}",
115+
xla_flags: list[str] | None = None,
116+
tf_xla_flags: list[str] | None = None,
109117
) -> None:
110118
"""
111119
Loads the graph from the SavedModel located at *model_path*, extracts the static graph specified by *serving_key*
112120
from it, AOT compiles it.
113121
114122
This process generates header and object files at *output_path*. The *class_name* is used as class name within the
115123
header access the AOT-compiled network.
124+
125+
When *xla_flags* and *tf_xla_flags* are given, they are forwarded as comma-separated values to the *XLA_FLAGS*
126+
and *TF_XLA_FLAGS* environment variables, respectively.
116127
"""
117128
# prepare model path
118129
model_path = os.path.abspath(os.path.expandvars(os.path.expanduser(str(model_path))))
@@ -131,6 +142,19 @@ def aot_compile(
131142
# get the compilation executable
132143
exe = _which_saved_model_cli()
133144

145+
# ammend the env when xla flags were passed
146+
env = os.environ.copy()
147+
if xla_flags:
148+
xla_flags_orig = env.get("XLA_FLAGS", "")
149+
if xla_flags_orig:
150+
xla_flags = [xla_flags_orig.rstrip(",")] + xla_flags
151+
env["XLA_FLAGS"] = ",".join(map(str, xla_flags))
152+
if tf_xla_flags:
153+
tf_xla_flags_orig = env.get("TF_XLA_FLAGS", "")
154+
if tf_xla_flags_orig:
155+
tf_xla_flags = [tf_xla_flags_orig.rstrip(",")] + tf_xla_flags
156+
env["TF_XLA_FLAGS"] = ",".join(map(str, tf_xla_flags))
157+
134158
# compile for each batch size
135159
for bs in sorted(set(map(int, batch_sizes))):
136160
cmd = (
@@ -143,7 +167,7 @@ def aot_compile(
143167
)
144168

145169
print(f"compiling for batch size {colored(bs, 'magenta')}")
146-
code = interruptable_popen(cmd, executable="/bin/bash", shell=True, cwd=output_path)[0]
170+
code = interruptable_popen(cmd, executable="/bin/bash", shell=True, cwd=output_path, env=env)[0]
147171
if code != 0:
148172
raise Exception(f"aot compilation using {exe} failed with exit code {code}")
149173

0 commit comments

Comments
 (0)