@@ -20,6 +20,8 @@ def compile_tf_graph(
2020 output_serving_key : str | None = None ,
2121 compile_prefix : str | None = None ,
2222 compile_class : str | None = None ,
23+ xla_flags : list [str ] | None = None ,
24+ tf_xla_flags : list [str ] | None = None ,
2325) -> None :
2426 """
2527 For AOT compilation a static memory layout at runtime is required. This function prepares the given input SavedModel
@@ -35,6 +37,8 @@ def compile_tf_graph(
3537
3638 An optional AOT compilation is initiated if *compile_class* and *compile_prefix* are given. In this case
3739 *compile_prefix* is the file prefix, while *compile_class* is the name of the AOT class within the generated files.
40+
41+ *xla_flags* and *tf_xla_flags* are forwarded to :py:func:`aot_compile`.
3842 """
3943 tf = import_tf ()[0 ]
4044
@@ -96,6 +100,8 @@ def compile_tf_graph(
96100 compile_class ,
97101 batch_sizes = batch_sizes ,
98102 serving_key = output_serving_key ,
103+ xla_flags = xla_flags ,
104+ tf_xla_flags = tf_xla_flags ,
99105 )
100106
101107
@@ -106,13 +112,18 @@ def aot_compile(
106112 class_name : str ,
107113 batch_sizes : tuple [int ] = (1 ,),
108114 serving_key : str = r"serving_default_bs{}" ,
115+ xla_flags : list [str ] | None = None ,
116+ tf_xla_flags : list [str ] | None = None ,
109117) -> None :
110118 """
111119 Loads the graph from the SavedModel located at *model_path*, extracts the static graph specified by *serving_key*
112120 from it, AOT compiles it.
113121
114122 This process generates header and object files at *output_path*. The *class_name* is used as class name within the
115123 header access the AOT-compiled network.
124+
125+ When *xla_flags* and *tf_xla_flags* are given, they are forwarded as comma-separated values to the *XLA_FLAGS*
126+ and *TF_XLA_FLAGS* environment variables, respectively.
116127 """
117128 # prepare model path
118129 model_path = os .path .abspath (os .path .expandvars (os .path .expanduser (str (model_path ))))
@@ -131,6 +142,19 @@ def aot_compile(
131142 # get the compilation executable
132143 exe = _which_saved_model_cli ()
133144
145+ # ammend the env when xla flags were passed
146+ env = os .environ .copy ()
147+ if xla_flags :
148+ xla_flags_orig = env .get ("XLA_FLAGS" , "" )
149+ if xla_flags_orig :
150+ xla_flags = [xla_flags_orig .rstrip ("," )] + xla_flags
151+ env ["XLA_FLAGS" ] = "," .join (map (str , xla_flags ))
152+ if tf_xla_flags :
153+ tf_xla_flags_orig = env .get ("TF_XLA_FLAGS" , "" )
154+ if tf_xla_flags_orig :
155+ tf_xla_flags = [tf_xla_flags_orig .rstrip ("," )] + tf_xla_flags
156+ env ["TF_XLA_FLAGS" ] = "," .join (map (str , tf_xla_flags ))
157+
134158 # compile for each batch size
135159 for bs in sorted (set (map (int , batch_sizes ))):
136160 cmd = (
@@ -143,7 +167,7 @@ def aot_compile(
143167 )
144168
145169 print (f"compiling for batch size { colored (bs , 'magenta' )} " )
146- code = interruptable_popen (cmd , executable = "/bin/bash" , shell = True , cwd = output_path )[0 ]
170+ code = interruptable_popen (cmd , executable = "/bin/bash" , shell = True , cwd = output_path , env = env )[0 ]
147171 if code != 0 :
148172 raise Exception (f"aot compilation using { exe } failed with exit code { code } " )
149173
0 commit comments