|
8 | 8 | from typing import Any, Final, NamedTuple, TypeAlias, cast |
9 | 9 |
|
10 | 10 | from dltype._lib import _constants, _dtypes, _errors, _log_utils, _parser, _tensor_type_base |
| 11 | +from dltype._lib import _dependency_utilities as _deps |
11 | 12 |
|
| 13 | +if _deps.is_torch_available(): |
| 14 | + from torch.jit import TracerWarning # pyright: ignore[reportPrivateImportUsage] |
| 15 | +else: |
| 16 | + |
| 17 | + class _NullWarning(Warning): |
| 18 | + pass |
| 19 | + |
| 20 | + TracerWarning = _NullWarning |
12 | 21 | _logger: Final = _log_utils.get_logger(__name__) |
13 | 22 |
|
14 | 23 | EvaluatedDimensionT: TypeAlias = dict[str, int] |
@@ -123,43 +132,48 @@ def assert_context(self) -> None: |
123 | 132 | """Considering the current context, check if all tensors match their expected types.""" |
124 | 133 | __tracebackhide__ = not _constants.DEBUG_MODE |
125 | 134 |
|
126 | | - start_t = time.perf_counter_ns() |
| 135 | + with warnings.catch_warnings(): |
| 136 | + warnings.simplefilter(category=TracerWarning, action="ignore") |
127 | 137 |
|
128 | | - try: |
129 | | - while self._hinted_tensors: |
130 | | - tensor_context = self._hinted_tensors.popleft() |
131 | | - # first check if the tensor could possibly have the right shape |
132 | | - tensor_context.dltype_annotation.check( |
133 | | - tensor_context.tensor, |
134 | | - tensor_name=tensor_context.tensor_arg_name, |
135 | | - ) |
| 138 | + start_t = time.perf_counter_ns() |
136 | 139 |
|
137 | | - if tensor_context.tensor_arg_name in self.registered_tensor_dtypes: |
138 | | - raise _errors.DLTypeDuplicateError( |
| 140 | + try: |
| 141 | + while self._hinted_tensors: |
| 142 | + tensor_context = self._hinted_tensors.popleft() |
| 143 | + # first check if the tensor could possibly have the right shape |
| 144 | + tensor_context.dltype_annotation.check( |
| 145 | + tensor_context.tensor, |
139 | 146 | tensor_name=tensor_context.tensor_arg_name, |
140 | 147 | ) |
141 | 148 |
|
142 | | - self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = tensor_context.tensor.dtype |
143 | | - expected_shape = tensor_context.get_expected_shape( |
144 | | - tensor_context.tensor, |
145 | | - ) |
146 | | - self._assert_tensor_shape( |
147 | | - tensor_context.tensor_arg_name, |
148 | | - expected_shape, |
149 | | - tensor_context.tensor, |
150 | | - ) |
| 149 | + if tensor_context.tensor_arg_name in self.registered_tensor_dtypes: |
| 150 | + raise _errors.DLTypeDuplicateError( |
| 151 | + tensor_name=tensor_context.tensor_arg_name, |
| 152 | + ) |
151 | 153 |
|
152 | | - finally: |
153 | | - end_t = time.perf_counter_ns() |
154 | | - runtime_ns = end_t - start_t |
155 | | - _logger.debug("Context evaluation took %d ns", runtime_ns) |
156 | | - if _maybe_warn_runtime(runtime_ns): |
157 | | - max_ms = _constants.MAX_ACCEPTABLE_EVALUATION_TIME_NS / 1e6 |
158 | | - warnings.warn( |
159 | | - f"Type checking took longer than expected {(runtime_ns) / 1e6:.2f}ms > {max_ms:.2f}ms", |
160 | | - UserWarning, |
161 | | - stacklevel=2, |
162 | | - ) |
| 154 | + self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = ( |
| 155 | + tensor_context.tensor.dtype |
| 156 | + ) |
| 157 | + expected_shape = tensor_context.get_expected_shape( |
| 158 | + tensor_context.tensor, |
| 159 | + ) |
| 160 | + self._assert_tensor_shape( |
| 161 | + tensor_context.tensor_arg_name, |
| 162 | + expected_shape, |
| 163 | + tensor_context.tensor, |
| 164 | + ) |
| 165 | + |
| 166 | + finally: |
| 167 | + end_t = time.perf_counter_ns() |
| 168 | + runtime_ns = end_t - start_t |
| 169 | + _logger.debug("Context evaluation took %d ns", runtime_ns) |
| 170 | + if _maybe_warn_runtime(runtime_ns): |
| 171 | + max_ms = _constants.MAX_ACCEPTABLE_EVALUATION_TIME_NS / 1e6 |
| 172 | + warnings.warn( |
| 173 | + f"Type checking took longer than expected {(runtime_ns) / 1e6:.2f}ms > {max_ms:.2f}ms", |
| 174 | + UserWarning, |
| 175 | + stacklevel=2, |
| 176 | + ) |
163 | 177 |
|
164 | 178 | def _assert_tensor_shape( |
165 | 179 | self, |
|
0 commit comments