Skip to content

Commit 359f070

Browse files
committed
Remove redundant version check and update requirements
- Removed check_cuda_bindings_version() function since PyTorch core now provides the warning via _probe_tools_id() - Updated PyTorch requirement from 2.0+ to 2.13+ (required for the annotation APIs used in this tutorial) - Simplified error messaging to reference PyTorch's built-in warnings
1 parent d9b296c commit 359f070

1 file changed

Lines changed: 3 additions & 71 deletions

File tree

advanced_source/cuda_graph_annotations_tutorial.py

Lines changed: 3 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
#
6161
# For this tutorial, you'll need:
6262
#
63-
# - PyTorch 2.0+
63+
# - PyTorch 2.13+
6464
# - A CUDA GPU
6565
# - Driver/CUDA-compat >= 13.1 for annotation support
6666
# - The ``cuda-bindings`` package >= 13.3.0 (``pip install cuda-python``)
@@ -97,69 +97,6 @@
9797
_move_overlapping_to_stream,
9898
)
9999

100-
###############################################################################
101-
# Checking CUDA Bindings Version
102-
# -------------------------------
103-
#
104-
# The annotation APIs require ``cuda-bindings >= 13.3.0`` to access the
105-
# ``cudaGraphNodeGetToolsId`` API. Let's check if the installed version
106-
# supports annotations.
107-
108-
def check_cuda_bindings_version():
109-
"""Check if cuda-bindings version supports annotations."""
110-
try:
111-
import cuda.bindings.runtime as runtime
112-
113-
# Check if the required API is available
114-
has_tools_id = hasattr(runtime, 'cudaGraphNodeGetToolsId')
115-
116-
if not has_tools_id:
117-
# Try to get the cuda-bindings version
118-
try:
119-
import importlib.metadata
120-
cuda_bindings_version = importlib.metadata.version('cuda-bindings')
121-
except Exception:
122-
cuda_bindings_version = "unknown"
123-
124-
print("=" * 70)
125-
print("ERROR: CUDA Bindings version too old for annotation support")
126-
print("=" * 70)
127-
print(f"Current cuda-bindings version: {cuda_bindings_version}")
128-
print(f"Required: cuda-bindings >= 13.3.0")
129-
print()
130-
print("The cudaGraphNodeGetToolsId API is not available in your")
131-
print("cuda-bindings installation. This API is required for kernel")
132-
print("annotations to work.")
133-
print()
134-
print("To fix this issue, upgrade cuda-bindings:")
135-
print(" pip install --upgrade cuda-bindings")
136-
print()
137-
print("If you're in a managed environment, you may need:")
138-
print(" pip install --upgrade --break-system-packages cuda-bindings")
139-
print()
140-
print("After upgrading, cuda-bindings 13.3.0+ will include the")
141-
print("cudaGraphNodeGetToolsId API needed for semantic annotations.")
142-
print("=" * 70)
143-
print()
144-
return False
145-
146-
return True
147-
148-
except ImportError:
149-
print("=" * 70)
150-
print("ERROR: cuda-bindings not installed")
151-
print("=" * 70)
152-
print("The cuda-bindings package is required for annotation support.")
153-
print()
154-
print("To install it:")
155-
print(" pip install cuda-python")
156-
print()
157-
print("This will install the CUDA Python bindings needed for")
158-
print("kernel annotations.")
159-
print("=" * 70)
160-
print()
161-
return False
162-
163100
###############################################################################
164101
# Building a Model
165102
# ----------------
@@ -396,20 +333,15 @@ def main():
396333
if not torch.cuda.is_available():
397334
raise SystemExit("CUDA required for this tutorial")
398335

399-
# Check cuda-bindings version first
400-
if not check_cuda_bindings_version():
401-
print("WARNING: Continuing without annotation support.")
402-
print("The tutorial will run, but no semantic annotations will be captured.")
403-
print("Only the cleanup passes will organize kernels.\n")
404-
405336
# Check if annotation support is available
337+
# PyTorch will log a warning if cuda-bindings version is too old
406338
supported = not _is_tools_id_unavailable()
407339
print(f"Annotation support available: {supported}")
408340
if not supported:
409341
print("NOTE: Annotation API not available.")
410342
print("This could be due to:")
411343
print(" - Driver/CUDA-compat < 13.1")
412-
print(" - Outdated cuda-bindings (see error above)")
344+
print(" - Outdated cuda-bindings (check PyTorch warnings above)")
413345
print("Annotations will not be recorded, but the demo will still run.")
414346
print("Any lane changes you see are from cleanup passes, not annotations.\n")
415347

0 commit comments

Comments
 (0)