88
99import os
1010
11- from cmsml .util import colored , interruptable_popen
11+ from cmsml .util import colored , interruptable_popen , make_list
1212from cmsml .tensorflow .tools import import_tf , load_model
1313
1414
@@ -20,8 +20,9 @@ def compile_tf_graph(
2020 output_serving_key : str | None = None ,
2121 compile_prefix : str | None = None ,
2222 compile_class : str | None = None ,
23- xla_flags : list [str ] | None = None ,
24- tf_xla_flags : list [str ] | None = None ,
23+ xla_flags : list [str ] | str | None = None ,
24+ tf_xla_flags : list [str ] | str | None = None ,
25+ additional_flags : list [str ] | str | None = None ,
2526) -> None :
2627 """
2728 For AOT compilation a static memory layout at runtime is required. This function prepares the given input SavedModel
@@ -38,7 +39,7 @@ def compile_tf_graph(
3839 An optional AOT compilation is initiated if *compile_class* and *compile_prefix* are given. In this case
3940 *compile_prefix* is the file prefix, while *compile_class* is the name of the AOT class within the generated files.
4041
41- *xla_flags* and *tf_xla_flags * are forwarded to :py:func:`aot_compile`.
42+ *xla_flags*, *tf_xla_flags* and *additional_flags * are forwarded to :py:func:`aot_compile`.
4243 """
4344 tf = import_tf ()[0 ]
4445
@@ -102,6 +103,7 @@ def compile_tf_graph(
102103 serving_key = output_serving_key ,
103104 xla_flags = xla_flags ,
104105 tf_xla_flags = tf_xla_flags ,
106+ additional_flags = additional_flags ,
105107 )
106108
107109
@@ -112,8 +114,9 @@ def aot_compile(
112114 class_name : str ,
113115 batch_sizes : tuple [int ] = (1 ,),
114116 serving_key : str = r"serving_default_bs{}" ,
115- xla_flags : list [str ] | None = None ,
116- tf_xla_flags : list [str ] | None = None ,
117+ xla_flags : list [str ] | str | None = None ,
118+ tf_xla_flags : list [str ] | str | None = None ,
119+ additional_flags : list [str ] | str | None = None ,
117120) -> None :
118121 """
119122 Loads the graph from the SavedModel located at *model_path*, extracts the static graph specified by *serving_key*
@@ -123,7 +126,8 @@ def aot_compile(
123126 header access the AOT-compiled network.
124127
125128 When *xla_flags* and *tf_xla_flags* are given, they are forwarded as comma-separated values to the *XLA_FLAGS*
126- and *TF_XLA_FLAGS* environment variables, respectively.
129+ and *TF_XLA_FLAGS* environment variables, respectively. *additional_flags* are forwarded as is to the underlying
130+ aot compiler invocation.
127131 """
128132 # prepare model path
129133 model_path = os .path .abspath (os .path .expandvars (os .path .expanduser (str (model_path ))))
@@ -145,16 +149,21 @@ def aot_compile(
145149 # ammend the env when xla flags were passed
146150 env = os .environ .copy ()
147151 if xla_flags :
152+ xla_flags = make_list (xla_flags )
148153 xla_flags_orig = env .get ("XLA_FLAGS" , "" )
149154 if xla_flags_orig :
150155 xla_flags = [xla_flags_orig .rstrip ("," )] + xla_flags
151156 env ["XLA_FLAGS" ] = "," .join (map (str , xla_flags ))
152157 if tf_xla_flags :
158+ tf_xla_flags = make_list (tf_xla_flags )
153159 tf_xla_flags_orig = env .get ("TF_XLA_FLAGS" , "" )
154160 if tf_xla_flags_orig :
155161 tf_xla_flags = [tf_xla_flags_orig .rstrip ("," )] + tf_xla_flags
156162 env ["TF_XLA_FLAGS" ] = "," .join (map (str , tf_xla_flags ))
157163
164+ # prepare additional flags
165+ additional_flags_str = " " .join (make_list (additional_flags )) if additional_flags else ""
166+
158167 # compile for each batch size
159168 for bs in sorted (set (map (int , batch_sizes ))):
160169 cmd = (
@@ -164,7 +173,8 @@ def aot_compile(
164173 f" --output_prefix { prefix .format (bs )} "
165174 f" --cpp_class { class_name .format (bs )} "
166175 " --tag_set serve"
167- )
176+ f" { additional_flags_str } "
177+ ).strip ()
168178
169179 print (f"compiling for batch size { colored (bs , 'magenta' )} " )
170180 code = interruptable_popen (cmd , executable = "/bin/bash" , shell = True , cwd = output_path , env = env )[0 ]
@@ -222,17 +232,23 @@ def main() -> None:
222232 )
223233 parser .add_argument (
224234 "--output-serving-key" ,
225- help = r"serving key pattern for concrete models in --output-path, with {} being replaced by "
235+ help = r"serving key pattern for concrete models in --output-path, with {} being replaced by "
226236 r"the batch size; default: <input_serving_key>__bs{}" ,
227237 )
228238 parser .add_argument (
229239 "--compile" ,
230240 "-c" ,
231241 nargs = 2 ,
232- help = r"file name prefix and class name of the AOT compiled objects; in both values, {} is "
242+ help = r"file name prefix and class name of the AOT compiled objects; in both values, {} is "
233243 "replaced by the batch size; no AOT compilation is triggered when empty; files will be "
234244 "saved at <output_path>/aot/<prefix>{.h,.o,_metadata.o,_makefile.inc}" ,
235245 )
246+ parser .add_argument (
247+ "--additional-flags" ,
248+ "-f" ,
249+ help = "additional, space-separated flags to be passed to the underlying aot compiler invocation; "
250+ "for more info, see 'saved_model_cli --helpfull'" ,
251+ )
236252
237253 args = parser .parse_args ()
238254
@@ -244,6 +260,7 @@ def main() -> None:
244260 output_serving_key = args .output_serving_key ,
245261 compile_prefix = args .compile and args .compile [0 ],
246262 compile_class = args .compile and args .compile [1 ],
263+ additional_flags = args .additional_flags ,
247264 )
248265
249266
0 commit comments