@@ -51,7 +51,20 @@ class QDQAutotunerBase:
5151 """Base class for pattern-based Q/DQ node insertion optimization in ONNX models."""
5252
5353 def __init__ (self , model : onnx .ModelProto | gs .Graph ):
54- """Initialize the autotuner with an ONNX model."""
54+ """Initialize the autotuner with an ONNX model.
55+
56+ Creates a clean copy of the model graph and initializes internal state.
57+ After construction, call initialize() to configure the autotuner, then
58+ use a subclass strategy to populate regions (e.g., QDQAutotuner does this
59+ automatically during initialize()).
60+
61+ Args:
62+ model: ONNX model (onnx.ModelProto) or graph (gs.Graph) to optimize.
63+ A clean copy is created internally, leaving the original unchanged.
64+
65+ Raises:
66+ TypeError: If model is neither onnx.ModelProto nor gs.Graph
67+ """
5568 if isinstance (model , onnx .ModelProto ):
5669 self .onnx_model = model
5770 elif isinstance (model , gs .Graph ):
@@ -76,7 +89,22 @@ def __init__(self, model: onnx.ModelProto | gs.Graph):
7689 def initialize (
7790 self , config : Config | None = None , pattern_cache : PatternCache | None = None
7891 ) -> None :
79- """Initialize autotuning session with configuration and pattern cache."""
92+ """Initialize autotuning session with configuration and pattern cache.
93+
94+ Prepares the autotuner for profiling by setting configuration parameters
95+ and optionally loading pattern cache data. This base method resets all profiling
96+ state and sets up the pattern cache storage.
97+
98+ Args:
99+ config: Autotuning configuration parameters. If None, uses default Config().
100+ Controls Q/DQ parameters, performance thresholds, and scheme generation.
101+ pattern_cache: Optional PatternCache object for seeding with known-good schemes.
102+ If None, creates a new empty pattern cache for tracking best schemes.
103+ If provided, uses existing schemes to warm-start optimization.
104+
105+ Raises:
106+ None (safe to call multiple times - will reset state each time)
107+ """
80108 if config is not None :
81109 self .config = config
82110
@@ -109,7 +137,24 @@ def initialize(
109137 self .initialized = True
110138
111139 def set_profile_region (self , region : Region | None , commit : bool = True ) -> None :
112- """Set the target region for profiling and scheme generation."""
140+ """Set the target region for profiling and scheme generation.
141+
142+ This method manages the profiling workflow:
143+ 1. If commit=True: Saves current schemes to profiled_patterns
144+ 2. Creates a RegionPattern from the new region's structure
145+ 3. For pattern-based: tries to seed schemes from pattern cache if available
146+ 4. Sets as current for generate() and submit() calls
147+
148+ Pass region=None to clear the current profile target without setting a new one.
149+
150+ Args:
151+ region: The region to profile next (None to clear current target)
152+ commit: If True, commit current schemes to profiled_patterns
153+ before switching. Set to False during initialization.
154+
155+ Raises:
156+ AutotunerNotInitializedError: If initialize() hasn't been called
157+ """
113158 if not self .initialized :
114159 raise AutotunerNotInitializedError (
115160 "QDQAutotunerBase not initialized. Call initialize() first."
@@ -185,13 +230,24 @@ def set_profile_region(self, region: Region | None, commit: bool = True) -> None
185230
186231 mode_info = f"seeded with { num_seeded } schemes" if num_seeded > 0 else "starting fresh"
187232 logger .info (
188- f"Profiling region { region .id } [pattern mode, level { region .level } , "
189- f"size { region .get_size_of_region_and_descendants ()} , { mode_info } ]"
233+ f"Profiling region { region .id } [level { region .level } , size "
234+ f"{ region .get_size_of_region_and_descendants ()} , { mode_info } ]"
190235 )
191236 logger .debug (f"Pattern signature: { region_pattern .signature } " )
192237
193238 def generate (self ) -> int :
194- """Generate a new Q/DQ insertion scheme for the current pattern or region."""
239+ """Generate a new Q/DQ insertion scheme for the current pattern or region.
240+
241+ Creates a new InsertionScheme by mutating the top-performing schemes:
242+ 1. Checks if there are any cached schemes (error=False, latency_ms=inf)
243+ 2. If cached schemes exist, picks one to re-profile
244+ 3. Otherwise, generates a new scheme by mutation
245+ 4. Selects a random scheme from the top 10 performers
246+ 5. Mutates it by adding/removing insertion points
247+ 6. Ensures the new scheme is unique (different from existing schemes)
248+ 7. Adds the scheme to current_profile_pattern_schemes
249+
250+ """
195251 if not self .initialized :
196252 raise AutotunerNotInitializedError (
197253 "QDQAutotunerBase not initialized. Call initialize() first."
@@ -261,7 +317,28 @@ def generate(self) -> int:
261317 def export_onnx (
262318 self , output_path : str | None = None , insert_qdq : bool = True , best : bool = False
263319 ) -> bytes :
264- """Export ONNX model with Q/DQ nodes inserted according to tested schemes."""
320+ """Export ONNX model with Q/DQ nodes inserted according to tested schemes.
321+
322+ This method creates a modified version of the model by:
323+ 1. For each region, finding the matching pattern
324+ 2. Applying the best scheme for profiled patterns
325+ 3. Applying the current scheme for the active profile pattern
326+ 4. Resolving pattern-relative insertion points to actual tensor names
327+ 5. Inserting Q/DQ pairs at the resolved locations
328+ 6. Converting to FP8 if needed (always creates INT8 first, then converts)
329+
330+ Args:
331+ output_path: Optional file path where the modified ONNX model will be saved.
332+ If None, the model is not saved to disk and only bytes are returned.
333+ insert_qdq: If True, insert Q/DQ nodes. If False, export unmodified model
334+ (useful for baseline measurements)
335+
336+ Returns:
337+ bytes: Serialized ONNX model as bytes
338+
339+ Raises:
340+ AutotunerNotInitializedError: If initialize() hasn't been called
341+ """
265342 if not self .initialized :
266343 raise AutotunerNotInitializedError (
267344 "QDQAutotunerBase not initialized. Call initialize() first."
@@ -387,7 +464,19 @@ def export_onnx(
387464 return model_bytes
388465
389466 def submit (self , latency_ms : float , success : bool = True ) -> None :
390- """Submit performance measurement for the most recently generated scheme."""
467+ """Submit performance measurement for the most recently generated scheme.
468+
469+ This method records the measured latency and manages the optimization state:
470+
471+ Args:
472+ latency_ms: Measured latency in milliseconds (must be > 0)
473+ success: Whether the measurement succeeded. If False, sets scheme.error=True,
474+ logs a warning, and skips speedup calculation.
475+
476+ Raises:
477+ AutotunerNotInitializedError: If initialize() hasn't been called
478+ InvalidSchemeError: If no pattern or region is set, or no schemes have been generated
479+ """
391480 if not self .initialized :
392481 raise AutotunerNotInitializedError (
393482 "QDQAutotunerBase not initialized. Call initialize() first."
@@ -458,7 +547,19 @@ def submit(self, latency_ms: float, success: bool = True) -> None:
458547 )
459548
460549 def save_state (self , output_path : str ) -> None :
461- """Save complete autotuner state to a YAML file for later reuse."""
550+ """Save complete autotuner state to a YAML file for later reuse.
551+
552+ Serializes all optimization results including:
553+ - Baseline latency measurement
554+ - All profiled patterns with their signatures
555+ - All generated schemes with insertion points and latencies
556+ - Configuration parameters
557+ - Current profiling state
558+
559+ Args:
560+ output_path: File path where the YAML state file will be written.
561+ Pattern cache will be saved to <base>_pattern_cache.yaml
562+ """
462563 current_pattern_sig = None
463564 if self .current_profile_pattern_schemes is not None :
464565 current_pattern_sig = self .current_profile_pattern_schemes .pattern_signature
@@ -498,7 +599,20 @@ def save_state(self, output_path: str) -> None:
498599 )
499600
500601 def load_state (self , input_path : str ) -> None :
501- """Load autotuner state from a previously saved YAML file."""
602+ """Load autotuner state from a previously saved YAML file.
603+
604+ Restores optimization results from a previous session:
605+ 1. Matches saved patterns to current model's patterns by signature
606+ 2. Loads all schemes with their insertion points and latencies (including unmeasured ones)
607+ 3. Restores baseline latency and configuration
608+
609+ Args:
610+ input_path: File path to the YAML state file to load
611+
612+ Raises:
613+ AutotunerNotInitializedError: If initialize() hasn't been called
614+ FileNotFoundError: If the input_path doesn't exist
615+ """
502616 if not self .initialized :
503617 raise AutotunerNotInitializedError (
504618 "QDQAutotunerBase not initialized. Call initialize() first."
@@ -571,7 +685,20 @@ def load_state(self, input_path: str) -> None:
571685 logger .debug (f"No pattern cache file at { cache_path } " )
572686
573687 def import_insertion_points (self , quantized_tensors : set [str ] | list [str ]) -> None :
574- """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache."""
688+ """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache.
689+
690+ Analyzes the current model's regions against the provided quantized tensors
691+ to extract Q/DQ insertion patterns. For each region, creates a pattern cache
692+ entry that captures which insertion points correspond to the quantized tensors.
693+ These cached patterns can then be used as seeds for future autotuning sessions.
694+
695+ Args:
696+ quantized_tensors: Set or list of tensor names that are quantized
697+ (i.e., tensors that have Q/DQ nodes applied to them)
698+
699+ Raises:
700+ AutotunerNotInitializedError: If initialize() hasn't been called
701+ """
575702 if not self .initialized :
576703 raise AutotunerNotInitializedError (
577704 "QDQAutotunerBase not initialized. Call initialize() first."
@@ -607,7 +734,22 @@ def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> No
607734 def _compute_convergence_metrics (
608735 self , schemes : list [InsertionScheme ], best_scheme : InsertionScheme | None
609736 ) -> tuple [int | None , float | None ]:
610- """Compute convergence metrics for a collection of schemes."""
737+ """Compute convergence metrics for a collection of schemes.
738+
739+ Analyzes when the best scheme was discovered during the profiling process
740+ by sorting schemes by their profile timestamps and finding the position
741+ of the best scheme.
742+
743+ Args:
744+ schemes: List of insertion schemes with profile timestamps
745+ best_scheme: The best performing scheme (lowest latency)
746+
747+ Returns:
748+ Tuple of (samples_before_best, time_to_best) where:
749+ - samples_before_best: Number of samples tested before finding best (0-based index)
750+ - time_to_best: Time in seconds from first sample to best sample
751+ Both values are None if metrics cannot be computed (e.g., missing timestamps)
752+ """
611753 samples_before_best = None
612754 time_to_best = None
613755
@@ -690,7 +832,29 @@ def _mutate_insertion_points(
690832 return [p for p in all_points if key_fn (p ) in current_points ]
691833
692834 def _generate_next_insertion_sample (self ) -> InsertionScheme :
693- """Generate a new insertion scheme by mutating top performers."""
835+ """Generate a new insertion scheme by mutating top performers.
836+
837+ This is the core scheme generation algorithm:
838+ 1. Identifies top schemes by latency
839+ 2. Randomly selects one as the base
840+ 3. Mutates node input insertion points (add, remove, or both)
841+ 4. Mutates region composite insertion points (child boundaries)
842+ 5. Mutates region output insertion points
843+ 6. Returns new unique scheme
844+
845+ **Mutation Strategy:**
846+ - Node input points: Add/remove 1-3 insertion points
847+ - Region composite points: Add/remove 1-3 boundary points
848+ - Region output points: Add/remove 1-3 output points
849+ - Mutation type chosen randomly: 'add', 'remove', or 'both'
850+
851+ **Baseline Case:**
852+ If no schemes exist yet, returns an empty baseline scheme.
853+
854+ Returns:
855+ New InsertionScheme with mutated insertion points.
856+ Returns empty scheme if no region is set or no candidates exist.
857+ """
694858 if self .current_profile_region is None :
695859 return InsertionScheme ()
696860
@@ -891,7 +1055,20 @@ def _create_qdq_nodes(
8911055 quant_dtype : np .dtype ,
8921056 q_scale : float ,
8931057 ) -> tuple [gs .Node , gs .Node ]:
894- """Create QuantizeLinear and DequantizeLinear node pair."""
1058+ """Create QuantizeLinear and DequantizeLinear node pair.
1059+
1060+ Args:
1061+ tensor_name: Name of the tensor being quantized
1062+ qdq_input: Input tensor to the Q node
1063+ output_shape: Shape for Q/DQ outputs (may be None)
1064+ output_dtype: Dtype for DQ output (also used for scale dtype)
1065+ quant_dtype: Dtype for quantized values
1066+ quant_type: Quantization type string
1067+ q_scale: Quantization scale
1068+
1069+ Returns:
1070+ Tuple of (q_node, dq_node)
1071+ """
8951072 # Create unique names for Q/DQ nodes
8961073 q_name = f"QDQ_Q_{ tensor_name } " .replace ("/" , "_" ).replace (":" , "_" )
8971074 dq_name = f"QDQ_DQ_{ tensor_name } " .replace ("/" , "_" ).replace (":" , "_" )
@@ -943,7 +1120,17 @@ def _create_qdq_nodes(
9431120 def _insert_qdq_at_tensors (
9441121 self , graph : gs .Graph , resolved_insertion_points : set [ResolvedInsertionPoint ]
9451122 ) -> None :
946- """Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations."""
1123+ """Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations.
1124+
1125+ This is the main entry point for Q/DQ insertion. It:
1126+ 1. Builds tensor map and tensor-to-users map for efficient lookup
1127+ 2. Processes each resolved insertion point to insert Q/DQ nodes
1128+ 3. Handles two insertion modes based on node_index
1129+
1130+ Args:
1131+ graph: Graph to modify in-place
1132+ resolved_insertion_points: Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ
1133+ """
9471134 q_scale = self .config .default_q_scale
9481135 quant_type = self .config .default_quant_type
9491136 quant_dtype = self ._get_quant_dtype (quant_type )
@@ -1031,12 +1218,28 @@ class QDQAutotuner(QDQAutotunerBase):
10311218 def initialize (
10321219 self , config : Config | None = None , pattern_cache : PatternCache | None = None
10331220 ) -> None :
1034- """Initialize autotuner and discover optimization regions automatically."""
1221+ """Initialize autotuner and discover optimization regions automatically.
1222+
1223+ Extends base class initialization by automatically searching for regions
1224+ after configuration is set up. Regions are discovered using pattern-based
1225+ search around compute-intensive operations.
1226+ """
10351227 super ().initialize (config , pattern_cache )
10361228 self ._search_regions ()
10371229
10381230 def _visit_region_recursively (self , region : Region ) -> list [Region ]:
1039- """Recursively traverse region hierarchy and collect all regions."""
1231+ """Recursively traverse region hierarchy and collect all regions.
1232+
1233+ Performs depth-first traversal of the region tree starting from a given
1234+ region. Collects the root region and all descendant regions (children,
1235+ grandchildren, etc.) into a flat list.
1236+
1237+ Args:
1238+ region: Root region to start traversal from
1239+
1240+ Returns:
1241+ List of all regions in the subtree (including root), in pre-order DFS.
1242+ """
10401243 regions = [region ]
10411244
10421245 for child in region .get_children ():
@@ -1045,7 +1248,15 @@ def _visit_region_recursively(self, region: Region) -> list[Region]:
10451248 return regions
10461249
10471250 def _reassign_region_ids (self , regions : list [Region ]) -> None :
1048- """Reassign sequential IDs to regions in breadth-first order."""
1251+ """Reassign sequential IDs to regions in breadth-first order.
1252+
1253+ Traverses the region hierarchy (including children) and assigns new
1254+ sequential IDs starting from 0. This ensures clean, predictable region
1255+ numbering after region discovery and manipulation.
1256+
1257+ Args:
1258+ regions: List of top-level regions (children will be processed too)
1259+ """
10491260 region_id = 0
10501261
10511262 queue = deque (regions )
@@ -1057,7 +1268,19 @@ def _reassign_region_ids(self, regions: list[Region]) -> None:
10571268 queue .extend (region .get_children ())
10581269
10591270 def _search_regions (self ) -> None :
1060- """Discover and organize optimization regions automatically."""
1271+ """Discover and organize optimization regions automatically.
1272+
1273+ This is the core region discovery method that:
1274+ 1. Runs automatic region search to find optimization targets
1275+ 2. Flattens hierarchical structure into a list
1276+ 3. Prioritizes LEAF regions (contain actual nodes)
1277+ 4. Reassigns IDs for clean indexing
1278+
1279+ **Search Strategy:**
1280+ Uses CombinedRegionSearch which performs:
1281+ - Phase 1: Bottom-up partitioning based on divergence/convergence
1282+ - Phase 2: Top-down refinement creating hierarchical structure
1283+ """
10611284 logger .info ("Discovering optimization regions" )
10621285 search = CombinedRegionSearch (
10631286 self .graph ,
0 commit comments