|
8 | 8 | from dataclasses import dataclass |
9 | 9 | from typing import TYPE_CHECKING, Optional |
10 | 10 |
|
11 | | -from cuda.core.experimental._utils.cuda_utils import CUDAError, check_or_create_options, driver, handle_return |
| 11 | +from cuda.core.experimental._utils.cuda_utils import ( |
| 12 | + CUDAError, |
| 13 | + check_or_create_options, |
| 14 | + driver, |
| 15 | + handle_return, |
| 16 | +) |
| 17 | +from cuda.core.experimental._utils.cuda_utils import ( |
| 18 | + _check_driver_error as raise_if_driver_error, |
| 19 | +) |
12 | 20 |
|
13 | 21 | if TYPE_CHECKING: |
14 | 22 | import cuda.bindings |
@@ -117,13 +125,31 @@ def __rsub__(self, other): |
117 | 125 |
|
118 | 126 | def __sub__(self, other): |
119 | 127 | # return self - other (in milliseconds) |
| 128 | + err, timing = driver.cuEventElapsedTime(other.handle, self.handle) |
120 | 129 | try: |
121 | | - timing = handle_return(driver.cuEventElapsedTime(other.handle, self.handle)) |
| 130 | + raise_if_driver_error(err) |
| 131 | + return timing |
122 | 132 | except CUDAError as e: |
123 | | - raise RuntimeError( |
124 | | - "Timing capability must be enabled in order to subtract two Events; timing is disabled by default." |
125 | | - ) from e |
126 | | - return timing |
| 133 | + if err == driver.CUresult.CUDA_ERROR_INVALID_HANDLE: |
| 134 | + if self.is_timing_disabled or other.is_timing_disabled: |
| 135 | + explanation = ( |
| 136 | + "Both Events must be created with timing enabled in order to subtract them; " |
| 137 | + "use EventOptions(enable_timing=True) when creating both events." |
| 138 | + ) |
| 139 | + else: |
| 140 | + explanation = ( |
| 141 | + "Both Events must be recorded before they can be subtracted; " |
| 142 | + "use Stream.record() to record both events to a stream." |
| 143 | + ) |
| 144 | + elif err == driver.CUresult.CUDA_ERROR_NOT_READY: |
| 145 | + explanation = ( |
| 146 | + "One or both events have not completed; " |
| 147 | + "use Event.sync(), Stream.sync(), or Device.sync() to wait for the events to complete " |
| 148 | + "before subtracting them." |
| 149 | + ) |
| 150 | + else: |
| 151 | + raise e |
| 152 | + raise RuntimeError(explanation) from e |
127 | 153 |
|
128 | 154 | @property |
129 | 155 | def is_timing_disabled(self) -> bool: |
@@ -164,5 +190,11 @@ def is_done(self) -> bool: |
164 | 190 |
|
165 | 191 | @property |
166 | 192 | def handle(self) -> cuda.bindings.driver.CUevent: |
167 | | - """Return the underlying CUevent object.""" |
| 193 | + """Return the underlying CUevent object. |
| 194 | +
|
| 195 | + .. caution:: |
| 196 | +
|
| 197 | + This handle is a Python object. To get the memory address of the underlying C |
| 198 | + handle, call ``int(Event.handle)``. |
| 199 | + """ |
168 | 200 | return self._mnff.handle |
0 commit comments