Skip to content

Commit bd18dfa

Browse files
committed
pick back docstrings
Signed-off-by: Will Guo <willg@nvidia.com>
1 parent 98d7f55 commit bd18dfa

1 file changed

Lines changed: 242 additions & 19 deletions

File tree

modelopt/onnx/quantization/autotune/autotuner.py

Lines changed: 242 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,20 @@ class QDQAutotunerBase:
5151
"""Base class for pattern-based Q/DQ node insertion optimization in ONNX models."""
5252

5353
def __init__(self, model: onnx.ModelProto | gs.Graph):
54-
"""Initialize the autotuner with an ONNX model."""
54+
"""Initialize the autotuner with an ONNX model.
55+
56+
Creates a clean copy of the model graph and initializes internal state.
57+
After construction, call initialize() to configure the autotuner, then
58+
use a subclass strategy to populate regions (e.g., QDQAutotuner does this
59+
automatically during initialize()).
60+
61+
Args:
62+
model: ONNX model (onnx.ModelProto) or graph (gs.Graph) to optimize.
63+
A clean copy is created internally, leaving the original unchanged.
64+
65+
Raises:
66+
TypeError: If model is neither onnx.ModelProto nor gs.Graph
67+
"""
5568
if isinstance(model, onnx.ModelProto):
5669
self.onnx_model = model
5770
elif isinstance(model, gs.Graph):
@@ -76,7 +89,22 @@ def __init__(self, model: onnx.ModelProto | gs.Graph):
7689
def initialize(
7790
self, config: Config | None = None, pattern_cache: PatternCache | None = None
7891
) -> None:
79-
"""Initialize autotuning session with configuration and pattern cache."""
92+
"""Initialize autotuning session with configuration and pattern cache.
93+
94+
Prepares the autotuner for profiling by setting configuration parameters
95+
and optionally loading pattern cache data. This base method resets all profiling
96+
state and sets up the pattern cache storage.
97+
98+
Args:
99+
config: Autotuning configuration parameters. If None, uses default Config().
100+
Controls Q/DQ parameters, performance thresholds, and scheme generation.
101+
pattern_cache: Optional PatternCache object for seeding with known-good schemes.
102+
If None, creates a new empty pattern cache for tracking best schemes.
103+
If provided, uses existing schemes to warm-start optimization.
104+
105+
Raises:
106+
None (safe to call multiple times - will reset state each time)
107+
"""
80108
if config is not None:
81109
self.config = config
82110

@@ -109,7 +137,24 @@ def initialize(
109137
self.initialized = True
110138

111139
def set_profile_region(self, region: Region | None, commit: bool = True) -> None:
112-
"""Set the target region for profiling and scheme generation."""
140+
"""Set the target region for profiling and scheme generation.
141+
142+
This method manages the profiling workflow:
143+
1. If commit=True: Saves current schemes to profiled_patterns
144+
2. Creates a RegionPattern from the new region's structure
145+
3. For pattern-based: tries to seed schemes from pattern cache if available
146+
4. Sets as current for generate() and submit() calls
147+
148+
Pass region=None to clear the current profile target without setting a new one.
149+
150+
Args:
151+
region: The region to profile next (None to clear current target)
152+
commit: If True, commit current schemes to profiled_patterns
153+
before switching. Set to False during initialization.
154+
155+
Raises:
156+
AutotunerNotInitializedError: If initialize() hasn't been called
157+
"""
113158
if not self.initialized:
114159
raise AutotunerNotInitializedError(
115160
"QDQAutotunerBase not initialized. Call initialize() first."
@@ -185,13 +230,24 @@ def set_profile_region(self, region: Region | None, commit: bool = True) -> None
185230

186231
mode_info = f"seeded with {num_seeded} schemes" if num_seeded > 0 else "starting fresh"
187232
logger.info(
188-
f"Profiling region {region.id} [pattern mode, level {region.level}, "
189-
f"size {region.get_size_of_region_and_descendants()}, {mode_info}]"
233+
f"Profiling region {region.id} [level {region.level}, size"
234+
f"{region.get_size_of_region_and_descendants()}, {mode_info}]"
190235
)
191236
logger.debug(f"Pattern signature: {region_pattern.signature}")
192237

193238
def generate(self) -> int:
194-
"""Generate a new Q/DQ insertion scheme for the current pattern or region."""
239+
"""Generate a new Q/DQ insertion scheme for the current pattern or region.
240+
241+
Creates a new InsertionScheme by mutating the top-performing schemes:
242+
1. Checks if there are any cached schemes (error=False, latency_ms=inf)
243+
2. If cached schemes exist, picks one to re-profile
244+
3. Otherwise, generates a new scheme by mutation
245+
4. Selects a random scheme from the top 10 performers
246+
5. Mutates it by adding/removing insertion points
247+
6. Ensures the new scheme is unique (different from existing schemes)
248+
7. Adds the scheme to current_profile_pattern_schemes
249+
250+
"""
195251
if not self.initialized:
196252
raise AutotunerNotInitializedError(
197253
"QDQAutotunerBase not initialized. Call initialize() first."
@@ -261,7 +317,28 @@ def generate(self) -> int:
261317
def export_onnx(
262318
self, output_path: str | None = None, insert_qdq: bool = True, best: bool = False
263319
) -> bytes:
264-
"""Export ONNX model with Q/DQ nodes inserted according to tested schemes."""
320+
"""Export ONNX model with Q/DQ nodes inserted according to tested schemes.
321+
322+
This method creates a modified version of the model by:
323+
1. For each region, finding the matching pattern
324+
2. Applying the best scheme for profiled patterns
325+
3. Applying the current scheme for the active profile pattern
326+
4. Resolving pattern-relative insertion points to actual tensor names
327+
5. Inserting Q/DQ pairs at the resolved locations
328+
6. Converting to FP8 if needed (always creates INT8 first, then converts)
329+
330+
Args:
331+
output_path: Optional file path where the modified ONNX model will be saved.
332+
If None, the model is not saved to disk and only bytes are returned.
333+
insert_qdq: If True, insert Q/DQ nodes. If False, export unmodified model
334+
(useful for baseline measurements)
335+
336+
Returns:
337+
bytes: Serialized ONNX model as bytes
338+
339+
Raises:
340+
AutotunerNotInitializedError: If initialize() hasn't been called
341+
"""
265342
if not self.initialized:
266343
raise AutotunerNotInitializedError(
267344
"QDQAutotunerBase not initialized. Call initialize() first."
@@ -387,7 +464,19 @@ def export_onnx(
387464
return model_bytes
388465

389466
def submit(self, latency_ms: float, success: bool = True) -> None:
390-
"""Submit performance measurement for the most recently generated scheme."""
467+
"""Submit performance measurement for the most recently generated scheme.
468+
469+
This method records the measured latency and manages the optimization state:
470+
471+
Args:
472+
latency_ms: Measured latency in milliseconds (must be > 0)
473+
success: Whether the measurement succeeded. If False, sets scheme.error=True,
474+
logs a warning, and skips speedup calculation.
475+
476+
Raises:
477+
AutotunerNotInitializedError: If initialize() hasn't been called
478+
InvalidSchemeError: If no pattern or region is set, or no schemes have been generated
479+
"""
391480
if not self.initialized:
392481
raise AutotunerNotInitializedError(
393482
"QDQAutotunerBase not initialized. Call initialize() first."
@@ -458,7 +547,19 @@ def submit(self, latency_ms: float, success: bool = True) -> None:
458547
)
459548

460549
def save_state(self, output_path: str) -> None:
461-
"""Save complete autotuner state to a YAML file for later reuse."""
550+
"""Save complete autotuner state to a YAML file for later reuse.
551+
552+
Serializes all optimization results including:
553+
- Baseline latency measurement
554+
- All profiled patterns with their signatures
555+
- All generated schemes with insertion points and latencies
556+
- Configuration parameters
557+
- Current profiling state
558+
559+
Args:
560+
output_path: File path where the YAML state file will be written.
561+
Pattern cache will be saved to <base>_pattern_cache.yaml
562+
"""
462563
current_pattern_sig = None
463564
if self.current_profile_pattern_schemes is not None:
464565
current_pattern_sig = self.current_profile_pattern_schemes.pattern_signature
@@ -498,7 +599,20 @@ def save_state(self, output_path: str) -> None:
498599
)
499600

500601
def load_state(self, input_path: str) -> None:
501-
"""Load autotuner state from a previously saved YAML file."""
602+
"""Load autotuner state from a previously saved YAML file.
603+
604+
Restores optimization results from a previous session:
605+
1. Matches saved patterns to current model's patterns by signature
606+
2. Loads all schemes with their insertion points and latencies (including unmeasured ones)
607+
3. Restores baseline latency and configuration
608+
609+
Args:
610+
input_path: File path to the YAML state file to load
611+
612+
Raises:
613+
AutotunerNotInitializedError: If initialize() hasn't been called
614+
FileNotFoundError: If the input_path doesn't exist
615+
"""
502616
if not self.initialized:
503617
raise AutotunerNotInitializedError(
504618
"QDQAutotunerBase not initialized. Call initialize() first."
@@ -571,7 +685,20 @@ def load_state(self, input_path: str) -> None:
571685
logger.debug(f"No pattern cache file at {cache_path}")
572686

573687
def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> None:
574-
"""Import Q/DQ insertion points from a list of quantized tensors and update pattern cache."""
688+
"""Import Q/DQ insertion points from a list of quantized tensors and update pattern cache.
689+
690+
Analyzes the current model's regions against the provided quantized tensors
691+
to extract Q/DQ insertion patterns. For each region, creates a pattern cache
692+
entry that captures which insertion points correspond to the quantized tensors.
693+
These cached patterns can then be used as seeds for future autotuning sessions.
694+
695+
Args:
696+
quantized_tensors: Set or list of tensor names that are quantized
697+
(i.e., tensors that have Q/DQ nodes applied to them)
698+
699+
Raises:
700+
AutotunerNotInitializedError: If initialize() hasn't been called
701+
"""
575702
if not self.initialized:
576703
raise AutotunerNotInitializedError(
577704
"QDQAutotunerBase not initialized. Call initialize() first."
@@ -607,7 +734,22 @@ def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> No
607734
def _compute_convergence_metrics(
608735
self, schemes: list[InsertionScheme], best_scheme: InsertionScheme | None
609736
) -> tuple[int | None, float | None]:
610-
"""Compute convergence metrics for a collection of schemes."""
737+
"""Compute convergence metrics for a collection of schemes.
738+
739+
Analyzes when the best scheme was discovered during the profiling process
740+
by sorting schemes by their profile timestamps and finding the position
741+
of the best scheme.
742+
743+
Args:
744+
schemes: List of insertion schemes with profile timestamps
745+
best_scheme: The best performing scheme (lowest latency)
746+
747+
Returns:
748+
Tuple of (samples_before_best, time_to_best) where:
749+
- samples_before_best: Number of samples tested before finding best (0-based index)
750+
- time_to_best: Time in seconds from first sample to best sample
751+
Both values are None if metrics cannot be computed (e.g., missing timestamps)
752+
"""
611753
samples_before_best = None
612754
time_to_best = None
613755

@@ -690,7 +832,29 @@ def _mutate_insertion_points(
690832
return [p for p in all_points if key_fn(p) in current_points]
691833

692834
def _generate_next_insertion_sample(self) -> InsertionScheme:
693-
"""Generate a new insertion scheme by mutating top performers."""
835+
"""Generate a new insertion scheme by mutating top performers.
836+
837+
This is the core scheme generation algorithm:
838+
1. Identifies top schemes by latency
839+
2. Randomly selects one as the base
840+
3. Mutates node input insertion points (add, remove, or both)
841+
4. Mutates region composite insertion points (child boundaries)
842+
5. Mutates region output insertion points
843+
6. Returns new unique scheme
844+
845+
**Mutation Strategy:**
846+
- Node input points: Add/remove 1-3 insertion points
847+
- Region composite points: Add/remove 1-3 boundary points
848+
- Region output points: Add/remove 1-3 output points
849+
- Mutation type chosen randomly: 'add', 'remove', or 'both'
850+
851+
**Baseline Case:**
852+
If no schemes exist yet, returns an empty baseline scheme.
853+
854+
Returns:
855+
New InsertionScheme with mutated insertion points.
856+
Returns empty scheme if no region is set or no candidates exist.
857+
"""
694858
if self.current_profile_region is None:
695859
return InsertionScheme()
696860

@@ -891,7 +1055,20 @@ def _create_qdq_nodes(
8911055
quant_dtype: np.dtype,
8921056
q_scale: float,
8931057
) -> tuple[gs.Node, gs.Node]:
894-
"""Create QuantizeLinear and DequantizeLinear node pair."""
1058+
"""Create QuantizeLinear and DequantizeLinear node pair.
1059+
1060+
Args:
1061+
tensor_name: Name of the tensor being quantized
1062+
qdq_input: Input tensor to the Q node
1063+
output_shape: Shape for Q/DQ outputs (may be None)
1064+
output_dtype: Dtype for DQ output (also used for scale dtype)
1065+
quant_dtype: Dtype for quantized values
1066+
quant_type: Quantization type string
1067+
q_scale: Quantization scale
1068+
1069+
Returns:
1070+
Tuple of (q_node, dq_node)
1071+
"""
8951072
# Create unique names for Q/DQ nodes
8961073
q_name = f"QDQ_Q_{tensor_name}".replace("/", "_").replace(":", "_")
8971074
dq_name = f"QDQ_DQ_{tensor_name}".replace("/", "_").replace(":", "_")
@@ -943,7 +1120,17 @@ def _create_qdq_nodes(
9431120
def _insert_qdq_at_tensors(
9441121
self, graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint]
9451122
) -> None:
946-
"""Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations."""
1123+
"""Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations.
1124+
1125+
This is the main entry point for Q/DQ insertion. It:
1126+
1. Builds tensor map and tensor-to-users map for efficient lookup
1127+
2. Processes each resolved insertion point to insert Q/DQ nodes
1128+
3. Handles two insertion modes based on node_index
1129+
1130+
Args:
1131+
graph: Graph to modify in-place
1132+
resolved_insertion_points: Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ
1133+
"""
9471134
q_scale = self.config.default_q_scale
9481135
quant_type = self.config.default_quant_type
9491136
quant_dtype = self._get_quant_dtype(quant_type)
@@ -1031,12 +1218,28 @@ class QDQAutotuner(QDQAutotunerBase):
10311218
def initialize(
10321219
self, config: Config | None = None, pattern_cache: PatternCache | None = None
10331220
) -> None:
1034-
"""Initialize autotuner and discover optimization regions automatically."""
1221+
"""Initialize autotuner and discover optimization regions automatically.
1222+
1223+
Extends base class initialization by automatically searching for regions
1224+
after configuration is set up. Regions are discovered using pattern-based
1225+
search around compute-intensive operations.
1226+
"""
10351227
super().initialize(config, pattern_cache)
10361228
self._search_regions()
10371229

10381230
def _visit_region_recursively(self, region: Region) -> list[Region]:
1039-
"""Recursively traverse region hierarchy and collect all regions."""
1231+
"""Recursively traverse region hierarchy and collect all regions.
1232+
1233+
Performs depth-first traversal of the region tree starting from a given
1234+
region. Collects the root region and all descendant regions (children,
1235+
grandchildren, etc.) into a flat list.
1236+
1237+
Args:
1238+
region: Root region to start traversal from
1239+
1240+
Returns:
1241+
List of all regions in the subtree (including root), in pre-order DFS.
1242+
"""
10401243
regions = [region]
10411244

10421245
for child in region.get_children():
@@ -1045,7 +1248,15 @@ def _visit_region_recursively(self, region: Region) -> list[Region]:
10451248
return regions
10461249

10471250
def _reassign_region_ids(self, regions: list[Region]) -> None:
1048-
"""Reassign sequential IDs to regions in breadth-first order."""
1251+
"""Reassign sequential IDs to regions in breadth-first order.
1252+
1253+
Traverses the region hierarchy (including children) and assigns new
1254+
sequential IDs starting from 0. This ensures clean, predictable region
1255+
numbering after region discovery and manipulation.
1256+
1257+
Args:
1258+
regions: List of top-level regions (children will be processed too)
1259+
"""
10491260
region_id = 0
10501261

10511262
queue = deque(regions)
@@ -1057,7 +1268,19 @@ def _reassign_region_ids(self, regions: list[Region]) -> None:
10571268
queue.extend(region.get_children())
10581269

10591270
def _search_regions(self) -> None:
1060-
"""Discover and organize optimization regions automatically."""
1271+
"""Discover and organize optimization regions automatically.
1272+
1273+
This is the core region discovery method that:
1274+
1. Runs automatic region search to find optimization targets
1275+
2. Flattens hierarchical structure into a list
1276+
3. Prioritizes LEAF regions (contain actual nodes)
1277+
4. Reassigns IDs for clean indexing
1278+
1279+
**Search Strategy:**
1280+
Uses CombinedRegionSearch which performs:
1281+
- Phase 1: Bottom-up partitioning based on divergence/convergence
1282+
- Phase 2: Top-down refinement creating hierarchical structure
1283+
"""
10611284
logger.info("Discovering optimization regions")
10621285
search = CombinedRegionSearch(
10631286
self.graph,

0 commit comments

Comments
 (0)