1717import threading
1818import types
1919import warnings
20- from collections .abc import Callable , Mapping , Sequence , Iterable
20+ from collections .abc import Callable , Iterable , Mapping , Sequence
2121from contextlib import contextmanager
2222from typing import TYPE_CHECKING , Any , NamedTuple
2323
3434 resolve_callsite_frame ,
3535)
3636from ._device import Device
37+ from ._nvtx import NVTXRange
3738
3839if TYPE_CHECKING :
3940 from ._eval_context import EvalContext
4041 from ._ops import Operator , Reader
4142
4243
44+ def _nvtx_range (message : str ):
45+ return NVTXRange (message , color = 0xB58900 , category = "compile" )
46+
47+
4348class State (enum .Enum ):
4449 TRACING = enum .auto ()
4550 COMPILED = enum .auto ()
@@ -196,6 +201,7 @@ def make_source_batches(self, tensor_lists: Sequence[Any]) -> tuple[CompiledBatc
196201 for i , tl in enumerate (tensor_lists )
197202 )
198203
204+ @_nvtx_range ("Recording operator" )
199205 def record (
200206 self ,
201207 call_chain : CallChain ,
@@ -222,6 +228,7 @@ def record(
222228 self ._call_trie .insert (call_chain , node )
223229 return node
224230
231+ @_nvtx_range ("Building pipeline" )
225232 def build_pipeline (self , ctx : "EvalContext" ) -> None :
226233 if not self .nodes :
227234 warnings .warn (
@@ -268,6 +275,7 @@ def reader_callback():
268275 self .pipeline = pipe
269276 self .state = State .COMPILED
270277
278+ @_nvtx_range ("Running compiled pipeline" )
271279 def run_pipeline (self ) -> tuple | dict :
272280 """Run the compiled pipeline and cache all node results for this iteration."""
273281 assert self .pipeline is not None and self .source is not None
@@ -302,6 +310,7 @@ def _matches(self, actual: Any, expected: Any) -> bool:
302310 return actual is None
303311 return not isinstance (actual , Batch ) and actual == expected
304312
313+ @_nvtx_range ("Getting compiled result" )
305314 def get_compiled_result (
306315 self ,
307316 frame : types .FrameType ,
@@ -424,6 +433,7 @@ def _batches_compiled(self, start_idx: int = 0):
424433 yield batches
425434
426435
436+ @_nvtx_range ("Graph Wiring" )
427437def _wire_compile_graph (
428438 source : CompileSource ,
429439 nodes : Sequence [CompileNode ],
@@ -502,9 +512,8 @@ def _call():
502512 f"called with batch_size={ batch_size } . Cannot change batch_size in "
503513 f"compiled mode."
504514 )
505- cached = compile_ctx .get_compiled_result (frame , inputs , raw_kwargs , device = device )
506- if cached is not None :
507- return cached
515+ if result := compile_ctx .get_compiled_result (frame , inputs , raw_kwargs , device = device ):
516+ return result
508517 return _call ()
509518
510519 # Run first, classify after, we need the result before we can inspect it
0 commit comments