Skip to content

Commit b78da96

Browse files
committed
Add NVTX ranges
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
1 parent cee195c commit b78da96

1 file changed

Lines changed: 13 additions & 4 deletions

File tree

  • dali/python/nvidia/dali/experimental/dynamic

dali/python/nvidia/dali/experimental/dynamic/_compile.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import threading
1818
import types
1919
import warnings
20-
from collections.abc import Callable, Mapping, Sequence, Iterable
20+
from collections.abc import Callable, Iterable, Mapping, Sequence
2121
from contextlib import contextmanager
2222
from typing import TYPE_CHECKING, Any, NamedTuple
2323

@@ -34,12 +34,17 @@
3434
resolve_callsite_frame,
3535
)
3636
from ._device import Device
37+
from ._nvtx import NVTXRange
3738

3839
if TYPE_CHECKING:
3940
from ._eval_context import EvalContext
4041
from ._ops import Operator, Reader
4142

4243

44+
def _nvtx_range(message: str):
45+
return NVTXRange(message, color=0xB58900, category="compile")
46+
47+
4348
class State(enum.Enum):
4449
TRACING = enum.auto()
4550
COMPILED = enum.auto()
@@ -196,6 +201,7 @@ def make_source_batches(self, tensor_lists: Sequence[Any]) -> tuple[CompiledBatc
196201
for i, tl in enumerate(tensor_lists)
197202
)
198203

204+
@_nvtx_range("Recording operator")
199205
def record(
200206
self,
201207
call_chain: CallChain,
@@ -222,6 +228,7 @@ def record(
222228
self._call_trie.insert(call_chain, node)
223229
return node
224230

231+
@_nvtx_range("Building pipeline")
225232
def build_pipeline(self, ctx: "EvalContext") -> None:
226233
if not self.nodes:
227234
warnings.warn(
@@ -268,6 +275,7 @@ def reader_callback():
268275
self.pipeline = pipe
269276
self.state = State.COMPILED
270277

278+
@_nvtx_range("Running compiled pipeline")
271279
def run_pipeline(self) -> tuple | dict:
272280
"""Run the compiled pipeline and cache all node results for this iteration."""
273281
assert self.pipeline is not None and self.source is not None
@@ -302,6 +310,7 @@ def _matches(self, actual: Any, expected: Any) -> bool:
302310
return actual is None
303311
return not isinstance(actual, Batch) and actual == expected
304312

313+
@_nvtx_range("Getting compiled result")
305314
def get_compiled_result(
306315
self,
307316
frame: types.FrameType,
@@ -424,6 +433,7 @@ def _batches_compiled(self, start_idx: int = 0):
424433
yield batches
425434

426435

436+
@_nvtx_range("Graph Wiring")
427437
def _wire_compile_graph(
428438
source: CompileSource,
429439
nodes: Sequence[CompileNode],
@@ -502,9 +512,8 @@ def _call():
502512
f"called with batch_size={batch_size}. Cannot change batch_size in "
503513
f"compiled mode."
504514
)
505-
cached = compile_ctx.get_compiled_result(frame, inputs, raw_kwargs, device=device)
506-
if cached is not None:
507-
return cached
515+
if result := compile_ctx.get_compiled_result(frame, inputs, raw_kwargs, device=device):
516+
return result
508517
return _call()
509518

510519
# Run first, classify after, we need the result before we can inspect it

0 commit comments

Comments
 (0)