@@ -144,6 +144,48 @@ def _write_log_file(self, file: Path | str | None, content: str) -> None:
144144 except Exception as e :
145145 self .logger .warning (f"Failed to save logs to { file } : { e } " )
146146
147+ _SUBGRAPH_SAFE_FLAGS = frozenset ({
148+ "stronglyTyped" , "maxTactics" , "sparsity" , "maxAuxStreams" ,
149+ "builderOptimizationLevel" ,
150+ "fp16" , "int8" , "fp8" , "best" , "noTF32" ,
151+ "useCudaGraph" , "useSpinWait" , "profilingVerbosity" , "verbose" ,
152+ "separateProfileRun" , "noDataTransfers" , "dumpProfile" , "dumpLayerInfo" ,
153+ "avgRuns" , "iterations" , "warmUp" , "duration" ,
154+ "saveEngine" , "timingCacheFile" ,
155+ "staticPlugins" , "dynamicPlugins" ,
156+ "noMyelinFusion" , "workspace" , "memPoolSize" ,
157+ "safe" ,
158+ })
159+
160+
161+ def _extract_flag_name (arg : str ) -> str :
162+ """Extract the flag name from a trtexec argument like '--maxTactics=2000'."""
163+ return arg .lstrip ("-" ).split ("=" , 1 )[0 ]
164+
165+
166+ def _filter_subgraph_safe_args (cmd : list [str ]) -> list [str ]:
167+ """Keep only subgraph-safe args from a trtexec command list.
168+
169+ Shape args (--optShapes, --minShapes, --maxShapes) are stripped here because
170+ they are rebuilt per-subgraph by the caller. All other model-specific args
171+ (--loadInputs, --exportOutput, --shapes, etc.) are also excluded.
172+ """
173+ filtered : list [str ] = []
174+ skip_next = False
175+ for c in cmd :
176+ if skip_next :
177+ skip_next = False
178+ continue
179+ if not c .startswith ("-" ):
180+ filtered .append (c )
181+ continue
182+ flag = _extract_flag_name (c )
183+ if flag in _SUBGRAPH_SAFE_FLAGS :
184+ filtered .append (c )
185+ elif "=" not in c :
186+ skip_next = True
187+ return filtered
188+
147189
148190def _dedup_trtexec_args (
149191 base_args : list [tuple [str , str | None ]], user_args : list [str ]
@@ -323,10 +365,11 @@ def run(
323365 cmd = list (self ._base_cmd )
324366
325367 if strip_shape_args :
326- cmd = [
327- c for c in cmd
328- if not any (c .startswith (p ) for p in ("--optShapes" , "--minShapes" , "--maxShapes" ))
329- ]
368+ before = cmd [:]
369+ cmd = _filter_subgraph_safe_args (cmd )
370+ removed = set (before ) - set (cmd )
371+ if removed :
372+ self .logger .debug (f"Subgraph filter removed: { removed } " )
330373
331374 cmd .append (f"--onnx={ model_path } " )
332375
0 commit comments