|
60 | 60 | # |
61 | 61 | # For this tutorial, you'll need: |
62 | 62 | # |
63 | | -# - PyTorch 2.0+ |
| 63 | +# - PyTorch 2.13+ |
64 | 64 | # - A CUDA GPU |
65 | 65 | # - Driver/CUDA-compat >= 13.1 for annotation support |
66 | 66 | # - The ``cuda-bindings`` package >= 13.3.0 (``pip install cuda-python``) |
|
97 | 97 | _move_overlapping_to_stream, |
98 | 98 | ) |
99 | 99 |
|
100 | | -############################################################################### |
101 | | -# Checking CUDA Bindings Version |
102 | | -# ------------------------------- |
103 | | -# |
104 | | -# The annotation APIs require ``cuda-bindings >= 13.3.0`` to access the |
105 | | -# ``cudaGraphNodeGetToolsId`` API. Let's check if the installed version |
106 | | -# supports annotations. |
107 | | - |
108 | | -def check_cuda_bindings_version(): |
109 | | - """Check if cuda-bindings version supports annotations.""" |
110 | | - try: |
111 | | - import cuda.bindings.runtime as runtime |
112 | | - |
113 | | - # Check if the required API is available |
114 | | - has_tools_id = hasattr(runtime, 'cudaGraphNodeGetToolsId') |
115 | | - |
116 | | - if not has_tools_id: |
117 | | - # Try to get the cuda-bindings version |
118 | | - try: |
119 | | - import importlib.metadata |
120 | | - cuda_bindings_version = importlib.metadata.version('cuda-bindings') |
121 | | - except Exception: |
122 | | - cuda_bindings_version = "unknown" |
123 | | - |
124 | | - print("=" * 70) |
125 | | - print("ERROR: CUDA Bindings version too old for annotation support") |
126 | | - print("=" * 70) |
127 | | - print(f"Current cuda-bindings version: {cuda_bindings_version}") |
128 | | - print(f"Required: cuda-bindings >= 13.3.0") |
129 | | - print() |
130 | | - print("The cudaGraphNodeGetToolsId API is not available in your") |
131 | | - print("cuda-bindings installation. This API is required for kernel") |
132 | | - print("annotations to work.") |
133 | | - print() |
134 | | - print("To fix this issue, upgrade cuda-bindings:") |
135 | | - print(" pip install --upgrade cuda-bindings") |
136 | | - print() |
137 | | - print("If you're in a managed environment, you may need:") |
138 | | - print(" pip install --upgrade --break-system-packages cuda-bindings") |
139 | | - print() |
140 | | - print("After upgrading, cuda-bindings 13.3.0+ will include the") |
141 | | - print("cudaGraphNodeGetToolsId API needed for semantic annotations.") |
142 | | - print("=" * 70) |
143 | | - print() |
144 | | - return False |
145 | | - |
146 | | - return True |
147 | | - |
148 | | - except ImportError: |
149 | | - print("=" * 70) |
150 | | - print("ERROR: cuda-bindings not installed") |
151 | | - print("=" * 70) |
152 | | - print("The cuda-bindings package is required for annotation support.") |
153 | | - print() |
154 | | - print("To install it:") |
155 | | - print(" pip install cuda-python") |
156 | | - print() |
157 | | - print("This will install the CUDA Python bindings needed for") |
158 | | - print("kernel annotations.") |
159 | | - print("=" * 70) |
160 | | - print() |
161 | | - return False |
162 | | - |
163 | 100 | ############################################################################### |
164 | 101 | # Building a Model |
165 | 102 | # ---------------- |
@@ -396,20 +333,15 @@ def main(): |
396 | 333 | if not torch.cuda.is_available(): |
397 | 334 | raise SystemExit("CUDA required for this tutorial") |
398 | 335 |
|
399 | | - # Check cuda-bindings version first |
400 | | - if not check_cuda_bindings_version(): |
401 | | - print("WARNING: Continuing without annotation support.") |
402 | | - print("The tutorial will run, but no semantic annotations will be captured.") |
403 | | - print("Only the cleanup passes will organize kernels.\n") |
404 | | - |
405 | 336 | # Check if annotation support is available |
| 337 | + # PyTorch will log a warning if cuda-bindings version is too old |
406 | 338 | supported = not _is_tools_id_unavailable() |
407 | 339 | print(f"Annotation support available: {supported}") |
408 | 340 | if not supported: |
409 | 341 | print("NOTE: Annotation API not available.") |
410 | 342 | print("This could be due to:") |
411 | 343 | print(" - Driver/CUDA-compat < 13.1") |
412 | | - print(" - Outdated cuda-bindings (see error above)") |
| 344 | + print(" - Outdated cuda-bindings (check PyTorch warnings above)") |
413 | 345 | print("Annotations will not be recorded, but the demo will still run.") |
414 | 346 | print("Any lane changes you see are from cleanup passes, not annotations.\n") |
415 | 347 |
|
|
0 commit comments